diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 1da686a19..c9b616582 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -167,11 +167,6 @@ func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper { filterFlags: filter.LogAccepts | filter.LogDrops, } - go tun.poll() - go tun.pumpEvents() - // The buffer starts out consumed. - tun.bufferConsumed <- struct{}{} - f := tdev.(*wgtun.NativeTun).File() ring, err := uring.NewFile(f) if err != nil { @@ -179,6 +174,12 @@ func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper { } else { tun.ring = ring } + + go tun.poll() + go tun.pumpEvents() + // The buffer starts out consumed. + tun.bufferConsumed <- struct{}{} + return tun } diff --git a/net/uring/io_uring.c b/net/uring/io_uring.c index ee65f4aaa..1e9e3439d 100644 --- a/net/uring/io_uring.c +++ b/net/uring/io_uring.c @@ -41,6 +41,34 @@ static int initialize(struct io_uring *ring, int fd) { return 0; } +struct req { + struct msghdr hdr; + struct iovec iov; + struct sockaddr_in sa; + char *buf; +}; + +typedef struct req goreq; + +static struct req *initializeReq(size_t sz) { + struct req *r = malloc(sizeof(struct req)); + memset(r, 0, sizeof(*r)); + r->buf = malloc(sz); + memset(r->buf, 0, sz); + r->iov.iov_base = r->buf; + r->iov.iov_len = sz; + r->hdr.msg_iov = &r->iov; + r->hdr.msg_iovlen = 1; + r->hdr.msg_name = &r->sa; + r->hdr.msg_namelen = sizeof(r->sa); + return r; +} + +static void freeReq(struct req *r) { + free(r->buf); + free(r); +} + // packNIdx packs a returned n (usually number of bytes) and a index into a request array into a 63-bit uint64. static uint64_t packNIdx(int n, size_t idx) { uint64_t idx64 = idx & 0xFFFFFFFF; // truncate to 32 bits, just to be careful (should never be larger than 8) @@ -48,12 +76,12 @@ static uint64_t packNIdx(int n, size_t idx) { return (n64 << 32) | idx64; } -static uint32_t ip(struct sockaddr_in *sa) { - return ntohl(sa->sin_addr.s_addr); +static uint32_t ip(struct req *r) { + return ntohl(r->sa.sin_addr.s_addr); } -static uint16_t port(struct sockaddr_in *sa) { - return ntohs(sa->sin_port); +static uint16_t port(struct req *r) { + return ntohs(r->sa.sin_port); } static uint32_t setIP(struct sockaddr_in *sa, uint32_t ip) { @@ -66,18 +94,9 @@ static uint16_t setPort(struct sockaddr_in *sa, uint16_t port) { // submit a recvmsg request via liburing // TODO: What recvfrom support arrives, maybe use that instead? -static int submit_recvmsg_request(struct io_uring *ring, struct msghdr *mhdr, struct iovec *iov, struct sockaddr_in *sender, char *buf, int buflen, size_t idx) { - iov->iov_base = buf; - iov->iov_len = buflen; - - mhdr->msg_iov = iov; - mhdr->msg_iovlen = 1; - - mhdr->msg_name = sender; - mhdr->msg_namelen = sizeof(struct sockaddr_in); - +static int submit_recvmsg_request(struct io_uring *ring, struct req *r, size_t idx) { struct io_uring_sqe *sqe = io_uring_get_sqe(ring); - io_uring_prep_recvmsg(sqe, 0, mhdr, 0); // use the 0th file in the list of registered fds + io_uring_prep_recvmsg(sqe, 0, &r->hdr, 0); // use the 0th file in the list of registered fds io_uring_sqe_set_flags(sqe, IOSQE_FIXED_FILE); io_uring_sqe_set_data(sqe, (void *)(idx)); io_uring_submit(ring); @@ -86,18 +105,10 @@ static int submit_recvmsg_request(struct io_uring *ring, struct msghdr *mhdr, st // submit a recvmsg request via liburing // TODO: What recvfrom support arrives, maybe use that instead? -static int submit_sendmsg_request(struct io_uring *ring, struct msghdr *mhdr, struct iovec *iov, struct sockaddr_in *sender, char *buf, int buflen, size_t idx) { - iov->iov_base = buf; - iov->iov_len = buflen; - - mhdr->msg_iov = iov; - mhdr->msg_iovlen = 1; - - mhdr->msg_name = sender; - mhdr->msg_namelen = sizeof(struct sockaddr_in); - +static int submit_sendmsg_request(struct io_uring *ring, struct req *r, int buflen, size_t idx) { + r->iov.iov_len = buflen; struct io_uring_sqe *sqe = io_uring_get_sqe(ring); - io_uring_prep_sendmsg(sqe, 0, mhdr, 0); // use the 0th file in the list of registered fds + io_uring_prep_sendmsg(sqe, 0, &r->hdr, 0); // use the 0th file in the list of registered fds io_uring_sqe_set_flags(sqe, IOSQE_FIXED_FILE); io_uring_sqe_set_data(sqe, (void *)(idx)); io_uring_submit(ring); @@ -138,13 +149,21 @@ again:; return nidx; } -// submit a write request via liburing -static int submit_write_request(struct io_uring *ring, char *buf, int buflen, size_t idx, struct iovec *iov) { - iov->iov_base = buf; - iov->iov_len = buflen; - +// submit a writev request via liburing +static int submit_writev_request(struct io_uring *ring, struct req *r, int buflen, size_t idx) { + r->iov.iov_len = buflen; struct io_uring_sqe *sqe = io_uring_get_sqe(ring); - io_uring_prep_writev(sqe, 0, iov, 1, 0); // use the 0th file in the list of registered fds + io_uring_prep_writev(sqe, 0, &r->iov, 1, 0); // use the 0th file in the list of registered fds + io_uring_sqe_set_flags(sqe, IOSQE_FIXED_FILE); + io_uring_sqe_set_data(sqe, (void *)(idx)); + int submitted = io_uring_submit(ring); + return 0; +} + +// submit a readv request via liburing +static int submit_readv_request(struct io_uring *ring, struct req *r, size_t idx) { + struct io_uring_sqe *sqe = io_uring_get_sqe(ring); + io_uring_prep_readv(sqe, 0, &r->iov, 1, 0); // use the 0th file in the list of registered fds io_uring_sqe_set_flags(sqe, IOSQE_FIXED_FILE); io_uring_sqe_set_data(sqe, (void *)(idx)); int submitted = io_uring_submit(ring); diff --git a/net/uring/io_uring_linux.go b/net/uring/io_uring_linux.go index 8210c18c0..d86a304c0 100644 --- a/net/uring/io_uring_linux.go +++ b/net/uring/io_uring_linux.go @@ -10,6 +10,7 @@ "fmt" "net" "os" + "reflect" "sync" "syscall" "time" @@ -19,6 +20,8 @@ "inet.af/netaddr" ) +const bufferSize = device.MaxSegmentSize + // A UDPConn is a recv-only UDP fd manager. // TODO: Support writes. // TODO: support multiplexing multiple fds? @@ -39,8 +42,8 @@ type UDPConn struct { file *os.File // must keep file from being GC'd fd C.int local net.Addr - recvReqs [8]udpReq - sendReqs [8]udpReq + recvReqs [8]*C.goreq + sendReqs [8]*C.goreq sendReqC chan int // indices into sendReqs } @@ -78,6 +81,14 @@ func NewUDPConn(conn *net.UDPConn) (*UDPConn, error) { fd: fd, local: conn.LocalAddr(), } + + // Initialize buffers + for _, reqs := range []*[8]*C.goreq{&u.recvReqs, &u.sendReqs} { + for i := range reqs { + reqs[i] = C.initializeReq(bufferSize) + } + } + // Initialize recv half. for i := range u.recvReqs { if err := u.submitRecvRequest(i); err != nil { @@ -93,23 +104,26 @@ func NewUDPConn(conn *net.UDPConn) (*UDPConn, error) { return u, nil } -type udpReq struct { - mhdr C.go_msghdr - iov C.go_iovec - sa C.go_sockaddr_in - buf [device.MaxSegmentSize]byte -} - func (u *UDPConn) submitRecvRequest(idx int) error { - r := &u.recvReqs[idx] // TODO: make a C struct instead of a Go struct, and pass that in, to simplify call sites. - errno := C.submit_recvmsg_request(u.recvRing, &r.mhdr, &r.iov, &r.sa, (*C.char)(unsafe.Pointer(&r.buf[0])), C.int(len(r.buf)), C.size_t(idx)) + errno := C.submit_recvmsg_request(u.recvRing, u.recvReqs[idx], C.size_t(idx)) if errno < 0 { return fmt.Errorf("uring.submitRecvRequest failed: %v", errno) // TODO: Improve } return nil } +// TODO: replace with unsafe.Slice once we are using Go 1.17. + +func sliceOf(ptr *C.char, n int) []byte { + var b []byte + h := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + h.Data = uintptr(unsafe.Pointer(ptr)) + h.Len = n + h.Cap = n + return b +} + func (u *UDPConn) ReadFromNetaddr(buf []byte) (int, netaddr.IPPort, error) { if u.fd == 0 { return 0, netaddr.IPPort{}, errors.New("invalid uring.UDPConn") @@ -119,13 +133,14 @@ func (u *UDPConn) ReadFromNetaddr(buf []byte) (int, netaddr.IPPort, error) { if err != nil { return 0, netaddr.IPPort{}, fmt.Errorf("ReadFromNetaddr: %v", err) } - r := &u.recvReqs[idx] - ip := C.ip(&r.sa) + r := u.recvReqs[idx] + ip := C.ip(r) var ip4 [4]byte binary.BigEndian.PutUint32(ip4[:], uint32(ip)) - port := C.port(&r.sa) + port := C.port(r) ipp := netaddr.IPPortFrom(netaddr.IPFrom4(ip4), uint16(port)) - copy(buf, r.buf[:n]) + rbuf := sliceOf(r.buf, n) + copy(buf, rbuf) // Queue up a new request. err = u.submitRecvRequest(int(idx)) if err != nil { @@ -154,6 +169,13 @@ func (u *UDPConn) Close() error { u.file.Close() u.file = nil u.fd = 0 + + // Free buffers + for _, reqs := range []*[8]*C.goreq{&u.recvReqs, &u.sendReqs} { + for _, r := range reqs { + C.freeReq(r) + } + } }) return nil } @@ -198,9 +220,10 @@ func (u *UDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return 0, fmt.Errorf("some WriteTo failed, maybe long ago: %v", err) } } - r := &u.sendReqs[idx] + r := u.sendReqs[idx] // Do the write. - copy(r.buf[:], p) + rbuf := sliceOf(r.buf, len(p)) + copy(rbuf, p) ip := binary.BigEndian.Uint32(udpAddr.IP) C.setIP(&r.sa, C.uint32_t(ip)) @@ -208,13 +231,10 @@ func (u *UDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { // TODO: populate r.sa with ip/port C.submit_sendmsg_request( - u.sendRing, // ring - &r.mhdr, // msghdr - &r.iov, // iov -- TODO: populate and don't pass it - &r.sa, // sockaddr_in, ditto - (*C.char)(unsafe.Pointer(&r.buf[0])), // buffer ptr, ditto - C.int(len(p)), // buffer len, ditto - C.size_t(idx), // user data + u.sendRing, // ring + r, + C.int(len(p)), // buffer len, ditto + C.size_t(idx), // user data ) // Get an extra buffer, if available. nidx := C.peek_completion(u.sendRing) @@ -252,33 +272,63 @@ func (c *UDPConn) SetWriteDeadline(t time.Time) error { // TODO: Support reads // TODO: all the todos from UDPConn type File struct { - ptr *C.go_uring - close sync.Once - file *os.File // must keep file from being GC'd - fd C.int - reqs [8]fileReq - reqC chan int // indices into reqs + writeRing *C.go_uring + readRing *C.go_uring + close sync.Once + file *os.File // must keep file from being GC'd + fd C.int + readReqs [8]*C.goreq + writeReqs [8]*C.goreq + writeReqC chan int // indices into reqs } func NewFile(file *os.File) (*File, error) { - r := new(C.go_uring) fd := C.int(file.Fd()) - ret := C.initialize(r, fd) - if ret < 0 { - return nil, fmt.Errorf("uring initialization failed: %d", ret) - } u := &File{ - ptr: r, file: file, fd: fd, } - u.reqC = make(chan int, len(u.reqs)) - for i := range u.reqs { - u.reqC <- i + for _, ringPtr := range []**C.go_uring{&u.writeRing, &u.readRing} { + r := new(C.go_uring) + ret := C.initialize(r, fd) + if ret < 0 { + // TODO: handle unwinding partial initialization + return nil, fmt.Errorf("uring initialization failed: %d", ret) + } + *ringPtr = r + } + + // Initialize buffers + for _, reqs := range []*[8]*C.goreq{&u.readReqs, &u.writeReqs} { + for i := range reqs { + reqs[i] = C.initializeReq(bufferSize) + } + } + + // Initialize read half. + for i := range u.readReqs { + if err := u.submitReadvRequest(i); err != nil { + u.Close() // TODO: will this crash? + return nil, err + } + } + + u.writeReqC = make(chan int, len(u.writeReqs)) + for i := range u.writeReqs { + u.writeReqC <- i } return u, nil } +func (u *File) submitReadvRequest(idx int) error { + // TODO: make a C struct instead of a Go struct, and pass that in, to simplify call sites. + errno := C.submit_readv_request(u.readRing, u.readReqs[idx], C.size_t(idx)) + if errno < 0 { + return fmt.Errorf("uring.submitReadvRequest failed: %v", errno) // TODO: Improve + } + return nil +} + func unpackNIdx(nidx C.uint64_t) (n, idx int, err error) { if int64(nidx) < 0 { return 0, 0, fmt.Errorf("error %d", int64(nidx)) @@ -291,6 +341,30 @@ type fileReq struct { buf [device.MaxSegmentSize]byte } +// Read data into buf[offset:]. +// We are allowed to write junk into buf[offset-4:offset]. +func (u *File) Read(buf []byte, offset int) (n int, err error) { // read a packet from the device (without any additional headers) + if u.fd == 0 { + return 0, errors.New("invalid uring.File") + } + nidx := C.wait_completion(u.readRing) + n, idx, err := unpackNIdx(nidx) + if err != nil || n < 4 { + return 0, fmt.Errorf("Read: %v", err) + } + r := u.readReqs[idx] + // Ignore the first 4 bytes of r.buf, because it contains TUN IP header, which we don't use. + // TODO: open with NOPI? + rbuf := sliceOf(r.buf, n) + n = copy(buf[offset:], rbuf[4:]) + // Queue up a new request. + err = u.submitReadvRequest(int(idx)) + if err != nil { + panic("how should we handle this?") + } + return n, nil +} + func (u *File) Write(buf []byte) (int, error) { if u.fd == 0 { return 0, errors.New("invalid uring.FileConn") @@ -298,22 +372,23 @@ func (u *File) Write(buf []byte) (int, error) { // If we need a buffer, get a buffer, potentially blocking. var idx int select { - case idx = <-u.reqC: + case idx = <-u.writeReqC: default: // No request available. Get one from the kernel. - nidx := C.wait_completion(u.ptr) + nidx := C.wait_completion(u.writeRing) var err error _, idx, err = unpackNIdx(nidx) if err != nil { return 0, fmt.Errorf("some write failed, maybe long ago: %v", err) } } - r := &u.reqs[idx] + r := u.writeReqs[idx] // Do the write. - copy(r.buf[:], buf) - C.submit_write_request(u.ptr, (*C.char)(unsafe.Pointer(&r.buf[0])), C.int(len(buf)), C.size_t(idx), &r.iov) + rbuf := sliceOf(r.buf, len(buf)) + copy(rbuf, buf) + C.submit_writev_request(u.writeRing, r, C.int(len(buf)), C.size_t(idx)) // Get an extra buffer, if available. - nidx := C.peek_completion(u.ptr) + nidx := C.peek_completion(u.writeRing) if syscall.Errno(-nidx) == syscall.EAGAIN || syscall.Errno(-nidx) == syscall.EINTR { // Nothing waiting for us. } else { @@ -321,7 +396,7 @@ func (u *File) Write(buf []byte) (int, error) { if err == nil { // Put the request buffer back in the usable queue. // Should never block, by construction. - u.reqC <- idx + u.writeReqC <- idx } } return len(buf), nil @@ -330,11 +405,20 @@ func (u *File) Write(buf []byte) (int, error) { // TODO: the TODOs from UDPConn.Close func (u *File) Close() error { u.close.Do(func() { - C.io_uring_queue_exit(u.ptr) - u.ptr = nil u.file.Close() + C.io_uring_queue_exit(u.readRing) + C.io_uring_queue_exit(u.writeRing) + u.readRing = nil + u.writeRing = nil u.file = nil u.fd = 0 + + // Free buffers + for _, reqs := range []*[8]*C.goreq{&u.readReqs, &u.writeReqs} { + for _, r := range reqs { + C.freeReq(r) + } + } }) return nil }