diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 6746e22d3..b36c16031 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -398,7 +398,36 @@ func registerRealtime(application *application.Application, model string) func(c case types.ConversationItemCreateEvent: xlog.Debug("recv", "message", string(msg)) - sendNotImplemented(c, "conversation.item.create") + // Add the item to the conversation + item := e.Item + // Ensure IDs are present + if item.User != nil && item.User.ID == "" { + item.User.ID = generateItemID() + } + if item.Assistant != nil && item.Assistant.ID == "" { + item.Assistant.ID = generateItemID() + } + if item.System != nil && item.System.ID == "" { + item.System.ID = generateItemID() + } + if item.FunctionCall != nil && item.FunctionCall.ID == "" { + item.FunctionCall.ID = generateItemID() + } + if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { + item.FunctionCallOutput.ID = generateItemID() + } + + conversation.Lock.Lock() + conversation.Items = append(conversation.Items, &item) + conversation.Lock.Unlock() + + sendEvent(c, types.ConversationItemAddedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: e.EventID, + }, + PreviousItemID: e.PreviousItemID, + Item: item, + }) case types.ConversationItemDeleteEvent: sendError(c, "not_implemented", "Deleting items not implemented", "", "event_TODO") @@ -444,7 +473,34 @@ func registerRealtime(application *application.Application, model string) func(c case types.ResponseCreateEvent: xlog.Debug("recv", "message", string(msg)) - sendNotImplemented(c, "response.create") + + // Handle optional items to add to context + if len(e.Response.Input) > 0 { + conversation.Lock.Lock() + for _, item := range e.Response.Input { + // Ensure IDs are present + if item.User != nil && item.User.ID == "" { + item.User.ID = generateItemID() + } + if item.Assistant != nil && item.Assistant.ID == "" { + item.Assistant.ID = generateItemID() + } + if item.System != nil && item.System.ID == "" { + item.System.ID = generateItemID() + } + if item.FunctionCall != nil && item.FunctionCall.ID == "" { + item.FunctionCall.ID = generateItemID() + } + if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { + item.FunctionCallOutput.ID = generateItemID() + } + + conversation.Items = append(conversation.Items, &item) + } + conversation.Lock.Unlock() + } + + go triggerResponse(session, conversation, c, &e.Response) case types.ResponseCancelEvent: xlog.Debug("recv", "message", string(msg)) @@ -825,8 +881,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *LockedWebsocket, mt int) { xlog.Debug("Generating realtime response...") - config := session.ModelInterface.PredictConfig() - + // Create user message item item := types.MessageItemUnion{ User: &types.MessageItemUser{ ID: generateItemID(), @@ -848,33 +903,73 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con Item: item, }) + triggerResponse(session, conv, c, nil) +} + +func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, overrides *types.ResponseCreateParams) { + config := session.ModelInterface.PredictConfig() + + // Default values + tools := session.Tools + toolChoice := session.ToolChoice + instructions := session.Instructions + // Overrides + if overrides != nil { + if overrides.Tools != nil { + tools = overrides.Tools + } + if overrides.ToolChoice != nil { + toolChoice = overrides.ToolChoice + } + if overrides.Instructions != "" { + instructions = overrides.Instructions + } + } + var conversationHistory schema.Messages conversationHistory = append(conversationHistory, schema.Message{ Role: string(types.MessageRoleSystem), - StringContent: session.Instructions, - Content: session.Instructions, + StringContent: instructions, + Content: instructions, }) + imgIndex := 0 conv.Lock.Lock() for _, item := range conv.Items { if item.User != nil { + msg := schema.Message{ + Role: string(types.MessageRoleUser), + } + textContent := "" + nrOfImgsInMessage := 0 for _, content := range item.User.Content { switch content.Type { case types.MessageContentTypeInputText: - conversationHistory = append(conversationHistory, schema.Message{ - Role: string(types.MessageRoleUser), - StringContent: content.Text, - Content: content.Text, - }) + textContent += content.Text case types.MessageContentTypeInputAudio: - conversationHistory = append(conversationHistory, schema.Message{ - Role: string(types.MessageRoleUser), - StringContent: content.Transcript, - Content: content.Transcript, - StringAudios: []string{content.Audio}, - }) + textContent += content.Transcript + case types.MessageContentTypeInputImage: + msg.StringImages = append(msg.StringImages, content.ImageURL) + imgIndex++ + nrOfImgsInMessage++ } } + if nrOfImgsInMessage > 0 { + templated, err := templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ + TotalImages: imgIndex, + ImagesInMessage: nrOfImgsInMessage, + }, textContent) + if err != nil { + xlog.Warn("Failed to apply multimodal template", "error", err) + templated = textContent + } + msg.StringContent = templated + msg.Content = templated + } else { + msg.StringContent = textContent + msg.Content = textContent + } + conversationHistory = append(conversationHistory, msg) } else if item.Assistant != nil { for _, content := range item.Assistant.Content { switch content.Type { @@ -905,6 +1000,11 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con } conv.Lock.Unlock() + var images []string + for _, m := range conversationHistory { + images = append(images, m.StringImages...) + } + responseID := generateUniqueID() sendEvent(c, types.ResponseCreatedEvent{ ServerEventBase: types.ServerEventBase{}, @@ -915,15 +1015,15 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con }, }) - predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, nil, nil, nil, nil, session.Tools, session.ToolChoice, nil, nil, nil) + predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil) if err != nil { - sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID) + sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here return } pred, err := predFunc() if err != nil { - sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID) + sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", "") return } @@ -1171,7 +1271,6 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con Status: types.ResponseStatusCompleted, }, }) - } // Helper functions to generate unique IDs diff --git a/core/http/endpoints/openai/types/message_item.go b/core/http/endpoints/openai/types/message_item.go index 2b4f0c95f..52997fe8c 100644 --- a/core/http/endpoints/openai/types/message_item.go +++ b/core/http/endpoints/openai/types/message_item.go @@ -26,6 +26,7 @@ const ( MessageContentTypeTranscript MessageContentType = "transcript" MessageContentTypeInputText MessageContentType = "input_text" MessageContentTypeInputAudio MessageContentType = "input_audio" + MessageContentTypeInputImage MessageContentType = "input_image" MessageContentTypeOutputText MessageContentType = "output_text" MessageContentTypeOutputAudio MessageContentType = "output_audio" )