tsnet: add opt-in SSH support (Server.ListenSSH)

This adds tsnet.Server.ListenSSH which, if the SSH feature is linked,
returns a net.Listener whose Accept yields *tailssh.Session values (as
net.Conn). This lets tsnet apps accept incoming SSH connections to
implement custom TUI applications.

Basic apps can use net.Conn directly (Read/Write/Close). Rich apps
import ssh/tailssh and type-assert for peer identity, PTY, signals,
etc. If feature/ssh isn't imported, ListenSSH returns an error.

Includes a demo guess-the-number game in tsnet/example/ssh-game.

Updates tailscale/corp#37839

Change-Id: I4e7c3c96afb030cdf4da8f2d8b2253820628129a
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2026-03-10 21:33:12 +00:00
committed by Brad Fitzpatrick
parent c9333854fb
commit 3e34e721e8
7 changed files with 618 additions and 0 deletions

View File

@@ -144,6 +144,24 @@ func RegisterNewSSHServer(fn newSSHServerFunc) {
newSSHServer = fn
}
// HookListenSSH is set by the ssh/tailssh package (via feature/ssh) to provide
// an implementation of ListenSSH for use by tsnet.
var HookListenSSH feature.Hook[func(net.Listener, *LocalBackend, logger.Logf) (net.Listener, error)]
// ListenSSH wraps the given listener with an SSH server that authenticates
// connections using Tailscale peer identity. The returned listener's Accept
// yields net.Conn values that are *tailssh.Session.
//
// If the ssh/tailssh package has not been linked (e.g. via
// _ "tailscale.com/feature/ssh"), ListenSSH returns an error.
func (b *LocalBackend) ListenSSH(ln net.Listener, logf logger.Logf) (net.Listener, error) {
fn, ok := HookListenSSH.GetOk()
if !ok {
return nil, errors.New("SSH support not available; import _ \"tailscale.com/feature/ssh\"")
}
return fn(ln, b, logf)
}
// watchSession represents a WatchNotifications channel,
// an [ipnauth.Actor] that owns it (e.g., a connected GUI/CLI),
// and sessionID as required to close targeted buses.

136
ssh/tailssh/listen.go Normal file
View File

@@ -0,0 +1,136 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9
package tailssh
import (
"errors"
"net"
"net/netip"
"sync"
gliderssh "github.com/tailscale/gliderssh"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/types/logger"
)
func init() {
ipnlocal.HookListenSSH.Set(listenSSH)
}
// listenSSH wraps rawLn with an SSH server that resolves Tailscale peer
// identity for each connection. The returned listener's Accept yields
// *Session values (as net.Conn).
func listenSSH(rawLn net.Listener, lb *ipnlocal.LocalBackend, logf logger.Logf) (net.Listener, error) {
hostKeys, err := getHostKeys(lb.TailscaleVarRoot(), logf)
if err != nil {
return nil, err
}
signers := make([]gliderssh.Signer, len(hostKeys))
for i, k := range hostKeys {
signers[i] = k
}
sl := &sshListener{
rawLn: rawLn,
sessions: make(chan net.Conn, 16),
done: make(chan struct{}),
}
sshSrv := &gliderssh.Server{
HostSigners: signers,
Handler: func(sess gliderssh.Session) {
srcAddr := sess.RemoteAddr().String()
ipp, err := netip.ParseAddrPort(srcAddr)
if err != nil {
logf("listenSSH: bad remote addr %q: %v", srcAddr, err)
sess.Exit(1)
return
}
node, userProfile, ok := lb.WhoIs("tcp", ipp)
if !ok {
logf("listenSSH: WhoIs failed for %v", srcAddr)
sess.Exit(1)
return
}
done := make(chan struct{})
s := newSession(sess, peerIdentity{
Node: node,
UserProfile: userProfile,
}, done)
// Send the session to the listener. If the listener is
// closed, drop the session.
select {
case sl.sessions <- s:
case <-sl.done:
sess.Exit(1)
return
}
// Block until the consumer is done with the session.
select {
case <-done:
case <-sess.Context().Done():
case <-sl.done:
}
},
}
go func() {
if err := sshSrv.Serve(rawLn); err != nil {
// Serve returns when the listener is closed. Only log
// unexpected errors.
select {
case <-sl.done:
default:
logf("listenSSH: Serve error: %v", err)
}
}
sl.Close()
}()
return sl, nil
}
// sshListener is a net.Listener that yields *Session values from its Accept
// method. It wraps a raw TCP listener with an SSH server.
type sshListener struct {
rawLn net.Listener
sessions chan net.Conn
done chan struct{}
closeOnce sync.Once
}
// Accept returns the next SSH session as a net.Conn. The returned value can
// be type-asserted to *Session.
func (sl *sshListener) Accept() (net.Conn, error) {
select {
case s, ok := <-sl.sessions:
if !ok {
return nil, errors.New("listener closed")
}
return s, nil
case <-sl.done:
return nil, errors.New("listener closed")
}
}
// Close closes the underlying raw listener and signals all pending sessions
// to terminate.
func (sl *sshListener) Close() error {
var err error
sl.closeOnce.Do(func() {
close(sl.done)
err = sl.rawLn.Close()
})
return err
}
// Addr returns the address of the underlying raw listener.
func (sl *sshListener) Addr() net.Addr {
return sl.rawLn.Addr()
}

