diff --git a/net/netns/netns_android.go b/net/netns/netns_android.go index e747f61f4..55fdadbbc 100644 --- a/net/netns/netns_android.go +++ b/net/netns/netns_android.go @@ -10,6 +10,7 @@ "sync" "syscall" + "tailscale.com/envknob" "tailscale.com/net/netmon" "tailscale.com/types/logger" ) @@ -17,8 +18,14 @@ var ( androidProtectFuncMu sync.Mutex androidProtectFunc func(fd int) error + + androidBindToNetworkFuncMu sync.Mutex + androidBindToNetworkFunc func(fd int) error ) +// If enabled, we skip the bind-to-network hook and let the routing table decide. +var bindToInterfaceByRouteEnv = envknob.RegisterBool("TS_BIND_TO_INTERFACE_BY_ROUTE") + // UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK. func UseSocketMark() bool { return false @@ -50,6 +57,16 @@ func SetAndroidProtectFunc(f func(fd int) error) { androidProtectFunc = f } +// SetAndroidBindToNetworkFunc registers a func that Android provides that binds +// the socket FD to the currently selected underlying etwork +// +// A nil func disables the hook. +func SetAndroidBindToNetworkFunc(f func(fd int) error) { + androidBindToNetworkFuncMu.Lock() + defer androidBindToNetworkFuncMu.Unlock() + androidBindToNetworkFunc = f +} + func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { return controlC } @@ -60,14 +77,42 @@ func control(logger.Logf, *netmon.Monitor) func(network, address string, c sysca // and net.ListenConfig.Control. func controlC(network, address string, c syscall.RawConn) error { var sockErr error + + // If route-based binding is enabled, we preserve the historical behavior: + // protect from VPN loops, but do NOT force-bind to a particular Network. + // This lets the OS routing table make per-destination decisions. + useRoute := bindToInterfaceByRoute.Load() || bindToInterfaceByRouteEnv() + err := c.Control(func(fd uintptr) { + fdInt := int(fd) + + // Protect from VPN loops androidProtectFuncMu.Lock() - f := androidProtectFunc + pf := androidProtectFunc androidProtectFuncMu.Unlock() - if f != nil { - sockErr = f(int(fd)) + if pf != nil { + if err := pf(fdInt); err != nil && sockErr == nil { + sockErr = err + return + } + } + + // Maybe bbind to currently active network + if useRoute { + return + } + + androidBindToNetworkFuncMu.Lock() + bf := androidBindToNetworkFunc + androidBindToNetworkFuncMu.Unlock() + if bf != nil { + if err := bf(fdInt); err != nil && sockErr == nil { + sockErr = err + return + } } }) + if err != nil { return fmt.Errorf("RawConn.Control on %T: %w", c, err) }