mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-31 20:21:26 -04:00
feat(distributed): add ExposeNodeHeader middleware + ResponseWriter wrapper
Introduce a per-request Echo middleware that wraps the response writer and lazily stamps X-LocalAI-Node on the first Write / WriteHeader / Flush. This replaces the chan-based per-request rendezvous and per-handler maybeSetNodeHeader calls with a single enforcement point. The wrapper reads the picked node ID by looking up the request's model in the ModelLoader at flush time (late binding), so the value reflects the post-ml.Load state of the loader rather than any pre-route guess. Off by default; gated by ApplicationConfig.ExposeNodeHeader. Ginkgo specs cover off/on, missing model, in-process model (no node ID), absent stash, buffered + streaming flush ordering, error path, and late binding under in-handler stamp. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-7[1m]
This commit is contained in:
122
core/http/middleware/node_header.go
Normal file
122
core/http/middleware/node_header.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// NodeHeaderName is the HTTP response header that, when --expose-node-header
|
||||
// is enabled, carries the ID of the distributed-mode worker node that served
|
||||
// the inference request. Off by default: node IDs reveal internal topology
|
||||
// and should not be exposed on a public endpoint.
|
||||
const NodeHeaderName = "X-LocalAI-Node"
|
||||
|
||||
// nodeHeaderWriter wraps an http.ResponseWriter and stamps the X-LocalAI-Node
|
||||
// header lazily on the first Write / WriteHeader / Flush call. The lazy
|
||||
// resolve is what makes this work for streaming: the picked node ID is only
|
||||
// known AFTER ml.Load runs (i.e. on the first SSE chunk), so resolving at
|
||||
// request entry would attach the previous request's routing decision (or
|
||||
// nothing on a cold cache).
|
||||
type nodeHeaderWriter struct {
|
||||
http.ResponseWriter
|
||||
resolve func() string
|
||||
set bool
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) maybeSet() {
|
||||
if w.set {
|
||||
return
|
||||
}
|
||||
w.set = true
|
||||
if id := w.resolve(); id != "" {
|
||||
w.Header().Set(NodeHeaderName, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) Write(b []byte) (int, error) {
|
||||
w.maybeSet()
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) WriteHeader(code int) {
|
||||
w.maybeSet()
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Flush keeps SSE handlers working: Echo's Response.Flush goes through
|
||||
// http.NewResponseController which walks Unwrap() chains and invokes Flush
|
||||
// on the first wrapper that implements http.Flusher. By implementing it
|
||||
// here we both stamp the header before the underlying writer flushes AND
|
||||
// keep the streaming path alive.
|
||||
func (w *nodeHeaderWriter) Flush() {
|
||||
w.maybeSet()
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack preserves WebSocket / raw-conn handlers that need to take over the
|
||||
// underlying TCP connection (e.g. /v1/realtime). Without this the wrapper
|
||||
// would silently break those endpoints.
|
||||
func (w *nodeHeaderWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
return nil, nil, errors.New("ResponseWriter does not implement http.Hijacker")
|
||||
}
|
||||
|
||||
// Unwrap lets http.NewResponseController reach through us to find optional
|
||||
// interfaces (CloseNotifier, SetReadDeadline, etc.) on the real writer.
|
||||
func (w *nodeHeaderWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
||||
// ExposeNodeHeader installs a per-request response writer wrapper that
|
||||
// stamps the X-LocalAI-Node header from the currently-loaded model's node
|
||||
// ID on the first write. Off by default; opted in via --expose-node-header
|
||||
// / LOCALAI_EXPOSE_NODE_HEADER. The model name is read from the standard
|
||||
// per-request context key set by the request-extractor middleware chain
|
||||
// (CONTEXT_LOCALS_KEY_MODEL_NAME), so any handler that goes through the
|
||||
// usual SetModelAndConfig wiring is automatically covered.
|
||||
//
|
||||
// Best-effort: under heavy concurrency for the same model across multiple
|
||||
// replicas, the header may reflect a recent routing decision rather than
|
||||
// this exact request's, because the model loader's per-modelID store entry
|
||||
// is overwritten on every routing decision. Acceptable for observability
|
||||
// and debugging.
|
||||
func ExposeNodeHeader(appCfg *config.ApplicationConfig, ml *model.ModelLoader) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if appCfg == nil || !appCfg.ExposeNodeHeader || ml == nil {
|
||||
return next(c)
|
||||
}
|
||||
orig := c.Response().Writer
|
||||
wrapper := &nodeHeaderWriter{
|
||||
ResponseWriter: orig,
|
||||
resolve: func() string {
|
||||
modelName, _ := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if modelName == "" {
|
||||
return ""
|
||||
}
|
||||
m := ml.CheckIsLoaded(modelName)
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m.NodeID()
|
||||
},
|
||||
}
|
||||
c.Response().Writer = wrapper
|
||||
defer func() {
|
||||
c.Response().Writer = orig
|
||||
}()
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
291
core/http/middleware/node_header_test.go
Normal file
291
core/http/middleware/node_header_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
// orderedWriter records the order in which header-snapshot vs body-byte
|
||||
// events happen. Used by the streaming spec to assert that the X-LocalAI-Node
|
||||
// header lands on the response BEFORE the first body byte is committed to
|
||||
// the underlying writer.
|
||||
type orderedWriter struct {
|
||||
http.ResponseWriter
|
||||
events []string
|
||||
}
|
||||
|
||||
func (o *orderedWriter) WriteHeader(code int) {
|
||||
o.events = append(o.events, "header:"+http.StatusText(code))
|
||||
o.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (o *orderedWriter) Write(b []byte) (int, error) {
|
||||
// Snapshot the X-LocalAI-Node header value AT THE INSTANT the underlying
|
||||
// writer is asked to commit bytes. This is what real HTTP clients
|
||||
// effectively observe: anything set on the header map AFTER this point
|
||||
// would be silently dropped.
|
||||
o.events = append(o.events, "write:node="+o.Header().Get(NodeHeaderName))
|
||||
return o.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (o *orderedWriter) Flush() {
|
||||
o.events = append(o.events, "flush:node="+o.Header().Get(NodeHeaderName))
|
||||
if f, ok := o.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("ExposeNodeHeader middleware", func() {
|
||||
const (
|
||||
modelID = "qwen3.5-0.8b"
|
||||
fakeNodeID = "node-abcdef"
|
||||
)
|
||||
|
||||
var (
|
||||
e *echo.Echo
|
||||
ml *model.ModelLoader
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
// loadModel pre-populates the loader's in-memory store with a model
|
||||
// entry whose NodeID is set to `nodeID` (or left empty). Marking the
|
||||
// model recently-healthy short-circuits the gRPC HealthCheck inside
|
||||
// CheckIsLoaded so the test does not try to dial a bogus address.
|
||||
loadModel := func(id, nodeID string) {
|
||||
m := model.NewModelWithClient(id, "10.0.0.1:50051", nil)
|
||||
if nodeID != "" {
|
||||
m.SetNodeID(nodeID)
|
||||
}
|
||||
m.MarkHealthy()
|
||||
store := model.NewInMemoryModelStore()
|
||||
store.Set(id, m)
|
||||
ml.SetModelStore(store)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
ml = model.NewModelLoader(&system.SystemState{})
|
||||
appCfg = &config.ApplicationConfig{}
|
||||
})
|
||||
|
||||
// run executes the middleware against a fake handler that stashes the
|
||||
// model name on the request context (the same way the
|
||||
// request-extractor middleware does in production) and then writes a
|
||||
// trivial body to trigger the wrapper. Returns the recorded response.
|
||||
run := func(handler echo.HandlerFunc) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
When("ExposeNodeHeader is false", func() {
|
||||
It("does not set the X-LocalAI-Node header", func() {
|
||||
appCfg.ExposeNodeHeader = false
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("does not even install the wrapper (writer is unchanged)", func() {
|
||||
appCfg.ExposeNodeHeader = false
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
origWriter := c.Response().Writer
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
// Pass-through must leave the writer identity intact so
|
||||
// no overhead is added on the hot path when the feature
|
||||
// is off.
|
||||
Expect(c.Response().Writer).To(BeIdenticalTo(origWriter))
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true and the model is loaded with a node ID", func() {
|
||||
It("sets the X-LocalAI-Node header on a buffered response", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
|
||||
It("sets the header even on a 500 error response (Write still triggers maybeSet)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusInternalServerError, "boom")
|
||||
})
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusInternalServerError))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true but no model is loaded for the request", func() {
|
||||
It("does not set the header (cold cache stays silent)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
// No model loaded.
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true and the model is loaded but has no node ID", func() {
|
||||
It("does not set the header (in-process model, not distributed)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, "") // local model: no node ID stamped
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true but no model name is stashed on the context", func() {
|
||||
It("does not set the header (handler did not opt in)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
// Intentionally skip the c.Set call.
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("the handler streams via Flush before any Write", func() {
|
||||
It("sets the header BEFORE the first byte hits the underlying writer", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
// Wrap the recorder with an order-tracking writer so we can
|
||||
// assert that the header is on the response map by the time
|
||||
// the first body byte is committed. This is the property
|
||||
// that protected the pre-refactor streaming bug: if the
|
||||
// wrapper stamped lazily but AFTER the byte commit, real
|
||||
// SSE clients would see the body without the header.
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
tracker := &orderedWriter{ResponseWriter: rec}
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Writer = tracker
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
// Simulate an SSE handler: flush headers, then write a
|
||||
// chunk and flush again. The wrapper must stamp the
|
||||
// node ID on the first call - either Flush or Write,
|
||||
// whichever comes first.
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Flush()
|
||||
_, err := c.Response().Write([]byte("data: chunk\n\n"))
|
||||
return err
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
// First recorded event on the underlying writer must show
|
||||
// the header already populated. The first event is either
|
||||
// flush or write; either way the node ID must be on it.
|
||||
Expect(tracker.events).ToNot(BeEmpty())
|
||||
Expect(tracker.events[0]).To(HavePrefix("flush:node=" + fakeNodeID))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("the handler writes a body without an explicit WriteHeader", func() {
|
||||
It("still stamps the header before the implicit 200 commit", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
tracker := &orderedWriter{ResponseWriter: rec}
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Writer = tracker
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
_, err := c.Response().Write([]byte("body"))
|
||||
return err
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
// Echo's Response.Write calls WriteHeader on the underlying
|
||||
// writer first, then Write. Both must see the header
|
||||
// already populated (the wrapper's maybeSet ran inside both
|
||||
// WriteHeader and Write before they hit `tracker`).
|
||||
Expect(len(tracker.events)).To(BeNumerically(">=", 2))
|
||||
Expect(tracker.events[0]).To(HavePrefix("header:"))
|
||||
Expect(tracker.events[1]).To(Equal("write:node=" + fakeNodeID))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("the model's node ID changes between request entry and first write", func() {
|
||||
It("uses the value present AT the first write (late binding)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, "stale-node-A")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
// Simulate ml.Load running mid-request and re-stamping
|
||||
// the model with this request's actual routing decision.
|
||||
m := ml.CheckIsLoaded(modelID)
|
||||
Expect(m).ToNot(BeNil())
|
||||
m.SetNodeID("fresh-node-B")
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal("fresh-node-B"),
|
||||
"the wrapper must read the node ID lazily at first write, not eagerly at entry")
|
||||
})
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user