diff --git a/cmd/stdiscosrv/apisrv_test.go b/cmd/stdiscosrv/apisrv_test.go index fa96de8bd..272c0ca12 100644 --- a/cmd/stdiscosrv/apisrv_test.go +++ b/cmd/stdiscosrv/apisrv_test.go @@ -107,7 +107,7 @@ func addr(host string, port int) *net.TCPAddr { } func BenchmarkAPIRequests(b *testing.B) { - db := newInMemoryStore(b.TempDir(), 0) + db := newInMemoryStore(b.TempDir(), 0, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go db.Serve(ctx) diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index f23ec73f9..95de82eab 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -24,10 +24,6 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/puzpuzpuz/xsync/v3" "github.com/syncthing/syncthing/lib/protocol" ) @@ -52,25 +48,27 @@ type inMemoryStore struct { m *xsync.MapOf[protocol.DeviceID, DatabaseRecord] dir string flushInterval time.Duration + s3 *s3Copier clock clock } -func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore { +func newInMemoryStore(dir string, flushInterval time.Duration, s3 *s3Copier) *inMemoryStore { s := &inMemoryStore{ m: xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](), dir: dir, flushInterval: flushInterval, + s3: s3, clock: defaultClock{}, } nr, err := s.read() - if os.IsNotExist(err) { + if os.IsNotExist(err) && s3 != nil { // Try to read from AWS fd, cerr := os.Create(path.Join(s.dir, "records.db")) if cerr != nil { log.Println("Error creating database file:", err) return s } - if err := s3Download(fd); err != nil { + if err := s3.downloadLatest(fd); err != nil { log.Printf("Error reading database from S3: %v", err) } _ = fd.Close() @@ -278,16 +276,15 @@ func (s *inMemoryStore) write() (err error) { return err } - if os.Getenv("PODINDEX") == "0" { - // Upload to S3 - log.Println("Uploading database") + // Upload to S3 + if s.s3 != nil { fd, err = os.Open(dbf) if err != nil { log.Printf("Error uploading database to S3: %v", err) return nil } defer fd.Close() - if err := s3Upload(fd); err != nil { + if err := s.s3.upload(fd); err != nil { log.Printf("Error uploading database to S3: %v", err) } log.Println("Finished uploading database") @@ -424,39 +421,6 @@ func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress { return naddrs } -func s3Upload(r io.Reader) error { - sess, err := session.NewSession(&aws.Config{ - Region: aws.String("fr-par"), - Endpoint: aws.String("s3.fr-par.scw.cloud"), - }) - if err != nil { - return err - } - uploader := s3manager.NewUploader(sess) - _, err = uploader.Upload(&s3manager.UploadInput{ - Bucket: aws.String("syncthing-discovery"), - Key: aws.String("discovery.db"), - Body: r, - }) - return err -} - -func s3Download(w io.WriterAt) error { - sess, err := session.NewSession(&aws.Config{ - Region: aws.String("fr-par"), - Endpoint: aws.String("s3.fr-par.scw.cloud"), - }) - if err != nil { - return err - } - downloader := s3manager.NewDownloader(sess) - _, err = downloader.Download(w, &s3.GetObjectInput{ - Bucket: aws.String("syncthing-discovery"), - Key: aws.String("discovery.db"), - }) - return err -} - func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) { if c := cmp.Compare(d.Address, other.Address); c != 0 { return c diff --git a/cmd/stdiscosrv/database_test.go b/cmd/stdiscosrv/database_test.go index b90592e85..89a557175 100644 --- a/cmd/stdiscosrv/database_test.go +++ b/cmd/stdiscosrv/database_test.go @@ -16,7 +16,7 @@ import ( ) func TestDatabaseGetSet(t *testing.T) { - db := newInMemoryStore(t.TempDir(), 0) + db := newInMemoryStore(t.TempDir(), 0, nil) ctx, cancel := context.WithCancel(context.Background()) go db.Serve(ctx) defer cancel() diff --git a/cmd/stdiscosrv/main.go b/cmd/stdiscosrv/main.go index 9f118ab16..daeb4b3c7 100644 --- a/cmd/stdiscosrv/main.go +++ b/cmd/stdiscosrv/main.go @@ -9,7 +9,6 @@ package main import ( "context" "crypto/tls" - "flag" "log" "net" "net/http" @@ -21,6 +20,7 @@ import ( _ "net/http/pprof" + "github.com/alecthomas/kong" "github.com/prometheus/client_golang/prometheus/promhttp" _ "github.com/syncthing/syncthing/lib/automaxprocs" "github.com/syncthing/syncthing/lib/build" @@ -58,52 +58,52 @@ const ( var debug = false +type CLI struct { + Cert string `group:"Listen" help:"Certificate file" default:"./cert.pem" env:"DISCOVERY_CERT_FILE"` + Key string `group:"Listen" help:"Key file" default:"./key.pem" env:"DISCOVERY_KEY_FILE"` + HTTP bool `group:"Listen" help:"Listen on HTTP (behind an HTTPS proxy)" env:"DISCOVERY_HTTP"` + Compression bool `group:"Listen" help:"Enable GZIP compression of responses" env:"DISCOVERY_COMPRESSION"` + Listen string `group:"Listen" help:"Listen address" default:":8443" env:"DISCOVERY_LISTEN"` + MetricsListen string `group:"Listen" help:"Metrics listen address" env:"DISCOVERY_METRICS_LISTEN"` + + Replicate []string `group:"Legacy replication" help:"Replication peers, id@address, comma separated" env:"DISCOVERY_REPLICATE"` + ReplicationListen string `group:"Legacy replication" help:"Replication listen address" default:":19200" env:"DISCOVERY_REPLICATION_LISTEN"` + ReplicationCert string `group:"Legacy replication" help:"Certificate file for replication" env:"DISCOVERY_REPLICATION_CERT_FILE"` + ReplicationKey string `group:"Legacy replication" help:"Key file for replication" env:"DISCOVERY_REPLICATION_KEY_FILE"` + + AMQPAddress string `group:"AMQP replication" help:"Address to AMQP broker" env:"DISCOVERY_AMQP_ADDRESS"` + + DBDir string `group:"Database" help:"Database directory" default:"." env:"DISCOVERY_DB_DIR"` + DBFlushInterval time.Duration `group:"Database" help:"Interval between database flushes" default:"5m" env:"DISCOVERY_DB_FLUSH_INTERVAL"` + + DBS3Endpoint string `name:"db-s3-endpoint" group:"Database (S3 backup)" help:"S3 endpoint for database" env:"DISCOVERY_DB_S3_ENDPOINT"` + DBS3Region string `name:"db-s3-region" group:"Database (S3 backup)" help:"S3 region for database" env:"DISCOVERY_DB_S3_REGION"` + DBS3Bucket string `name:"db-s3-bucket" group:"Database (S3 backup)" help:"S3 bucket for database" env:"DISCOVERY_DB_S3_BUCKET"` + DBS3AccessKeyID string `name:"db-s3-access-key-id" group:"Database (S3 backup)" help:"S3 access key ID for database" env:"DISCOVERY_DB_S3_ACCESS_KEY_ID"` + DBS3SecretKey string `name:"db-s3-secret-key" group:"Database (S3 backup)" help:"S3 secret key for database" env:"DISCOVERY_DB_S3_SECRET_KEY"` + + Debug bool `short:"d" help:"Print debug output" env:"DISCOVERY_DEBUG"` + Version bool `short:"v" help:"Print version and exit"` +} + func main() { - var listen string - var dir string - var metricsListen string - var replicationListen string - var replicationPeers string - var certFile string - var keyFile string - var replCertFile string - var replKeyFile string - var useHTTP bool - var compression bool - var amqpAddress string - var flushInterval time.Duration - log.SetOutput(os.Stdout) - // log.SetFlags(0) - flag.StringVar(&certFile, "cert", "./cert.pem", "Certificate file") - flag.StringVar(&keyFile, "key", "./key.pem", "Key file") - flag.StringVar(&dir, "db-dir", ".", "Database directory") - flag.BoolVar(&debug, "debug", false, "Print debug output") - flag.BoolVar(&useHTTP, "http", false, "Listen on HTTP (behind an HTTPS proxy)") - flag.BoolVar(&compression, "compression", true, "Enable GZIP compression of responses") - flag.StringVar(&listen, "listen", ":8443", "Listen address") - flag.StringVar(&metricsListen, "metrics-listen", "", "Metrics listen address") - flag.StringVar(&replicationPeers, "replicate", "", "Replication peers, id@address, comma separated") - flag.StringVar(&replicationListen, "replication-listen", ":19200", "Replication listen address") - flag.StringVar(&replCertFile, "replication-cert", "", "Certificate file for replication") - flag.StringVar(&replKeyFile, "replication-key", "", "Key file for replication") - flag.StringVar(&amqpAddress, "amqp-address", "", "Address to AMQP broker") - flag.DurationVar(&flushInterval, "flush-interval", 5*time.Minute, "Interval between database flushes") - showVersion := flag.Bool("version", false, "Show version") - flag.Parse() + var cli CLI + kong.Parse(&cli) + debug = cli.Debug log.Println(build.LongVersionFor("stdiscosrv")) - if *showVersion { + if cli.Version { return } buildInfo.WithLabelValues(build.Version, runtime.Version(), build.User, build.Date.UTC().Format("2006-01-02T15:04:05Z")).Set(1) - cert, err := tls.LoadX509KeyPair(certFile, keyFile) + cert, err := tls.LoadX509KeyPair(cli.Cert, cli.Key) if os.IsNotExist(err) { log.Println("Failed to load keypair. Generating one, this might take a while...") - cert, err = tlsutil.NewCertificate(certFile, keyFile, "stdiscosrv", 20*365) + cert, err = tlsutil.NewCertificate(cli.Cert, cli.Key, "stdiscosrv", 20*365) if err != nil { log.Fatalln("Failed to generate X509 key pair:", err) } @@ -114,8 +114,8 @@ func main() { log.Println("Server device ID is", devID) replCert := cert - if replCertFile != "" && replKeyFile != "" { - replCert, err = tls.LoadX509KeyPair(replCertFile, replKeyFile) + if cli.ReplicationCert != "" && cli.ReplicationKey != "" { + replCert, err = tls.LoadX509KeyPair(cli.ReplicationCert, cli.ReplicationKey) if err != nil { log.Fatalln("Failed to load replication keypair:", err) } @@ -126,8 +126,7 @@ func main() { // Parse the replication specs, if any. var allowedReplicationPeers []protocol.DeviceID var replicationDestinations []string - parts := strings.Split(replicationPeers, ",") - for _, part := range parts { + for _, part := range cli.Replicate { if part == "" { continue } @@ -165,10 +164,22 @@ func main() { // Root of the service tree. main := suture.New("main", suture.Spec{ PassThroughPanics: true, + Timeout: 2 * time.Minute, }) + // If configured, use S3 for database backups. + var s3c *s3Copier + if cli.DBS3Endpoint != "" { + hostname, err := os.Hostname() + if err != nil { + log.Fatalf("Failed to get hostname: %v", err) + } + key := hostname + ".db" + s3c = newS3Copier(cli.DBS3Endpoint, cli.DBS3Region, cli.DBS3Bucket, key, cli.DBS3AccessKeyID, cli.DBS3SecretKey) + } + // Start the database. - db := newInMemoryStore(dir, flushInterval) + db := newInMemoryStore(cli.DBDir, cli.DBFlushInterval, s3c) main.Add(db) // Start any replication senders. @@ -181,28 +192,28 @@ func main() { // If we have replication configured, start the replication listener. if len(allowedReplicationPeers) > 0 { - rl := newReplicationListener(replicationListen, replCert, allowedReplicationPeers, db) + rl := newReplicationListener(cli.ReplicationListen, replCert, allowedReplicationPeers, db) main.Add(rl) } // If we have an AMQP broker, start that - if amqpAddress != "" { + if cli.AMQPAddress != "" { clientID := rand.String(10) - kr := newAMQPReplicator(amqpAddress, clientID, db) + kr := newAMQPReplicator(cli.AMQPAddress, clientID, db) repl = append(repl, kr) main.Add(kr) } // Start the main API server. - qs := newAPISrv(listen, cert, db, repl, useHTTP, compression) + qs := newAPISrv(cli.Listen, cert, db, repl, cli.HTTP, cli.Compression) main.Add(qs) // If we have a metrics port configured, start a metrics handler. - if metricsListen != "" { + if cli.MetricsListen != "" { go func() { mux := http.NewServeMux() mux.Handle("/metrics", promhttp.Handler()) - log.Fatal(http.ListenAndServe(metricsListen, mux)) + log.Fatal(http.ListenAndServe(cli.MetricsListen, mux)) }() } diff --git a/cmd/stdiscosrv/s3.go b/cmd/stdiscosrv/s3.go new file mode 100644 index 000000000..f60cfee27 --- /dev/null +++ b/cmd/stdiscosrv/s3.go @@ -0,0 +1,97 @@ +// Copyright (C) 2024 The Syncthing Authors. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +package main + +import ( + "io" + "log" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" +) + +type s3Copier struct { + endpoint string + region string + bucket string + key string + accessKeyID string + secretKey string +} + +func newS3Copier(endpoint, region, bucket, key, accessKeyID, secretKey string) *s3Copier { + return &s3Copier{ + endpoint: endpoint, + region: region, + bucket: bucket, + key: key, + accessKeyID: accessKeyID, + secretKey: secretKey, + } +} + +func (s *s3Copier) upload(r io.Reader) error { + sess, err := session.NewSession(&aws.Config{ + Region: aws.String(s.region), + Endpoint: aws.String(s.endpoint), + Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""), + }) + if err != nil { + return err + } + + uploader := s3manager.NewUploader(sess) + _, err = uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(s.key), + Body: r, + }) + return err +} + +func (s *s3Copier) downloadLatest(w io.WriterAt) error { + sess, err := session.NewSession(&aws.Config{ + Region: aws.String(s.region), + Endpoint: aws.String(s.endpoint), + Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""), + }) + if err != nil { + return err + } + + svc := s3.New(sess) + resp, err := svc.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: aws.String(s.bucket)}) + if err != nil { + return err + } + + var lastKey string + var lastModified time.Time + var lastSize int64 + for _, item := range resp.Contents { + if item.LastModified.After(lastModified) && *item.Size > lastSize { + lastKey = *item.Key + lastModified = *item.LastModified + lastSize = *item.Size + } else if lastModified.Sub(*item.LastModified) < 5*time.Minute && *item.Size > lastSize { + lastKey = *item.Key + lastSize = *item.Size + } + } + + log.Println("Downloading database from", lastKey) + downloader := s3manager.NewDownloader(sess) + _, err = downloader.Download(w, &s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(lastKey), + }) + return err +}