mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-29 11:07:18 -04:00
* feat(distributed): add per-request node ID context holder Introduce pkg/distributedhdr, a leaf package carrying a per-request *atomic.Value holder for the picked worker node ID from the SmartRouter (core/services/nodes) up to the HTTP response writer wrapper (core/http/middleware). Avoids the import cycle that a shared key in either consumer would create. Exposes NewHolder, WithHolder, Holder, Stamp, Load, Inherit. The holder is atomic.Value so cross-goroutine publish from the router to the response writer wrapper is race-clean. Assisted-by: Claude:claude-opus-4-7[1m] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): add ExposeNodeHeader middleware + response writer wrapper New ApplicationConfig.ExposeNodeHeader bool + --expose-node-header CLI flag / LOCALAI_EXPOSE_NODE_HEADER env var (default off; the node ID reveals internal topology and is opt-in). The middleware creates a per-request *atomic.Value holder, attaches it to c.Request().Context() via distributedhdr.WithHolder, and wraps c.Response().Writer with a custom http.ResponseWriter that sets the X-LocalAI-Node header on first Write / WriteHeader / Flush by reading the holder. Implements http.Flusher, http.Hijacker, Unwrap so it composes cleanly with Echo and http.NewResponseController. request.go propagates the holder onto derived contexts via distributedhdr.Inherit so the holder survives the correlation-ID context replacement. Unit + race-clean concurrency + integration specs. Assisted-by: Claude:claude-opus-4-7[1m] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): stamp node ID in router and wire middleware to inference routes ModelRouterAdapter.Route stamps the picked node ID into the per-request holder via distributedhdr.Stamp(ctx, result.Node.ID) right after replica selection. Wire ExposeNodeHeader middleware to: - OpenAI chat/completion/embeddings + audio transcriptions/speech + image generations/inpainting - Anthropic /v1/messages - Ollama /api/chat, /api/generate, /api/embed, /api/embeddings - Jina /v1/rerank - LocalAI /v1/vad The middleware's wrapper reads the holder on first byte and sets the X-LocalAI-Node response header before delegating to the underlying writer. Per-request scope means no race under concurrent multi-replica routing. Assisted-by: Claude:claude-opus-4-7[1m] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(distributed): thread request context through backend Load + cover ctx propagation Five non-OpenAI backend helpers were silently using app.Context instead of the request context for the gRPC backend call: transcription, TTS, image generation, rerank, VAD. Effect: distributedhdr.Stamp in the router callback was a silent no-op for these paths, AND client cancellation didn't propagate to in-flight inference. Thread c.Request().Context() (or the equivalent input.Context after the request middleware has installed the correlation-ID derived context) through each helper and into ModelOptions via model.WithContext(ctx). ImageGeneration's signature gains a leading ctx parameter; in-tree callers (openai image, openai inpainting, openai inpainting_test) are updated to match. ModelEmbedding gains a leading ctx parameter for the same reason; the openai and ollama embedding handlers pass the request context through. chat_stream_workers.go defers the initial role=assistant chunk emission until the first token callback so the wrapper's lazy X-LocalAI-Node lookup against the loader runs AFTER ml.Load has stamped the per-modelID node ID; semantically identical for clients (role still arrives before any text). Regression test core/backend/ctx_propagation_test.go pins ctx propagation for all five helpers. Docs updated to enumerate the full endpoint coverage of the --expose-node-header flag. Assisted-by: Claude:claude-opus-4-7[1m] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
725 lines
22 KiB
Go
725 lines
22 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
|
"github.com/mudler/LocalAI/core/templates"
|
|
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
|
"github.com/mudler/LocalAI/pkg/model"
|
|
"github.com/mudler/LocalAI/pkg/utils"
|
|
"github.com/mudler/xlog"
|
|
)
|
|
|
|
type correlationIDKeyType string
|
|
|
|
// CorrelationIDKey to track request across process boundary
|
|
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
|
|
|
type RequestExtractor struct {
|
|
modelConfigLoader *config.ModelConfigLoader
|
|
modelLoader *model.ModelLoader
|
|
applicationConfig *config.ApplicationConfig
|
|
}
|
|
|
|
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
|
return &RequestExtractor{
|
|
modelConfigLoader: modelConfigLoader,
|
|
modelLoader: modelLoader,
|
|
applicationConfig: applicationConfig,
|
|
}
|
|
}
|
|
|
|
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
|
|
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
|
|
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
|
|
|
|
// TODO: Refactor to not return error if unchanged
|
|
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
|
|
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
|
if ok && model != "" {
|
|
return
|
|
}
|
|
model = c.Param("model")
|
|
|
|
if model == "" {
|
|
model = c.QueryParam("model")
|
|
}
|
|
|
|
// Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting)
|
|
if model == "" {
|
|
model = c.FormValue("model")
|
|
}
|
|
|
|
if model == "" {
|
|
// Set model from bearer token, if available
|
|
auth := c.Request().Header.Get("Authorization")
|
|
bearer := strings.TrimPrefix(auth, "Bearer ")
|
|
if bearer != "" && bearer != auth {
|
|
exists, err := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, galleryop.ALWAYS_INCLUDE)
|
|
if err == nil && exists {
|
|
model = bearer
|
|
}
|
|
}
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
|
|
}
|
|
|
|
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc {
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
re.setModelNameFromRequest(c)
|
|
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
|
if !ok || localModelName == "" {
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
|
|
xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName)
|
|
}
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc {
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
re.setModelNameFromRequest(c)
|
|
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
|
if localModelName != "" { // Don't overwrite existing values
|
|
return next(c)
|
|
}
|
|
|
|
modelNames, err := galleryop.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, galleryop.SKIP_IF_CONFIGURED)
|
|
if err != nil {
|
|
xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err)
|
|
return next(c)
|
|
}
|
|
|
|
if len(modelNames) == 0 {
|
|
xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed")
|
|
// This is non-fatal - making it so was breaking the case of direct installation of raw models
|
|
// return errors.New("this endpoint requires at least one model to be installed")
|
|
return next(c)
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
|
|
xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0])
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TODO: If context and cancel above belong on all methods, move that part of above into here!
|
|
// Otherwise, it's in its own method below for now
|
|
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc {
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
input := initializer()
|
|
if input == nil {
|
|
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body")
|
|
}
|
|
if err := c.Bind(input); err != nil {
|
|
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err))
|
|
}
|
|
|
|
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
|
|
if input.ModelName(nil) == "" {
|
|
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
|
if ok && localModelName != "" {
|
|
xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName)
|
|
input.ModelName(&localModelName)
|
|
}
|
|
}
|
|
|
|
modelName := input.ModelName(nil)
|
|
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(modelName, re.applicationConfig)
|
|
|
|
if err != nil {
|
|
xlog.Warn("Model Configuration File not found", "model", modelName, "error", err)
|
|
} else if cfg.Model == "" && modelName != "" {
|
|
xlog.Debug("config does not include model, using input", "input.ModelName", modelName)
|
|
cfg.Model = modelName
|
|
}
|
|
|
|
// If a model name was specified, verify it actually exists before proceeding.
|
|
// Check both configured models and loose model files in the model path.
|
|
// Skip the check for HuggingFace model IDs (contain "/") since backends
|
|
// like diffusers may download these on the fly.
|
|
if modelName != "" && !strings.Contains(modelName, "/") {
|
|
exists, existsErr := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, modelName, galleryop.ALWAYS_INCLUDE)
|
|
if existsErr == nil && !exists {
|
|
return c.JSON(http.StatusNotFound, schema.ErrorResponse{
|
|
Error: &schema.APIError{
|
|
Message: fmt.Sprintf("model %q not found. To see available models, call GET /v1/models", modelName),
|
|
Code: http.StatusNotFound,
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
// Check if the model is disabled
|
|
if cfg != nil && cfg.IsDisabled() {
|
|
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
|
Error: &schema.APIError{
|
|
Message: fmt.Sprintf("model %q is disabled and cannot be loaded. Enable it via the System page or API to use it.", modelName),
|
|
Code: http.StatusForbidden,
|
|
Type: "model_disabled",
|
|
},
|
|
})
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
|
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
|
|
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
|
if !ok || input.Model == "" {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
|
if !ok || cfg == nil {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
// Extract or generate the correlation ID
|
|
correlationID := c.Request().Header.Get("X-Correlation-ID")
|
|
if correlationID == "" {
|
|
correlationID = uuid.New().String()
|
|
}
|
|
c.Response().Header().Set("X-Correlation-ID", correlationID)
|
|
|
|
// Use the request context directly - Echo properly supports context cancellation!
|
|
// No need for workarounds like handleConnectionCancellation
|
|
reqCtx := c.Request().Context()
|
|
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
|
|
|
// Cancel when request context is cancelled (client disconnects)
|
|
go func() {
|
|
select {
|
|
case <-reqCtx.Done():
|
|
cancel()
|
|
case <-c1.Done():
|
|
// Already cancelled
|
|
}
|
|
}()
|
|
|
|
// Add the correlation ID to the new context
|
|
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
|
ctxWithCorrelationID = distributedhdr.Inherit(ctxWithCorrelationID, reqCtx)
|
|
|
|
input.Context = ctxWithCorrelationID
|
|
input.Cancel = cancel
|
|
|
|
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if cfg.Model == "" {
|
|
xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model)
|
|
cfg.Model = input.Model
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
|
|
|
return nil
|
|
}
|
|
|
|
// extractToolChoiceFunctionName parses a tool_choice map and returns the
|
|
// specific function name. Accepts both the OpenAI-spec nested shape
|
|
// ({type:function, function:{name:...}}) and the legacy/Anthropic-compat
|
|
// flat shape ({type:function, name:...}); the nested form wins when both
|
|
// are present. Returns "" for malformed input or when the shape names a
|
|
// mode rather than a specific tool.
|
|
func extractToolChoiceFunctionName(m map[string]any) string {
|
|
tcType, ok := m["type"].(string)
|
|
if !ok || tcType != "function" {
|
|
return ""
|
|
}
|
|
if fn, ok := m["function"].(map[string]any); ok {
|
|
if n, ok := fn["name"].(string); ok && n != "" {
|
|
return n
|
|
}
|
|
}
|
|
if n, ok := m["name"].(string); ok {
|
|
return n
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
|
if input.Echo {
|
|
config.Echo = input.Echo
|
|
}
|
|
if input.TopK != nil {
|
|
config.TopK = input.TopK
|
|
}
|
|
if input.TopP != nil {
|
|
config.TopP = input.TopP
|
|
}
|
|
if input.MinP != nil {
|
|
config.MinP = input.MinP
|
|
}
|
|
|
|
if input.Backend != "" {
|
|
config.Backend = input.Backend
|
|
}
|
|
|
|
if input.ClipSkip != 0 {
|
|
config.Diffusers.ClipSkip = input.ClipSkip
|
|
}
|
|
|
|
if input.NegativePromptScale != 0 {
|
|
config.NegativePromptScale = input.NegativePromptScale
|
|
}
|
|
|
|
if input.NegativePrompt != "" {
|
|
config.NegativePrompt = input.NegativePrompt
|
|
}
|
|
|
|
if input.RopeFreqBase != 0 {
|
|
config.RopeFreqBase = input.RopeFreqBase
|
|
}
|
|
|
|
if input.RopeFreqScale != 0 {
|
|
config.RopeFreqScale = input.RopeFreqScale
|
|
}
|
|
|
|
if input.Grammar != "" {
|
|
config.Grammar = input.Grammar
|
|
}
|
|
|
|
if input.Temperature != nil {
|
|
config.Temperature = input.Temperature
|
|
}
|
|
|
|
// Collapse the modern max_completion_tokens alias into the
|
|
// legacy Maxtokens field so downstream code reads exactly one.
|
|
// MaxCompletionTokens wins on conflict — it's the canonical
|
|
// name per OpenAI's deprecation guidance, and a client that
|
|
// took the trouble to send it intends that value. Clearing
|
|
// the sibling prevents both names from being emitted if input
|
|
// is re-marshaled (cloud-proxy passthrough).
|
|
if input.MaxCompletionTokens != nil {
|
|
input.Maxtokens = input.MaxCompletionTokens
|
|
input.MaxCompletionTokens = nil
|
|
}
|
|
if input.Maxtokens != nil {
|
|
config.Maxtokens = input.Maxtokens
|
|
}
|
|
|
|
if input.ResponseFormat != nil {
|
|
switch responseFormat := input.ResponseFormat.(type) {
|
|
case string:
|
|
config.ResponseFormat = responseFormat
|
|
case map[string]any:
|
|
config.ResponseFormatMap = responseFormat
|
|
}
|
|
}
|
|
|
|
switch stop := input.Stop.(type) {
|
|
case string:
|
|
if stop != "" {
|
|
config.StopWords = append(config.StopWords, stop)
|
|
}
|
|
case []any:
|
|
for _, pp := range stop {
|
|
if s, ok := pp.(string); ok {
|
|
config.StopWords = append(config.StopWords, s)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(input.Tools) > 0 {
|
|
for _, tool := range input.Tools {
|
|
input.Functions = append(input.Functions, tool.Function)
|
|
}
|
|
}
|
|
|
|
if input.ToolsChoice != nil {
|
|
// OpenAI tool_choice has three valid shapes plus one tolerated
|
|
// non-spec form seen in the wild:
|
|
//
|
|
// 1. string mode: "auto" | "none" | "required"
|
|
// 2. specific tool: {"type":"function", "function":{"name":"..."}} (current spec)
|
|
// 3. legacy: {"type":"function", "name":"..."} (older / Anthropic-compat)
|
|
// 4. double-encoded: "{\"type\":\"function\", ...}" (some clients serialize the object)
|
|
//
|
|
// The pre-#9559 code unmarshalled the string case through
|
|
// json.Unmarshal([]byte(content), &functions.Tool{}), which:
|
|
// - failed for plain string modes (so "required" / "none" were
|
|
// silently ignored and tools stayed enabled regardless), but
|
|
// - happened to handle shape 4 by accident.
|
|
// It also could not parse shape 3 because functions.Tool has no
|
|
// flat top-level Name field.
|
|
//
|
|
// Mirror the parsing pattern from MergeOpenResponsesConfig (#9509),
|
|
// route results through the existing input.FunctionCall string/map
|
|
// dispatch downstream (see the switch on input.FunctionCall in this
|
|
// same function), and preserve the shape-4 fallback so non-spec
|
|
// clients don't silently break. Tracked in #9508; sibling fix in #9526.
|
|
switch content := input.ToolsChoice.(type) {
|
|
case string:
|
|
// "auto" is the default and needs no override. "none" and "required"
|
|
// both reach SetFunctionCallString via the input.FunctionCall string
|
|
// branch below; ShouldUseFunctions() then returns false for "none"
|
|
// (tools disabled) and true for "required" (mode engaged).
|
|
//
|
|
// If the string looks like a JSON object, try shape 4 first: parse
|
|
// it as a tool_choice map and use the resulting name. Falling back
|
|
// to mode-string handling when the parse yields no usable name keeps
|
|
// genuinely-malformed input from accidentally engaging a mode.
|
|
if content == "" || content == "auto" {
|
|
break
|
|
}
|
|
if strings.HasPrefix(strings.TrimSpace(content), "{") {
|
|
var nested map[string]any
|
|
if err := json.Unmarshal([]byte(content), &nested); err == nil {
|
|
if name := extractToolChoiceFunctionName(nested); name != "" {
|
|
input.FunctionCall = map[string]any{"name": name}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
input.FunctionCall = content
|
|
case map[string]any:
|
|
if name := extractToolChoiceFunctionName(content); name != "" {
|
|
input.FunctionCall = map[string]any{"name": name}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Decode each request's message content
|
|
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
|
for i, m := range input.Messages {
|
|
nrOfImgsInMessage := 0
|
|
nrOfVideosInMessage := 0
|
|
nrOfAudiosInMessage := 0
|
|
|
|
switch content := m.Content.(type) {
|
|
case string:
|
|
input.Messages[i].StringContent = content
|
|
case []any:
|
|
dat, _ := json.Marshal(content)
|
|
c := []schema.Content{}
|
|
json.Unmarshal(dat, &c)
|
|
|
|
textContent := ""
|
|
// we will template this at the end
|
|
|
|
CONTENT:
|
|
for _, pp := range c {
|
|
switch pp.Type {
|
|
case "text":
|
|
textContent += pp.Text
|
|
//input.Messages[i].StringContent = pp.Text
|
|
case "video", "video_url":
|
|
// Decode content as base64 either if it's an URL or base64 text
|
|
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
|
if err != nil {
|
|
xlog.Error("Failed encoding video", "error", err)
|
|
continue CONTENT
|
|
}
|
|
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
|
vidIndex++
|
|
nrOfVideosInMessage++
|
|
case "audio_url", "audio":
|
|
// Decode content as base64 either if it's an URL or base64 text
|
|
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
|
if err != nil {
|
|
xlog.Error("Failed encoding audio", "error", err)
|
|
continue CONTENT
|
|
}
|
|
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
|
audioIndex++
|
|
nrOfAudiosInMessage++
|
|
case "input_audio":
|
|
// TODO: make sure that we only return base64 stuff
|
|
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
|
|
audioIndex++
|
|
nrOfAudiosInMessage++
|
|
case "image_url", "image":
|
|
// Decode content as base64 either if it's an URL or base64 text
|
|
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
|
if err != nil {
|
|
xlog.Error("Failed encoding image", "error", err)
|
|
continue CONTENT
|
|
}
|
|
|
|
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
|
|
|
imgIndex++
|
|
nrOfImgsInMessage++
|
|
}
|
|
}
|
|
|
|
// When the backend handles templating itself (UseTokenizerTemplate),
|
|
// it also injects media markers server-side (see
|
|
// oaicompat_chat_params_parse in llama.cpp). Emitting our own markers
|
|
// here would double-mark them and downstream consumers ignore
|
|
// StringContent in that path anyway, so just pass through plain text.
|
|
if config.TemplateConfig.UseTokenizerTemplate {
|
|
input.Messages[i].StringContent = textContent
|
|
} else {
|
|
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
|
TotalImages: imgIndex,
|
|
TotalVideos: vidIndex,
|
|
TotalAudios: audioIndex,
|
|
ImagesInMessage: nrOfImgsInMessage,
|
|
VideosInMessage: nrOfVideosInMessage,
|
|
AudiosInMessage: nrOfAudiosInMessage,
|
|
}, textContent)
|
|
}
|
|
}
|
|
}
|
|
|
|
if input.RepeatPenalty != 0 {
|
|
config.RepeatPenalty = input.RepeatPenalty
|
|
}
|
|
|
|
if input.FrequencyPenalty != 0 {
|
|
config.FrequencyPenalty = input.FrequencyPenalty
|
|
}
|
|
|
|
if input.PresencePenalty != 0 {
|
|
config.PresencePenalty = input.PresencePenalty
|
|
}
|
|
|
|
if input.Keep != 0 {
|
|
config.Keep = input.Keep
|
|
}
|
|
|
|
if input.Batch != 0 {
|
|
config.Batch = input.Batch
|
|
}
|
|
|
|
if input.IgnoreEOS {
|
|
config.IgnoreEOS = input.IgnoreEOS
|
|
}
|
|
|
|
if input.Seed != nil {
|
|
config.Seed = input.Seed
|
|
}
|
|
|
|
if input.TypicalP != nil {
|
|
config.TypicalP = input.TypicalP
|
|
}
|
|
|
|
xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input))
|
|
|
|
switch inputs := input.Input.(type) {
|
|
case string:
|
|
if inputs != "" {
|
|
config.InputStrings = append(config.InputStrings, inputs)
|
|
}
|
|
case []any:
|
|
for _, pp := range inputs {
|
|
switch i := pp.(type) {
|
|
case string:
|
|
config.InputStrings = append(config.InputStrings, i)
|
|
case []any:
|
|
tokens := []int{}
|
|
inputStrings := []string{}
|
|
for _, ii := range i {
|
|
switch ii := ii.(type) {
|
|
case int:
|
|
tokens = append(tokens, ii)
|
|
case float64:
|
|
tokens = append(tokens, int(ii))
|
|
case string:
|
|
inputStrings = append(inputStrings, ii)
|
|
default:
|
|
xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii))
|
|
}
|
|
}
|
|
config.InputToken = append(config.InputToken, tokens)
|
|
config.InputStrings = append(config.InputStrings, inputStrings...)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Can be either a string or an object
|
|
switch fnc := input.FunctionCall.(type) {
|
|
case string:
|
|
if fnc != "" {
|
|
config.SetFunctionCallString(fnc)
|
|
}
|
|
case map[string]any:
|
|
var name string
|
|
n, exists := fnc["name"]
|
|
if exists {
|
|
nn, e := n.(string)
|
|
if e {
|
|
name = nn
|
|
}
|
|
}
|
|
config.SetFunctionCallNameString(name)
|
|
}
|
|
|
|
switch p := input.Prompt.(type) {
|
|
case string:
|
|
config.PromptStrings = append(config.PromptStrings, p)
|
|
case []any:
|
|
for _, pp := range p {
|
|
if s, ok := pp.(string); ok {
|
|
config.PromptStrings = append(config.PromptStrings, s)
|
|
}
|
|
}
|
|
}
|
|
|
|
// If a quality was defined as number, convert it to step
|
|
if input.Quality != "" {
|
|
q, err := strconv.Atoi(input.Quality)
|
|
if err == nil {
|
|
config.Step = q
|
|
}
|
|
}
|
|
|
|
if valid, _ := config.Validate(); valid {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("unable to validate configuration after merging")
|
|
}
|
|
|
|
func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error {
|
|
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest)
|
|
if !ok || input.Model == "" {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
// Convert input items to Messages (this will be done in the endpoint handler)
|
|
// We store the input in the request for the endpoint to process
|
|
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
|
if !ok || cfg == nil {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
// Extract or generate the correlation ID (Open Responses uses x-request-id)
|
|
correlationID := c.Request().Header.Get("x-request-id")
|
|
if correlationID == "" {
|
|
correlationID = uuid.New().String()
|
|
}
|
|
c.Response().Header().Set("x-request-id", correlationID)
|
|
|
|
// Use the request context directly - Echo properly supports context cancellation!
|
|
reqCtx := c.Request().Context()
|
|
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
|
|
|
// Cancel when request context is cancelled (client disconnects)
|
|
go func() {
|
|
select {
|
|
case <-reqCtx.Done():
|
|
cancel()
|
|
case <-c1.Done():
|
|
// Already cancelled
|
|
}
|
|
}()
|
|
|
|
// Add the correlation ID to the new context
|
|
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
|
ctxWithCorrelationID = distributedhdr.Inherit(ctxWithCorrelationID, reqCtx)
|
|
|
|
input.Context = ctxWithCorrelationID
|
|
input.Cancel = cancel
|
|
|
|
err := MergeOpenResponsesConfig(cfg, input)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if cfg.Model == "" {
|
|
xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model)
|
|
cfg.Model = input.Model
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
|
|
|
return nil
|
|
}
|
|
|
|
// MergeOpenResponsesConfig merges request parameters into the model configuration.
|
|
func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error {
|
|
// Temperature
|
|
if input.Temperature != nil {
|
|
config.Temperature = input.Temperature
|
|
}
|
|
|
|
// TopP
|
|
if input.TopP != nil {
|
|
config.TopP = input.TopP
|
|
}
|
|
|
|
// MaxOutputTokens -> Maxtokens
|
|
if input.MaxOutputTokens != nil {
|
|
config.Maxtokens = input.MaxOutputTokens
|
|
}
|
|
|
|
// Convert tools to functions - this will be handled in the endpoint handler
|
|
// We just validate that tools are present if needed
|
|
|
|
// Handle tool_choice
|
|
if input.ToolChoice != nil {
|
|
switch tc := input.ToolChoice.(type) {
|
|
case string:
|
|
// "auto", "required", or "none"
|
|
if tc == "required" {
|
|
config.SetFunctionCallString("required")
|
|
} else if tc == "none" {
|
|
// Don't use tools - handled in endpoint
|
|
}
|
|
// "auto" is default - let model decide
|
|
case map[string]any:
|
|
// Specific tool. OpenAI spec nests the function name under "function":
|
|
// {"type":"function", "function":{"name":"..."}}
|
|
// Legacy/Anthropic-compat form puts it at the top level:
|
|
// {"type":"function", "name":"..."}
|
|
// The old code only handled the legacy shape AND used the wrong
|
|
// setter (SetFunctionCallString writes the mode field; the
|
|
// specific-function name lives in a separate field read by
|
|
// ShouldCallSpecificFunction / FunctionToCall). Net effect: a
|
|
// correctly-formed OpenAI tool_choice never engaged grammar-based
|
|
// forcing, the model got the tools but no selection hint, and
|
|
// streamed raw JSON as delta.content instead of delta.tool_calls.
|
|
if tcType, ok := tc["type"].(string); ok && tcType == "function" {
|
|
var name string
|
|
if fn, ok := tc["function"].(map[string]any); ok {
|
|
if n, ok := fn["name"].(string); ok {
|
|
name = n
|
|
}
|
|
}
|
|
if name == "" {
|
|
if n, ok := tc["name"].(string); ok {
|
|
name = n
|
|
}
|
|
}
|
|
if name != "" {
|
|
config.SetFunctionCallNameString(name)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if valid, _ := config.Validate(); valid {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("unable to validate configuration after merging")
|
|
}
|