diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index ea5af0897..23d525a69 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -144,6 +144,29 @@ func RegisterNewSSHServer(fn newSSHServerFunc) { newSSHServer = fn } +// listenSSHFunc, if non-nil, is the implementation of ListenSSH provided by the +// ssh/tailssh package via RegisterListenSSH. +var listenSSHFunc func(net.Listener, *LocalBackend, logger.Logf) (net.Listener, error) + +// RegisterListenSSH lets the conditionally linked ssh/tailssh package register +// an implementation of ListenSSH for use by tsnet. +func RegisterListenSSH(fn func(net.Listener, *LocalBackend, logger.Logf) (net.Listener, error)) { + listenSSHFunc = fn +} + +// ListenSSH wraps the given listener with an SSH server that authenticates +// connections using Tailscale peer identity. The returned listener's Accept +// yields net.Conn values that are *tailssh.Session. +// +// If the ssh/tailssh package has not been linked (e.g. via +// _ "tailscale.com/feature/ssh"), ListenSSH returns an error. +func (b *LocalBackend) ListenSSH(ln net.Listener, logf logger.Logf) (net.Listener, error) { + if listenSSHFunc == nil { + return nil, errors.New("SSH support not available; import _ \"tailscale.com/feature/ssh\"") + } + return listenSSHFunc(ln, b, logf) +} + // watchSession represents a WatchNotifications channel, // an [ipnauth.Actor] that owns it (e.g., a connected GUI/CLI), // and sessionID as required to close targeted buses. diff --git a/ssh/tailssh/listen.go b/ssh/tailssh/listen.go new file mode 100644 index 000000000..04fa14d87 --- /dev/null +++ b/ssh/tailssh/listen.go @@ -0,0 +1,136 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 + +package tailssh + +import ( + "errors" + "net" + "net/netip" + "sync" + + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tempfork/gliderlabs/ssh" + "tailscale.com/types/logger" +) + +func init() { + ipnlocal.RegisterListenSSH(listenSSH) +} + +// listenSSH wraps rawLn with an SSH server that resolves Tailscale peer +// identity for each connection. The returned listener's Accept yields +// *Session values (as net.Conn). +func listenSSH(rawLn net.Listener, lb *ipnlocal.LocalBackend, logf logger.Logf) (net.Listener, error) { + hostKeys, err := getHostKeys(lb.TailscaleVarRoot(), logf) + if err != nil { + return nil, err + } + signers := make([]ssh.Signer, len(hostKeys)) + for i, k := range hostKeys { + signers[i] = k + } + + sl := &sshListener{ + rawLn: rawLn, + sessions: make(chan net.Conn, 16), + done: make(chan struct{}), + } + + sshSrv := &ssh.Server{ + HostSigners: signers, + Handler: func(sess ssh.Session) { + srcAddr := sess.RemoteAddr().String() + ipp, err := netip.ParseAddrPort(srcAddr) + if err != nil { + logf("listenSSH: bad remote addr %q: %v", srcAddr, err) + sess.Exit(1) + return + } + node, userProfile, ok := lb.WhoIs("tcp", ipp) + if !ok { + logf("listenSSH: WhoIs failed for %v", srcAddr) + sess.Exit(1) + return + } + + done := make(chan struct{}) + s := newSession(sess, PeerIdentity{ + Node: node, + UserProfile: userProfile, + }, done) + + // Send the session to the listener. If the listener is + // closed, drop the session. + select { + case sl.sessions <- s: + case <-sl.done: + sess.Exit(1) + return + } + + // Block until the consumer is done with the session. + select { + case <-done: + case <-sess.Context().Done(): + case <-sl.done: + } + }, + } + + go func() { + if err := sshSrv.Serve(rawLn); err != nil { + // Serve returns when the listener is closed. Only log + // unexpected errors. + select { + case <-sl.done: + default: + logf("listenSSH: Serve error: %v", err) + } + } + sl.Close() + }() + + return sl, nil +} + +// sshListener is a net.Listener that yields *Session values from its Accept +// method. It wraps a raw TCP listener with an SSH server. +type sshListener struct { + rawLn net.Listener + sessions chan net.Conn + done chan struct{} + closeOnce sync.Once +} + +// Accept returns the next SSH session as a net.Conn. The returned value can +// be type-asserted to *Session. +func (l *sshListener) Accept() (net.Conn, error) { + select { + case s, ok := <-l.sessions: + if !ok { + return nil, errors.New("listener closed") + } + return s, nil + case <-l.done: + return nil, errors.New("listener closed") + } +} + +// Close closes the underlying raw listener and signals all pending sessions +// to terminate. +func (l *sshListener) Close() error { + var err error + l.closeOnce.Do(func() { + close(l.done) + err = l.rawLn.Close() + }) + return err +} + +// Addr returns the address of the underlying raw listener. +func (l *sshListener) Addr() net.Addr { + return l.rawLn.Addr() +} diff --git a/ssh/tailssh/session.go b/ssh/tailssh/session.go new file mode 100644 index 000000000..d88df1455 --- /dev/null +++ b/ssh/tailssh/session.go @@ -0,0 +1,210 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 + +package tailssh + +import ( + "context" + "errors" + "io" + "net" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tempfork/gliderlabs/ssh" +) + +var errNoDeadline = errors.New("tailssh.Session: deadlines not supported") + +// Signal represents an SSH signal (e.g. "INT", "TERM"). +type Signal = ssh.Signal + +// Pty represents a PTY request and configuration. +type Pty struct { + // Term is the TERM environment variable value. + Term string + + // Window is the initial window size. + Window Window + + // Modes are the RFC 4254 terminal modes as opcode/value pairs. + Modes map[uint8]uint32 +} + +// Window represents the size of a PTY window. +type Window struct { + Width int + Height int + WidthPixels int + HeightPixels int +} + +// PeerIdentity contains the Tailscale identity of the connecting SSH peer. +type PeerIdentity struct { + Node tailcfg.NodeView + UserProfile tailcfg.UserProfile +} + +// Session wraps a gliderlabs ssh.Session with Tailscale peer identity +// information. It implements net.Conn so callers that only need Read/Write/Close +// can use it directly. Callers that need SSH-specific functionality can +// type-assert from the net.Conn returned by the listener's Accept. +type Session struct { + // sess is the underlying gliderlabs SSH session. + sess ssh.Session + + // peer is the Tailscale identity of the remote peer. + peer PeerIdentity + + // done is closed when the session handler should return, + // unblocking the gliderlabs handler goroutine. + done chan struct{} +} + +// newSession creates a new Session wrapping the given gliderlabs session and +// peer identity. The done channel is closed by the session consumer to signal +// that the handler goroutine may return. +func newSession(sess ssh.Session, peer PeerIdentity, done chan struct{}) *Session { + return &Session{ + sess: sess, + peer: peer, + done: done, + } +} + +// Read reads from the SSH channel (stdin from the client). +func (s *Session) Read(p []byte) (int, error) { + return s.sess.Read(p) +} + +// Write writes to the SSH channel (stdout to the client). +func (s *Session) Write(p []byte) (int, error) { + return s.sess.Write(p) +} + +// Close signals the session handler to return and closes the underlying channel. +func (s *Session) Close() error { + select { + case <-s.done: + default: + close(s.done) + } + return nil +} + +// RemoteAddr returns the net.Addr of the client side of the connection. +func (s *Session) RemoteAddr() net.Addr { + return s.sess.RemoteAddr() +} + +// LocalAddr returns the net.Addr of the server side of the connection. +func (s *Session) LocalAddr() net.Addr { + return s.sess.LocalAddr() +} + +// SetDeadline is not supported and returns an error. +func (s *Session) SetDeadline(t time.Time) error { + return errNoDeadline +} + +// SetReadDeadline is not supported and returns an error. +func (s *Session) SetReadDeadline(t time.Time) error { + return errNoDeadline +} + +// SetWriteDeadline is not supported and returns an error. +func (s *Session) SetWriteDeadline(t time.Time) error { + return errNoDeadline +} + +// User returns the SSH username. +func (s *Session) User() string { + return s.sess.User() +} + +// PeerIdentity returns the Tailscale identity of the remote peer. +func (s *Session) PeerIdentity() PeerIdentity { + return s.peer +} + +// Environ returns a copy of the environment variables set by the client. +func (s *Session) Environ() []string { + return s.sess.Environ() +} + +// RawCommand returns the exact command string provided by the client. +func (s *Session) RawCommand() string { + return s.sess.RawCommand() +} + +// Subsystem returns the subsystem requested by the client. +func (s *Session) Subsystem() string { + return s.sess.Subsystem() +} + +// Pty returns PTY information, a channel of window size changes, and whether +// a PTY was requested. The returned types use this package's Pty and Window +// types rather than the internal gliderlabs types. +func (s *Session) Pty() (Pty, <-chan Window, bool) { + gPty, gWinCh, ok := s.sess.Pty() + if !ok { + return Pty{}, nil, false + } + p := Pty{ + Term: gPty.Term, + Window: Window{ + Width: gPty.Window.Width, + Height: gPty.Window.Height, + WidthPixels: gPty.Window.WidthPixels, + HeightPixels: gPty.Window.HeightPixels, + }, + } + if gPty.Modes != nil { + p.Modes = make(map[uint8]uint32, len(gPty.Modes)) + for k, v := range gPty.Modes { + p.Modes[k] = v + } + } + + // Convert the gliderlabs Window channel to our Window type. + winCh := make(chan Window, 1) + go func() { + defer close(winCh) + for gw := range gWinCh { + winCh <- Window{ + Width: gw.Width, + Height: gw.Height, + WidthPixels: gw.WidthPixels, + HeightPixels: gw.HeightPixels, + } + } + }() + + return p, winCh, true +} + +// Signals registers a channel to receive signals from the client. +// Pass nil to unregister. +func (s *Session) Signals(c chan<- Signal) { + s.sess.Signals(c) +} + +// Exit sends an exit status to the client and closes the session. +func (s *Session) Exit(code int) error { + err := s.sess.Exit(code) + s.Close() + return err +} + +// Stderr returns an io.Writer for the SSH stderr channel. +func (s *Session) Stderr() io.Writer { + return s.sess.Stderr() +} + +// Context returns the session's context, which is canceled when the client +// disconnects. +func (s *Session) Context() context.Context { + return s.sess.Context() +} diff --git a/tsnet/example/ssh-game/ssh-game.go b/tsnet/example/ssh-game/ssh-game.go new file mode 100644 index 000000000..a1b07e50c --- /dev/null +++ b/tsnet/example/ssh-game/ssh-game.go @@ -0,0 +1,91 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// The ssh-game server demonstrates how to use tsnet's ListenSSH to build +// a custom SSH application. It runs a simple "guess the number" game. +// +// Usage: +// +// go run ./tsnet/example/ssh-game +// +// Then from another Tailscale node: +// +// ssh -p 2222 +package main + +import ( + "bufio" + "fmt" + "log" + "math/rand/v2" + "net" + "strings" + + _ "tailscale.com/feature/ssh" + "tailscale.com/ssh/tailssh" + "tailscale.com/tsnet" +) + +func main() { + s := &tsnet.Server{ + Hostname: "ssh-game", + } + defer s.Close() + + ln, err := s.ListenSSH(":2222") + if err != nil { + log.Fatal(err) + } + defer ln.Close() + log.Println("Listening on :2222") + + for { + conn, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + go handleGame(conn) + } +} + +func handleGame(c net.Conn) { + sess, ok := c.(*tailssh.Session) + if !ok { + fmt.Fprintf(c, "unexpected connection type\n") + c.Close() + return + } + defer sess.Exit(0) + + peer := sess.PeerIdentity() + target := rand.IntN(100) + 1 + scanner := bufio.NewScanner(sess) + + fmt.Fprintf(sess, "Welcome, %s from %s!\r\n", + peer.UserProfile.LoginName, + peer.Node.ComputedName()) + fmt.Fprintf(sess, "I'm thinking of a number between 1 and 100.\r\n") + fmt.Fprintf(sess, "Can you guess it?\r\n\r\n") + + for attempts := 1; ; attempts++ { + fmt.Fprintf(sess, "Your guess: ") + if !scanner.Scan() { + return + } + line := strings.TrimSpace(scanner.Text()) + var guess int + if _, err := fmt.Sscanf(line, "%d", &guess); err != nil { + fmt.Fprintf(sess, "Please enter a number.\r\n") + continue + } + switch { + case guess < target: + fmt.Fprintf(sess, "Higher!\r\n") + case guess > target: + fmt.Fprintf(sess, "Lower!\r\n") + default: + fmt.Fprintf(sess, "Correct! You got it in %d attempts.\r\n", attempts) + return + } + } +} diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 4a116cf34..57cd5a004 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -1074,6 +1074,33 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { return s.listen(network, addr, listenOnTailnet) } +// ListenSSH listens on the Tailscale network for SSH connections at the given +// addr (e.g. ":2222"). The returned listener's Accept method yields net.Conn +// values that are actually *tailssh.Session, providing access to the +// connecting peer's Tailscale identity, PTY information, signals, and more. +// +// Basic applications can use the returned connections as plain net.Conn +// (Read/Write/Close). Applications that need richer SSH semantics should +// type-assert to *tailssh.Session. +// +// SSH support must be linked into the binary by importing +// _ "tailscale.com/feature/ssh". Without that import, ListenSSH returns an +// error. +// +// If s has not been started yet, it will be started. +func (s *Server) ListenSSH(addr string) (net.Listener, error) { + rawLn, err := s.Listen("tcp", addr) + if err != nil { + return nil, err + } + sshLn, err := s.lb.ListenSSH(rawLn, s.logf) + if err != nil { + rawLn.Close() + return nil, err + } + return sshLn, nil +} + // ListenPacket announces on the Tailscale network. // // The network must be "udp", "udp4" or "udp6". The addr must be of the form