From 83dadabf30ae14aab545e3d050a0d12ecb9e1c8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Claus=20Lensb=C3=B8l?= Date: Mon, 30 Mar 2026 18:01:26 -0400 Subject: [PATCH] WIP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Claus Lensbøl --- net/packet/tsmp.go | 13 +++++++++- net/packet/tsmp_test.go | 46 ++++++++++++++++++++++++--------- net/tstun/wrap.go | 1 + net/tstun/wrap_test.go | 4 +++ types/events/disco_update.go | 5 ++-- wgengine/magicsock/magicsock.go | 2 ++ wgengine/userspace.go | 42 +++++++++++++++++++++++------- wgengine/userspace_test.go | 2 +- 8 files changed, 89 insertions(+), 26 deletions(-) diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go index ad1db311a..41562cc35 100644 --- a/net/packet/tsmp.go +++ b/net/packet/tsmp.go @@ -273,6 +273,7 @@ func (h TSMPPongReply) Marshal(buf []byte) error { type TSMPDiscoKeyAdvertisement struct { Src, Dst netip.Addr // Src and Dst are set from the parent IP Header when parsing. Key key.DiscoPublic + Request bool } func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) { @@ -293,7 +294,12 @@ func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) { payload := make([]byte, 0, 33) payload = append(payload, byte(TSMPTypeDiscoAdvertisement)) payload = ka.Key.AppendTo(payload) - if len(payload) != 33 { + if ka.Request { + payload = append(payload, 1) + } else { + payload = append(payload, 0) + } + if len(payload) != 34 { // Mostly to safeguard against ourselves changing this in the future. return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload)) } @@ -312,6 +318,11 @@ func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok tka.Src = pp.Src.Addr() tka.Dst = pp.Dst.Addr() tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33])) + if p[33] == 1 { + tka.Request = true + } else { + tka.Request = false + } return tka, true } diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go index 01bb836d7..fede3dbe5 100644 --- a/net/packet/tsmp_test.go +++ b/net/packet/tsmp_test.go @@ -80,10 +80,10 @@ func TestTailscaleRejectedHeader(t *testing.T) { func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) { var ( - // IPv4: Ver(4)Len(5), TOS, Len(53), ID, Flags, TTL(64), Proto(99), Cksum - headerV4, _ = hex.DecodeString("45000035000000004063705d") - // IPv6: Ver(6)TCFlow, Len(33), NextHdr(99), HopLim(64) - headerV6, _ = hex.DecodeString("6000000000216340") + // IPv4: Ver(4)Len(5), TOS, Len(54), ID, Flags, TTL(64), Proto(99), Cksum + headerV4, _ = hex.DecodeString("45000036000000004063705c") + // IPv6: Ver(6)TCFlow, Len(34), NextHdr(99), HopLim(64) + headerV6, _ = hex.DecodeString("6000000000226340") packetType = []byte{'a'} testKey = bytes.Repeat([]byte{'a'}, 32) @@ -107,20 +107,42 @@ func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) { { name: "v4Header", tka: TSMPDiscoKeyAdvertisement{ - Src: srcV4, - Dst: dstV4, - Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + Src: srcV4, + Dst: dstV4, + Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + Request: false, }, - want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey), + want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey, []byte{0}), }, { name: "v6Header", tka: TSMPDiscoKeyAdvertisement{ - Src: srcV6, - Dst: dstV6, - Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + Src: srcV6, + Dst: dstV6, + Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + Request: false, }, - want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey), + want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey, []byte{0}), + }, + { + name: "v4Header_request", + tka: TSMPDiscoKeyAdvertisement{ + Src: srcV4, + Dst: dstV4, + Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + Request: true, + }, + want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey, []byte{1}), + }, + { + name: "v6Header_request", + tka: TSMPDiscoKeyAdvertisement{ + Src: srcV6, + Dst: dstV6, + Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + Request: true, + }, + want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey, []byte{1}), }, } diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 1b28eb157..afabd44a4 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -1181,6 +1181,7 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa t.discoKeyAdvertisementPub.Publish(events.DiscoKeyAdvertisement{ Src: discoKeyAdvert.Src, Key: discoKeyAdvert.Key, + Request: discoKeyAdvert.Request, }) } return filter.DropSilently, gro diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 57b300513..bdb9821e4 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -974,6 +974,7 @@ func TestTSMPDisco(t *testing.T) { Src: src, Dst: dst, Key: discoKey.Public(), + Request: true, }).Marshal() var p packet.Parsed @@ -989,6 +990,9 @@ func TestTSMPDisco(t *testing.T) { if tda.Key.Compare(discoKey.Public()) != 0 { t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key) } + if !tda.Request { + t.Errorf("Requested expected to be true, got false") + } }) } diff --git a/types/events/disco_update.go b/types/events/disco_update.go index 206c554a1..1b9b48bd9 100644 --- a/types/events/disco_update.go +++ b/types/events/disco_update.go @@ -17,8 +17,9 @@ // [controlclient.Direct], that injects the received key into the netmap as if // it was a netmap update from control. type DiscoKeyAdvertisement struct { - Src netip.Addr // Src field is populated by the IP header of the packet, not from the payload itself. - Key key.DiscoPublic + Src netip.Addr // Src field is populated by the IP header of the packet, not from the payload itself. + Key key.DiscoPublic + Request bool } // PeerDiscoKeyUpdate is an event sent on the [eventbus.Bus] when diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 6a2e9c39c..7c294214b 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -4317,6 +4317,7 @@ func (c *Conn) HandleDiscoKeyAdvertisement(node tailcfg.NodeView, update packet. type NewDiscoKeyAvailable struct { NodeFirstAddr netip.Addr NodeID tailcfg.NodeID + Request bool } // maybeSendTSMPDiscoAdvert conditionally emits an event indicating that we @@ -4340,6 +4341,7 @@ func (c *Conn) maybeSendTSMPDiscoAdvert(de *endpoint) { c.tsmpDiscoKeyAvailablePub.Publish(NewDiscoKeyAvailable{ NodeFirstAddr: de.nodeAddr, NodeID: de.nodeID, + Request: true, }) } } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 5b81206d0..2cf2c62ab 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -621,9 +621,18 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) e.magicConn.HandleDiscoKeyAdvertisement(peer.Node, pkt) }) var tsmpRequestGroup singleflight.Group[netip.Addr, struct{}] + eventbus.SubscribeFunc(ec, func(update events.DiscoKeyAdvertisement) { + if update.Request { + go tsmpRequestGroup.Do(update.Src, func() (struct{}, error) { + e.sendTSMPDiscoAdvertisement(update.Src, false) + e.logf("wgengine: sending TSMP disco key advertisement to %v", update.Src) + return struct{}{}, nil + }) + } + }) eventbus.SubscribeFunc(ec, func(req magicsock.NewDiscoKeyAvailable) { go tsmpRequestGroup.Do(req.NodeFirstAddr, func() (struct{}, error) { - e.sendTSMPDiscoAdvertisement(req.NodeFirstAddr) + e.sendTSMPDiscoAdvertisement(req.NodeFirstAddr, req.Request) e.logf("wgengine: sending TSMP disco key advertisement to %v", req.NodeFirstAddr) return struct{}{}, nil }) @@ -1133,14 +1142,26 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, // If the key changed, mark the connection for reconfiguration. pub := p.PublicKey + if old, ok := prevEP[pub]; ok && old != p.DiscoKey { // If the disco key was learned via TSMP, we do not need to reset the // wireguard config as the new key was received over an existing wireguard // connection. - if discoTSMP, okTSMP := e.tsmpLearnedDisco[p.PublicKey]; okTSMP && - discoTSMP == p.DiscoKey { - delete(e.tsmpLearnedDisco, p.PublicKey) - e.logf("wgengine: Skipping reconfig (TSMP key): %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey) + if discoTSMP, okTSMP := e.tsmpLearnedDisco[p.PublicKey]; okTSMP { + if discoTSMP == p.DiscoKey { + // Key matches, remove entry from map. + delete(e.tsmpLearnedDisco, p.PublicKey) + e.logf("wgengine: Skipping reconfig (TSMP key): %s changed from %q to %q", + pub.ShortString(), old, p.DiscoKey) + } else { + // Key does not match (we should never get here), log and let the + // entry be cleaned up in the node clean routine. Also makes it + // possible to test for this. + e.logf("wgengine: [unexpected] Reconfig: using TSMP key for %s (control stale): tsmp=%q control=%q old=%q", + pub.ShortString(), discoTSMP, p.DiscoKey, old) + p.DiscoKey = discoTSMP + } + // Skip session clear no matter what. continue } @@ -1562,7 +1583,7 @@ func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size in e.magicConn.Ping(peer, res, size, cb) case "TSMP": e.sendTSMPPing(ip, peer, res, cb) - e.sendTSMPDiscoAdvertisement(ip) + e.sendTSMPDiscoAdvertisement(ip, false) case "ICMP": e.sendICMPEchoRequest(ip, peer, res, cb) } @@ -1683,16 +1704,17 @@ func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer tailcfg.NodeView, res e.tundev.InjectOutbound(tsmpPing) } -func (e *userspaceEngine) sendTSMPDiscoAdvertisement(ip netip.Addr) { +func (e *userspaceEngine) sendTSMPDiscoAdvertisement(ip netip.Addr, request bool) { srcIP, err := e.mySelfIPMatchingFamily(ip) if err != nil { e.logf("getting matching node: %s", err) return } tdka := packet.TSMPDiscoKeyAdvertisement{ - Src: srcIP, - Dst: ip, - Key: e.magicConn.DiscoPublicKey(), + Src: srcIP, + Dst: ip, + Key: e.magicConn.DiscoPublicKey(), + Request: request, } payload, err := tdka.Marshal() if err != nil { diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 210c528b3..82ae94b62 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -445,7 +445,7 @@ func TestTSMPKeyAdvertisement(t *testing.T) { addr := netip.MustParseAddr("100.100.99.1") previousValue := metricTSMPDiscoKeyAdvertisementSent.Value() - ue.sendTSMPDiscoAdvertisement(addr) + ue.sendTSMPDiscoAdvertisement(addr, false) if val := metricTSMPDiscoKeyAdvertisementSent.Value(); val <= previousValue { errs := metricTSMPDiscoKeyAdvertisementError.Value() t.Errorf("Expected 1 disco key advert, got %d, errors %d", val, errs)