diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index c1183bf57..5104c8c8a 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -78,7 +78,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/derp/derpmap from tailscale.com/cmd/tailscaled+ tailscale.com/disco from tailscale.com/derp+ tailscale.com/health from tailscale.com/control/controlclient+ - tailscale.com/internal/deepprint from tailscale.com/ipn/ipnlocal+ + tailscale.com/internal/deephash from tailscale.com/ipn/ipnlocal+ tailscale.com/ipn from tailscale.com/ipn/ipnserver+ tailscale.com/ipn/ipnlocal from tailscale.com/ipn/ipnserver+ tailscale.com/ipn/ipnserver from tailscale.com/cmd/tailscaled diff --git a/go.mod b/go.mod index 29e42c2ed..3378debf4 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( golang.zx2c4.com/wireguard/windows v0.1.2-0.20201113162609-9b85be97fdf8 gopkg.in/yaml.v2 v2.2.8 // indirect honnef.co/go/tools v0.1.0 - inet.af/netaddr v0.0.0-20210508014949-da1c2a70a83d + inet.af/netaddr v0.0.0-20210511181906-37180328850c inet.af/netstack v0.0.0-20210317161235-a1bf4e56ef22 inet.af/peercred v0.0.0-20210302202138-56e694897155 inet.af/wf v0.0.0-20210424212123-eaa011a774a4 diff --git a/go.sum b/go.sum index 062e51705..d7017de21 100644 --- a/go.sum +++ b/go.sum @@ -256,6 +256,8 @@ honnef.co/go/tools v0.1.0/go.mod h1:XtegFAyX/PfluP4921rXU5IkjkqBCDnUq4W8VCIoKvM= inet.af/netaddr v0.0.0-20210222205655-a1ec2b7b8c44/go.mod h1:I2i9ONCXRZDnG1+7O8fSuYzjcPxHQXrIfzD/IkR87x4= inet.af/netaddr v0.0.0-20210508014949-da1c2a70a83d h1:9tuJMxDV7THGfXWirKBD/v9rbsBC21bHd2eEYsYuIek= inet.af/netaddr v0.0.0-20210508014949-da1c2a70a83d/go.mod h1:z0nx+Dh+7N7CC8V5ayHtHGpZpxLQZZxkIaaz6HN65Ls= +inet.af/netaddr v0.0.0-20210511181906-37180328850c h1:rzDy/tC8LjEdN94+i0Bu22tTo/qE9cvhKyfD0HMU0NU= +inet.af/netaddr v0.0.0-20210511181906-37180328850c/go.mod h1:z0nx+Dh+7N7CC8V5ayHtHGpZpxLQZZxkIaaz6HN65Ls= inet.af/netstack v0.0.0-20210317161235-a1bf4e56ef22 h1:DNtszwGa6w76qlIr+PbPEnlBJdiRV8SaxeigOy0q1gg= inet.af/netstack v0.0.0-20210317161235-a1bf4e56ef22/go.mod h1:GVx+5OZtbG4TVOW5ilmyRZAZXr1cNwfqUEkTOtWK0PM= inet.af/peercred v0.0.0-20210302202138-56e694897155 h1:KojYNEYqDkZ2O3LdyTstR1l13L3ePKTIEM2h7ONkfkE= diff --git a/internal/deephash/deephash.go b/internal/deephash/deephash.go new file mode 100644 index 000000000..a82d04ba9 --- /dev/null +++ b/internal/deephash/deephash.go @@ -0,0 +1,174 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package deephash hashes a Go value recursively, in a predictable +// order, without looping. +package deephash + +import ( + "bufio" + "crypto/sha256" + "fmt" + "reflect" + + "inet.af/netaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/wgkey" +) + +func Hash(v ...interface{}) string { + h := sha256.New() + // 64 matches the chunk size in crypto/sha256/sha256.go + b := bufio.NewWriterSize(h, 64) + Print(b, v) + b.Flush() + return fmt.Sprintf("%x", h.Sum(nil)) +} + +// UpdateHash sets last to the hash of v and reports whether its value changed. +func UpdateHash(last *string, v ...interface{}) (changed bool) { + sig := Hash(v) + if *last != sig { + *last = sig + return true + } + return false +} + +func Print(w *bufio.Writer, v ...interface{}) { + print(w, reflect.ValueOf(v), make(map[uintptr]bool)) +} + +var ( + netaddrIPType = reflect.TypeOf(netaddr.IP{}) + netaddrIPPrefix = reflect.TypeOf(netaddr.IPPrefix{}) + wgkeyKeyType = reflect.TypeOf(wgkey.Key{}) + wgkeyPrivateType = reflect.TypeOf(wgkey.Private{}) + tailcfgDiscoKeyType = reflect.TypeOf(tailcfg.DiscoKey{}) +) + +func print(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) { + if !v.IsValid() { + return + } + + // Special case some common types. + if v.CanInterface() { + switch v.Type() { + case netaddrIPType: + var b []byte + var err error + if v.CanAddr() { + x := v.Addr().Interface().(*netaddr.IP) + b, err = x.MarshalText() + } else { + x := v.Interface().(netaddr.IP) + b, err = x.MarshalText() + } + if err == nil { + w.Write(b) + return + } + case netaddrIPPrefix: + var b []byte + var err error + if v.CanAddr() { + x := v.Addr().Interface().(*netaddr.IPPrefix) + b, err = x.MarshalText() + } else { + x := v.Interface().(netaddr.IPPrefix) + b, err = x.MarshalText() + } + if err == nil { + w.Write(b) + return + } + case wgkeyKeyType: + if v.CanAddr() { + x := v.Addr().Interface().(*wgkey.Key) + w.Write(x[:]) + } else { + x := v.Interface().(wgkey.Key) + w.Write(x[:]) + } + return + case wgkeyPrivateType: + if v.CanAddr() { + x := v.Addr().Interface().(*wgkey.Private) + w.Write(x[:]) + } else { + x := v.Interface().(wgkey.Private) + w.Write(x[:]) + } + return + case tailcfgDiscoKeyType: + if v.CanAddr() { + x := v.Addr().Interface().(*tailcfg.DiscoKey) + w.Write(x[:]) + } else { + x := v.Interface().(tailcfg.DiscoKey) + w.Write(x[:]) + } + return + } + } + + // Generic handling. + switch v.Kind() { + default: + panic(fmt.Sprintf("unhandled kind %v for type %v", v.Kind(), v.Type())) + case reflect.Ptr: + ptr := v.Pointer() + if visited[ptr] { + return + } + visited[ptr] = true + print(w, v.Elem(), visited) + return + case reflect.Struct: + w.WriteString("struct{\n") + for i, n := 0, v.NumField(); i < n; i++ { + fmt.Fprintf(w, " [%d]: ", i) + print(w, v.Field(i), visited) + w.WriteString("\n") + } + w.WriteString("}\n") + case reflect.Slice, reflect.Array: + if v.Type().Elem().Kind() == reflect.Uint8 && v.CanInterface() { + fmt.Fprintf(w, "%q", v.Interface()) + return + } + fmt.Fprintf(w, "[%d]{\n", v.Len()) + for i, ln := 0, v.Len(); i < ln; i++ { + fmt.Fprintf(w, " [%d]: ", i) + print(w, v.Index(i), visited) + w.WriteString("\n") + } + w.WriteString("}\n") + case reflect.Interface: + print(w, v.Elem(), visited) + case reflect.Map: + sm := newSortedMap(v) + fmt.Fprintf(w, "map[%d]{\n", len(sm.Key)) + for i, k := range sm.Key { + print(w, k, visited) + w.WriteString(": ") + print(w, sm.Value[i], visited) + w.WriteString("\n") + } + w.WriteString("}\n") + case reflect.String: + w.WriteString(v.String()) + case reflect.Bool: + fmt.Fprintf(w, "%v", v.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fmt.Fprintf(w, "%v", v.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + fmt.Fprintf(w, "%v", v.Uint()) + case reflect.Float32, reflect.Float64: + fmt.Fprintf(w, "%v", v.Float()) + case reflect.Complex64, reflect.Complex128: + fmt.Fprintf(w, "%v", v.Complex()) + } +} diff --git a/internal/deepprint/deepprint_test.go b/internal/deephash/deephash_test.go similarity index 55% rename from internal/deepprint/deepprint_test.go rename to internal/deephash/deephash_test.go index d7576f27a..909c69ec0 100644 --- a/internal/deepprint/deepprint_test.go +++ b/internal/deephash/deephash_test.go @@ -2,13 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package deepprint +package deephash import ( - "bytes" "testing" "inet.af/netaddr" + "tailscale.com/tailcfg" + "tailscale.com/util/dnsname" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" ) @@ -18,10 +19,6 @@ func TestDeepPrint(t *testing.T) { // Mostly we're just testing that we don't panic on handled types. v := getVal() - var buf bytes.Buffer - Print(&buf, v) - t.Logf("Got: %s", buf.Bytes()) - hash1 := Hash(v) t.Logf("hash: %v", hash1) for i := 0; i < 20; i++ { @@ -39,7 +36,9 @@ func getVal() []interface{} { Addresses: []netaddr.IPPrefix{{Bits: 5, IP: netaddr.IPFrom16([16]byte{3: 3})}}, Peers: []wgcfg.Peer{ { - Endpoints: "foo:5", + Endpoints: wgcfg.Endpoints{ + IPPorts: wgcfg.NewIPPortSet(netaddr.MustParseIPPort("42.42.42.42:5")), + }, }, }, }, @@ -49,16 +48,25 @@ func getVal() []interface{} { netaddr.MustParseIPPrefix("1234::/64"), }, }, - map[string]string{ - "key1": "val1", - "key2": "val2", - "key3": "val3", - "key4": "val4", - "key5": "val5", - "key6": "val6", - "key7": "val7", - "key8": "val8", - "key9": "val9", + map[dnsname.FQDN][]netaddr.IP{ + dnsname.FQDN("a."): {netaddr.MustParseIP("1.2.3.4"), netaddr.MustParseIP("4.3.2.1")}, + dnsname.FQDN("b."): {netaddr.MustParseIP("8.8.8.8"), netaddr.MustParseIP("9.9.9.9")}, + }, + map[dnsname.FQDN][]netaddr.IPPort{ + dnsname.FQDN("a."): {netaddr.MustParseIPPort("1.2.3.4:11"), netaddr.MustParseIPPort("4.3.2.1:22")}, + dnsname.FQDN("b."): {netaddr.MustParseIPPort("8.8.8.8:11"), netaddr.MustParseIPPort("9.9.9.9:22")}, + }, + map[tailcfg.DiscoKey]bool{ + {1: 1}: true, + {1: 2}: false, }, } } + +func BenchmarkHash(b *testing.B) { + b.ReportAllocs() + v := getVal() + for i := 0; i < b.N; i++ { + Hash(v) + } +} diff --git a/internal/deepprint/fmtsort.go b/internal/deephash/fmtsort.go similarity index 99% rename from internal/deepprint/fmtsort.go rename to internal/deephash/fmtsort.go index 861679153..f4d3674f9 100644 --- a/internal/deepprint/fmtsort.go +++ b/internal/deephash/fmtsort.go @@ -10,7 +10,7 @@ // This is a slightly modified fork of Go's src/internal/fmtsort/sort.go -package deepprint +package deephash import ( "reflect" diff --git a/internal/deepprint/deepprint.go b/internal/deepprint/deepprint.go deleted file mode 100644 index 75088d553..000000000 --- a/internal/deepprint/deepprint.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package deepprint walks a Go value recursively, in a predictable -// order, without looping, and prints each value out to a given -// Writer, which is assumed to be a hash.Hash, as this package doesn't -// format things nicely. -// -// This is intended as a lighter version of go-spew, etc. We don't need its -// features when our writer is just a hash. -package deepprint - -import ( - "crypto/sha256" - "fmt" - "io" - "reflect" -) - -func Hash(v ...interface{}) string { - h := sha256.New() - Print(h, v) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// UpdateHash sets last to the hash of v and reports whether its value changed. -func UpdateHash(last *string, v ...interface{}) (changed bool) { - sig := Hash(v) - if *last != sig { - *last = sig - return true - } - return false -} - -func Print(w io.Writer, v ...interface{}) { - print(w, reflect.ValueOf(v), make(map[uintptr]bool)) -} - -func print(w io.Writer, v reflect.Value, visited map[uintptr]bool) { - if !v.IsValid() { - return - } - switch v.Kind() { - default: - panic(fmt.Sprintf("unhandled kind %v for type %v", v.Kind(), v.Type())) - case reflect.Ptr: - ptr := v.Pointer() - if visited[ptr] { - return - } - visited[ptr] = true - print(w, v.Elem(), visited) - return - case reflect.Struct: - fmt.Fprintf(w, "struct{\n") - t := v.Type() - for i, n := 0, v.NumField(); i < n; i++ { - sf := t.Field(i) - fmt.Fprintf(w, "%s: ", sf.Name) - print(w, v.Field(i), visited) - fmt.Fprintf(w, "\n") - } - case reflect.Slice, reflect.Array: - if v.Type().Elem().Kind() == reflect.Uint8 && v.CanInterface() { - fmt.Fprintf(w, "%q", v.Interface()) - return - } - fmt.Fprintf(w, "[%d]{\n", v.Len()) - for i, ln := 0, v.Len(); i < ln; i++ { - fmt.Fprintf(w, " [%d]: ", i) - print(w, v.Index(i), visited) - fmt.Fprintf(w, "\n") - } - fmt.Fprintf(w, "}\n") - case reflect.Interface: - print(w, v.Elem(), visited) - case reflect.Map: - sm := newSortedMap(v) - fmt.Fprintf(w, "map[%d]{\n", len(sm.Key)) - for i, k := range sm.Key { - print(w, k, visited) - fmt.Fprintf(w, ": ") - print(w, sm.Value[i], visited) - fmt.Fprintf(w, "\n") - } - fmt.Fprintf(w, "}\n") - - case reflect.String: - fmt.Fprintf(w, "%s", v.String()) - case reflect.Bool: - fmt.Fprintf(w, "%v", v.Bool()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fmt.Fprintf(w, "%v", v.Int()) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - fmt.Fprintf(w, "%v", v.Uint()) - case reflect.Float32, reflect.Float64: - fmt.Fprintf(w, "%v", v.Float()) - case reflect.Complex64, reflect.Complex128: - fmt.Fprintf(w, "%v", v.Complex()) - } -} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 22792ee8b..c31976c1c 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -30,7 +30,7 @@ "tailscale.com/client/tailscale/apitype" "tailscale.com/control/controlclient" "tailscale.com/health" - "tailscale.com/internal/deepprint" + "tailscale.com/internal/deephash" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/policy" @@ -905,7 +905,7 @@ func (b *LocalBackend) updateFilter(netMap *netmap.NetworkMap, prefs *ipn.Prefs) localNets := localNetsB.IPSet() logNets := logNetsB.IPSet() - changed := deepprint.UpdateHash(&b.filterHash, haveNetmap, addrs, packetFilter, localNets.Ranges(), logNets.Ranges(), shieldsUp) + changed := deephash.UpdateHash(&b.filterHash, haveNetmap, addrs, packetFilter, localNets.Ranges(), logNets.Ranges(), shieldsUp) if !changed { return } diff --git a/wgengine/bench/wg.go b/wgengine/bench/wg.go index 412799e15..21554e814 100644 --- a/wgengine/bench/wg.go +++ b/wgengine/bench/wg.go @@ -5,10 +5,10 @@ package main import ( + "errors" "io" "log" "os" - "strings" "sync" "testing" @@ -89,14 +89,23 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netadd var e1waitDoneOnce sync.Once e1.SetStatusCallback(func(st *wgengine.Status, err error) { + if errors.Is(err, wgengine.ErrEngineClosing) { + return + } if err != nil { log.Fatalf("e1 status err: %v", err) } logf("e1 status: %v", *st) var eps []string + var ipps []netaddr.IPPort for _, ep := range st.LocalAddrs { eps = append(eps, ep.Addr.String()) + ipps = append(ipps, ep.Addr) + } + endpoint := wgcfg.Endpoints{ + PublicKey: c1.PrivateKey.Public(), + IPPorts: wgcfg.NewIPPortSet(ipps...), } n := tailcfg.Node{ @@ -115,7 +124,7 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netadd p := wgcfg.Peer{ PublicKey: c1.PrivateKey.Public(), AllowedIPs: []netaddr.IPPrefix{a1}, - Endpoints: strings.Join(eps, ","), + Endpoints: endpoint, } c2.Peers = []wgcfg.Peer{p} e2.Reconfig(&c2, &router.Config{}, new(dns.Config)) @@ -124,14 +133,23 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netadd var e2waitDoneOnce sync.Once e2.SetStatusCallback(func(st *wgengine.Status, err error) { + if errors.Is(err, wgengine.ErrEngineClosing) { + return + } if err != nil { log.Fatalf("e2 status err: %v", err) } logf("e2 status: %v", *st) var eps []string + var ipps []netaddr.IPPort for _, ep := range st.LocalAddrs { eps = append(eps, ep.Addr.String()) + ipps = append(ipps, ep.Addr) + } + endpoint := wgcfg.Endpoints{ + PublicKey: c2.PrivateKey.Public(), + IPPorts: wgcfg.NewIPPortSet(ipps...), } n := tailcfg.Node{ @@ -150,7 +168,7 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netadd p := wgcfg.Peer{ PublicKey: c2.PrivateKey.Public(), AllowedIPs: []netaddr.IPPrefix{a2}, - Endpoints: strings.Join(eps, ","), + Endpoints: endpoint, } c1.Peers = []wgcfg.Peer{p} e1.Reconfig(&c1, &router.Config{}, new(dns.Config)) diff --git a/wgengine/magicsock/legacy.go b/wgengine/magicsock/legacy.go index c07166235..57b4765dd 100644 --- a/wgengine/magicsock/legacy.go +++ b/wgengine/magicsock/legacy.go @@ -10,7 +10,6 @@ "crypto/subtle" "encoding/binary" "errors" - "fmt" "hash" "net" "strings" @@ -27,6 +26,7 @@ "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/wgkey" + "tailscale.com/wgengine/wgcfg" ) var ( @@ -34,7 +34,11 @@ errDisabled = errors.New("magicsock: legacy networking disabled") ) -func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs string) (conn.Endpoint, error) { +// createLegacyEndpointLocked creates a new wireguard-go endpoint for a legacy connection. +// pk is the public key of the remote peer. addrs is the ordered set of addresses for the remote peer. +// rawDest is the encoded wireguard-go endpoint string. It should be treated as a black box. +// It is provided so that addrSet.DstToString can return it when requested by wireguard-go. +func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs wgcfg.IPPortSet, rawDest string) (conn.Endpoint, error) { if c.disableLegacy { return nil, errDisabled } @@ -43,18 +47,9 @@ func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs string) (conn.End Logf: c.logf, publicKey: pk, curAddr: -1, - rawdst: addrs, - } - - if addrs != "" { - for _, ep := range strings.Split(addrs, ",") { - ipp, err := netaddr.ParseIPPort(ep) - if err != nil { - return nil, fmt.Errorf("bogus address %q", ep) - } - a.ipPorts = append(a.ipPorts, ipp) - } + rawdst: rawDest, } + a.ipPorts = append(a.ipPorts, addrs.IPPorts()...) // If this endpoint is being updated, remember its old set of // endpoints so we can remove any (from c.addrsByUDP) that are diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 2a88c1d64..4291e13d9 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -11,6 +11,7 @@ "context" crand "crypto/rand" "encoding/binary" + "encoding/json" "errors" "fmt" "hash/fnv" @@ -27,7 +28,6 @@ "time" "github.com/tailscale/wireguard-go/conn" - "go4.org/mem" "golang.org/x/crypto/nacl/box" "golang.org/x/time/rate" "inet.af/netaddr" @@ -2736,42 +2736,30 @@ func packIPPort(ua netaddr.IPPort) []byte { } // ParseEndpoint is called by WireGuard to connect to an endpoint. -// -// keyAddrs is the 32 byte public key of the peer followed by addrs. -// Addrs is either: -// -// 1) a comma-separated list of UDP ip:ports (the peer doesn't have a discovery key) -// 2) ".disco.tailscale:12345", a magic value that means the peer -// is running code that supports active discovery, so ParseEndpoint returns -// a discoEndpoint. -func (c *Conn) ParseEndpoint(keyAddrs string) (conn.Endpoint, error) { - if len(keyAddrs) < 32 { - c.logf("[unexpected] ParseEndpoint keyAddrs too short: %q", keyAddrs) - return nil, errors.New("endpoint string too short") +// endpointStr is a json-serialized wgcfg.Endpoints struct. +// If those Endpoints contain an active discovery key, ParseEndpoint returns a discoEndpoint. +// Otherwise it returns a legacy endpoint. +func (c *Conn) ParseEndpoint(endpointStr string) (conn.Endpoint, error) { + var endpoints wgcfg.Endpoints + err := json.Unmarshal([]byte(endpointStr), &endpoints) + if err != nil { + return nil, fmt.Errorf("magicsock: ParseEndpoint: json.Unmarshal failed on %q: %w", endpointStr, err) } - var pk key.Public - copy(pk[:], keyAddrs) - addrs := keyAddrs[len(pk):] + pk := key.Public(endpoints.PublicKey) + discoKey := endpoints.DiscoKey + c.logf("magicsock: ParseEndpoint: key=%s: disco=%s ipps=%s", pk.ShortString(), discoKey.ShortString(), derpStr(endpoints.IPPorts.String())) + c.mu.Lock() defer c.mu.Unlock() - - c.logf("magicsock: ParseEndpoint: key=%s: %s", pk.ShortString(), derpStr(addrs)) - - if !strings.HasSuffix(addrs, wgcfg.EndpointDiscoSuffix) { - return c.createLegacyEndpointLocked(pk, addrs) - } - - discoHex := strings.TrimSuffix(addrs, wgcfg.EndpointDiscoSuffix) - discoKey, err := key.NewPublicFromHexMem(mem.S(discoHex)) - if err != nil { - return nil, fmt.Errorf("magicsock: invalid discokey endpoint %q for %v: %w", addrs, pk.ShortString(), err) + if discoKey.IsZero() { + return c.createLegacyEndpointLocked(pk, endpoints.IPPorts, endpointStr) } de := &discoEndpoint{ c: c, publicKey: tailcfg.NodeKey(pk), // peer public key (for WireGuard + DERP) discoKey: tailcfg.DiscoKey(discoKey), // for discovery mesages discoShort: tailcfg.DiscoKey(discoKey).ShortString(), - wgEndpoint: addrs, + wgEndpoint: endpointStr, sentPing: map[stun.TxID]sentPing{}, endpointState: map[netaddr.IPPort]*endpointState{}, } @@ -3115,7 +3103,7 @@ type discoEndpoint struct { discoKey tailcfg.DiscoKey // for discovery mesages discoShort string // ShortString of discoKey fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using - wgEndpoint string // string from ParseEndpoint: ".disco.tailscale:12345" + wgEndpoint string // string from ParseEndpoint, holds a JSON-serialized wgcfg.Endpoints // Owned by Conn.mu: lastPingFrom netaddr.IPPort diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index a98d48a45..8545a39dc 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -10,6 +10,7 @@ crand "crypto/rand" "crypto/tls" "encoding/binary" + "encoding/json" "errors" "fmt" "io/ioutil" @@ -167,7 +168,7 @@ func newMagicStack(t testing.TB, logf logger.Logf, l nettype.PacketListener, der tsTun.SetFilter(filter.NewAllowAllForTest(logf)) wgLogger := wglog.NewLogger(logf) - dev := device.NewDevice(tsTun, conn.Bind(), wgLogger.DeviceLogger, new(device.DeviceOptions)) + dev := device.NewDevice(tsTun, conn.Bind(), wgLogger.DeviceLogger) dev.Up() // Wait for magicsock to connect up to DERP. @@ -468,10 +469,14 @@ func makeConfigs(t *testing.T, addrs []netaddr.IPPort) []wgcfg.Config { if peerNum == i { continue } + publicKey := privKeys[peerNum].Public() peer := wgcfg.Peer{ - PublicKey: privKeys[peerNum].Public(), - AllowedIPs: addresses[peerNum], - Endpoints: addr.String(), + PublicKey: publicKey, + AllowedIPs: addresses[peerNum], + Endpoints: wgcfg.Endpoints{ + PublicKey: publicKey, + IPPorts: wgcfg.NewIPPortSet(addr), + }, PersistentKeepalive: 25, } cfg.Peers = append(cfg.Peers, peer) @@ -504,7 +509,7 @@ func TestDeviceStartStop(t *testing.T) { tun := tuntest.NewChannelTUN() wgLogger := wglog.NewLogger(t.Logf) - dev := device.NewDevice(tun.TUN(), conn.Bind(), wgLogger.DeviceLogger, new(device.DeviceOptions)) + dev := device.NewDevice(tun.TUN(), conn.Bind(), wgLogger.DeviceLogger) dev.Up() dev.Close() } @@ -1242,6 +1247,19 @@ func newNonLegacyTestConn(t testing.TB) *Conn { return conn } +func makeEndpoint(tb testing.TB, public tailcfg.NodeKey, disco tailcfg.DiscoKey) string { + tb.Helper() + ep := wgcfg.Endpoints{ + PublicKey: wgkey.Key(public), + DiscoKey: disco, + } + buf, err := json.Marshal(ep) + if err != nil { + tb.Fatal(err) + } + return string(buf) +} + // addTestEndpoint sets conn's network map to a single peer expected // to receive packets from sendConn (or DERP), and returns that peer's // nodekey and discokey. @@ -1261,7 +1279,7 @@ func addTestEndpoint(tb testing.TB, conn *Conn, sendConn net.PacketConn) (tailcf }, }) conn.SetPrivateKey(wgkey.Private{0: 1}) - _, err := conn.ParseEndpoint(string(nodeKey[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345") + _, err := conn.ParseEndpoint(makeEndpoint(tb, nodeKey, discoKey)) if err != nil { tb.Fatal(err) } @@ -1435,7 +1453,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) { }, }, }) - _, err := conn.ParseEndpoint(string(nodeKey1[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345") + _, err := conn.ParseEndpoint(makeEndpoint(t, nodeKey1, discoKey)) if err != nil { t.Fatal(err) } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index d90cb8533..6e56bbd11 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -11,7 +11,6 @@ "errors" "fmt" "io" - "net" "os" "reflect" "runtime" @@ -27,7 +26,7 @@ "inet.af/netaddr" "tailscale.com/control/controlclient" "tailscale.com/health" - "tailscale.com/internal/deepprint" + "tailscale.com/internal/deephash" "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" "tailscale.com/net/dns/resolver" @@ -307,8 +306,6 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } e.wgLogger = wglog.NewLogger(logf) - opts := &device.DeviceOptions{} - e.tundev.OnTSMPPongReceived = func(pong packet.TSMPPongReply) { e.mu.Lock() defer e.mu.Unlock() @@ -321,7 +318,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) // wgdev takes ownership of tundev, will close it when closed. e.logf("Creating wireguard device...") - e.wgdev = device.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger, opts) + e.wgdev = device.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger) closePool.addFunc(e.wgdev.Close) closePool.addFunc(func() { if err := e.magicConn.Close(); err != nil { @@ -502,15 +499,7 @@ func isTrimmablePeer(p *wgcfg.Peer, numPeers int) bool { if forceFullWireguardConfig(numPeers) { return false } - if !isSingleEndpoint(p.Endpoints) { - return false - } - - host, _, err := net.SplitHostPort(p.Endpoints) - if err != nil { - return false - } - if !strings.HasSuffix(host, ".disco.tailscale") { + if p.Endpoints.DiscoKey.IsZero() { return false } @@ -580,26 +569,6 @@ func (e *userspaceEngine) isActiveSince(dk tailcfg.DiscoKey, ip netaddr.IP, t ti return unixTime >= t.Unix() } -// discoKeyFromPeer returns the DiscoKey for a wireguard config's Peer. -// -// Invariant: isTrimmablePeer(p) == true, so it should have 1 endpoint with -// Host of form "<64-hex-digits>.disco.tailscale". If invariant is violated, -// we return the zero value. -func discoKeyFromPeer(p *wgcfg.Peer) tailcfg.DiscoKey { - if len(p.Endpoints) < 64 { - return tailcfg.DiscoKey{} - } - host, rest := p.Endpoints[:64], p.Endpoints[64:] - if !strings.HasPrefix(rest, ".disco.tailscale") { - return tailcfg.DiscoKey{} - } - k, err := key.NewPublicFromHexMem(mem.S(host)) - if err != nil { - return tailcfg.DiscoKey{} - } - return tailcfg.DiscoKey(k) -} - // discoChanged are the set of peers whose disco keys have changed, implying they've restarted. // If a peer is in this set and was previously in the live wireguard config, // it needs to be first removed and then re-added to flush out its wireguard session key. @@ -647,7 +616,7 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Publ } continue } - dk := discoKeyFromPeer(p) + dk := p.Endpoints.DiscoKey trackDisco = append(trackDisco, dk) recentlyActive := false for _, cidr := range p.AllowedIPs { @@ -664,7 +633,7 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Publ } } - if !deepprint.UpdateHash(&e.lastEngineSigTrim, min, trimmedDisco, trackDisco, trackIPs) { + if !deephash.UpdateHash(&e.lastEngineSigTrim, min, trimmedDisco, trackDisco, trackIPs) { // No changes return nil } @@ -785,8 +754,8 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } e.mu.Unlock() - engineChanged := deepprint.UpdateHash(&e.lastEngineSigFull, cfg) - routerChanged := deepprint.UpdateHash(&e.lastRouterSig, routerCfg, dnsCfg) + engineChanged := deephash.UpdateHash(&e.lastEngineSigFull, cfg) + routerChanged := deephash.UpdateHash(&e.lastRouterSig, routerCfg, dnsCfg) if !engineChanged && !routerChanged { return ErrNoChanges } @@ -797,19 +766,19 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, // and a second time with it. discoChanged := make(map[key.Public]bool) { - prevEP := make(map[key.Public]string) + prevEP := make(map[key.Public]tailcfg.DiscoKey) for i := range e.lastCfgFull.Peers { - if p := &e.lastCfgFull.Peers[i]; isSingleEndpoint(p.Endpoints) { - prevEP[key.Public(p.PublicKey)] = p.Endpoints + if p := &e.lastCfgFull.Peers[i]; !p.Endpoints.DiscoKey.IsZero() { + prevEP[key.Public(p.PublicKey)] = p.Endpoints.DiscoKey } } for i := range cfg.Peers { p := &cfg.Peers[i] - if !isSingleEndpoint(p.Endpoints) { + if p.Endpoints.DiscoKey.IsZero() { continue } pub := key.Public(p.PublicKey) - if old, ok := prevEP[pub]; ok && old != p.Endpoints { + if old, ok := prevEP[pub]; ok && old != p.Endpoints.DiscoKey { discoChanged[pub] = true e.logf("wgengine: Reconfig: %s changed from %q to %q", pub.ShortString(), old, p.Endpoints) } @@ -853,11 +822,6 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, return nil } -// isSingleEndpoint reports whether endpoints contains exactly one host:port pair. -func isSingleEndpoint(s string) bool { - return s != "" && !strings.Contains(s, ",") -} - func (e *userspaceEngine) GetFilter() *filter.Filter { return e.tundev.GetFilter() } @@ -880,6 +844,8 @@ func (e *userspaceEngine) getStatusCallback() StatusCallback { var singleNewline = []byte{'\n'} +var ErrEngineClosing = errors.New("engine closing; no status") + func (e *userspaceEngine) getStatus() (*Status, error) { // Grab derpConns before acquiring wgLock to not violate lock ordering; // the DERPs method acquires magicsock.Conn.mu. @@ -893,7 +859,7 @@ func (e *userspaceEngine) getStatus() (*Status, error) { closing := e.closing e.mu.Unlock() if closing { - return nil, errors.New("engine closing; no status") + return nil, ErrEngineClosing } if e.wgdev == nil { diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 34a331233..ea5f3ef08 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -104,7 +104,7 @@ func TestUserspaceEngineReconfig(t *testing.T) { AllowedIPs: []netaddr.IPPrefix{ {IP: netaddr.IPv4(100, 100, 99, 1), Bits: 32}, }, - Endpoints: discoHex + ".disco.tailscale:12345", + Endpoints: wgcfg.Endpoints{DiscoKey: dkFromHex(discoHex)}, }, }, } diff --git a/wgengine/wgcfg/clone.go b/wgengine/wgcfg/clone.go index 1b35a5d95..a29e0262d 100644 --- a/wgengine/wgcfg/clone.go +++ b/wgengine/wgcfg/clone.go @@ -2,12 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Code generated by tailscale.com/cmd/cloner -type Config,Peer; DO NOT EDIT. +// Code generated by tailscale.com/cmd/cloner -type Config,Peer,Endpoints,IPPortSet; DO NOT EDIT. package wgcfg import ( "inet.af/netaddr" + "tailscale.com/tailcfg" "tailscale.com/types/wgkey" ) @@ -29,7 +30,7 @@ func (src *Config) Clone() *Config { } // A compilation failure here means this code must be regenerated, with command: -// tailscale.com/cmd/cloner -type Config,Peer +// tailscale.com/cmd/cloner -type Config,Peer,Endpoints,IPPortSet var _ConfigNeedsRegeneration = Config(struct { Name string PrivateKey wgkey.Private @@ -48,14 +49,53 @@ func (src *Peer) Clone() *Peer { dst := new(Peer) *dst = *src dst.AllowedIPs = append(src.AllowedIPs[:0:0], src.AllowedIPs...) + dst.Endpoints = *src.Endpoints.Clone() return dst } // A compilation failure here means this code must be regenerated, with command: -// tailscale.com/cmd/cloner -type Config,Peer +// tailscale.com/cmd/cloner -type Config,Peer,Endpoints,IPPortSet var _PeerNeedsRegeneration = Peer(struct { PublicKey wgkey.Key AllowedIPs []netaddr.IPPrefix - Endpoints string + Endpoints Endpoints PersistentKeepalive uint16 }{}) + +// Clone makes a deep copy of Endpoints. +// The result aliases no memory with the original. +func (src *Endpoints) Clone() *Endpoints { + if src == nil { + return nil + } + dst := new(Endpoints) + *dst = *src + dst.IPPorts = *src.IPPorts.Clone() + return dst +} + +// A compilation failure here means this code must be regenerated, with command: +// tailscale.com/cmd/cloner -type Config,Peer,Endpoints,IPPortSet +var _EndpointsNeedsRegeneration = Endpoints(struct { + PublicKey wgkey.Key + DiscoKey tailcfg.DiscoKey + IPPorts IPPortSet +}{}) + +// Clone makes a deep copy of IPPortSet. +// The result aliases no memory with the original. +func (src *IPPortSet) Clone() *IPPortSet { + if src == nil { + return nil + } + dst := new(IPPortSet) + *dst = *src + dst.ipp = append(src.ipp[:0:0], src.ipp...) + return dst +} + +// A compilation failure here means this code must be regenerated, with command: +// tailscale.com/cmd/cloner -type Config,Peer,Endpoints,IPPortSet +var _IPPortSetNeedsRegeneration = IPPortSet(struct { + ipp []netaddr.IPPort +}{}) diff --git a/wgengine/wgcfg/config.go b/wgengine/wgcfg/config.go index 3abaf48cc..31dc4b238 100644 --- a/wgengine/wgcfg/config.go +++ b/wgengine/wgcfg/config.go @@ -6,16 +6,15 @@ package wgcfg import ( + "encoding/json" + "strings" + "inet.af/netaddr" + "tailscale.com/tailcfg" "tailscale.com/types/wgkey" ) -//go:generate go run tailscale.com/cmd/cloner -type=Config,Peer -output=clone.go - -// EndpointDiscoSuffix is appended to the hex representation of a peer's discovery key -// and is then the sole wireguard endpoint for peers with a non-zero discovery key. -// This form is then recognize by magicsock's ParseEndpoint. -const EndpointDiscoSuffix = ".disco.tailscale:12345" +//go:generate go run tailscale.com/cmd/cloner -type=Config,Peer,Endpoints,IPPortSet -output=clone.go // Config is a WireGuard configuration. // It only supports the set of things Tailscale uses. @@ -31,10 +30,92 @@ type Config struct { type Peer struct { PublicKey wgkey.Key AllowedIPs []netaddr.IPPrefix - Endpoints string // comma-separated host/port pairs: "1.2.3.4:56,[::]:80" + Endpoints Endpoints PersistentKeepalive uint16 } +// Endpoints represents the routes to reach a remote node. +// It is serialized and provided to wireguard-go as a conn.Endpoint. +type Endpoints struct { + // PublicKey is the public key for the remote node. + PublicKey wgkey.Key `json:"pk"` + // DiscoKey is the disco key associated with the remote node. + DiscoKey tailcfg.DiscoKey `json:"dk,omitempty"` + // IPPorts is a set of possible ip+ports the remote node can be reached at. + // This is used only for legacy connections to pre-disco (pre-0.100) peers. + IPPorts IPPortSet `json:"ipp,omitempty"` +} + +func (e Endpoints) Equal(f Endpoints) bool { + if e.PublicKey != f.PublicKey { + return false + } + if e.DiscoKey != f.DiscoKey { + return false + } + return e.IPPorts.EqualUnordered(f.IPPorts) +} + +// IPPortSet is an immutable slice of netaddr.IPPorts. +type IPPortSet struct { + ipp []netaddr.IPPort +} + +// NewIPPortSet returns an IPPortSet containing the ports in ipp. +func NewIPPortSet(ipps ...netaddr.IPPort) IPPortSet { + return IPPortSet{ipp: append(ipps[:0:0], ipps...)} +} + +// String returns a comma-separated list of all IPPorts in s. +func (s IPPortSet) String() string { + buf := new(strings.Builder) + for i, ipp := range s.ipp { + if i > 0 { + buf.WriteByte(',') + } + buf.WriteString(ipp.String()) + } + return buf.String() +} + +// IPPorts returns a slice of netaddr.IPPorts containing the IPPorts in s. +func (s IPPortSet) IPPorts() []netaddr.IPPort { + return append(s.ipp[:0:0], s.ipp...) +} + +// EqualUnordered reports whether s and t contain the same IPPorts, regardless of order. +func (s IPPortSet) EqualUnordered(t IPPortSet) bool { + if len(s.ipp) != len(t.ipp) { + return false + } + // Check whether the endpoints are the same, regardless of order. + ipps := make(map[netaddr.IPPort]int, len(s.ipp)) + for _, ipp := range s.ipp { + ipps[ipp]++ + } + for _, ipp := range t.ipp { + ipps[ipp]-- + } + for _, n := range ipps { + if n != 0 { + return false + } + } + return true +} + +// MarshalJSON marshals s into JSON. +// It is necessary so that IPPortSet's fields can be unexported, to guarantee immutability. +func (s IPPortSet) MarshalJSON() ([]byte, error) { + return json.Marshal(s.ipp) +} + +// UnmarshalJSON unmarshals s from JSON. +// It is necessary so that IPPortSet's fields can be unexported, to guarantee immutability. +func (s *IPPortSet) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, &s.ipp) +} + // PeerWithKey returns the Peer with key k and reports whether it was found. func (config Config) PeerWithKey(k wgkey.Key) (Peer, bool) { for _, p := range config.Peers { diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index 25be7f8ae..c3b8ffba4 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -10,6 +10,7 @@ "io" "net" "os" + "reflect" "sort" "strings" "sync" @@ -90,7 +91,7 @@ func TestDeviceConfig(t *testing.T) { t.Errorf("on error, could not IpcGetOperation: %v", err) } w.Flush() - t.Errorf("cfg:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) + t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) } } @@ -127,7 +128,7 @@ func TestDeviceConfig(t *testing.T) { }) t.Run("device1 modify peer", func(t *testing.T) { - cfg1.Peers[0].Endpoints = "1.2.3.4:12345" + cfg1.Peers[0].Endpoints.IPPorts = NewIPPortSet(netaddr.MustParseIPPort("1.2.3.4:12345")) if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { t.Fatal(err) } @@ -135,7 +136,7 @@ func TestDeviceConfig(t *testing.T) { }) t.Run("device1 replace endpoint", func(t *testing.T) { - cfg1.Peers[0].Endpoints = "1.1.1.1:123" + cfg1.Peers[0].Endpoints.IPPorts = NewIPPortSet(netaddr.MustParseIPPort("1.1.1.1:123")) if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { t.Fatal(err) } @@ -176,7 +177,7 @@ func TestDeviceConfig(t *testing.T) { } peersEqual := func(p, q Peer) bool { return p.PublicKey == q.PublicKey && p.PersistentKeepalive == q.PersistentKeepalive && - p.Endpoints == q.Endpoints && cidrsEqual(p.AllowedIPs, q.AllowedIPs) + reflect.DeepEqual(p.Endpoints, q.Endpoints) && cidrsEqual(p.AllowedIPs, q.AllowedIPs) } if !peersEqual(peer0(origCfg), peer0(newCfg)) { t.Error("reconfig modified old peer") diff --git a/wgengine/wgcfg/nmcfg/nmcfg.go b/wgengine/wgcfg/nmcfg/nmcfg.go index 8fe1c062a..938bd6826 100644 --- a/wgengine/wgcfg/nmcfg/nmcfg.go +++ b/wgengine/wgcfg/nmcfg/nmcfg.go @@ -8,8 +8,6 @@ import ( "bytes" "fmt" - "net" - "strconv" "strings" "inet.af/netaddr" @@ -79,17 +77,19 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, cpeer.PersistentKeepalive = 25 // seconds } - if !peer.DiscoKey.IsZero() { - cpeer.Endpoints = fmt.Sprintf("%x.disco.tailscale:12345", peer.DiscoKey[:]) - } else { - if err := appendEndpoint(cpeer, peer.DERP); err != nil { + cpeer.Endpoints = wgcfg.Endpoints{PublicKey: wgkey.Key(peer.Key), DiscoKey: peer.DiscoKey} + if peer.DiscoKey.IsZero() { + // Legacy connection. Add IP+port endpoints. + var ipps []netaddr.IPPort + if err := appendEndpoint(cpeer, &ipps, peer.DERP); err != nil { return nil, err } for _, ep := range peer.Endpoints { - if err := appendEndpoint(cpeer, ep); err != nil { + if err := appendEndpoint(cpeer, &ipps, ep); err != nil { return nil, err } } + cpeer.Endpoints.IPPorts = wgcfg.NewIPPortSet(ipps...) } didExitNodeWarn := false for _, allowedIP := range peer.AllowedIPs { @@ -136,21 +136,14 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, return cfg, nil } -func appendEndpoint(peer *wgcfg.Peer, epStr string) error { +func appendEndpoint(peer *wgcfg.Peer, ipps *[]netaddr.IPPort, epStr string) error { if epStr == "" { return nil } - _, port, err := net.SplitHostPort(epStr) + ipp, err := netaddr.ParseIPPort(epStr) if err != nil { return fmt.Errorf("malformed endpoint %q for peer %v", epStr, peer.PublicKey.ShortString()) } - _, err = strconv.ParseUint(port, 10, 16) - if err != nil { - return fmt.Errorf("invalid port in endpoint %q for peer %v", epStr, peer.PublicKey.ShortString()) - } - if peer.Endpoints != "" { - peer.Endpoints += "," - } - peer.Endpoints += epStr + *ipps = append(*ipps, ipp) return nil } diff --git a/wgengine/wgcfg/parser.go b/wgengine/wgcfg/parser.go index 318728a6b..55dc95012 100644 --- a/wgengine/wgcfg/parser.go +++ b/wgengine/wgcfg/parser.go @@ -7,6 +7,7 @@ import ( "bufio" "encoding/hex" + "encoding/json" "fmt" "io" "net" @@ -26,21 +27,6 @@ func (e *ParseError) Error() string { return fmt.Sprintf("%s: %q", e.why, e.offender) } -func validateEndpoints(s string) error { - if s == "" { - // Otherwise strings.Split of the empty string produces [""]. - return nil - } - vals := strings.Split(s, ",") - for _, val := range vals { - _, _, err := parseEndpoint(val) - if err != nil { - return err - } - } - return nil -} - func parseEndpoint(s string) (host string, port uint16, err error) { i := strings.LastIndexByte(s, ':') if i < 0 { @@ -103,6 +89,7 @@ func FromUAPI(r io.Reader) (*Config, error) { } key := parts[0] value := parts[1] + valueBytes := scanner.Bytes()[len(key)+1:] if key == "public_key" { if deviceConfig { @@ -121,7 +108,7 @@ func FromUAPI(r io.Reader) (*Config, error) { if deviceConfig { err = cfg.handleDeviceLine(key, value) } else { - err = cfg.handlePeerLine(peer, key, value) + err = cfg.handlePeerLine(peer, key, value, valueBytes) } if err != nil { return nil, err @@ -165,14 +152,13 @@ func (cfg *Config) handlePublicKeyLine(value string) (*Peer, error) { return peer, nil } -func (cfg *Config) handlePeerLine(peer *Peer, key, value string) error { +func (cfg *Config) handlePeerLine(peer *Peer, key, value string, valueBytes []byte) error { switch key { case "endpoint": - err := validateEndpoints(value) + err := json.Unmarshal(valueBytes, &peer.Endpoints) if err != nil { return err } - peer.Endpoints = value case "persistent_keepalive_interval": n, err := strconv.ParseUint(value, 10, 16) if err != nil { diff --git a/wgengine/wgcfg/parser_test.go b/wgengine/wgcfg/parser_test.go index 9d4fe1992..e101a3a05 100644 --- a/wgengine/wgcfg/parser_test.go +++ b/wgengine/wgcfg/parser_test.go @@ -53,21 +53,3 @@ func TestParseEndpoint(t *testing.T) { t.Error("Error was expected") } } - -func TestValidateEndpoints(t *testing.T) { - tests := []struct { - in string - want error - }{ - {"", nil}, - {"1.2.3.4:5", nil}, - {"1.2.3.4:5,6.7.8.9:10", nil}, - {",", &ParseError{why: "Missing port from endpoint", offender: ""}}, - } - for _, tt := range tests { - got := validateEndpoints(tt.in) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("%q = %#v (%s); want %#v (%s)", tt.in, got, got, tt.want, tt.want) - } - } -} diff --git a/wgengine/wgcfg/writer.go b/wgengine/wgcfg/writer.go index 9e3462c38..c8d09a52a 100644 --- a/wgengine/wgcfg/writer.go +++ b/wgengine/wgcfg/writer.go @@ -5,11 +5,10 @@ package wgcfg import ( + "encoding/json" "fmt" "io" - "sort" "strconv" - "strings" "inet.af/netaddr" "tailscale.com/types/wgkey" @@ -53,8 +52,12 @@ func (cfg *Config) ToUAPI(w io.Writer, prev *Config) error { setPeer(p) set("protocol_version", "1") - if !endpointsEqual(oldPeer.Endpoints, p.Endpoints) { - set("endpoint", p.Endpoints) + if !oldPeer.Endpoints.Equal(p.Endpoints) { + buf, err := json.Marshal(p.Endpoints) + if err != nil { + return err + } + set("endpoint", string(buf)) } // TODO: replace_allowed_ips is expensive. @@ -90,24 +93,6 @@ func (cfg *Config) ToUAPI(w io.Writer, prev *Config) error { return stickyErr } -func endpointsEqual(x, y string) bool { - // Cheap comparisons. - if x == y { - return true - } - xs := strings.Split(x, ",") - ys := strings.Split(y, ",") - if len(xs) != len(ys) { - return false - } - // Otherwise, see if they're the same, but out of order. - sort.Strings(xs) - sort.Strings(ys) - x = strings.Join(xs, ",") - y = strings.Join(ys, ",") - return x == y -} - func cidrsEqual(x, y []netaddr.IPPrefix) bool { // TODO: re-implement using netaddr.IPSet.Equal. if len(x) != len(y) { diff --git a/wgengine/wglog/wglog.go b/wgengine/wglog/wglog.go index 9e7cc5482..4a9c5a6a9 100644 --- a/wgengine/wglog/wglog.go +++ b/wgengine/wglog/wglog.go @@ -85,7 +85,7 @@ func (x *Logger) SetPeers(peers []wgcfg.Peer) { // Construct a new peer public key log rewriter. replace := make(map[string]string) for _, peer := range peers { - old := "peer(" + wireguardGoString(peer.PublicKey) + ")" + old := wireguardGoString(peer.PublicKey) new := peer.PublicKey.ShortString() replace[old] = new } @@ -94,10 +94,17 @@ func (x *Logger) SetPeers(peers []wgcfg.Peer) { // wireguardGoString prints p in the same format used by wireguard-go. func wireguardGoString(k wgkey.Key) string { - base64Key := base64.StdEncoding.EncodeToString(k[:]) - abbreviatedKey := "invalid" - if len(base64Key) == 44 { - abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] - } - return abbreviatedKey + const prefix = "peer(" + b := make([]byte, len(prefix)+44) + copy(b, prefix) + r := b[len(prefix):] + base64.StdEncoding.Encode(r, k[:]) + r = r[4:] + copy(r, "…") + r = r[len("…"):] + copy(r, b[len(prefix)+39:len(prefix)+43]) + r = r[4:] + r[0] = ')' + r = r[1:] + return string(b[:len(b)-len(r)]) } diff --git a/wgengine/wglog/wglog_test.go b/wgengine/wglog/wglog_test.go index b804b5959..f2aad667a 100644 --- a/wgengine/wglog/wglog_test.go +++ b/wgengine/wglog/wglog_test.go @@ -8,6 +8,7 @@ "fmt" "testing" + "tailscale.com/types/logger" "tailscale.com/types/wgkey" "tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wglog" @@ -70,3 +71,30 @@ func stringer(s string) stringerString { type stringerString string func (s stringerString) String() string { return string(s) } + +func BenchmarkSetPeers(b *testing.B) { + b.ReportAllocs() + x := wglog.NewLogger(logger.Discard) + peers := [][]wgcfg.Peer{genPeers(0), genPeers(15), genPeers(16), genPeers(15)} + for i := 0; i < b.N; i++ { + for _, p := range peers { + x.SetPeers(p) + } + } +} + +func genPeers(n int) []wgcfg.Peer { + if n > 32 { + panic("too many peers") + } + if n == 0 { + return nil + } + peers := make([]wgcfg.Peer, n) + for i := range peers { + var k wgkey.Key + k[n] = byte(n) + peers[i].PublicKey = k + } + return peers +}