210
ssh/tailssh/session.go Normal file
View File

@@ -0,0 +1,210 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9
package tailssh
import (
"context"
"errors"
"io"
"maps"
"net"
"sync"
"time"
gliderssh "github.com/tailscale/gliderssh"
"tailscale.com/tailcfg"
)
var errNoDeadline = errors.New("tailssh.Session: deadlines not supported")
// Signal represents an SSH signal (e.g. "INT", "TERM").
type Signal = gliderssh.Signal
// Pty represents a PTY request and configuration.
type Pty struct {
// Term is the TERM environment variable value.
Term string
// Window is the initial window size.
Window Window
// Modes are the RFC 4254 terminal modes as opcode/value pairs.
Modes map[uint8]uint32
}
// Window represents the size of a PTY window.
type Window struct {
Width int // columns
Height int // rows
WidthPixels int // width in pixels
HeightPixels int // height in pixels
}
// peerIdentity contains the Tailscale identity of the connecting SSH peer.
type peerIdentity struct {
Node tailcfg.NodeView // node connecting
UserProfile tailcfg.UserProfile
}
// Session wraps a gliderlabs gliderssh.Session with Tailscale peer identity
// information. It implements net.Conn so callers that only need Read/Write/Close
// can use it directly. Callers that need SSH-specific functionality can
// type-assert from the net.Conn returned by the listener's Accept.
type Session struct {
// sess is the underlying gliderlabs SSH session.
sess gliderssh.Session
// peer is the Tailscale identity of the remote peer.
peer peerIdentity
// done is closed when the session handler should return,
// unblocking the gliderlabs handler goroutine.
doneOnce sync.Once // guards close(done)
done chan struct{}
}
// newSession creates a new Session wrapping the given gliderlabs session and
// peer identity. The done channel is closed by the session consumer to signal
// that the handler goroutine may return.
func newSession(sess gliderssh.Session, peer peerIdentity, done chan struct{}) *Session {
return &Session{
sess: sess,
peer: peer,
done: done,
}
}
// Read reads from the SSH channel (stdin from the client).
func (s *Session) Read(p []byte) (int, error) {
return s.sess.Read(p)
}
// Write writes to the SSH channel (stdout to the client).
func (s *Session) Write(p []byte) (int, error) {
return s.sess.Write(p)
}
// Close signals the session handler to return and closes the underlying channel.
func (s *Session) Close() error {
s.doneOnce.Do(func() { close(s.done) })
return nil
}
// RemoteAddr returns the net.Addr of the client side of the connection.
func (s *Session) RemoteAddr() net.Addr {
return s.sess.RemoteAddr()
}
// LocalAddr returns the net.Addr of the server side of the connection.
func (s *Session) LocalAddr() net.Addr {
return s.sess.LocalAddr()
}
// SetDeadline is not supported and returns an error.
func (s *Session) SetDeadline(t time.Time) error {
return errNoDeadline
}
// SetReadDeadline is not supported and returns an error.
func (s *Session) SetReadDeadline(t time.Time) error {
return errNoDeadline
}
// SetWriteDeadline is not supported and returns an error.
func (s *Session) SetWriteDeadline(t time.Time) error {
return errNoDeadline
}
// User returns the SSH username.
func (s *Session) User() string {
return s.sess.User()
}
// Peer returns the Tailscale identity of the remote node.
func (s *Session) Peer() tailcfg.NodeView { return s.peer.Node }
// UserProfile returns the Tailscale user profile of the remote node.
//
// For tagged nodes, this is same sort of UserProfile that is returned by the
// LocalAPI WhoIs API.
func (s *Session) UserProfile() tailcfg.UserProfile { return s.peer.UserProfile }
// Environ returns a copy of the environment variables set by the client.
func (s *Session) Environ() []string {
return s.sess.Environ()
}
// RawCommand returns the exact command string provided by the client.
func (s *Session) RawCommand() string {
return s.sess.RawCommand()
}
// Subsystem returns the subsystem requested by the client.
func (s *Session) Subsystem() string {
return s.sess.Subsystem()
}
// Pty returns PTY information, a channel of window size changes, and whether a
// PTY was requested.
func (s *Session) Pty() (_ Pty, _ <-chan Window, ok bool) {
gPty, gWinCh, ok := s.sess.Pty()
if !ok {
return Pty{}, nil, false
}
p := Pty{
Term: gPty.Term,
Window: Window{
Width: gPty.Window.Width,
Height: gPty.Window.Height,
WidthPixels: gPty.Window.WidthPixels,
HeightPixels: gPty.Window.HeightPixels,
},
}
if gPty.Modes != nil {
p.Modes = make(map[uint8]uint32, len(gPty.Modes))
maps.Copy(p.Modes, gPty.Modes)
}
// Convert the gliderlabs Window channel to our Window type.
winCh := make(chan Window, 1)
go func() {
defer close(winCh)
for gw := range gWinCh {
winCh <- Window{
Width: gw.Width,
Height: gw.Height,
WidthPixels: gw.WidthPixels,
HeightPixels: gw.HeightPixels,
}
}
}()
return p, winCh, true
}
// Signals registers a channel to receive signals from the client.
// Pass nil to unregister.
func (s *Session) Signals(c chan<- Signal) {
s.sess.Signals(c)
}
// Exit sends an exit status to the client and closes the session.
func (s *Session) Exit(code int) error {
err := s.sess.Exit(code)
s.Close()
return err
}
// Stderr returns an io.Writer for the SSH stderr channel.
func (s *Session) Stderr() io.Writer {
return s.sess.Stderr()
}
// Context returns the session's context, which is canceled when the client
// disconnects.
func (s *Session) Context() context.Context {
return s.sess.Context()
}

