refactor(openai): route node-header through middleware wrapper

Wire middleware.ExposeNodeHeader onto the OpenAI inference routes
(chat, completions, embeddings) plus the Anthropic /v1/messages shim
and the Ollama chat/generate/embed shims. The wrapper handles
X-LocalAI-Node attribution from a single place, so the per-handler
maybeSetNodeHeader calls and the per-request nodeIDCh rendezvous /
applyNodeIDHeader plumbing in chat.go and completion.go are removed.

For SSE: the wrapper's lazy stamp on the first Write / WriteHeader /
Flush picks up the post-ml.Load node ID from the loader, replacing the
chan signal the worker used to publish. The role=assistant first chunk
emission stays where it is (inside the first token callback) so all
writes still happen AFTER ml.Load has stamped the per-modelID node ID.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7[1m]
This commit is contained in:
Ettore Di Giacinto
2026-05-24 21:23:04 +00:00
parent 799215cdc6
commit 1c4bdfd1d6
6 changed files with 43 additions and 115 deletions

View File

@@ -325,12 +325,12 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().Header().Set("X-Correlation-ID", id)
// X-LocalAI-Node attribution is deferred to AFTER ml.Load runs
// inside the streaming worker. The worker publishes the picked
// node ID onto nodeIDCh (per request, never shared), and we
// read it and set the header below before the first SSE flush.
// Setting it here would attach the previous request's routing
// decision (or nothing on a cold cache), not this request's.
// X-LocalAI-Node attribution (when --expose-node-header is on)
// is handled by middleware.ExposeNodeHeader at the wrapper
// layer: the first c.Response().Write / Flush lazily reads the
// node ID from the loader (post-ml.Load) and stamps the header
// before the byte hits the underlying writer. No per-request
// chan / per-handler plumbing needed here.
mcpStreamMaxIterations := 10
if config.Agent.MaxIterations > 0 {
@@ -347,34 +347,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
responses := make(chan schema.OpenAIResponse)
ended := make(chan streamWorkerResult, 1)
// Buffered size 1 so the worker's publish is always
// non-blocking even if the handler hasn't drained yet.
// One node-ID value per request: re-used across MCP loop
// iterations is fine because the worker only sends on the
// first token of the first iteration (or on the bail-out
// path); later iterations share the same routing decision.
nodeIDCh := make(chan string, 1)
nodeIDApplied := false
applyNodeIDHeader := func() {
if nodeIDApplied {
return
}
select {
case nodeID := <-nodeIDCh:
nodeIDApplied = true
if nodeID != "" {
c.Response().Header().Set(NodeHeaderName, nodeID)
}
default:
}
}
go func() {
if !shouldUseFn {
u, err := processStream(predInput, input, config, cl, startupOptions, ml, responses, id, created, nodeIDCh)
u, err := processStream(predInput, input, config, cl, startupOptions, ml, responses, id, created)
ended <- streamWorkerResult{usage: u, err: err}
} else {
u, err := processStreamWithTools(noActionName, predInput, input, config, cl, startupOptions, ml, responses, id, created, &textContentToReturn, nodeIDCh)
u, err := processStreamWithTools(noActionName, predInput, input, config, cl, startupOptions, ml, responses, id, created, &textContentToReturn)
ended <- streamWorkerResult{usage: u, err: err}
}
}()
@@ -419,12 +398,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
continue
}
xlog.Debug("Sending chunk", "chunk", string(respData))
// Attach X-LocalAI-Node from the worker's per-request
// signal before the very first Flush() locks the
// headers. The worker pushes on nodeIDCh AFTER
// ml.Load returns, so by the time we hit our first
// chunk read the signal is already buffered.
applyNodeIDHeader()
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
if err != nil {
xlog.Debug("Sending chunk failed", "error", err)
@@ -449,10 +422,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
respData, marshalErr := json.Marshal(errorResp)
if marshalErr != nil {
xlog.Error("Failed to marshal error response", "error", marshalErr)
applyNodeIDHeader()
fmt.Fprintf(c.Response().Writer, "data: {\"error\":{\"message\":\"Internal error\",\"type\":\"server_error\"}}\n\n")
} else {
applyNodeIDHeader()
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
}
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
@@ -517,7 +488,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
"result": toolResult,
}
if mcpEventData, err := json.Marshal(mcpEvent); err == nil {
applyNodeIDHeader()
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", mcpEventData)
c.Response().Flush()
}
@@ -559,7 +529,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
Object: "chat.completion.chunk",
}
respData, _ := json.Marshal(toolCallMsg)
applyNodeIDHeader()
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
c.Response().Flush()
toolsCalled = true
@@ -590,7 +559,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
Object: "chat.completion.chunk",
}
respData, _ := json.Marshal(resp)
applyNodeIDHeader()
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
// Trailing usage chunk per OpenAI spec: emit only when the
@@ -609,7 +577,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", trailer)
}
applyNodeIDHeader()
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
xlog.Debug("Stream ended")
@@ -618,8 +585,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
// Safety fallback. The MCP iteration loop above always returns,
// so this is structurally unreachable; if we ever reach it the
// per-iteration nodeIDCh has gone out of scope and there is no
// header to attach. Just close the stream cleanly.
// stream is closed cleanly. The middleware-installed wrapper
// still stamps X-LocalAI-Node on this final write if applicable.
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
@@ -977,10 +944,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
respData, _ := json.Marshal(resp)
xlog.Debug("Response", "response", string(respData))
// Attribute the response to a specific worker node when
// distributed mode is enabled and the operator opted in via
// --expose-node-header. No-op otherwise.
maybeSetNodeHeader(c, startupOptions, ml, input.Model)
// X-LocalAI-Node attribution (when --expose-node-header is on)
// is handled by middleware.ExposeNodeHeader at the wrapper
// layer; c.JSON's writes trigger the lazy stamp.
// Return the prediction in the response body
return c.JSON(200, resp)

