mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-17 04:56:52 -04:00
* fix(middleware): parse OpenAI-spec tool_choice in /v1/chat/completions Follows up on #9526 (the 3-site setter fix) by addressing the remaining clause in #9508 — string mode and OpenAI-spec specific-function shape both silently failed in the /v1/chat/completions parsing path. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(middleware): restore LF endings and cover tool_choice parsing with specs The previous commit on this branch saved core/http/middleware/request.go with CRLF line endings, ballooning the diff against master to 684 / 651 for what is in reality a ~50-line parsing change. Restore LF (matches .editorconfig end_of_line = lf). Add 11 Ginkgo specs under "SetModelAndConfig tool_choice parsing (chat completions)" that parallel the existing MergeOpenResponsesConfig specs from #9509. They drive the full middleware chain (SetModelAndConfig + SetOpenAIRequest) and assert: * "required" -> ShouldUseFunctions=true, no specific name * "none" -> ShouldUseFunctions=false (tools disabled per OpenAI spec) * "auto" -> default, tools available, no specific name * {type:function, function:{name:X}} (spec) -> X is forced * {type:function, name:X} (legacy) -> X is forced * nested wins when both forms are present * malformed shapes (no type, wrong type, no name, empty name) are no-ops Update the inline comment on the string case to describe the actual mechanism: "none" reaches SetFunctionCallString("none") downstream and is then honored by ShouldUseFunctions() returning false. Before this PR json.Unmarshal([]byte("none"), &functions.Tool{}) failed silently, so "none" was ignored - making "none" actually work is a real behavior fix this PR brings. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:opus-4-7 [Claude Code] * fix(middleware): preserve pre-#9559 support for JSON-string-encoded tool_choice Some non-spec clients send tool_choice as a JSON-encoded string of an object form, e.g. "{\"type\":\"function\",\"function\":{\"name\":\"X\"}}". The pre-#9559 code accepted this by accident: its case string: branch ran json.Unmarshal([]byte(content), &functions.Tool{}), which succeeded for that double-encoded shape even though it failed for the legitimate plain string modes "auto" / "none" / "required". The first version of this PR routed every string straight to SetFunctionCallString as a mode, which fixed the plain-string cases but silently regressed the double-encoded one (funcs.Select("{...}") returns nothing). Restore the fallback: when a string looks like a JSON object, try parsing it as a tool_choice map first; fall through to mode-string handling only when no usable name comes out. Factor the map-name extraction into a small helper (extractToolChoiceFunctionName) so the string-fallback and the regular map case go through identical code, and accept both the OpenAI-spec nested shape and the legacy/Anthropic flat shape from either entry point. Add 3 Ginkgo specs covering the double-encoded case (nested form, legacy form, and the fall-through when the JSON has no usable name). Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:opus-4-7 [Claude Code] * test(middleware): silence errcheck on AfterEach os.RemoveAll The new tool_choice parsing tests added a second AfterEach that calls os.RemoveAll(modelDir) without checking the error; errcheck flagged it. Suppress with the standard _ = idiom. The pre-existing AfterEach on the earlier Describe still elides the check the same way it did before - leaving that untouched to keep this commit minimal. Assisted-by: Claude:opus-4-7 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
711 lines
22 KiB
Go
711 lines
22 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
|
"github.com/mudler/LocalAI/core/templates"
|
|
"github.com/mudler/LocalAI/pkg/model"
|
|
"github.com/mudler/LocalAI/pkg/utils"
|
|
"github.com/mudler/xlog"
|
|
)
|
|
|
|
type correlationIDKeyType string
|
|
|
|
// CorrelationIDKey to track request across process boundary
|
|
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
|
|
|
type RequestExtractor struct {
|
|
modelConfigLoader *config.ModelConfigLoader
|
|
modelLoader *model.ModelLoader
|
|
applicationConfig *config.ApplicationConfig
|
|
}
|
|
|
|
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
|
return &RequestExtractor{
|
|
modelConfigLoader: modelConfigLoader,
|
|
modelLoader: modelLoader,
|
|
applicationConfig: applicationConfig,
|
|
}
|
|
}
|
|
|
|
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
|
|
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
|
|
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
|
|
|
|
// TODO: Refactor to not return error if unchanged
|
|
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
|
|
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
|
if ok && model != "" {
|
|
return
|
|
}
|
|
model = c.Param("model")
|
|
|
|
if model == "" {
|
|
model = c.QueryParam("model")
|
|
}
|
|
|
|
// Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting)
|
|
if model == "" {
|
|
model = c.FormValue("model")
|
|
}
|
|
|
|
if model == "" {
|
|
// Set model from bearer token, if available
|
|
auth := c.Request().Header.Get("Authorization")
|
|
bearer := strings.TrimPrefix(auth, "Bearer ")
|
|
if bearer != "" && bearer != auth {
|
|
exists, err := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, galleryop.ALWAYS_INCLUDE)
|
|
if err == nil && exists {
|
|
model = bearer
|
|
}
|
|
}
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
|
|
}
|
|
|
|
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc {
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
re.setModelNameFromRequest(c)
|
|
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
|
if !ok || localModelName == "" {
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
|
|
xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName)
|
|
}
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc {
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
re.setModelNameFromRequest(c)
|
|
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
|
if localModelName != "" { // Don't overwrite existing values
|
|
return next(c)
|
|
}
|
|
|
|
modelNames, err := galleryop.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, galleryop.SKIP_IF_CONFIGURED)
|
|
if err != nil {
|
|
xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err)
|
|
return next(c)
|
|
}
|
|
|
|
if len(modelNames) == 0 {
|
|
xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed")
|
|
// This is non-fatal - making it so was breaking the case of direct installation of raw models
|
|
// return errors.New("this endpoint requires at least one model to be installed")
|
|
return next(c)
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
|
|
xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0])
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TODO: If context and cancel above belong on all methods, move that part of above into here!
|
|
// Otherwise, it's in its own method below for now
|
|
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc {
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
input := initializer()
|
|
if input == nil {
|
|
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body")
|
|
}
|
|
if err := c.Bind(input); err != nil {
|
|
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err))
|
|
}
|
|
|
|
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
|
|
if input.ModelName(nil) == "" {
|
|
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
|
if ok && localModelName != "" {
|
|
xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName)
|
|
input.ModelName(&localModelName)
|
|
}
|
|
}
|
|
|
|
modelName := input.ModelName(nil)
|
|
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(modelName, re.applicationConfig)
|
|
|
|
if err != nil {
|
|
xlog.Warn("Model Configuration File not found", "model", modelName, "error", err)
|
|
} else if cfg.Model == "" && modelName != "" {
|
|
xlog.Debug("config does not include model, using input", "input.ModelName", modelName)
|
|
cfg.Model = modelName
|
|
}
|
|
|
|
// If a model name was specified, verify it actually exists before proceeding.
|
|
// Check both configured models and loose model files in the model path.
|
|
// Skip the check for HuggingFace model IDs (contain "/") since backends
|
|
// like diffusers may download these on the fly.
|
|
if modelName != "" && !strings.Contains(modelName, "/") {
|
|
exists, existsErr := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, modelName, galleryop.ALWAYS_INCLUDE)
|
|
if existsErr == nil && !exists {
|
|
return c.JSON(http.StatusNotFound, schema.ErrorResponse{
|
|
Error: &schema.APIError{
|
|
Message: fmt.Sprintf("model %q not found. To see available models, call GET /v1/models", modelName),
|
|
Code: http.StatusNotFound,
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
// Check if the model is disabled
|
|
if cfg != nil && cfg.IsDisabled() {
|
|
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
|
Error: &schema.APIError{
|
|
Message: fmt.Sprintf("model %q is disabled and cannot be loaded. Enable it via the System page or API to use it.", modelName),
|
|
Code: http.StatusForbidden,
|
|
Type: "model_disabled",
|
|
},
|
|
})
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
|
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
|
|
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
|
if !ok || input.Model == "" {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
|
if !ok || cfg == nil {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
// Extract or generate the correlation ID
|
|
correlationID := c.Request().Header.Get("X-Correlation-ID")
|
|
if correlationID == "" {
|
|
correlationID = uuid.New().String()
|
|
}
|
|
c.Response().Header().Set("X-Correlation-ID", correlationID)
|
|
|
|
// Use the request context directly - Echo properly supports context cancellation!
|
|
// No need for workarounds like handleConnectionCancellation
|
|
reqCtx := c.Request().Context()
|
|
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
|
|
|
// Cancel when request context is cancelled (client disconnects)
|
|
go func() {
|
|
select {
|
|
case <-reqCtx.Done():
|
|
cancel()
|
|
case <-c1.Done():
|
|
// Already cancelled
|
|
}
|
|
}()
|
|
|
|
// Add the correlation ID to the new context
|
|
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
|
|
|
input.Context = ctxWithCorrelationID
|
|
input.Cancel = cancel
|
|
|
|
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if cfg.Model == "" {
|
|
xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model)
|
|
cfg.Model = input.Model
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
|
|
|
return nil
|
|
}
|
|
|
|
// extractToolChoiceFunctionName parses a tool_choice map and returns the
|
|
// specific function name. Accepts both the OpenAI-spec nested shape
|
|
// ({type:function, function:{name:...}}) and the legacy/Anthropic-compat
|
|
// flat shape ({type:function, name:...}); the nested form wins when both
|
|
// are present. Returns "" for malformed input or when the shape names a
|
|
// mode rather than a specific tool.
|
|
func extractToolChoiceFunctionName(m map[string]any) string {
|
|
tcType, ok := m["type"].(string)
|
|
if !ok || tcType != "function" {
|
|
return ""
|
|
}
|
|
if fn, ok := m["function"].(map[string]any); ok {
|
|
if n, ok := fn["name"].(string); ok && n != "" {
|
|
return n
|
|
}
|
|
}
|
|
if n, ok := m["name"].(string); ok {
|
|
return n
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
|
if input.Echo {
|
|
config.Echo = input.Echo
|
|
}
|
|
if input.TopK != nil {
|
|
config.TopK = input.TopK
|
|
}
|
|
if input.TopP != nil {
|
|
config.TopP = input.TopP
|
|
}
|
|
if input.MinP != nil {
|
|
config.MinP = input.MinP
|
|
}
|
|
|
|
if input.Backend != "" {
|
|
config.Backend = input.Backend
|
|
}
|
|
|
|
if input.ClipSkip != 0 {
|
|
config.Diffusers.ClipSkip = input.ClipSkip
|
|
}
|
|
|
|
if input.NegativePromptScale != 0 {
|
|
config.NegativePromptScale = input.NegativePromptScale
|
|
}
|
|
|
|
if input.NegativePrompt != "" {
|
|
config.NegativePrompt = input.NegativePrompt
|
|
}
|
|
|
|
if input.RopeFreqBase != 0 {
|
|
config.RopeFreqBase = input.RopeFreqBase
|
|
}
|
|
|
|
if input.RopeFreqScale != 0 {
|
|
config.RopeFreqScale = input.RopeFreqScale
|
|
}
|
|
|
|
if input.Grammar != "" {
|
|
config.Grammar = input.Grammar
|
|
}
|
|
|
|
if input.Temperature != nil {
|
|
config.Temperature = input.Temperature
|
|
}
|
|
|
|
if input.Maxtokens != nil {
|
|
config.Maxtokens = input.Maxtokens
|
|
}
|
|
|
|
if input.ResponseFormat != nil {
|
|
switch responseFormat := input.ResponseFormat.(type) {
|
|
case string:
|
|
config.ResponseFormat = responseFormat
|
|
case map[string]any:
|
|
config.ResponseFormatMap = responseFormat
|
|
}
|
|
}
|
|
|
|
switch stop := input.Stop.(type) {
|
|
case string:
|
|
if stop != "" {
|
|
config.StopWords = append(config.StopWords, stop)
|
|
}
|
|
case []any:
|
|
for _, pp := range stop {
|
|
if s, ok := pp.(string); ok {
|
|
config.StopWords = append(config.StopWords, s)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(input.Tools) > 0 {
|
|
for _, tool := range input.Tools {
|
|
input.Functions = append(input.Functions, tool.Function)
|
|
}
|
|
}
|
|
|
|
if input.ToolsChoice != nil {
|
|
// OpenAI tool_choice has three valid shapes plus one tolerated
|
|
// non-spec form seen in the wild:
|
|
//
|
|
// 1. string mode: "auto" | "none" | "required"
|
|
// 2. specific tool: {"type":"function", "function":{"name":"..."}} (current spec)
|
|
// 3. legacy: {"type":"function", "name":"..."} (older / Anthropic-compat)
|
|
// 4. double-encoded: "{\"type\":\"function\", ...}" (some clients serialize the object)
|
|
//
|
|
// The pre-#9559 code unmarshalled the string case through
|
|
// json.Unmarshal([]byte(content), &functions.Tool{}), which:
|
|
// - failed for plain string modes (so "required" / "none" were
|
|
// silently ignored and tools stayed enabled regardless), but
|
|
// - happened to handle shape 4 by accident.
|
|
// It also could not parse shape 3 because functions.Tool has no
|
|
// flat top-level Name field.
|
|
//
|
|
// Mirror the parsing pattern from MergeOpenResponsesConfig (#9509),
|
|
// route results through the existing input.FunctionCall string/map
|
|
// dispatch downstream (see the switch on input.FunctionCall in this
|
|
// same function), and preserve the shape-4 fallback so non-spec
|
|
// clients don't silently break. Tracked in #9508; sibling fix in #9526.
|
|
switch content := input.ToolsChoice.(type) {
|
|
case string:
|
|
// "auto" is the default and needs no override. "none" and "required"
|
|
// both reach SetFunctionCallString via the input.FunctionCall string
|
|
// branch below; ShouldUseFunctions() then returns false for "none"
|
|
// (tools disabled) and true for "required" (mode engaged).
|
|
//
|
|
// If the string looks like a JSON object, try shape 4 first: parse
|
|
// it as a tool_choice map and use the resulting name. Falling back
|
|
// to mode-string handling when the parse yields no usable name keeps
|
|
// genuinely-malformed input from accidentally engaging a mode.
|
|
if content == "" || content == "auto" {
|
|
break
|
|
}
|
|
if strings.HasPrefix(strings.TrimSpace(content), "{") {
|
|
var nested map[string]any
|
|
if err := json.Unmarshal([]byte(content), &nested); err == nil {
|
|
if name := extractToolChoiceFunctionName(nested); name != "" {
|
|
input.FunctionCall = map[string]any{"name": name}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
input.FunctionCall = content
|
|
case map[string]any:
|
|
if name := extractToolChoiceFunctionName(content); name != "" {
|
|
input.FunctionCall = map[string]any{"name": name}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Decode each request's message content
|
|
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
|
for i, m := range input.Messages {
|
|
nrOfImgsInMessage := 0
|
|
nrOfVideosInMessage := 0
|
|
nrOfAudiosInMessage := 0
|
|
|
|
switch content := m.Content.(type) {
|
|
case string:
|
|
input.Messages[i].StringContent = content
|
|
case []any:
|
|
dat, _ := json.Marshal(content)
|
|
c := []schema.Content{}
|
|
json.Unmarshal(dat, &c)
|
|
|
|
textContent := ""
|
|
// we will template this at the end
|
|
|
|
CONTENT:
|
|
for _, pp := range c {
|
|
switch pp.Type {
|
|
case "text":
|
|
textContent += pp.Text
|
|
//input.Messages[i].StringContent = pp.Text
|
|
case "video", "video_url":
|
|
// Decode content as base64 either if it's an URL or base64 text
|
|
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
|
if err != nil {
|
|
xlog.Error("Failed encoding video", "error", err)
|
|
continue CONTENT
|
|
}
|
|
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
|
vidIndex++
|
|
nrOfVideosInMessage++
|
|
case "audio_url", "audio":
|
|
// Decode content as base64 either if it's an URL or base64 text
|
|
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
|
if err != nil {
|
|
xlog.Error("Failed encoding audio", "error", err)
|
|
continue CONTENT
|
|
}
|
|
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
|
audioIndex++
|
|
nrOfAudiosInMessage++
|
|
case "input_audio":
|
|
// TODO: make sure that we only return base64 stuff
|
|
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
|
|
audioIndex++
|
|
nrOfAudiosInMessage++
|
|
case "image_url", "image":
|
|
// Decode content as base64 either if it's an URL or base64 text
|
|
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
|
if err != nil {
|
|
xlog.Error("Failed encoding image", "error", err)
|
|
continue CONTENT
|
|
}
|
|
|
|
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
|
|
|
imgIndex++
|
|
nrOfImgsInMessage++
|
|
}
|
|
}
|
|
|
|
// When the backend handles templating itself (UseTokenizerTemplate),
|
|
// it also injects media markers server-side (see
|
|
// oaicompat_chat_params_parse in llama.cpp). Emitting our own markers
|
|
// here would double-mark them and downstream consumers ignore
|
|
// StringContent in that path anyway, so just pass through plain text.
|
|
if config.TemplateConfig.UseTokenizerTemplate {
|
|
input.Messages[i].StringContent = textContent
|
|
} else {
|
|
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
|
TotalImages: imgIndex,
|
|
TotalVideos: vidIndex,
|
|
TotalAudios: audioIndex,
|
|
ImagesInMessage: nrOfImgsInMessage,
|
|
VideosInMessage: nrOfVideosInMessage,
|
|
AudiosInMessage: nrOfAudiosInMessage,
|
|
}, textContent)
|
|
}
|
|
}
|
|
}
|
|
|
|
if input.RepeatPenalty != 0 {
|
|
config.RepeatPenalty = input.RepeatPenalty
|
|
}
|
|
|
|
if input.FrequencyPenalty != 0 {
|
|
config.FrequencyPenalty = input.FrequencyPenalty
|
|
}
|
|
|
|
if input.PresencePenalty != 0 {
|
|
config.PresencePenalty = input.PresencePenalty
|
|
}
|
|
|
|
if input.Keep != 0 {
|
|
config.Keep = input.Keep
|
|
}
|
|
|
|
if input.Batch != 0 {
|
|
config.Batch = input.Batch
|
|
}
|
|
|
|
if input.IgnoreEOS {
|
|
config.IgnoreEOS = input.IgnoreEOS
|
|
}
|
|
|
|
if input.Seed != nil {
|
|
config.Seed = input.Seed
|
|
}
|
|
|
|
if input.TypicalP != nil {
|
|
config.TypicalP = input.TypicalP
|
|
}
|
|
|
|
xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input))
|
|
|
|
switch inputs := input.Input.(type) {
|
|
case string:
|
|
if inputs != "" {
|
|
config.InputStrings = append(config.InputStrings, inputs)
|
|
}
|
|
case []any:
|
|
for _, pp := range inputs {
|
|
switch i := pp.(type) {
|
|
case string:
|
|
config.InputStrings = append(config.InputStrings, i)
|
|
case []any:
|
|
tokens := []int{}
|
|
inputStrings := []string{}
|
|
for _, ii := range i {
|
|
switch ii := ii.(type) {
|
|
case int:
|
|
tokens = append(tokens, ii)
|
|
case float64:
|
|
tokens = append(tokens, int(ii))
|
|
case string:
|
|
inputStrings = append(inputStrings, ii)
|
|
default:
|
|
xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii))
|
|
}
|
|
}
|
|
config.InputToken = append(config.InputToken, tokens)
|
|
config.InputStrings = append(config.InputStrings, inputStrings...)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Can be either a string or an object
|
|
switch fnc := input.FunctionCall.(type) {
|
|
case string:
|
|
if fnc != "" {
|
|
config.SetFunctionCallString(fnc)
|
|
}
|
|
case map[string]any:
|
|
var name string
|
|
n, exists := fnc["name"]
|
|
if exists {
|
|
nn, e := n.(string)
|
|
if e {
|
|
name = nn
|
|
}
|
|
}
|
|
config.SetFunctionCallNameString(name)
|
|
}
|
|
|
|
switch p := input.Prompt.(type) {
|
|
case string:
|
|
config.PromptStrings = append(config.PromptStrings, p)
|
|
case []any:
|
|
for _, pp := range p {
|
|
if s, ok := pp.(string); ok {
|
|
config.PromptStrings = append(config.PromptStrings, s)
|
|
}
|
|
}
|
|
}
|
|
|
|
// If a quality was defined as number, convert it to step
|
|
if input.Quality != "" {
|
|
q, err := strconv.Atoi(input.Quality)
|
|
if err == nil {
|
|
config.Step = q
|
|
}
|
|
}
|
|
|
|
if valid, _ := config.Validate(); valid {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("unable to validate configuration after merging")
|
|
}
|
|
|
|
func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error {
|
|
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest)
|
|
if !ok || input.Model == "" {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
// Convert input items to Messages (this will be done in the endpoint handler)
|
|
// We store the input in the request for the endpoint to process
|
|
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
|
if !ok || cfg == nil {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
// Extract or generate the correlation ID (Open Responses uses x-request-id)
|
|
correlationID := c.Request().Header.Get("x-request-id")
|
|
if correlationID == "" {
|
|
correlationID = uuid.New().String()
|
|
}
|
|
c.Response().Header().Set("x-request-id", correlationID)
|
|
|
|
// Use the request context directly - Echo properly supports context cancellation!
|
|
reqCtx := c.Request().Context()
|
|
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
|
|
|
// Cancel when request context is cancelled (client disconnects)
|
|
go func() {
|
|
select {
|
|
case <-reqCtx.Done():
|
|
cancel()
|
|
case <-c1.Done():
|
|
// Already cancelled
|
|
}
|
|
}()
|
|
|
|
// Add the correlation ID to the new context
|
|
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
|
|
|
input.Context = ctxWithCorrelationID
|
|
input.Cancel = cancel
|
|
|
|
err := MergeOpenResponsesConfig(cfg, input)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if cfg.Model == "" {
|
|
xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model)
|
|
cfg.Model = input.Model
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
|
|
|
return nil
|
|
}
|
|
|
|
// MergeOpenResponsesConfig merges request parameters into the model configuration.
|
|
func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error {
|
|
// Temperature
|
|
if input.Temperature != nil {
|
|
config.Temperature = input.Temperature
|
|
}
|
|
|
|
// TopP
|
|
if input.TopP != nil {
|
|
config.TopP = input.TopP
|
|
}
|
|
|
|
// MaxOutputTokens -> Maxtokens
|
|
if input.MaxOutputTokens != nil {
|
|
config.Maxtokens = input.MaxOutputTokens
|
|
}
|
|
|
|
// Convert tools to functions - this will be handled in the endpoint handler
|
|
// We just validate that tools are present if needed
|
|
|
|
// Handle tool_choice
|
|
if input.ToolChoice != nil {
|
|
switch tc := input.ToolChoice.(type) {
|
|
case string:
|
|
// "auto", "required", or "none"
|
|
if tc == "required" {
|
|
config.SetFunctionCallString("required")
|
|
} else if tc == "none" {
|
|
// Don't use tools - handled in endpoint
|
|
}
|
|
// "auto" is default - let model decide
|
|
case map[string]any:
|
|
// Specific tool. OpenAI spec nests the function name under "function":
|
|
// {"type":"function", "function":{"name":"..."}}
|
|
// Legacy/Anthropic-compat form puts it at the top level:
|
|
// {"type":"function", "name":"..."}
|
|
// The old code only handled the legacy shape AND used the wrong
|
|
// setter (SetFunctionCallString writes the mode field; the
|
|
// specific-function name lives in a separate field read by
|
|
// ShouldCallSpecificFunction / FunctionToCall). Net effect: a
|
|
// correctly-formed OpenAI tool_choice never engaged grammar-based
|
|
// forcing, the model got the tools but no selection hint, and
|
|
// streamed raw JSON as delta.content instead of delta.tool_calls.
|
|
if tcType, ok := tc["type"].(string); ok && tcType == "function" {
|
|
var name string
|
|
if fn, ok := tc["function"].(map[string]any); ok {
|
|
if n, ok := fn["name"].(string); ok {
|
|
name = n
|
|
}
|
|
}
|
|
if name == "" {
|
|
if n, ok := tc["name"].(string); ok {
|
|
name = n
|
|
}
|
|
}
|
|
if name != "" {
|
|
config.SetFunctionCallNameString(name)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if valid, _ := config.Validate(); valid {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("unable to validate configuration after merging")
|
|
}
|