From f274a0cfab11a05b8ee95e1ac735bc1969097e3d Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Tue, 6 Jul 2021 17:52:04 -0700 Subject: [PATCH] simplify change point at the cost of some very, very naughty unsafe --- net/tstun/wrap.go | 73 +++----------------- net/uring/file_test.go | 2 +- net/uring/io_uring_linux.go | 117 +++++++++++++++++++++++++++++--- wgengine/magicsock/magicsock.go | 6 +- 4 files changed, 122 insertions(+), 76 deletions(-) diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 7fa0ea4d9..93120e85f 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -12,13 +12,10 @@ "os" "sync" "sync/atomic" - "syscall" "time" - "golang.org/x/net/ipv6" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - wgtun "golang.zx2c4.com/wireguard/tun" // import twice to work around shadowing (TODO: fix properly) "inet.af/netaddr" "tailscale.com/net/packet" "tailscale.com/net/uring" @@ -66,9 +63,6 @@ type Wrapper struct { // tdev is the underlying Wrapper device. tdev tun.Device - // uring performs writes to the underlying Wrapper device. - ring *uring.File - closeOnce sync.Once lastActivityAtomic int64 // unix seconds of last send or receive @@ -167,12 +161,15 @@ func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper { filterFlags: filter.LogAccepts | filter.LogDrops, } - f := tdev.(*wgtun.NativeTun).File() - ring, err := uring.NewFile(f) - if err != nil { - panic(err) // TODO: just log? wat? - } else { - tun.ring = ring + if uring.Available() { + uringTun, err := uring.NewTUN(tdev) + name, _ := tdev.Name() + if err != nil { + logf("not using io_uring for TUN %v: %v", name, err) + } else { + logf("using uring for TUN %v", name) + tdev = uringTun + } } go tun.poll() @@ -315,7 +312,7 @@ func (t *Wrapper) poll() { if t.isClosed() { return } - n, err = t.read(t.buffer[:], PacketStartOffset) + n, err = t.tdev.Read(t.buffer[:], PacketStartOffset) } t.sendOutbound(tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err}) } @@ -534,55 +531,7 @@ func (t *Wrapper) Write(buf []byte, offset int) (int, error) { } t.noteActivity() - return t.write(buf, offset) -} - -func (t *Wrapper) write(buf []byte, offset int) (int, error) { - if t.ring == nil { - return t.tdev.Write(buf, offset) - } - - // below copied from wireguard-go NativeTUN.Write - - // reserve space for header - buf = buf[offset-4:] - - // add packet information header - buf[0] = 0x00 - buf[1] = 0x00 - if buf[4]>>4 == ipv6.Version { - buf[2] = 0x86 - buf[3] = 0xdd - } else { - buf[2] = 0x08 - buf[3] = 0x00 - } - - n, err := t.ring.Write(buf) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } - return n, err -} - -func (t *Wrapper) read(buf []byte, offset int) (n int, err error) { - // TODO: upstream has graceful shutdown error handling here. - buff := buf[offset-4:] - const useIOUring = true - if useIOUring { - n, err = t.ring.Read(buff[:]) - } else { - n, err = t.tdev.(*wgtun.NativeTun).File().Read(buff[:]) - } - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } - if n < 4 { - n = 0 - } else { - n -= 4 - } - return + return t.tdev.Write(buf, offset) } func (t *Wrapper) GetFilter() *filter.Filter { diff --git a/net/uring/file_test.go b/net/uring/file_test.go index 978f27a93..6975e3684 100644 --- a/net/uring/file_test.go +++ b/net/uring/file_test.go @@ -21,7 +21,7 @@ func TestFileRead(t *testing.T) { c.Assert(err, qt.IsNil) t.Cleanup(func() { f.Close() }) - uf, err := NewFile(f) + uf, err := newFile(f) if err != nil { t.Skipf("io_uring not available: %v", err) } diff --git a/net/uring/io_uring_linux.go b/net/uring/io_uring_linux.go index d7c834062..c030ae4ab 100644 --- a/net/uring/io_uring_linux.go +++ b/net/uring/io_uring_linux.go @@ -16,7 +16,9 @@ "time" "unsafe" + "golang.org/x/net/ipv6" "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" "inet.af/netaddr" ) @@ -48,7 +50,11 @@ type UDPConn struct { is4 bool } -func NewUDPConn(conn *net.UDPConn) (*UDPConn, error) { +func NewUDPConn(pconn net.PacketConn) (*UDPConn, error) { + conn, ok := pconn.(*net.UDPConn) + if !ok { + return nil, fmt.Errorf("cannot use io_uring with conn of type %T", pconn) + } // this is dumb local := conn.LocalAddr().String() ip, err := netaddr.ParseIPPort(local) @@ -290,7 +296,7 @@ func (c *UDPConn) SetWriteDeadline(t time.Time) error { // A File is a write-only file fd manager. // TODO: Support reads // TODO: all the todos from UDPConn -type File struct { +type file struct { writeRing *C.go_uring readRing *C.go_uring close sync.Once @@ -301,10 +307,10 @@ type File struct { writeReqC chan int // indices into reqs } -func NewFile(file *os.File) (*File, error) { - fd := C.int(file.Fd()) - u := &File{ - file: file, +func newFile(f *os.File) (*file, error) { + fd := C.int(f.Fd()) + u := &file{ + file: f, fd: fd, } for _, ringPtr := range []**C.go_uring{&u.writeRing, &u.readRing} { @@ -340,7 +346,7 @@ func NewFile(file *os.File) (*File, error) { return u, nil } -func (u *File) submitReadvRequest(idx int) error { +func (u *file) submitReadvRequest(idx int) error { // TODO: make a C struct instead of a Go struct, and pass that in, to simplify call sites. errno := C.submit_readv_request(u.readRing, u.readReqs[idx], C.size_t(idx)) if errno < 0 { @@ -363,7 +369,7 @@ type fileReq struct { // Read data into buf[offset:]. // We are allowed to write junk into buf[offset-4:offset]. -func (u *File) Read(buf []byte) (n int, err error) { // read a packet from the device (without any additional headers) +func (u *file) Read(buf []byte) (n int, err error) { // read a packet from the device (without any additional headers) if u.fd == 0 { return 0, errors.New("invalid uring.File") } @@ -383,7 +389,7 @@ func (u *File) Read(buf []byte) (n int, err error) { // read a packet from the d return n, nil } -func (u *File) Write(buf []byte) (int, error) { +func (u *file) Write(buf []byte) (int, error) { if u.fd == 0 { return 0, errors.New("invalid uring.FileConn") } @@ -421,7 +427,7 @@ func (u *File) Write(buf []byte) (int, error) { } // TODO: the TODOs from UDPConn.Close -func (u *File) Close() error { +func (u *file) Close() error { u.close.Do(func() { u.file.Close() // TODO: require kernel 5.5, send an abort SQE, handle aborts gracefully @@ -442,3 +448,94 @@ func (u *File) Close() error { }) return nil } + +// Wrap files into TUN devices. + +func NewTUN(d tun.Device) (tun.Device, error) { + nt, ok := d.(*tun.NativeTun) + if !ok { + return nil, fmt.Errorf("NewTUN only wraps *tun.NativeTun, got %T", d) + } + f, err := newFile(nt.File()) + if err != nil { + return nil, err + } + v := reflect.ValueOf(nt) + field, ok := v.Elem().Type().FieldByName("errors") + if !ok { + return nil, errors.New("could not find internal tun.NativeTun errors field") + } + ptr := unsafe.Pointer(nt) + ptr = unsafe.Pointer(uintptr(ptr) + field.Offset) // TODO: switch to unsafe.Add with Go 1.17...as if that's the worst thing in this line + c := *(*chan error)(ptr) + return &TUN{d: nt, f: f, errors: c}, nil +} + +// No nopi +type TUN struct { + d *tun.NativeTun + f *file + errors chan error +} + +func (t *TUN) File() *os.File { + return t.f.file +} + +func (t *TUN) Read(buf []byte, offset int) (int, error) { + select { + case err := <-t.errors: + return 0, err + default: + } + // TODO: upstream has graceful shutdown error handling here. + buff := buf[offset-4:] + n, err := t.f.Read(buff[:]) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + if n < 4 { + n = 0 + } else { + n -= 4 + } + return n, err +} + +func (t *TUN) Write(buf []byte, offset int) (int, error) { + // below copied from wireguard-go NativeTun.Write + + // reserve space for header + buf = buf[offset-4:] + + // add packet information header + buf[0] = 0x00 + buf[1] = 0x00 + if buf[4]>>4 == ipv6.Version { + buf[2] = 0x86 + buf[3] = 0xdd + } else { + buf[2] = 0x08 + buf[3] = 0x00 + } + + n, err := t.f.Write(buf) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + return n, err +} + +func (t *TUN) Flush() error { return t.d.Flush() } +func (t *TUN) MTU() (int, error) { return t.d.MTU() } +func (t *TUN) Name() (string, error) { return t.d.Name() } +func (t *TUN) Events() chan tun.Event { return t.d.Events() } + +func (t *TUN) Close() error { + err1 := t.f.Close() + err2 := t.d.Close() + if err1 != nil { + return err1 + } + return err2 +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 573db07df..1cddd61f4 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -2698,11 +2698,11 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate // Success. ruc.pconn = pconn if uring.Available() { - uringConn, err := uring.NewUDPConn(pconn.(*net.UDPConn)) + uringConn, err := uring.NewUDPConn(pconn) if err != nil { - c.logf("not using io_uring for %v: %v", pconn.LocalAddr(), err) + c.logf("not using io_uring for UDP %v: %v", pconn.LocalAddr(), err) } else { - c.logf("using uring for %v", pconn.LocalAddr()) + c.logf("using uring for UDP %v", pconn.LocalAddr()) ruc.pconn = uringConn } }