From 68a1fd010f66734dd8ea7e43f9a5b88e87a8d2b6 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Sun, 8 Sep 2024 08:58:06 +0200 Subject: [PATCH 01/14] feat(stdiscosrv): make compression optional (and faster) --- cmd/stdiscosrv/apisrv.go | 15 ++++-- cmd/stdiscosrv/apisrv_test.go | 88 +++++++++++++++++++++++++++++++++++ cmd/stdiscosrv/main.go | 4 +- 3 files changed, 103 insertions(+), 4 deletions(-) diff --git a/cmd/stdiscosrv/apisrv.go b/cmd/stdiscosrv/apisrv.go index b8721ae68..4aa9c6adf 100644 --- a/cmd/stdiscosrv/apisrv.go +++ b/cmd/stdiscosrv/apisrv.go @@ -45,7 +45,9 @@ type apiSrv struct { listener net.Listener repl replicator // optional useHTTP bool + compression bool missesIncrease int + gzipWriters sync.Pool mapsMut sync.Mutex misses map[string]int32 @@ -61,13 +63,14 @@ type contextKey int const idKey contextKey = iota -func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP bool, missesIncrease int) *apiSrv { +func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP, compression bool, missesIncrease int) *apiSrv { return &apiSrv{ addr: addr, cert: cert, db: db, repl: repl, useHTTP: useHTTP, + compression: compression, misses: make(map[string]int32), missesIncrease: missesIncrease, } @@ -226,10 +229,16 @@ func (s *apiSrv) handleGET(w http.ResponseWriter, req *http.Request) { var bw io.Writer = w // Use compression if the client asks for it - if strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") { + if s.compression && strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") { + gw, ok := s.gzipWriters.Get().(*gzip.Writer) + if ok { + gw.Reset(w) + } else { + gw = gzip.NewWriter(w) + } w.Header().Set("Content-Encoding", "gzip") - gw := gzip.NewWriter(bw) defer gw.Close() + defer s.gzipWriters.Put(gw) bw = gw } diff --git a/cmd/stdiscosrv/apisrv_test.go b/cmd/stdiscosrv/apisrv_test.go index 457ed404f..505e85862 100644 --- a/cmd/stdiscosrv/apisrv_test.go +++ b/cmd/stdiscosrv/apisrv_test.go @@ -7,9 +7,19 @@ package main import ( + "context" + "crypto/tls" "fmt" + "io" "net" + "net/http" + "net/http/httptest" + "os" + "strings" "testing" + + "github.com/syncthing/syncthing/lib/protocol" + "github.com/syncthing/syncthing/lib/tlsutil" ) func TestFixupAddresses(t *testing.T) { @@ -94,3 +104,81 @@ func addr(host string, port int) *net.TCPAddr { Port: port, } } + +func BenchmarkAPIRequests(b *testing.B) { + db, err := newLevelDBStore(b.TempDir()) + if err != nil { + b.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go db.Serve(ctx) + api := newAPISrv("127.0.0.1:0", tls.Certificate{}, db, nil, true, true, 1) + srv := httptest.NewServer(http.HandlerFunc(api.handler)) + + kf := b.TempDir() + "/cert" + crt, err := tlsutil.NewCertificate(kf+".crt", kf+".key", "localhost", 7) + if err != nil { + b.Fatal(err) + } + certBs, err := os.ReadFile(kf + ".crt") + if err != nil { + b.Fatal(err) + } + certString := string(strings.ReplaceAll(string(certBs), "\n", " ")) + + devID := protocol.NewDeviceID(crt.Certificate[0]) + devIDString := devID.String() + + b.Run("Announce", func(b *testing.B) { + b.ReportAllocs() + url := srv.URL + "/v2/?device=" + devIDString + for i := 0; i < b.N; i++ { + req, _ := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"addresses":["tcp://10.10.10.10:42000"]}`)) + req.Header.Set("X-Ssl-Cert", certString) + resp, err := http.DefaultClient.Do(req) + if err != nil { + b.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + b.Fatalf("unexpected status %s", resp.Status) + } + } + }) + + b.Run("Lookup", func(b *testing.B) { + b.ReportAllocs() + url := srv.URL + "/v2/?device=" + devIDString + for i := 0; i < b.N; i++ { + req, _ := http.NewRequest(http.MethodGet, url, nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + b.Fatal(err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b.Fatalf("unexpected status %s", resp.Status) + } + } + }) + + b.Run("LookupNoCompression", func(b *testing.B) { + b.ReportAllocs() + url := srv.URL + "/v2/?device=" + devIDString + for i := 0; i < b.N; i++ { + req, _ := http.NewRequest(http.MethodGet, url, nil) + req.Header.Set("Accept-Encoding", "identity") // disable compression + resp, err := http.DefaultClient.Do(req) + if err != nil { + b.Fatal(err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b.Fatalf("unexpected status %s", resp.Status) + } + } + }) +} diff --git a/cmd/stdiscosrv/main.go b/cmd/stdiscosrv/main.go index 0b8c907f5..4a1245266 100644 --- a/cmd/stdiscosrv/main.go +++ b/cmd/stdiscosrv/main.go @@ -80,6 +80,7 @@ func main() { var replCertFile string var replKeyFile string var useHTTP bool + var compression bool var largeDB bool var amqpAddress string missesIncrease := 1 @@ -92,6 +93,7 @@ func main() { flag.StringVar(&dir, "db-dir", "./discovery.db", "Database directory") flag.BoolVar(&debug, "debug", false, "Print debug output") flag.BoolVar(&useHTTP, "http", false, "Listen on HTTP (behind an HTTPS proxy)") + flag.BoolVar(&compression, "compression", true, "Enable GZIP compression of responses") flag.StringVar(&listen, "listen", ":8443", "Listen address") flag.StringVar(&metricsListen, "metrics-listen", "", "Metrics listen address") flag.StringVar(&replicationPeers, "replicate", "", "Replication peers, id@address, comma separated") @@ -225,7 +227,7 @@ func main() { }() // Start the main API server. - qs := newAPISrv(listen, cert, db, repl, useHTTP, missesIncrease) + qs := newAPISrv(listen, cert, db, repl, useHTTP, compression, missesIncrease) main.Add(qs) // If we have a metrics port configured, start a metrics handler. From aed2c66e52aca81ef4402fcf9c9ae7d15464f43c Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Fri, 6 Sep 2024 11:14:23 +0200 Subject: [PATCH 02/14] feat(discosrv): in-memory storage with S3 backing --- cmd/stdiscosrv/amqp.go | 23 +- cmd/stdiscosrv/apisrv.go | 123 ++++---- cmd/stdiscosrv/apisrv_test.go | 7 +- cmd/stdiscosrv/database.go | 483 ++++++++++++++++++-------------- cmd/stdiscosrv/database.pb.go | 107 ++----- cmd/stdiscosrv/database.proto | 6 +- cmd/stdiscosrv/database_test.go | 52 +--- cmd/stdiscosrv/main.go | 73 ++--- cmd/stdiscosrv/replication.go | 19 +- cmd/stdiscosrv/stats.go | 32 ++- go.mod | 3 + go.sum | 10 + 12 files changed, 485 insertions(+), 453 deletions(-) diff --git a/cmd/stdiscosrv/amqp.go b/cmd/stdiscosrv/amqp.go index e09919e27..e32eea49e 100644 --- a/cmd/stdiscosrv/amqp.go +++ b/cmd/stdiscosrv/amqp.go @@ -7,11 +7,14 @@ package main import ( + "bytes" "context" "fmt" "io" + "log" amqp "github.com/rabbitmq/amqp091-go" + "github.com/syncthing/syncthing/lib/protocol" "github.com/thejerf/suture/v4" ) @@ -49,7 +52,7 @@ func newAMQPReplicator(broker, clientID string, db database) *amqpReplicator { } } -func (s *amqpReplicator) send(key string, ps []DatabaseAddress, seen int64) { +func (s *amqpReplicator) send(key *protocol.DeviceID, ps []DatabaseAddress, seen int64) { s.sender.send(key, ps, seen) } @@ -109,9 +112,9 @@ func (s *amqpSender) String() string { return fmt.Sprintf("amqpSender(%q)", s.broker) } -func (s *amqpSender) send(key string, ps []DatabaseAddress, seen int64) { +func (s *amqpSender) send(key *protocol.DeviceID, ps []DatabaseAddress, seen int64) { item := ReplicationRecord{ - Key: key, + Key: key[:], Addresses: ps, Seen: seen, } @@ -161,8 +164,20 @@ func (s *amqpReceiver) Serve(ctx context.Context) error { replicationRecvsTotal.WithLabelValues("error").Inc() return fmt.Errorf("replication unmarshal: %w", err) } + if bytes.Equal(rec.Key, []byte("")) { + continue + } + id, err := protocol.DeviceIDFromBytes(rec.Key) + if err != nil { + id, err = protocol.DeviceIDFromString(string(rec.Key)) + } + if err != nil { + log.Println("Replication device ID:", err) + replicationRecvsTotal.WithLabelValues("error").Inc() + continue + } - if err := s.db.merge(rec.Key, rec.Addresses, rec.Seen); err != nil { + if err := s.db.merge(&id, rec.Addresses, rec.Seen); err != nil { return fmt.Errorf("replication database merge: %w", err) } diff --git a/cmd/stdiscosrv/apisrv.go b/cmd/stdiscosrv/apisrv.go index 4aa9c6adf..d60a55896 100644 --- a/cmd/stdiscosrv/apisrv.go +++ b/cmd/stdiscosrv/apisrv.go @@ -46,11 +46,9 @@ type apiSrv struct { repl replicator // optional useHTTP bool compression bool - missesIncrease int gzipWriters sync.Pool - - mapsMut sync.Mutex - misses map[string]int32 + seenTracker *retryAfterTracker + notSeenTracker *retryAfterTracker } type requestID int64 @@ -63,20 +61,30 @@ type contextKey int const idKey contextKey = iota -func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP, compression bool, missesIncrease int) *apiSrv { +func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP, compression bool) *apiSrv { return &apiSrv{ - addr: addr, - cert: cert, - db: db, - repl: repl, - useHTTP: useHTTP, - compression: compression, - misses: make(map[string]int32), - missesIncrease: missesIncrease, + addr: addr, + cert: cert, + db: db, + repl: repl, + useHTTP: useHTTP, + compression: compression, + seenTracker: &retryAfterTracker{ + name: "seenTracker", + bucketStarts: time.Now(), + desiredRate: 250, + currentDelay: notFoundRetryUnknownMinSeconds, + }, + notSeenTracker: &retryAfterTracker{ + name: "notSeenTracker", + bucketStarts: time.Now(), + desiredRate: 250, + currentDelay: notFoundRetryUnknownMaxSeconds / 2, + }, } } -func (s *apiSrv) Serve(_ context.Context) error { +func (s *apiSrv) Serve(ctx context.Context) error { if s.useHTTP { listener, err := net.Listen("tcp", s.addr) if err != nil { @@ -110,6 +118,11 @@ func (s *apiSrv) Serve(_ context.Context) error { ErrorLog: log.New(io.Discard, "", 0), } + go func() { + <-ctx.Done() + srv.Shutdown(context.Background()) + }() + err := srv.Serve(s.listener) if err != nil { log.Println("Serve:", err) @@ -186,8 +199,7 @@ func (s *apiSrv) handleGET(w http.ResponseWriter, req *http.Request) { return } - key := deviceID.String() - rec, err := s.db.get(key) + rec, err := s.db.get(&deviceID) if err != nil { // some sort of internal error lookupRequestsTotal.WithLabelValues("internal_error").Inc() @@ -197,27 +209,14 @@ func (s *apiSrv) handleGET(w http.ResponseWriter, req *http.Request) { } if len(rec.Addresses) == 0 { - lookupRequestsTotal.WithLabelValues("not_found").Inc() - - s.mapsMut.Lock() - misses := s.misses[key] - if misses < rec.Misses { - misses = rec.Misses + var afterS int + if rec.Seen == 0 { + afterS = s.notSeenTracker.retryAfterS() + lookupRequestsTotal.WithLabelValues("not_found_ever").Inc() + } else { + afterS = s.seenTracker.retryAfterS() + lookupRequestsTotal.WithLabelValues("not_found_recent").Inc() } - misses += int32(s.missesIncrease) - s.misses[key] = misses - s.mapsMut.Unlock() - - if misses >= notFoundMissesWriteInterval { - rec.Misses = misses - rec.Missed = time.Now().UnixNano() - rec.Addresses = nil - // rec.Seen retained from get - s.db.put(key, rec) - } - - afterS := notFoundRetryAfterSeconds(int(misses)) - retryAfterHistogram.Observe(float64(afterS)) w.Header().Set("Retry-After", strconv.Itoa(afterS)) http.Error(w, "Not Found", http.StatusNotFound) return @@ -301,7 +300,6 @@ func (s *apiSrv) Stop() { } func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string) error { - key := deviceID.String() now := time.Now() expire := now.Add(addressExpiryTime).UnixNano() @@ -317,9 +315,9 @@ func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string) seen := now.UnixNano() if s.repl != nil { - s.repl.send(key, dbAddrs, seen) + s.repl.send(&deviceID, dbAddrs, seen) } - return s.db.merge(key, dbAddrs, seen) + return s.db.merge(&deviceID, dbAddrs, seen) } func handlePing(w http.ResponseWriter, _ *http.Request) { @@ -503,15 +501,44 @@ func errorRetryAfterString() string { return strconv.Itoa(errorRetryAfterSeconds + rand.Intn(errorRetryFuzzSeconds)) } -func notFoundRetryAfterSeconds(misses int) int { - retryAfterS := notFoundRetryMinSeconds + notFoundRetryIncSeconds*misses - if retryAfterS > notFoundRetryMaxSeconds { - retryAfterS = notFoundRetryMaxSeconds - } - retryAfterS += rand.Intn(notFoundRetryFuzzSeconds) - return retryAfterS -} - func reannounceAfterString() string { return strconv.Itoa(reannounceAfterSeconds + rand.Intn(reannounzeFuzzSeconds)) } + +type retryAfterTracker struct { + name string + desiredRate float64 // requests per second + + mut sync.Mutex + lastCount int // requests in the last bucket + curCount int // requests in the current bucket + bucketStarts time.Time // start of the current bucket + currentDelay int // current delay in seconds +} + +func (t *retryAfterTracker) retryAfterS() int { + now := time.Now() + t.mut.Lock() + if durS := now.Sub(t.bucketStarts).Seconds(); durS > float64(t.currentDelay) { + t.bucketStarts = now + t.lastCount = t.curCount + lastRate := float64(t.lastCount) / durS + + switch { + case t.currentDelay > notFoundRetryUnknownMinSeconds && + lastRate < 0.75*t.desiredRate: + t.currentDelay = max(8*t.currentDelay/10, notFoundRetryUnknownMinSeconds) + case t.currentDelay < notFoundRetryUnknownMaxSeconds && + lastRate > 1.25*t.desiredRate: + t.currentDelay = min(3*t.currentDelay/2, notFoundRetryUnknownMaxSeconds) + } + + t.curCount = 0 + } + if t.curCount == 0 { + retryAfterLevel.WithLabelValues(t.name).Set(float64(t.currentDelay)) + } + t.curCount++ + t.mut.Unlock() + return t.currentDelay + rand.Intn(t.currentDelay/4) +} diff --git a/cmd/stdiscosrv/apisrv_test.go b/cmd/stdiscosrv/apisrv_test.go index 505e85862..07723a0e0 100644 --- a/cmd/stdiscosrv/apisrv_test.go +++ b/cmd/stdiscosrv/apisrv_test.go @@ -106,14 +106,11 @@ func addr(host string, port int) *net.TCPAddr { } func BenchmarkAPIRequests(b *testing.B) { - db, err := newLevelDBStore(b.TempDir()) - if err != nil { - b.Fatal(err) - } + db := newInMemoryStore(b.TempDir(), 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go db.Serve(ctx) - api := newAPISrv("127.0.0.1:0", tls.Certificate{}, db, nil, true, true, 1) + api := newAPISrv("127.0.0.1:0", tls.Certificate{}, db, nil, true, true) srv := httptest.NewServer(http.HandlerFunc(api.handler)) kf := b.TempDir() + "/cert" diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index 1bc841df5..29686e584 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -10,17 +10,25 @@ package main import ( + "bufio" "context" + "encoding/binary" + "io" "log" "net" "net/url" + "os" + "path" "sort" "time" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/puzpuzpuz/xsync/v3" + "github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/sliceutil" - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/storage" - "github.com/syndtr/goleveldb/leveldb/util" ) type clock interface { @@ -34,270 +42,296 @@ func (defaultClock) Now() time.Time { } type database interface { - put(key string, rec DatabaseRecord) error - merge(key string, addrs []DatabaseAddress, seen int64) error - get(key string) (DatabaseRecord, error) + put(key *protocol.DeviceID, rec DatabaseRecord) error + merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error + get(key *protocol.DeviceID) (DatabaseRecord, error) } -type levelDBStore struct { - db *leveldb.DB - inbox chan func() - clock clock - marshalBuf []byte +type inMemoryStore struct { + m *xsync.MapOf[protocol.DeviceID, DatabaseRecord] + dir string + flushInterval time.Duration + clock clock } -func newLevelDBStore(dir string) (*levelDBStore, error) { - db, err := leveldb.OpenFile(dir, levelDBOptions) - if err != nil { - return nil, err +func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore { + s := &inMemoryStore{ + m: xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](), + dir: dir, + flushInterval: flushInterval, + clock: defaultClock{}, } - return &levelDBStore{ - db: db, - inbox: make(chan func(), 16), - clock: defaultClock{}, - }, nil -} - -func newMemoryLevelDBStore() (*levelDBStore, error) { - db, err := leveldb.Open(storage.NewMemStorage(), nil) - if err != nil { - return nil, err - } - return &levelDBStore{ - db: db, - inbox: make(chan func(), 16), - clock: defaultClock{}, - }, nil -} - -func (s *levelDBStore) put(key string, rec DatabaseRecord) error { - t0 := time.Now() - defer func() { - databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds()) - }() - - rc := make(chan error) - - s.inbox <- func() { - size := rec.Size() - if len(s.marshalBuf) < size { - s.marshalBuf = make([]byte, size) + err := s.read() + if os.IsNotExist(err) { + // Try to read from AWS + fd, cerr := os.Create(path.Join(s.dir, "records.db")) + if cerr != nil { + log.Println("Error creating database file:", err) + return s } - n, _ := rec.MarshalTo(s.marshalBuf) - rc <- s.db.Put([]byte(key), s.marshalBuf[:n], nil) + if err := s3Download(fd); err != nil { + log.Printf("Error reading database from S3: %v", err) + } + _ = fd.Close() + err = s.read() } - - err := <-rc if err != nil { - databaseOperations.WithLabelValues(dbOpPut, dbResError).Inc() - } else { - databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc() + log.Println("Error reading database:", err) } - - return err + s.calculateStatistics() + return s } -func (s *levelDBStore) merge(key string, addrs []DatabaseAddress, seen int64) error { +func (s *inMemoryStore) put(key *protocol.DeviceID, rec DatabaseRecord) error { + t0 := time.Now() + s.m.Store(*key, rec) + databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc() + databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds()) + return nil +} + +func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error { t0 := time.Now() - defer func() { - databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds()) - }() - rc := make(chan error) newRec := DatabaseRecord{ Addresses: addrs, Seen: seen, } - s.inbox <- func() { - // grab the existing record - oldRec, err := s.get(key) - if err != nil { - // "not found" is not an error from get, so this is serious - // stuff only - rc <- err - return - } - newRec = merge(newRec, oldRec) + oldRec, _ := s.m.Load(*key) + newRec = merge(newRec, oldRec) + s.m.Store(*key, newRec) - // We replicate s.put() functionality here ourselves instead of - // calling it because we want to serialize our get above together - // with the put in the same function. - size := newRec.Size() - if len(s.marshalBuf) < size { - s.marshalBuf = make([]byte, size) - } - n, _ := newRec.MarshalTo(s.marshalBuf) - rc <- s.db.Put([]byte(key), s.marshalBuf[:n], nil) - } + databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc() + databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds()) - err := <-rc - if err != nil { - databaseOperations.WithLabelValues(dbOpMerge, dbResError).Inc() - } else { - databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc() - } - - return err + return nil } -func (s *levelDBStore) get(key string) (DatabaseRecord, error) { +func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) { t0 := time.Now() defer func() { databaseOperationSeconds.WithLabelValues(dbOpGet).Observe(time.Since(t0).Seconds()) }() - keyBs := []byte(key) - val, err := s.db.Get(keyBs, nil) - if err == leveldb.ErrNotFound { + rec, ok := s.m.Load(*key) + if !ok { databaseOperations.WithLabelValues(dbOpGet, dbResNotFound).Inc() return DatabaseRecord{}, nil } - if err != nil { - databaseOperations.WithLabelValues(dbOpGet, dbResError).Inc() - return DatabaseRecord{}, err - } - - var rec DatabaseRecord - - if err := rec.Unmarshal(val); err != nil { - databaseOperations.WithLabelValues(dbOpGet, dbResUnmarshalError).Inc() - return DatabaseRecord{}, nil - } rec.Addresses = expire(rec.Addresses, s.clock.Now().UnixNano()) databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc() return rec, nil } -func (s *levelDBStore) Serve(ctx context.Context) error { - t := time.NewTimer(0) +func (s *inMemoryStore) Serve(ctx context.Context) error { + t := time.NewTimer(s.flushInterval) defer t.Stop() - defer s.db.Close() - // Start the statistics serve routine. It will exit with us when - // statisticsTrigger is closed. - statisticsTrigger := make(chan struct{}) - statisticsDone := make(chan struct{}) - go s.statisticsServe(statisticsTrigger, statisticsDone) + if s.flushInterval <= 0 { + t.Stop() + } loop: for { select { - case fn := <-s.inbox: - // Run function in serialized order. - fn() - case <-t.C: - // Trigger the statistics routine to do its thing in the - // background. - statisticsTrigger <- struct{}{} - - case <-statisticsDone: - // The statistics routine is done with one iteratation, schedule - // the next. - t.Reset(databaseStatisticsInterval) + if err := s.write(); err != nil { + log.Println("Error writing database:", err) + } + s.calculateStatistics() + t.Reset(s.flushInterval) case <-ctx.Done(): // We're done. - close(statisticsTrigger) break loop } } - // Also wait for statisticsServe to return - <-statisticsDone + return s.write() +} + +func (s *inMemoryStore) calculateStatistics() { + t0 := time.Now() + nowNanos := t0.UnixNano() + cutoff24h := t0.Add(-24 * time.Hour).UnixNano() + cutoff1w := t0.Add(-7 * 24 * time.Hour).UnixNano() + current, currentIPv4, currentIPv6, last24h, last1w, errors := 0, 0, 0, 0, 0, 0 + + s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool { + // If there are addresses that have not expired it's a current + // record, otherwise account it based on when it was last seen + // (last 24 hours or last week) or finally as inactice. + addrs := expire(rec.Addresses, nowNanos) + switch { + case len(addrs) > 0: + current++ + seenIPv4, seenIPv6 := false, false + for _, addr := range addrs { + uri, err := url.Parse(addr.Address) + if err != nil { + continue + } + host, _, err := net.SplitHostPort(uri.Host) + if err != nil { + continue + } + if ip := net.ParseIP(host); ip != nil && ip.To4() != nil { + seenIPv4 = true + } else if ip != nil { + seenIPv6 = true + } + if seenIPv4 && seenIPv6 { + break + } + } + if seenIPv4 { + currentIPv4++ + } + if seenIPv6 { + currentIPv6++ + } + case rec.Seen > cutoff24h: + last24h++ + case rec.Seen > cutoff1w: + last1w++ + default: + // drop the record if it's older than a week + s.m.Delete(key) + } + return true + }) + + databaseKeys.WithLabelValues("current").Set(float64(current)) + databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4)) + databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6)) + databaseKeys.WithLabelValues("last24h").Set(float64(last24h)) + databaseKeys.WithLabelValues("last1w").Set(float64(last1w)) + databaseKeys.WithLabelValues("error").Set(float64(errors)) + databaseStatisticsSeconds.Set(time.Since(t0).Seconds()) +} + +func (s *inMemoryStore) write() (err error) { + t0 := time.Now() + defer func() { + if err == nil { + databaseWriteSeconds.Set(time.Since(t0).Seconds()) + databaseLastWritten.Set(float64(t0.Unix())) + } + }() + + dbf := path.Join(s.dir, "records.db") + fd, err := os.Create(dbf + ".tmp") + if err != nil { + return err + } + bw := bufio.NewWriter(fd) + + var buf []byte + var rangeErr error + now := s.clock.Now().UnixNano() + cutoff1w := s.clock.Now().Add(-7 * 24 * time.Hour).UnixNano() + s.m.Range(func(key protocol.DeviceID, value DatabaseRecord) bool { + if value.Seen < cutoff1w { + // drop the record if it's older than a week + return true + } + rec := ReplicationRecord{ + Key: key[:], + Addresses: expire(value.Addresses, now), + Seen: value.Seen, + } + s := rec.Size() + if s+4 > len(buf) { + buf = make([]byte, s+4) + } + n, err := rec.MarshalTo(buf[4:]) + if err != nil { + rangeErr = err + return false + } + binary.BigEndian.PutUint32(buf, uint32(n)) + if _, err := bw.Write(buf[:n+4]); err != nil { + rangeErr = err + return false + } + return true + }) + if rangeErr != nil { + _ = fd.Close() + return rangeErr + } + + if err := bw.Flush(); err != nil { + _ = fd.Close + return err + } + if err := fd.Close(); err != nil { + return err + } + if err := os.Rename(dbf+".tmp", dbf); err != nil { + return err + } + + if os.Getenv("PODINDEX") == "0" { + // Upload to S3 + fd, err = os.Open(dbf) + if err != nil { + log.Printf("Error uploading database to S3: %v", err) + return nil + } + defer fd.Close() + if err := s3Upload(fd); err != nil { + log.Printf("Error uploading database to S3: %v", err) + } + } return nil } -func (s *levelDBStore) statisticsServe(trigger <-chan struct{}, done chan<- struct{}) { - defer close(done) +func (s *inMemoryStore) read() error { + fd, err := os.Open(path.Join(s.dir, "records.db")) + if err != nil { + return err + } + defer fd.Close() - for range trigger { - t0 := time.Now() - nowNanos := t0.UnixNano() - cutoff24h := t0.Add(-24 * time.Hour).UnixNano() - cutoff1w := t0.Add(-7 * 24 * time.Hour).UnixNano() - cutoff2Mon := t0.Add(-60 * 24 * time.Hour).UnixNano() - current, currentIPv4, currentIPv6, last24h, last1w, inactive, errors := 0, 0, 0, 0, 0, 0, 0 - - iter := s.db.NewIterator(&util.Range{}, nil) - for iter.Next() { - // Attempt to unmarshal the record and count the - // failure if there's something wrong with it. - var rec DatabaseRecord - if err := rec.Unmarshal(iter.Value()); err != nil { - errors++ - continue - } - - // If there are addresses that have not expired it's a current - // record, otherwise account it based on when it was last seen - // (last 24 hours or last week) or finally as inactice. - addrs := expire(rec.Addresses, nowNanos) - switch { - case len(addrs) > 0: - current++ - seenIPv4, seenIPv6 := false, false - for _, addr := range addrs { - uri, err := url.Parse(addr.Address) - if err != nil { - continue - } - host, _, err := net.SplitHostPort(uri.Host) - if err != nil { - continue - } - if ip := net.ParseIP(host); ip != nil && ip.To4() != nil { - seenIPv4 = true - } else if ip != nil { - seenIPv6 = true - } - if seenIPv4 && seenIPv6 { - break - } - } - if seenIPv4 { - currentIPv4++ - } - if seenIPv6 { - currentIPv6++ - } - case rec.Seen > cutoff24h: - last24h++ - case rec.Seen > cutoff1w: - last1w++ - case rec.Seen > cutoff2Mon: - inactive++ - case rec.Missed < cutoff2Mon: - // It hasn't been seen lately and we haven't recorded - // someone asking for this device in a long time either; - // delete the record. - if err := s.db.Delete(iter.Key(), nil); err != nil { - databaseOperations.WithLabelValues(dbOpDelete, dbResError).Inc() - } else { - databaseOperations.WithLabelValues(dbOpDelete, dbResSuccess).Inc() - } - default: - inactive++ + br := bufio.NewReader(fd) + var buf []byte + for { + var n uint32 + if err := binary.Read(br, binary.BigEndian, &n); err != nil { + if err == io.EOF { + break } + return err + } + if int(n) > len(buf) { + buf = make([]byte, n) + } + if _, err := io.ReadFull(br, buf[:n]); err != nil { + return err + } + rec := ReplicationRecord{} + if err := rec.Unmarshal(buf[:n]); err != nil { + return err + } + key, err := protocol.DeviceIDFromBytes(rec.Key) + if err != nil { + key, err = protocol.DeviceIDFromString(string(rec.Key)) + } + if err != nil { + log.Println("Bad device ID:", err) + continue } - iter.Release() - - databaseKeys.WithLabelValues("current").Set(float64(current)) - databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4)) - databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6)) - databaseKeys.WithLabelValues("last24h").Set(float64(last24h)) - databaseKeys.WithLabelValues("last1w").Set(float64(last1w)) - databaseKeys.WithLabelValues("inactive").Set(float64(inactive)) - databaseKeys.WithLabelValues("error").Set(float64(errors)) - databaseStatisticsSeconds.Set(time.Since(t0).Seconds()) - - // Signal that we are done and can be scheduled again. - done <- struct{}{} + s.m.Store(key, DatabaseRecord{ + Addresses: rec.Addresses, + Seen: rec.Seen, + }) } + return nil } // merge returns the merged result of the two database records a and b. The @@ -411,3 +445,36 @@ func (s databaseAddressOrder) Swap(a, b int) { func (s databaseAddressOrder) Len() int { return len(s) } + +func s3Upload(r io.Reader) error { + sess, err := session.NewSession(&aws.Config{ + Region: aws.String("fr-par"), + Endpoint: aws.String("s3.fr-par.scw.cloud"), + }) + if err != nil { + return err + } + uploader := s3manager.NewUploader(sess) + _, err = uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String("syncthing-discovery"), + Key: aws.String("discovery.db"), + Body: r, + }) + return err +} + +func s3Download(w io.WriterAt) error { + sess, err := session.NewSession(&aws.Config{ + Region: aws.String("fr-par"), + Endpoint: aws.String("s3.fr-par.scw.cloud"), + }) + if err != nil { + return err + } + downloader := s3manager.NewDownloader(sess) + _, err = downloader.Download(w, &s3.GetObjectInput{ + Bucket: aws.String("syncthing-discovery"), + Key: aws.String("discovery.db"), + }) + return err +} diff --git a/cmd/stdiscosrv/database.pb.go b/cmd/stdiscosrv/database.pb.go index ccc7ed340..cf51a7143 100644 --- a/cmd/stdiscosrv/database.pb.go +++ b/cmd/stdiscosrv/database.pb.go @@ -25,9 +25,7 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type DatabaseRecord struct { Addresses []DatabaseAddress `protobuf:"bytes,1,rep,name=addresses,proto3" json:"addresses"` - Misses int32 `protobuf:"varint,2,opt,name=misses,proto3" json:"misses,omitempty"` Seen int64 `protobuf:"varint,3,opt,name=seen,proto3" json:"seen,omitempty"` - Missed int64 `protobuf:"varint,4,opt,name=missed,proto3" json:"missed,omitempty"` } func (m *DatabaseRecord) Reset() { *m = DatabaseRecord{} } @@ -64,7 +62,7 @@ func (m *DatabaseRecord) XXX_DiscardUnknown() { var xxx_messageInfo_DatabaseRecord proto.InternalMessageInfo type ReplicationRecord struct { - Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` Addresses []DatabaseAddress `protobuf:"bytes,2,rep,name=addresses,proto3" json:"addresses"` Seen int64 `protobuf:"varint,3,opt,name=seen,proto3" json:"seen,omitempty"` } @@ -149,24 +147,23 @@ func init() { func init() { proto.RegisterFile("database.proto", fileDescriptor_b90fe3356ea5df07) } var fileDescriptor_b90fe3356ea5df07 = []byte{ - // 270 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x90, 0x41, 0x4a, 0xc4, 0x30, - 0x18, 0x85, 0x9b, 0x49, 0x1d, 0x99, 0x08, 0xa3, 0x06, 0x94, 0x20, 0x12, 0x4b, 0xdd, 0x74, 0xd5, - 0x01, 0x5d, 0xb9, 0x74, 0xd0, 0x0b, 0xe4, 0x06, 0xe9, 0xe4, 0x77, 0x08, 0x3a, 0x4d, 0x49, 0x2a, - 0xe8, 0x29, 0xf4, 0x58, 0x5d, 0xce, 0xd2, 0x95, 0x68, 0x7b, 0x11, 0x69, 0x26, 0x55, 0x14, 0x37, - 0xb3, 0x7b, 0xdf, 0xff, 0xbf, 0x97, 0xbc, 0x84, 0x4c, 0x95, 0xac, 0x65, 0x21, 0x1d, 0xe4, 0x95, - 0x35, 0xb5, 0xa1, 0xf1, 0x4a, 0xea, 0xf2, 0xe4, 0xdc, 0x42, 0x65, 0xdc, 0xcc, 0x8f, 0x8a, 0xc7, - 0xbb, 0xd9, 0xd2, 0x2c, 0x8d, 0x07, 0xaf, 0x36, 0xd6, 0xf4, 0x05, 0x91, 0xe9, 0x4d, 0x48, 0x0b, - 0x58, 0x18, 0xab, 0xe8, 0x15, 0x99, 0x48, 0xa5, 0x2c, 0x38, 0x07, 0x8e, 0xa1, 0x04, 0x67, 0x7b, - 0x17, 0x47, 0x79, 0x7f, 0x62, 0x3e, 0x18, 0xaf, 0x37, 0xeb, 0x79, 0xdc, 0xbc, 0x9f, 0x45, 0xe2, - 0xc7, 0x4d, 0x8f, 0xc9, 0x78, 0xa5, 0x7d, 0x6e, 0x94, 0xa0, 0x6c, 0x47, 0x04, 0xa2, 0x94, 0xc4, - 0x0e, 0xa0, 0x64, 0x38, 0x41, 0x19, 0x16, 0x5e, 0x7f, 0x7b, 0x15, 0x8b, 0xfd, 0x34, 0x50, 0x5a, - 0x93, 0x43, 0x01, 0xd5, 0x83, 0x5e, 0xc8, 0x5a, 0x9b, 0x32, 0x74, 0x3a, 0x20, 0xf8, 0x1e, 0x9e, - 0x19, 0x4a, 0x50, 0x36, 0x11, 0xbd, 0xfc, 0xdd, 0x72, 0xb4, 0x55, 0xcb, 0x7f, 0xda, 0xa4, 0xb7, - 0x64, 0xff, 0x4f, 0x8e, 0x32, 0xb2, 0x1b, 0x32, 0xe1, 0xde, 0x01, 0xfb, 0x0d, 0x3c, 0x55, 0xda, - 0x86, 0x77, 0x62, 0x31, 0xe0, 0xfc, 0xb4, 0xf9, 0xe4, 0x51, 0xd3, 0x72, 0xb4, 0x6e, 0x39, 0xfa, - 0x68, 0x39, 0x7a, 0xed, 0x78, 0xb4, 0xee, 0x78, 0xf4, 0xd6, 0xf1, 0xa8, 0x18, 0xfb, 0x3f, 0xbf, - 0xfc, 0x0a, 0x00, 0x00, 0xff, 0xff, 0x7a, 0xa2, 0xf6, 0x1e, 0xb0, 0x01, 0x00, 0x00, + // 243 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4b, 0x49, 0x2c, 0x49, + 0x4c, 0x4a, 0x2c, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0xc9, 0x4d, 0xcc, 0xcc, + 0x93, 0x52, 0x2e, 0x4a, 0x2d, 0xc8, 0x2f, 0xd6, 0x07, 0x0b, 0x25, 0x95, 0xa6, 0xe9, 0xa7, 0xe7, + 0xa7, 0xe7, 0x83, 0x39, 0x60, 0x16, 0x44, 0xa9, 0x52, 0x3c, 0x17, 0x9f, 0x0b, 0x54, 0x73, 0x50, + 0x6a, 0x72, 0x7e, 0x51, 0x8a, 0x90, 0x25, 0x17, 0x67, 0x62, 0x4a, 0x4a, 0x51, 0x6a, 0x71, 0x71, + 0x6a, 0xb1, 0x04, 0xa3, 0x02, 0xb3, 0x06, 0xb7, 0x91, 0xa8, 0x1e, 0xc8, 0x40, 0x3d, 0x98, 0x42, + 0x47, 0x88, 0xb4, 0x13, 0xcb, 0x89, 0x7b, 0xf2, 0x0c, 0x41, 0x08, 0xd5, 0x42, 0x42, 0x5c, 0x2c, + 0xc5, 0xa9, 0xa9, 0x79, 0x12, 0xcc, 0x0a, 0x8c, 0x1a, 0xcc, 0x41, 0x60, 0xb6, 0x52, 0x09, 0x97, + 0x60, 0x50, 0x6a, 0x41, 0x4e, 0x66, 0x72, 0x62, 0x49, 0x66, 0x7e, 0x1e, 0xd4, 0x0e, 0x01, 0x2e, + 0xe6, 0xec, 0xd4, 0x4a, 0x09, 0x46, 0x05, 0x46, 0x0d, 0x9e, 0x20, 0x10, 0x13, 0xd5, 0x56, 0x26, + 0x8a, 0x6d, 0x75, 0xe5, 0xe2, 0x47, 0xd3, 0x27, 0x24, 0xc1, 0xc5, 0x0e, 0xd5, 0x03, 0xb6, 0x97, + 0x33, 0x08, 0xc6, 0x05, 0xc9, 0xa4, 0x56, 0x14, 0x64, 0x16, 0x81, 0x6d, 0x06, 0x99, 0x01, 0xe3, + 0x3a, 0xc9, 0x9c, 0x78, 0x28, 0xc7, 0x70, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f, + 0x1e, 0xc9, 0x31, 0x4e, 0x78, 0x2c, 0xc7, 0x70, 0xe1, 0xb1, 0x1c, 0xc3, 0x8d, 0xc7, 0x72, 0x0c, + 0x49, 0x6c, 0xe0, 0x20, 0x34, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0xc6, 0x0b, 0x9b, 0x77, 0x7f, + 0x01, 0x00, 0x00, } func (m *DatabaseRecord) Marshal() (dAtA []byte, err error) { @@ -189,21 +186,11 @@ func (m *DatabaseRecord) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if m.Missed != 0 { - i = encodeVarintDatabase(dAtA, i, uint64(m.Missed)) - i-- - dAtA[i] = 0x20 - } if m.Seen != 0 { i = encodeVarintDatabase(dAtA, i, uint64(m.Seen)) i-- dAtA[i] = 0x18 } - if m.Misses != 0 { - i = encodeVarintDatabase(dAtA, i, uint64(m.Misses)) - i-- - dAtA[i] = 0x10 - } if len(m.Addresses) > 0 { for iNdEx := len(m.Addresses) - 1; iNdEx >= 0; iNdEx-- { { @@ -328,15 +315,9 @@ func (m *DatabaseRecord) Size() (n int) { n += 1 + l + sovDatabase(uint64(l)) } } - if m.Misses != 0 { - n += 1 + sovDatabase(uint64(m.Misses)) - } if m.Seen != 0 { n += 1 + sovDatabase(uint64(m.Seen)) } - if m.Missed != 0 { - n += 1 + sovDatabase(uint64(m.Missed)) - } return n } @@ -447,25 +428,6 @@ func (m *DatabaseRecord) Unmarshal(dAtA []byte) error { return err } iNdEx = postIndex - case 2: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Misses", wireType) - } - m.Misses = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowDatabase - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Misses |= int32(b&0x7F) << shift - if b < 0x80 { - break - } - } case 3: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Seen", wireType) @@ -485,25 +447,6 @@ func (m *DatabaseRecord) Unmarshal(dAtA []byte) error { break } } - case 4: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Missed", wireType) - } - m.Missed = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowDatabase - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Missed |= int64(b&0x7F) << shift - if b < 0x80 { - break - } - } default: iNdEx = preIndex skippy, err := skipDatabase(dAtA[iNdEx:]) @@ -558,7 +501,7 @@ func (m *ReplicationRecord) Unmarshal(dAtA []byte) error { if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) } - var stringLen uint64 + var byteLen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowDatabase @@ -568,23 +511,25 @@ func (m *ReplicationRecord) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - stringLen |= uint64(b&0x7F) << shift + byteLen |= int(b&0x7F) << shift if b < 0x80 { break } } - intStringLen := int(stringLen) - if intStringLen < 0 { + if byteLen < 0 { return ErrInvalidLengthDatabase } - postIndex := iNdEx + intStringLen + postIndex := iNdEx + byteLen if postIndex < 0 { return ErrInvalidLengthDatabase } if postIndex > l { return io.ErrUnexpectedEOF } - m.Key = string(dAtA[iNdEx:postIndex]) + m.Key = append(m.Key[:0], dAtA[iNdEx:postIndex]...) + if m.Key == nil { + m.Key = []byte{} + } iNdEx = postIndex case 2: if wireType != 2 { diff --git a/cmd/stdiscosrv/database.proto b/cmd/stdiscosrv/database.proto index 919cc601e..e53dd9f1b 100644 --- a/cmd/stdiscosrv/database.proto +++ b/cmd/stdiscosrv/database.proto @@ -17,15 +17,11 @@ option (gogoproto.goproto_sizecache_all) = false; message DatabaseRecord { repeated DatabaseAddress addresses = 1 [(gogoproto.nullable) = false]; - int32 misses = 2; // Number of lookups* without hits int64 seen = 3; // Unix nanos, last device announce - int64 missed = 4; // Unix nanos, last* failed lookup } -// *) Not every lookup results in a write, so may not be completely accurate - message ReplicationRecord { - string key = 1; + bytes key = 1; // raw 32 byte device ID repeated DatabaseAddress addresses = 2 [(gogoproto.nullable) = false]; int64 seen = 3; // Unix nanos, last device announce } diff --git a/cmd/stdiscosrv/database_test.go b/cmd/stdiscosrv/database_test.go index 14b15059d..a476424f2 100644 --- a/cmd/stdiscosrv/database_test.go +++ b/cmd/stdiscosrv/database_test.go @@ -11,29 +11,25 @@ import ( "fmt" "testing" "time" + + "github.com/syncthing/syncthing/lib/protocol" ) func TestDatabaseGetSet(t *testing.T) { - db, err := newMemoryLevelDBStore() - if err != nil { - t.Fatal(err) - } + db := newInMemoryStore(t.TempDir(), 0) ctx, cancel := context.WithCancel(context.Background()) go db.Serve(ctx) defer cancel() // Check missing record - rec, err := db.get("abcd") + rec, err := db.get(&protocol.EmptyDeviceID) if err != nil { t.Error("not found should not be an error") } if len(rec.Addresses) != 0 { t.Error("addresses should be empty") } - if rec.Misses != 0 { - t.Error("missing should be zero") - } // Set up a clock @@ -46,13 +42,13 @@ func TestDatabaseGetSet(t *testing.T) { rec.Addresses = []DatabaseAddress{ {Address: "tcp://1.2.3.4:5", Expires: tc.Now().Add(time.Minute).UnixNano()}, } - if err := db.put("abcd", rec); err != nil { + if err := db.put(&protocol.EmptyDeviceID, rec); err != nil { t.Fatal(err) } // Verify it - rec, err = db.get("abcd") + rec, err = db.get(&protocol.EmptyDeviceID) if err != nil { t.Fatal(err) } @@ -72,13 +68,13 @@ func TestDatabaseGetSet(t *testing.T) { addrs := []DatabaseAddress{ {Address: "tcp://6.7.8.9:0", Expires: tc.Now().Add(time.Minute).UnixNano()}, } - if err := db.merge("abcd", addrs, tc.Now().UnixNano()); err != nil { + if err := db.merge(&protocol.EmptyDeviceID, addrs, tc.Now().UnixNano()); err != nil { t.Fatal(err) } // Verify it - rec, err = db.get("abcd") + rec, err = db.get(&protocol.EmptyDeviceID) if err != nil { t.Fatal(err) } @@ -101,7 +97,7 @@ func TestDatabaseGetSet(t *testing.T) { // Verify it - rec, err = db.get("abcd") + rec, err = db.get(&protocol.EmptyDeviceID) if err != nil { t.Fatal(err) } @@ -114,40 +110,18 @@ func TestDatabaseGetSet(t *testing.T) { t.Error("incorrect address") } - // Put a record with misses - - rec = DatabaseRecord{Misses: 42, Missed: tc.Now().UnixNano()} - if err := db.put("efgh", rec); err != nil { - t.Fatal(err) - } - - // Verify it - - rec, err = db.get("efgh") - if err != nil { - t.Fatal(err) - } - if len(rec.Addresses) != 0 { - t.Log(rec.Addresses) - t.Fatal("should have no addresses") - } - if rec.Misses != 42 { - t.Log(rec.Misses) - t.Error("incorrect misses") - } - // Set an address addrs = []DatabaseAddress{ {Address: "tcp://6.7.8.9:0", Expires: tc.Now().Add(time.Minute).UnixNano()}, } - if err := db.merge("efgh", addrs, tc.Now().UnixNano()); err != nil { + if err := db.merge(&protocol.GlobalDeviceID, addrs, tc.Now().UnixNano()); err != nil { t.Fatal(err) } // Verify it - rec, err = db.get("efgh") + rec, err = db.get(&protocol.GlobalDeviceID) if err != nil { t.Fatal(err) } @@ -155,10 +129,6 @@ func TestDatabaseGetSet(t *testing.T) { t.Log(rec.Addresses) t.Fatal("should have one address") } - if rec.Misses != 0 { - t.Log(rec.Misses) - t.Error("should have no misses") - } } func TestFilter(t *testing.T) { diff --git a/cmd/stdiscosrv/main.go b/cmd/stdiscosrv/main.go index 4a1245266..3a4be8e6f 100644 --- a/cmd/stdiscosrv/main.go +++ b/cmd/stdiscosrv/main.go @@ -14,6 +14,7 @@ import ( "net" "net/http" "os" + "os/signal" "runtime" "strings" "time" @@ -24,7 +25,6 @@ import ( "github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/rand" "github.com/syncthing/syncthing/lib/tlsutil" - "github.com/syndtr/goleveldb/leveldb/opt" "github.com/thejerf/suture/v4" ) @@ -39,17 +39,12 @@ const ( errorRetryAfterSeconds = 1500 errorRetryFuzzSeconds = 300 - // Retry for not found is minSeconds + failures * incSeconds + - // random(fuzz), where failures is the number of consecutive lookups - // with no answer, up to maxSeconds. The fuzz is applied after capping - // to maxSeconds. - notFoundRetryMinSeconds = 60 - notFoundRetryMaxSeconds = 3540 - notFoundRetryIncSeconds = 10 - notFoundRetryFuzzSeconds = 60 - - // How often (in requests) we serialize the missed counter to database. - notFoundMissesWriteInterval = 10 + // Retry for not found is notFoundRetrySeenSeconds for records we have + // seen an announcement for (but it's not active right now) and + // notFoundRetryUnknownSeconds for records we have never seen (or not + // seen within the last week). + notFoundRetryUnknownMinSeconds = 60 + notFoundRetryUnknownMaxSeconds = 3600 httpReadTimeout = 5 * time.Second httpWriteTimeout = 5 * time.Second @@ -59,14 +54,6 @@ const ( replicationOutboxSize = 10000 ) -// These options make the database a little more optimized for writes, at -// the expense of some memory usage and risk of losing writes in a (system) -// crash. -var levelDBOptions = &opt.Options{ - NoSync: true, - WriteBuffer: 32 << 20, // default 4<<20 -} - var debug = false func main() { @@ -81,16 +68,15 @@ func main() { var replKeyFile string var useHTTP bool var compression bool - var largeDB bool var amqpAddress string - missesIncrease := 1 + var flushInterval time.Duration log.SetOutput(os.Stdout) log.SetFlags(0) flag.StringVar(&certFile, "cert", "./cert.pem", "Certificate file") flag.StringVar(&keyFile, "key", "./key.pem", "Key file") - flag.StringVar(&dir, "db-dir", "./discovery.db", "Database directory") + flag.StringVar(&dir, "db-dir", ".", "Database directory") flag.BoolVar(&debug, "debug", false, "Print debug output") flag.BoolVar(&useHTTP, "http", false, "Listen on HTTP (behind an HTTPS proxy)") flag.BoolVar(&compression, "compression", true, "Enable GZIP compression of responses") @@ -100,9 +86,8 @@ func main() { flag.StringVar(&replicationListen, "replication-listen", ":19200", "Replication listen address") flag.StringVar(&replCertFile, "replication-cert", "", "Certificate file for replication") flag.StringVar(&replKeyFile, "replication-key", "", "Key file for replication") - flag.BoolVar(&largeDB, "large-db", false, "Use larger database settings") flag.StringVar(&amqpAddress, "amqp-address", "", "Address to AMQP broker") - flag.IntVar(&missesIncrease, "misses-increase", 1, "How many times to increase the misses counter on each miss") + flag.DurationVar(&flushInterval, "flush-interval", 5*time.Minute, "Interval between database flushes") showVersion := flag.Bool("version", false, "Show version") flag.Parse() @@ -113,15 +98,6 @@ func main() { buildInfo.WithLabelValues(build.Version, runtime.Version(), build.User, build.Date.UTC().Format("2006-01-02T15:04:05Z")).Set(1) - if largeDB { - levelDBOptions.BlockCacheCapacity = 64 << 20 - levelDBOptions.BlockSize = 64 << 10 - levelDBOptions.CompactionTableSize = 16 << 20 - levelDBOptions.CompactionTableSizeMultiplier = 2.0 - levelDBOptions.WriteBuffer = 64 << 20 - levelDBOptions.CompactionL0Trigger = 8 - } - cert, err := tls.LoadX509KeyPair(certFile, keyFile) if os.IsNotExist(err) { log.Println("Failed to load keypair. Generating one, this might take a while...") @@ -190,10 +166,7 @@ func main() { }) // Start the database. - db, err := newLevelDBStore(dir) - if err != nil { - log.Fatalln("Open database:", err) - } + db := newInMemoryStore(dir, flushInterval) main.Add(db) // Start any replication senders. @@ -218,16 +191,8 @@ func main() { main.Add(kr) } - go func() { - for range time.NewTicker(time.Second).C { - for _, r := range repl { - r.send("", nil, time.Now().UnixNano()) - } - } - }() - // Start the main API server. - qs := newAPISrv(listen, cert, db, repl, useHTTP, compression, missesIncrease) + qs := newAPISrv(listen, cert, db, repl, useHTTP, compression) main.Add(qs) // If we have a metrics port configured, start a metrics handler. @@ -239,6 +204,18 @@ func main() { }() } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Cancel on signal + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, os.Interrupt) + go func() { + sig := <-signalChan + log.Printf("Received signal %s; shutting down", sig) + cancel() + }() + // Engage! - main.Serve(context.Background()) + main.Serve(ctx) } diff --git a/cmd/stdiscosrv/replication.go b/cmd/stdiscosrv/replication.go index e7aa2894a..a39304e1a 100644 --- a/cmd/stdiscosrv/replication.go +++ b/cmd/stdiscosrv/replication.go @@ -26,7 +26,7 @@ const ( ) type replicator interface { - send(key string, addrs []DatabaseAddress, seen int64) + send(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) } // a replicationSender tries to connect to the remote address and provide @@ -144,9 +144,9 @@ func (s *replicationSender) String() string { return fmt.Sprintf("replicationSender(%q)", s.dst) } -func (s *replicationSender) send(key string, ps []DatabaseAddress, seen int64) { +func (s *replicationSender) send(key *protocol.DeviceID, ps []DatabaseAddress, seen int64) { item := ReplicationRecord{ - Key: key, + Key: key[:], Addresses: ps, Seen: seen, } @@ -163,7 +163,7 @@ func (s *replicationSender) send(key string, ps []DatabaseAddress, seen int64) { // a replicationMultiplexer sends to multiple replicators type replicationMultiplexer []replicator -func (m replicationMultiplexer) send(key string, ps []DatabaseAddress, seen int64) { +func (m replicationMultiplexer) send(key *protocol.DeviceID, ps []DatabaseAddress, seen int64) { for _, s := range m { // each send is nonblocking s.send(key, ps, seen) @@ -290,9 +290,18 @@ func (l *replicationListener) handle(ctx context.Context, conn net.Conn) { replicationRecvsTotal.WithLabelValues("error").Inc() continue } + id, err := protocol.DeviceIDFromBytes(rec.Key) + if err != nil { + id, err = protocol.DeviceIDFromString(string(rec.Key)) + } + if err != nil { + log.Println("Replication device ID:", err) + replicationRecvsTotal.WithLabelValues("error").Inc() + continue + } // Store - l.db.merge(rec.Key, rec.Addresses, rec.Seen) + l.db.merge(&id, rec.Addresses, rec.Seen) replicationRecvsTotal.WithLabelValues("success").Inc() } } diff --git a/cmd/stdiscosrv/stats.go b/cmd/stdiscosrv/stats.go index ba9ccb40d..619a4ba32 100644 --- a/cmd/stdiscosrv/stats.go +++ b/cmd/stdiscosrv/stats.go @@ -96,13 +96,28 @@ var ( Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001}, }, []string{"operation"}) - retryAfterHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{ - Namespace: "syncthing", - Subsystem: "discovery", - Name: "retry_after_seconds", - Help: "Retry-After header value in seconds.", - Buckets: prometheus.ExponentialBuckets(60, 2, 7), // 60, 120, 240, 480, 960, 1920, 3840 - }) + databaseWriteSeconds = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "syncthing", + Subsystem: "discovery", + Name: "database_write_seconds", + Help: "Time spent writing the database.", + }) + databaseLastWritten = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "syncthing", + Subsystem: "discovery", + Name: "database_last_written", + Help: "Timestamp of the last successful database write.", + }) + + retryAfterLevel = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "syncthing", + Subsystem: "discovery", + Name: "retry_after_seconds", + Help: "Retry-After header value in seconds.", + }, []string{"name"}) ) const ( @@ -123,5 +138,6 @@ func init() { replicationSendsTotal, replicationRecvsTotal, databaseKeys, databaseStatisticsSeconds, databaseOperations, databaseOperationSeconds, - retryAfterHistogram) + databaseWriteSeconds, databaseLastWritten, + retryAfterLevel) } diff --git a/go.mod b/go.mod index 89503f138..bfc8bc654 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22.0 require ( github.com/AudriusButkevicius/recli v0.0.7-0.20220911121932-d000ce8fbf0f github.com/alecthomas/kong v0.9.0 + github.com/aws/aws-sdk-go v1.55.5 github.com/calmh/incontainer v1.0.0 github.com/calmh/xdr v1.1.0 github.com/ccding/go-stun v0.1.5 @@ -28,6 +29,7 @@ require ( github.com/oschwald/geoip2-golang v1.11.0 github.com/pierrec/lz4/v4 v4.1.21 github.com/prometheus/client_golang v1.19.1 + github.com/puzpuzpuz/xsync/v3 v3.4.0 github.com/quic-go/quic-go v0.46.0 github.com/rabbitmq/amqp091-go v1.10.0 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 @@ -67,6 +69,7 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nxadm/tail v1.4.11 // indirect diff --git a/go.sum b/go.sum index 7dc7397a0..7ad39adcd 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= +github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/calmh/glob v0.0.0-20220615080505-1d823af5017b h1:Fjm4GuJ+TGMgqfGHN42IQArJb77CfD/mAwLbDUoJe6g= @@ -124,6 +126,10 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6 github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= @@ -194,6 +200,8 @@ github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4= +github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/quic-go/quic-go v0.46.0 h1:uuwLClEEyk1DNvchH8uCByQVjo3yKL9opKulExNDs7Y= github.com/quic-go/quic-go v0.46.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI= github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw= @@ -381,7 +389,9 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From 77f7778292a9dcce5b70b6e179cd10c006a05708 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Sat, 7 Sep 2024 09:05:49 +0200 Subject: [PATCH 03/14] feat(stdiscosrv): enable HTTP profiler --- cmd/stdiscosrv/main.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/stdiscosrv/main.go b/cmd/stdiscosrv/main.go index 3a4be8e6f..330cb049b 100644 --- a/cmd/stdiscosrv/main.go +++ b/cmd/stdiscosrv/main.go @@ -19,6 +19,8 @@ import ( "strings" "time" + _ "net/http/pprof" + "github.com/prometheus/client_golang/prometheus/promhttp" _ "github.com/syncthing/syncthing/lib/automaxprocs" "github.com/syncthing/syncthing/lib/build" From 822b6ac36bf25da728eac09aa42f1aeb1a029c40 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Sun, 8 Sep 2024 21:16:13 +0200 Subject: [PATCH 04/14] chore(stdiscosrv): reduce unnecessary allocations in merge --- cmd/stdiscosrv/database.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index 29686e584..a6cbd424f 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -350,7 +350,7 @@ func merge(a, b DatabaseRecord) DatabaseRecord { } res := DatabaseRecord{ - Addresses: make([]DatabaseAddress, 0, len(a.Addresses)+len(b.Addresses)), + Addresses: make([]DatabaseAddress, 0, max(len(a.Addresses), len(b.Addresses))), Seen: a.Seen, } if b.Seen > a.Seen { From f9b72330a86b68a897527a2bd42d4bf772ec6c44 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 9 Sep 2024 08:24:09 +0200 Subject: [PATCH 05/14] chore(stdiscosrv): reduce allocations in cert handling --- cmd/stdiscosrv/apisrv.go | 35 ++++++++++++++++++++++++++--------- cmd/stdiscosrv/apisrv_test.go | 4 +++- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/cmd/stdiscosrv/apisrv.go b/cmd/stdiscosrv/apisrv.go index d60a55896..1d202d7cf 100644 --- a/cmd/stdiscosrv/apisrv.go +++ b/cmd/stdiscosrv/apisrv.go @@ -367,7 +367,7 @@ func certificateBytes(req *http.Request) ([]byte, error) { } bs = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: hdr}) - } else if hdr := req.Header.Get("X-Forwarded-Tls-Client-Cert"); hdr != "" { + } else if cert := req.Header.Get("X-Forwarded-Tls-Client-Cert"); cert != "" { // Traefik 2 passtlsclientcert // // The certificate is in PEM format, maybe with URL encoding @@ -375,19 +375,36 @@ func certificateBytes(req *http.Request) ([]byte, error) { // statements. We need to decode, reinstate the newlines every 64 // character and add statements for the PEM decoder - if strings.Contains(hdr, "%") { - if unesc, err := url.QueryUnescape(hdr); err == nil { - hdr = unesc + if strings.Contains(cert, "%") { + if unesc, err := url.QueryUnescape(cert); err == nil { + cert = unesc } } - for i := 64; i < len(hdr); i += 65 { - hdr = hdr[:i] + "\n" + hdr[i:] + const ( + header = "-----BEGIN CERTIFICATE-----" + footer = "-----END CERTIFICATE-----" + ) + + var b bytes.Buffer + b.Grow(len(header) + 1 + len(cert) + len(cert)/64 + 1 + len(footer) + 1) + + b.WriteString(header) + b.WriteByte('\n') + + for i := 0; i < len(cert); i += 64 { + end := i + 64 + if end > len(cert) { + end = len(cert) + } + b.WriteString(cert[i:end]) + b.WriteByte('\n') } - hdr = "-----BEGIN CERTIFICATE-----\n" + hdr - hdr += "\n-----END CERTIFICATE-----\n" - bs = []byte(hdr) + b.WriteString(footer) + b.WriteByte('\n') + + bs = b.Bytes() } if bs == nil { diff --git a/cmd/stdiscosrv/apisrv_test.go b/cmd/stdiscosrv/apisrv_test.go index 07723a0e0..fa96de8bd 100644 --- a/cmd/stdiscosrv/apisrv_test.go +++ b/cmd/stdiscosrv/apisrv_test.go @@ -15,6 +15,7 @@ import ( "net/http" "net/http/httptest" "os" + "regexp" "strings" "testing" @@ -122,6 +123,7 @@ func BenchmarkAPIRequests(b *testing.B) { if err != nil { b.Fatal(err) } + certBs = regexp.MustCompile(`---[^\n]+---\n`).ReplaceAll(certBs, nil) certString := string(strings.ReplaceAll(string(certBs), "\n", " ")) devID := protocol.NewDeviceID(crt.Certificate[0]) @@ -132,7 +134,7 @@ func BenchmarkAPIRequests(b *testing.B) { url := srv.URL + "/v2/?device=" + devIDString for i := 0; i < b.N; i++ { req, _ := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"addresses":["tcp://10.10.10.10:42000"]}`)) - req.Header.Set("X-Ssl-Cert", certString) + req.Header.Set("X-Forwarded-Tls-Client-Cert", certString) resp, err := http.DefaultClient.Do(req) if err != nil { b.Fatal(err) From 5c2fcbfd196f98f7b3e61967003fa77c2660a5af Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Tue, 10 Sep 2024 11:51:07 +0200 Subject: [PATCH 06/14] chore(stdiscosrv): simplify sorting --- cmd/stdiscosrv/apisrv.go | 4 ++-- cmd/stdiscosrv/database.go | 41 ++++++++++---------------------------- 2 files changed, 12 insertions(+), 33 deletions(-) diff --git a/cmd/stdiscosrv/apisrv.go b/cmd/stdiscosrv/apisrv.go index 1d202d7cf..2c7c874c9 100644 --- a/cmd/stdiscosrv/apisrv.go +++ b/cmd/stdiscosrv/apisrv.go @@ -22,7 +22,7 @@ import ( "net" "net/http" "net/url" - "sort" + "slices" "strconv" "strings" "sync" @@ -311,7 +311,7 @@ func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string) // The address slice must always be sorted for database merges to work // properly. - sort.Sort(databaseAddressOrder(dbAddrs)) + slices.SortFunc(dbAddrs, DatabaseAddress.Cmp) seen := now.UnixNano() if s.repl != nil { diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index a6cbd424f..407168260 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -11,6 +11,7 @@ package main import ( "bufio" + "cmp" "context" "encoding/binary" "io" @@ -19,7 +20,7 @@ import ( "net/url" "os" "path" - "sort" + "slices" "time" "github.com/aws/aws-sdk-go/aws" @@ -326,6 +327,7 @@ func (s *inMemoryStore) read() error { continue } + slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp) s.m.Store(key, DatabaseRecord{ Addresses: rec.Addresses, Seen: rec.Seen, @@ -339,15 +341,6 @@ func (s *inMemoryStore) read() error { // chosen for any duplicates. func merge(a, b DatabaseRecord) DatabaseRecord { // Both lists must be sorted for this to work. - if !sort.IsSorted(databaseAddressOrder(a.Addresses)) { - log.Println("Warning: bug: addresses not correctly sorted in merge") - a.Addresses = sortedAddressCopy(a.Addresses) - } - if !sort.IsSorted(databaseAddressOrder(b.Addresses)) { - // no warning because this is the side we read from disk and it may - // legitimately predate correct sorting. - b.Addresses = sortedAddressCopy(b.Addresses) - } res := DatabaseRecord{ Addresses: make([]DatabaseAddress, 0, max(len(a.Addresses), len(b.Addresses))), @@ -425,27 +418,6 @@ func expire(addrs []DatabaseAddress, now int64) []DatabaseAddress { return addrs } -func sortedAddressCopy(addrs []DatabaseAddress) []DatabaseAddress { - sorted := make([]DatabaseAddress, len(addrs)) - copy(sorted, addrs) - sort.Sort(databaseAddressOrder(sorted)) - return sorted -} - -type databaseAddressOrder []DatabaseAddress - -func (s databaseAddressOrder) Less(a, b int) bool { - return s[a].Address < s[b].Address -} - -func (s databaseAddressOrder) Swap(a, b int) { - s[a], s[b] = s[b], s[a] -} - -func (s databaseAddressOrder) Len() int { - return len(s) -} - func s3Upload(r io.Reader) error { sess, err := session.NewSession(&aws.Config{ Region: aws.String("fr-par"), @@ -478,3 +450,10 @@ func s3Download(w io.WriterAt) error { }) return err } + +func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) { + if c := cmp.Compare(d.Address, other.Address); c != 0 { + return c + } + return cmp.Compare(d.Expires, other.Expires) +} From 66fb65b01f638a2700d2e51a734f8f53b32cf486 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Tue, 10 Sep 2024 12:22:37 +0200 Subject: [PATCH 07/14] chore(stdiscosrv): use order-preserving expire --- cmd/stdiscosrv/database.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index 407168260..24291378c 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -29,7 +29,6 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/puzpuzpuz/xsync/v3" "github.com/syncthing/syncthing/lib/protocol" - "github.com/syncthing/syncthing/lib/sliceutil" ) type clock interface { @@ -405,12 +404,14 @@ loop: // expire returns the list of addresses after removing expired entries. // Expiration happen in place, so the slice given as the parameter is -// destroyed. Internal order is not preserved. +// destroyed. Internal order is preserved. func expire(addrs []DatabaseAddress, now int64) []DatabaseAddress { i := 0 for i < len(addrs) { if addrs[i].Expires < now { - addrs = sliceutil.RemoveAndZero(addrs, i) + copy(addrs[i:], addrs[i+1:]) + addrs[len(addrs)-1] = DatabaseAddress{} + addrs = addrs[:len(addrs)-1] continue } i++ From 3d59740a0a151d199aa48d9e6ea754c0abf988a2 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Tue, 10 Sep 2024 15:49:17 +0200 Subject: [PATCH 08/14] chore(stdiscosrv): database writing logging --- cmd/stdiscosrv/database.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index 24291378c..c016a5ec8 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -214,7 +214,9 @@ func (s *inMemoryStore) calculateStatistics() { func (s *inMemoryStore) write() (err error) { t0 := time.Now() + log.Println("Writing database") defer func() { + log.Println("Finished writing database") if err == nil { databaseWriteSeconds.Set(time.Since(t0).Seconds()) databaseLastWritten.Set(float64(t0.Unix())) @@ -276,6 +278,7 @@ func (s *inMemoryStore) write() (err error) { if os.Getenv("PODINDEX") == "0" { // Upload to S3 + log.Println("Uploading database") fd, err = os.Open(dbf) if err != nil { log.Printf("Error uploading database to S3: %v", err) @@ -285,6 +288,7 @@ func (s *inMemoryStore) write() (err error) { if err := s3Upload(fd); err != nil { log.Printf("Error uploading database to S3: %v", err) } + log.Println("Finished uploading database") } return nil From b794726e1fc172e699dede1c33a3cbc65263e598 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Tue, 10 Sep 2024 19:41:22 +0200 Subject: [PATCH 09/14] chore(stdiscosrv): sched in loop --- cmd/stdiscosrv/database.go | 16 +++++++++++++++- cmd/stdiscosrv/main.go | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index c016a5ec8..e36fa72cb 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -14,12 +14,14 @@ import ( "cmp" "context" "encoding/binary" + "errors" "io" "log" "net" "net/url" "os" "path" + "runtime" "slices" "time" @@ -159,7 +161,13 @@ func (s *inMemoryStore) calculateStatistics() { cutoff1w := t0.Add(-7 * 24 * time.Hour).UnixNano() current, currentIPv4, currentIPv6, last24h, last1w, errors := 0, 0, 0, 0, 0, 0 + n := 0 s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool { + if n%1000 == 0 { + runtime.Gosched() + } + n++ + // If there are addresses that have not expired it's a current // record, otherwise account it based on when it was last seen // (last 24 hours or last week) or finally as inactice. @@ -234,7 +242,13 @@ func (s *inMemoryStore) write() (err error) { var rangeErr error now := s.clock.Now().UnixNano() cutoff1w := s.clock.Now().Add(-7 * 24 * time.Hour).UnixNano() + n := 0 s.m.Range(func(key protocol.DeviceID, value DatabaseRecord) bool { + if n%1000 == 0 { + runtime.Gosched() + } + n++ + if value.Seen < cutoff1w { // drop the record if it's older than a week return true @@ -306,7 +320,7 @@ func (s *inMemoryStore) read() error { for { var n uint32 if err := binary.Read(br, binary.BigEndian, &n); err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } return err diff --git a/cmd/stdiscosrv/main.go b/cmd/stdiscosrv/main.go index 330cb049b..9f118ab16 100644 --- a/cmd/stdiscosrv/main.go +++ b/cmd/stdiscosrv/main.go @@ -74,7 +74,7 @@ func main() { var flushInterval time.Duration log.SetOutput(os.Stdout) - log.SetFlags(0) + // log.SetFlags(0) flag.StringVar(&certFile, "cert", "./cert.pem", "Certificate file") flag.StringVar(&keyFile, "key", "./key.pem", "Key file") From f3f5557c8e0f8a58287832efa20e22a8e5de4135 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Wed, 11 Sep 2024 08:15:32 +0200 Subject: [PATCH 10/14] chore(stdiscosrv): improve expire, logging --- cmd/stdiscosrv/database.go | 62 ++++++++++++++++----------------- cmd/stdiscosrv/database_test.go | 2 +- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index e36fa72cb..0135a4071 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -63,7 +63,7 @@ func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore { flushInterval: flushInterval, clock: defaultClock{}, } - err := s.read() + nr, err := s.read() if os.IsNotExist(err) { // Try to read from AWS fd, cerr := os.Create(path.Join(s.dir, "records.db")) @@ -75,11 +75,12 @@ func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore { log.Printf("Error reading database from S3: %v", err) } _ = fd.Close() - err = s.read() + nr, err = s.read() } if err != nil { log.Println("Error reading database:", err) } + log.Printf("Read %d records from database", nr) s.calculateStatistics() return s } @@ -122,7 +123,7 @@ func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) { return DatabaseRecord{}, nil } - rec.Addresses = expire(rec.Addresses, s.clock.Now().UnixNano()) + rec.Addresses = expire(rec.Addresses, s.clock.Now()) databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc() return rec, nil } @@ -139,10 +140,13 @@ loop: for { select { case <-t.C: + log.Println("Flushing database") if err := s.write(); err != nil { log.Println("Error writing database:", err) } + log.Println("Calculating statistics") s.calculateStatistics() + log.Println("Finished calculating statistics") t.Reset(s.flushInterval) case <-ctx.Done(): @@ -155,10 +159,9 @@ loop: } func (s *inMemoryStore) calculateStatistics() { - t0 := time.Now() - nowNanos := t0.UnixNano() - cutoff24h := t0.Add(-24 * time.Hour).UnixNano() - cutoff1w := t0.Add(-7 * 24 * time.Hour).UnixNano() + now := s.clock.Now() + cutoff24h := now.Add(-24 * time.Hour).UnixNano() + cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano() current, currentIPv4, currentIPv6, last24h, last1w, errors := 0, 0, 0, 0, 0, 0 n := 0 @@ -168,15 +171,11 @@ func (s *inMemoryStore) calculateStatistics() { } n++ - // If there are addresses that have not expired it's a current - // record, otherwise account it based on when it was last seen - // (last 24 hours or last week) or finally as inactice. - addrs := expire(rec.Addresses, nowNanos) switch { - case len(addrs) > 0: + case len(rec.Addresses) > 0: current++ seenIPv4, seenIPv6 := false, false - for _, addr := range addrs { + for _, addr := range rec.Addresses { uri, err := url.Parse(addr.Address) if err != nil { continue @@ -217,7 +216,7 @@ func (s *inMemoryStore) calculateStatistics() { databaseKeys.WithLabelValues("last24h").Set(float64(last24h)) databaseKeys.WithLabelValues("last1w").Set(float64(last1w)) databaseKeys.WithLabelValues("error").Set(float64(errors)) - databaseStatisticsSeconds.Set(time.Since(t0).Seconds()) + databaseStatisticsSeconds.Set(time.Since(now).Seconds()) } func (s *inMemoryStore) write() (err error) { @@ -240,8 +239,8 @@ func (s *inMemoryStore) write() (err error) { var buf []byte var rangeErr error - now := s.clock.Now().UnixNano() - cutoff1w := s.clock.Now().Add(-7 * 24 * time.Hour).UnixNano() + now := s.clock.Now() + cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano() n := 0 s.m.Range(func(key protocol.DeviceID, value DatabaseRecord) bool { if n%1000 == 0 { @@ -308,32 +307,33 @@ func (s *inMemoryStore) write() (err error) { return nil } -func (s *inMemoryStore) read() error { +func (s *inMemoryStore) read() (int, error) { fd, err := os.Open(path.Join(s.dir, "records.db")) if err != nil { - return err + return 0, err } defer fd.Close() br := bufio.NewReader(fd) var buf []byte + nr := 0 for { var n uint32 if err := binary.Read(br, binary.BigEndian, &n); err != nil { if errors.Is(err, io.EOF) { break } - return err + return nr, err } if int(n) > len(buf) { buf = make([]byte, n) } if _, err := io.ReadFull(br, buf[:n]); err != nil { - return err + return nr, err } rec := ReplicationRecord{} if err := rec.Unmarshal(buf[:n]); err != nil { - return err + return nr, err } key, err := protocol.DeviceIDFromBytes(rec.Key) if err != nil { @@ -349,8 +349,9 @@ func (s *inMemoryStore) read() error { Addresses: rec.Addresses, Seen: rec.Seen, }) + nr++ } - return nil + return nr, nil } // merge returns the merged result of the two database records a and b. The @@ -423,18 +424,15 @@ loop: // expire returns the list of addresses after removing expired entries. // Expiration happen in place, so the slice given as the parameter is // destroyed. Internal order is preserved. -func expire(addrs []DatabaseAddress, now int64) []DatabaseAddress { - i := 0 - for i < len(addrs) { - if addrs[i].Expires < now { - copy(addrs[i:], addrs[i+1:]) - addrs[len(addrs)-1] = DatabaseAddress{} - addrs = addrs[:len(addrs)-1] - continue +func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress { + cutoff := now.UnixNano() + naddrs := addrs[:0] + for i := range addrs { + if addrs[i].Expires >= cutoff { + naddrs = append(naddrs, addrs[i]) } - i++ } - return addrs + return naddrs } func s3Upload(r io.Reader) error { diff --git a/cmd/stdiscosrv/database_test.go b/cmd/stdiscosrv/database_test.go index a476424f2..b90592e85 100644 --- a/cmd/stdiscosrv/database_test.go +++ b/cmd/stdiscosrv/database_test.go @@ -160,7 +160,7 @@ func TestFilter(t *testing.T) { } for _, tc := range cases { - res := expire(tc.a, 10) + res := expire(tc.a, time.Unix(0, 10)) if fmt.Sprint(res) != fmt.Sprint(tc.b) { t.Errorf("Incorrect result %v, expected %v", res, tc.b) } From 63e465928244c3ff6b85863750c56a6c6b5f132e Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Wed, 11 Sep 2024 08:34:11 +0200 Subject: [PATCH 11/14] chore(stdiscosrv): less garbage in statistics --- cmd/stdiscosrv/database.go | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index 0135a4071..f23ec73f9 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -17,12 +17,11 @@ import ( "errors" "io" "log" - "net" - "net/url" "os" "path" "runtime" "slices" + "strings" "time" "github.com/aws/aws-sdk-go/aws" @@ -140,13 +139,13 @@ loop: for { select { case <-t.C: + log.Println("Calculating statistics") + s.calculateStatistics() log.Println("Flushing database") if err := s.write(); err != nil { log.Println("Error writing database:", err) } - log.Println("Calculating statistics") - s.calculateStatistics() - log.Println("Finished calculating statistics") + log.Println("Finished flushing database") t.Reset(s.flushInterval) case <-ctx.Done(): @@ -162,7 +161,7 @@ func (s *inMemoryStore) calculateStatistics() { now := s.clock.Now() cutoff24h := now.Add(-24 * time.Hour).UnixNano() cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano() - current, currentIPv4, currentIPv6, last24h, last1w, errors := 0, 0, 0, 0, 0, 0 + current, currentIPv4, currentIPv6, last24h, last1w := 0, 0, 0, 0, 0 n := 0 s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool { @@ -171,23 +170,16 @@ func (s *inMemoryStore) calculateStatistics() { } n++ + addresses := expire(rec.Addresses, now) switch { - case len(rec.Addresses) > 0: + case len(addresses) > 0: current++ seenIPv4, seenIPv6 := false, false for _, addr := range rec.Addresses { - uri, err := url.Parse(addr.Address) - if err != nil { - continue - } - host, _, err := net.SplitHostPort(uri.Host) - if err != nil { - continue - } - if ip := net.ParseIP(host); ip != nil && ip.To4() != nil { - seenIPv4 = true - } else if ip != nil { + if strings.Contains(addr.Address, "[") { seenIPv6 = true + } else { + seenIPv4 = true } if seenIPv4 && seenIPv6 { break @@ -215,15 +207,12 @@ func (s *inMemoryStore) calculateStatistics() { databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6)) databaseKeys.WithLabelValues("last24h").Set(float64(last24h)) databaseKeys.WithLabelValues("last1w").Set(float64(last1w)) - databaseKeys.WithLabelValues("error").Set(float64(errors)) databaseStatisticsSeconds.Set(time.Since(now).Seconds()) } func (s *inMemoryStore) write() (err error) { t0 := time.Now() - log.Println("Writing database") defer func() { - log.Println("Finished writing database") if err == nil { databaseWriteSeconds.Set(time.Since(t0).Seconds()) databaseLastWritten.Set(float64(t0.Unix())) @@ -254,7 +243,7 @@ func (s *inMemoryStore) write() (err error) { } rec := ReplicationRecord{ Key: key[:], - Addresses: expire(value.Addresses, now), + Addresses: value.Addresses, Seen: value.Seen, } s := rec.Size() @@ -346,7 +335,7 @@ func (s *inMemoryStore) read() (int, error) { slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp) s.m.Store(key, DatabaseRecord{ - Addresses: rec.Addresses, + Addresses: expire(rec.Addresses, s.clock.Now()), Seen: rec.Seen, }) nr++ From 6505e123bbd4488dca616df146d1b47af04540fc Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Wed, 11 Sep 2024 11:31:09 +0200 Subject: [PATCH 12/14] chore(stdiscosrv): clean up s3 handling --- cmd/stdiscosrv/apisrv_test.go | 2 +- cmd/stdiscosrv/database.go | 52 +++------------- cmd/stdiscosrv/database_test.go | 2 +- cmd/stdiscosrv/main.go | 103 ++++++++++++++++++-------------- cmd/stdiscosrv/s3.go | 97 ++++++++++++++++++++++++++++++ 5 files changed, 164 insertions(+), 92 deletions(-) create mode 100644 cmd/stdiscosrv/s3.go diff --git a/cmd/stdiscosrv/apisrv_test.go b/cmd/stdiscosrv/apisrv_test.go index fa96de8bd..272c0ca12 100644 --- a/cmd/stdiscosrv/apisrv_test.go +++ b/cmd/stdiscosrv/apisrv_test.go @@ -107,7 +107,7 @@ func addr(host string, port int) *net.TCPAddr { } func BenchmarkAPIRequests(b *testing.B) { - db := newInMemoryStore(b.TempDir(), 0) + db := newInMemoryStore(b.TempDir(), 0, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go db.Serve(ctx) diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index f23ec73f9..95de82eab 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -24,10 +24,6 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/puzpuzpuz/xsync/v3" "github.com/syncthing/syncthing/lib/protocol" ) @@ -52,25 +48,27 @@ type inMemoryStore struct { m *xsync.MapOf[protocol.DeviceID, DatabaseRecord] dir string flushInterval time.Duration + s3 *s3Copier clock clock } -func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore { +func newInMemoryStore(dir string, flushInterval time.Duration, s3 *s3Copier) *inMemoryStore { s := &inMemoryStore{ m: xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](), dir: dir, flushInterval: flushInterval, + s3: s3, clock: defaultClock{}, } nr, err := s.read() - if os.IsNotExist(err) { + if os.IsNotExist(err) && s3 != nil { // Try to read from AWS fd, cerr := os.Create(path.Join(s.dir, "records.db")) if cerr != nil { log.Println("Error creating database file:", err) return s } - if err := s3Download(fd); err != nil { + if err := s3.downloadLatest(fd); err != nil { log.Printf("Error reading database from S3: %v", err) } _ = fd.Close() @@ -278,16 +276,15 @@ func (s *inMemoryStore) write() (err error) { return err } - if os.Getenv("PODINDEX") == "0" { - // Upload to S3 - log.Println("Uploading database") + // Upload to S3 + if s.s3 != nil { fd, err = os.Open(dbf) if err != nil { log.Printf("Error uploading database to S3: %v", err) return nil } defer fd.Close() - if err := s3Upload(fd); err != nil { + if err := s.s3.upload(fd); err != nil { log.Printf("Error uploading database to S3: %v", err) } log.Println("Finished uploading database") @@ -424,39 +421,6 @@ func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress { return naddrs } -func s3Upload(r io.Reader) error { - sess, err := session.NewSession(&aws.Config{ - Region: aws.String("fr-par"), - Endpoint: aws.String("s3.fr-par.scw.cloud"), - }) - if err != nil { - return err - } - uploader := s3manager.NewUploader(sess) - _, err = uploader.Upload(&s3manager.UploadInput{ - Bucket: aws.String("syncthing-discovery"), - Key: aws.String("discovery.db"), - Body: r, - }) - return err -} - -func s3Download(w io.WriterAt) error { - sess, err := session.NewSession(&aws.Config{ - Region: aws.String("fr-par"), - Endpoint: aws.String("s3.fr-par.scw.cloud"), - }) - if err != nil { - return err - } - downloader := s3manager.NewDownloader(sess) - _, err = downloader.Download(w, &s3.GetObjectInput{ - Bucket: aws.String("syncthing-discovery"), - Key: aws.String("discovery.db"), - }) - return err -} - func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) { if c := cmp.Compare(d.Address, other.Address); c != 0 { return c diff --git a/cmd/stdiscosrv/database_test.go b/cmd/stdiscosrv/database_test.go index b90592e85..89a557175 100644 --- a/cmd/stdiscosrv/database_test.go +++ b/cmd/stdiscosrv/database_test.go @@ -16,7 +16,7 @@ import ( ) func TestDatabaseGetSet(t *testing.T) { - db := newInMemoryStore(t.TempDir(), 0) + db := newInMemoryStore(t.TempDir(), 0, nil) ctx, cancel := context.WithCancel(context.Background()) go db.Serve(ctx) defer cancel() diff --git a/cmd/stdiscosrv/main.go b/cmd/stdiscosrv/main.go index 9f118ab16..daeb4b3c7 100644 --- a/cmd/stdiscosrv/main.go +++ b/cmd/stdiscosrv/main.go @@ -9,7 +9,6 @@ package main import ( "context" "crypto/tls" - "flag" "log" "net" "net/http" @@ -21,6 +20,7 @@ import ( _ "net/http/pprof" + "github.com/alecthomas/kong" "github.com/prometheus/client_golang/prometheus/promhttp" _ "github.com/syncthing/syncthing/lib/automaxprocs" "github.com/syncthing/syncthing/lib/build" @@ -58,52 +58,52 @@ const ( var debug = false +type CLI struct { + Cert string `group:"Listen" help:"Certificate file" default:"./cert.pem" env:"DISCOVERY_CERT_FILE"` + Key string `group:"Listen" help:"Key file" default:"./key.pem" env:"DISCOVERY_KEY_FILE"` + HTTP bool `group:"Listen" help:"Listen on HTTP (behind an HTTPS proxy)" env:"DISCOVERY_HTTP"` + Compression bool `group:"Listen" help:"Enable GZIP compression of responses" env:"DISCOVERY_COMPRESSION"` + Listen string `group:"Listen" help:"Listen address" default:":8443" env:"DISCOVERY_LISTEN"` + MetricsListen string `group:"Listen" help:"Metrics listen address" env:"DISCOVERY_METRICS_LISTEN"` + + Replicate []string `group:"Legacy replication" help:"Replication peers, id@address, comma separated" env:"DISCOVERY_REPLICATE"` + ReplicationListen string `group:"Legacy replication" help:"Replication listen address" default:":19200" env:"DISCOVERY_REPLICATION_LISTEN"` + ReplicationCert string `group:"Legacy replication" help:"Certificate file for replication" env:"DISCOVERY_REPLICATION_CERT_FILE"` + ReplicationKey string `group:"Legacy replication" help:"Key file for replication" env:"DISCOVERY_REPLICATION_KEY_FILE"` + + AMQPAddress string `group:"AMQP replication" help:"Address to AMQP broker" env:"DISCOVERY_AMQP_ADDRESS"` + + DBDir string `group:"Database" help:"Database directory" default:"." env:"DISCOVERY_DB_DIR"` + DBFlushInterval time.Duration `group:"Database" help:"Interval between database flushes" default:"5m" env:"DISCOVERY_DB_FLUSH_INTERVAL"` + + DBS3Endpoint string `name:"db-s3-endpoint" group:"Database (S3 backup)" help:"S3 endpoint for database" env:"DISCOVERY_DB_S3_ENDPOINT"` + DBS3Region string `name:"db-s3-region" group:"Database (S3 backup)" help:"S3 region for database" env:"DISCOVERY_DB_S3_REGION"` + DBS3Bucket string `name:"db-s3-bucket" group:"Database (S3 backup)" help:"S3 bucket for database" env:"DISCOVERY_DB_S3_BUCKET"` + DBS3AccessKeyID string `name:"db-s3-access-key-id" group:"Database (S3 backup)" help:"S3 access key ID for database" env:"DISCOVERY_DB_S3_ACCESS_KEY_ID"` + DBS3SecretKey string `name:"db-s3-secret-key" group:"Database (S3 backup)" help:"S3 secret key for database" env:"DISCOVERY_DB_S3_SECRET_KEY"` + + Debug bool `short:"d" help:"Print debug output" env:"DISCOVERY_DEBUG"` + Version bool `short:"v" help:"Print version and exit"` +} + func main() { - var listen string - var dir string - var metricsListen string - var replicationListen string - var replicationPeers string - var certFile string - var keyFile string - var replCertFile string - var replKeyFile string - var useHTTP bool - var compression bool - var amqpAddress string - var flushInterval time.Duration - log.SetOutput(os.Stdout) - // log.SetFlags(0) - flag.StringVar(&certFile, "cert", "./cert.pem", "Certificate file") - flag.StringVar(&keyFile, "key", "./key.pem", "Key file") - flag.StringVar(&dir, "db-dir", ".", "Database directory") - flag.BoolVar(&debug, "debug", false, "Print debug output") - flag.BoolVar(&useHTTP, "http", false, "Listen on HTTP (behind an HTTPS proxy)") - flag.BoolVar(&compression, "compression", true, "Enable GZIP compression of responses") - flag.StringVar(&listen, "listen", ":8443", "Listen address") - flag.StringVar(&metricsListen, "metrics-listen", "", "Metrics listen address") - flag.StringVar(&replicationPeers, "replicate", "", "Replication peers, id@address, comma separated") - flag.StringVar(&replicationListen, "replication-listen", ":19200", "Replication listen address") - flag.StringVar(&replCertFile, "replication-cert", "", "Certificate file for replication") - flag.StringVar(&replKeyFile, "replication-key", "", "Key file for replication") - flag.StringVar(&amqpAddress, "amqp-address", "", "Address to AMQP broker") - flag.DurationVar(&flushInterval, "flush-interval", 5*time.Minute, "Interval between database flushes") - showVersion := flag.Bool("version", false, "Show version") - flag.Parse() + var cli CLI + kong.Parse(&cli) + debug = cli.Debug log.Println(build.LongVersionFor("stdiscosrv")) - if *showVersion { + if cli.Version { return } buildInfo.WithLabelValues(build.Version, runtime.Version(), build.User, build.Date.UTC().Format("2006-01-02T15:04:05Z")).Set(1) - cert, err := tls.LoadX509KeyPair(certFile, keyFile) + cert, err := tls.LoadX509KeyPair(cli.Cert, cli.Key) if os.IsNotExist(err) { log.Println("Failed to load keypair. Generating one, this might take a while...") - cert, err = tlsutil.NewCertificate(certFile, keyFile, "stdiscosrv", 20*365) + cert, err = tlsutil.NewCertificate(cli.Cert, cli.Key, "stdiscosrv", 20*365) if err != nil { log.Fatalln("Failed to generate X509 key pair:", err) } @@ -114,8 +114,8 @@ func main() { log.Println("Server device ID is", devID) replCert := cert - if replCertFile != "" && replKeyFile != "" { - replCert, err = tls.LoadX509KeyPair(replCertFile, replKeyFile) + if cli.ReplicationCert != "" && cli.ReplicationKey != "" { + replCert, err = tls.LoadX509KeyPair(cli.ReplicationCert, cli.ReplicationKey) if err != nil { log.Fatalln("Failed to load replication keypair:", err) } @@ -126,8 +126,7 @@ func main() { // Parse the replication specs, if any. var allowedReplicationPeers []protocol.DeviceID var replicationDestinations []string - parts := strings.Split(replicationPeers, ",") - for _, part := range parts { + for _, part := range cli.Replicate { if part == "" { continue } @@ -165,10 +164,22 @@ func main() { // Root of the service tree. main := suture.New("main", suture.Spec{ PassThroughPanics: true, + Timeout: 2 * time.Minute, }) + // If configured, use S3 for database backups. + var s3c *s3Copier + if cli.DBS3Endpoint != "" { + hostname, err := os.Hostname() + if err != nil { + log.Fatalf("Failed to get hostname: %v", err) + } + key := hostname + ".db" + s3c = newS3Copier(cli.DBS3Endpoint, cli.DBS3Region, cli.DBS3Bucket, key, cli.DBS3AccessKeyID, cli.DBS3SecretKey) + } + // Start the database. - db := newInMemoryStore(dir, flushInterval) + db := newInMemoryStore(cli.DBDir, cli.DBFlushInterval, s3c) main.Add(db) // Start any replication senders. @@ -181,28 +192,28 @@ func main() { // If we have replication configured, start the replication listener. if len(allowedReplicationPeers) > 0 { - rl := newReplicationListener(replicationListen, replCert, allowedReplicationPeers, db) + rl := newReplicationListener(cli.ReplicationListen, replCert, allowedReplicationPeers, db) main.Add(rl) } // If we have an AMQP broker, start that - if amqpAddress != "" { + if cli.AMQPAddress != "" { clientID := rand.String(10) - kr := newAMQPReplicator(amqpAddress, clientID, db) + kr := newAMQPReplicator(cli.AMQPAddress, clientID, db) repl = append(repl, kr) main.Add(kr) } // Start the main API server. - qs := newAPISrv(listen, cert, db, repl, useHTTP, compression) + qs := newAPISrv(cli.Listen, cert, db, repl, cli.HTTP, cli.Compression) main.Add(qs) // If we have a metrics port configured, start a metrics handler. - if metricsListen != "" { + if cli.MetricsListen != "" { go func() { mux := http.NewServeMux() mux.Handle("/metrics", promhttp.Handler()) - log.Fatal(http.ListenAndServe(metricsListen, mux)) + log.Fatal(http.ListenAndServe(cli.MetricsListen, mux)) }() } diff --git a/cmd/stdiscosrv/s3.go b/cmd/stdiscosrv/s3.go new file mode 100644 index 000000000..f60cfee27 --- /dev/null +++ b/cmd/stdiscosrv/s3.go @@ -0,0 +1,97 @@ +// Copyright (C) 2024 The Syncthing Authors. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +package main + +import ( + "io" + "log" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" +) + +type s3Copier struct { + endpoint string + region string + bucket string + key string + accessKeyID string + secretKey string +} + +func newS3Copier(endpoint, region, bucket, key, accessKeyID, secretKey string) *s3Copier { + return &s3Copier{ + endpoint: endpoint, + region: region, + bucket: bucket, + key: key, + accessKeyID: accessKeyID, + secretKey: secretKey, + } +} + +func (s *s3Copier) upload(r io.Reader) error { + sess, err := session.NewSession(&aws.Config{ + Region: aws.String(s.region), + Endpoint: aws.String(s.endpoint), + Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""), + }) + if err != nil { + return err + } + + uploader := s3manager.NewUploader(sess) + _, err = uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(s.key), + Body: r, + }) + return err +} + +func (s *s3Copier) downloadLatest(w io.WriterAt) error { + sess, err := session.NewSession(&aws.Config{ + Region: aws.String(s.region), + Endpoint: aws.String(s.endpoint), + Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""), + }) + if err != nil { + return err + } + + svc := s3.New(sess) + resp, err := svc.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: aws.String(s.bucket)}) + if err != nil { + return err + } + + var lastKey string + var lastModified time.Time + var lastSize int64 + for _, item := range resp.Contents { + if item.LastModified.After(lastModified) && *item.Size > lastSize { + lastKey = *item.Key + lastModified = *item.LastModified + lastSize = *item.Size + } else if lastModified.Sub(*item.LastModified) < 5*time.Minute && *item.Size > lastSize { + lastKey = *item.Key + lastSize = *item.Size + } + } + + log.Println("Downloading database from", lastKey) + downloader := s3manager.NewDownloader(sess) + _, err = downloader.Download(w, &s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(lastKey), + }) + return err +} From 1616edcee3dcf43f916b95aa6fda2c9fca27df8f Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Wed, 11 Sep 2024 11:43:58 +0200 Subject: [PATCH 13/14] chore(stdiscosrv): remove legacy replication --- cmd/stdiscosrv/amqp.go | 4 - cmd/stdiscosrv/apisrv.go | 4 + cmd/stdiscosrv/main.go | 94 ++-------- cmd/stdiscosrv/replication.go | 334 ---------------------------------- 4 files changed, 19 insertions(+), 417 deletions(-) delete mode 100644 cmd/stdiscosrv/replication.go diff --git a/cmd/stdiscosrv/amqp.go b/cmd/stdiscosrv/amqp.go index e32eea49e..bb99c1e6d 100644 --- a/cmd/stdiscosrv/amqp.go +++ b/cmd/stdiscosrv/amqp.go @@ -7,7 +7,6 @@ package main import ( - "bytes" "context" "fmt" "io" @@ -164,9 +163,6 @@ func (s *amqpReceiver) Serve(ctx context.Context) error { replicationRecvsTotal.WithLabelValues("error").Inc() return fmt.Errorf("replication unmarshal: %w", err) } - if bytes.Equal(rec.Key, []byte("")) { - continue - } id, err := protocol.DeviceIDFromBytes(rec.Key) if err != nil { id, err = protocol.DeviceIDFromString(string(rec.Key)) diff --git a/cmd/stdiscosrv/apisrv.go b/cmd/stdiscosrv/apisrv.go index 2c7c874c9..fced38e09 100644 --- a/cmd/stdiscosrv/apisrv.go +++ b/cmd/stdiscosrv/apisrv.go @@ -51,6 +51,10 @@ type apiSrv struct { notSeenTracker *retryAfterTracker } +type replicator interface { + send(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) +} + type requestID int64 func (i requestID) String() string { diff --git a/cmd/stdiscosrv/main.go b/cmd/stdiscosrv/main.go index daeb4b3c7..069dd02a5 100644 --- a/cmd/stdiscosrv/main.go +++ b/cmd/stdiscosrv/main.go @@ -10,12 +10,10 @@ import ( "context" "crypto/tls" "log" - "net" "net/http" "os" "os/signal" "runtime" - "strings" "time" _ "net/http/pprof" @@ -66,11 +64,6 @@ type CLI struct { Listen string `group:"Listen" help:"Listen address" default:":8443" env:"DISCOVERY_LISTEN"` MetricsListen string `group:"Listen" help:"Metrics listen address" env:"DISCOVERY_METRICS_LISTEN"` - Replicate []string `group:"Legacy replication" help:"Replication peers, id@address, comma separated" env:"DISCOVERY_REPLICATE"` - ReplicationListen string `group:"Legacy replication" help:"Replication listen address" default:":19200" env:"DISCOVERY_REPLICATION_LISTEN"` - ReplicationCert string `group:"Legacy replication" help:"Certificate file for replication" env:"DISCOVERY_REPLICATION_CERT_FILE"` - ReplicationKey string `group:"Legacy replication" help:"Key file for replication" env:"DISCOVERY_REPLICATION_KEY_FILE"` - AMQPAddress string `group:"AMQP replication" help:"Address to AMQP broker" env:"DISCOVERY_AMQP_ADDRESS"` DBDir string `group:"Database" help:"Database directory" default:"." env:"DISCOVERY_DB_DIR"` @@ -100,65 +93,21 @@ func main() { buildInfo.WithLabelValues(build.Version, runtime.Version(), build.User, build.Date.UTC().Format("2006-01-02T15:04:05Z")).Set(1) - cert, err := tls.LoadX509KeyPair(cli.Cert, cli.Key) - if os.IsNotExist(err) { - log.Println("Failed to load keypair. Generating one, this might take a while...") - cert, err = tlsutil.NewCertificate(cli.Cert, cli.Key, "stdiscosrv", 20*365) - if err != nil { - log.Fatalln("Failed to generate X509 key pair:", err) - } - } else if err != nil { - log.Fatalln("Failed to load keypair:", err) - } - devID := protocol.NewDeviceID(cert.Certificate[0]) - log.Println("Server device ID is", devID) - - replCert := cert - if cli.ReplicationCert != "" && cli.ReplicationKey != "" { - replCert, err = tls.LoadX509KeyPair(cli.ReplicationCert, cli.ReplicationKey) - if err != nil { - log.Fatalln("Failed to load replication keypair:", err) - } - } - replDevID := protocol.NewDeviceID(replCert.Certificate[0]) - log.Println("Replication device ID is", replDevID) - - // Parse the replication specs, if any. - var allowedReplicationPeers []protocol.DeviceID - var replicationDestinations []string - for _, part := range cli.Replicate { - if part == "" { - continue - } - - fields := strings.Split(part, "@") - switch len(fields) { - case 2: - // This is an id@address specification. Grab the address for the - // destination list. Try to resolve it once to catch obvious - // syntax errors here rather than having the sender service fail - // repeatedly later. - _, err := net.ResolveTCPAddr("tcp", fields[1]) + var cert tls.Certificate + if !cli.HTTP { + var err error + cert, err = tls.LoadX509KeyPair(cli.Cert, cli.Key) + if os.IsNotExist(err) { + log.Println("Failed to load keypair. Generating one, this might take a while...") + cert, err = tlsutil.NewCertificate(cli.Cert, cli.Key, "stdiscosrv", 20*365) if err != nil { - log.Fatalln("Resolving address:", err) + log.Fatalln("Failed to generate X509 key pair:", err) } - replicationDestinations = append(replicationDestinations, fields[1]) - fallthrough // N.B. - - case 1: - // The first part is always a device ID. - id, err := protocol.DeviceIDFromString(fields[0]) - if err != nil { - log.Fatalln("Parsing device ID:", err) - } - if id == protocol.EmptyDeviceID { - log.Fatalf("Missing device ID for peer in %q", part) - } - allowedReplicationPeers = append(allowedReplicationPeers, id) - - default: - log.Fatalln("Unrecognized replication spec:", part) + } else if err != nil { + log.Fatalln("Failed to load keypair:", err) } + devID := protocol.NewDeviceID(cert.Certificate[0]) + log.Println("Server device ID is", devID) } // Root of the service tree. @@ -182,26 +131,13 @@ func main() { db := newInMemoryStore(cli.DBDir, cli.DBFlushInterval, s3c) main.Add(db) - // Start any replication senders. - var repl replicationMultiplexer - for _, dst := range replicationDestinations { - rs := newReplicationSender(dst, replCert, allowedReplicationPeers) - main.Add(rs) - repl = append(repl, rs) - } - - // If we have replication configured, start the replication listener. - if len(allowedReplicationPeers) > 0 { - rl := newReplicationListener(cli.ReplicationListen, replCert, allowedReplicationPeers, db) - main.Add(rl) - } - - // If we have an AMQP broker, start that + // If we have an AMQP broker for replication, start that + var repl replicator if cli.AMQPAddress != "" { clientID := rand.String(10) kr := newAMQPReplicator(cli.AMQPAddress, clientID, db) - repl = append(repl, kr) main.Add(kr) + repl = kr } // Start the main API server. diff --git a/cmd/stdiscosrv/replication.go b/cmd/stdiscosrv/replication.go deleted file mode 100644 index a39304e1a..000000000 --- a/cmd/stdiscosrv/replication.go +++ /dev/null @@ -1,334 +0,0 @@ -// Copyright (C) 2018 The Syncthing Authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at https://mozilla.org/MPL/2.0/. - -package main - -import ( - "context" - "crypto/tls" - "encoding/binary" - "fmt" - io "io" - "log" - "net" - "time" - - "github.com/syncthing/syncthing/lib/protocol" -) - -const ( - replicationReadTimeout = time.Minute - replicationWriteTimeout = 30 * time.Second - replicationHeartbeatInterval = time.Second * 30 -) - -type replicator interface { - send(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) -} - -// a replicationSender tries to connect to the remote address and provide -// them with a feed of replication updates. -type replicationSender struct { - dst string - cert tls.Certificate // our certificate - allowedIDs []protocol.DeviceID - outbox chan ReplicationRecord -} - -func newReplicationSender(dst string, cert tls.Certificate, allowedIDs []protocol.DeviceID) *replicationSender { - return &replicationSender{ - dst: dst, - cert: cert, - allowedIDs: allowedIDs, - outbox: make(chan ReplicationRecord, replicationOutboxSize), - } -} - -func (s *replicationSender) Serve(ctx context.Context) error { - // Sleep a little at startup. Peers often restart at the same time, and - // this avoid the service failing and entering backoff state - // unnecessarily, while also reducing the reconnect rate to something - // reasonable by default. - time.Sleep(2 * time.Second) - - tlsCfg := &tls.Config{ - Certificates: []tls.Certificate{s.cert}, - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: true, - } - - // Dial the TLS connection. - conn, err := tls.Dial("tcp", s.dst, tlsCfg) - if err != nil { - log.Println("Replication connect:", err) - return err - } - defer func() { - conn.SetWriteDeadline(time.Now().Add(time.Second)) - conn.Close() - }() - - // The replication stream is not especially latency sensitive, but it is - // quite a lot of data in small writes. Make it more efficient. - if tcpc, ok := conn.NetConn().(*net.TCPConn); ok { - _ = tcpc.SetNoDelay(false) - } - - // Get the other side device ID. - remoteID, err := deviceID(conn) - if err != nil { - log.Println("Replication connect:", err) - return err - } - - // Verify it's in the set of allowed device IDs. - if !deviceIDIn(remoteID, s.allowedIDs) { - log.Println("Replication connect: unexpected device ID:", remoteID) - return err - } - - heartBeatTicker := time.NewTicker(replicationHeartbeatInterval) - defer heartBeatTicker.Stop() - - // Send records. - buf := make([]byte, 1024) - for { - select { - case <-heartBeatTicker.C: - if len(s.outbox) > 0 { - // No need to send heartbeats if there are events/prevrious - // heartbeats to send, they will keep the connection alive. - continue - } - // Empty replication message is the heartbeat: - s.outbox <- ReplicationRecord{} - - case rec := <-s.outbox: - // Buffer must hold record plus four bytes for size - size := rec.Size() - if len(buf) < size+4 { - buf = make([]byte, size+4) - } - - // Record comes after the four bytes size - n, err := rec.MarshalTo(buf[4:]) - if err != nil { - // odd to get an error here, but we haven't sent anything - // yet so it's not fatal - replicationSendsTotal.WithLabelValues("error").Inc() - log.Println("Replication marshal:", err) - continue - } - binary.BigEndian.PutUint32(buf, uint32(n)) - - // Send - conn.SetWriteDeadline(time.Now().Add(replicationWriteTimeout)) - if _, err := conn.Write(buf[:4+n]); err != nil { - replicationSendsTotal.WithLabelValues("error").Inc() - log.Println("Replication write:", err) - // Yes, we are losing the replication event here. - return err - } - replicationSendsTotal.WithLabelValues("success").Inc() - - case <-ctx.Done(): - return nil - } - } -} - -func (s *replicationSender) String() string { - return fmt.Sprintf("replicationSender(%q)", s.dst) -} - -func (s *replicationSender) send(key *protocol.DeviceID, ps []DatabaseAddress, seen int64) { - item := ReplicationRecord{ - Key: key[:], - Addresses: ps, - Seen: seen, - } - - // The send should never block. The inbox is suitably buffered for at - // least a few seconds of stalls, which shouldn't happen in practice. - select { - case s.outbox <- item: - default: - replicationSendsTotal.WithLabelValues("drop").Inc() - } -} - -// a replicationMultiplexer sends to multiple replicators -type replicationMultiplexer []replicator - -func (m replicationMultiplexer) send(key *protocol.DeviceID, ps []DatabaseAddress, seen int64) { - for _, s := range m { - // each send is nonblocking - s.send(key, ps, seen) - } -} - -// replicationListener accepts incoming connections and reads replication -// items from them. Incoming items are applied to the KV store. -type replicationListener struct { - addr string - cert tls.Certificate - allowedIDs []protocol.DeviceID - db database -} - -func newReplicationListener(addr string, cert tls.Certificate, allowedIDs []protocol.DeviceID, db database) *replicationListener { - return &replicationListener{ - addr: addr, - cert: cert, - allowedIDs: allowedIDs, - db: db, - } -} - -func (l *replicationListener) Serve(ctx context.Context) error { - tlsCfg := &tls.Config{ - Certificates: []tls.Certificate{l.cert}, - ClientAuth: tls.RequestClientCert, - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: true, - } - - lst, err := tls.Listen("tcp", l.addr, tlsCfg) - if err != nil { - log.Println("Replication listen:", err) - return err - } - defer lst.Close() - - for { - select { - case <-ctx.Done(): - return nil - default: - } - - // Accept a connection - conn, err := lst.Accept() - if err != nil { - log.Println("Replication accept:", err) - return err - } - - // Figure out the other side device ID - remoteID, err := deviceID(conn.(*tls.Conn)) - if err != nil { - log.Println("Replication accept:", err) - conn.SetWriteDeadline(time.Now().Add(time.Second)) - conn.Close() - continue - } - - // Verify it is in the set of allowed device IDs - if !deviceIDIn(remoteID, l.allowedIDs) { - log.Println("Replication accept: unexpected device ID:", remoteID) - conn.SetWriteDeadline(time.Now().Add(time.Second)) - conn.Close() - continue - } - - go l.handle(ctx, conn) - } -} - -func (l *replicationListener) String() string { - return fmt.Sprintf("replicationListener(%q)", l.addr) -} - -func (l *replicationListener) handle(ctx context.Context, conn net.Conn) { - defer func() { - conn.SetWriteDeadline(time.Now().Add(time.Second)) - conn.Close() - }() - - buf := make([]byte, 1024) - - for { - select { - case <-ctx.Done(): - return - default: - } - - conn.SetReadDeadline(time.Now().Add(replicationReadTimeout)) - - // First four bytes are the size - if _, err := io.ReadFull(conn, buf[:4]); err != nil { - log.Println("Replication read size:", err) - replicationRecvsTotal.WithLabelValues("error").Inc() - return - } - - // Read the rest of the record - size := int(binary.BigEndian.Uint32(buf[:4])) - if len(buf) < size { - buf = make([]byte, size) - } - - if size == 0 { - // Heartbeat, ignore - continue - } - - if _, err := io.ReadFull(conn, buf[:size]); err != nil { - log.Println("Replication read record:", err) - replicationRecvsTotal.WithLabelValues("error").Inc() - return - } - - // Unmarshal - var rec ReplicationRecord - if err := rec.Unmarshal(buf[:size]); err != nil { - log.Println("Replication unmarshal:", err) - replicationRecvsTotal.WithLabelValues("error").Inc() - continue - } - id, err := protocol.DeviceIDFromBytes(rec.Key) - if err != nil { - id, err = protocol.DeviceIDFromString(string(rec.Key)) - } - if err != nil { - log.Println("Replication device ID:", err) - replicationRecvsTotal.WithLabelValues("error").Inc() - continue - } - - // Store - l.db.merge(&id, rec.Addresses, rec.Seen) - replicationRecvsTotal.WithLabelValues("success").Inc() - } -} - -func deviceID(conn *tls.Conn) (protocol.DeviceID, error) { - // Handshake may not be complete on the server side yet, which we need - // to get the client certificate. - if !conn.ConnectionState().HandshakeComplete { - if err := conn.Handshake(); err != nil { - return protocol.DeviceID{}, err - } - } - - // We expect exactly one certificate. - certs := conn.ConnectionState().PeerCertificates - if len(certs) != 1 { - return protocol.DeviceID{}, fmt.Errorf("unexpected number of certificates (%d != 1)", len(certs)) - } - - return protocol.NewDeviceID(certs[0].Raw), nil -} - -func deviceIDIn(id protocol.DeviceID, ids []protocol.DeviceID) bool { - for _, candidate := range ids { - if id == candidate { - return true - } - } - return false -} From 94d0195b6333b4807eb7e697f427bc90a1ac7ee2 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Fri, 13 Sep 2024 08:38:03 +0200 Subject: [PATCH 14/14] chore(stdiscosrv): hide internal/undocumented flags --- cmd/stdiscosrv/main.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cmd/stdiscosrv/main.go b/cmd/stdiscosrv/main.go index 069dd02a5..fed332f60 100644 --- a/cmd/stdiscosrv/main.go +++ b/cmd/stdiscosrv/main.go @@ -64,16 +64,16 @@ type CLI struct { Listen string `group:"Listen" help:"Listen address" default:":8443" env:"DISCOVERY_LISTEN"` MetricsListen string `group:"Listen" help:"Metrics listen address" env:"DISCOVERY_METRICS_LISTEN"` - AMQPAddress string `group:"AMQP replication" help:"Address to AMQP broker" env:"DISCOVERY_AMQP_ADDRESS"` - DBDir string `group:"Database" help:"Database directory" default:"." env:"DISCOVERY_DB_DIR"` DBFlushInterval time.Duration `group:"Database" help:"Interval between database flushes" default:"5m" env:"DISCOVERY_DB_FLUSH_INTERVAL"` - DBS3Endpoint string `name:"db-s3-endpoint" group:"Database (S3 backup)" help:"S3 endpoint for database" env:"DISCOVERY_DB_S3_ENDPOINT"` - DBS3Region string `name:"db-s3-region" group:"Database (S3 backup)" help:"S3 region for database" env:"DISCOVERY_DB_S3_REGION"` - DBS3Bucket string `name:"db-s3-bucket" group:"Database (S3 backup)" help:"S3 bucket for database" env:"DISCOVERY_DB_S3_BUCKET"` - DBS3AccessKeyID string `name:"db-s3-access-key-id" group:"Database (S3 backup)" help:"S3 access key ID for database" env:"DISCOVERY_DB_S3_ACCESS_KEY_ID"` - DBS3SecretKey string `name:"db-s3-secret-key" group:"Database (S3 backup)" help:"S3 secret key for database" env:"DISCOVERY_DB_S3_SECRET_KEY"` + DBS3Endpoint string `name:"db-s3-endpoint" group:"Database (S3 backup)" hidden:"true" help:"S3 endpoint for database" env:"DISCOVERY_DB_S3_ENDPOINT"` + DBS3Region string `name:"db-s3-region" group:"Database (S3 backup)" hidden:"true" help:"S3 region for database" env:"DISCOVERY_DB_S3_REGION"` + DBS3Bucket string `name:"db-s3-bucket" group:"Database (S3 backup)" hidden:"true" help:"S3 bucket for database" env:"DISCOVERY_DB_S3_BUCKET"` + DBS3AccessKeyID string `name:"db-s3-access-key-id" group:"Database (S3 backup)" hidden:"true" help:"S3 access key ID for database" env:"DISCOVERY_DB_S3_ACCESS_KEY_ID"` + DBS3SecretKey string `name:"db-s3-secret-key" group:"Database (S3 backup)" hidden:"true" help:"S3 secret key for database" env:"DISCOVERY_DB_S3_SECRET_KEY"` + + AMQPAddress string `group:"AMQP replication" hidden:"true" help:"Address to AMQP broker" env:"DISCOVERY_AMQP_ADDRESS"` Debug bool `short:"d" help:"Print debug output" env:"DISCOVERY_DEBUG"` Version bool `short:"v" help:"Print version and exit"`