View File

@@ -0,0 +1,13 @@
<!-- README.md auto-generated by misc/genreadme; DO NOT EDIT. (or remove this line) -->
# ssh-game
The ssh-game server demonstrates how to use tsnet's ListenSSH to build a custom SSH application. It runs a simple "guess the number" game.
Usage:
go run ./tsnet/example/ssh-game
Then from another Tailscale node:
ssh -p 2222 <hostname>

View File

@@ -0,0 +1,92 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9
// The ssh-game server demonstrates how to use tsnet's ListenSSH to build
// a custom SSH application. It runs a simple "guess the number" game.
//
// Usage:
//
// go run ./tsnet/example/ssh-game
//
// Then from another Tailscale node:
//
// ssh -p 2222 <hostname>
package main
import (
"bufio"
"fmt"
"log"
"math/rand/v2"
"net"
"strings"
_ "tailscale.com/feature/ssh"
"tailscale.com/ssh/tailssh"
"tailscale.com/tsnet"
)
func main() {
s := &tsnet.Server{
Hostname: "ssh-game",
}
defer s.Close()
ln, err := s.ListenSSH(":2222")
if err != nil {
log.Fatal(err)
}
defer ln.Close()
log.Println("Listening on :2222")
for {
conn, err := ln.Accept()
if err != nil {
log.Fatal(err)
}
go handleGame(conn)
}
}
func handleGame(c net.Conn) {
sess, ok := c.(*tailssh.Session)
if !ok {
fmt.Fprintf(c, "unexpected connection type\n")
c.Close()
return
}
defer sess.Exit(0)
target := rand.IntN(100) + 1
scanner := bufio.NewScanner(sess)
fmt.Fprintf(sess, "Welcome, %s from %s!\r\n",
sess.UserProfile().LoginName,
sess.Peer().ComputedName())
fmt.Fprintf(sess, "I'm thinking of a number between 1 and 100.\r\n")
fmt.Fprintf(sess, "Can you guess it?\r\n\r\n")
for attempts := 1; ; attempts++ {
fmt.Fprintf(sess, "Your guess: ")
if !scanner.Scan() {
return
}
line := strings.TrimSpace(scanner.Text())
var guess int
if _, err := fmt.Sscanf(line, "%d", &guess); err != nil {
fmt.Fprintf(sess, "Please enter a number.\r\n")
continue
}
switch {
case guess < target:
fmt.Fprintf(sess, "Higher!\r\n")
case guess > target:
fmt.Fprintf(sess, "Lower!\r\n")
default:
fmt.Fprintf(sess, "Correct! You got it in %d attempts.\r\n", attempts)
return
}
}
}

