diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 57733a66b..12c4cdd4d 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -144,6 +144,24 @@ func RegisterNewSSHServer(fn newSSHServerFunc) { newSSHServer = fn } +// HookListenSSH is set by the ssh/tailssh package (via feature/ssh) to provide +// an implementation of ListenSSH for use by tsnet. +var HookListenSSH feature.Hook[func(net.Listener, *LocalBackend, logger.Logf) (net.Listener, error)] + +// ListenSSH wraps the given listener with an SSH server that authenticates +// connections using Tailscale peer identity. The returned listener's Accept +// yields net.Conn values that are *tailssh.Session. +// +// If the ssh/tailssh package has not been linked (e.g. via +// _ "tailscale.com/feature/ssh"), ListenSSH returns an error. +func (b *LocalBackend) ListenSSH(ln net.Listener, logf logger.Logf) (net.Listener, error) { + fn, ok := HookListenSSH.GetOk() + if !ok { + return nil, errors.New("SSH support not available; import _ \"tailscale.com/feature/ssh\"") + } + return fn(ln, b, logf) +} + // watchSession represents a WatchNotifications channel, // an [ipnauth.Actor] that owns it (e.g., a connected GUI/CLI), // and sessionID as required to close targeted buses. diff --git a/ssh/tailssh/listen.go b/ssh/tailssh/listen.go new file mode 100644 index 000000000..8c98127f2 --- /dev/null +++ b/ssh/tailssh/listen.go @@ -0,0 +1,136 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 + +package tailssh + +import ( + "errors" + "net" + "net/netip" + "sync" + + gliderssh "github.com/tailscale/gliderssh" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/types/logger" +) + +func init() { + ipnlocal.HookListenSSH.Set(listenSSH) +} + +// listenSSH wraps rawLn with an SSH server that resolves Tailscale peer +// identity for each connection. The returned listener's Accept yields +// *Session values (as net.Conn). +func listenSSH(rawLn net.Listener, lb *ipnlocal.LocalBackend, logf logger.Logf) (net.Listener, error) { + hostKeys, err := getHostKeys(lb.TailscaleVarRoot(), logf) + if err != nil { + return nil, err + } + signers := make([]gliderssh.Signer, len(hostKeys)) + for i, k := range hostKeys { + signers[i] = k + } + + sl := &sshListener{ + rawLn: rawLn, + sessions: make(chan net.Conn, 16), + done: make(chan struct{}), + } + + sshSrv := &gliderssh.Server{ + HostSigners: signers, + Handler: func(sess gliderssh.Session) { + srcAddr := sess.RemoteAddr().String() + ipp, err := netip.ParseAddrPort(srcAddr) + if err != nil { + logf("listenSSH: bad remote addr %q: %v", srcAddr, err) + sess.Exit(1) + return + } + node, userProfile, ok := lb.WhoIs("tcp", ipp) + if !ok { + logf("listenSSH: WhoIs failed for %v", srcAddr) + sess.Exit(1) + return + } + + done := make(chan struct{}) + s := newSession(sess, peerIdentity{ + Node: node, + UserProfile: userProfile, + }, done) + + // Send the session to the listener. If the listener is + // closed, drop the session. + select { + case sl.sessions <- s: + case <-sl.done: + sess.Exit(1) + return + } + + // Block until the consumer is done with the session. + select { + case <-done: + case <-sess.Context().Done(): + case <-sl.done: + } + }, + } + + go func() { + if err := sshSrv.Serve(rawLn); err != nil { + // Serve returns when the listener is closed. Only log + // unexpected errors. + select { + case <-sl.done: + default: + logf("listenSSH: Serve error: %v", err) + } + } + sl.Close() + }() + + return sl, nil +} + +// sshListener is a net.Listener that yields *Session values from its Accept +// method. It wraps a raw TCP listener with an SSH server. +type sshListener struct { + rawLn net.Listener + sessions chan net.Conn + done chan struct{} + closeOnce sync.Once +} + +// Accept returns the next SSH session as a net.Conn. The returned value can +// be type-asserted to *Session. +func (sl *sshListener) Accept() (net.Conn, error) { + select { + case s, ok := <-sl.sessions: + if !ok { + return nil, errors.New("listener closed") + } + return s, nil + case <-sl.done: + return nil, errors.New("listener closed") + } +} + +// Close closes the underlying raw listener and signals all pending sessions +// to terminate. +func (sl *sshListener) Close() error { + var err error + sl.closeOnce.Do(func() { + close(sl.done) + err = sl.rawLn.Close() + }) + return err +} + +// Addr returns the address of the underlying raw listener. +func (sl *sshListener) Addr() net.Addr { + return sl.rawLn.Addr() +} diff --git a/ssh/tailssh/session.go b/ssh/tailssh/session.go new file mode 100644 index 000000000..c19ae7bd4 --- /dev/null +++ b/ssh/tailssh/session.go @@ -0,0 +1,210 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 + +package tailssh + +import ( + "context" + "errors" + "io" + "maps" + "net" + "sync" + "time" + + gliderssh "github.com/tailscale/gliderssh" + "tailscale.com/tailcfg" +) + +var errNoDeadline = errors.New("tailssh.Session: deadlines not supported") + +// Signal represents an SSH signal (e.g. "INT", "TERM"). +type Signal = gliderssh.Signal + +// Pty represents a PTY request and configuration. +type Pty struct { + // Term is the TERM environment variable value. + Term string + + // Window is the initial window size. + Window Window + + // Modes are the RFC 4254 terminal modes as opcode/value pairs. + Modes map[uint8]uint32 +} + +// Window represents the size of a PTY window. +type Window struct { + Width int // columns + Height int // rows + WidthPixels int // width in pixels + HeightPixels int // height in pixels +} + +// peerIdentity contains the Tailscale identity of the connecting SSH peer. +type peerIdentity struct { + Node tailcfg.NodeView // node connecting + UserProfile tailcfg.UserProfile +} + +// Session wraps a gliderlabs gliderssh.Session with Tailscale peer identity +// information. It implements net.Conn so callers that only need Read/Write/Close +// can use it directly. Callers that need SSH-specific functionality can +// type-assert from the net.Conn returned by the listener's Accept. +type Session struct { + // sess is the underlying gliderlabs SSH session. + sess gliderssh.Session + + // peer is the Tailscale identity of the remote peer. + peer peerIdentity + + // done is closed when the session handler should return, + // unblocking the gliderlabs handler goroutine. + doneOnce sync.Once // guards close(done) + done chan struct{} +} + +// newSession creates a new Session wrapping the given gliderlabs session and +// peer identity. The done channel is closed by the session consumer to signal +// that the handler goroutine may return. +func newSession(sess gliderssh.Session, peer peerIdentity, done chan struct{}) *Session { + return &Session{ + sess: sess, + peer: peer, + done: done, + } +} + +// Read reads from the SSH channel (stdin from the client). +func (s *Session) Read(p []byte) (int, error) { + return s.sess.Read(p) +} + +// Write writes to the SSH channel (stdout to the client). +func (s *Session) Write(p []byte) (int, error) { + return s.sess.Write(p) +} + +// Close signals the session handler to return and closes the underlying channel. +func (s *Session) Close() error { + s.doneOnce.Do(func() { close(s.done) }) + return nil +} + +// RemoteAddr returns the net.Addr of the client side of the connection. +func (s *Session) RemoteAddr() net.Addr { + return s.sess.RemoteAddr() +} + +// LocalAddr returns the net.Addr of the server side of the connection. +func (s *Session) LocalAddr() net.Addr { + return s.sess.LocalAddr() +} + +// SetDeadline is not supported and returns an error. +func (s *Session) SetDeadline(t time.Time) error { + return errNoDeadline +} + +// SetReadDeadline is not supported and returns an error. +func (s *Session) SetReadDeadline(t time.Time) error { + return errNoDeadline +} + +// SetWriteDeadline is not supported and returns an error. +func (s *Session) SetWriteDeadline(t time.Time) error { + return errNoDeadline +} + +// User returns the SSH username. +func (s *Session) User() string { + return s.sess.User() +} + +// Peer returns the Tailscale identity of the remote node. +func (s *Session) Peer() tailcfg.NodeView { return s.peer.Node } + +// UserProfile returns the Tailscale user profile of the remote node. +// +// For tagged nodes, this is same sort of UserProfile that is returned by the +// LocalAPI WhoIs API. +func (s *Session) UserProfile() tailcfg.UserProfile { return s.peer.UserProfile } + +// Environ returns a copy of the environment variables set by the client. +func (s *Session) Environ() []string { + return s.sess.Environ() +} + +// RawCommand returns the exact command string provided by the client. +func (s *Session) RawCommand() string { + return s.sess.RawCommand() +} + +// Subsystem returns the subsystem requested by the client. +func (s *Session) Subsystem() string { + return s.sess.Subsystem() +} + +// Pty returns PTY information, a channel of window size changes, and whether a +// PTY was requested. +func (s *Session) Pty() (_ Pty, _ <-chan Window, ok bool) { + gPty, gWinCh, ok := s.sess.Pty() + if !ok { + return Pty{}, nil, false + } + p := Pty{ + Term: gPty.Term, + Window: Window{ + Width: gPty.Window.Width, + Height: gPty.Window.Height, + WidthPixels: gPty.Window.WidthPixels, + HeightPixels: gPty.Window.HeightPixels, + }, + } + if gPty.Modes != nil { + p.Modes = make(map[uint8]uint32, len(gPty.Modes)) + maps.Copy(p.Modes, gPty.Modes) + } + + // Convert the gliderlabs Window channel to our Window type. + winCh := make(chan Window, 1) + go func() { + defer close(winCh) + for gw := range gWinCh { + winCh <- Window{ + Width: gw.Width, + Height: gw.Height, + WidthPixels: gw.WidthPixels, + HeightPixels: gw.HeightPixels, + } + } + }() + + return p, winCh, true +} + +// Signals registers a channel to receive signals from the client. +// Pass nil to unregister. +func (s *Session) Signals(c chan<- Signal) { + s.sess.Signals(c) +} + +// Exit sends an exit status to the client and closes the session. +func (s *Session) Exit(code int) error { + err := s.sess.Exit(code) + s.Close() + return err +} + +// Stderr returns an io.Writer for the SSH stderr channel. +func (s *Session) Stderr() io.Writer { + return s.sess.Stderr() +} + +// Context returns the session's context, which is canceled when the client +// disconnects. +func (s *Session) Context() context.Context { + return s.sess.Context() +} diff --git a/tsnet/example/ssh-game/README.md b/tsnet/example/ssh-game/README.md new file mode 100644 index 000000000..6fc0c598e --- /dev/null +++ b/tsnet/example/ssh-game/README.md @@ -0,0 +1,13 @@ + + +# ssh-game + +The ssh-game server demonstrates how to use tsnet's ListenSSH to build a custom SSH application. It runs a simple "guess the number" game. + +Usage: + + go run ./tsnet/example/ssh-game + +Then from another Tailscale node: + + ssh -p 2222 diff --git a/tsnet/example/ssh-game/ssh-game.go b/tsnet/example/ssh-game/ssh-game.go new file mode 100644 index 000000000..23d5d3d60 --- /dev/null +++ b/tsnet/example/ssh-game/ssh-game.go @@ -0,0 +1,92 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 + +// The ssh-game server demonstrates how to use tsnet's ListenSSH to build +// a custom SSH application. It runs a simple "guess the number" game. +// +// Usage: +// +// go run ./tsnet/example/ssh-game +// +// Then from another Tailscale node: +// +// ssh -p 2222 +package main + +import ( + "bufio" + "fmt" + "log" + "math/rand/v2" + "net" + "strings" + + _ "tailscale.com/feature/ssh" + "tailscale.com/ssh/tailssh" + "tailscale.com/tsnet" +) + +func main() { + s := &tsnet.Server{ + Hostname: "ssh-game", + } + defer s.Close() + + ln, err := s.ListenSSH(":2222") + if err != nil { + log.Fatal(err) + } + defer ln.Close() + log.Println("Listening on :2222") + + for { + conn, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + go handleGame(conn) + } +} + +func handleGame(c net.Conn) { + sess, ok := c.(*tailssh.Session) + if !ok { + fmt.Fprintf(c, "unexpected connection type\n") + c.Close() + return + } + defer sess.Exit(0) + + target := rand.IntN(100) + 1 + scanner := bufio.NewScanner(sess) + + fmt.Fprintf(sess, "Welcome, %s from %s!\r\n", + sess.UserProfile().LoginName, + sess.Peer().ComputedName()) + fmt.Fprintf(sess, "I'm thinking of a number between 1 and 100.\r\n") + fmt.Fprintf(sess, "Can you guess it?\r\n\r\n") + + for attempts := 1; ; attempts++ { + fmt.Fprintf(sess, "Your guess: ") + if !scanner.Scan() { + return + } + line := strings.TrimSpace(scanner.Text()) + var guess int + if _, err := fmt.Sscanf(line, "%d", &guess); err != nil { + fmt.Fprintf(sess, "Please enter a number.\r\n") + continue + } + switch { + case guess < target: + fmt.Fprintf(sess, "Higher!\r\n") + case guess > target: + fmt.Fprintf(sess, "Lower!\r\n") + default: + fmt.Fprintf(sess, "Correct! You got it in %d attempts.\r\n", attempts) + return + } + } +} diff --git a/tsnet/listenssh_test.go b/tsnet/listenssh_test.go new file mode 100644 index 000000000..989af1209 --- /dev/null +++ b/tsnet/listenssh_test.go @@ -0,0 +1,122 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 + +package tsnet + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "testing" + "time" + + gossh "golang.org/x/crypto/ssh" + + _ "tailscale.com/feature/ssh" + "tailscale.com/ssh/tailssh" + "tailscale.com/tstest" +) + +// TestListenSSH starts two tsnet nodes on a test tailnet, has one listen +// for SSH via ListenSSH, and has the other connect using the Go +// x/crypto/ssh client. The server verifies the command string and echoes +// back the connecting peer's login name, verifying that WhoIs and +// Peer/UserProfile work end-to-end. +func TestListenSSH(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + controlURL, _ := startControl(t) + srvNode, srvIP, _ := startServer(t, ctx, controlURL, "sshsrv") + clientNode, clientIP, _ := startServer(t, ctx, controlURL, "sshclient") + + // Listen for SSH on srvNode. + ln, err := srvNode.ListenSSH(":22") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { ln.Close() }) + + // Server goroutine: verify the command, then write the peer's login name back. + srvErrCh := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + srvErrCh <- err + return + } + sess := conn.(*tailssh.Session) + defer sess.Exit(0) + if got := sess.RawCommand(); got != "test-whoami" { + srvErrCh <- fmt.Errorf("server got command %q, want %q", got, "test-whoami") + return + } + fmt.Fprintf(sess, "%s\n", sess.UserProfile().LoginName) + srvErrCh <- nil + }() + + // Wait until srvNode knows about clientNode so WhoIs succeeds when the + // SSH connection arrives. + if err := tstest.WaitFor(30*time.Second, func() error { + lc, err := srvNode.LocalClient() + if err != nil { + return err + } + st, err := lc.Status(ctx) + if err != nil { + return err + } + for _, peer := range st.Peer { + for _, ip := range peer.TailscaleIPs { + if ip == clientIP { + return nil + } + } + } + return errors.New("clientNode not yet in srvNode's netmap") + }); err != nil { + t.Fatal(err) + } + + // Dial srvNode's SSH listener from clientNode's Tailscale network. + addr := net.JoinHostPort(srvIP.String(), "22") + tcpConn, err := clientNode.Dial(ctx, "tcp", addr) + if err != nil { + t.Fatal(err) + } + + // gliderssh defaults to NoClientAuth when no auth handler is registered, + // so no Auth methods are needed. + sshConn, chans, reqs, err := gossh.NewClientConn(tcpConn, addr, &gossh.ClientConfig{ + User: "test", + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + sshClient := gossh.NewClient(sshConn, chans, reqs) + defer sshClient.Close() + + session, err := sshClient.NewSession() + if err != nil { + t.Fatal(err) + } + out, err := session.Output("test-whoami") + if err != nil { + t.Fatalf("session.Output: %v", err) + } + + loginName := strings.TrimSpace(string(out)) + if loginName == "" { + t.Error("SSH server returned empty login name; WhoIs or Peer/UserProfile may be broken") + } + t.Logf("peer login name from SSH server: %q", loginName) + + if err := <-srvErrCh; err != nil { + t.Errorf("SSH server goroutine: %v", err) + } +} diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index eb72d28d3..5bc0dd79c 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -1264,6 +1264,33 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { return s.listen(network, addr, listenOnTailnet) } +// ListenSSH listens on the Tailscale network for SSH connections at the given +// addr (e.g. ":2222"). The returned listener's Accept method yields net.Conn +// values that are actually *tailssh.Session, providing access to the +// connecting peer's Tailscale identity, PTY information, signals, and more. +// +// Basic applications can use the returned connections as plain net.Conn +// (Read/Write/Close). Applications that need richer SSH semantics should +// type-assert to *tailssh.Session. +// +// SSH support must be linked into the binary by importing +// _ "tailscale.com/feature/ssh". Without that import, ListenSSH returns an +// error. +// +// If s has not been started yet, it will be started. +func (s *Server) ListenSSH(addr string) (net.Listener, error) { + rawLn, err := s.Listen("tcp", addr) + if err != nil { + return nil, err + } + sshLn, err := s.lb.ListenSSH(rawLn, s.logf) + if err != nil { + rawLn.Close() + return nil, err + } + return sshLn, nil +} + // ListenPacket announces on the Tailscale network. // // The network must be "udp", "udp4" or "udp6". The addr must be of the form