diff --git a/core/p2p/federated.go b/core/p2p/federated.go index 45035d428..800d38b81 100644 --- a/core/p2p/federated.go +++ b/core/p2p/federated.go @@ -146,27 +146,6 @@ func extractModel(path, queryModel string, body []byte) string { return probe.Model } -// SelectBestServer picks the online federated peer to serve the next request -// using the shared cluster-routing policy (least in-flight, then most free -// VRAM). Returns "" when no peer is online. -func (fs *FederatedServer) SelectBestServer() string { - fs.syncTableStatus() - // Snapshot the node set before taking fs.Lock so the fs critical section - // only guards requestTable. GetAvailableNodes takes its own global mutex; - // calling it outside fs.Lock avoids a fs.Mutex -> node.mu lock ordering. - nodes := GetAvailableNodes(fs.service) - fs.Lock() - defer fs.Unlock() - candidates := buildFederatedCandidates(nodes, fs.requestTable, time.Now(), "") - best := clusterrouting.PickBestReplica(candidates) - if best == nil { - xlog.Debug("No online federated peers to select", "request_table", fs.requestTable) - return "" - } - xlog.Debug("Selected federated peer", "peer", best.NodeID, "request_table", fs.requestTable) - return best.NodeID -} - // affinityPreferred returns the peer the prefix index considers warm for this // chain, or "" when there is no match strong enough among the candidates. It // reuses prefixcache's per-model radix-tree Decide; the final load-guarded pick diff --git a/core/p2p/federated_server.go b/core/p2p/federated_server.go index 7ed9920dc..1d8b7c930 100644 --- a/core/p2p/federated_server.go +++ b/core/p2p/federated_server.go @@ -1,17 +1,66 @@ package p2p import ( + "bufio" + "bytes" "context" "errors" "fmt" "io" "net" + "net/http" + "strings" + "time" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/edgevpn/pkg/node" "github.com/mudler/xlog" ) +// ErrBodyTooLarge is returned by readRequest when the buffered request body +// exceeds the configured limit. The proxy turns it into a 413 response. +var ErrBodyTooLarge = errors.New("request body exceeds limit") + +// readRequest parses a single HTTP request from r and buffers its body (so the +// body can both be inspected for the model/prefix and replayed to the chosen +// peer). limit caps the buffered body in bytes; 0 means unlimited. A body over +// the cap returns ErrBodyTooLarge. The returned request has its body replaced +// with the buffered bytes and RequestURI cleared so it can be re-serialized +// with req.Write to the peer stream. +func readRequest(r *bufio.Reader, limit int64) (*http.Request, []byte, error) { + req, err := http.ReadRequest(r) + if err != nil { + return nil, nil, err + } + var body []byte + if req.Body != nil { + reader := io.Reader(req.Body) + if limit > 0 { + reader = io.LimitReader(req.Body, limit+1) + } + body, err = io.ReadAll(reader) + _ = req.Body.Close() + if err != nil { + return nil, nil, err + } + if limit > 0 && int64(len(body)) > limit { + return nil, nil, ErrBodyTooLarge + } + } + req.Body = io.NopCloser(bytes.NewReader(body)) + req.ContentLength = int64(len(body)) + req.RequestURI = "" + return req, body, nil +} + +// isWebsocketUpgrade reports whether req is a websocket handshake, which must be +// forwarded as a raw bidirectional duplex (not request/streamed-response) and +// is not body-capped or model-routed. +func isWebsocketUpgrade(req *http.Request) bool { + return strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") && + strings.EqualFold(req.Header.Get("Upgrade"), "websocket") +} + func (f *FederatedServer) Start(ctx context.Context) error { n, err := NewNode(f.p2ptoken) if err != nil { @@ -62,40 +111,60 @@ func (fs *FederatedServer) proxy(ctx context.Context, node *node.Node) error { continue } - // Handle connections in a new goroutine, forwarding to the p2p service + // Handle connections in a new goroutine, terminating HTTP and + // forwarding the request to the chosen p2p peer. go func() { - workerID := "" - if fs.workerTarget != "" { - workerID = fs.workerTarget - } else if fs.loadBalanced { - xlog.Debug("Load balancing request") - - workerID = fs.SelectBestServer() - if workerID == "" { - xlog.Debug("Best server not found, selecting random") - workerID = fs.RandomServer() + br := bufio.NewReader(conn) + req, body, err := readRequest(br, fs.bodyLimit) + if err != nil { + if err == ErrBodyTooLarge { + fs.sendHTMLResponse(conn, 413, "Request body too large") + return } - } else { - workerID = fs.RandomServer() - } - - if workerID == "" { - xlog.Error("No available nodes yet") - fs.sendHTMLResponse(conn, 503, "Sorry, waiting for nodes to connect") + xlog.Error("Failed to read request", "error", err) + _ = conn.Close() + return + } + + upgrade := isWebsocketUpgrade(req) + + now := time.Now() + var ( + workerID string + model string + chain []uint64 + ) + switch { + case fs.workerTarget != "": + workerID = fs.workerTarget + case !fs.loadBalanced: + // Explicit random mode (the RandomWorker flag): keep the + // historical random pick, no model/affinity routing. + workerID = fs.RandomServer() + case upgrade: + // Websocket: no readable model; route by load only. + workerID, _ = fs.selectPeer("", nil, now) + default: + model = extractModel(req.URL.Path, req.URL.Query().Get("model"), body) + workerID, chain = fs.selectPeer(model, body, now) + } + + if workerID == "" { + fs.sendHTMLResponse(conn, 503, "No federated peer available for this request") return } - xlog.Debug("Selected node", "node", workerID) nodeData, exists := GetNode(fs.service, workerID) if !exists { - xlog.Error("Node not found", "node", workerID) fs.sendHTMLResponse(conn, 404, "Node not found") return } - proxyP2PConnection(ctx, node, nodeData.ServiceID, conn) - if fs.loadBalanced { - fs.RecordRequest(workerID) + proxyHTTPToPeer(ctx, node, nodeData.ServiceID, conn, req, upgrade) + + fs.RecordRequest(workerID) + if !upgrade { + fs.observeServed(model, chain, workerID, now) } }() } @@ -132,6 +201,8 @@ func getHTTPStatusText(statusCode int) string { switch statusCode { case 503: return "Service Unavailable" + case 413: + return "Request Entity Too Large" case 404: return "Not Found" case 200: diff --git a/core/p2p/federated_test.go b/core/p2p/federated_test.go index 703e18fe7..f5c8ed5c3 100644 --- a/core/p2p/federated_test.go +++ b/core/p2p/federated_test.go @@ -1,6 +1,8 @@ package p2p import ( + "bufio" + "strings" "time" . "github.com/onsi/ginkgo/v2" @@ -145,3 +147,35 @@ var _ = Describe("affinityPreferred", func() { Expect(affinityPreferred(idx, "m1", nil, nil, cfg, ref)).To(Equal("")) }) }) + +var _ = Describe("L7 request handling", func() { + It("reads a buffered request and its body under the cap", func() { + raw := "POST /v1/chat/completions HTTP/1.1\r\nHost: x\r\nContent-Length: 28\r\n\r\n" + + `{"model":"m1","messages":[]}` + req, body, err := readRequest(bufio.NewReader(strings.NewReader(raw)), 1024) + Expect(err).ToNot(HaveOccurred()) + Expect(req.URL.Path).To(Equal("/v1/chat/completions")) + Expect(string(body)).To(ContainSubstring(`"model":"m1"`)) + }) + + It("rejects a body over the cap with ErrBodyTooLarge", func() { + big := strings.Repeat("a", 200) + raw := "POST /x HTTP/1.1\r\nHost: x\r\nContent-Length: 200\r\n\r\n" + big + _, _, err := readRequest(bufio.NewReader(strings.NewReader(raw)), 64) + Expect(err).To(MatchError(ErrBodyTooLarge)) + }) + + It("detects a websocket upgrade request", func() { + raw := "GET /v1/realtime HTTP/1.1\r\nHost: x\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n" + req, _, err := readRequest(bufio.NewReader(strings.NewReader(raw)), 1024) + Expect(err).ToNot(HaveOccurred()) + Expect(isWebsocketUpgrade(req)).To(BeTrue()) + }) + + It("does not flag a normal POST as a websocket upgrade", func() { + raw := "POST /v1/chat/completions HTTP/1.1\r\nHost: x\r\nContent-Length: 2\r\n\r\n{}" + req, _, err := readRequest(bufio.NewReader(strings.NewReader(raw)), 1024) + Expect(err).ToNot(HaveOccurred()) + Expect(isWebsocketUpgrade(req)).To(BeFalse()) + }) +}) diff --git a/core/p2p/p2p.go b/core/p2p/p2p.go index 4cf892c62..ff13e8b09 100644 --- a/core/p2p/p2p.go +++ b/core/p2p/p2p.go @@ -6,12 +6,14 @@ import ( "fmt" "io" "net" + "net/http" "os" "strings" "sync" "time" "github.com/ipfs/go-log" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/utils" @@ -87,37 +89,39 @@ func nodeAnnounce(ctx context.Context, node *node.Node) { ) } -func proxyP2PConnection(ctx context.Context, node *node.Node, serviceID string, conn net.Conn) { - ledger, _ := node.Ledger() +// openPeerStream resolves serviceID to its advertised peer in the services +// ledger and opens a libp2p stream to that peer over the service protocol. +// Returns the stream or an error describing which lookup step failed. +func openPeerStream(ctx context.Context, n *node.Node, serviceID string) (network.Stream, error) { + ledger, _ := n.Ledger() // Retrieve current ID for ip in the blockchain existingValue, found := ledger.GetKey(protocol.ServicesLedgerKey, serviceID) service := &types.Service{} existingValue.Unmarshal(service) // If mismatch, update the blockchain if !found { - zlog.Error("Service not found on blockchain") - conn.Close() - // ll.Debugf("service '%s' not found on blockchain", serviceID) - return + return nil, errors.New("service not found on blockchain") } // Decode the Peer d, err := peer.Decode(service.PeerID) if err != nil { - zlog.Error("cannot decode peer") - - conn.Close() - // ll.Debugf("could not decode peer '%s'", service.PeerID) - return + return nil, fmt.Errorf("cannot decode peer: %w", err) } // Open a stream - stream, err := node.Host().NewStream(ctx, d, protocol.ServiceProtocol.ID()) + stream, err := n.Host().NewStream(ctx, d, protocol.ServiceProtocol.ID()) if err != nil { - zlog.Error("cannot open stream peer", "error", err) + return nil, fmt.Errorf("cannot open stream peer: %w", err) + } + return stream, nil +} +func proxyP2PConnection(ctx context.Context, node *node.Node, serviceID string, conn net.Conn) { + stream, err := openPeerStream(ctx, node, serviceID) + if err != nil { + zlog.Error("Could not open peer stream", "error", err) conn.Close() - // ll.Debugf("could not open stream '%s'", err.Error()) return } // ll.Debugf("(service %s) Redirecting", serviceID, l.Addr().String()) @@ -131,6 +135,44 @@ func proxyP2PConnection(ctx context.Context, node *node.Node, serviceID string, conn.Close() } +// proxyHTTPToPeer forwards an already-parsed HTTP request to the chosen peer +// over a libp2p stream and streams the response back to conn. When duplex is +// true (a websocket upgrade) it runs a bidirectional copy after writing the +// request, so post-101 frames flow both ways. The response is never buffered, +// so SSE keeps flowing. +func proxyHTTPToPeer(ctx context.Context, n *node.Node, serviceID string, conn net.Conn, req *http.Request, duplex bool) { + stream, err := openPeerStream(ctx, n, serviceID) + if err != nil { + zlog.Error("Could not open peer stream", "error", err) + _ = conn.Close() + return + } + // Force the peer to close after responding so the one-way io.Copy below + // terminates. Without this the peer keeps the HTTP/1.1 connection alive and + // io.Copy(conn, stream) blocks forever, leaking the goroutine, conn, and + // stream. Websocket upgrades keep keep-alive: their duplex copy owns the + // lifetime. + req.Close = !duplex + if err := req.Write(stream); err != nil { + zlog.Error("Could not write request to peer", "error", err) + _ = stream.Close() + _ = conn.Close() + return + } + if duplex { + closer := make(chan struct{}, 2) + go copyStream(closer, stream, conn) + go copyStream(closer, conn, stream) + <-closer + _ = stream.Close() + _ = conn.Close() + return + } + _, _ = io.Copy(conn, stream) + _ = stream.Close() + _ = conn.Close() +} + func allocateLocalService(ctx context.Context, node *node.Node, listenAddr, service string) error { zlog.Info("Allocating service", "service", service, "address", listenAddr) // Open local port for listening