diff --git a/net/porttrack/porttrack.go b/net/porttrack/porttrack.go index 822e7200e..f71154f78 100644 --- a/net/porttrack/porttrack.go +++ b/net/porttrack/porttrack.go @@ -9,9 +9,9 @@ // // The magic address format is: // -// testport-report:HOST:PORT/LABEL +// testport-report-LABEL:PORT // -// where HOST:PORT is the collector's TCP address and LABEL identifies +// where localhost:PORT is the collector's TCP address and LABEL identifies // which listener this is (e.g. "main", "plaintext"). // // When [Listen] is called with a non-magic address, it falls through to @@ -31,17 +31,18 @@ "tailscale.com/util/testenv" ) -const magicPrefix = "testport-report:" +const magicPrefix = "testport-report-" // Collector is the parent/test side of the porttrack protocol. It // listens for port reports from child processes that used [Listen] // with a magic address obtained from [Collector.Addr]. type Collector struct { - ln net.Listener - mu sync.Mutex - cond *sync.Cond - ports map[string]int - err error // non-nil if a context passed to Port was cancelled + ln net.Listener + lnPort int + mu sync.Mutex + cond *sync.Cond + ports map[string]int + err error // non-nil if a context passed to Port was cancelled } // NewCollector creates a new Collector. The collector's TCP listener is @@ -53,8 +54,9 @@ func NewCollector(t testenv.TB) *Collector { t.Fatalf("porttrack.NewCollector: %v", err) } c := &Collector{ - ln: ln, - ports: make(map[string]int), + ln: ln, + lnPort: ln.Addr().(*net.TCPAddr).Port, + ports: make(map[string]int), } c.cond = sync.NewCond(&c.mu) go c.accept(t) @@ -100,7 +102,14 @@ func (c *Collector) handleConn(t testenv.TB, conn net.Conn) { // causes the child to bind to localhost:0 and report its actual port // back to this collector under the given label. func (c *Collector) Addr(label string) string { - return magicPrefix + c.ln.Addr().String() + "/" + label + for _, c := range label { + switch { + case 'a' <= c && c <= 'z', 'A' <= c && c <= 'Z', '0' <= c && c <= '9', c == '-': + default: + panic(fmt.Sprintf("invalid label %q: only letters, digits, and hyphens are allowed", label)) + } + } + return fmt.Sprintf("%s%s:%d", magicPrefix, label, c.lnPort) } // Port blocks until the child process has reported the port for the @@ -145,13 +154,11 @@ func Listen(network, address string) (net.Listener, error) { return net.Listen(network, address) } - // rest is "HOST:PORT/LABEL" - slashIdx := strings.LastIndex(rest, "/") - if slashIdx < 0 { - return nil, fmt.Errorf("porttrack: malformed magic address %q: missing /LABEL", address) + // rest is LABEL:PORT. + label, collectorPort, ok := strings.Cut(rest, ":") + if !ok { + return nil, fmt.Errorf("porttrack: malformed magic address %q: missing :PORT", address) } - collectorAddr := rest[:slashIdx] - label := rest[slashIdx+1:] ln, err := net.Listen(network, "localhost:0") if err != nil { @@ -160,6 +167,7 @@ func Listen(network, address string) (net.Listener, error) { port := ln.Addr().(*net.TCPAddr).Port + collectorAddr := net.JoinHostPort("localhost", collectorPort) conn, err := net.Dial("tcp", collectorAddr) if err != nil { ln.Close()