122
tsnet/listenssh_test.go Normal file
View File

@@ -0,0 +1,122 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9
package tsnet
import (
"context"
"errors"
"fmt"
"net"
"strings"
"testing"
"time"
gossh "golang.org/x/crypto/ssh"
_ "tailscale.com/feature/ssh"
"tailscale.com/ssh/tailssh"
"tailscale.com/tstest"
)
// TestListenSSH starts two tsnet nodes on a test tailnet, has one listen
// for SSH via ListenSSH, and has the other connect using the Go
// x/crypto/ssh client. The server verifies the command string and echoes
// back the connecting peer's login name, verifying that WhoIs and
// Peer/UserProfile work end-to-end.
func TestListenSSH(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()
controlURL, _ := startControl(t)
srvNode, srvIP, _ := startServer(t, ctx, controlURL, "sshsrv")
clientNode, clientIP, _ := startServer(t, ctx, controlURL, "sshclient")
// Listen for SSH on srvNode.
ln, err := srvNode.ListenSSH(":22")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { ln.Close() })
// Server goroutine: verify the command, then write the peer's login name back.
srvErrCh := make(chan error, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
srvErrCh <- err
return
}
sess := conn.(*tailssh.Session)
defer sess.Exit(0)
if got := sess.RawCommand(); got != "test-whoami" {
srvErrCh <- fmt.Errorf("server got command %q, want %q", got, "test-whoami")
return
}
fmt.Fprintf(sess, "%s\n", sess.UserProfile().LoginName)
srvErrCh <- nil
}()
// Wait until srvNode knows about clientNode so WhoIs succeeds when the
// SSH connection arrives.
if err := tstest.WaitFor(30*time.Second, func() error {
lc, err := srvNode.LocalClient()
if err != nil {
return err
}
st, err := lc.Status(ctx)
if err != nil {
return err
}
for _, peer := range st.Peer {
for _, ip := range peer.TailscaleIPs {
if ip == clientIP {
return nil
}
}
}
return errors.New("clientNode not yet in srvNode's netmap")
}); err != nil {
t.Fatal(err)
}
// Dial srvNode's SSH listener from clientNode's Tailscale network.
addr := net.JoinHostPort(srvIP.String(), "22")
tcpConn, err := clientNode.Dial(ctx, "tcp", addr)
if err != nil {
t.Fatal(err)
}
// gliderssh defaults to NoClientAuth when no auth handler is registered,
// so no Auth methods are needed.
sshConn, chans, reqs, err := gossh.NewClientConn(tcpConn, addr, &gossh.ClientConfig{
User: "test",
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
})
if err != nil {
t.Fatal(err)
}
sshClient := gossh.NewClient(sshConn, chans, reqs)
defer sshClient.Close()
session, err := sshClient.NewSession()
if err != nil {
t.Fatal(err)
}
out, err := session.Output("test-whoami")
if err != nil {
t.Fatalf("session.Output: %v", err)
}
loginName := strings.TrimSpace(string(out))
if loginName == "" {
t.Error("SSH server returned empty login name; WhoIs or Peer/UserProfile may be broken")
}
t.Logf("peer login name from SSH server: %q", loginName)
if err := <-srvErrCh; err != nil {
t.Errorf("SSH server goroutine: %v", err)
}
}

View File

@@ -1264,6 +1264,33 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
return s.listen(network, addr, listenOnTailnet)
}
// ListenSSH listens on the Tailscale network for SSH connections at the given
// addr (e.g. ":2222"). The returned listener's Accept method yields net.Conn
// values that are actually *tailssh.Session, providing access to the
// connecting peer's Tailscale identity, PTY information, signals, and more.
//
// Basic applications can use the returned connections as plain net.Conn
// (Read/Write/Close). Applications that need richer SSH semantics should
// type-assert to *tailssh.Session.
//
// SSH support must be linked into the binary by importing
// _ "tailscale.com/feature/ssh". Without that import, ListenSSH returns an
// error.
//
// If s has not been started yet, it will be started.
func (s *Server) ListenSSH(addr string) (net.Listener, error) {
rawLn, err := s.Listen("tcp", addr)
if err != nil {
return nil, err
}
sshLn, err := s.lb.ListenSSH(rawLn, s.logf)
if err != nil {
rawLn.Close()
return nil, err
}
return sshLn, nil
}
// ListenPacket announces on the Tailscale network.
//
// The network must be "udp", "udp4" or "udp6". The addr must be of the form