diff --git a/net/uring/io_uring_linux.go b/net/uring/io_uring_linux.go index d7c8cc38b..c0500d45a 100644 --- a/net/uring/io_uring_linux.go +++ b/net/uring/io_uring_linux.go @@ -9,11 +9,11 @@ "fmt" "net" "os" - "runtime" "sync" "time" "unsafe" + "golang.zx2c4.com/wireguard/device" "inet.af/netaddr" ) @@ -35,6 +35,7 @@ type UDPConn struct { file *os.File // must keep file from being GC'd fd C.int local net.Addr + req req } func NewUDPConn(conn *net.UDPConn) (*UDPConn, error) { @@ -56,39 +57,53 @@ func NewUDPConn(conn *net.UDPConn) (*UDPConn, error) { const queue_depth = 8 // TODO: What value to use here? C.io_uring_queue_init(queue_depth, r, 0) - c := &UDPConn{ + u := &UDPConn{ ptr: r, conn: conn, file: file, fd: C.int(file.Fd()), local: conn.LocalAddr(), } - return c, nil + if err := u.submitRequest(); err != nil { + u.Close() // TODO: will this crash? + return nil, err + } + return u, nil +} + +type req struct { + mhdr C.go_msghdr + iov C.go_iovec + sa C.go_sockaddr_in + ip [4]byte // TODO: ipv6 + port C.uint16_t + buf [device.MaxSegmentSize]byte +} + +func (u *UDPConn) submitRequest() error { + // TODO: eventually separate submitting the request and waiting for the response. + errno := C.submit_recvmsg_request(u.fd, u.ptr, &u.req.mhdr, &u.req.iov, &u.req.sa, (*C.char)(unsafe.Pointer(&u.req.buf[0])), C.int(len(u.req.buf))) + if errno < 0 { + return fmt.Errorf("uring.submitRequest failed: %v", errno) // TODO: Improve + } + return nil } func (u *UDPConn) ReadFromNetaddr(buf []byte) (int, netaddr.IPPort, error) { if u.fd == 0 { return 0, netaddr.IPPort{}, errors.New("invalid uring.UDPConn") } - mhdr := new(C.go_msghdr) - iov := new(C.go_iovec) - sa := new(C.go_sockaddr_in) - // TODO: eventually separate submitting the request and waiting for the response. - errno := C.submit_recvmsg_request(u.fd, u.ptr, mhdr, iov, sa, (*C.char)(unsafe.Pointer(&buf[0])), C.int(len(buf))) - if errno < 0 { - return 0, netaddr.IPPort{}, fmt.Errorf("uring.UDPConn recv failed: %v", errno) // TODO: Improve errno - } - - a := new([4]byte) - var port C.uint16_t - n := C.receive_into(u.fd, u.ptr, (*C.char)(unsafe.Pointer(a)), &port) + n := C.receive_into(u.fd, u.ptr, (*C.char)(unsafe.Pointer(&u.req.ip)), &u.req.port) if n < 0 { return 0, netaddr.IPPort{}, errors.New("something wrong") } - ipp := netaddr.IPPortFrom(netaddr.IPFrom4(*a), uint16(port)) - runtime.KeepAlive(mhdr) - runtime.KeepAlive(iov) - runtime.KeepAlive(sa) + ipp := netaddr.IPPortFrom(netaddr.IPFrom4(u.req.ip), uint16(u.req.port)) + copy(buf, u.req.buf[:n]) + // Queue up a new request. + err := u.submitRequest() + if err != nil { + panic("how should we handle this?") + } return int(n), ipp, nil }