View File

@@ -26,32 +26,13 @@ import (
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
// process runs the streaming inference and (when nodeIDCh is non-nil)
// signals the picked distributed node ID back to the caller after
// ml.Load has run. The handler reads nodeIDCh before the first SSE
// flush so X-LocalAI-Node attribution reflects THIS request's routing
// decision, not the previous one. nodeIDCh must be buffered (>=1).
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, nodeIDCh chan<- string) error {
nodeIDSent := false
publishNodeID := func() {
if nodeIDSent {
return
}
nodeIDSent = true
if nodeIDCh == nil {
return
}
select {
case nodeIDCh <- resolveNodeID(appConfig, loader, req.Model):
default:
}
}
// process runs the streaming inference. X-LocalAI-Node attribution
// (when --expose-node-header is on) is handled by
// middleware.ExposeNodeHeader at the response writer wrapper layer:
// the first SSE write triggers a lazy lookup against the loader, so
// no in-band signalling is needed here.
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
// First callback fires only after ml.Load has returned, so the
// loader's per-modelID node stamp is populated. Push the node
// ID before any SSE chunk leaves so the handler can attach the
// X-LocalAI-Node header before its first flush.
publishNodeID()
created := int(time.Now().Unix())
usage := schema.OpenAIUsage{
@@ -87,10 +68,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
return true
}
_, _, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
// Backend produced no tokens (immediate error/empty path): publish
// whatever node ID the loader has so the handler doesn't wait
// forever and the header is still set if applicable.
publishNodeID()
close(responses)
return err
}
@@ -134,12 +111,11 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
// X-LocalAI-Node attribution is deferred to AFTER ml.Load runs
// inside the streaming worker. The worker publishes the picked
// node ID onto nodeIDCh (per request, never shared), and we
// read it and set the header below before the first SSE flush.
// Setting it here would attach the previous request's routing
// decision (or nothing on a cold cache), not this request's.
// X-LocalAI-Node attribution (when --expose-node-header is on)
// is handled by middleware.ExposeNodeHeader at the wrapper
// layer: the first c.Response().Write / Flush lazily reads the
// node ID from the loader (post-ml.Load) and stamps the header
// before the byte hits the underlying writer.
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
@@ -159,26 +135,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
}
responses := make(chan schema.OpenAIResponse)
// Buffered so the worker's publish is always non-blocking.
nodeIDCh := make(chan string, 1)
nodeIDApplied := false
applyNodeIDHeader := func() {
if nodeIDApplied {
return
}
select {
case nodeID := <-nodeIDCh:
nodeIDApplied = true
if nodeID != "" {
c.Response().Header().Set(NodeHeaderName, nodeID)
}
default:
}
}
ended := make(chan error)
go func() {
ended <- process(id, predInput, input, config, ml, responses, extraUsage, nodeIDCh)
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
}()
var latestUsage *schema.OpenAIUsage
@@ -206,12 +166,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
}
xlog.Debug("Sending chunk", "chunk", string(respData))
// Attach X-LocalAI-Node from the worker's per-request
// signal before the very first Flush() locks the
// headers. The worker pushes on nodeIDCh AFTER
// ml.Load returns, so by the time we hit our first
// chunk read the signal is already buffered.
applyNodeIDHeader()
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
if err != nil {
return err
@@ -238,7 +192,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
Object: "text_completion",
}
errorData, marshalErr := json.Marshal(errorResp)
applyNodeIDHeader()
if marshalErr != nil {
xlog.Error("Failed to marshal error response", "error", marshalErr)
// Send a simple error message as fallback
@@ -265,7 +218,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
applyNodeIDHeader()
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
// Trailing usage chunk per OpenAI spec: emit only when the caller
@@ -332,10 +284,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
jsonResult, _ := json.Marshal(resp)
xlog.Debug("Response", "response", string(jsonResult))
// Attribute the response to a specific worker node when distributed
// mode is enabled and the operator opted in via --expose-node-header.
// No-op otherwise.
maybeSetNodeHeader(c, appConfig, ml, input.Model)
// X-LocalAI-Node attribution (when --expose-node-header is on) is
// handled by middleware.ExposeNodeHeader at the wrapper layer.
// Return the prediction in the response body
return c.JSON(200, resp)

View File

@@ -102,10 +102,8 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
jsonResult, _ := json.Marshal(resp)
xlog.Debug("Response", "response", string(jsonResult))
// Attribute the response to a specific worker node when distributed
// mode is enabled and the operator opted in via --expose-node-header.
// No-op otherwise.
maybeSetNodeHeader(c, appConfig, ml, input.Model)
// X-LocalAI-Node attribution (when --expose-node-header is on) is
// handled by middleware.ExposeNodeHeader at the wrapper layer.
// Return the prediction in the response body
return c.JSON(200, resp)

View File

@@ -35,6 +35,7 @@ func RegisterAnthropicRoutes(app *echo.Echo,
)
messagesMiddleware := []echo.MiddlewareFunc{
middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader()),
middleware.UsageMiddleware(application.AuthDB()),
middleware.TraceMiddleware(application),
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),

