diff --git a/net/udprelay/server.go b/net/udprelay/server.go index d59578780..45127dfae 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -77,8 +77,8 @@ type Server struct { closeCh chan struct{} netChecker *netcheck.Client - mu sync.Mutex // guards the following fields - macSecrets [][blake2s.Size]byte // [0] is most recent, max 2 elements + mu sync.Mutex // guards the following fields + macSecrets views.Slice[[blake2s.Size]byte] // [0] is most recent, max 2 elements macSecretRotatedAt mono.Time derpMap *tailcfg.DERPMap onlyStaticAddrPorts bool // no dynamic addr port discovery when set @@ -87,8 +87,11 @@ type Server struct { closed bool lamportID uint64 nextVNI uint32 - byVNI map[uint32]*serverEndpoint - byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint + // serverEndpointByVNI is consistent with serverEndpointByDisco while mu is + // held, i.e. mu must be held around write ops. Read ops in performance + // sensitive paths, e.g. packet forwarding, do not need to acquire mu. + serverEndpointByVNI sync.Map // key is uint32 (Geneve VNI), value is [*serverEndpoint] + serverEndpointByDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint } const macSecretRotationInterval = time.Minute * 2 @@ -100,23 +103,23 @@ type Server struct { ) // serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state. -// serverEndpoint methods are not thread-safe. type serverEndpoint struct { // discoPubKeys contains the key.DiscoPublic of the served clients. The // indexing of this array aligns with the following fields, e.g. // discoSharedSecrets[0] is the shared secret to use when sealing // Disco protocol messages for transmission towards discoPubKeys[0]. - discoPubKeys key.SortedPairOfDiscoPublic - discoSharedSecrets [2]key.DiscoShared + discoPubKeys key.SortedPairOfDiscoPublic + discoSharedSecrets [2]key.DiscoShared + lamportID uint64 + vni uint32 + allocatedAt mono.Time + + mu sync.Mutex // guards the following fields inProgressGeneration [2]uint32 // or zero if a handshake has never started, or has just completed boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg lastSeen [2]mono.Time packetsRx [2]uint64 // num packets received from/sent by each client after they are bound bytesRx [2]uint64 // num bytes received from/sent by each client after they are bound - - lamportID uint64 - vni uint32 - allocatedAt mono.Time } func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg disco.BindUDPRelayEndpointCommon) ([blake2s.Size]byte, error) { @@ -141,7 +144,10 @@ func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg di return out, nil } -func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte, now mono.Time) (write []byte, to netip.AddrPort) { +func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets views.Slice[[blake2s.Size]byte], now mono.Time) (write []byte, to netip.AddrPort) { + e.mu.Lock() + defer e.mu.Unlock() + if senderIndex != 0 && senderIndex != 1 { return nil, netip.AddrPort{} } @@ -186,7 +192,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex } reply = append(reply, disco.Magic...) reply = serverDisco.AppendTo(reply) - mac, err := blakeMACFromBindMsg(macSecrets[0], from, m.BindUDPRelayEndpointCommon) + mac, err := blakeMACFromBindMsg(macSecrets.At(0), from, m.BindUDPRelayEndpointCommon) if err != nil { return nil, netip.AddrPort{} } @@ -206,7 +212,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex // silently drop return nil, netip.AddrPort{} } - for _, macSecret := range macSecrets { + for _, macSecret := range macSecrets.All() { mac, err := blakeMACFromBindMsg(macSecret, from, discoMsg.BindUDPRelayEndpointCommon) if err != nil { // silently drop @@ -230,7 +236,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex } } -func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte, now mono.Time) (write []byte, to netip.AddrPort) { +func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets views.Slice[[blake2s.Size]byte], now mono.Time) (write []byte, to netip.AddrPort) { senderRaw, isDiscoMsg := disco.Source(b) if !isDiscoMsg { // Not a Disco message @@ -265,7 +271,9 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by } func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now mono.Time) (write []byte, to netip.AddrPort) { - if !e.isBound() { + e.mu.Lock() + defer e.mu.Unlock() + if !e.isBoundLocked() { // not a control packet, but serverEndpoint isn't bound return nil, netip.AddrPort{} } @@ -287,7 +295,9 @@ func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now mon } func (e *serverEndpoint) isExpired(now mono.Time, bindLifetime, steadyStateLifetime time.Duration) bool { - if !e.isBound() { + e.mu.Lock() + defer e.mu.Unlock() + if !e.isBoundLocked() { if now.Sub(e.allocatedAt) > bindLifetime { return true } @@ -299,9 +309,9 @@ func (e *serverEndpoint) isExpired(now mono.Time, bindLifetime, steadyStateLifet return false } -// isBound returns true if both clients have completed a 3-way handshake, +// isBoundLocked returns true if both clients have completed a 3-way handshake, // otherwise false. -func (e *serverEndpoint) isBound() bool { +func (e *serverEndpoint) isBoundLocked() bool { return e.boundAddrPorts[0].IsValid() && e.boundAddrPorts[1].IsValid() } @@ -313,15 +323,14 @@ func (e *serverEndpoint) isBound() bool { // used. func NewServer(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (s *Server, err error) { s = &Server{ - logf: logf, - disco: key.NewDisco(), - bindLifetime: defaultBindLifetime, - steadyStateLifetime: defaultSteadyStateLifetime, - closeCh: make(chan struct{}), - onlyStaticAddrPorts: onlyStaticAddrPorts, - byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), - nextVNI: minVNI, - byVNI: make(map[uint32]*serverEndpoint), + logf: logf, + disco: key.NewDisco(), + bindLifetime: defaultBindLifetime, + steadyStateLifetime: defaultSteadyStateLifetime, + closeCh: make(chan struct{}), + onlyStaticAddrPorts: onlyStaticAddrPorts, + serverEndpointByDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), + nextVNI: minVNI, } s.discoPublic = s.disco.Public() @@ -640,8 +649,8 @@ func (s *Server) Close() error { // acquire s.mu. s.mu.Lock() defer s.mu.Unlock() - clear(s.byVNI) - clear(s.byDisco) + s.serverEndpointByVNI.Clear() + clear(s.serverEndpointByDisco) s.closed = true s.bus.Close() }) @@ -659,10 +668,10 @@ func (s *Server) endpointGCLoop() { // holding s.mu for the duration. Keep it simple (and slow) for now. s.mu.Lock() defer s.mu.Unlock() - for k, v := range s.byDisco { + for k, v := range s.serverEndpointByDisco { if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) { - delete(s.byDisco, k) - delete(s.byVNI, v.vni) + delete(s.serverEndpointByDisco, k) + s.serverEndpointByVNI.Delete(v.vni) } } } @@ -690,12 +699,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n if err != nil { return nil, netip.AddrPort{} } - // TODO: consider performance implications of holding s.mu for the remainder - // of this method, which does a bunch of disco/crypto work depending. Keep - // it simple (and slow) for now. - s.mu.Lock() - defer s.mu.Unlock() - e, ok := s.byVNI[gh.VNI.Get()] + e, ok := s.serverEndpointByVNI.Load(gh.VNI.Get()) if !ok { // unknown VNI return nil, netip.AddrPort{} @@ -708,27 +712,36 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n return nil, netip.AddrPort{} } msg := b[packet.GeneveFixedHeaderLength:] - s.maybeRotateMACSecretLocked(now) - return e.handleSealedDiscoControlMsg(from, msg, s.discoPublic, s.macSecrets, now) + secrets := s.getMACSecrets(now) + return e.(*serverEndpoint).handleSealedDiscoControlMsg(from, msg, s.discoPublic, secrets, now) } - return e.handleDataPacket(from, b, now) + return e.(*serverEndpoint).handleDataPacket(from, b, now) +} + +func (s *Server) getMACSecrets(now mono.Time) views.Slice[[blake2s.Size]byte] { + s.mu.Lock() + defer s.mu.Unlock() + s.maybeRotateMACSecretLocked(now) + return s.macSecrets } func (s *Server) maybeRotateMACSecretLocked(now mono.Time) { if !s.macSecretRotatedAt.IsZero() && now.Sub(s.macSecretRotatedAt) < macSecretRotationInterval { return } - switch len(s.macSecrets) { + secrets := s.macSecrets.AsSlice() + switch len(secrets) { case 0: - s.macSecrets = make([][blake2s.Size]byte, 1, 2) + secrets = make([][blake2s.Size]byte, 1, 2) case 1: - s.macSecrets = append(s.macSecrets, [blake2s.Size]byte{}) + secrets = append(secrets, [blake2s.Size]byte{}) fallthrough case 2: - s.macSecrets[1] = s.macSecrets[0] + secrets[1] = secrets[0] } - rand.Read(s.macSecrets[0][:]) + rand.Read(secrets[0][:]) s.macSecretRotatedAt = now + s.macSecrets = views.SliceOf(secrets) return } @@ -838,7 +851,7 @@ func (s *Server) getNextVNILocked() (uint32, error) { } else { s.nextVNI++ } - _, ok := s.byVNI[vni] + _, ok := s.serverEndpointByVNI.Load(vni) if !ok { return vni, nil } @@ -877,7 +890,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv } pair := key.NewSortedPairOfDiscoPublic(discoA, discoB) - e, ok := s.byDisco[pair] + e, ok := s.serverEndpointByDisco[pair] if ok { // Return the existing allocation. Clients can resolve duplicate // [endpoint.ServerEndpoint]'s via [endpoint.ServerEndpoint.LamportID]. @@ -915,8 +928,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0]) e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1]) - s.byDisco[pair] = e - s.byVNI[e.vni] = e + s.serverEndpointByDisco[pair] = e + s.serverEndpointByVNI.Store(e.vni, e) s.logf("allocated endpoint vni=%d lamportID=%d disco[0]=%v disco[1]=%v", e.vni, e.lamportID, pair.Get()[0].ShortString(), pair.Get()[1].ShortString()) return endpoint.ServerEndpoint{ @@ -930,19 +943,19 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv }, nil } -// extractClientInfo constructs a [status.ClientInfo] for one of the two peer -// relay clients involved in this session. -func extractClientInfo(idx int, ep *serverEndpoint) status.ClientInfo { - if idx != 0 && idx != 1 { - panic(fmt.Sprintf("idx passed to extractClientInfo() must be 0 or 1; got %d", idx)) - } - - return status.ClientInfo{ - Endpoint: ep.boundAddrPorts[idx], - ShortDisco: ep.discoPubKeys.Get()[idx].ShortString(), - PacketsTx: ep.packetsRx[idx], - BytesTx: ep.bytesRx[idx], +// extractClientInfo constructs a [status.ClientInfo] for both relay clients +// involved in this session. +func (e *serverEndpoint) extractClientInfo() [2]status.ClientInfo { + e.mu.Lock() + defer e.mu.Unlock() + ret := [2]status.ClientInfo{} + for i := range e.boundAddrPorts { + ret[i].Endpoint = e.boundAddrPorts[i] + ret[i].ShortDisco = e.discoPubKeys.Get()[i].ShortString() + ret[i].PacketsTx = e.packetsRx[i] + ret[i].BytesTx = e.bytesRx[i] } + return ret } // GetSessions returns a slice of peer relay session statuses, with each @@ -955,14 +968,13 @@ func (s *Server) GetSessions() []status.ServerSession { if s.closed { return nil } - var sessions = make([]status.ServerSession, 0, len(s.byDisco)) - for _, se := range s.byDisco { - c1 := extractClientInfo(0, se) - c2 := extractClientInfo(1, se) + var sessions = make([]status.ServerSession, 0, len(s.serverEndpointByDisco)) + for _, se := range s.serverEndpointByDisco { + clientInfos := se.extractClientInfo() sessions = append(sessions, status.ServerSession{ VNI: se.vni, - Client1: c1, - Client2: c2, + Client1: clientInfos[0], + Client2: clientInfos[1], }) } return sessions diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index bc7680107..c4b365641 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -339,19 +339,18 @@ func TestServer_getNextVNILocked(t *testing.T) { c := qt.New(t) s := &Server{ nextVNI: minVNI, - byVNI: make(map[uint32]*serverEndpoint), } for i := uint64(0); i < uint64(totalPossibleVNI); i++ { vni, err := s.getNextVNILocked() if err != nil { // using quicktest here triples test time t.Fatal(err) } - s.byVNI[vni] = nil + s.serverEndpointByVNI.Store(vni, nil) } c.Assert(s.nextVNI, qt.Equals, minVNI) _, err := s.getNextVNILocked() c.Assert(err, qt.IsNotNil) - delete(s.byVNI, minVNI) + s.serverEndpointByVNI.Delete(minVNI) _, err = s.getNextVNILocked() c.Assert(err, qt.IsNil) } @@ -455,17 +454,17 @@ func TestServer_maybeRotateMACSecretLocked(t *testing.T) { s := &Server{} start := mono.Now() s.maybeRotateMACSecretLocked(start) - qt.Assert(t, len(s.macSecrets), qt.Equals, 1) - macSecret := s.macSecrets[0] + qt.Assert(t, s.macSecrets.Len(), qt.Equals, 1) + macSecret := s.macSecrets.At(0) s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval - time.Nanosecond)) - qt.Assert(t, len(s.macSecrets), qt.Equals, 1) - qt.Assert(t, s.macSecrets[0], qt.Equals, macSecret) + qt.Assert(t, s.macSecrets.Len(), qt.Equals, 1) + qt.Assert(t, s.macSecrets.At(0), qt.Equals, macSecret) s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval)) - qt.Assert(t, len(s.macSecrets), qt.Equals, 2) - qt.Assert(t, s.macSecrets[1], qt.Equals, macSecret) - qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1]) + qt.Assert(t, s.macSecrets.Len(), qt.Equals, 2) + qt.Assert(t, s.macSecrets.At(1), qt.Equals, macSecret) + qt.Assert(t, s.macSecrets.At(0), qt.Not(qt.Equals), s.macSecrets.At(1)) s.maybeRotateMACSecretLocked(s.macSecretRotatedAt.Add(macSecretRotationInterval)) - qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[0]) - qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[1]) - qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1]) + qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets.At(0)) + qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets.At(1)) + qt.Assert(t, s.macSecrets.At(0), qt.Not(qt.Equals), s.macSecrets.At(1)) }