mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-30 03:25:42 -04:00
refactor(openai): route node-header through middleware wrapper
Wire middleware.ExposeNodeHeader onto the OpenAI inference routes (chat, completions, embeddings) plus the Anthropic /v1/messages shim and the Ollama chat/generate/embed shims. The wrapper handles X-LocalAI-Node attribution from a single place, so the per-handler maybeSetNodeHeader calls and the per-request nodeIDCh rendezvous / applyNodeIDHeader plumbing in chat.go and completion.go are removed. For SSE: the wrapper's lazy stamp on the first Write / WriteHeader / Flush picks up the post-ml.Load node ID from the loader, replacing the chan signal the worker used to publish. The role=assistant first chunk emission stays where it is (inside the first token callback) so all writes still happen AFTER ml.Load has stamped the per-modelID node ID. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-7[1m]
This commit is contained in:
@@ -325,12 +325,12 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
// X-LocalAI-Node attribution is deferred to AFTER ml.Load runs
|
||||
// inside the streaming worker. The worker publishes the picked
|
||||
// node ID onto nodeIDCh (per request, never shared), and we
|
||||
// read it and set the header below before the first SSE flush.
|
||||
// Setting it here would attach the previous request's routing
|
||||
// decision (or nothing on a cold cache), not this request's.
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on)
|
||||
// is handled by middleware.ExposeNodeHeader at the wrapper
|
||||
// layer: the first c.Response().Write / Flush lazily reads the
|
||||
// node ID from the loader (post-ml.Load) and stamps the header
|
||||
// before the byte hits the underlying writer. No per-request
|
||||
// chan / per-handler plumbing needed here.
|
||||
|
||||
mcpStreamMaxIterations := 10
|
||||
if config.Agent.MaxIterations > 0 {
|
||||
@@ -347,34 +347,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
ended := make(chan streamWorkerResult, 1)
|
||||
// Buffered size 1 so the worker's publish is always
|
||||
// non-blocking even if the handler hasn't drained yet.
|
||||
// One node-ID value per request: re-used across MCP loop
|
||||
// iterations is fine because the worker only sends on the
|
||||
// first token of the first iteration (or on the bail-out
|
||||
// path); later iterations share the same routing decision.
|
||||
nodeIDCh := make(chan string, 1)
|
||||
nodeIDApplied := false
|
||||
applyNodeIDHeader := func() {
|
||||
if nodeIDApplied {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case nodeID := <-nodeIDCh:
|
||||
nodeIDApplied = true
|
||||
if nodeID != "" {
|
||||
c.Response().Header().Set(NodeHeaderName, nodeID)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if !shouldUseFn {
|
||||
u, err := processStream(predInput, input, config, cl, startupOptions, ml, responses, id, created, nodeIDCh)
|
||||
u, err := processStream(predInput, input, config, cl, startupOptions, ml, responses, id, created)
|
||||
ended <- streamWorkerResult{usage: u, err: err}
|
||||
} else {
|
||||
u, err := processStreamWithTools(noActionName, predInput, input, config, cl, startupOptions, ml, responses, id, created, &textContentToReturn, nodeIDCh)
|
||||
u, err := processStreamWithTools(noActionName, predInput, input, config, cl, startupOptions, ml, responses, id, created, &textContentToReturn)
|
||||
ended <- streamWorkerResult{usage: u, err: err}
|
||||
}
|
||||
}()
|
||||
@@ -419,12 +398,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
continue
|
||||
}
|
||||
xlog.Debug("Sending chunk", "chunk", string(respData))
|
||||
// Attach X-LocalAI-Node from the worker's per-request
|
||||
// signal before the very first Flush() locks the
|
||||
// headers. The worker pushes on nodeIDCh AFTER
|
||||
// ml.Load returns, so by the time we hit our first
|
||||
// chunk read the signal is already buffered.
|
||||
applyNodeIDHeader()
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
xlog.Debug("Sending chunk failed", "error", err)
|
||||
@@ -449,10 +422,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
respData, marshalErr := json.Marshal(errorResp)
|
||||
if marshalErr != nil {
|
||||
xlog.Error("Failed to marshal error response", "error", marshalErr)
|
||||
applyNodeIDHeader()
|
||||
fmt.Fprintf(c.Response().Writer, "data: {\"error\":{\"message\":\"Internal error\",\"type\":\"server_error\"}}\n\n")
|
||||
} else {
|
||||
applyNodeIDHeader()
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
}
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
@@ -517,7 +488,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
"result": toolResult,
|
||||
}
|
||||
if mcpEventData, err := json.Marshal(mcpEvent); err == nil {
|
||||
applyNodeIDHeader()
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", mcpEventData)
|
||||
c.Response().Flush()
|
||||
}
|
||||
@@ -559,7 +529,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
respData, _ := json.Marshal(toolCallMsg)
|
||||
applyNodeIDHeader()
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
c.Response().Flush()
|
||||
toolsCalled = true
|
||||
@@ -590,7 +559,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
applyNodeIDHeader()
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
|
||||
// Trailing usage chunk per OpenAI spec: emit only when the
|
||||
@@ -609,7 +577,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", trailer)
|
||||
}
|
||||
|
||||
applyNodeIDHeader()
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
xlog.Debug("Stream ended")
|
||||
@@ -618,8 +585,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
// Safety fallback. The MCP iteration loop above always returns,
|
||||
// so this is structurally unreachable; if we ever reach it the
|
||||
// per-iteration nodeIDCh has gone out of scope and there is no
|
||||
// header to attach. Just close the stream cleanly.
|
||||
// stream is closed cleanly. The middleware-installed wrapper
|
||||
// still stamps X-LocalAI-Node on this final write if applicable.
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
@@ -977,10 +944,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
respData, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(respData))
|
||||
|
||||
// Attribute the response to a specific worker node when
|
||||
// distributed mode is enabled and the operator opted in via
|
||||
// --expose-node-header. No-op otherwise.
|
||||
maybeSetNodeHeader(c, startupOptions, ml, input.Model)
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on)
|
||||
// is handled by middleware.ExposeNodeHeader at the wrapper
|
||||
// layer; c.JSON's writes trigger the lazy stamp.
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
|
||||
@@ -26,32 +26,13 @@ import (
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// process runs the streaming inference and (when nodeIDCh is non-nil)
|
||||
// signals the picked distributed node ID back to the caller after
|
||||
// ml.Load has run. The handler reads nodeIDCh before the first SSE
|
||||
// flush so X-LocalAI-Node attribution reflects THIS request's routing
|
||||
// decision, not the previous one. nodeIDCh must be buffered (>=1).
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, nodeIDCh chan<- string) error {
|
||||
nodeIDSent := false
|
||||
publishNodeID := func() {
|
||||
if nodeIDSent {
|
||||
return
|
||||
}
|
||||
nodeIDSent = true
|
||||
if nodeIDCh == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case nodeIDCh <- resolveNodeID(appConfig, loader, req.Model):
|
||||
default:
|
||||
}
|
||||
}
|
||||
// process runs the streaming inference. X-LocalAI-Node attribution
|
||||
// (when --expose-node-header is on) is handled by
|
||||
// middleware.ExposeNodeHeader at the response writer wrapper layer:
|
||||
// the first SSE write triggers a lazy lookup against the loader, so
|
||||
// no in-band signalling is needed here.
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
// First callback fires only after ml.Load has returned, so the
|
||||
// loader's per-modelID node stamp is populated. Push the node
|
||||
// ID before any SSE chunk leaves so the handler can attach the
|
||||
// X-LocalAI-Node header before its first flush.
|
||||
publishNodeID()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
@@ -87,10 +68,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
return true
|
||||
}
|
||||
_, _, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
||||
// Backend produced no tokens (immediate error/empty path): publish
|
||||
// whatever node ID the loader has so the handler doesn't wait
|
||||
// forever and the header is still set if applicable.
|
||||
publishNodeID()
|
||||
close(responses)
|
||||
return err
|
||||
}
|
||||
@@ -134,12 +111,11 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
// X-LocalAI-Node attribution is deferred to AFTER ml.Load runs
|
||||
// inside the streaming worker. The worker publishes the picked
|
||||
// node ID onto nodeIDCh (per request, never shared), and we
|
||||
// read it and set the header below before the first SSE flush.
|
||||
// Setting it here would attach the previous request's routing
|
||||
// decision (or nothing on a cold cache), not this request's.
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on)
|
||||
// is handled by middleware.ExposeNodeHeader at the wrapper
|
||||
// layer: the first c.Response().Write / Flush lazily reads the
|
||||
// node ID from the loader (post-ml.Load) and stamps the header
|
||||
// before the byte hits the underlying writer.
|
||||
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
@@ -159,26 +135,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
}
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
// Buffered so the worker's publish is always non-blocking.
|
||||
nodeIDCh := make(chan string, 1)
|
||||
nodeIDApplied := false
|
||||
applyNodeIDHeader := func() {
|
||||
if nodeIDApplied {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case nodeID := <-nodeIDCh:
|
||||
nodeIDApplied = true
|
||||
if nodeID != "" {
|
||||
c.Response().Header().Set(NodeHeaderName, nodeID)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
ended := make(chan error)
|
||||
go func() {
|
||||
ended <- process(id, predInput, input, config, ml, responses, extraUsage, nodeIDCh)
|
||||
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
|
||||
}()
|
||||
|
||||
var latestUsage *schema.OpenAIUsage
|
||||
@@ -206,12 +166,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
}
|
||||
|
||||
xlog.Debug("Sending chunk", "chunk", string(respData))
|
||||
// Attach X-LocalAI-Node from the worker's per-request
|
||||
// signal before the very first Flush() locks the
|
||||
// headers. The worker pushes on nodeIDCh AFTER
|
||||
// ml.Load returns, so by the time we hit our first
|
||||
// chunk read the signal is already buffered.
|
||||
applyNodeIDHeader()
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -238,7 +192,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
Object: "text_completion",
|
||||
}
|
||||
errorData, marshalErr := json.Marshal(errorResp)
|
||||
applyNodeIDHeader()
|
||||
if marshalErr != nil {
|
||||
xlog.Error("Failed to marshal error response", "error", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
@@ -265,7 +218,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
applyNodeIDHeader()
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
|
||||
// Trailing usage chunk per OpenAI spec: emit only when the caller
|
||||
@@ -332,10 +284,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(jsonResult))
|
||||
|
||||
// Attribute the response to a specific worker node when distributed
|
||||
// mode is enabled and the operator opted in via --expose-node-header.
|
||||
// No-op otherwise.
|
||||
maybeSetNodeHeader(c, appConfig, ml, input.Model)
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on) is
|
||||
// handled by middleware.ExposeNodeHeader at the wrapper layer.
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
|
||||
@@ -102,10 +102,8 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(jsonResult))
|
||||
|
||||
// Attribute the response to a specific worker node when distributed
|
||||
// mode is enabled and the operator opted in via --expose-node-header.
|
||||
// No-op otherwise.
|
||||
maybeSetNodeHeader(c, appConfig, ml, input.Model)
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on) is
|
||||
// handled by middleware.ExposeNodeHeader at the wrapper layer.
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
|
||||
@@ -35,6 +35,7 @@ func RegisterAnthropicRoutes(app *echo.Echo,
|
||||
)
|
||||
|
||||
messagesMiddleware := []echo.MiddlewareFunc{
|
||||
middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader()),
|
||||
middleware.UsageMiddleware(application.AuthDB()),
|
||||
middleware.TraceMiddleware(application),
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
|
||||
@@ -18,6 +18,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
|
||||
traceMiddleware := middleware.TraceMiddleware(application)
|
||||
usageMiddleware := middleware.UsageMiddleware(application.AuthDB())
|
||||
nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader())
|
||||
|
||||
// Chat endpoint: POST /api/chat
|
||||
chatHandler := ollama.ChatEndpoint(
|
||||
@@ -27,6 +28,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
@@ -43,6 +45,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
generateMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
@@ -58,6 +61,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
embedMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
|
||||
@@ -17,6 +17,12 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// openAI compatible API endpoint
|
||||
traceMiddleware := middleware.TraceMiddleware(application)
|
||||
usageMiddleware := middleware.UsageMiddleware(application.AuthDB())
|
||||
// X-LocalAI-Node attribution middleware: wraps the response writer and
|
||||
// stamps the header on first write when --expose-node-header is on. No-op
|
||||
// otherwise. Applied to every inference path that routes through
|
||||
// ml.Load (chat, completion, embeddings) so distributed-mode operators
|
||||
// can observe which worker served each request.
|
||||
nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader())
|
||||
|
||||
// realtime
|
||||
// TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
|
||||
@@ -34,6 +40,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// chat
|
||||
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, application.LocalAIAssistant())
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
@@ -73,6 +80,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// completion
|
||||
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
completionMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
|
||||
@@ -94,6 +102,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// embeddings
|
||||
embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
embeddingMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
|
||||
Reference in New Issue
Block a user