Files
LocalAI/core/http/endpoints/openresponses/websocket.go
Ettore Di Giacinto 8818452d85 feat(ui): MCP Apps, mcp streaming and client-side support (#8947)
* Revert "fix: Add timeout-based wait for model deletion completion (#8756)"

This reverts commit 9e1b0d0c82.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: add mcp prompts and resources

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(ui): add client-side MCP

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(ui): allow to authenticate MCP servers

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(ui): add MCP Apps

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: update AGENTS

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: allow to collapse navbar, save state in storage

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(ui): add MCP button also to home page

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(chat): populate string content

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-11 07:30:49 +01:00

425 lines
13 KiB
Go

package openresponses
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
)
const (
wsMaxMessageSize = 10 * 1024 * 1024 // 10MB
wsConnectionLimit = 60 * time.Minute
)
var wsUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// lockedConn wraps a websocket connection with a mutex for safe concurrent writes
type lockedConn struct {
*websocket.Conn
sync.Mutex
}
func (lc *lockedConn) writeJSON(v any) error {
lc.Lock()
defer lc.Unlock()
return lc.Conn.WriteJSON(v)
}
// WebSocketEndpoint handles WebSocket mode for the Responses API.
// Clients connect via ws://<host>:<port>/v1/responses and send response.create messages.
// Events are streamed back over the WebSocket connection instead of SSE.
func WebSocketEndpoint(application *application.Application) echo.HandlerFunc {
cl := application.ModelConfigLoader()
ml := application.ModelLoader()
evaluator := application.TemplatesEvaluator()
appConfig := application.ApplicationConfig()
return func(c echo.Context) error {
ws, err := wsUpgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer ws.Close()
ws.SetReadLimit(wsMaxMessageSize)
// Set absolute deadline so blocking ReadMessage unblocks after the limit
deadline := time.Now().Add(wsConnectionLimit)
ws.SetReadDeadline(deadline)
ws.SetWriteDeadline(deadline)
conn := &lockedConn{Conn: ws}
// Context for cancelling in-flight work when the connection closes
connCtx, connCancel := context.WithDeadline(context.Background(), deadline)
defer connCancel()
xlog.Debug("WebSocket Responses connection established", "address", ws.RemoteAddr().String())
handleWebSocketConnection(connCtx, conn, cl, ml, evaluator, appConfig)
return nil
}
}
// handleWebSocketConnection runs the read loop for a single WebSocket connection.
func handleWebSocketConnection(connCtx context.Context, conn *lockedConn, cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) {
// Track in-flight response to enforce one-at-a-time
var inflight sync.Mutex
// Read loop
for {
select {
case <-connCtx.Done():
sendWSError(conn, "websocket_connection_limit_reached", "Connection exceeded maximum duration", "")
return
default:
}
_, msgBytes, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
xlog.Debug("WebSocket Responses read error", "error", err)
}
return
}
// Parse the envelope to determine message type
var envelope struct {
Type string `json:"type"`
}
if err := json.Unmarshal(msgBytes, &envelope); err != nil {
sendWSError(conn, "invalid_request", "invalid JSON message", "")
continue
}
if envelope.Type != "response.create" {
sendWSError(conn, "invalid_request", fmt.Sprintf("unsupported message type: %s", envelope.Type), "type")
continue
}
// Parse the full request
var wsMsg schema.ORWebSocketMessage
if err := json.Unmarshal(msgBytes, &wsMsg); err != nil {
sendWSError(conn, "invalid_request", fmt.Sprintf("failed to parse request: %v", err), "")
continue
}
// Enforce one in-flight response at a time (non-blocking check)
if !inflight.TryLock() {
sendWSError(conn, "invalid_request", "a response is already in progress on this connection", "")
continue
}
go func() {
defer inflight.Unlock()
handleWSResponseCreate(connCtx, conn, &wsMsg.OpenResponsesRequest, cl, ml, evaluator, appConfig)
}()
}
}
// handleWSResponseCreate processes a single response.create message and streams events over WebSocket.
// It reuses the existing background stream infrastructure: the request is processed via
// handleBackgroundStream which buffers events into the store, and a forwarder goroutine
// reads those events and sends them over the WebSocket.
func handleWSResponseCreate(connCtx context.Context, conn *lockedConn, input *schema.OpenResponsesRequest, cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) {
createdAt := time.Now().Unix()
responseID := fmt.Sprintf("resp_%s", uuid.New().String())
if input.Model == "" {
sendWSError(conn, "invalid_request", "model is required", "model")
return
}
// Resolve model configuration (same logic as middleware.SetModelAndConfig)
cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(input.Model, appConfig)
if err != nil {
xlog.Warn("WebSocket Responses: model config not found", "model", input.Model, "error", err)
sendWSError(conn, "invalid_request", fmt.Sprintf("model not found: %s", input.Model), "model")
return
}
if cfg.Model == "" {
cfg.Model = input.Model
}
// Merge request params into config (same as mergeOpenResponsesRequestAndModelConfig)
if err := middleware.MergeOpenResponsesConfig(cfg, input); err != nil {
sendWSError(conn, "invalid_request", fmt.Sprintf("invalid configuration: %v", err), "")
return
}
// Set up context with cancellation tied to connection lifetime
reqCtx, reqCancel := context.WithCancel(connCtx)
defer reqCancel()
input.Context = reqCtx
input.Cancel = reqCancel
store := GetGlobalStore()
if appConfig.OpenResponsesStoreTTL > 0 {
store.SetTTL(appConfig.OpenResponsesStoreTTL)
}
shouldStore := true
if input.Store != nil && !*input.Store {
shouldStore = false
}
// Handle previous_response_id
var messages []schema.Message
if input.PreviousResponseID != "" {
stored, err := store.Get(input.PreviousResponseID)
if err != nil {
sendWSErrorEvent(conn, "previous_response_not_found",
fmt.Sprintf("previous response not found: %s", input.PreviousResponseID), "previous_response_id")
return
}
previousInputMessages, err := convertORInputToMessages(stored.Request.Input, cfg)
if err != nil {
sendWSError(conn, "invalid_request", fmt.Sprintf("failed to convert previous input: %v", err), "")
return
}
previousOutputMessages, err := convertOROutputItemsToMessages(stored.Response.Output)
if err != nil {
sendWSError(conn, "invalid_request", fmt.Sprintf("failed to convert previous response: %v", err), "")
return
}
messages = previousInputMessages
messages = append(messages, previousOutputMessages...)
}
// Convert current input to messages
newMessages, err := convertORInputToMessages(input.Input, cfg)
if err != nil {
sendWSError(conn, "invalid_request", fmt.Sprintf("failed to parse input: %v", err), "")
return
}
messages = append(messages, newMessages...)
if input.Instructions != "" {
messages = append([]schema.Message{{Role: "system", StringContent: input.Instructions}}, messages...)
}
// Handle tools
var funcs functions.Functions
var shouldUseFn bool
if len(input.Tools) > 0 {
funcs, shouldUseFn = convertORToolsToFunctions(input, cfg)
}
// Create OpenAI-compatible request
openAIReq := &schema.OpenAIRequest{
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{Model: input.Model},
Temperature: input.Temperature,
TopP: input.TopP,
Maxtokens: input.MaxOutputTokens,
},
Messages: messages,
Stream: true, // WebSocket mode always streams
Context: reqCtx,
Cancel: reqCancel,
Functions: funcs,
}
if input.TextFormat != nil {
openAIReq.ResponseFormat = convertTextFormatToResponseFormat(input.TextFormat)
}
// Generate grammar for function calling
if shouldUseFn && !cfg.FunctionsConfig.GrammarConfig.NoGrammar {
noActionName := "answer"
noActionDescription := "use this action to answer without performing any action"
if cfg.FunctionsConfig.NoActionFunctionName != "" {
noActionName = cfg.FunctionsConfig.NoActionFunctionName
}
if cfg.FunctionsConfig.NoActionDescriptionName != "" {
noActionDescription = cfg.FunctionsConfig.NoActionDescriptionName
}
noActionGrammar := functions.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
},
},
},
}
funcsWithNoAction := make(functions.Functions, len(funcs))
copy(funcsWithNoAction, funcs)
if !cfg.FunctionsConfig.DisableNoAction {
funcsWithNoAction = append(funcsWithNoAction, noActionGrammar)
}
if cfg.FunctionToCall() != "" {
funcsWithNoAction = funcsWithNoAction.Select(cfg.FunctionToCall())
}
jsStruct := funcsWithNoAction.ToJSONStructure(cfg.FunctionsConfig.FunctionNameKey, cfg.FunctionsConfig.FunctionNameKey)
g, err := jsStruct.Grammar(cfg.FunctionsConfig.GrammarOptions()...)
if err == nil {
cfg.Grammar = g
} else {
xlog.Error("WebSocket Responses: failed generating grammar", "error", err)
}
}
// Merge contiguous assistant messages
openAIReq.Messages = mergeContiguousAssistantMessages(openAIReq.Messages)
predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn)
// Use the background stream infrastructure: store the request as a background task,
// process it via handleBackgroundStream, and forward buffered events over WebSocket.
queuedResponse := buildORResponse(responseID, createdAt, nil, schema.ORStatusQueued, input, []schema.ORItemField{}, nil, shouldStore)
store.StoreBackground(responseID, input, queuedResponse, reqCancel, true)
// Start processing in a goroutine
processDone := make(chan struct{})
go func() {
defer close(processDone)
store.UpdateStatus(responseID, schema.ORStatusInProgress, nil)
finalResponse, bgErr := handleBackgroundStream(reqCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, nil, nil)
if bgErr != nil {
xlog.Error("WebSocket Responses: processing failed", "response_id", responseID, "error", bgErr)
now := time.Now().Unix()
store.UpdateStatus(responseID, schema.ORStatusFailed, &now)
// Buffer an error event so the client sees the failure
failedResponse := buildORResponse(responseID, createdAt, &now, schema.ORStatusFailed, input, []schema.ORItemField{}, nil, shouldStore)
bufferEvent(store, responseID, &schema.ORStreamEvent{
Type: "response.failed",
Response: failedResponse,
Error: &schema.ORErrorPayload{
Type: "server_error",
Message: bgErr.Error(),
},
})
return
}
if finalResponse != nil {
store.UpdateResponse(responseID, finalResponse)
}
}()
// Forward events from the store to the WebSocket connection
forwardEvents(reqCtx, conn, store, responseID, processDone, shouldStore)
}
// forwardEvents subscribes to events for a response and sends them over the WebSocket.
// This mirrors handleStreamResume but writes JSON to WebSocket instead of SSE.
func forwardEvents(ctx context.Context, conn *lockedConn, store *ResponseStore, responseID string, done <-chan struct{}, shouldStore bool) {
eventsChan, err := store.GetEventsChan(responseID)
if err != nil {
return
}
lastSeq := -1
for {
// Drain all available events
events, err := store.GetEventsAfter(responseID, lastSeq)
if err != nil {
return
}
for _, event := range events {
var parsed schema.ORStreamEvent
if err := json.Unmarshal(event.Data, &parsed); err != nil {
continue
}
if err := conn.writeJSON(&parsed); err != nil {
return
}
lastSeq = event.SequenceNumber
}
// Check if processing is done and all events have been sent
select {
case <-done:
// Drain any final events
finalEvents, err := store.GetEventsAfter(responseID, lastSeq)
if err == nil {
for _, event := range finalEvents {
var parsed schema.ORStreamEvent
if err := json.Unmarshal(event.Data, &parsed); err != nil {
continue
}
if err := conn.writeJSON(&parsed); err != nil {
return
}
}
}
// Clean up non-stored responses from the cache
if !shouldStore {
store.Delete(responseID)
}
return
default:
}
// Wait for new events, completion, or context cancellation
select {
case <-ctx.Done():
return
case <-done:
// Will drain in next iteration
case <-eventsChan:
// New events available
}
}
}
func sendWSError(conn *lockedConn, errType, message, param string) {
event := schema.ORStreamEvent{
Type: "error",
Error: &schema.ORErrorPayload{
Type: errType,
Message: message,
Param: param,
},
}
conn.writeJSON(&event)
}
func sendWSErrorEvent(conn *lockedConn, code, message, param string) {
event := schema.ORStreamEvent{
Type: "error",
Error: &schema.ORErrorPayload{
Type: "invalid_request_error",
Code: code,
Message: message,
Param: param,
},
}
conn.writeJSON(&event)
}