From 1e2fdfd745fec5339a003ba68cce0ce9f785859a Mon Sep 17 00:00:00 2001 From: Harry Harpham Date: Mon, 30 Mar 2026 08:44:47 -0600 Subject: [PATCH] tsnet: fix bug in closing multiple ServiceListeners at once Prior to this change, closing multiple ServiceListeners concurrently could result in failures as the independent close operations vie for the attention of the Server's LocalBackend. The close operations would each obtain the current ETag of the serve config and try to write new serve config using this ETag. When one write invalidated the ETag of another, the latter would fail. Exacerbating the issue, ServiceListener.Close cannot be retried. This change resolves the bug by using Server.mu to synchronize across all ServiceListener.Close operations, ensuring they happen serially. Fixes #19169 Signed-off-by: Harry Harpham --- tsnet/tsnet.go | 16 ++++++++----- tsnet/tsnet_test.go | 58 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 962cd8ecb..71452f662 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -1540,13 +1540,16 @@ func (sl *ServiceListener) Close() error { // We should only clean up state once. Otherwise we can stomp on state // created by new listeners. sl.closeOnce.Do(func() { + sl.s.mu.Lock() + defer sl.s.mu.Unlock() + // Two pieces of state we need to clear: // 1. The Service advertisement pref // 2. Artifacts in the serve config // Then we can close the listener. var adErr error - if err := sl.s.decrementServiceAdvertisement(sl.svcName); err != nil { + if err := sl.s.decrementServiceAdvertisementLocked(sl.svcName); err != nil { adErr = fmt.Errorf("managing Service advertisements: %w", err) } @@ -1594,10 +1597,7 @@ func (s *Server) advertiseService(name tailcfg.ServiceName) error { // advertising the Service. Advertisement of the Service will be withdrawn if // the count hits zero. It is an error to call this function when the Service is // not being advertised by this node. -func (s *Server) decrementServiceAdvertisement(name tailcfg.ServiceName) error { - s.mu.Lock() - defer s.mu.Unlock() - +func (s *Server) decrementServiceAdvertisementLocked(name tailcfg.ServiceName) error { cleanAdvertisement := func() error { delete(s.advertisedServices, name) advertised := s.lb.Prefs().AdvertiseServices() @@ -1683,7 +1683,11 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, if err := s.advertiseService(svcName); err != nil { return nil, fmt.Errorf("advertising Service: %w", err) } - onError = append(onError, func() { s.decrementServiceAdvertisement(svcName) }) + onError = append(onError, func() { + s.mu.Lock() + defer s.mu.Unlock() + s.decrementServiceAdvertisementLocked(svcName) + }) srvCfg := new(ipn.ServeConfig) sc, srvCfgETag, err := s.lb.ServeConfigETag() diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 4824a58a0..12a810ebf 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -1427,6 +1427,7 @@ func TestListenService(t *testing.T) { func TestListenServiceClose(t *testing.T) { tstest.Shard(t) + const serviceName = "svc:foo" diffServeConfig := func(a, b ipn.ServeConfigView) string { // We treat a mapping from svc:foo to nil or the zero value as if it @@ -1457,7 +1458,7 @@ func TestListenServiceClose(t *testing.T) { name: "TCP", run: func(t *testing.T, s *Server) { before := s.lb.ServeConfig() - ln := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 8080})) + ln := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 8080})) ln.Close() after := s.lb.ServeConfig() if diff := diffServeConfig(after, before); diff != "" { @@ -1469,7 +1470,7 @@ func TestListenServiceClose(t *testing.T) { name: "HTTP", run: func(t *testing.T, s *Server) { before := s.lb.ServeConfig() - ln := must.Get(s.ListenService("svc:foo", ServiceModeHTTP{Port: 8080})) + ln := must.Get(s.ListenService(serviceName, ServiceModeHTTP{Port: 8080})) ln.Close() after := s.lb.ServeConfig() if diff := diffServeConfig(after, before); diff != "" { @@ -1482,14 +1483,14 @@ func TestListenServiceClose(t *testing.T) { name: "two_listeners", run: func(t *testing.T, s *Server) { // Start a listener on 443. - ln1 := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 443})) + ln1 := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 443})) defer ln1.Close() // Save the serve config for this original listener. before := s.lb.ServeConfig() // Now start and close a new listener on a different port. - ln2 := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 8080})) + ln2 := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 8080})) ln2.Close() // The serve config for the original listener should be intact. @@ -1505,7 +1506,7 @@ func TestListenServiceClose(t *testing.T) { // should be automatically closed). name: "after_server_close", run: func(t *testing.T, s *Server) { - ln := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 8080})) + ln := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 8080})) // Close the server, then close the listener. must.Do(s.Close()) @@ -1519,13 +1520,57 @@ func TestListenServiceClose(t *testing.T) { } }, }, + { + // Regression test for https://github.com/tailscale/tailscale/issues/19169, + // in which concurrent ServiceListener.Close calls (by different + // listeners) would fail. + name: "concurrent_close", + run: func(t *testing.T, s *Server) { + const concurrentCloseCalls = 100 + + readyGroup := new(sync.WaitGroup) + closedGroup := new(sync.WaitGroup) + closeThemAll := make(chan (struct{})) + errC := make(chan error, concurrentCloseCalls) + for i := range concurrentCloseCalls { + readyGroup.Add(1) + closedGroup.Add(1) + ln := must.Get(s.ListenService(serviceName, ServiceModeTCP{ + Port: uint16(i + 1), + })) + go func() { + readyGroup.Done() + <-closeThemAll + errC <- ln.Close() + closedGroup.Done() + }() + } + + readyGroup.Wait() + close(closeThemAll) + closedGroup.Wait() + close(errC) + + var errs []error + for err := range errC { + if err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + t.Fatalf("%d close errors; sample: %v", len(errs), errs[0]) + } + if diff := diffServeConfig(s.lb.ServeConfig(), (&ipn.ServeConfig{}).View()); diff != "" { + t.Fatalf("expected empty config (-got, +want):\n%s", diff) + } + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := t.Context() - const serviceName = "svc:foo" controlURL, control := startControl(t) serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") @@ -1533,7 +1578,6 @@ func TestListenServiceClose(t *testing.T) { tt.run(t, serviceHost) }) - } }