diff --git a/net/uring/file_linux.go b/net/uring/file_linux.go index 87e05567a..ee62bbcd4 100644 --- a/net/uring/file_linux.go +++ b/net/uring/file_linux.go @@ -11,8 +11,10 @@ "sync" "sync/atomic" "syscall" + "time" + "unsafe" - "golang.zx2c4.com/wireguard/device" + "tailscale.com/syncs" ) // A file is a file handle that uses io_uring for reads and writes. @@ -28,84 +30,130 @@ type file struct { // close ensures that file closes occur exactly once. close sync.Once + // closed indicates whether the file has been closed. + closed syncs.AtomicBool + // shutdown is a sequence of funcs to be called when the UDPConn closes. + shutdown []func() - // closed is an atomic variable that indicates whether the connection has been closed. - // TODO: Make an atomic bool type that we can use here. - closed uint32 + // file is the os file underlying this file. + file *os.File - file *os.File // must keep file from being GC'd - fd uintptr - readReqs [1]*C.goreq // Whoops! The kernel apparently cannot handle more than 1 concurrent preadv calls on a tun device! + // readReqs is an array of re-usable file preadv requests. + // We attempt to keep them all queued up for the kernel to fulfill. + // The array length is tied to the size of the uring. + readReqs [1]*C.goreq // Whoops! The kernel apparently cannot handle more than 1 concurrent preadv calls on a tun device! + + // writeReqs is an array of re-usable file pwritev requests. + // We dispatch them to the kernel as writes are requested. + // The array length is tied to the size of the uring. writeReqs [8]*C.goreq - writeReqC chan int // indices into reqs + + // writeReqC is a channel containing indices into writeReqs + // that are free to use (that is, not in the kernel). + writeReqC chan int + + // refcount counts the number of outstanding read/write requests. + // See the length comment for UDPConn.refcount for details. + refcount syncs.AtomicInt32 } func newFile(f *os.File) (*file, error) { - fd := f.Fd() u := &file{ - file: f, - fd: fd, + readRing: new(C.go_uring), + writeRing: new(C.go_uring), + file: f, } - for _, ringPtr := range []**C.go_uring{&u.writeRing, &u.readRing} { - r := new(C.go_uring) - ret := C.initialize(r, C.int(fd)) - if ret < 0 { - // TODO: handle unwinding partial initialization - return nil, fmt.Errorf("uring initialization failed: %d", ret) - } - *ringPtr = r + + fd := f.Fd() + if ret := C.initialize(u.readRing, C.int(fd)); ret < 0 { + u.doShutdown() + return nil, fmt.Errorf("readRing initialization failed: %w", syscall.Errno(-ret)) } + u.shutdown = append(u.shutdown, func() { + C.io_uring_queue_exit(u.readRing) + }) + + if ret := C.initialize(u.writeRing, C.int(fd)); ret < 0 { + u.doShutdown() + return nil, fmt.Errorf("writeRing initialization failed: %w", syscall.Errno(-ret)) + } + u.shutdown = append(u.shutdown, func() { + C.io_uring_queue_exit(u.writeRing) + }) // Initialize buffers for i := range &u.readReqs { - u.readReqs[i] = C.initializeReq(bufferSize, 0) + u.readReqs[i] = C.initializeReq(bufferSize, 0) // 0: not used for IP addresses } for i := range &u.writeReqs { - u.writeReqs[i] = C.initializeReq(bufferSize, 0) + u.writeReqs[i] = C.initializeReq(bufferSize, 0) // 0: not used for IP addresses } + u.shutdown = append(u.shutdown, func() { + for _, r := range u.readReqs { + C.freeReq(r) + } + for _, r := range u.writeReqs { + C.freeReq(r) + } + }) // Initialize read half. for i := range u.readReqs { if err := u.submitReadvRequest(i); err != nil { - u.Close() // TODO: will this crash? + u.doShutdown() return nil, err } } + // Initialize write half. u.writeReqC = make(chan int, len(u.writeReqs)) for i := range u.writeReqs { u.writeReqC <- i } + + // Initialization succeeded. + // Take ownership of the file. + u.shutdown = append(u.shutdown, func() { + u.file.Close() + }) return u, nil } func (u *file) submitReadvRequest(idx int) error { - // TODO: make a C struct instead of a Go struct, and pass that in, to simplify call sites. errno := C.submit_readv_request(u.readRing, u.readReqs[idx], C.size_t(idx)) if errno < 0 { - return fmt.Errorf("uring.submitReadvRequest failed: %v", errno) // TODO: Improve + return fmt.Errorf("uring.submitReadvRequest failed: %w", syscall.Errno(-errno)) } + atomic.AddInt32(u.readReqInKernel(idx), 1) // TODO: CAS? return nil } -type fileReq struct { - iov C.go_iovec - buf [device.MaxSegmentSize]byte +func (u *file) readReqInKernel(idx int) *int32 { + return (*int32)(unsafe.Pointer(&u.readReqs[idx].in_kernel)) } -// Read data into buf[offset:]. -// We are allowed to write junk into buf[offset-4:offset]. -func (u *file) Read(buf []byte) (n int, err error) { // read a packet from the device (without any additional headers) - if u.fd == 0 { // TODO: review all uses of u.fd for atomic read/write - return 0, errors.New("invalid uring.File") +// Read data into buf. +func (u *file) Read(buf []byte) (n int, err error) { + // The docs for the u.refcount field document this prologue. + u.refcount.Add(1) + defer u.refcount.Add(-1) + if u.closed.Get() { + return 0, os.ErrClosed } + n, idx, err := waitCompletion(u.readRing) + if errors.Is(err, syscall.ECANCELED) { + atomic.AddInt32(u.readReqInKernel(idx), -1) + return 0, os.ErrClosed + } if err != nil { return 0, fmt.Errorf("Read: io_uring failed to issue syscall: %w", err) } + atomic.AddInt32(u.readReqInKernel(idx), -1) if n < 0 { - // Syscall failed. - u.submitReadvRequest(int(idx)) // best effort attempt not to leak idx + // io_uring ran our syscall, which failed. + // Best effort attempt not to leak idx. + u.submitReadvRequest(int(idx)) return 0, fmt.Errorf("Read: syscall failed: %w", syscall.Errno(-n)) } // Success. @@ -121,9 +169,13 @@ func (u *file) Read(buf []byte) (n int, err error) { // read a packet from the d } func (u *file) Write(buf []byte) (int, error) { - if u.fd == 0 { - return 0, errors.New("invalid uring.FileConn") + // The docs for the u.refcount field document this prologue. + u.refcount.Add(1) + defer u.refcount.Add(-1) + if u.closed.Get() { + return 0, os.ErrClosed } + // If we need a buffer, get a buffer, potentially blocking. var idx int select { @@ -157,21 +209,30 @@ func (u *file) Write(buf []byte) (int, error) { func (u *file) Close() error { u.close.Do(func() { - atomic.StoreUintptr(&u.fd, 0) - u.file.Close() - u.file = nil - // TODO: bring the shutdown logic from UDPConn.Close here? - // Or is closing the file above enough, unlike for UDP? - C.io_uring_queue_exit(u.readRing) - C.io_uring_queue_exit(u.writeRing) - - // Free buffers - for _, r := range u.readReqs { - C.freeReq(r) - } - for _, r := range u.writeReqs { - C.freeReq(r) + // Announce to readers and writers that we are closing down. + // Busy loop until all reads and writes are unblocked. + // See the docs for u.refcount. + u.closed.Set(true) + for { + // Request that the kernel cancel all submitted reads. (Writes don't block indefinitely.) + for idx := range u.readReqs { + if atomic.LoadInt32(u.readReqInKernel(idx)) != 0 { + C.submit_cancel_request(u.readRing, C.size_t(idx)) + } + } + if u.refcount.Get() == 0 { + break + } + time.Sleep(time.Millisecond) } + // Do the rest of the shutdown. + u.doShutdown() }) return nil } + +func (u *file) doShutdown() { + for _, fn := range u.shutdown { + fn() + } +}