diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index b7531349d..7f2c6ff80 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -325,12 +325,12 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") c.Response().Header().Set("X-Correlation-ID", id) - // X-LocalAI-Node attribution is deferred to AFTER ml.Load runs - // inside the streaming worker. The worker publishes the picked - // node ID onto nodeIDCh (per request, never shared), and we - // read it and set the header below before the first SSE flush. - // Setting it here would attach the previous request's routing - // decision (or nothing on a cold cache), not this request's. + // X-LocalAI-Node attribution (when --expose-node-header is on) + // is handled by middleware.ExposeNodeHeader at the wrapper + // layer: the first c.Response().Write / Flush lazily reads the + // node ID from the loader (post-ml.Load) and stamps the header + // before the byte hits the underlying writer. No per-request + // chan / per-handler plumbing needed here. mcpStreamMaxIterations := 10 if config.Agent.MaxIterations > 0 { @@ -347,34 +347,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator responses := make(chan schema.OpenAIResponse) ended := make(chan streamWorkerResult, 1) - // Buffered size 1 so the worker's publish is always - // non-blocking even if the handler hasn't drained yet. - // One node-ID value per request: re-used across MCP loop - // iterations is fine because the worker only sends on the - // first token of the first iteration (or on the bail-out - // path); later iterations share the same routing decision. - nodeIDCh := make(chan string, 1) - nodeIDApplied := false - applyNodeIDHeader := func() { - if nodeIDApplied { - return - } - select { - case nodeID := <-nodeIDCh: - nodeIDApplied = true - if nodeID != "" { - c.Response().Header().Set(NodeHeaderName, nodeID) - } - default: - } - } go func() { if !shouldUseFn { - u, err := processStream(predInput, input, config, cl, startupOptions, ml, responses, id, created, nodeIDCh) + u, err := processStream(predInput, input, config, cl, startupOptions, ml, responses, id, created) ended <- streamWorkerResult{usage: u, err: err} } else { - u, err := processStreamWithTools(noActionName, predInput, input, config, cl, startupOptions, ml, responses, id, created, &textContentToReturn, nodeIDCh) + u, err := processStreamWithTools(noActionName, predInput, input, config, cl, startupOptions, ml, responses, id, created, &textContentToReturn) ended <- streamWorkerResult{usage: u, err: err} } }() @@ -419,12 +398,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator continue } xlog.Debug("Sending chunk", "chunk", string(respData)) - // Attach X-LocalAI-Node from the worker's per-request - // signal before the very first Flush() locks the - // headers. The worker pushes on nodeIDCh AFTER - // ml.Load returns, so by the time we hit our first - // chunk read the signal is already buffered. - applyNodeIDHeader() _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) if err != nil { xlog.Debug("Sending chunk failed", "error", err) @@ -449,10 +422,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator respData, marshalErr := json.Marshal(errorResp) if marshalErr != nil { xlog.Error("Failed to marshal error response", "error", marshalErr) - applyNodeIDHeader() fmt.Fprintf(c.Response().Writer, "data: {\"error\":{\"message\":\"Internal error\",\"type\":\"server_error\"}}\n\n") } else { - applyNodeIDHeader() fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) } fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") @@ -517,7 +488,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator "result": toolResult, } if mcpEventData, err := json.Marshal(mcpEvent); err == nil { - applyNodeIDHeader() fmt.Fprintf(c.Response().Writer, "data: %s\n\n", mcpEventData) c.Response().Flush() } @@ -559,7 +529,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator Object: "chat.completion.chunk", } respData, _ := json.Marshal(toolCallMsg) - applyNodeIDHeader() fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) c.Response().Flush() toolsCalled = true @@ -590,7 +559,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator Object: "chat.completion.chunk", } respData, _ := json.Marshal(resp) - applyNodeIDHeader() fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) // Trailing usage chunk per OpenAI spec: emit only when the @@ -609,7 +577,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator _, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", trailer) } - applyNodeIDHeader() fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() xlog.Debug("Stream ended") @@ -618,8 +585,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator // Safety fallback. The MCP iteration loop above always returns, // so this is structurally unreachable; if we ever reach it the - // per-iteration nodeIDCh has gone out of scope and there is no - // header to attach. Just close the stream cleanly. + // stream is closed cleanly. The middleware-installed wrapper + // still stamps X-LocalAI-Node on this final write if applicable. fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil @@ -977,10 +944,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator respData, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(respData)) - // Attribute the response to a specific worker node when - // distributed mode is enabled and the operator opted in via - // --expose-node-header. No-op otherwise. - maybeSetNodeHeader(c, startupOptions, ml, input.Model) + // X-LocalAI-Node attribution (when --expose-node-header is on) + // is handled by middleware.ExposeNodeHeader at the wrapper + // layer; c.JSON's writes trigger the lazy stamp. // Return the prediction in the response body return c.JSON(200, resp) diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 593f00015..c2a3f45f2 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -26,32 +26,13 @@ import ( // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { - // process runs the streaming inference and (when nodeIDCh is non-nil) - // signals the picked distributed node ID back to the caller after - // ml.Load has run. The handler reads nodeIDCh before the first SSE - // flush so X-LocalAI-Node attribution reflects THIS request's routing - // decision, not the previous one. nodeIDCh must be buffered (>=1). - process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, nodeIDCh chan<- string) error { - nodeIDSent := false - publishNodeID := func() { - if nodeIDSent { - return - } - nodeIDSent = true - if nodeIDCh == nil { - return - } - select { - case nodeIDCh <- resolveNodeID(appConfig, loader, req.Model): - default: - } - } + // process runs the streaming inference. X-LocalAI-Node attribution + // (when --expose-node-header is on) is handled by + // middleware.ExposeNodeHeader at the response writer wrapper layer: + // the first SSE write triggers a lazy lookup against the loader, so + // no in-band signalling is needed here. + process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool { - // First callback fires only after ml.Load has returned, so the - // loader's per-modelID node stamp is populated. Push the node - // ID before any SSE chunk leaves so the handler can attach the - // X-LocalAI-Node header before its first flush. - publishNodeID() created := int(time.Now().Unix()) usage := schema.OpenAIUsage{ @@ -87,10 +68,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva return true } _, _, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback) - // Backend produced no tokens (immediate error/empty path): publish - // whatever node ID the loader has so the handler doesn't wait - // forever and the header is still set if applicable. - publishNodeID() close(responses) return err } @@ -134,12 +111,11 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") - // X-LocalAI-Node attribution is deferred to AFTER ml.Load runs - // inside the streaming worker. The worker publishes the picked - // node ID onto nodeIDCh (per request, never shared), and we - // read it and set the header below before the first SSE flush. - // Setting it here would attach the previous request's routing - // decision (or nothing on a cold cache), not this request's. + // X-LocalAI-Node attribution (when --expose-node-header is on) + // is handled by middleware.ExposeNodeHeader at the wrapper + // layer: the first c.Response().Write / Flush lazily reads the + // node ID from the loader (post-ml.Load) and stamps the header + // before the byte hits the underlying writer. if len(config.PromptStrings) > 1 { return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") @@ -159,26 +135,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva } responses := make(chan schema.OpenAIResponse) - // Buffered so the worker's publish is always non-blocking. - nodeIDCh := make(chan string, 1) - nodeIDApplied := false - applyNodeIDHeader := func() { - if nodeIDApplied { - return - } - select { - case nodeID := <-nodeIDCh: - nodeIDApplied = true - if nodeID != "" { - c.Response().Header().Set(NodeHeaderName, nodeID) - } - default: - } - } ended := make(chan error) go func() { - ended <- process(id, predInput, input, config, ml, responses, extraUsage, nodeIDCh) + ended <- process(id, predInput, input, config, ml, responses, extraUsage) }() var latestUsage *schema.OpenAIUsage @@ -206,12 +166,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva } xlog.Debug("Sending chunk", "chunk", string(respData)) - // Attach X-LocalAI-Node from the worker's per-request - // signal before the very first Flush() locks the - // headers. The worker pushes on nodeIDCh AFTER - // ml.Load returns, so by the time we hit our first - // chunk read the signal is already buffered. - applyNodeIDHeader() _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) if err != nil { return err @@ -238,7 +192,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva Object: "text_completion", } errorData, marshalErr := json.Marshal(errorResp) - applyNodeIDHeader() if marshalErr != nil { xlog.Error("Failed to marshal error response", "error", marshalErr) // Send a simple error message as fallback @@ -265,7 +218,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva Object: "text_completion", } respData, _ := json.Marshal(resp) - applyNodeIDHeader() fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) // Trailing usage chunk per OpenAI spec: emit only when the caller @@ -332,10 +284,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) - // Attribute the response to a specific worker node when distributed - // mode is enabled and the operator opted in via --expose-node-header. - // No-op otherwise. - maybeSetNodeHeader(c, appConfig, ml, input.Model) + // X-LocalAI-Node attribution (when --expose-node-header is on) is + // handled by middleware.ExposeNodeHeader at the wrapper layer. // Return the prediction in the response body return c.JSON(200, resp) diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index edf30b762..03e8d637e 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -102,10 +102,8 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) - // Attribute the response to a specific worker node when distributed - // mode is enabled and the operator opted in via --expose-node-header. - // No-op otherwise. - maybeSetNodeHeader(c, appConfig, ml, input.Model) + // X-LocalAI-Node attribution (when --expose-node-header is on) is + // handled by middleware.ExposeNodeHeader at the wrapper layer. // Return the prediction in the response body return c.JSON(200, resp) diff --git a/core/http/routes/anthropic.go b/core/http/routes/anthropic.go index 68b3079bd..e1ba75e3c 100644 --- a/core/http/routes/anthropic.go +++ b/core/http/routes/anthropic.go @@ -35,6 +35,7 @@ func RegisterAnthropicRoutes(app *echo.Echo, ) messagesMiddleware := []echo.MiddlewareFunc{ + middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader()), middleware.UsageMiddleware(application.AuthDB()), middleware.TraceMiddleware(application), re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), diff --git a/core/http/routes/ollama.go b/core/http/routes/ollama.go index aba0d8e97..2bd98947b 100644 --- a/core/http/routes/ollama.go +++ b/core/http/routes/ollama.go @@ -18,6 +18,7 @@ func RegisterOllamaRoutes(app *echo.Echo, traceMiddleware := middleware.TraceMiddleware(application) usageMiddleware := middleware.UsageMiddleware(application.AuthDB()) + nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader()) // Chat endpoint: POST /api/chat chatHandler := ollama.ChatEndpoint( @@ -27,6 +28,7 @@ func RegisterOllamaRoutes(app *echo.Echo, application.ApplicationConfig(), ) chatMiddleware := []echo.MiddlewareFunc{ + nodeHeaderMiddleware, usageMiddleware, traceMiddleware, re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), @@ -43,6 +45,7 @@ func RegisterOllamaRoutes(app *echo.Echo, application.ApplicationConfig(), ) generateMiddleware := []echo.MiddlewareFunc{ + nodeHeaderMiddleware, usageMiddleware, traceMiddleware, re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), @@ -58,6 +61,7 @@ func RegisterOllamaRoutes(app *echo.Echo, application.ApplicationConfig(), ) embedMiddleware := []echo.MiddlewareFunc{ + nodeHeaderMiddleware, usageMiddleware, traceMiddleware, re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)), diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index bd7793ae9..a42580ab2 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -17,6 +17,12 @@ func RegisterOpenAIRoutes(app *echo.Echo, // openAI compatible API endpoint traceMiddleware := middleware.TraceMiddleware(application) usageMiddleware := middleware.UsageMiddleware(application.AuthDB()) + // X-LocalAI-Node attribution middleware: wraps the response writer and + // stamps the header on first write when --expose-node-header is on. No-op + // otherwise. Applied to every inference path that routes through + // ml.Load (chat, completion, embeddings) so distributed-mode operators + // can observe which worker served each request. + nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig(), application.ModelLoader()) // realtime // TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions @@ -34,6 +40,7 @@ func RegisterOpenAIRoutes(app *echo.Echo, // chat chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, application.LocalAIAssistant()) chatMiddleware := []echo.MiddlewareFunc{ + nodeHeaderMiddleware, usageMiddleware, traceMiddleware, re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), @@ -73,6 +80,7 @@ func RegisterOpenAIRoutes(app *echo.Echo, // completion completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()) completionMiddleware := []echo.MiddlewareFunc{ + nodeHeaderMiddleware, usageMiddleware, traceMiddleware, re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)), @@ -94,6 +102,7 @@ func RegisterOpenAIRoutes(app *echo.Echo, // embeddings embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) embeddingMiddleware := []echo.MiddlewareFunc{ + nodeHeaderMiddleware, usageMiddleware, traceMiddleware, re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),