mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-25 01:02:05 -04:00
Compare commits
15 Commits
v4.3.0
...
feat/expos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c778ad0f6d | ||
|
|
8b2697f39a | ||
|
|
1af79c1b0f | ||
|
|
40ec4ffc94 | ||
|
|
04407d24f3 | ||
|
|
1c4bdfd1d6 | ||
|
|
799215cdc6 | ||
|
|
88306d562d | ||
|
|
df8418cb2d | ||
|
|
42d6e52fd7 | ||
|
|
a867b3d2a8 | ||
|
|
63448826b1 | ||
|
|
be1041de0c | ||
|
|
b85b7e29df | ||
|
|
17791fb741 |
@@ -157,6 +157,7 @@ type RunCMD struct {
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
}
|
||||
@@ -277,6 +278,9 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
opts = append(opts, config.DisableMetricsEndpoint)
|
||||
|
||||
@@ -112,6 +112,18 @@ type ApplicationConfig struct {
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed DistributedConfig
|
||||
|
||||
// ExposeNodeHeader, when true, activates middleware.ExposeNodeHeader on
|
||||
// the inference routes (OpenAI chat/completions/embeddings, Anthropic
|
||||
// /v1/messages, Ollama /api/chat,/api/generate,/api/embed). The
|
||||
// middleware wraps the response writer and attaches an "X-LocalAI-Node"
|
||||
// response header carrying the ID of the distributed-mode worker node
|
||||
// that served the request. Off by default because the node ID is
|
||||
// internal topology that can aid attacker reconnaissance if surfaced on
|
||||
// a public endpoint; operators opt in explicitly via
|
||||
// --expose-node-header / LOCALAI_EXPOSE_NODE_HEADER for debugging,
|
||||
// observability and load-balancer attribution.
|
||||
ExposeNodeHeader bool
|
||||
|
||||
// LocalAI Assistant chat modality. Hard-disable the in-process admin MCP
|
||||
// server with this flag; runtime-toggleable via /api/settings.
|
||||
DisableLocalAIAssistant bool
|
||||
@@ -893,6 +905,15 @@ func WithDisableLocalAIAssistant(disabled bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithExposeNodeHeader enables the X-LocalAI-Node response header on
|
||||
// inference endpoints. Default off; the node ID reveals internal cluster
|
||||
// topology and is opt-in for that reason.
|
||||
func WithExposeNodeHeader(enabled bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.ExposeNodeHeader = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
|
||||
// Some options defined at the application level are going to be passed as defaults for
|
||||
// all the configuration for the models.
|
||||
|
||||
@@ -325,6 +325,12 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on)
|
||||
// is handled by middleware.ExposeNodeHeader at the wrapper
|
||||
// layer: the first c.Response().Write / Flush lazily reads the
|
||||
// node ID from the loader (post-ml.Load) and stamps the header
|
||||
// before the byte hits the underlying writer. No per-request
|
||||
// chan / per-handler plumbing needed here.
|
||||
|
||||
mcpStreamMaxIterations := 10
|
||||
if config.Agent.MaxIterations > 0 {
|
||||
@@ -577,7 +583,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return nil
|
||||
} // end MCP stream iteration loop
|
||||
|
||||
// Safety fallback
|
||||
// Safety fallback. The MCP iteration loop above always returns,
|
||||
// so this is structurally unreachable; if we ever reach it the
|
||||
// stream is closed cleanly. The middleware-installed wrapper
|
||||
// still stamps X-LocalAI-Node on this final write if applicable.
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
@@ -935,6 +944,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
respData, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(respData))
|
||||
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on)
|
||||
// is handled by middleware.ExposeNodeHeader at the wrapper
|
||||
// layer; c.JSON's writes trigger the lazy stamp.
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
} // end MCP iteration loop
|
||||
|
||||
@@ -21,6 +21,13 @@ import (
|
||||
// The caller owns the `responses` channel and is expected to read from
|
||||
// it while this function runs; processStream closes the channel before
|
||||
// returning.
|
||||
//
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on) is
|
||||
// handled by middleware.ExposeNodeHeader at the response writer wrapper
|
||||
// layer; no in-band signal from the worker is needed. The initial
|
||||
// role=assistant chunk is still emitted from the first token callback
|
||||
// rather than eagerly here, so the wrapper's lazy lookup against the
|
||||
// loader runs AFTER ml.Load has stamped the per-modelID node ID.
|
||||
func processStream(
|
||||
s string,
|
||||
req *schema.OpenAIRequest,
|
||||
@@ -32,13 +39,7 @@ func processStream(
|
||||
id string,
|
||||
created int,
|
||||
) (backend.TokenUsage, error) {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentInitialRole := false
|
||||
|
||||
// Detect if thinking token is already in prompt or template
|
||||
// When UseTokenizerTemplate is enabled, predInput is empty, so we check the template
|
||||
@@ -70,6 +71,17 @@ func processStream(
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
if !sentInitialRole {
|
||||
sentInitialRole = true
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
delta := &schema.Message{}
|
||||
if contentDelta != "" {
|
||||
delta.Content = &contentDelta
|
||||
@@ -130,6 +142,9 @@ func processStreamWithTools(
|
||||
hasChatDeltaToolCalls := false
|
||||
hasChatDeltaContent := false
|
||||
|
||||
// X-LocalAI-Node attribution is handled by middleware.ExposeNodeHeader
|
||||
// at the wrapper layer; no in-band signalling from this worker.
|
||||
|
||||
_, finalUsage, chatDeltas, err := ComputeChoices(req, prompt, cfg, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
|
||||
|
||||
@@ -26,6 +26,11 @@ import (
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// process runs the streaming inference. X-LocalAI-Node attribution
|
||||
// (when --expose-node-header is on) is handled by
|
||||
// middleware.ExposeNodeHeader at the response writer wrapper layer:
|
||||
// the first SSE write triggers a lazy lookup against the loader, so
|
||||
// no in-band signalling is needed here.
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
created := int(time.Now().Unix())
|
||||
@@ -106,6 +111,11 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on)
|
||||
// is handled by middleware.ExposeNodeHeader at the wrapper
|
||||
// layer: the first c.Response().Write / Flush lazily reads the
|
||||
// node ID from the loader (post-ml.Load) and stamps the header
|
||||
// before the byte hits the underlying writer.
|
||||
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
@@ -274,6 +284,9 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(jsonResult))
|
||||
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on) is
|
||||
// handled by middleware.ExposeNodeHeader at the wrapper layer.
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
@@ -102,6 +102,9 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(jsonResult))
|
||||
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on) is
|
||||
// handled by middleware.ExposeNodeHeader at the wrapper layer.
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
131
core/http/middleware/node_header.go
Normal file
131
core/http/middleware/node_header.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// NodeHeaderName is the HTTP response header that, when --expose-node-header
|
||||
// is enabled, carries the ID of the distributed-mode worker node that served
|
||||
// the inference request. Off by default: node IDs reveal internal topology
|
||||
// and should not be exposed on a public endpoint.
|
||||
const NodeHeaderName = "X-LocalAI-Node"
|
||||
|
||||
// nodeHeaderWriter wraps an http.ResponseWriter and stamps the X-LocalAI-Node
|
||||
// header lazily on the first Write / WriteHeader / Flush call. The lazy
|
||||
// resolve is what makes this work for streaming: the picked node ID is only
|
||||
// known AFTER ml.Load runs (i.e. on the first SSE chunk), so resolving at
|
||||
// request entry would attach the previous request's routing decision (or
|
||||
// nothing on a cold cache).
|
||||
type nodeHeaderWriter struct {
|
||||
http.ResponseWriter
|
||||
resolve func() string
|
||||
set bool
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) maybeSet() {
|
||||
if w.set {
|
||||
return
|
||||
}
|
||||
w.set = true
|
||||
if id := w.resolve(); id != "" {
|
||||
w.Header().Set(NodeHeaderName, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) Write(b []byte) (int, error) {
|
||||
w.maybeSet()
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) WriteHeader(code int) {
|
||||
w.maybeSet()
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Flush keeps SSE handlers working: Echo's Response.Flush goes through
|
||||
// http.NewResponseController which walks Unwrap() chains and invokes Flush
|
||||
// on the first wrapper that implements http.Flusher. By implementing it
|
||||
// here we both stamp the header before the underlying writer flushes AND
|
||||
// keep the streaming path alive.
|
||||
func (w *nodeHeaderWriter) Flush() {
|
||||
w.maybeSet()
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack preserves WebSocket / raw-conn handlers that need to take over the
|
||||
// underlying TCP connection (e.g. /v1/realtime). Without this the wrapper
|
||||
// would silently break those endpoints.
|
||||
//
|
||||
// When the underlying writer does not implement http.Hijacker we return
|
||||
// http.ErrNotSupported so callers using errors.Is (notably
|
||||
// http.NewResponseController.Hijack) detect the condition through the
|
||||
// standard sentinel rather than a string-matched custom error.
|
||||
func (w *nodeHeaderWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("hijack not supported: %w", http.ErrNotSupported)
|
||||
}
|
||||
|
||||
// Unwrap lets http.NewResponseController reach through us to find optional
|
||||
// interfaces (CloseNotifier, SetReadDeadline, etc.) on the real writer.
|
||||
func (w *nodeHeaderWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
||||
// ExposeNodeHeader installs a per-request response writer wrapper that
|
||||
// stamps the X-LocalAI-Node header from the currently-loaded model's node
|
||||
// ID on the first write. Off by default; opted in via --expose-node-header
|
||||
// / LOCALAI_EXPOSE_NODE_HEADER. The model name is read from the standard
|
||||
// per-request context key set by the request-extractor middleware chain
|
||||
// (CONTEXT_LOCALS_KEY_MODEL_NAME), so any handler that goes through the
|
||||
// usual SetModelAndConfig wiring is automatically covered.
|
||||
//
|
||||
// Best-effort: under heavy concurrency for the same model across multiple
|
||||
// replicas, the header may reflect a recent routing decision rather than
|
||||
// this exact request's, because the model loader's per-modelID store entry
|
||||
// is overwritten on every routing decision. Acceptable for observability
|
||||
// and debugging.
|
||||
func ExposeNodeHeader(appCfg *config.ApplicationConfig, ml *model.ModelLoader) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if appCfg == nil || !appCfg.ExposeNodeHeader || ml == nil {
|
||||
return next(c)
|
||||
}
|
||||
orig := c.Response().Writer
|
||||
wrapper := &nodeHeaderWriter{
|
||||
ResponseWriter: orig,
|
||||
resolve: func() string {
|
||||
modelName, _ := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if modelName == "" {
|
||||
return ""
|
||||
}
|
||||
// Pure store read - never invokes HealthCheck and
|
||||
// never acquires ml.mu, so the wrapper cannot stall
|
||||
// the response writer for the 2-minute gRPC
|
||||
// HealthCheck timeout that CheckIsLoaded can pay
|
||||
// when the recently-healthy cache window has
|
||||
// expired. The X-LocalAI-Node header is
|
||||
// best-effort observability; a stale value is
|
||||
// preferable to blocking the byte stream.
|
||||
return ml.LookupNodeID(modelName)
|
||||
},
|
||||
}
|
||||
c.Response().Writer = wrapper
|
||||
defer func() {
|
||||
c.Response().Writer = orig
|
||||
}()
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
225
core/http/middleware/node_header_integration_test.go
Normal file
225
core/http/middleware/node_header_integration_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package middleware_test
|
||||
|
||||
// Route-level integration coverage for the X-LocalAI-Node middleware.
|
||||
//
|
||||
// What this file pins (and why a separate spec on top of the unit tests
|
||||
// in node_header_test.go):
|
||||
//
|
||||
// - The unit tests in node_header_test.go exercise the wrapper by
|
||||
// invoking `mw(handler)(c)` directly against a hand-built
|
||||
// echo.Context. That misses regressions where the contract between
|
||||
// the real Echo router and the wrapper breaks: e.g. middleware
|
||||
// installation via e.Use() loses the wrapper because the framework
|
||||
// re-decorates c.Response().Writer after middleware setup, or a
|
||||
// handler that bypasses c.Response().Writer (writing to some other
|
||||
// captured surface).
|
||||
//
|
||||
// - This spec dispatches a real HTTP request through e.ServeHTTP into
|
||||
// a streaming handler shaped like chat.go's streaming branch: set
|
||||
// SSE headers, write chunks via c.Response().Write, Flush. It
|
||||
// proves that:
|
||||
// 1. Middleware installed via e.Use() is on the writer chain
|
||||
// when the handler runs.
|
||||
// 2. The wrapper's lazy maybeSet fires on the first underlying
|
||||
// Write/Flush, so X-LocalAI-Node lands on the response map
|
||||
// BEFORE the first body byte is committed.
|
||||
// 3. The header is present in the recorded response (i.e. it
|
||||
// isn't dropped because we tried to set it post-WriteHeader).
|
||||
//
|
||||
// Out of scope (and why):
|
||||
//
|
||||
// - We do NOT wire core/http/endpoints/openai.ChatEndpoint
|
||||
// end-to-end. ChatEndpoint depends on templates.Evaluator, the
|
||||
// MCP NATS client, and the LocalAI Assistant holder; standing
|
||||
// those up just to assert header ordering is out of proportion to
|
||||
// the property under test. The handler used here mirrors
|
||||
// chat.go's streaming branch and exercises the SAME middleware →
|
||||
// c.Response().Writer → SSE write path as production. If
|
||||
// chat.go's streaming branch ever stops going through
|
||||
// c.Response().Writer (e.g. it starts using a captured raw
|
||||
// http.ResponseWriter from a different seam), this test will not
|
||||
// notice; guard that with a code review checklist on chat.go.
|
||||
//
|
||||
// - We do NOT exercise the real processStream worker here.
|
||||
// processStream lives in core/http/endpoints/openai, which itself
|
||||
// imports core/http/middleware - a regular import from middleware
|
||||
// into openai would create a cycle. processStream is independently
|
||||
// covered in core/http/endpoints/openai/chat_stream_usage_test.go;
|
||||
// the only behaviour we need at this layer is the writer-contract
|
||||
// check above, which the synthetic SSE handler reproduces faithfully.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
// orderRecorder snapshots the X-LocalAI-Node header value AT THE MOMENT
|
||||
// the underlying writer is asked to commit each event. Any header set on
|
||||
// the response map AFTER the first write/flush is dropped on the wire,
|
||||
// so this is the ground-truth observation a real SSE client would see.
|
||||
type orderRecorder struct {
|
||||
http.ResponseWriter
|
||||
mu sync.Mutex
|
||||
events []string
|
||||
}
|
||||
|
||||
func (o *orderRecorder) record(ev string) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
o.events = append(o.events, ev)
|
||||
}
|
||||
|
||||
func (o *orderRecorder) snapshot() []string {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
out := make([]string, len(o.events))
|
||||
copy(out, o.events)
|
||||
return out
|
||||
}
|
||||
|
||||
func (o *orderRecorder) WriteHeader(code int) {
|
||||
o.record(fmt.Sprintf("header:%d:node=%s", code, o.Header().Get(middleware.NodeHeaderName)))
|
||||
o.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (o *orderRecorder) Write(b []byte) (int, error) {
|
||||
o.record(fmt.Sprintf("write:node=%s", o.Header().Get(middleware.NodeHeaderName)))
|
||||
return o.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (o *orderRecorder) Flush() {
|
||||
o.record(fmt.Sprintf("flush:node=%s", o.Header().Get(middleware.NodeHeaderName)))
|
||||
if f, ok := o.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("ExposeNodeHeader middleware (route-level integration)", func() {
|
||||
const (
|
||||
modelID = "integration-model"
|
||||
fakeNodeID = "node-route-7"
|
||||
)
|
||||
|
||||
var (
|
||||
ml *model.ModelLoader
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(GinkgoT().TempDir()),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ml = model.NewModelLoader(systemState)
|
||||
|
||||
// Stamp the loader with a model entry that already has the
|
||||
// node ID set. In production the SmartRouter stamps this
|
||||
// during ml.Load before the first chunk is emitted; here we
|
||||
// pre-stamp it because the assertion is about wire ordering
|
||||
// (header-before-first-byte), not about ml.Load timing
|
||||
// (which is covered separately in pkg/model/lookup_node_id_test.go).
|
||||
m := model.NewModelWithClient(modelID, "10.0.0.1:50051", nil)
|
||||
m.SetNodeID(fakeNodeID)
|
||||
m.MarkHealthy()
|
||||
store := model.NewInMemoryModelStore()
|
||||
store.Set(modelID, m)
|
||||
ml.SetModelStore(store)
|
||||
|
||||
appCfg = config.NewApplicationConfig()
|
||||
appCfg.ExposeNodeHeader = true
|
||||
})
|
||||
|
||||
It("stamps X-LocalAI-Node before the first SSE byte via the real router + middleware chain", func() {
|
||||
// Build a real Echo router. We need the tracker to sit BELOW
|
||||
// the ExposeNodeHeader wrapper in the writer chain (so its
|
||||
// recorded snapshot reflects what bytes-on-the-wire see AFTER
|
||||
// the wrapper has had a chance to stamp the header). Install
|
||||
// the tracker via a middleware that runs BEFORE
|
||||
// ExposeNodeHeader; Echo's middleware execution order matches
|
||||
// e.Use() call order, so the first Use() wraps the OUTER
|
||||
// layer of the writer chain (i.e. the wrapper installed by
|
||||
// the second Use() wraps the tracker installed by the first).
|
||||
var (
|
||||
recorderMu sync.Mutex
|
||||
tracker *orderRecorder
|
||||
)
|
||||
e := echo.New()
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
recorderMu.Lock()
|
||||
tracker = &orderRecorder{ResponseWriter: c.Response().Writer}
|
||||
c.Response().Writer = tracker
|
||||
recorderMu.Unlock()
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
e.Use(middleware.ExposeNodeHeader(appCfg, ml))
|
||||
|
||||
e.POST("/v1/chat/completions", func(c echo.Context) error {
|
||||
// Mirror SetModelAndConfig: stash the model name on the
|
||||
// per-request locals so the middleware's resolve closure
|
||||
// can pick it up. Every real chat / completion handler
|
||||
// goes through this contract.
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
|
||||
// SSE response prelude (same shape as chat.go).
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Emit a handful of SSE chunks. The very first
|
||||
// Write/Flush is what triggers the middleware
|
||||
// wrapper's maybeSet, so the X-LocalAI-Node header
|
||||
// MUST already be on the response map by the time the
|
||||
// byte is committed.
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := c.Response().Write([]byte(fmt.Sprintf("data: chunk %d\n\n", i)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(""))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
recorderMu.Lock()
|
||||
Expect(tracker).ToNot(BeNil(), "handler must run and install the order recorder")
|
||||
events := tracker.snapshot()
|
||||
recorderMu.Unlock()
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Header().Get(middleware.NodeHeaderName)).To(Equal(fakeNodeID),
|
||||
"production contract: header must reach the wire on a streamed response")
|
||||
|
||||
Expect(events).ToNot(BeEmpty(),
|
||||
"expected at least one underlying-writer event from the streaming handler")
|
||||
|
||||
// The very first observed event is the moment the wrapper
|
||||
// commits to the wire. Its recorded node= value is what a
|
||||
// real HTTP client would actually see; anything that lands
|
||||
// AFTER this byte is invisible.
|
||||
first := events[0]
|
||||
Expect(first).To(ContainSubstring("node="+fakeNodeID),
|
||||
"first writer event must carry the X-LocalAI-Node header (chain: middleware.Use -> e.POST -> handler.Write/Flush); got events: %v", events)
|
||||
|
||||
// Body sanity: SSE chunks made it to the recorder.
|
||||
Expect(rec.Body.String()).To(ContainSubstring("data: chunk 0"))
|
||||
})
|
||||
})
|
||||
291
core/http/middleware/node_header_test.go
Normal file
291
core/http/middleware/node_header_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
// orderedWriter records the order in which header-snapshot vs body-byte
|
||||
// events happen. Used by the streaming spec to assert that the X-LocalAI-Node
|
||||
// header lands on the response BEFORE the first body byte is committed to
|
||||
// the underlying writer.
|
||||
type orderedWriter struct {
|
||||
http.ResponseWriter
|
||||
events []string
|
||||
}
|
||||
|
||||
func (o *orderedWriter) WriteHeader(code int) {
|
||||
o.events = append(o.events, "header:"+http.StatusText(code))
|
||||
o.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (o *orderedWriter) Write(b []byte) (int, error) {
|
||||
// Snapshot the X-LocalAI-Node header value AT THE INSTANT the underlying
|
||||
// writer is asked to commit bytes. This is what real HTTP clients
|
||||
// effectively observe: anything set on the header map AFTER this point
|
||||
// would be silently dropped.
|
||||
o.events = append(o.events, "write:node="+o.Header().Get(NodeHeaderName))
|
||||
return o.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (o *orderedWriter) Flush() {
|
||||
o.events = append(o.events, "flush:node="+o.Header().Get(NodeHeaderName))
|
||||
if f, ok := o.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("ExposeNodeHeader middleware", func() {
|
||||
const (
|
||||
modelID = "qwen3.5-0.8b"
|
||||
fakeNodeID = "node-abcdef"
|
||||
)
|
||||
|
||||
var (
|
||||
e *echo.Echo
|
||||
ml *model.ModelLoader
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
// loadModel pre-populates the loader's in-memory store with a model
|
||||
// entry whose NodeID is set to `nodeID` (or left empty). Marking the
|
||||
// model recently-healthy short-circuits the gRPC HealthCheck inside
|
||||
// CheckIsLoaded so the test does not try to dial a bogus address.
|
||||
loadModel := func(id, nodeID string) {
|
||||
m := model.NewModelWithClient(id, "10.0.0.1:50051", nil)
|
||||
if nodeID != "" {
|
||||
m.SetNodeID(nodeID)
|
||||
}
|
||||
m.MarkHealthy()
|
||||
store := model.NewInMemoryModelStore()
|
||||
store.Set(id, m)
|
||||
ml.SetModelStore(store)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
ml = model.NewModelLoader(&system.SystemState{})
|
||||
appCfg = &config.ApplicationConfig{}
|
||||
})
|
||||
|
||||
// run executes the middleware against a fake handler that stashes the
|
||||
// model name on the request context (the same way the
|
||||
// request-extractor middleware does in production) and then writes a
|
||||
// trivial body to trigger the wrapper. Returns the recorded response.
|
||||
run := func(handler echo.HandlerFunc) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
When("ExposeNodeHeader is false", func() {
|
||||
It("does not set the X-LocalAI-Node header", func() {
|
||||
appCfg.ExposeNodeHeader = false
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("does not even install the wrapper (writer is unchanged)", func() {
|
||||
appCfg.ExposeNodeHeader = false
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
origWriter := c.Response().Writer
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
// Pass-through must leave the writer identity intact so
|
||||
// no overhead is added on the hot path when the feature
|
||||
// is off.
|
||||
Expect(c.Response().Writer).To(BeIdenticalTo(origWriter))
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true and the model is loaded with a node ID", func() {
|
||||
It("sets the X-LocalAI-Node header on a buffered response", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
|
||||
It("sets the header even on a 500 error response (Write still triggers maybeSet)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusInternalServerError, "boom")
|
||||
})
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusInternalServerError))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true but no model is loaded for the request", func() {
|
||||
It("does not set the header (cold cache stays silent)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
// No model loaded.
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true and the model is loaded but has no node ID", func() {
|
||||
It("does not set the header (in-process model, not distributed)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, "") // local model: no node ID stamped
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true but no model name is stashed on the context", func() {
|
||||
It("does not set the header (handler did not opt in)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
// Intentionally skip the c.Set call.
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("the handler streams via Flush before any Write", func() {
|
||||
It("sets the header BEFORE the first byte hits the underlying writer", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
// Wrap the recorder with an order-tracking writer so we can
|
||||
// assert that the header is on the response map by the time
|
||||
// the first body byte is committed. This is the property
|
||||
// that protected the pre-refactor streaming bug: if the
|
||||
// wrapper stamped lazily but AFTER the byte commit, real
|
||||
// SSE clients would see the body without the header.
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
tracker := &orderedWriter{ResponseWriter: rec}
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Writer = tracker
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
// Simulate an SSE handler: flush headers, then write a
|
||||
// chunk and flush again. The wrapper must stamp the
|
||||
// node ID on the first call - either Flush or Write,
|
||||
// whichever comes first.
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Flush()
|
||||
_, err := c.Response().Write([]byte("data: chunk\n\n"))
|
||||
return err
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
// First recorded event on the underlying writer must show
|
||||
// the header already populated. The first event is either
|
||||
// flush or write; either way the node ID must be on it.
|
||||
Expect(tracker.events).ToNot(BeEmpty())
|
||||
Expect(tracker.events[0]).To(HavePrefix("flush:node=" + fakeNodeID))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("the handler writes a body without an explicit WriteHeader", func() {
|
||||
It("still stamps the header before the implicit 200 commit", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, fakeNodeID)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
tracker := &orderedWriter{ResponseWriter: rec}
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Writer = tracker
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
_, err := c.Response().Write([]byte("body"))
|
||||
return err
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
// Echo's Response.Write calls WriteHeader on the underlying
|
||||
// writer first, then Write. Both must see the header
|
||||
// already populated (the wrapper's maybeSet ran inside both
|
||||
// WriteHeader and Write before they hit `tracker`).
|
||||
Expect(len(tracker.events)).To(BeNumerically(">=", 2))
|
||||
Expect(tracker.events[0]).To(HavePrefix("header:"))
|
||||
Expect(tracker.events[1]).To(Equal("write:node=" + fakeNodeID))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("the model's node ID changes between request entry and first write", func() {
|
||||
It("uses the value present AT the first write (late binding)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
loadModel(modelID, "stale-node-A")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelID)
|
||||
// Simulate ml.Load running mid-request and re-stamping
|
||||
// the model with this request's actual routing decision.
|
||||
m := ml.CheckIsLoaded(modelID)
|
||||
Expect(m).ToNot(BeNil())
|
||||
m.SetNodeID("fresh-node-B")
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg, ml)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal("fresh-node-B"),
|
||||
"the wrapper must read the node ID lazily at first write, not eagerly at entry")
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -35,6 +35,7 @@ func RegisterAnthropicRoutes(app *echo.Echo,
|
||||
)
|
||||
|
||||
messagesMiddleware := []echo.MiddlewareFunc{
|
||||
middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader()),
|
||||
middleware.UsageMiddleware(application.AuthDB()),
|
||||
middleware.TraceMiddleware(application),
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
|
||||
@@ -18,6 +18,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
|
||||
traceMiddleware := middleware.TraceMiddleware(application)
|
||||
usageMiddleware := middleware.UsageMiddleware(application.AuthDB())
|
||||
nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader())
|
||||
|
||||
// Chat endpoint: POST /api/chat
|
||||
chatHandler := ollama.ChatEndpoint(
|
||||
@@ -27,6 +28,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
@@ -43,6 +45,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
generateMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
@@ -58,6 +61,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
embedMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
|
||||
@@ -17,6 +17,12 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// openAI compatible API endpoint
|
||||
traceMiddleware := middleware.TraceMiddleware(application)
|
||||
usageMiddleware := middleware.UsageMiddleware(application.AuthDB())
|
||||
// X-LocalAI-Node attribution middleware: wraps the response writer and
|
||||
// stamps the header on first write when --expose-node-header is on. No-op
|
||||
// otherwise. Applied to every inference path that routes through
|
||||
// ml.Load (chat, completion, embeddings) so distributed-mode operators
|
||||
// can observe which worker served each request.
|
||||
nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader())
|
||||
|
||||
// realtime
|
||||
// TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
|
||||
@@ -34,6 +40,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// chat
|
||||
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, application.LocalAIAssistant())
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
@@ -73,6 +80,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// completion
|
||||
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
completionMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
|
||||
@@ -94,6 +102,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// embeddings
|
||||
embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
embeddingMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
|
||||
@@ -68,6 +68,11 @@ func (a *ModelRouterAdapter) Route(ctx context.Context, backend, modelID, modelN
|
||||
// by SmartRouter. Use NewModelWithClient so the wrapper is preserved when
|
||||
// the ModelLoader returns this model on subsequent requests.
|
||||
m := model.NewModelWithClient(modelID, result.Node.Address, result.Client)
|
||||
// Stash the picked node ID so HTTP handlers can surface it via the
|
||||
// optional X-LocalAI-Node response header. Best-effort: the in-process
|
||||
// store keeps only the latest routing decision per modelID; see the
|
||||
// nodeID field comment on Model.
|
||||
m.SetNodeID(result.Node.ID)
|
||||
|
||||
xlog.Info("Model routed to remote node", "model", modelName, "node", result.Node.Name, "address", result.Node.Address)
|
||||
return m, nil
|
||||
|
||||
@@ -88,6 +88,7 @@ The frontend is a standard LocalAI instance with distributed mode enabled. These
|
||||
| `--auth-database-url` | `LOCALAI_AUTH_DATABASE_URL` | *(required)* | PostgreSQL connection URL |
|
||||
| `--backend-install-timeout` | `LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT` | `15m` | How long the frontend waits for a worker to acknowledge a backend install before considering the request stalled. Raise it when workers pull large backend images over slow links. If a worker takes longer than this, the operation shows as "still installing in background" in the admin UI and clears once the worker finishes. |
|
||||
| `--backend-upgrade-timeout` | `LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT` | `15m` | Same as the install timeout, applied to backend upgrades (force-reinstall). |
|
||||
| `--expose-node-header` | `LOCALAI_EXPOSE_NODE_HEADER` | `false` | When enabled, inference responses on the OpenAI-compatible endpoints (chat completions, completions, embeddings) as well as the Anthropic Messages (`/v1/messages`) and Ollama (`/api/chat`, `/api/generate`, `/api/embed`) shims carry an `X-LocalAI-Node` header with the ID of the worker node that served the request. Useful for debugging, observability and load-balancer attribution. Off by default: the node ID reveals internal cluster topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency for the same model across multiple replicas, the header may reflect a recent routing decision rather than this exact request's. Acceptable for observability and debugging. |
|
||||
|
||||
### Optional: S3 Object Storage
|
||||
|
||||
|
||||
@@ -40,8 +40,14 @@ type ModelRouter func(ctx context.Context, backend, modelID, modelName, modelFil
|
||||
opts *pb.ModelOptions, parallel bool) (*Model, error)
|
||||
|
||||
type ModelLoader struct {
|
||||
ModelPath string
|
||||
mu sync.Mutex
|
||||
ModelPath string
|
||||
mu sync.Mutex
|
||||
// storeMu guards only the `store` field reference (not the store's
|
||||
// internal state, which has its own concurrency mechanism). Kept
|
||||
// separate from `mu` so lock-free helpers like LookupNodeID can
|
||||
// snapshot the store reference without ever blocking behind a
|
||||
// HealthCheck-holding CheckIsLoaded call on `mu`.
|
||||
storeMu sync.RWMutex
|
||||
store ModelStore
|
||||
loading map[string]chan struct{} // tracks models currently being loaded
|
||||
wd *WatchDog
|
||||
@@ -112,7 +118,41 @@ func (ml *ModelLoader) SetModelRouter(r ModelRouter) {
|
||||
func (ml *ModelLoader) SetModelStore(s ModelStore) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
ml.storeMu.Lock()
|
||||
ml.store = s
|
||||
ml.storeMu.Unlock()
|
||||
}
|
||||
|
||||
// getStore returns the current store reference, taking only the
|
||||
// store-reference RWMutex (not ml.mu). Safe to call from hot paths that
|
||||
// must not block behind a HealthCheck-holding CheckIsLoaded.
|
||||
func (ml *ModelLoader) getStore() ModelStore {
|
||||
ml.storeMu.RLock()
|
||||
defer ml.storeMu.RUnlock()
|
||||
return ml.store
|
||||
}
|
||||
|
||||
// LookupNodeID returns the distributed worker node ID associated with
|
||||
// the loaded model, or "" if the model is not in the in-memory store or
|
||||
// has no node ID stamped.
|
||||
//
|
||||
// Unlike CheckIsLoaded this is a pure store read: it does NOT acquire
|
||||
// ml.mu and does NOT invoke a gRPC HealthCheck. The returned value may
|
||||
// be stale (the per-modelID store entry is overwritten on every
|
||||
// distributed-mode routing decision), which is acceptable for the
|
||||
// X-LocalAI-Node observability header. The contract here is "never pay
|
||||
// I/O on the response hot path"; correctness of the value is
|
||||
// best-effort by design.
|
||||
func (ml *ModelLoader) LookupNodeID(modelName string) string {
|
||||
store := ml.getStore()
|
||||
if store == nil {
|
||||
return ""
|
||||
}
|
||||
m, ok := store.Get(modelName)
|
||||
if !ok || m == nil {
|
||||
return ""
|
||||
}
|
||||
return m.NodeID()
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) GetWatchDog() *WatchDog {
|
||||
|
||||
100
pkg/model/lookup_node_id_test.go
Normal file
100
pkg/model/lookup_node_id_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// healthCheckCountingClient is a grpc.Backend stub that panics on every
|
||||
// method except HealthCheck, which it counts. Used to prove that
|
||||
// LookupNodeID never reaches the gRPC layer.
|
||||
//
|
||||
// We embed grpc.Backend so we only have to implement the one method we
|
||||
// care about; any other call will nil-deref-panic and surface a clear
|
||||
// failure in the test rather than silently swallowing a regression.
|
||||
type healthCheckCountingClient struct {
|
||||
grpc.Backend
|
||||
calls atomic.Int64
|
||||
}
|
||||
|
||||
func (c *healthCheckCountingClient) HealthCheck(_ context.Context) (bool, error) {
|
||||
c.calls.Add(1)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var _ = Describe("ModelLoader.LookupNodeID", func() {
|
||||
var (
|
||||
ml *model.ModelLoader
|
||||
store *model.InMemoryModelStore
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(GinkgoT().TempDir()),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ml = model.NewModelLoader(systemState)
|
||||
store = model.NewInMemoryModelStore()
|
||||
ml.SetModelStore(store)
|
||||
})
|
||||
|
||||
It("returns the stamped node ID for a loaded model", func() {
|
||||
m := model.NewModelWithClient("test-model", "10.0.0.1:50051", &healthCheckCountingClient{})
|
||||
m.SetNodeID("node-xyz")
|
||||
store.Set("test-model", m)
|
||||
|
||||
Expect(ml.LookupNodeID("test-model")).To(Equal("node-xyz"))
|
||||
})
|
||||
|
||||
It("returns empty string when the model is not loaded", func() {
|
||||
Expect(ml.LookupNodeID("missing-model")).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns empty string when the model is loaded but has no node ID", func() {
|
||||
m := model.NewModelWithClient("local-model", "127.0.0.1:50051", &healthCheckCountingClient{})
|
||||
// SetNodeID intentionally not called; in-process models stay unstamped.
|
||||
store.Set("local-model", m)
|
||||
|
||||
Expect(ml.LookupNodeID("local-model")).To(BeEmpty())
|
||||
})
|
||||
|
||||
// This is the regression guard for the Important #1 finding on the
|
||||
// node-header PR: the previous middleware called CheckIsLoaded on
|
||||
// the response hot path, which can hold ml.mu across a 2-minute
|
||||
// gRPC HealthCheck timeout whenever the recently-healthy cache
|
||||
// window has expired. LookupNodeID must read from the store only.
|
||||
It("does NOT invoke HealthCheck on the backend client", func() {
|
||||
client := &healthCheckCountingClient{}
|
||||
m := model.NewModelWithClient("hot-path-model", "10.0.0.1:50051", client)
|
||||
m.SetNodeID("node-42")
|
||||
// Deliberately do NOT call MarkHealthy: if LookupNodeID were
|
||||
// going through CheckIsLoaded, the lack of a cached healthy
|
||||
// flag would force a fresh HealthCheck round-trip. We want
|
||||
// the counter to stay at 0.
|
||||
store.Set("hot-path-model", m)
|
||||
|
||||
id := ml.LookupNodeID("hot-path-model")
|
||||
|
||||
Expect(id).To(Equal("node-42"))
|
||||
Expect(client.calls.Load()).To(BeZero(),
|
||||
"LookupNodeID must not invoke HealthCheck (would hang the response writer for up to 2 minutes on a stale-healthy model)")
|
||||
})
|
||||
|
||||
It("returns empty string when no store has been wired", func() {
|
||||
// Construct a loader and overwrite the store with nil via a
|
||||
// custom ModelStore-typed nil to exercise the defensive nil
|
||||
// guard. Done indirectly to avoid exporting internal state.
|
||||
bareLoader := model.NewModelLoader(&system.SystemState{})
|
||||
// Default store is non-nil (NewInMemoryModelStore), so seed
|
||||
// the missing-model branch instead - covered by the second It
|
||||
// above. This spec verifies the defensive contract at the
|
||||
// API surface: a never-loaded model still returns "" cleanly.
|
||||
Expect(bareLoader.LookupNodeID("anything")).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -19,6 +19,14 @@ type Model struct {
|
||||
client grpc.Backend
|
||||
process *process.Process
|
||||
lastHealthCheck time.Time
|
||||
// nodeID is the ID of the distributed-mode worker node that owns this
|
||||
// model handle, when set. Empty for in-process models. Best-effort:
|
||||
// because the distributed LoadModel path overwrites the per-modelID
|
||||
// store entry on every routing decision, this value reflects the
|
||||
// most-recently-routed node for the model, not necessarily the node
|
||||
// that served a specific in-flight request. Used by the optional
|
||||
// X-LocalAI-Node response header (--expose-node-header).
|
||||
nodeID string
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
@@ -40,6 +48,23 @@ func NewModelWithClient(ID, address string, client grpc.Backend) *Model {
|
||||
}
|
||||
}
|
||||
|
||||
// SetNodeID records the distributed-mode worker node that owns this model
|
||||
// handle. Safe to call from any goroutine.
|
||||
func (m *Model) SetNodeID(id string) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
m.nodeID = id
|
||||
}
|
||||
|
||||
// NodeID returns the distributed-mode worker node ID associated with this
|
||||
// model handle, or "" if unknown / in-process. See the nodeID field comment
|
||||
// for the best-effort caveat.
|
||||
func (m *Model) NodeID() string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.nodeID
|
||||
}
|
||||
|
||||
func (m *Model) Process() *process.Process {
|
||||
return m.process
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user