simplify change point

at the cost of some very, very naughty unsafe
This commit is contained in:
Josh Bleecher Snyder
2021-07-06 17:52:04 -07:00
parent f71ff18c11
commit f274a0cfab
4 changed files with 122 additions and 76 deletions

View File

@@ -12,13 +12,10 @@
"os"
"sync"
"sync/atomic"
"syscall"
"time"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
wgtun "golang.zx2c4.com/wireguard/tun" // import twice to work around shadowing (TODO: fix properly)
"inet.af/netaddr"
"tailscale.com/net/packet"
"tailscale.com/net/uring"
@@ -66,9 +63,6 @@ type Wrapper struct {
// tdev is the underlying Wrapper device.
tdev tun.Device
// uring performs writes to the underlying Wrapper device.
ring *uring.File
closeOnce sync.Once
lastActivityAtomic int64 // unix seconds of last send or receive
@@ -167,12 +161,15 @@ func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper {
filterFlags: filter.LogAccepts | filter.LogDrops,
}
f := tdev.(*wgtun.NativeTun).File()
ring, err := uring.NewFile(f)
if err != nil {
panic(err) // TODO: just log? wat?
} else {
tun.ring = ring
if uring.Available() {
uringTun, err := uring.NewTUN(tdev)
name, _ := tdev.Name()
if err != nil {
logf("not using io_uring for TUN %v: %v", name, err)
} else {
logf("using uring for TUN %v", name)
tdev = uringTun
}
}
go tun.poll()
@@ -315,7 +312,7 @@ func (t *Wrapper) poll() {
if t.isClosed() {
return
}
n, err = t.read(t.buffer[:], PacketStartOffset)
n, err = t.tdev.Read(t.buffer[:], PacketStartOffset)
}
t.sendOutbound(tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err})
}
@@ -534,55 +531,7 @@ func (t *Wrapper) Write(buf []byte, offset int) (int, error) {
}
t.noteActivity()
return t.write(buf, offset)
}
func (t *Wrapper) write(buf []byte, offset int) (int, error) {
if t.ring == nil {
return t.tdev.Write(buf, offset)
}
// below copied from wireguard-go NativeTUN.Write
// reserve space for header
buf = buf[offset-4:]
// add packet information header
buf[0] = 0x00
buf[1] = 0x00
if buf[4]>>4 == ipv6.Version {
buf[2] = 0x86
buf[3] = 0xdd
} else {
buf[2] = 0x08
buf[3] = 0x00
}
n, err := t.ring.Write(buf)
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
return n, err
}
func (t *Wrapper) read(buf []byte, offset int) (n int, err error) {
// TODO: upstream has graceful shutdown error handling here.
buff := buf[offset-4:]
const useIOUring = true
if useIOUring {
n, err = t.ring.Read(buff[:])
} else {
n, err = t.tdev.(*wgtun.NativeTun).File().Read(buff[:])
}
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
if n < 4 {
n = 0
} else {
n -= 4
}
return
return t.tdev.Write(buf, offset)
}
func (t *Wrapper) GetFilter() *filter.Filter {

View File

@@ -21,7 +21,7 @@ func TestFileRead(t *testing.T) {
c.Assert(err, qt.IsNil)
t.Cleanup(func() { f.Close() })
uf, err := NewFile(f)
uf, err := newFile(f)
if err != nil {
t.Skipf("io_uring not available: %v", err)
}

View File

@@ -16,7 +16,9 @@
"time"
"unsafe"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"inet.af/netaddr"
)
@@ -48,7 +50,11 @@ type UDPConn struct {
is4 bool
}
func NewUDPConn(conn *net.UDPConn) (*UDPConn, error) {
func NewUDPConn(pconn net.PacketConn) (*UDPConn, error) {
conn, ok := pconn.(*net.UDPConn)
if !ok {
return nil, fmt.Errorf("cannot use io_uring with conn of type %T", pconn)
}
// this is dumb
local := conn.LocalAddr().String()
ip, err := netaddr.ParseIPPort(local)
@@ -290,7 +296,7 @@ func (c *UDPConn) SetWriteDeadline(t time.Time) error {
// A File is a write-only file fd manager.
// TODO: Support reads
// TODO: all the todos from UDPConn
type File struct {
type file struct {
writeRing *C.go_uring
readRing *C.go_uring
close sync.Once
@@ -301,10 +307,10 @@ type File struct {
writeReqC chan int // indices into reqs
}
func NewFile(file *os.File) (*File, error) {
fd := C.int(file.Fd())
u := &File{
file: file,
func newFile(f *os.File) (*file, error) {
fd := C.int(f.Fd())
u := &file{
file: f,
fd: fd,
}
for _, ringPtr := range []**C.go_uring{&u.writeRing, &u.readRing} {
@@ -340,7 +346,7 @@ func NewFile(file *os.File) (*File, error) {
return u, nil
}
func (u *File) submitReadvRequest(idx int) error {
func (u *file) submitReadvRequest(idx int) error {
// TODO: make a C struct instead of a Go struct, and pass that in, to simplify call sites.
errno := C.submit_readv_request(u.readRing, u.readReqs[idx], C.size_t(idx))
if errno < 0 {
@@ -363,7 +369,7 @@ type fileReq struct {
// Read data into buf[offset:].
// We are allowed to write junk into buf[offset-4:offset].
func (u *File) Read(buf []byte) (n int, err error) { // read a packet from the device (without any additional headers)
func (u *file) Read(buf []byte) (n int, err error) { // read a packet from the device (without any additional headers)
if u.fd == 0 {
return 0, errors.New("invalid uring.File")
}
@@ -383,7 +389,7 @@ func (u *File) Read(buf []byte) (n int, err error) { // read a packet from the d
return n, nil
}
func (u *File) Write(buf []byte) (int, error) {
func (u *file) Write(buf []byte) (int, error) {
if u.fd == 0 {
return 0, errors.New("invalid uring.FileConn")
}
@@ -421,7 +427,7 @@ func (u *File) Write(buf []byte) (int, error) {
}
// TODO: the TODOs from UDPConn.Close
func (u *File) Close() error {
func (u *file) Close() error {
u.close.Do(func() {
u.file.Close()
// TODO: require kernel 5.5, send an abort SQE, handle aborts gracefully
@@ -442,3 +448,94 @@ func (u *File) Close() error {
})
return nil
}
// Wrap files into TUN devices.
func NewTUN(d tun.Device) (tun.Device, error) {
nt, ok := d.(*tun.NativeTun)
if !ok {
return nil, fmt.Errorf("NewTUN only wraps *tun.NativeTun, got %T", d)
}
f, err := newFile(nt.File())
if err != nil {
return nil, err
}
v := reflect.ValueOf(nt)
field, ok := v.Elem().Type().FieldByName("errors")
if !ok {
return nil, errors.New("could not find internal tun.NativeTun errors field")
}
ptr := unsafe.Pointer(nt)
ptr = unsafe.Pointer(uintptr(ptr) + field.Offset) // TODO: switch to unsafe.Add with Go 1.17...as if that's the worst thing in this line
c := *(*chan error)(ptr)
return &TUN{d: nt, f: f, errors: c}, nil
}
// No nopi
type TUN struct {
d *tun.NativeTun
f *file
errors chan error
}
func (t *TUN) File() *os.File {
return t.f.file
}
func (t *TUN) Read(buf []byte, offset int) (int, error) {
select {
case err := <-t.errors:
return 0, err
default:
}
// TODO: upstream has graceful shutdown error handling here.
buff := buf[offset-4:]
n, err := t.f.Read(buff[:])
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
if n < 4 {
n = 0
} else {
n -= 4
}
return n, err
}
func (t *TUN) Write(buf []byte, offset int) (int, error) {
// below copied from wireguard-go NativeTun.Write
// reserve space for header
buf = buf[offset-4:]
// add packet information header
buf[0] = 0x00
buf[1] = 0x00
if buf[4]>>4 == ipv6.Version {
buf[2] = 0x86
buf[3] = 0xdd
} else {
buf[2] = 0x08
buf[3] = 0x00
}
n, err := t.f.Write(buf)
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
return n, err
}
func (t *TUN) Flush() error { return t.d.Flush() }
func (t *TUN) MTU() (int, error) { return t.d.MTU() }
func (t *TUN) Name() (string, error) { return t.d.Name() }
func (t *TUN) Events() chan tun.Event { return t.d.Events() }
func (t *TUN) Close() error {
err1 := t.f.Close()
err2 := t.d.Close()
if err1 != nil {
return err1
}
return err2
}

View File

@@ -2698,11 +2698,11 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
// Success.
ruc.pconn = pconn
if uring.Available() {
uringConn, err := uring.NewUDPConn(pconn.(*net.UDPConn))
uringConn, err := uring.NewUDPConn(pconn)
if err != nil {
c.logf("not using io_uring for %v: %v", pconn.LocalAddr(), err)
c.logf("not using io_uring for UDP %v: %v", pconn.LocalAddr(), err)
} else {
c.logf("using uring for %v", pconn.LocalAddr())
c.logf("using uring for UDP %v", pconn.LocalAddr())
ruc.pconn = uringConn
}
}