Files
tailscale/net/netutil/netutil_test.go
Achille Roussel 7f3bbc9865 net/netutil: add NewDefaultTransport to avoid http.DefaultTransport panics
Several packages built their HTTP transports with

    http.DefaultTransport.(*http.Transport).Clone()

The standard library only documents http.DefaultTransport as an
http.RoundTripper, so an application is free to replace it with a
RoundTripper that is not a *http.Transport (e.g. an instrumented or
tracing wrapper). When such an application embeds tsnet.Server, the
unchecked type assertion panics as soon as tsnet brings up its control
connection, DNS bootstrap, or log uploader.

Add netutil.NewDefaultTransport, which returns a clone of the global
when it is still the standard *http.Transport (preserving existing
behavior) and otherwise returns a fresh transport mirroring the stdlib
defaults. Route every clone site through it.

Updates #19937

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
2026-06-01 12:28:36 -07:00

123 lines
3.0 KiB
Go

// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package netutil
import (
"io"
"net"
"net/http"
"runtime"
"testing"
"time"
)
type conn struct {
net.Conn
}
func TestOneConnListener(t *testing.T) {
c1 := new(conn)
a1 := dummyAddr("a1")
// Two Accepts
ln := NewOneConnListener(c1, a1)
if got := ln.Addr(); got != a1 {
t.Errorf("Addr = %#v; want %#v", got, a1)
}
c, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
if c != c1 {
t.Fatalf("didn't get c1; got %p", c)
}
c, err = ln.Accept()
if err != io.EOF {
t.Errorf("got %v; want EOF", err)
}
if c != nil {
t.Errorf("unexpected non-nil Conn")
}
// Close before Accept
ln = NewOneConnListener(c1, a1)
ln.Close()
_, err = ln.Accept()
if err != io.EOF {
t.Fatalf("got %v; want EOF", err)
}
// Implicit addr
ln = NewOneConnListener(c1, nil)
if ln.Addr() == nil {
t.Errorf("nil Addr")
}
}
// roundTripperFunc is an http.RoundTripper that is not a *http.Transport,
// used to exercise the fallback path of NewDefaultTransport.
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
func TestNewDefaultTransport(t *testing.T) {
// Standard case: http.DefaultTransport is still a *http.Transport, so we
// get a clone of it with the stdlib defaults.
tr := NewDefaultTransport()
if tr == nil {
t.Fatal("got nil transport")
}
if got, want := tr.MaxIdleConns, 100; got != want {
t.Errorf("MaxIdleConns = %d; want %d", got, want)
}
if got, want := tr.IdleConnTimeout, 90*time.Second; got != want {
t.Errorf("IdleConnTimeout = %v; want %v", got, want)
}
if !tr.ForceAttemptHTTP2 {
t.Error("ForceAttemptHTTP2 = false; want true")
}
// Regression case: an application has replaced http.DefaultTransport with
// a RoundTripper that is not a *http.Transport. NewDefaultTransport must
// not panic and must still return a usable transport with stdlib defaults.
orig := http.DefaultTransport
defer func() { http.DefaultTransport = orig }()
http.DefaultTransport = roundTripperFunc(func(*http.Request) (*http.Response, error) {
return nil, nil
})
tr = NewDefaultTransport()
if tr == nil {
t.Fatal("got nil transport on fallback path")
}
if got, want := tr.MaxIdleConns, 100; got != want {
t.Errorf("fallback MaxIdleConns = %d; want %d", got, want)
}
if got, want := tr.IdleConnTimeout, 90*time.Second; got != want {
t.Errorf("fallback IdleConnTimeout = %v; want %v", got, want)
}
if !tr.ForceAttemptHTTP2 {
t.Error("fallback ForceAttemptHTTP2 = false; want true")
}
if tr.DialContext == nil {
t.Error("fallback DialContext = nil; want non-nil")
}
if tr.Proxy == nil {
t.Error("fallback Proxy = nil; want non-nil")
}
}
func TestIPForwardingEnabledLinux(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("skipping on %s", runtime.GOOS)
}
got, err := ipForwardingEnabledLinux(ipv4, "some-not-found-interface")
if err != nil {
t.Fatal(err)
}
if got {
t.Errorf("got true; want false")
}
}