diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 962cd8ecb..71452f662 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -1540,13 +1540,16 @@ func (sl *ServiceListener) Close() error { // We should only clean up state once. Otherwise we can stomp on state // created by new listeners. sl.closeOnce.Do(func() { + sl.s.mu.Lock() + defer sl.s.mu.Unlock() + // Two pieces of state we need to clear: // 1. The Service advertisement pref // 2. Artifacts in the serve config // Then we can close the listener. var adErr error - if err := sl.s.decrementServiceAdvertisement(sl.svcName); err != nil { + if err := sl.s.decrementServiceAdvertisementLocked(sl.svcName); err != nil { adErr = fmt.Errorf("managing Service advertisements: %w", err) } @@ -1594,10 +1597,7 @@ func (s *Server) advertiseService(name tailcfg.ServiceName) error { // advertising the Service. Advertisement of the Service will be withdrawn if // the count hits zero. It is an error to call this function when the Service is // not being advertised by this node. -func (s *Server) decrementServiceAdvertisement(name tailcfg.ServiceName) error { - s.mu.Lock() - defer s.mu.Unlock() - +func (s *Server) decrementServiceAdvertisementLocked(name tailcfg.ServiceName) error { cleanAdvertisement := func() error { delete(s.advertisedServices, name) advertised := s.lb.Prefs().AdvertiseServices() @@ -1683,7 +1683,11 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, if err := s.advertiseService(svcName); err != nil { return nil, fmt.Errorf("advertising Service: %w", err) } - onError = append(onError, func() { s.decrementServiceAdvertisement(svcName) }) + onError = append(onError, func() { + s.mu.Lock() + defer s.mu.Unlock() + s.decrementServiceAdvertisementLocked(svcName) + }) srvCfg := new(ipn.ServeConfig) sc, srvCfgETag, err := s.lb.ServeConfigETag() diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 4824a58a0..12a810ebf 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -1427,6 +1427,7 @@ func TestListenService(t *testing.T) { func TestListenServiceClose(t *testing.T) { tstest.Shard(t) + const serviceName = "svc:foo" diffServeConfig := func(a, b ipn.ServeConfigView) string { // We treat a mapping from svc:foo to nil or the zero value as if it @@ -1457,7 +1458,7 @@ func TestListenServiceClose(t *testing.T) { name: "TCP", run: func(t *testing.T, s *Server) { before := s.lb.ServeConfig() - ln := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 8080})) + ln := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 8080})) ln.Close() after := s.lb.ServeConfig() if diff := diffServeConfig(after, before); diff != "" { @@ -1469,7 +1470,7 @@ func TestListenServiceClose(t *testing.T) { name: "HTTP", run: func(t *testing.T, s *Server) { before := s.lb.ServeConfig() - ln := must.Get(s.ListenService("svc:foo", ServiceModeHTTP{Port: 8080})) + ln := must.Get(s.ListenService(serviceName, ServiceModeHTTP{Port: 8080})) ln.Close() after := s.lb.ServeConfig() if diff := diffServeConfig(after, before); diff != "" { @@ -1482,14 +1483,14 @@ func TestListenServiceClose(t *testing.T) { name: "two_listeners", run: func(t *testing.T, s *Server) { // Start a listener on 443. - ln1 := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 443})) + ln1 := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 443})) defer ln1.Close() // Save the serve config for this original listener. before := s.lb.ServeConfig() // Now start and close a new listener on a different port. - ln2 := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 8080})) + ln2 := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 8080})) ln2.Close() // The serve config for the original listener should be intact. @@ -1505,7 +1506,7 @@ func TestListenServiceClose(t *testing.T) { // should be automatically closed). name: "after_server_close", run: func(t *testing.T, s *Server) { - ln := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 8080})) + ln := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 8080})) // Close the server, then close the listener. must.Do(s.Close()) @@ -1519,13 +1520,57 @@ func TestListenServiceClose(t *testing.T) { } }, }, + { + // Regression test for https://github.com/tailscale/tailscale/issues/19169, + // in which concurrent ServiceListener.Close calls (by different + // listeners) would fail. + name: "concurrent_close", + run: func(t *testing.T, s *Server) { + const concurrentCloseCalls = 100 + + readyGroup := new(sync.WaitGroup) + closedGroup := new(sync.WaitGroup) + closeThemAll := make(chan (struct{})) + errC := make(chan error, concurrentCloseCalls) + for i := range concurrentCloseCalls { + readyGroup.Add(1) + closedGroup.Add(1) + ln := must.Get(s.ListenService(serviceName, ServiceModeTCP{ + Port: uint16(i + 1), + })) + go func() { + readyGroup.Done() + <-closeThemAll + errC <- ln.Close() + closedGroup.Done() + }() + } + + readyGroup.Wait() + close(closeThemAll) + closedGroup.Wait() + close(errC) + + var errs []error + for err := range errC { + if err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + t.Fatalf("%d close errors; sample: %v", len(errs), errs[0]) + } + if diff := diffServeConfig(s.lb.ServeConfig(), (&ipn.ServeConfig{}).View()); diff != "" { + t.Fatalf("expected empty config (-got, +want):\n%s", diff) + } + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := t.Context() - const serviceName = "svc:foo" controlURL, control := startControl(t) serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") @@ -1533,7 +1578,6 @@ func TestListenServiceClose(t *testing.T) { tt.run(t, serviceHost) }) - } }