diff --git a/core/cli/run.go b/core/cli/run.go index 517052b9c..a67b35fad 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -83,6 +83,7 @@ type RunCMD struct { EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"` TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"` AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"` + OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"` Version bool } @@ -249,6 +250,15 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { opts = append(opts, config.WithLRUEvictionRetryInterval(dur)) } + // Handle Open Responses store TTL + if r.OpenResponsesStoreTTL != "" && r.OpenResponsesStoreTTL != "0" { + dur, err := time.ParseDuration(r.OpenResponsesStoreTTL) + if err != nil { + return fmt.Errorf("invalid Open Responses store TTL: %w", err) + } + opts = append(opts, config.WithOpenResponsesStoreTTL(dur)) + } + // split ":" to get backend name and the uri for _, v := range r.ExternalGRPCBackends { backend := v[:strings.IndexByte(v, ':')] diff --git a/core/config/application_config.go b/core/config/application_config.go index 26b603f53..e96e8ac58 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -86,6 +86,8 @@ type ApplicationConfig struct { AgentJobRetentionDays int // Default: 30 days + OpenResponsesStoreTTL time.Duration // TTL for Open Responses store (0 = no expiration) + PathWithoutAuth []string } @@ -467,6 +469,12 @@ func WithAgentJobRetentionDays(days int) AppOption { } } +func WithOpenResponsesStoreTTL(ttl time.Duration) AppOption { + return func(o *ApplicationConfig) { + o.OpenResponsesStoreTTL = ttl + } +} + func WithEnforcedPredownloadScans(enforced bool) AppOption { return func(o *ApplicationConfig) { o.EnforcePredownloadScans = enforced @@ -594,6 +602,12 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { } else { lruEvictionRetryInterval = "1s" // default } + var openResponsesStoreTTL string + if o.OpenResponsesStoreTTL > 0 { + openResponsesStoreTTL = o.OpenResponsesStoreTTL.String() + } else { + openResponsesStoreTTL = "0" // default: no expiration + } return RuntimeSettings{ WatchdogEnabled: &watchdogEnabled, @@ -628,6 +642,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { AutoloadBackendGalleries: &autoloadBackendGalleries, ApiKeys: &apiKeys, AgentJobRetentionDays: &agentJobRetentionDays, + OpenResponsesStoreTTL: &openResponsesStoreTTL, } } @@ -769,6 +784,14 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req if settings.AgentJobRetentionDays != nil { o.AgentJobRetentionDays = *settings.AgentJobRetentionDays } + if settings.OpenResponsesStoreTTL != nil { + if *settings.OpenResponsesStoreTTL == "0" || *settings.OpenResponsesStoreTTL == "" { + o.OpenResponsesStoreTTL = 0 // No expiration + } else if dur, err := time.ParseDuration(*settings.OpenResponsesStoreTTL); err == nil { + o.OpenResponsesStoreTTL = dur + } + // This setting doesn't require restart, can be updated dynamically + } // Note: ApiKeys requires special handling (merging with startup keys) - handled in caller return requireRestart diff --git a/core/config/runtime_settings.go b/core/config/runtime_settings.go index 1a7f6db81..9c4d4531d 100644 --- a/core/config/runtime_settings.go +++ b/core/config/runtime_settings.go @@ -60,4 +60,7 @@ type RuntimeSettings struct { // Agent settings AgentJobRetentionDays *int `json:"agent_job_retention_days,omitempty"` + + // Open Responses settings + OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration) } diff --git a/core/http/app.go b/core/http/app.go index 376a403dc..88cd3ffa8 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -193,6 +193,8 @@ func API(application *application.Application) (*echo.Echo, error) { corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",") } e.Use(middleware.CORSWithConfig(corsConfig)) + } else { + e.Use(middleware.CORS()) } // CSRF middleware @@ -214,6 +216,7 @@ func API(application *application.Application) (*echo.Echo, error) { routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application) routes.RegisterOpenAIRoutes(e, requestExtractor, application) routes.RegisterAnthropicRoutes(e, requestExtractor, application) + routes.RegisterOpenResponsesRoutes(e, requestExtractor, application) if !application.ApplicationConfig().DisableWebUI { routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application) routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService()) diff --git a/core/http/endpoints/localai/settings.go b/core/http/endpoints/localai/settings.go index 93746baaa..ce835e983 100644 --- a/core/http/endpoints/localai/settings.go +++ b/core/http/endpoints/localai/settings.go @@ -11,6 +11,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/openresponses" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/xlog" @@ -84,6 +85,16 @@ func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc { }) } } + if settings.OpenResponsesStoreTTL != nil { + if *settings.OpenResponsesStoreTTL != "0" && *settings.OpenResponsesStoreTTL != "" { + if _, err := time.ParseDuration(*settings.OpenResponsesStoreTTL); err != nil { + return c.JSON(http.StatusBadRequest, schema.SettingsResponse{ + Success: false, + Error: "Invalid open_responses_store_ttl format: " + err.Error(), + }) + } + } + } // Save to file if appConfig.DynamicConfigsDir == "" { @@ -144,6 +155,22 @@ func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc { xlog.Info("Updated LRU eviction retry settings", "maxRetries", maxRetries, "retryInterval", retryInterval) } + // Update Open Responses store TTL dynamically + if settings.OpenResponsesStoreTTL != nil { + ttl := time.Duration(0) + if *settings.OpenResponsesStoreTTL != "0" && *settings.OpenResponsesStoreTTL != "" { + if dur, err := time.ParseDuration(*settings.OpenResponsesStoreTTL); err == nil { + ttl = dur + } else { + xlog.Warn("Invalid Open Responses store TTL format", "ttl", *settings.OpenResponsesStoreTTL, "error", err) + } + } + // Import the store package + store := openresponses.GetGlobalStore() + store.SetTTL(ttl) + xlog.Info("Updated Open Responses store TTL", "ttl", ttl) + } + // Check if agent job retention changed agentJobChanged := settings.AgentJobRetentionDays != nil diff --git a/core/http/endpoints/openresponses/responses.go b/core/http/endpoints/openresponses/responses.go new file mode 100644 index 000000000..337978506 --- /dev/null +++ b/core/http/endpoints/openresponses/responses.go @@ -0,0 +1,3301 @@ +package openresponses + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "time" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/templates" + "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" + "github.com/mudler/cogito" + "github.com/mudler/xlog" +) + +// ResponsesEndpoint is the Open Responses API endpoint +// https://www.openresponses.org/specification +// @Summary Create a response using the Open Responses API +// @Param request body schema.OpenResponsesRequest true "Request body" +// @Success 200 {object} schema.ORResponseResource "Response" +// @Router /v1/responses [post] +func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + createdAt := time.Now().Unix() + responseID := fmt.Sprintf("resp_%s", uuid.New().String()) + + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) + if !ok || input.Model == "" { + return sendOpenResponsesError(c, 400, "invalid_request", "model is required", "") + } + + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return sendOpenResponsesError(c, 400, "invalid_request", "model configuration not found", "") + } + + // Initialize store with TTL from appConfig + store := GetGlobalStore() + if appConfig.OpenResponsesStoreTTL > 0 { + store.SetTTL(appConfig.OpenResponsesStoreTTL) + } + + // Check if storage is disabled for this request + shouldStore := true + if input.Store != nil && !*input.Store { + shouldStore = false + } + + // Handle previous_response_id if provided + var previousResponse *schema.ORResponseResource + var messages []schema.Message + if input.PreviousResponseID != "" { + stored, err := store.Get(input.PreviousResponseID) + if err != nil { + return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("previous response not found: %s", input.PreviousResponseID), "previous_response_id") + } + previousResponse = stored.Response + + // Also convert previous response input to messages + previousInputMessages, err := convertORInputToMessages(stored.Request.Input, cfg) + if err != nil { + return sendOpenResponsesError(c, 400, "invalid_request", fmt.Sprintf("failed to convert previous input: %v", err), "") + } + + // Convert previous response output items to messages + previousOutputMessages, err := convertOROutputItemsToMessages(previousResponse.Output) + if err != nil { + return sendOpenResponsesError(c, 400, "invalid_request", fmt.Sprintf("failed to convert previous response: %v", err), "") + } + + // Concatenate: previous_input + previous_output + new_input + // Start with previous input messages + messages = previousInputMessages + // Add previous output as assistant messages + messages = append(messages, previousOutputMessages...) + } + + // Convert Open Responses input to internal Messages + newMessages, err := convertORInputToMessages(input.Input, cfg) + if err != nil { + return sendOpenResponsesError(c, 400, "invalid_request", fmt.Sprintf("failed to parse input: %v", err), "") + } + // Append new input messages + messages = append(messages, newMessages...) + + // Add instructions as system message if provided + if input.Instructions != "" { + messages = append([]schema.Message{{Role: "system", StringContent: input.Instructions}}, messages...) + } + + // Handle tools + var funcs functions.Functions + var shouldUseFn bool + var useMCP bool + + if len(input.Tools) > 0 { + // User-provided tools + funcs, shouldUseFn = convertORToolsToFunctions(input, cfg) + } else if cfg.MCP.Servers != "" || cfg.MCP.Stdio != "" { + // MCP tools (internal) + useMCP = true + } + + // Create OpenAI-compatible request for internal processing + openAIReq := &schema.OpenAIRequest{ + PredictionOptions: schema.PredictionOptions{ + BasicModelRequest: schema.BasicModelRequest{Model: input.Model}, + Temperature: input.Temperature, + TopP: input.TopP, + Maxtokens: input.MaxOutputTokens, + }, + Messages: messages, + Stream: input.Stream, + Context: input.Context, + Cancel: input.Cancel, + Functions: funcs, + } + + // Handle text_format -> response_format conversion + if input.TextFormat != nil { + openAIReq.ResponseFormat = convertTextFormatToResponseFormat(input.TextFormat) + } + + // Generate grammar for function calling (similar to OpenAI chat endpoint) + if shouldUseFn && !cfg.FunctionsConfig.GrammarConfig.NoGrammar { + // Add no-action function to allow model to respond without calling a tool + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + if cfg.FunctionsConfig.NoActionFunctionName != "" { + noActionName = cfg.FunctionsConfig.NoActionFunctionName + } + if cfg.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = cfg.FunctionsConfig.NoActionDescriptionName + } + + noActionGrammar := functions.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }, + }, + }, + } + + // Make a copy of funcs to avoid modifying the original + funcsWithNoAction := make(functions.Functions, len(funcs)) + copy(funcsWithNoAction, funcs) + + // Append no-action function unless disabled + if !cfg.FunctionsConfig.DisableNoAction { + funcsWithNoAction = append(funcsWithNoAction, noActionGrammar) + } + + // Force picking one of the functions by the request + if cfg.FunctionToCall() != "" { + funcsWithNoAction = funcsWithNoAction.Select(cfg.FunctionToCall()) + } + + // Generate grammar to constrain model output to valid function calls + jsStruct := funcsWithNoAction.ToJSONStructure(cfg.FunctionsConfig.FunctionNameKey, cfg.FunctionsConfig.FunctionNameKey) + g, err := jsStruct.Grammar(cfg.FunctionsConfig.GrammarOptions()...) + if err == nil { + cfg.Grammar = g + xlog.Debug("Open Responses - Generated grammar for function calling") + } else { + xlog.Error("Open Responses - Failed generating grammar for function calling", "error", err) + } + } + + // Template the prompt + predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) + xlog.Debug("Open Responses - Prompt (after templating)", "prompt", predInput) + + // Handle background mode + isBackground := input.Background != nil && *input.Background + if isBackground { + // Background mode requires storage + if !shouldStore { + return sendOpenResponsesError(c, 400, "invalid_request_error", "background=true requires store=true", "background") + } + + // Create initial response with "queued" status + queuedResponse := buildORResponse(responseID, createdAt, nil, schema.ORStatusQueued, input, []schema.ORItemField{}, nil, true) + + // Create cancellable context for background execution + bgCtx, bgCancel := context.WithCancel(context.Background()) + + // Store the background response + store.StoreBackground(responseID, input, queuedResponse, bgCancel, input.Stream) + + // Start background processing goroutine + go func() { + defer bgCancel() + + // Update status to in_progress + store.UpdateStatus(responseID, schema.ORStatusInProgress, nil) + + var finalResponse *schema.ORResponseResource + var bgErr error + + if useMCP { + // Background MCP processing + finalResponse, bgErr = handleBackgroundMCPResponse(bgCtx, store, responseID, createdAt, input, cfg, ml, predInput, openAIReq, appConfig) + } else if input.Stream { + // Background streaming processing (buffer events) + finalResponse, bgErr = handleBackgroundStream(bgCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn) + } else { + // Background non-streaming processing + finalResponse, bgErr = handleBackgroundNonStream(bgCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn) + } + + if bgErr != nil { + xlog.Error("Background response failed", "response_id", responseID, "error", bgErr) + now := time.Now().Unix() + store.UpdateStatus(responseID, schema.ORStatusFailed, &now) + return + } + + // Update final response in store + if finalResponse != nil { + store.UpdateResponse(responseID, finalResponse) + } + }() + + // Return immediately with queued response + return c.JSON(200, queuedResponse) + } + + if useMCP { + // Use MCP agentic loop + return handleMCPResponse(c, responseID, createdAt, input, cfg, ml, predInput, openAIReq, appConfig, shouldStore) + } + + if input.Stream { + return handleOpenResponsesStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore) + } + + return handleOpenResponsesNonStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore) + } +} + +// convertORInputToMessages converts Open Responses input to internal Messages +func convertORInputToMessages(input interface{}, cfg *config.ModelConfig) ([]schema.Message, error) { + var messages []schema.Message + + switch v := input.(type) { + case string: + // Simple string = user message + return []schema.Message{{Role: "user", StringContent: v}}, nil + case []interface{}: + // Array of items + for _, itemRaw := range v { + itemMap, ok := itemRaw.(map[string]interface{}) + if !ok { + continue + } + + itemType, _ := itemMap["type"].(string) + switch itemType { + case "message": + msg, err := convertORMessageItem(itemMap, cfg) + if err != nil { + return nil, err + } + messages = append(messages, msg) + case "function_call_output": + // Convert function call output to tool role message + callID, _ := itemMap["call_id"].(string) + output := itemMap["output"] + var outputStr string + if str, ok := output.(string); ok { + outputStr = str + } else { + // Convert to JSON string + outputBytes, _ := json.Marshal(output) + outputStr = string(outputBytes) + } + // For tool messages, we use the Name field to store the call ID + messages = append(messages, schema.Message{ + Role: "tool", + Name: callID, + Content: outputStr, + StringContent: outputStr, + }) + case "item_reference": + // Handle item references - look up item in stored responses + // According to spec, item_reference uses "id" field, not "item_id" + itemID, ok := itemMap["id"].(string) + if !ok || itemID == "" { + return nil, fmt.Errorf("item_reference missing id") + } + + store := GetGlobalStore() + item, responseID, err := store.FindItem(itemID) + if err != nil { + return nil, fmt.Errorf("item not found: %s (from response %s): %w", itemID, responseID, err) + } + + // Log item reference resolution for debugging + xlog.Debug("Resolved item reference", "item_id", itemID, "response_id", responseID, "item_type", item.Type) + + // Convert referenced item to message based on its type + msg, err := convertORItemToMessage(item, responseID) + if err != nil { + return nil, fmt.Errorf("failed to convert referenced item %s from response %s: %w", itemID, responseID, err) + } + messages = append(messages, msg) + } + } + return messages, nil + default: + return nil, fmt.Errorf("unsupported input type: %T", input) + } +} + +// convertORItemToMessage converts a single ORItemField to a Message +// responseID is the ID of the response where this item was found (for logging/debugging) +func convertORItemToMessage(item *schema.ORItemField, responseID string) (schema.Message, error) { + switch item.Type { + case "message": + // Convert message item to message + var textContent string + if contentParts, ok := item.Content.([]schema.ORContentPart); ok { + for _, part := range contentParts { + if part.Type == "output_text" || part.Type == "input_text" { + textContent += part.Text + } + } + } else if str, ok := item.Content.(string); ok { + textContent = str + } + return schema.Message{ + Role: item.Role, + StringContent: textContent, + Content: textContent, + }, nil + case "function_call_output": + // Convert function call output to tool role message + var outputStr string + if str, ok := item.Output.(string); ok { + outputStr = str + } else { + // Convert to JSON string + outputBytes, _ := json.Marshal(item.Output) + outputStr = string(outputBytes) + } + return schema.Message{ + Role: "tool", + Name: item.CallID, + Content: outputStr, + StringContent: outputStr, + }, nil + default: + return schema.Message{}, fmt.Errorf("unsupported item type for conversion: %s (from response %s)", item.Type, responseID) + } +} + +// convertOROutputItemsToMessages converts Open Responses output items to internal Messages +func convertOROutputItemsToMessages(outputItems []schema.ORItemField) ([]schema.Message, error) { + var messages []schema.Message + + for _, item := range outputItems { + switch item.Type { + case "message": + // Convert message item to assistant message + var textContent string + if contentParts, ok := item.Content.([]schema.ORContentPart); ok && len(contentParts) > 0 { + for _, part := range contentParts { + if part.Type == "output_text" { + textContent += part.Text + } + } + } + messages = append(messages, schema.Message{ + Role: item.Role, + StringContent: textContent, + Content: textContent, + }) + case "function_call": + // Function calls are handled separately - they become tool calls in the next turn + // For now, we skip them as they're part of the model's output, not input + case "function_call_output": + // Convert function call output to tool role message + var outputStr string + if str, ok := item.Output.(string); ok { + outputStr = str + } else { + // Convert to JSON string + outputBytes, _ := json.Marshal(item.Output) + outputStr = string(outputBytes) + } + messages = append(messages, schema.Message{ + Role: "tool", + Name: item.CallID, + Content: outputStr, + StringContent: outputStr, + }) + } + } + + return messages, nil +} + +// convertORMessageItem converts an Open Responses message item to internal Message +func convertORMessageItem(itemMap map[string]interface{}, cfg *config.ModelConfig) (schema.Message, error) { + role, _ := itemMap["role"].(string) + msg := schema.Message{Role: role} + + content := itemMap["content"] + switch contentVal := content.(type) { + case string: + msg.StringContent = contentVal + msg.Content = contentVal + case []interface{}: + // Array of content parts + var textContent string + var stringImages []string + var stringVideos []string + var stringAudios []string + + for _, partRaw := range contentVal { + partMap, ok := partRaw.(map[string]interface{}) + if !ok { + continue + } + + partType, _ := partMap["type"].(string) + switch partType { + case "input_text": + if text, ok := partMap["text"].(string); ok { + textContent += text + } + case "input_image": + if imageURL, ok := partMap["image_url"].(string); ok { + // Convert to base64 data URI + base64, err := utils.GetContentURIAsBase64(imageURL) + if err != nil { + xlog.Error("Failed encoding image", "error", err) + continue + } + stringImages = append(stringImages, base64) + } + case "input_file": + if fileURL, ok := partMap["file_url"].(string); ok { + // Convert to base64 + base64, err := utils.GetContentURIAsBase64(fileURL) + if err != nil { + xlog.Error("Failed encoding file", "error", err) + continue + } + // For now, treat files as text content + textContent += base64 + } else if fileData, ok := partMap["file_data"].(string); ok { + // Already base64 + textContent += fileData + } + case "input_video": + if videoURL, ok := partMap["video_url"].(string); ok { + // Convert to base64 data URI + base64, err := utils.GetContentURIAsBase64(videoURL) + if err != nil { + xlog.Error("Failed encoding video", "error", err) + continue + } + stringVideos = append(stringVideos, base64) + } + case "input_audio": + if audioURL, ok := partMap["audio_url"].(string); ok { + // Convert to base64 data URI + base64, err := utils.GetContentURIAsBase64(audioURL) + if err != nil { + xlog.Error("Failed encoding audio", "error", err) + continue + } + stringAudios = append(stringAudios, base64) + } + } + } + + msg.StringContent = textContent + msg.Content = textContent + msg.StringImages = stringImages + msg.StringVideos = stringVideos + msg.StringAudios = stringAudios + + // Template multimodal content + if len(stringImages) > 0 || len(stringVideos) > 0 || len(stringAudios) > 0 { + msg.StringContent, _ = templates.TemplateMultiModal(cfg.TemplateConfig.Multimodal, templates.MultiModalOptions{ + TotalImages: len(stringImages), + TotalVideos: len(stringVideos), + TotalAudios: len(stringAudios), + ImagesInMessage: len(stringImages), + VideosInMessage: len(stringVideos), + AudiosInMessage: len(stringAudios), + }, textContent) + } + } + + return msg, nil +} + +// convertORToolsToFunctions converts Open Responses tools to internal Functions +func convertORToolsToFunctions(input *schema.OpenResponsesRequest, cfg *config.ModelConfig) (functions.Functions, bool) { + if len(input.Tools) == 0 { + return nil, false + } + + // Build allowed tools set if specified + allowedSet := make(map[string]bool) + if len(input.AllowedTools) > 0 { + for _, name := range input.AllowedTools { + allowedSet[name] = true + } + } + + var funcs functions.Functions + for _, tool := range input.Tools { + if tool.Type == "function" { + // Skip if not in allowed list (when allowed_tools is specified) + if len(allowedSet) > 0 && !allowedSet[tool.Name] { + continue + } + f := functions.Function{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.Parameters, + } + funcs = append(funcs, f) + } + } + + // Handle tool_choice + if input.ToolChoice != nil { + switch tc := input.ToolChoice.(type) { + case string: + switch tc { + case "required": + cfg.SetFunctionCallString("required") + case "none": + return nil, false + case "auto": + // "auto" is the default - let model decide whether to use tools + // Tools are available but not forced + } + case map[string]interface{}: + if tcType, ok := tc["type"].(string); ok && tcType == "function" { + if name, ok := tc["name"].(string); ok { + cfg.SetFunctionCallString(name) + } + } + } + } + + return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions() +} + +// convertTextFormatToResponseFormat converts Open Responses text_format to OpenAI response_format +func convertTextFormatToResponseFormat(textFormat interface{}) interface{} { + switch tf := textFormat.(type) { + case map[string]interface{}: + if tfType, ok := tf["type"].(string); ok { + if tfType == "json_schema" { + return map[string]interface{}{ + "type": "json_schema", + "json_schema": tf, + } + } + return map[string]interface{}{"type": tfType} + } + case string: + return map[string]interface{}{"type": tf} + } + return nil +} + +// handleBackgroundNonStream handles background non-streaming responses +func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) (*schema.ORResponseResource, error) { + images := []string{} + videos := []string{} + audios := []string{} + for _, m := range openAIReq.Messages { + images = append(images, m.StringImages...) + videos = append(videos, m.StringVideos...) + audios = append(audios, m.StringAudios...) + } + + toolsJSON := serializeToolsForBackend(input.Tools) + toolChoiceJSON := "" + if input.ToolChoice != nil { + toolChoiceBytes, err := json.Marshal(input.ToolChoice) + if err == nil { + toolChoiceJSON = string(toolChoiceBytes) + } + } + + var logprobs *int + if input.TopLogprobs != nil && *input.TopLogprobs > 0 { + logprobs = input.TopLogprobs + } + + predFunc, err := backend.ModelInference( + ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias) + if err != nil { + return nil, fmt.Errorf("model inference failed: %w", err) + } + + // Check for cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + prediction, err := predFunc() + if err != nil { + return nil, fmt.Errorf("prediction failed: %w", err) + } + + result := backend.Finetune(*cfg, predInput, prediction.Response) + + // Parse tool calls if using functions (same logic as regular handler) + var outputItems []schema.ORItemField + var toolCalls []schema.ToolCall + + if shouldUseFn { + cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig) + funcCallResults := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) + textContent := functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig) + + noActionName := "answer" + if cfg.FunctionsConfig.NoActionFunctionName != "" { + noActionName = cfg.FunctionsConfig.NoActionFunctionName + } + + for i, fc := range funcCallResults { + if fc.Name == noActionName { + if fc.Arguments != "" { + var args map[string]interface{} + if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { + if msg, ok := args["message"].(string); ok && msg != "" { + textContent = msg + } + } + } + continue + } + toolCalls = append(toolCalls, schema.ToolCall{ + Index: i, + ID: fmt.Sprintf("fc_%s", uuid.New().String()), + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: fc.Name, + Arguments: fc.Arguments, + }, + }) + } + + if textContent != "" { + outputItems = append(outputItems, schema.ORItemField{ + Type: "message", + ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)}, + }) + } + + for _, tc := range toolCalls { + outputItems = append(outputItems, schema.ORItemField{ + Type: "function_call", + ID: fmt.Sprintf("fc_%s", uuid.New().String()), + Status: "completed", + CallID: tc.ID, + Name: tc.FunctionCall.Name, + Arguments: tc.FunctionCall.Arguments, + }) + } + + if len(outputItems) == 0 && result != "" { + outputItems = append(outputItems, schema.ORItemField{ + Type: "message", + ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)}, + }) + } + } else { + outputItems = append(outputItems, schema.ORItemField{ + Type: "message", + ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)}, + }) + } + + now := time.Now().Unix() + response := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, outputItems, &schema.ORUsage{ + InputTokens: prediction.Usage.Prompt, + OutputTokens: prediction.Usage.Completion, + TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + }, true) + + return response, nil +} + +// handleBackgroundStream handles background streaming responses with event buffering +func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) (*schema.ORResponseResource, error) { + images := []string{} + videos := []string{} + audios := []string{} + for _, m := range openAIReq.Messages { + images = append(images, m.StringImages...) + videos = append(videos, m.StringVideos...) + audios = append(audios, m.StringAudios...) + } + + toolsJSON := serializeToolsForBackend(input.Tools) + toolChoiceJSON := "" + if input.ToolChoice != nil { + toolChoiceBytes, err := json.Marshal(input.ToolChoice) + if err == nil { + toolChoiceJSON = string(toolChoiceBytes) + } + } + + sequenceNumber := 0 + + // Emit response.created + responseCreated := buildORResponse(responseID, createdAt, nil, schema.ORStatusInProgress, input, []schema.ORItemField{}, nil, true) + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.created", + SequenceNumber: sequenceNumber, + Response: responseCreated, + }) + sequenceNumber++ + + // Emit response.in_progress + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.in_progress", + SequenceNumber: sequenceNumber, + Response: responseCreated, + }) + sequenceNumber++ + + var accumulatedText string + var collectedOutputItems []schema.ORItemField + outputIndex := 0 + currentMessageID := fmt.Sprintf("msg_%s", uuid.New().String()) + + // Emit output_item.added + messageItem := &schema.ORItemField{ + Type: "message", + ID: currentMessageID, + Status: "in_progress", + Role: "assistant", + Content: []schema.ORContentPart{}, + } + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: messageItem, + }) + sequenceNumber++ + + // Emit content_part.added + currentContentIndex := 0 + emptyPart := makeOutputTextPart("") + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.content_part.added", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &emptyPart, + }) + sequenceNumber++ + + // Token callback for streaming + tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool { + select { + case <-ctx.Done(): + return false + default: + } + + accumulatedText += token + + // Buffer text delta + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_text.delta", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Delta: strPtr(token), + Logprobs: emptyLogprobs(), + }) + sequenceNumber++ + return true + } + + var streamLogprobs *int + if input.TopLogprobs != nil && *input.TopLogprobs > 0 { + streamLogprobs = input.TopLogprobs + } + + predFunc, err := backend.ModelInference( + ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias) + if err != nil { + return nil, fmt.Errorf("model inference failed: %w", err) + } + + prediction, err := predFunc() + if err != nil { + return nil, fmt.Errorf("prediction failed: %w", err) + } + + // Emit output_text.done + streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs) + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_text.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Text: strPtr(accumulatedText), + Logprobs: logprobsPtr(streamEventLogprobs), + }) + sequenceNumber++ + + // Emit content_part.done + textPart := makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs) + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.content_part.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &textPart, + }) + sequenceNumber++ + + // Emit output_item.done + completedMessageItem := &schema.ORItemField{ + Type: "message", + ID: currentMessageID, + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs)}, + } + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: completedMessageItem, + }) + sequenceNumber++ + collectedOutputItems = append(collectedOutputItems, *completedMessageItem) + + // Build final response + now := time.Now().Unix() + response := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, collectedOutputItems, &schema.ORUsage{ + InputTokens: prediction.Usage.Prompt, + OutputTokens: prediction.Usage.Completion, + TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + }, true) + + // Emit response.completed + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.completed", + SequenceNumber: sequenceNumber, + Response: response, + }) + + return response, nil +} + +// handleBackgroundMCPResponse handles background MCP responses +func handleBackgroundMCPResponse(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string, openAIReq *schema.OpenAIRequest, appConfig *config.ApplicationConfig) (*schema.ORResponseResource, error) { + // Check for cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Validate MCP config + if cfg.MCP.Servers == "" && cfg.MCP.Stdio == "" { + return nil, fmt.Errorf("no MCP servers configured") + } + + // Get MCP config from model config + remote, stdio, err := cfg.MCP.MCPConfigFromYAML() + if err != nil { + return nil, fmt.Errorf("failed to get MCP config: %w", err) + } + + // Get MCP sessions + sessions, err := mcpTools.SessionsFromMCPConfig(cfg.Name, remote, stdio) + if err != nil { + return nil, fmt.Errorf("failed to get MCP sessions: %w", err) + } + + if len(sessions) == 0 { + return nil, fmt.Errorf("no working MCP servers found") + } + + // Build fragment from messages + fragment := cogito.NewEmptyFragment() + for _, message := range openAIReq.Messages { + fragment = fragment.AddMessage(message.Role, message.StringContent) + } + fragmentPtr := &fragment + + // Get API address and key + _, port, err := net.SplitHostPort(appConfig.APIAddress) + if err != nil { + return nil, fmt.Errorf("failed to parse API address: %w", err) + } + apiKey := "" + if len(appConfig.ApiKeys) > 0 { + apiKey = appConfig.ApiKeys[0] + } + + // Create OpenAI LLM client + defaultLLM := cogito.NewOpenAILLM(cfg.Name, apiKey, "http://127.0.0.1:"+port) + + // Build cogito options + cogitoOpts := cfg.BuildCogitoOptions() + cogitoOpts = append( + cogitoOpts, + cogito.WithContext(ctx), + cogito.WithMCPs(sessions...), + ) + + if input.Stream { + return handleBackgroundMCPStream(ctx, store, responseID, createdAt, input, cfg, defaultLLM, fragmentPtr, cogitoOpts) + } + + // Non-streaming mode + return handleBackgroundMCPNonStream(ctx, store, responseID, createdAt, input, cfg, defaultLLM, fragmentPtr, cogitoOpts) +} + +// handleBackgroundMCPNonStream handles background non-streaming MCP responses +func handleBackgroundMCPNonStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, defaultLLM cogito.LLM, fragment *cogito.Fragment, cogitoOpts []cogito.Option) (*schema.ORResponseResource, error) { + frag := *fragment + + // Check for cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Set up callbacks for logging + cogitoOpts = append( + cogitoOpts, + cogito.WithStatusCallback(func(s string) { + xlog.Debug("[Open Responses MCP Background] Status", "model", cfg.Name, "status", s, "response_id", responseID) + }), + cogito.WithReasoningCallback(func(s string) { + xlog.Debug("[Open Responses MCP Background] Reasoning", "model", cfg.Name, "reasoning", s, "response_id", responseID) + }), + cogito.WithToolCallBack(func(t *cogito.ToolChoice, state *cogito.SessionState) cogito.ToolCallDecision { + xlog.Debug("[Open Responses MCP Background] Tool call", "model", cfg.Name, "tool", t.Name, "reasoning", t.Reasoning, "arguments", t.Arguments, "response_id", responseID) + return cogito.ToolCallDecision{ + Approved: true, + } + }), + cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) { + xlog.Debug("[Open Responses MCP Background] Tool call result", "model", cfg.Name, "tool", t.Name, "result", t.Result, "tool_arguments", t.ToolArguments, "response_id", responseID) + }), + ) + + // Execute tools + f, err := cogito.ExecuteTools(defaultLLM, frag, cogitoOpts...) + if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) { + return nil, fmt.Errorf("failed to execute tools: %w", err) + } + + // Check for cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Get final response + f, err = defaultLLM.Ask(ctx, f) + if err != nil { + return nil, fmt.Errorf("failed to get response: %w", err) + } + + // Convert fragment to Open Responses format + fPtr := &f + outputItems := convertCogitoFragmentToORItems(fPtr) + + // Build response with all required fields + now := time.Now().Unix() + response := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, outputItems, nil, true) + + return response, nil +} + +// handleBackgroundMCPStream handles background streaming MCP responses +func handleBackgroundMCPStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, defaultLLM cogito.LLM, fragment *cogito.Fragment, cogitoOpts []cogito.Option) (*schema.ORResponseResource, error) { + frag := *fragment + sequenceNumber := 0 + + // Emit response.created + responseCreated := buildORResponse(responseID, createdAt, nil, schema.ORStatusInProgress, input, []schema.ORItemField{}, nil, true) + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.created", + SequenceNumber: sequenceNumber, + Response: responseCreated, + }) + sequenceNumber++ + + // Emit response.in_progress + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.in_progress", + SequenceNumber: sequenceNumber, + Response: responseCreated, + }) + sequenceNumber++ + + // Create channels for streaming events + events := make(chan interface{}) + ended := make(chan error, 1) + var collectedOutputItems []schema.ORItemField + outputIndex := 0 + + // Set up callbacks + statusCallback := func(s string) { + select { + case <-ctx.Done(): + return + case events <- map[string]interface{}{ + "type": "status", + "message": s, + }: + } + } + + reasoningCallback := func(s string) { + select { + case <-ctx.Done(): + return + default: + } + itemID := fmt.Sprintf("reasoning_%s", uuid.New().String()) + outputIndex++ + item := &schema.ORItemField{ + Type: "reasoning", + ID: itemID, + Status: "in_progress", + } + collectedOutputItems = append(collectedOutputItems, *item) + + select { + case <-ctx.Done(): + return + case events <- map[string]interface{}{ + "type": "reasoning", + "item_id": itemID, + "output_index": outputIndex, + "content": s, + }: + } + } + + toolCallCallback := func(t *cogito.ToolChoice, state *cogito.SessionState) cogito.ToolCallDecision { + select { + case <-ctx.Done(): + return cogito.ToolCallDecision{Approved: false} + default: + } + toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) + outputIndex++ + item := &schema.ORItemField{ + Type: "function_call", + ID: toolCallID, + Status: "in_progress", + CallID: toolCallID, + Name: t.Name, + Arguments: "", + } + collectedOutputItems = append(collectedOutputItems, *item) + + select { + case <-ctx.Done(): + return cogito.ToolCallDecision{Approved: false} + case events <- map[string]interface{}{ + "type": "tool_call", + "item_id": toolCallID, + "output_index": outputIndex, + "name": t.Name, + "arguments": t.Arguments, + "reasoning": t.Reasoning, + }: + } + return cogito.ToolCallDecision{ + Approved: true, + } + } + + toolCallResultCallback := func(t cogito.ToolStatus) { + select { + case <-ctx.Done(): + return + default: + } + outputIndex++ + callID := fmt.Sprintf("fc_%s", uuid.New().String()) + item := schema.ORItemField{ + Type: "function_call_output", + ID: fmt.Sprintf("fco_%s", uuid.New().String()), + Status: "completed", + CallID: callID, + Output: t.Result, + } + collectedOutputItems = append(collectedOutputItems, item) + + select { + case <-ctx.Done(): + return + case events <- map[string]interface{}{ + "type": "tool_result", + "item_id": item.ID, + "output_index": outputIndex, + "name": t.Name, + "result": t.Result, + }: + } + } + + cogitoOpts = append(cogitoOpts, + cogito.WithStatusCallback(statusCallback), + cogito.WithReasoningCallback(reasoningCallback), + cogito.WithToolCallBack(toolCallCallback), + cogito.WithToolCallResultCallback(toolCallResultCallback), + ) + + // Execute tools in goroutine + go func() { + defer close(events) + + f, err := cogito.ExecuteTools(defaultLLM, frag, cogitoOpts...) + if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) { + select { + case <-ctx.Done(): + ended <- ctx.Err() + case events <- map[string]interface{}{ + "type": "error", + "message": fmt.Sprintf("Failed to execute tools: %v", err), + }: + ended <- err + } + return + } + + // Check for cancellation + select { + case <-ctx.Done(): + ended <- ctx.Err() + return + default: + } + + // Get final response + f, err = defaultLLM.Ask(ctx, f) + if err != nil { + select { + case <-ctx.Done(): + ended <- ctx.Err() + case events <- map[string]interface{}{ + "type": "error", + "message": fmt.Sprintf("Failed to get response: %v", err), + }: + ended <- err + } + return + } + + // Stream final assistant message + content := f.LastMessage().Content + messageID := fmt.Sprintf("msg_%s", uuid.New().String()) + outputIndex++ + item := schema.ORItemField{ + Type: "message", + ID: messageID, + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPart(content)}, + } + collectedOutputItems = append(collectedOutputItems, item) + + select { + case <-ctx.Done(): + ended <- ctx.Err() + case events <- map[string]interface{}{ + "type": "assistant", + "item_id": messageID, + "output_index": outputIndex, + "content": content, + }: + ended <- nil + } + }() + + // Process events from channel +LOOP: + for { + select { + case <-ctx.Done(): + break LOOP + case event := <-events: + if event == nil { + break LOOP + } + // Convert event to Open Responses format and buffer + bufferMCPEventAsOR(store, responseID, event, &sequenceNumber) + case err := <-ended: + if err != nil { + // Buffer error event + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "error", + SequenceNumber: sequenceNumber, + Error: &schema.ORErrorPayload{ + Type: "model_error", + Message: err.Error(), + }, + }) + sequenceNumber++ + + // Buffer failed response + responseFailed := buildORResponse(responseID, createdAt, nil, schema.ORStatusFailed, input, collectedOutputItems, nil, true) + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.failed", + SequenceNumber: sequenceNumber, + Response: responseFailed, + }) + return nil, err + } + + // Emit response.completed + now := time.Now().Unix() + responseCompleted := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, collectedOutputItems, nil, true) + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.completed", + SequenceNumber: sequenceNumber, + Response: responseCompleted, + }) + + break LOOP + } + } + + // Build final response + now := time.Now().Unix() + response := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, collectedOutputItems, nil, true) + + return response, nil +} + +// bufferEvent stores an SSE event in the response store for streaming resume +func bufferEvent(store *ResponseStore, responseID string, event *schema.ORStreamEvent) { + if err := store.AppendEvent(responseID, event); err != nil { + xlog.Error("Failed to buffer event", "response_id", responseID, "error", err) + } +} + +// handleOpenResponsesNonStream handles non-streaming responses +func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, shouldStore bool) error { + images := []string{} + videos := []string{} + audios := []string{} + for _, m := range openAIReq.Messages { + images = append(images, m.StringImages...) + videos = append(videos, m.StringVideos...) + audios = append(audios, m.StringAudios...) + } + + // Convert and serialize tools to OpenAI format for the backend + toolsJSON := serializeToolsForBackend(input.Tools) + toolChoiceJSON := "" + if input.ToolChoice != nil { + toolChoiceBytes, err := json.Marshal(input.ToolChoice) + if err == nil { + toolChoiceJSON = string(toolChoiceBytes) + } + } + + // Pass logprobs and logit_bias parameters if requested + var logprobs *int + if input.TopLogprobs != nil && *input.TopLogprobs > 0 { + logprobs = input.TopLogprobs + } + + predFunc, err := backend.ModelInference( + input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias) + if err != nil { + xlog.Error("Open Responses model inference failed", "error", err) + return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("model inference failed: %v", err), "") + } + + prediction, err := predFunc() + if err != nil { + xlog.Error("Open Responses prediction failed", "error", err) + return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("prediction failed: %v", err), "") + } + + result := backend.Finetune(*cfg, predInput, prediction.Response) + xlog.Debug("Open Responses - Raw model result", "result", result, "shouldUseFn", shouldUseFn) + + // Parse tool calls if using functions + var outputItems []schema.ORItemField + var toolCalls []schema.ToolCall + + if shouldUseFn { + // Clean up the result first (handle reasoning tags, etc.) + cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig) + xlog.Debug("Open Responses - Cleaned result", "cleanedResult", cleanedResult) + + funcCallResults := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) + textContent := functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig) + xlog.Debug("Open Responses - Parsed function calls", "count", len(funcCallResults), "textContent", textContent) + + // Check for noAction function (model chose to respond without tool) + noActionName := "answer" + if cfg.FunctionsConfig.NoActionFunctionName != "" { + noActionName = cfg.FunctionsConfig.NoActionFunctionName + } + + // Filter out noAction calls and extract the message + for i, fc := range funcCallResults { + if fc.Name == noActionName { + // This is a text response, not a tool call + // Try to extract the message from the arguments + if fc.Arguments != "" { + var args map[string]interface{} + if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { + if msg, ok := args["message"].(string); ok && msg != "" { + textContent = msg + } + } + } + continue + } + toolCalls = append(toolCalls, schema.ToolCall{ + Index: i, + ID: fmt.Sprintf("fc_%s", uuid.New().String()), + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: fc.Name, + Arguments: fc.Arguments, + }, + }) + } + + // Add message item with text content (include logprobs if available) + if textContent != "" { + outputItems = append(outputItems, schema.ORItemField{ + Type: "message", + ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)}, + }) + } + + // Add function call items + for _, tc := range toolCalls { + outputItems = append(outputItems, schema.ORItemField{ + Type: "function_call", + ID: fmt.Sprintf("fc_%s", uuid.New().String()), + Status: "completed", + CallID: tc.ID, + Name: tc.FunctionCall.Name, + Arguments: tc.FunctionCall.Arguments, + }) + } + + // If we have no output items but the model did produce output, include the raw result as a message + // This handles cases where the function call parsing failed but we still have model output + if len(outputItems) == 0 && result != "" { + xlog.Debug("Open Responses - No parsed output, falling back to raw result") + outputItems = append(outputItems, schema.ORItemField{ + Type: "message", + ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)}, + }) + } + } else { + // Simple text response (include logprobs if available) + outputItems = []schema.ORItemField{ + { + Type: "message", + ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)}, + }, + } + } + + // Build response with all required fields + now := time.Now().Unix() + response := buildORResponse(responseID, createdAt, &now, "completed", input, outputItems, &schema.ORUsage{ + InputTokens: prediction.Usage.Prompt, + OutputTokens: prediction.Usage.Completion, + TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + }, shouldStore) + + // Store response for future reference (if enabled) + if shouldStore { + store := GetGlobalStore() + store.Store(responseID, input, response) + } + + return c.JSON(200, response) +} + +// handleOpenResponsesStream handles streaming responses +func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, shouldStore bool) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + + sequenceNumber := 0 + + // Emit response.created - use helper to create response with all required fields + responseCreated := buildORResponse(responseID, createdAt, nil, "in_progress", input, []schema.ORItemField{}, nil, shouldStore) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.created", + SequenceNumber: sequenceNumber, + Response: responseCreated, + }) + sequenceNumber++ + + // Emit response.in_progress + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.in_progress", + SequenceNumber: sequenceNumber, + Response: responseCreated, + }) + sequenceNumber++ + + images := []string{} + videos := []string{} + audios := []string{} + for _, m := range openAIReq.Messages { + images = append(images, m.StringImages...) + videos = append(videos, m.StringVideos...) + audios = append(audios, m.StringAudios...) + } + + // Convert and serialize tools to OpenAI format for the backend + toolsJSON := serializeToolsForBackend(input.Tools) + toolChoiceJSON := "" + if input.ToolChoice != nil { + toolChoiceBytes, err := json.Marshal(input.ToolChoice) + if err == nil { + toolChoiceJSON = string(toolChoiceBytes) + } + } + + // Track state for streaming + var currentMessageID string + var currentContentIndex int + var accumulatedText string + var lastEmittedToolCallCount int + outputIndex := 0 + inToolCallMode := false + + // Collect all output items for storage + var collectedOutputItems []schema.ORItemField + + if shouldUseFn { + // For tool calls, we need to track accumulated result and parse incrementally + // We'll handle this differently - track the full result and parse tool calls + accumulatedResult := "" + tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool { + accumulatedResult += token + accumulatedText += token + + // Try to parse tool calls incrementally + cleanedResult := functions.CleanupLLMResult(accumulatedResult, cfg.FunctionsConfig) + + // Determine XML format from config + var xmlFormat *functions.XMLToolCallFormat + if cfg.FunctionsConfig.XMLFormat != nil { + xmlFormat = cfg.FunctionsConfig.XMLFormat + } else if cfg.FunctionsConfig.XMLFormatPreset != "" { + xmlFormat = functions.GetXMLFormatPreset(cfg.FunctionsConfig.XMLFormatPreset) + } + + // Try XML parsing first + partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true) + if parseErr == nil && len(partialResults) > lastEmittedToolCallCount { + // New tool calls detected + if !inToolCallMode && currentMessageID != "" { + // Close the current message content part + textPart := makeOutputTextPart(functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &textPart, + }) + sequenceNumber++ + inToolCallMode = true + } + + // Emit new tool calls + for i := lastEmittedToolCallCount; i < len(partialResults); i++ { + tc := partialResults[i] + toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) + outputIndex++ + + // Emit function_call item added + functionCallItem := &schema.ORItemField{ + Type: "function_call", + ID: toolCallID, + Status: "in_progress", + CallID: toolCallID, + Name: tc.Name, + Arguments: "", + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + + // Emit arguments delta + if tc.Arguments != "" { + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.function_call_arguments.delta", + SequenceNumber: sequenceNumber, + ItemID: toolCallID, + OutputIndex: &outputIndex, + Delta: strPtr(tc.Arguments), + }) + sequenceNumber++ + + // Emit arguments done + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.function_call_arguments.done", + SequenceNumber: sequenceNumber, + ItemID: toolCallID, + OutputIndex: &outputIndex, + Arguments: strPtr(tc.Arguments), + }) + sequenceNumber++ + + // Emit function_call item done + functionCallItem.Status = "completed" + functionCallItem.Arguments = tc.Arguments + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + + // Collect item for storage + collectedOutputItems = append(collectedOutputItems, *functionCallItem) + } + } + lastEmittedToolCallCount = len(partialResults) + c.Response().Flush() + return true + } + + // Try JSON parsing as fallback + jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true) + if jsonErr == nil && len(jsonResults) > lastEmittedToolCallCount { + for i := lastEmittedToolCallCount; i < len(jsonResults); i++ { + jsonObj := jsonResults[i] + if name, ok := jsonObj["name"].(string); ok && name != "" { + args := "{}" + if argsVal, ok := jsonObj["arguments"]; ok { + if argsStr, ok := argsVal.(string); ok { + args = argsStr + } else { + argsBytes, _ := json.Marshal(argsVal) + args = string(argsBytes) + } + } + + toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) + outputIndex++ + + functionCallItem := &schema.ORItemField{ + Type: "function_call", + ID: toolCallID, + Status: "completed", + CallID: toolCallID, + Name: name, + Arguments: args, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + } + } + lastEmittedToolCallCount = len(jsonResults) + c.Response().Flush() + return true + } + + // If no tool calls detected yet, emit text delta + if !inToolCallMode { + if currentMessageID == "" { + // Emit output_item.added for message + currentMessageID = fmt.Sprintf("msg_%s", uuid.New().String()) + messageItem := &schema.ORItemField{ + Type: "message", + ID: currentMessageID, + Status: "in_progress", + Role: "assistant", + Content: []schema.ORContentPart{}, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: messageItem, + }) + sequenceNumber++ + + // Emit content_part.added + currentContentIndex = 0 + emptyPart := makeOutputTextPart("") + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.added", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &emptyPart, + }) + sequenceNumber++ + } + + // Emit text delta + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_text.delta", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Delta: strPtr(token), + Logprobs: emptyLogprobs(), + }) + sequenceNumber++ + c.Response().Flush() + } + return true + } + + // Pass logprobs and logit_bias parameters if requested + var streamLogprobs *int + if input.TopLogprobs != nil && *input.TopLogprobs > 0 { + streamLogprobs = input.TopLogprobs + } + + predFunc, err := backend.ModelInference( + input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias) + if err != nil { + xlog.Error("Open Responses stream model inference failed", "error", err) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "error", + SequenceNumber: sequenceNumber, + Error: &schema.ORErrorPayload{ + Type: "model_error", + Message: fmt.Sprintf("model inference failed: %v", err), + }, + }) + sequenceNumber++ + responseFailed := responseCreated + responseFailed.Status = "failed" + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.failed", + SequenceNumber: sequenceNumber, + Response: responseFailed, + }) + // Send [DONE] even on error + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + + prediction, err := predFunc() + if err != nil { + xlog.Error("Open Responses stream prediction failed", "error", err) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "error", + SequenceNumber: sequenceNumber, + Error: &schema.ORErrorPayload{ + Type: "model_error", + Message: fmt.Sprintf("prediction failed: %v", err), + }, + }) + sequenceNumber++ + responseFailed := responseCreated + responseFailed.Status = "failed" + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.failed", + SequenceNumber: sequenceNumber, + Response: responseFailed, + }) + // Send [DONE] even on error + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + + result := backend.Finetune(*cfg, predInput, prediction.Response) + cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig) + xlog.Debug("Open Responses Stream - Cleaned result", "cleanedResult", cleanedResult) + + parsedToolCalls := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) + textContent := functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig) + + // Handle noAction function (model chose to respond without tool) + noActionName := "answer" + if cfg.FunctionsConfig.NoActionFunctionName != "" { + noActionName = cfg.FunctionsConfig.NoActionFunctionName + } + + // Filter out noAction calls and extract the message + var toolCalls []functions.FuncCallResults + for _, fc := range parsedToolCalls { + if fc.Name == noActionName { + // This is a text response, not a tool call + if fc.Arguments != "" { + var args map[string]interface{} + if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { + if msg, ok := args["message"].(string); ok && msg != "" { + textContent = msg + } + } + } + continue + } + toolCalls = append(toolCalls, fc) + } + + xlog.Debug("Open Responses Stream - Parsed", "toolCalls", len(toolCalls), "textContent", textContent) + + // Convert prediction logprobs for streaming events + streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs) + + // If we have no output but the model did produce something, use the raw result + if textContent == "" && len(toolCalls) == 0 && result != "" { + xlog.Debug("Open Responses Stream - No parsed output, using raw result") + textContent = result + } + + // Close message if we have text content + if currentMessageID != "" && textContent != "" && !inToolCallMode { + // Emit output_text.done + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_text.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Text: strPtr(textContent), + Logprobs: logprobsPtr(streamEventLogprobs), + }) + sequenceNumber++ + + // Emit content_part.done (with actual logprobs) + textPart := makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &textPart, + }) + sequenceNumber++ + + // Emit output_item.done for message (with actual logprobs) + messageItem := &schema.ORItemField{ + Type: "message", + ID: currentMessageID, + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)}, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: messageItem, + }) + sequenceNumber++ + + // Collect message item for storage + collectedOutputItems = append(collectedOutputItems, *messageItem) + } + + // Emit any remaining tool calls that weren't streamed + for i := lastEmittedToolCallCount; i < len(toolCalls); i++ { + tc := toolCalls[i] + toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) + outputIndex++ + + functionCallItem := &schema.ORItemField{ + Type: "function_call", + ID: toolCallID, + Status: "completed", + CallID: toolCallID, + Name: tc.Name, + Arguments: tc.Arguments, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + + // Collect function call item for storage + collectedOutputItems = append(collectedOutputItems, *functionCallItem) + } + + // Build final response with all items (include logprobs) + var allOutputItems []schema.ORItemField + if currentMessageID != "" && textContent != "" { + allOutputItems = append(allOutputItems, schema.ORItemField{ + Type: "message", + ID: currentMessageID, + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)}, + }) + } + for _, tc := range toolCalls { + toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) + allOutputItems = append(allOutputItems, schema.ORItemField{ + Type: "function_call", + ID: toolCallID, + Status: "completed", + CallID: toolCallID, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + + // Emit response.completed + now := time.Now().Unix() + responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, allOutputItems, &schema.ORUsage{ + InputTokens: prediction.Usage.Prompt, + OutputTokens: prediction.Usage.Completion, + TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + }, shouldStore) + + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.completed", + SequenceNumber: sequenceNumber, + Response: responseCompleted, + }) + + // Store response for future reference (if enabled) + if shouldStore { + store := GetGlobalStore() + store.Store(responseID, input, responseCompleted) + } + + // Send [DONE] + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + + return nil + } + + // Non-tool-call streaming path + // Emit output_item.added for message + currentMessageID = fmt.Sprintf("msg_%s", uuid.New().String()) + messageItem := &schema.ORItemField{ + Type: "message", + ID: currentMessageID, + Status: "in_progress", + Role: "assistant", + Content: []schema.ORContentPart{}, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: messageItem, + }) + sequenceNumber++ + + // Emit content_part.added + currentContentIndex = 0 + emptyTextPart := makeOutputTextPart("") + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.added", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &emptyTextPart, + }) + sequenceNumber++ + + // Stream text deltas + tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool { + accumulatedText += token + + // Emit text delta + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_text.delta", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Delta: strPtr(token), + Logprobs: emptyLogprobs(), + }) + sequenceNumber++ + c.Response().Flush() + return true + } + + // Pass logprobs and logit_bias parameters if requested + var mcpLogprobs *int + if input.TopLogprobs != nil && *input.TopLogprobs > 0 { + mcpLogprobs = input.TopLogprobs + } + + predFunc, err := backend.ModelInference( + input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, mcpLogprobs, input.TopLogprobs, input.LogitBias) + if err != nil { + xlog.Error("Open Responses stream model inference failed", "error", err) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "error", + SequenceNumber: sequenceNumber, + Error: &schema.ORErrorPayload{ + Type: "model_error", + Message: fmt.Sprintf("model inference failed: %v", err), + }, + }) + sequenceNumber++ + responseFailed := responseCreated + responseFailed.Status = "failed" + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.failed", + SequenceNumber: sequenceNumber, + Response: responseFailed, + }) + // Send [DONE] even on error + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + + prediction, err := predFunc() + if err != nil { + xlog.Error("Open Responses stream prediction failed", "error", err) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "error", + SequenceNumber: sequenceNumber, + Error: &schema.ORErrorPayload{ + Type: "model_error", + Message: fmt.Sprintf("prediction failed: %v", err), + }, + }) + sequenceNumber++ + responseFailed := responseCreated + responseFailed.Status = "failed" + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.failed", + SequenceNumber: sequenceNumber, + Response: responseFailed, + }) + // Send [DONE] even on error + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + + result := backend.Finetune(*cfg, predInput, prediction.Response) + + // Convert prediction logprobs for streaming events + mcpStreamLogprobs := convertLogprobsForStreaming(prediction.Logprobs) + + // Emit output_text.done + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_text.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Text: strPtr(result), + Logprobs: logprobsPtr(mcpStreamLogprobs), + }) + sequenceNumber++ + + // Emit content_part.done (with actual logprobs) + resultPart := makeOutputTextPartWithLogprobs(result, prediction.Logprobs) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &resultPart, + }) + sequenceNumber++ + + // Emit output_item.done (with actual logprobs) + messageItem.Status = "completed" + messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)} + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: messageItem, + }) + sequenceNumber++ + + // Emit response.completed + now := time.Now().Unix() + + // Collect final output items (use collected items if available, otherwise use messageItem) + var finalOutputItems []schema.ORItemField + if len(collectedOutputItems) > 0 { + finalOutputItems = collectedOutputItems + } else { + finalOutputItems = []schema.ORItemField{*messageItem} + } + responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, finalOutputItems, &schema.ORUsage{ + InputTokens: prediction.Usage.Prompt, + OutputTokens: prediction.Usage.Completion, + TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + }, shouldStore) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.completed", + SequenceNumber: sequenceNumber, + Response: responseCompleted, + }) + + // Store response for future reference (if enabled) + if shouldStore { + store := GetGlobalStore() + store.Store(responseID, input, responseCompleted) + } + + // Send [DONE] + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + + return nil +} + +// handleMCPResponse handles responses using MCP agentic loop +func handleMCPResponse(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string, openAIReq *schema.OpenAIRequest, appConfig *config.ApplicationConfig, shouldStore bool) error { + ctx := input.Context + if ctx == nil { + ctx = c.Request().Context() + } + + // Validate MCP config + if cfg.MCP.Servers == "" && cfg.MCP.Stdio == "" { + return sendOpenResponsesError(c, 400, "invalid_request", "no MCP servers configured", "") + } + + // Get MCP config from model config + remote, stdio, err := cfg.MCP.MCPConfigFromYAML() + if err != nil { + return sendOpenResponsesError(c, 500, "server_error", fmt.Sprintf("failed to get MCP config: %v", err), "") + } + + // Get MCP sessions + sessions, err := mcpTools.SessionsFromMCPConfig(cfg.Name, remote, stdio) + if err != nil { + return sendOpenResponsesError(c, 500, "server_error", fmt.Sprintf("failed to get MCP sessions: %v", err), "") + } + + if len(sessions) == 0 { + return sendOpenResponsesError(c, 500, "server_error", "no working MCP servers found", "") + } + + // Build fragment from messages + fragment := cogito.NewEmptyFragment() + for _, message := range openAIReq.Messages { + fragment = fragment.AddMessage(message.Role, message.StringContent) + } + fragmentPtr := &fragment + + // Get API address and key + _, port, err := net.SplitHostPort(appConfig.APIAddress) + if err != nil { + return sendOpenResponsesError(c, 500, "server_error", fmt.Sprintf("failed to parse API address: %v", err), "") + } + apiKey := "" + if len(appConfig.ApiKeys) > 0 { + apiKey = appConfig.ApiKeys[0] + } + + ctxWithCancellation, cancel := context.WithCancel(ctx) + defer cancel() + + // Create OpenAI LLM client + defaultLLM := cogito.NewOpenAILLM(cfg.Name, apiKey, "http://127.0.0.1:"+port) + + // Build cogito options + cogitoOpts := cfg.BuildCogitoOptions() + cogitoOpts = append( + cogitoOpts, + cogito.WithContext(ctxWithCancellation), + cogito.WithMCPs(sessions...), + ) + + if input.Stream { + return handleMCPStream(c, responseID, createdAt, input, cfg, defaultLLM, fragmentPtr, cogitoOpts, ctxWithCancellation, cancel, shouldStore) + } + + // Non-streaming mode + return handleMCPNonStream(c, responseID, createdAt, input, cfg, defaultLLM, fragmentPtr, cogitoOpts, ctxWithCancellation, shouldStore) +} + +// sendSSEEvent sends a Server-Sent Event +func sendSSEEvent(c echo.Context, event *schema.ORStreamEvent) { + data, err := json.Marshal(event) + if err != nil { + xlog.Error("Failed to marshal SSE event", "error", err) + return + } + fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.Type, string(data)) +} + +// handleMCPNonStream handles non-streaming MCP responses +func handleMCPNonStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, defaultLLM cogito.LLM, fragment *cogito.Fragment, cogitoOpts []cogito.Option, ctx context.Context, shouldStore bool) error { + frag := *fragment + // Set up callbacks for logging + cogitoOpts = append( + cogitoOpts, + cogito.WithStatusCallback(func(s string) { + xlog.Debug("[Open Responses MCP] Status", "model", cfg.Name, "status", s) + }), + cogito.WithReasoningCallback(func(s string) { + xlog.Debug("[Open Responses MCP] Reasoning", "model", cfg.Name, "reasoning", s) + }), + cogito.WithToolCallBack(func(t *cogito.ToolChoice, state *cogito.SessionState) cogito.ToolCallDecision { + xlog.Debug("[Open Responses MCP] Tool call", "model", cfg.Name, "tool", t.Name, "reasoning", t.Reasoning, "arguments", t.Arguments) + return cogito.ToolCallDecision{ + Approved: true, + } + }), + cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) { + xlog.Debug("[Open Responses MCP] Tool call result", "model", cfg.Name, "tool", t.Name, "result", t.Result, "tool_arguments", t.ToolArguments) + }), + ) + + // Execute tools + f, err := cogito.ExecuteTools(defaultLLM, frag, cogitoOpts...) + if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) { + return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("failed to execute tools: %v", err), "") + } + + // Get final response + f, err = defaultLLM.Ask(ctx, f) + if err != nil { + return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("failed to get response: %v", err), "") + } + + // Convert fragment to Open Responses format + fPtr := &f + outputItems := convertCogitoFragmentToORItems(fPtr) + + // Build response with all required fields + now := time.Now().Unix() + response := buildORResponse(responseID, createdAt, &now, "completed", input, outputItems, nil, shouldStore) + + // Store response (if enabled) + if shouldStore { + store := GetGlobalStore() + store.Store(responseID, input, response) + } + + return c.JSON(200, response) +} + +// handleMCPStream handles streaming MCP responses +func handleMCPStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, defaultLLM cogito.LLM, fragment *cogito.Fragment, cogitoOpts []cogito.Option, ctx context.Context, cancel context.CancelFunc, shouldStore bool) error { + frag := *fragment + // Set SSE headers + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + + sequenceNumber := 0 + + // Emit response.created - use helper to create response with all required fields + responseCreated := buildORResponse(responseID, createdAt, nil, "in_progress", input, []schema.ORItemField{}, nil, shouldStore) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.created", + SequenceNumber: sequenceNumber, + Response: responseCreated, + }) + sequenceNumber++ + + // Emit response.in_progress + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.in_progress", + SequenceNumber: sequenceNumber, + Response: responseCreated, + }) + sequenceNumber++ + + // Create channels for streaming events + events := make(chan interface{}) + ended := make(chan error, 1) + var collectedOutputItems []schema.ORItemField + outputIndex := 0 + + // Set up callbacks + statusCallback := func(s string) { + events <- map[string]interface{}{ + "type": "status", + "message": s, + } + } + + reasoningCallback := func(s string) { + itemID := fmt.Sprintf("reasoning_%s", uuid.New().String()) + outputIndex++ + item := &schema.ORItemField{ + Type: "reasoning", + ID: itemID, + Status: "in_progress", + } + collectedOutputItems = append(collectedOutputItems, *item) + + events <- map[string]interface{}{ + "type": "reasoning", + "item_id": itemID, + "output_index": outputIndex, + "content": s, + } + } + + toolCallCallback := func(t *cogito.ToolChoice, state *cogito.SessionState) cogito.ToolCallDecision { + toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) + outputIndex++ + item := &schema.ORItemField{ + Type: "function_call", + ID: toolCallID, + Status: "in_progress", + CallID: toolCallID, + Name: t.Name, + Arguments: "", + } + collectedOutputItems = append(collectedOutputItems, *item) + + events <- map[string]interface{}{ + "type": "tool_call", + "item_id": toolCallID, + "output_index": outputIndex, + "name": t.Name, + "arguments": t.Arguments, + "reasoning": t.Reasoning, + } + return cogito.ToolCallDecision{ + Approved: true, + } + } + + toolCallResultCallback := func(t cogito.ToolStatus) { + outputIndex++ + callID := fmt.Sprintf("fc_%s", uuid.New().String()) + item := schema.ORItemField{ + Type: "function_call_output", + ID: fmt.Sprintf("fco_%s", uuid.New().String()), + Status: "completed", + CallID: callID, + Output: t.Result, + } + collectedOutputItems = append(collectedOutputItems, item) + + events <- map[string]interface{}{ + "type": "tool_result", + "item_id": item.ID, + "output_index": outputIndex, + "name": t.Name, + "result": t.Result, + } + } + + cogitoOpts = append(cogitoOpts, + cogito.WithStatusCallback(statusCallback), + cogito.WithReasoningCallback(reasoningCallback), + cogito.WithToolCallBack(toolCallCallback), + cogito.WithToolCallResultCallback(toolCallResultCallback), + ) + + // Execute tools in goroutine + go func() { + defer close(events) + + f, err := cogito.ExecuteTools(defaultLLM, frag, cogitoOpts...) + if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) { + events <- map[string]interface{}{ + "type": "error", + "message": fmt.Sprintf("Failed to execute tools: %v", err), + } + ended <- err + return + } + + // Get final response + f, err = defaultLLM.Ask(ctx, f) + if err != nil { + events <- map[string]interface{}{ + "type": "error", + "message": fmt.Sprintf("Failed to get response: %v", err), + } + ended <- err + return + } + + // Stream final assistant message + content := f.LastMessage().Content + messageID := fmt.Sprintf("msg_%s", uuid.New().String()) + outputIndex++ + item := schema.ORItemField{ + Type: "message", + ID: messageID, + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPart(content)}, + } + collectedOutputItems = append(collectedOutputItems, item) + + events <- map[string]interface{}{ + "type": "assistant", + "item_id": messageID, + "output_index": outputIndex, + "content": content, + } + + ended <- nil + }() + + // Stream events to client +LOOP: + for { + select { + case <-ctx.Done(): + cancel() + break LOOP + case event := <-events: + if event == nil { + break LOOP + } + // Convert event to Open Responses format and send + if err := sendMCPEventAsOR(c, event, &sequenceNumber); err != nil { + cancel() + return err + } + c.Response().Flush() + case err := <-ended: + if err == nil { + // Emit response.completed + now := time.Now().Unix() + responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, collectedOutputItems, nil, shouldStore) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.completed", + SequenceNumber: sequenceNumber, + Response: responseCompleted, + }) + sequenceNumber++ + + // Store response (if enabled) + if shouldStore { + store := GetGlobalStore() + store.Store(responseID, input, responseCompleted) + } + + // Send [DONE] + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + break LOOP + } + // Send error + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "error", + SequenceNumber: sequenceNumber, + Error: &schema.ORErrorPayload{ + Type: "model_error", + Message: err.Error(), + }, + }) + sequenceNumber++ + responseFailed := buildORResponse(responseID, createdAt, nil, "failed", input, collectedOutputItems, nil, shouldStore) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.failed", + SequenceNumber: sequenceNumber, + Response: responseFailed, + }) + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + } + + return nil +} + +// convertCogitoFragmentToORItems converts a cogito fragment to Open Responses items +func convertCogitoFragmentToORItems(f *cogito.Fragment) []schema.ORItemField { + var items []schema.ORItemField + + // Get the last message (assistant response) + lastMsg := f.LastMessage() + if lastMsg != nil && lastMsg.Content != "" { + items = append(items, schema.ORItemField{ + Type: "message", + ID: fmt.Sprintf("msg_%s", uuid.New().String()), + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPart(lastMsg.Content)}, + }) + } + + return items +} + +// sendMCPEventAsOR converts MCP events to Open Responses format and sends them +func sendMCPEventAsOR(c echo.Context, event interface{}, sequenceNumber *int) error { + eventMap, ok := event.(map[string]interface{}) + if !ok { + return nil + } + + eventType, _ := eventMap["type"].(string) + switch eventType { + case "status": + // Status events are informational, skip for now + return nil + case "reasoning": + itemID, _ := eventMap["item_id"].(string) + outputIndex, _ := eventMap["output_index"].(int) + + item := &schema.ORItemField{ + Type: "reasoning", + ID: itemID, + Status: "in_progress", + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + // Note: reasoning content streaming would go here + return nil + case "tool_call": + itemID, _ := eventMap["item_id"].(string) + outputIndex, _ := eventMap["output_index"].(int) + name, _ := eventMap["name"].(string) + arguments, _ := eventMap["arguments"].(string) + + item := &schema.ORItemField{ + Type: "function_call", + ID: itemID, + Status: "in_progress", + CallID: itemID, + Name: name, + Arguments: "", + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + + // Emit arguments + if arguments != "" { + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.function_call_arguments.delta", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + Delta: strPtr(arguments), + }) + *sequenceNumber++ + + item.Status = "completed" + item.Arguments = arguments + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.function_call_arguments.done", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + Arguments: strPtr(arguments), + }) + *sequenceNumber++ + + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + } + return nil + case "tool_result": + itemID, _ := eventMap["item_id"].(string) + outputIndex, _ := eventMap["output_index"].(int) + result, _ := eventMap["result"].(string) + + item := &schema.ORItemField{ + Type: "function_call_output", + ID: itemID, + Status: "completed", + Output: result, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + return nil + case "assistant": + itemID, _ := eventMap["item_id"].(string) + outputIndex, _ := eventMap["output_index"].(int) + content, _ := eventMap["content"].(string) + + item := &schema.ORItemField{ + Type: "message", + ID: itemID, + Status: "in_progress", + Role: "assistant", + Content: []schema.ORContentPart{}, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + + // Emit content part + emptyPart := makeOutputTextPart("") + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.added", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + ContentIndex: func() *int { i := 0; return &i }(), + Part: &emptyPart, + }) + *sequenceNumber++ + + // Emit text done + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_text.done", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + ContentIndex: func() *int { i := 0; return &i }(), + Text: strPtr(content), + Logprobs: emptyLogprobs(), + }) + *sequenceNumber++ + + // Emit content part done + contentPart := makeOutputTextPart(content) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.done", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + ContentIndex: func() *int { i := 0; return &i }(), + Part: &contentPart, + }) + *sequenceNumber++ + + // Emit item done + item.Status = "completed" + item.Content = []schema.ORContentPart{makeOutputTextPart(content)} + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + return nil + case "error": + message, _ := eventMap["message"].(string) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "error", + SequenceNumber: *sequenceNumber, + Error: &schema.ORErrorPayload{ + Type: "model_error", + Message: message, + }, + }) + *sequenceNumber++ + return nil + } + + return nil +} + +// bufferMCPEventAsOR converts MCP events to Open Responses format and buffers them +func bufferMCPEventAsOR(store *ResponseStore, responseID string, event interface{}, sequenceNumber *int) { + eventMap, ok := event.(map[string]interface{}) + if !ok { + return + } + + eventType, _ := eventMap["type"].(string) + switch eventType { + case "status": + // Status events are informational, skip for now + return + case "reasoning": + itemID, _ := eventMap["item_id"].(string) + outputIndex, _ := eventMap["output_index"].(int) + + item := &schema.ORItemField{ + Type: "reasoning", + ID: itemID, + Status: "in_progress", + } + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + // Note: reasoning content streaming would go here + return + case "tool_call": + itemID, _ := eventMap["item_id"].(string) + outputIndex, _ := eventMap["output_index"].(int) + name, _ := eventMap["name"].(string) + arguments, _ := eventMap["arguments"].(string) + + item := &schema.ORItemField{ + Type: "function_call", + ID: itemID, + Status: "in_progress", + CallID: itemID, + Name: name, + Arguments: "", + } + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + + // Emit arguments + if arguments != "" { + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.function_call_arguments.delta", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + Delta: strPtr(arguments), + }) + *sequenceNumber++ + + item.Status = "completed" + item.Arguments = arguments + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.function_call_arguments.done", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + Arguments: strPtr(arguments), + }) + *sequenceNumber++ + + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + } + return + case "tool_result": + itemID, _ := eventMap["item_id"].(string) + outputIndex, _ := eventMap["output_index"].(int) + result, _ := eventMap["result"].(string) + + item := &schema.ORItemField{ + Type: "function_call_output", + ID: itemID, + Status: "completed", + Output: result, + } + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + return + case "assistant": + itemID, _ := eventMap["item_id"].(string) + outputIndex, _ := eventMap["output_index"].(int) + content, _ := eventMap["content"].(string) + + item := &schema.ORItemField{ + Type: "message", + ID: itemID, + Status: "in_progress", + Role: "assistant", + Content: []schema.ORContentPart{}, + } + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + + // Emit content part + emptyPart := makeOutputTextPart("") + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.content_part.added", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + ContentIndex: func() *int { i := 0; return &i }(), + Part: &emptyPart, + }) + *sequenceNumber++ + + // Emit text done + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_text.done", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + ContentIndex: func() *int { i := 0; return &i }(), + Text: strPtr(content), + Logprobs: emptyLogprobs(), + }) + *sequenceNumber++ + + // Emit content part done + contentPart := makeOutputTextPart(content) + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.content_part.done", + SequenceNumber: *sequenceNumber, + ItemID: itemID, + OutputIndex: &outputIndex, + ContentIndex: func() *int { i := 0; return &i }(), + Part: &contentPart, + }) + *sequenceNumber++ + + // Emit item done + item.Status = "completed" + item.Content = []schema.ORContentPart{makeOutputTextPart(content)} + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: *sequenceNumber, + OutputIndex: &outputIndex, + Item: item, + }) + *sequenceNumber++ + return + case "error": + message, _ := eventMap["message"].(string) + bufferEvent(store, responseID, &schema.ORStreamEvent{ + Type: "error", + SequenceNumber: *sequenceNumber, + Error: &schema.ORErrorPayload{ + Type: "model_error", + Message: message, + }, + }) + *sequenceNumber++ + return + } +} + +// getTopLogprobs returns the top_logprobs value, defaulting to 0 if nil +func getTopLogprobs(topLogprobs *int) int { + if topLogprobs != nil { + return *topLogprobs + } + return 0 +} + +// Helper functions for pointer types in streaming events +func strPtr(s string) *string { + return &s +} + +func logprobsPtr(lp []schema.ORLogProb) *[]schema.ORLogProb { + return &lp +} + +func emptyLogprobs() *[]schema.ORLogProb { + empty := []schema.ORLogProb{} + return &empty +} + +// makeOutputTextPart creates an output_text content part with all required fields per Open Responses spec +func makeOutputTextPart(text string) schema.ORContentPart { + return schema.ORContentPartWithLogprobs(text, nil) +} + +// makeOutputTextPartWithLogprobs creates an output_text content part with actual logprobs data +func makeOutputTextPartWithLogprobs(text string, logprobs *schema.Logprobs) schema.ORContentPart { + return schema.ORContentPartWithLogprobs(text, logprobs) +} + +// convertLogprobsForStreaming converts OpenAI-style logprobs to Open Responses format for streaming events +func convertLogprobsForStreaming(logprobs *schema.Logprobs) []schema.ORLogProb { + if logprobs == nil || len(logprobs.Content) == 0 { + return []schema.ORLogProb{} + } + + result := make([]schema.ORLogProb, 0, len(logprobs.Content)) + for _, lp := range logprobs.Content { + topLPs := make([]schema.ORTopLogProb, 0, len(lp.TopLogprobs)) + for _, tlp := range lp.TopLogprobs { + topLPs = append(topLPs, schema.ORTopLogProb{ + Token: tlp.Token, + Logprob: tlp.Logprob, + Bytes: tlp.Bytes, + }) + } + result = append(result, schema.ORLogProb{ + Token: lp.Token, + Logprob: lp.Logprob, + Bytes: lp.Bytes, + TopLogprobs: topLPs, + }) + } + return result +} + +// ensureUsageDetails ensures usage has all required detail fields +func ensureUsageDetails(usage *schema.ORUsage) *schema.ORUsage { + if usage == nil { + return nil + } + // Ensure details are always present (not nil) + if usage.InputTokensDetails == nil { + usage.InputTokensDetails = &schema.ORInputTokensDetails{CachedTokens: 0} + } + if usage.OutputTokensDetails == nil { + usage.OutputTokensDetails = &schema.OROutputTokensDetails{ReasoningTokens: 0} + } + return usage +} + +// buildORResponse creates a complete ORResponseResource with all required fields +func buildORResponse(responseID string, createdAt int64, completedAt *int64, status string, input *schema.OpenResponsesRequest, outputItems []schema.ORItemField, usage *schema.ORUsage, shouldStore bool) *schema.ORResponseResource { + // Ensure output is never null - always an array + if outputItems == nil { + outputItems = []schema.ORItemField{} + } + + // Ensure tools is never null - always an array + tools := input.Tools + if tools == nil { + tools = []schema.ORFunctionTool{} + } + + // Ensure metadata is never null - always a map + metadata := input.Metadata + if metadata == nil { + metadata = map[string]string{} + } + + // Set default values for sampling parameters + temperature := 1.0 + if input.Temperature != nil { + temperature = *input.Temperature + } + + topP := 1.0 + if input.TopP != nil { + topP = *input.TopP + } + + presencePenalty := 0.0 + if input.PresencePenalty != nil { + presencePenalty = *input.PresencePenalty + } + + frequencyPenalty := 0.0 + if input.FrequencyPenalty != nil { + frequencyPenalty = *input.FrequencyPenalty + } + + // Default truncation to "auto" + truncation := "auto" + if input.Truncation != "" { + truncation = input.Truncation + } + + // Default service_tier to "default" + serviceTier := "default" + if input.ServiceTier != "" { + serviceTier = input.ServiceTier + } + + // Default parallel_tool_calls to true + parallelToolCalls := true + if input.ParallelToolCalls != nil { + parallelToolCalls = *input.ParallelToolCalls + } + + // Default tool_choice: "auto" if tools are present, "none" otherwise + var toolChoice interface{} + if input.ToolChoice != nil { + toolChoice = input.ToolChoice + } else if len(tools) > 0 { + toolChoice = "auto" + } else { + toolChoice = "none" + } + + // Background defaults to false + background := false + if input.Background != nil { + background = *input.Background + } + + // Convert nullable string fields + var previousResponseID *string + if input.PreviousResponseID != "" { + previousResponseID = &input.PreviousResponseID + } + + var instructions *string + if input.Instructions != "" { + instructions = &input.Instructions + } + + // Convert reasoning + var reasoning *schema.ORReasoning + if input.Reasoning != nil { + reasoning = &schema.ORReasoning{ + Effort: input.Reasoning.Effort, + Summary: input.Reasoning.Summary, + } + } + + // Build default text config + textConfig := &schema.ORTextConfig{ + Format: &schema.ORTextFormat{ + Type: "text", + }, + } + + return &schema.ORResponseResource{ + ID: responseID, + Object: "response", + CreatedAt: createdAt, + CompletedAt: completedAt, + Status: status, + Model: input.Model, + Output: outputItems, + Error: nil, // null when no error + IncompleteDetails: nil, // null when complete + PreviousResponseID: previousResponseID, + Instructions: instructions, + + // Tool-related fields + Tools: tools, + ToolChoice: toolChoice, + ParallelToolCalls: parallelToolCalls, + MaxToolCalls: input.MaxToolCalls, + + // Sampling parameters + Temperature: temperature, + TopP: topP, + PresencePenalty: presencePenalty, + FrequencyPenalty: frequencyPenalty, + TopLogprobs: getTopLogprobs(input.TopLogprobs), + MaxOutputTokens: input.MaxOutputTokens, + + // Text format + Text: textConfig, + + // Truncation and reasoning + Truncation: truncation, + Reasoning: reasoning, + + // Usage + Usage: ensureUsageDetails(usage), + + // Metadata and operational flags + Metadata: metadata, + Store: shouldStore, + Background: background, + ServiceTier: serviceTier, + + // Safety and caching (nullable, not yet implemented) + SafetyIdentifier: nil, + PromptCacheKey: nil, + } +} + +// sendOpenResponsesError sends an error response +func sendOpenResponsesError(c echo.Context, statusCode int, errorType, message, param string) error { + errorResp := map[string]interface{}{ + "error": map[string]interface{}{ + "type": errorType, + "message": message, + }, + } + if param != "" { + errorResp["error"].(map[string]interface{})["param"] = param + } + return c.JSON(statusCode, errorResp) +} + +// convertORToolsToOpenAIFormat converts Open Responses tools to OpenAI format for the backend +// Open Responses format: { type, name, description, parameters } +// OpenAI format: { type, function: { name, description, parameters } } +func convertORToolsToOpenAIFormat(orTools []schema.ORFunctionTool) []functions.Tool { + result := make([]functions.Tool, 0, len(orTools)) + for _, t := range orTools { + result = append(result, functions.Tool{ + Type: "function", + Function: functions.Function{ + Name: t.Name, + Description: t.Description, + Parameters: t.Parameters, + }, + }) + } + return result +} + +// serializeToolsForBackend converts and serializes Open Responses tools to JSON for the backend +func serializeToolsForBackend(orTools []schema.ORFunctionTool) string { + if len(orTools) == 0 { + return "" + } + openAITools := convertORToolsToOpenAIFormat(orTools) + toolsBytes, err := json.Marshal(openAITools) + if err != nil { + return "" + } + return string(toolsBytes) +} + +// GetResponseEndpoint returns a handler for GET /responses/:id +// This endpoint is used for polling background responses or resuming streaming +// @Summary Get a response by ID +// @Description Retrieve a response by ID. Can be used for polling background responses or resuming streaming responses. +// @Param id path string true "Response ID" +// @Param stream query string false "Set to 'true' to resume streaming" +// @Param starting_after query int false "Sequence number to resume from (for streaming)" +// @Success 200 {object} schema.ORResponseResource "Response" +// @Failure 400 {object} map[string]interface{} "Bad Request" +// @Failure 404 {object} map[string]interface{} "Not Found" +// @Router /v1/responses/{id} [get] +func GetResponseEndpoint() func(c echo.Context) error { + return func(c echo.Context) error { + responseID := c.Param("id") + if responseID == "" { + return sendOpenResponsesError(c, 400, "invalid_request_error", "response ID is required", "id") + } + + store := GetGlobalStore() + stored, err := store.Get(responseID) + if err != nil { + return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("response not found: %s", responseID), "id") + } + + // Check if streaming resume is requested + streamParam := c.QueryParam("stream") + if streamParam == "true" { + // Validate that the response was created with streaming enabled + if !stored.StreamEnabled { + return sendOpenResponsesError(c, 400, "invalid_request_error", "cannot stream a response that was not created with stream=true", "stream") + } + + // Get starting_after parameter + startingAfter := 0 + startingAfterParam := c.QueryParam("starting_after") + if startingAfterParam != "" { + if _, err := fmt.Sscanf(startingAfterParam, "%d", &startingAfter); err != nil { + return sendOpenResponsesError(c, 400, "invalid_request_error", "starting_after must be an integer", "starting_after") + } + } + + return handleStreamResume(c, store, responseID, stored, startingAfter) + } + + // Non-streaming: return the current response state + stored.mu.RLock() + response := stored.Response + stored.mu.RUnlock() + + return c.JSON(200, response) + } +} + +// handleStreamResume handles resuming a streaming response from a specific sequence number +func handleStreamResume(c echo.Context, store *ResponseStore, responseID string, stored *StoredResponse, startingAfter int) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + + // Get buffered events after the starting point + events, err := store.GetEventsAfter(responseID, startingAfter) + if err != nil { + return sendOpenResponsesError(c, 500, "server_error", fmt.Sprintf("failed to get events: %v", err), "") + } + + // Send all buffered events + for _, event := range events { + fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.EventType, string(event.Data)) + c.Response().Flush() + } + + // Get the current status + stored.mu.RLock() + status := stored.Response.Status + stored.mu.RUnlock() + + // If response is still in progress, subscribe to new events + if status == schema.ORStatusQueued || status == schema.ORStatusInProgress { + eventsChan, err := store.GetEventsChan(responseID) + if err != nil { + // Response might have completed, just finish + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + + // Track last sent sequence number + lastSeq := startingAfter + if len(events) > 0 { + lastSeq = events[len(events)-1].SequenceNumber + } + + // Wait for new events or completion + for { + select { + case <-c.Request().Context().Done(): + // Client disconnected + return nil + case <-eventsChan: + // New events available + newEvents, err := store.GetEventsAfter(responseID, lastSeq) + if err != nil { + break + } + for _, event := range newEvents { + fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.EventType, string(event.Data)) + c.Response().Flush() + lastSeq = event.SequenceNumber + } + + // Check if response is now complete + stored.mu.RLock() + status = stored.Response.Status + stored.mu.RUnlock() + + if status != schema.ORStatusQueued && status != schema.ORStatusInProgress { + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + case <-time.After(30 * time.Second): + // Timeout - send keepalive or check status + stored.mu.RLock() + status = stored.Response.Status + stored.mu.RUnlock() + + if status != schema.ORStatusQueued && status != schema.ORStatusInProgress { + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + } + } + } + + // Response already complete + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil +} + +// CancelResponseEndpoint returns a handler for POST /responses/:id/cancel +// This endpoint cancels a background response if it's still in progress +// @Summary Cancel a response +// @Description Cancel a background response if it's still in progress +// @Param id path string true "Response ID" +// @Success 200 {object} schema.ORResponseResource "Response" +// @Failure 400 {object} map[string]interface{} "Bad Request" +// @Failure 404 {object} map[string]interface{} "Not Found" +// @Router /v1/responses/{id}/cancel [post] +func CancelResponseEndpoint() func(c echo.Context) error { + return func(c echo.Context) error { + responseID := c.Param("id") + if responseID == "" { + return sendOpenResponsesError(c, 400, "invalid_request_error", "response ID is required", "id") + } + + store := GetGlobalStore() + response, err := store.Cancel(responseID) + if err != nil { + return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("response not found: %s", responseID), "id") + } + + // Return the final response object + return c.JSON(200, response) + } +} diff --git a/core/http/endpoints/openresponses/store.go b/core/http/endpoints/openresponses/store.go new file mode 100644 index 000000000..a548254fb --- /dev/null +++ b/core/http/endpoints/openresponses/store.go @@ -0,0 +1,453 @@ +package openresponses + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/xlog" +) + +// ResponseStore provides thread-safe storage for Open Responses API responses +type ResponseStore struct { + mu sync.RWMutex + responses map[string]*StoredResponse + ttl time.Duration // Time-to-live for stored responses (0 = no expiration) + cleanupCtx context.Context + cleanupCancel context.CancelFunc +} + +// StreamedEvent represents a buffered SSE event for streaming resume +type StreamedEvent struct { + SequenceNumber int `json:"sequence_number"` + EventType string `json:"event_type"` + Data []byte `json:"data"` // JSON-serialized event +} + +// StoredResponse contains a complete response with its input request and output items +type StoredResponse struct { + Request *schema.OpenResponsesRequest + Response *schema.ORResponseResource + Items map[string]*schema.ORItemField // item_id -> item mapping for quick lookup + StoredAt time.Time + ExpiresAt *time.Time // nil if no expiration + + // Background execution support + CancelFunc context.CancelFunc // For cancellation of background tasks + StreamEvents []StreamedEvent // Buffered events for streaming resume + StreamEnabled bool // Was created with stream=true + IsBackground bool // Was created with background=true + EventsChan chan struct{} // Signals new events for live subscribers + mu sync.RWMutex // Protect concurrent access to this response +} + +var ( + globalStore *ResponseStore + storeOnce sync.Once +) + +// GetGlobalStore returns the singleton response store instance +func GetGlobalStore() *ResponseStore { + storeOnce.Do(func() { + globalStore = NewResponseStore(0) // Default: no TTL, will be updated from appConfig + }) + return globalStore +} + +// SetTTL updates the TTL for the store +// This will affect all new responses stored after this call +func (s *ResponseStore) SetTTL(ttl time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + + // Stop existing cleanup loop if running + if s.cleanupCancel != nil { + s.cleanupCancel() + s.cleanupCancel = nil + s.cleanupCtx = nil + } + + s.ttl = ttl + + // If TTL > 0, start cleanup loop + if ttl > 0 { + s.cleanupCtx, s.cleanupCancel = context.WithCancel(context.Background()) + go s.cleanupLoop(s.cleanupCtx) + } + + xlog.Debug("Updated Open Responses store TTL", "ttl", ttl, "cleanup_running", ttl > 0) +} + +// NewResponseStore creates a new response store with optional TTL +// If ttl is 0, responses are stored indefinitely +func NewResponseStore(ttl time.Duration) *ResponseStore { + store := &ResponseStore{ + responses: make(map[string]*StoredResponse), + ttl: ttl, + } + + // Start cleanup goroutine if TTL is set + if ttl > 0 { + store.cleanupCtx, store.cleanupCancel = context.WithCancel(context.Background()) + go store.cleanupLoop(store.cleanupCtx) + } + + return store +} + +// Store stores a response with its request and items +func (s *ResponseStore) Store(responseID string, request *schema.OpenResponsesRequest, response *schema.ORResponseResource) { + s.mu.Lock() + defer s.mu.Unlock() + + // Build item index for quick lookup + items := make(map[string]*schema.ORItemField) + for i := range response.Output { + item := &response.Output[i] + if item.ID != "" { + items[item.ID] = item + } + } + + stored := &StoredResponse{ + Request: request, + Response: response, + Items: items, + StoredAt: time.Now(), + ExpiresAt: nil, + } + + // Set expiration if TTL is configured + if s.ttl > 0 { + expiresAt := time.Now().Add(s.ttl) + stored.ExpiresAt = &expiresAt + } + + s.responses[responseID] = stored + xlog.Debug("Stored Open Responses response", "response_id", responseID, "items_count", len(items)) +} + +// Get retrieves a stored response by ID +func (s *ResponseStore) Get(responseID string) (*StoredResponse, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + stored, exists := s.responses[responseID] + if !exists { + return nil, fmt.Errorf("response not found: %s", responseID) + } + + // Check expiration + if stored.ExpiresAt != nil && time.Now().After(*stored.ExpiresAt) { + // Expired, but we'll return it anyway and let caller handle cleanup + return nil, fmt.Errorf("response expired: %s", responseID) + } + + return stored, nil +} + +// GetItem retrieves a specific item from a stored response +func (s *ResponseStore) GetItem(responseID, itemID string) (*schema.ORItemField, error) { + stored, err := s.Get(responseID) + if err != nil { + return nil, err + } + + item, exists := stored.Items[itemID] + if !exists { + return nil, fmt.Errorf("item not found: %s in response %s", itemID, responseID) + } + + return item, nil +} + +// FindItem searches for an item across all stored responses +// Returns the item and the response ID it was found in +func (s *ResponseStore) FindItem(itemID string) (*schema.ORItemField, string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + now := time.Now() + for responseID, stored := range s.responses { + // Skip expired responses + if stored.ExpiresAt != nil && now.After(*stored.ExpiresAt) { + continue + } + + if item, exists := stored.Items[itemID]; exists { + return item, responseID, nil + } + } + + return nil, "", fmt.Errorf("item not found in any stored response: %s", itemID) +} + +// Delete removes a response from storage +func (s *ResponseStore) Delete(responseID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.responses, responseID) + xlog.Debug("Deleted Open Responses response", "response_id", responseID) +} + +// Cleanup removes expired responses +func (s *ResponseStore) Cleanup() int { + if s.ttl == 0 { + return 0 + } + + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + count := 0 + for id, stored := range s.responses { + if stored.ExpiresAt != nil && now.After(*stored.ExpiresAt) { + delete(s.responses, id) + count++ + } + } + + if count > 0 { + xlog.Debug("Cleaned up expired Open Responses", "count", count) + } + + return count +} + +// cleanupLoop runs periodic cleanup of expired responses +func (s *ResponseStore) cleanupLoop(ctx context.Context) { + if s.ttl == 0 { + return + } + + ticker := time.NewTicker(s.ttl / 2) // Cleanup at half TTL interval + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + xlog.Debug("Stopped Open Responses store cleanup loop") + return + case <-ticker.C: + s.Cleanup() + } + } +} + +// Count returns the number of stored responses +func (s *ResponseStore) Count() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.responses) +} + +// StoreBackground stores a background response with cancel function and optional streaming support +func (s *ResponseStore) StoreBackground(responseID string, request *schema.OpenResponsesRequest, response *schema.ORResponseResource, cancelFunc context.CancelFunc, streamEnabled bool) { + s.mu.Lock() + defer s.mu.Unlock() + + // Build item index for quick lookup + items := make(map[string]*schema.ORItemField) + for i := range response.Output { + item := &response.Output[i] + if item.ID != "" { + items[item.ID] = item + } + } + + stored := &StoredResponse{ + Request: request, + Response: response, + Items: items, + StoredAt: time.Now(), + ExpiresAt: nil, + CancelFunc: cancelFunc, + StreamEvents: []StreamedEvent{}, + StreamEnabled: streamEnabled, + IsBackground: true, + EventsChan: make(chan struct{}, 100), // Buffered channel for event notifications + } + + // Set expiration if TTL is configured + if s.ttl > 0 { + expiresAt := time.Now().Add(s.ttl) + stored.ExpiresAt = &expiresAt + } + + s.responses[responseID] = stored + xlog.Debug("Stored background Open Responses response", "response_id", responseID, "stream_enabled", streamEnabled) +} + +// UpdateStatus updates the status of a stored response +func (s *ResponseStore) UpdateStatus(responseID string, status string, completedAt *int64) error { + s.mu.RLock() + stored, exists := s.responses[responseID] + s.mu.RUnlock() + + if !exists { + return fmt.Errorf("response not found: %s", responseID) + } + + stored.mu.Lock() + defer stored.mu.Unlock() + + stored.Response.Status = status + stored.Response.CompletedAt = completedAt + + xlog.Debug("Updated response status", "response_id", responseID, "status", status) + return nil +} + +// UpdateResponse updates the entire response object for a stored response +func (s *ResponseStore) UpdateResponse(responseID string, response *schema.ORResponseResource) error { + s.mu.RLock() + stored, exists := s.responses[responseID] + s.mu.RUnlock() + + if !exists { + return fmt.Errorf("response not found: %s", responseID) + } + + stored.mu.Lock() + defer stored.mu.Unlock() + + // Rebuild item index + items := make(map[string]*schema.ORItemField) + for i := range response.Output { + item := &response.Output[i] + if item.ID != "" { + items[item.ID] = item + } + } + + stored.Response = response + stored.Items = items + + xlog.Debug("Updated response", "response_id", responseID, "status", response.Status, "items_count", len(items)) + return nil +} + +// AppendEvent appends a streaming event to the buffer for resume support +func (s *ResponseStore) AppendEvent(responseID string, event *schema.ORStreamEvent) error { + s.mu.RLock() + stored, exists := s.responses[responseID] + s.mu.RUnlock() + + if !exists { + return fmt.Errorf("response not found: %s", responseID) + } + + // Serialize the event + data, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("failed to marshal event: %w", err) + } + + stored.mu.Lock() + stored.StreamEvents = append(stored.StreamEvents, StreamedEvent{ + SequenceNumber: event.SequenceNumber, + EventType: event.Type, + Data: data, + }) + stored.mu.Unlock() + + // Notify any subscribers of new event + select { + case stored.EventsChan <- struct{}{}: + default: + // Channel full, subscribers will catch up + } + + return nil +} + +// GetEventsAfter returns all events with sequence number greater than startingAfter +func (s *ResponseStore) GetEventsAfter(responseID string, startingAfter int) ([]StreamedEvent, error) { + s.mu.RLock() + stored, exists := s.responses[responseID] + s.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("response not found: %s", responseID) + } + + stored.mu.RLock() + defer stored.mu.RUnlock() + + var result []StreamedEvent + for _, event := range stored.StreamEvents { + if event.SequenceNumber > startingAfter { + result = append(result, event) + } + } + + return result, nil +} + +// Cancel cancels a background response if it's still in progress +func (s *ResponseStore) Cancel(responseID string) (*schema.ORResponseResource, error) { + s.mu.RLock() + stored, exists := s.responses[responseID] + s.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("response not found: %s", responseID) + } + + stored.mu.Lock() + defer stored.mu.Unlock() + + // If already in a terminal state, just return the response (idempotent) + status := stored.Response.Status + if status == schema.ORStatusCompleted || status == schema.ORStatusFailed || + status == schema.ORStatusIncomplete || status == schema.ORStatusCancelled { + xlog.Debug("Response already in terminal state", "response_id", responseID, "status", status) + return stored.Response, nil + } + + // Cancel the context if available + if stored.CancelFunc != nil { + stored.CancelFunc() + xlog.Debug("Cancelled background response", "response_id", responseID) + } + + // Update status to cancelled + now := time.Now().Unix() + stored.Response.Status = schema.ORStatusCancelled + stored.Response.CompletedAt = &now + + return stored.Response, nil +} + +// GetEventsChan returns the events notification channel for a response +func (s *ResponseStore) GetEventsChan(responseID string) (chan struct{}, error) { + s.mu.RLock() + stored, exists := s.responses[responseID] + s.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("response not found: %s", responseID) + } + + return stored.EventsChan, nil +} + +// IsStreamEnabled checks if a response was created with streaming enabled +func (s *ResponseStore) IsStreamEnabled(responseID string) (bool, error) { + s.mu.RLock() + stored, exists := s.responses[responseID] + s.mu.RUnlock() + + if !exists { + return false, fmt.Errorf("response not found: %s", responseID) + } + + stored.mu.RLock() + defer stored.mu.RUnlock() + + return stored.StreamEnabled, nil +} diff --git a/core/http/endpoints/openresponses/store_suite_test.go b/core/http/endpoints/openresponses/store_suite_test.go new file mode 100644 index 000000000..7ab45cece --- /dev/null +++ b/core/http/endpoints/openresponses/store_suite_test.go @@ -0,0 +1,13 @@ +package openresponses + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestStore(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "ResponseStore Suite") +} diff --git a/core/http/endpoints/openresponses/store_test.go b/core/http/endpoints/openresponses/store_test.go new file mode 100644 index 000000000..e0dcdba68 --- /dev/null +++ b/core/http/endpoints/openresponses/store_test.go @@ -0,0 +1,626 @@ +package openresponses + +import ( + "context" + "fmt" + "time" + + "github.com/mudler/LocalAI/core/schema" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ResponseStore", func() { + var store *ResponseStore + + BeforeEach(func() { + store = NewResponseStore(0) // No TTL for most tests + }) + + AfterEach(func() { + // Clean up + }) + + Describe("Store and Get", func() { + It("should store and retrieve a response", func() { + responseID := "resp_test123" + request := &schema.OpenResponsesRequest{ + Model: "test-model", + Input: "Hello", + } + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + CreatedAt: time.Now().Unix(), + Status: "completed", + Model: "test-model", + Output: []schema.ORItemField{ + { + Type: "message", + ID: "msg_123", + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{{ + Type: "output_text", + Text: "Hello, world!", + Annotations: []schema.ORAnnotation{}, + Logprobs: []schema.ORLogProb{}, + }}, + }, + }, + } + + store.Store(responseID, request, response) + + stored, err := store.Get(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(stored).ToNot(BeNil()) + Expect(stored.Response.ID).To(Equal(responseID)) + Expect(stored.Request.Model).To(Equal("test-model")) + Expect(len(stored.Items)).To(Equal(1)) + Expect(stored.Items["msg_123"]).ToNot(BeNil()) + Expect(stored.Items["msg_123"].ID).To(Equal("msg_123")) + }) + + It("should return error for non-existent response", func() { + _, err := store.Get("nonexistent") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not found")) + }) + + It("should index all items by ID", func() { + responseID := "resp_test456" + request := &schema.OpenResponsesRequest{ + Model: "test-model", + Input: "Test", + } + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Output: []schema.ORItemField{ + { + Type: "message", + ID: "msg_1", + Status: "completed", + Role: "assistant", + }, + { + Type: "function_call", + ID: "fc_1", + Status: "completed", + CallID: "fc_1", + Name: "test_function", + Arguments: `{"arg": "value"}`, + }, + { + Type: "message", + ID: "msg_2", + Status: "completed", + Role: "assistant", + }, + }, + } + + store.Store(responseID, request, response) + + stored, err := store.Get(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(len(stored.Items)).To(Equal(3)) + Expect(stored.Items["msg_1"]).ToNot(BeNil()) + Expect(stored.Items["fc_1"]).ToNot(BeNil()) + Expect(stored.Items["msg_2"]).ToNot(BeNil()) + }) + + It("should handle items without IDs", func() { + responseID := "resp_test789" + request := &schema.OpenResponsesRequest{ + Model: "test-model", + Input: "Test", + } + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Output: []schema.ORItemField{ + { + Type: "message", + ID: "", // No ID + Status: "completed", + Role: "assistant", + }, + { + Type: "message", + ID: "msg_with_id", + Status: "completed", + Role: "assistant", + }, + }, + } + + store.Store(responseID, request, response) + + stored, err := store.Get(responseID) + Expect(err).ToNot(HaveOccurred()) + // Only items with IDs are indexed + Expect(len(stored.Items)).To(Equal(1)) + Expect(stored.Items["msg_with_id"]).ToNot(BeNil()) + }) + }) + + Describe("GetItem", func() { + It("should retrieve a specific item by ID", func() { + responseID := "resp_item_test" + itemID := "msg_specific" + request := &schema.OpenResponsesRequest{ + Model: "test-model", + Input: "Test", + } + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Output: []schema.ORItemField{ + { + Type: "message", + ID: itemID, + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{{ + Type: "output_text", + Text: "Specific message", + Annotations: []schema.ORAnnotation{}, + Logprobs: []schema.ORLogProb{}, + }}, + }, + }, + } + + store.Store(responseID, request, response) + + item, err := store.GetItem(responseID, itemID) + Expect(err).ToNot(HaveOccurred()) + Expect(item).ToNot(BeNil()) + Expect(item.ID).To(Equal(itemID)) + Expect(item.Type).To(Equal("message")) + }) + + It("should return error for non-existent item", func() { + responseID := "resp_item_test2" + request := &schema.OpenResponsesRequest{ + Model: "test-model", + Input: "Test", + } + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Output: []schema.ORItemField{ + { + Type: "message", + ID: "msg_existing", + Status: "completed", + }, + }, + } + + store.Store(responseID, request, response) + + _, err := store.GetItem(responseID, "nonexistent_item") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("item not found")) + }) + + It("should return error for non-existent response when getting item", func() { + _, err := store.GetItem("nonexistent_response", "any_item") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("response not found")) + }) + }) + + Describe("FindItem", func() { + It("should find an item across all stored responses", func() { + // Store first response + responseID1 := "resp_find_1" + itemID1 := "msg_find_1" + store.Store(responseID1, &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ + ID: responseID1, + Object: "response", + Output: []schema.ORItemField{ + {Type: "message", ID: itemID1, Status: "completed"}, + }, + }) + + // Store second response + responseID2 := "resp_find_2" + itemID2 := "msg_find_2" + store.Store(responseID2, &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ + ID: responseID2, + Object: "response", + Output: []schema.ORItemField{ + {Type: "message", ID: itemID2, Status: "completed"}, + }, + }) + + // Find item from first response + item, foundResponseID, err := store.FindItem(itemID1) + Expect(err).ToNot(HaveOccurred()) + Expect(item).ToNot(BeNil()) + Expect(item.ID).To(Equal(itemID1)) + Expect(foundResponseID).To(Equal(responseID1)) + + // Find item from second response + item, foundResponseID, err = store.FindItem(itemID2) + Expect(err).ToNot(HaveOccurred()) + Expect(item).ToNot(BeNil()) + Expect(item.ID).To(Equal(itemID2)) + Expect(foundResponseID).To(Equal(responseID2)) + }) + + It("should return error when item not found in any response", func() { + _, _, err := store.FindItem("nonexistent_item") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("item not found in any stored response")) + }) + }) + + Describe("Delete", func() { + It("should delete a stored response", func() { + responseID := "resp_delete_test" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + } + + store.Store(responseID, request, response) + Expect(store.Count()).To(Equal(1)) + + store.Delete(responseID) + Expect(store.Count()).To(Equal(0)) + + _, err := store.Get(responseID) + Expect(err).To(HaveOccurred()) + }) + + It("should handle deleting non-existent response gracefully", func() { + // Should not panic + store.Delete("nonexistent") + Expect(store.Count()).To(Equal(0)) + }) + }) + + Describe("Count", func() { + It("should return correct count of stored responses", func() { + Expect(store.Count()).To(Equal(0)) + + store.Store("resp_1", &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ID: "resp_1", Object: "response"}) + Expect(store.Count()).To(Equal(1)) + + store.Store("resp_2", &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ID: "resp_2", Object: "response"}) + Expect(store.Count()).To(Equal(2)) + + store.Delete("resp_1") + Expect(store.Count()).To(Equal(1)) + }) + }) + + Describe("TTL and Expiration", func() { + It("should set expiration when TTL is configured", func() { + ttlStore := NewResponseStore(100 * time.Millisecond) + responseID := "resp_ttl_test" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ID: responseID, Object: "response"} + + ttlStore.Store(responseID, request, response) + + stored, err := ttlStore.Get(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(stored.ExpiresAt).ToNot(BeNil()) + Expect(stored.ExpiresAt.After(time.Now())).To(BeTrue()) + }) + + It("should not set expiration when TTL is 0", func() { + responseID := "resp_no_ttl" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ID: responseID, Object: "response"} + + store.Store(responseID, request, response) + + stored, err := store.Get(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(stored.ExpiresAt).To(BeNil()) + }) + + It("should clean up expired responses", func() { + ttlStore := NewResponseStore(50 * time.Millisecond) + responseID := "resp_expire_test" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ID: responseID, Object: "response"} + + ttlStore.Store(responseID, request, response) + Expect(ttlStore.Count()).To(Equal(1)) + + // Wait for expiration (longer than TTL and cleanup interval) + time.Sleep(150 * time.Millisecond) + + // Cleanup should remove expired response (may have already been cleaned by goroutine) + count := ttlStore.Cleanup() + // Count might be 0 if cleanup goroutine already ran, or 1 if we're first + Expect(count).To(BeNumerically(">=", 0)) + Expect(ttlStore.Count()).To(Equal(0)) + + _, err := ttlStore.Get(responseID) + Expect(err).To(HaveOccurred()) + }) + + It("should return error for expired response", func() { + ttlStore := NewResponseStore(50 * time.Millisecond) + responseID := "resp_expire_error" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ID: responseID, Object: "response"} + + ttlStore.Store(responseID, request, response) + + // Wait for expiration (but not long enough for cleanup goroutine to remove it) + time.Sleep(75 * time.Millisecond) + + // Try to get before cleanup goroutine removes it + _, err := ttlStore.Get(responseID) + // Error could be "expired" or "not found" (if cleanup already ran) + Expect(err).To(HaveOccurred()) + // Either error message is acceptable + errMsg := err.Error() + Expect(errMsg).To(Or(ContainSubstring("expired"), ContainSubstring("not found"))) + }) + }) + + Describe("Thread Safety", func() { + It("should handle concurrent stores and gets", func() { + // This is a basic concurrency test + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func(id int) { + responseID := fmt.Sprintf("resp_concurrent_%d", id) + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Output: []schema.ORItemField{ + {Type: "message", ID: fmt.Sprintf("msg_%d", id), Status: "completed"}, + }, + } + store.Store(responseID, request, response) + + // Retrieve immediately + stored, err := store.Get(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(stored).ToNot(BeNil()) + done <- true + }(i) + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } + + Expect(store.Count()).To(Equal(10)) + }) + }) + + Describe("GetGlobalStore", func() { + It("should return singleton instance", func() { + store1 := GetGlobalStore() + store2 := GetGlobalStore() + Expect(store1).To(Equal(store2)) + }) + + It("should persist data across GetGlobalStore calls", func() { + globalStore := GetGlobalStore() + responseID := "resp_global_test" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ID: responseID, Object: "response"} + + globalStore.Store(responseID, request, response) + + // Get store again + globalStore2 := GetGlobalStore() + stored, err := globalStore2.Get(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(stored).ToNot(BeNil()) + }) + }) + + Describe("Background Mode Support", func() { + It("should store background response with cancel function", func() { + responseID := "resp_bg_test" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Status: schema.ORStatusQueued, + } + + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + store.StoreBackground(responseID, request, response, cancel, true) + + stored, err := store.Get(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(stored).ToNot(BeNil()) + Expect(stored.IsBackground).To(BeTrue()) + Expect(stored.StreamEnabled).To(BeTrue()) + Expect(stored.CancelFunc).ToNot(BeNil()) + }) + + It("should update status of stored response", func() { + responseID := "resp_status_test" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Status: schema.ORStatusQueued, + } + + store.Store(responseID, request, response) + + err := store.UpdateStatus(responseID, schema.ORStatusInProgress, nil) + Expect(err).ToNot(HaveOccurred()) + + stored, err := store.Get(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(stored.Response.Status).To(Equal(schema.ORStatusInProgress)) + }) + + It("should append and retrieve streaming events", func() { + responseID := "resp_events_test" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Status: schema.ORStatusInProgress, + } + + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + store.StoreBackground(responseID, request, response, cancel, true) + + // Append events + event1 := &schema.ORStreamEvent{ + Type: "response.created", + SequenceNumber: 0, + } + event2 := &schema.ORStreamEvent{ + Type: "response.in_progress", + SequenceNumber: 1, + } + event3 := &schema.ORStreamEvent{ + Type: "response.output_text.delta", + SequenceNumber: 2, + } + + err := store.AppendEvent(responseID, event1) + Expect(err).ToNot(HaveOccurred()) + err = store.AppendEvent(responseID, event2) + Expect(err).ToNot(HaveOccurred()) + err = store.AppendEvent(responseID, event3) + Expect(err).ToNot(HaveOccurred()) + + // Get all events after -1 (all events) + events, err := store.GetEventsAfter(responseID, -1) + Expect(err).ToNot(HaveOccurred()) + Expect(events).To(HaveLen(3)) + + // Get events after sequence 1 + events, err = store.GetEventsAfter(responseID, 1) + Expect(err).ToNot(HaveOccurred()) + Expect(events).To(HaveLen(1)) + Expect(events[0].SequenceNumber).To(Equal(2)) + }) + + It("should cancel an in-progress response", func() { + responseID := "resp_cancel_test" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Status: schema.ORStatusInProgress, + } + + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + store.StoreBackground(responseID, request, response, cancel, false) + + // Cancel the response + cancelledResponse, err := store.Cancel(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(cancelledResponse.Status).To(Equal(schema.ORStatusCancelled)) + Expect(cancelledResponse.CompletedAt).ToNot(BeNil()) + }) + + It("should be idempotent when cancelling already completed response", func() { + responseID := "resp_idempotent_cancel" + request := &schema.OpenResponsesRequest{Model: "test"} + completedAt := time.Now().Unix() + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Status: schema.ORStatusCompleted, + CompletedAt: &completedAt, + } + + store.Store(responseID, request, response) + + // Try to cancel a completed response + cancelledResponse, err := store.Cancel(responseID) + Expect(err).ToNot(HaveOccurred()) + // Status should remain completed (not changed to cancelled) + Expect(cancelledResponse.Status).To(Equal(schema.ORStatusCompleted)) + }) + + It("should check if streaming is enabled", func() { + responseID := "resp_stream_check" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Status: schema.ORStatusQueued, + } + + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + store.StoreBackground(responseID, request, response, cancel, true) + + enabled, err := store.IsStreamEnabled(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(enabled).To(BeTrue()) + + // Store another without streaming + responseID2 := "resp_no_stream" + store.StoreBackground(responseID2, request, response, cancel, false) + + enabled2, err := store.IsStreamEnabled(responseID2) + Expect(err).ToNot(HaveOccurred()) + Expect(enabled2).To(BeFalse()) + }) + + It("should notify subscribers of new events", func() { + responseID := "resp_events_chan" + request := &schema.OpenResponsesRequest{Model: "test"} + response := &schema.ORResponseResource{ + ID: responseID, + Object: "response", + Status: schema.ORStatusInProgress, + } + + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + store.StoreBackground(responseID, request, response, cancel, true) + + eventsChan, err := store.GetEventsChan(responseID) + Expect(err).ToNot(HaveOccurred()) + Expect(eventsChan).ToNot(BeNil()) + + // Append an event + event := &schema.ORStreamEvent{ + Type: "response.output_text.delta", + SequenceNumber: 0, + } + + go func() { + time.Sleep(10 * time.Millisecond) + store.AppendEvent(responseID, event) + }() + + // Wait for notification + select { + case <-eventsChan: + // Event received + case <-time.After(1 * time.Second): + Fail("Timeout waiting for event notification") + } + }) + }) +}) diff --git a/core/http/http_suite_test.go b/core/http/http_suite_test.go index 94467437f..805eb5b52 100644 --- a/core/http/http_suite_test.go +++ b/core/http/http_suite_test.go @@ -1,13 +1,33 @@ package http_test import ( + "os" + "path/filepath" "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +var ( + tmpdir string + modelDir string +) + func TestLocalAI(t *testing.T) { RegisterFailHandler(Fail) + + var err error + tmpdir, err = os.MkdirTemp("", "") + Expect(err).ToNot(HaveOccurred()) + modelDir = filepath.Join(tmpdir, "models") + err = os.Mkdir(modelDir, 0750) + Expect(err).ToNot(HaveOccurred()) + + AfterSuite(func() { + err := os.RemoveAll(tmpdir) + Expect(err).ToNot(HaveOccurred()) + }) + RunSpecs(t, "LocalAI HTTP test suite") } diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index 76d7fee64..115a00149 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -484,3 +484,103 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. } return fmt.Errorf("unable to validate configuration after merging") } + +func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error { + input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return echo.ErrBadRequest + } + + // Extract or generate the correlation ID (Open Responses uses x-request-id) + correlationID := c.Request().Header.Get("x-request-id") + if correlationID == "" { + correlationID = uuid.New().String() + } + c.Response().Header().Set("x-request-id", correlationID) + + // Use the request context directly - Echo properly supports context cancellation! + reqCtx := c.Request().Context() + c1, cancel := context.WithCancel(re.applicationConfig.Context) + + // Cancel when request context is cancelled (client disconnects) + go func() { + select { + case <-reqCtx.Done(): + cancel() + case <-c1.Done(): + // Already cancelled + } + }() + + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID + input.Cancel = cancel + + err := mergeOpenResponsesRequestAndModelConfig(cfg, input) + if err != nil { + return err + } + + if cfg.Model == "" { + xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) + cfg.Model = input.Model + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return nil +} + +func mergeOpenResponsesRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error { + // Temperature + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + // TopP + if input.TopP != nil { + config.TopP = input.TopP + } + + // MaxOutputTokens -> Maxtokens + if input.MaxOutputTokens != nil { + config.Maxtokens = input.MaxOutputTokens + } + + // Convert tools to functions - this will be handled in the endpoint handler + // We just validate that tools are present if needed + + // Handle tool_choice + if input.ToolChoice != nil { + switch tc := input.ToolChoice.(type) { + case string: + // "auto", "required", or "none" + if tc == "required" { + config.SetFunctionCallString("required") + } else if tc == "none" { + // Don't use tools - handled in endpoint + } + // "auto" is default - let model decide + case map[string]interface{}: + // Specific tool: {type:"function", name:"..."} + if tcType, ok := tc["type"].(string); ok && tcType == "function" { + if name, ok := tc["name"].(string); ok { + config.SetFunctionCallString(name) + } + } + } + } + + if valid, _ := config.Validate(); valid { + return nil + } + return fmt.Errorf("unable to validate configuration after merging") +} diff --git a/core/http/openresponses_test.go b/core/http/openresponses_test.go new file mode 100644 index 000000000..61a448c62 --- /dev/null +++ b/core/http/openresponses_test.go @@ -0,0 +1,1027 @@ +package http_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + . "github.com/mudler/LocalAI/core/http" + "github.com/mudler/LocalAI/pkg/system" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/xlog" +) + +const testModel = "Qwen3-VL-2B-Instruct-GGUF" + +var _ = Describe("Open Responses API", func() { + var app *echo.Echo + var c context.Context + var cancel context.CancelFunc + + commonOpts := []config.AppOption{ + config.WithDebug(true), + } + + Context("API with ephemeral models", func() { + BeforeEach(func(sc SpecContext) { + var err error + + backendPath := os.Getenv("BACKENDS_PATH") + + c, cancel = context.WithCancel(context.Background()) + + systemState, err := system.GetSystemState( + system.WithBackendPath(backendPath), + system.WithModelPath(modelDir), + ) + Expect(err).ToNot(HaveOccurred()) + + application, err := application.New( + append(commonOpts, + config.WithContext(c), + config.WithSystemState(systemState), + config.WithApiKeys([]string{apiKey}), + config.WithModelsURL("https://huggingface.co/unsloth/Qwen3-VL-2B-Instruct-GGUF"), + )...) + Expect(err).ToNot(HaveOccurred()) + + app, err = API(application) + Expect(err).ToNot(HaveOccurred()) + + go func() { + if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { + xlog.Error("server error", "error", err) + } + }() + + // Wait for API to be ready + Eventually(func() error { + resp, err := http.Get("http://127.0.0.1:9090/healthz") + if err != nil { + return err + } + resp.Body.Close() + return nil + }, "2m").ShouldNot(HaveOccurred()) + }) + + AfterEach(func(sc SpecContext) { + cancel() + if app != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := app.Shutdown(ctx) + Expect(err).ToNot(HaveOccurred()) + } + + }) + + Context("HTTP Protocol Compliance", func() { + It("MUST accept application/json Content-Type", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + // Should accept the request (may fail on model not found, but should accept Content-Type) + Expect(resp.StatusCode).To(Or(Equal(200), Equal(400), Equal(500))) + }) + + It("MUST return application/json for non-streaming responses", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + "stream": false, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + contentType := resp.Header.Get("Content-Type") + if resp.StatusCode == 200 { + Expect(contentType).To(ContainSubstring("application/json")) + } + }) + + It("MUST return text/event-stream for streaming responses", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + "stream": true, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + contentType := resp.Header.Get("Content-Type") + if resp.StatusCode == 200 { + Expect(contentType).To(Equal("text/event-stream")) + } + }) + + It("MUST end streaming with [DONE] terminal event", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + "stream": true, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + body, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + bodyStr := string(body) + // Should end with [DONE] + Expect(bodyStr).To(ContainSubstring("data: [DONE]")) + } + }) + + It("MUST have event field matching type in body", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + "stream": true, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + body, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + bodyStr := string(body) + + // Parse SSE events + lines := strings.Split(bodyStr, "\n") + for i, line := range lines { + if strings.HasPrefix(line, "event: ") { + eventType := strings.TrimPrefix(line, "event: ") + // Next line should be data: with matching type + if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { + dataLine := strings.TrimPrefix(lines[i+1], "data: ") + var eventData map[string]interface{} + if err := json.Unmarshal([]byte(dataLine), &eventData); err == nil { + if typeVal, ok := eventData["type"].(string); ok { + Expect(typeVal).To(Equal(eventType)) + } + } + } + } + } + } + }) + }) + + Context("Response Structure", func() { + It("MUST return id field", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response).To(HaveKey("id")) + Expect(response["id"]).ToNot(BeEmpty()) + } + }) + + It("MUST return object field as 'response'", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response).To(HaveKey("object")) + Expect(response["object"]).To(Equal("response")) + } + }) + + It("MUST return created_at timestamp", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response).To(HaveKey("created_at")) + // Should be a number (unix timestamp) + createdAt, ok := response["created_at"].(float64) + Expect(ok).To(BeTrue()) + Expect(createdAt).To(BeNumerically(">", 0)) + } + }) + + It("MUST return status field", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response).To(HaveKey("status")) + status, ok := response["status"].(string) + Expect(ok).To(BeTrue()) + Expect(status).To(BeElementOf("in_progress", "completed", "failed", "incomplete")) + } + }) + + It("MUST return model field", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response).To(HaveKey("model")) + Expect(response["model"]).ToNot(BeEmpty()) + } + }) + + It("MUST return output array of items", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response).To(HaveKey("output")) + output, ok := response["output"].([]interface{}) + Expect(ok).To(BeTrue()) + Expect(output).ToNot(BeNil()) + } + }) + }) + + Context("Items", func() { + It("MUST include id field on all items", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + + output, ok := response["output"].([]interface{}) + if ok { + for _, item := range output { + itemMap, ok := item.(map[string]interface{}) + Expect(ok).To(BeTrue()) + Expect(itemMap).To(HaveKey("id")) + Expect(itemMap["id"]).ToNot(BeEmpty()) + } + } + } + }) + + It("MUST include type field on all items", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + + output, ok := response["output"].([]interface{}) + if ok { + for _, item := range output { + itemMap, ok := item.(map[string]interface{}) + Expect(ok).To(BeTrue()) + Expect(itemMap).To(HaveKey("type")) + Expect(itemMap["type"]).ToNot(BeEmpty()) + } + } + } + }) + + It("MUST include status field on all items", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + + output, ok := response["output"].([]interface{}) + if ok { + for _, item := range output { + itemMap, ok := item.(map[string]interface{}) + Expect(ok).To(BeTrue()) + Expect(itemMap).To(HaveKey("status")) + status, ok := itemMap["status"].(string) + Expect(ok).To(BeTrue()) + Expect(status).To(BeElementOf("in_progress", "completed", "incomplete")) + } + } + } + }) + + It("MUST support message items with role field", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": []map[string]interface{}{ + { + "type": "message", + "role": "user", + "content": []map[string]interface{}{ + { + "type": "input_text", + "text": "Hello", + }, + }, + }, + }, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + + output, ok := response["output"].([]interface{}) + if ok && len(output) > 0 { + itemMap, ok := output[0].(map[string]interface{}) + Expect(ok).To(BeTrue()) + if itemMap["type"] == "message" { + Expect(itemMap).To(HaveKey("role")) + role, ok := itemMap["role"].(string) + Expect(ok).To(BeTrue()) + Expect(role).To(BeElementOf("user", "assistant", "system", "developer")) + } + } + } + }) + }) + + Context("Content Types", func() { + It("MUST support input_text content", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": []map[string]interface{}{ + { + "type": "message", + "role": "user", + "content": []map[string]interface{}{ + { + "type": "input_text", + "text": "Hello world", + }, + }, + }, + }, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + // Should accept the request + Expect(resp.StatusCode).To(Or(Equal(200), Equal(400), Equal(500))) + }) + + It("MUST support input_image content with URL", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": []map[string]interface{}{ + { + "type": "message", + "role": "user", + "content": []map[string]interface{}{ + { + "type": "input_image", + "image_url": "https://example.com/image.png", + "detail": "auto", + }, + }, + }, + }, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + // Should accept the request + Expect(resp.StatusCode).To(Or(Equal(200), Equal(400), Equal(500))) + }) + + It("MUST support input_image content with base64", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": []map[string]interface{}{ + { + "type": "message", + "role": "user", + "content": []map[string]interface{}{ + { + "type": "input_image", + "image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", + "detail": "auto", + }, + }, + }, + }, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + // Should accept the request + Expect(resp.StatusCode).To(Or(Equal(200), Equal(400), Equal(500))) + }) + + It("MUST support output_text content", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + var response map[string]interface{} + body, _ := io.ReadAll(resp.Body) + err = json.Unmarshal(body, &response) + Expect(err).ToNot(HaveOccurred()) + + output, ok := response["output"].([]interface{}) + if ok && len(output) > 0 { + itemMap, ok := output[0].(map[string]interface{}) + Expect(ok).To(BeTrue()) + if itemMap["type"] == "message" { + content, ok := itemMap["content"].([]interface{}) + if ok && len(content) > 0 { + contentMap, ok := content[0].(map[string]interface{}) + if ok { + contentType, _ := contentMap["type"].(string) + if contentType == "output_text" { + Expect(contentMap).To(HaveKey("text")) + } + } + } + } + } + } + }) + }) + + Context("Streaming Events", func() { + It("MUST emit response.created as first event", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + "stream": true, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + body, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + bodyStr := string(body) + + // Should contain response.created event + Expect(bodyStr).To(ContainSubstring("response.created")) + } + }) + + It("MUST include sequence_number in all events", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Hello", + "stream": true, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode == 200 { + body, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + bodyStr := string(body) + + // Parse SSE events and check for sequence_number + lines := strings.Split(bodyStr, "\n") + for _, line := range lines { + if strings.HasPrefix(line, "data: ") { + dataLine := strings.TrimPrefix(line, "data: ") + if dataLine != "[DONE]" { + var eventData map[string]interface{} + if err := json.Unmarshal([]byte(dataLine), &eventData); err == nil { + if _, hasType := eventData["type"]; hasType { + Expect(eventData).To(HaveKey("sequence_number")) + } + } + } + } + } + } + }) + }) + + Context("Error Handling", func() { + It("MUST return structured error with type and message fields", func() { + reqBody := map[string]interface{}{ + "model": "nonexistent-model", + "input": "Hello", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + var errorResp map[string]interface{} + body, _ := io.ReadAll(resp.Body) + json.Unmarshal(body, &errorResp) + + if errorResp["error"] != nil { + errorObj, ok := errorResp["error"].(map[string]interface{}) + if ok { + Expect(errorObj).To(HaveKey("type")) + Expect(errorObj).To(HaveKey("message")) + } + } + } + }) + }) + + Context("Previous Response ID", func() { + It("should load previous response and concatenate context", func() { + // First, create a response + reqBody1 := map[string]interface{}{ + "model": testModel, + "input": "What is 2+2?", + } + payload1, err := json.Marshal(reqBody1) + Expect(err).ToNot(HaveOccurred()) + + req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1)) + Expect(err).ToNot(HaveOccurred()) + req1.Header.Set("Content-Type", "application/json") + req1.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp1, err := client.Do(req1) + Expect(err).ToNot(HaveOccurred()) + defer resp1.Body.Close() + + // Check if first response succeeded + if resp1.StatusCode != 200 { + Skip("First response failed, skipping previous_response_id test (backend may not be available)") + } + + var response1 map[string]interface{} + body1, err := io.ReadAll(resp1.Body) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal(body1, &response1) + Expect(err).ToNot(HaveOccurred()) + + responseID, ok := response1["id"].(string) + Expect(ok).To(BeTrue()) + Expect(responseID).ToNot(BeEmpty()) + + // Now create a new response with previous_response_id + reqBody2 := map[string]interface{}{ + "model": testModel, + "input": "What about 3+3?", + "previous_response_id": responseID, + } + payload2, err := json.Marshal(reqBody2) + Expect(err).ToNot(HaveOccurred()) + + req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2)) + Expect(err).ToNot(HaveOccurred()) + req2.Header.Set("Content-Type", "application/json") + req2.Header.Set("Authorization", bearerKey) + + resp2, err := client.Do(req2) + Expect(err).ToNot(HaveOccurred()) + defer resp2.Body.Close() + + var response2 map[string]interface{} + body2, err := io.ReadAll(resp2.Body) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal(body2, &response2) + Expect(err).ToNot(HaveOccurred()) + + Expect(response2["previous_response_id"]).To(Equal(responseID)) + Expect(response2["status"]).To(Equal("completed")) + }) + + It("should return error for invalid previous_response_id", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": "Test", + "previous_response_id": "nonexistent_response_id", + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + Expect(resp.StatusCode).To(Equal(404)) + + var errorResp map[string]interface{} + body, _ := io.ReadAll(resp.Body) + json.Unmarshal(body, &errorResp) + + if errorResp["error"] != nil { + errorObj, ok := errorResp["error"].(map[string]interface{}) + if ok { + Expect(errorObj["type"]).To(Equal("not_found")) + Expect(errorObj["param"]).To(Equal("previous_response_id")) + } + } + }) + }) + + Context("Item Reference", func() { + It("should resolve item_reference in input", func() { + // First, create a response with items + reqBody1 := map[string]interface{}{ + "model": testModel, + "input": "Hello", + } + payload1, err := json.Marshal(reqBody1) + Expect(err).ToNot(HaveOccurred()) + + req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1)) + Expect(err).ToNot(HaveOccurred()) + req1.Header.Set("Content-Type", "application/json") + req1.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp1, err := client.Do(req1) + Expect(err).ToNot(HaveOccurred()) + defer resp1.Body.Close() + + // Check if first response succeeded + if resp1.StatusCode != 200 { + Skip("First response failed, skipping item_reference test (backend may not be available)") + } + + var response1 map[string]interface{} + body1, err := io.ReadAll(resp1.Body) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal(body1, &response1) + Expect(err).ToNot(HaveOccurred()) + + // Get the first output item ID + output, ok := response1["output"].([]interface{}) + Expect(ok).To(BeTrue()) + Expect(len(output)).To(BeNumerically(">", 0)) + + firstItem, ok := output[0].(map[string]interface{}) + Expect(ok).To(BeTrue()) + itemID, ok := firstItem["id"].(string) + Expect(ok).To(BeTrue()) + Expect(itemID).ToNot(BeEmpty()) + + // Now create a new response with item_reference + reqBody2 := map[string]interface{}{ + "model": testModel, + "input": []interface{}{ + map[string]interface{}{ + "type": "item_reference", + "item_id": itemID, + }, + map[string]interface{}{ + "type": "message", + "role": "user", + "content": "Continue from the previous message", + }, + }, + } + payload2, err := json.Marshal(reqBody2) + Expect(err).ToNot(HaveOccurred()) + + req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2)) + Expect(err).ToNot(HaveOccurred()) + req2.Header.Set("Content-Type", "application/json") + req2.Header.Set("Authorization", bearerKey) + + resp2, err := client.Do(req2) + Expect(err).ToNot(HaveOccurred()) + defer resp2.Body.Close() + + // Should succeed (item reference resolved) + Expect(resp2.StatusCode).To(Equal(200)) + }) + + It("should return error for invalid item_reference", func() { + reqBody := map[string]interface{}{ + "model": testModel, + "input": []interface{}{ + map[string]interface{}{ + "type": "item_reference", + "item_id": "nonexistent_item_id", + }, + }, + } + payload, err := json.Marshal(reqBody) + Expect(err).ToNot(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) + + client := &http.Client{} + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + // Should return error + Expect(resp.StatusCode).To(BeNumerically(">=", 400)) + }) + }) + }) +}) diff --git a/core/http/routes/openresponses.go b/core/http/routes/openresponses.go new file mode 100644 index 000000000..19cadbbae --- /dev/null +++ b/core/http/routes/openresponses.go @@ -0,0 +1,58 @@ +package routes + +import ( + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/openresponses" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" +) + +func RegisterOpenResponsesRoutes(app *echo.Echo, + re *middleware.RequestExtractor, + application *application.Application) { + + // Open Responses API endpoint + responsesHandler := openresponses.ResponsesEndpoint( + application.ModelConfigLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ) + + responsesMiddleware := []echo.MiddlewareFunc{ + middleware.TraceMiddleware(application), + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenResponsesRequest) }), + setOpenResponsesRequestContext(re), + } + + // Main Open Responses endpoint + app.POST("/v1/responses", responsesHandler, responsesMiddleware...) + + // Also support without version prefix for compatibility + app.POST("/responses", responsesHandler, responsesMiddleware...) + + // GET /responses/:id - Retrieve a response (for polling background requests) + getResponseHandler := openresponses.GetResponseEndpoint() + app.GET("/v1/responses/:id", getResponseHandler, middleware.TraceMiddleware(application)) + app.GET("/responses/:id", getResponseHandler, middleware.TraceMiddleware(application)) + + // POST /responses/:id/cancel - Cancel a background response + cancelResponseHandler := openresponses.CancelResponseEndpoint() + app.POST("/v1/responses/:id/cancel", cancelResponseHandler, middleware.TraceMiddleware(application)) + app.POST("/responses/:id/cancel", cancelResponseHandler, middleware.TraceMiddleware(application)) +} + +// setOpenResponsesRequestContext sets up the context and cancel function for Open Responses requests +func setOpenResponsesRequestContext(re *middleware.RequestExtractor) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := re.SetOpenResponsesRequest(c); err != nil { + return err + } + return next(c) + } + } +} diff --git a/core/http/views/settings.html b/core/http/views/settings.html index 97587f0e3..4f5815adf 100644 --- a/core/http/views/settings.html +++ b/core/http/views/settings.html @@ -485,6 +485,28 @@ + +
+ Configure Open Responses API response storage +
+ +Time-to-live for stored responses (e.g., 1h, 30m, 0 = no expiration)
+ +