View File

@@ -18,6 +18,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
traceMiddleware := middleware.TraceMiddleware(application)
usageMiddleware := middleware.UsageMiddleware(application.AuthDB())
nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader())
// Chat endpoint: POST /api/chat
chatHandler := ollama.ChatEndpoint(
@@ -27,6 +28,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
application.ApplicationConfig(),
)
chatMiddleware := []echo.MiddlewareFunc{
nodeHeaderMiddleware,
usageMiddleware,
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
@@ -43,6 +45,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
application.ApplicationConfig(),
)
generateMiddleware := []echo.MiddlewareFunc{
nodeHeaderMiddleware,
usageMiddleware,
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
@@ -58,6 +61,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
application.ApplicationConfig(),
)
embedMiddleware := []echo.MiddlewareFunc{
nodeHeaderMiddleware,
usageMiddleware,
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),

View File

@@ -17,6 +17,12 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// openAI compatible API endpoint
traceMiddleware := middleware.TraceMiddleware(application)
usageMiddleware := middleware.UsageMiddleware(application.AuthDB())
// X-LocalAI-Node attribution middleware: wraps the response writer and
// stamps the header on first write when --expose-node-header is on. No-op
// otherwise. Applied to every inference path that routes through
// ml.Load (chat, completion, embeddings) so distributed-mode operators
// can observe which worker served each request.
nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader())
// realtime
// TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
@@ -34,6 +40,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// chat
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, application.LocalAIAssistant())
chatMiddleware := []echo.MiddlewareFunc{
nodeHeaderMiddleware,
usageMiddleware,
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
@@ -73,6 +80,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// completion
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
completionMiddleware := []echo.MiddlewareFunc{
nodeHeaderMiddleware,
usageMiddleware,
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
@@ -94,6 +102,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// embeddings
embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
embeddingMiddleware := []echo.MiddlewareFunc{
nodeHeaderMiddleware,
usageMiddleware,
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),