mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-04 07:01:39 -04:00
feat(realtime): Allow sending text, image and audio conversation items" (#8524)
feat(realtime): Allow sending text and image conversation items Signed-off-by: Richard Palethorpe <io@richiejp.com> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
4a4d65f8e8
commit
f6c80a6987
@@ -398,7 +398,36 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
|
||||
case types.ConversationItemCreateEvent:
|
||||
xlog.Debug("recv", "message", string(msg))
|
||||
sendNotImplemented(c, "conversation.item.create")
|
||||
// Add the item to the conversation
|
||||
item := e.Item
|
||||
// Ensure IDs are present
|
||||
if item.User != nil && item.User.ID == "" {
|
||||
item.User.ID = generateItemID()
|
||||
}
|
||||
if item.Assistant != nil && item.Assistant.ID == "" {
|
||||
item.Assistant.ID = generateItemID()
|
||||
}
|
||||
if item.System != nil && item.System.ID == "" {
|
||||
item.System.ID = generateItemID()
|
||||
}
|
||||
if item.FunctionCall != nil && item.FunctionCall.ID == "" {
|
||||
item.FunctionCall.ID = generateItemID()
|
||||
}
|
||||
if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" {
|
||||
item.FunctionCallOutput.ID = generateItemID()
|
||||
}
|
||||
|
||||
conversation.Lock.Lock()
|
||||
conversation.Items = append(conversation.Items, &item)
|
||||
conversation.Lock.Unlock()
|
||||
|
||||
sendEvent(c, types.ConversationItemAddedEvent{
|
||||
ServerEventBase: types.ServerEventBase{
|
||||
EventID: e.EventID,
|
||||
},
|
||||
PreviousItemID: e.PreviousItemID,
|
||||
Item: item,
|
||||
})
|
||||
|
||||
case types.ConversationItemDeleteEvent:
|
||||
sendError(c, "not_implemented", "Deleting items not implemented", "", "event_TODO")
|
||||
@@ -444,7 +473,34 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
|
||||
case types.ResponseCreateEvent:
|
||||
xlog.Debug("recv", "message", string(msg))
|
||||
sendNotImplemented(c, "response.create")
|
||||
|
||||
// Handle optional items to add to context
|
||||
if len(e.Response.Input) > 0 {
|
||||
conversation.Lock.Lock()
|
||||
for _, item := range e.Response.Input {
|
||||
// Ensure IDs are present
|
||||
if item.User != nil && item.User.ID == "" {
|
||||
item.User.ID = generateItemID()
|
||||
}
|
||||
if item.Assistant != nil && item.Assistant.ID == "" {
|
||||
item.Assistant.ID = generateItemID()
|
||||
}
|
||||
if item.System != nil && item.System.ID == "" {
|
||||
item.System.ID = generateItemID()
|
||||
}
|
||||
if item.FunctionCall != nil && item.FunctionCall.ID == "" {
|
||||
item.FunctionCall.ID = generateItemID()
|
||||
}
|
||||
if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" {
|
||||
item.FunctionCallOutput.ID = generateItemID()
|
||||
}
|
||||
|
||||
conversation.Items = append(conversation.Items, &item)
|
||||
}
|
||||
conversation.Lock.Unlock()
|
||||
}
|
||||
|
||||
go triggerResponse(session, conversation, c, &e.Response)
|
||||
|
||||
case types.ResponseCancelEvent:
|
||||
xlog.Debug("recv", "message", string(msg))
|
||||
@@ -825,8 +881,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS
|
||||
func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *LockedWebsocket, mt int) {
|
||||
xlog.Debug("Generating realtime response...")
|
||||
|
||||
config := session.ModelInterface.PredictConfig()
|
||||
|
||||
// Create user message item
|
||||
item := types.MessageItemUnion{
|
||||
User: &types.MessageItemUser{
|
||||
ID: generateItemID(),
|
||||
@@ -848,33 +903,73 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
|
||||
Item: item,
|
||||
})
|
||||
|
||||
triggerResponse(session, conv, c, nil)
|
||||
}
|
||||
|
||||
func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, overrides *types.ResponseCreateParams) {
|
||||
config := session.ModelInterface.PredictConfig()
|
||||
|
||||
// Default values
|
||||
tools := session.Tools
|
||||
toolChoice := session.ToolChoice
|
||||
instructions := session.Instructions
|
||||
// Overrides
|
||||
if overrides != nil {
|
||||
if overrides.Tools != nil {
|
||||
tools = overrides.Tools
|
||||
}
|
||||
if overrides.ToolChoice != nil {
|
||||
toolChoice = overrides.ToolChoice
|
||||
}
|
||||
if overrides.Instructions != "" {
|
||||
instructions = overrides.Instructions
|
||||
}
|
||||
}
|
||||
|
||||
var conversationHistory schema.Messages
|
||||
conversationHistory = append(conversationHistory, schema.Message{
|
||||
Role: string(types.MessageRoleSystem),
|
||||
StringContent: session.Instructions,
|
||||
Content: session.Instructions,
|
||||
StringContent: instructions,
|
||||
Content: instructions,
|
||||
})
|
||||
|
||||
imgIndex := 0
|
||||
conv.Lock.Lock()
|
||||
for _, item := range conv.Items {
|
||||
if item.User != nil {
|
||||
msg := schema.Message{
|
||||
Role: string(types.MessageRoleUser),
|
||||
}
|
||||
textContent := ""
|
||||
nrOfImgsInMessage := 0
|
||||
for _, content := range item.User.Content {
|
||||
switch content.Type {
|
||||
case types.MessageContentTypeInputText:
|
||||
conversationHistory = append(conversationHistory, schema.Message{
|
||||
Role: string(types.MessageRoleUser),
|
||||
StringContent: content.Text,
|
||||
Content: content.Text,
|
||||
})
|
||||
textContent += content.Text
|
||||
case types.MessageContentTypeInputAudio:
|
||||
conversationHistory = append(conversationHistory, schema.Message{
|
||||
Role: string(types.MessageRoleUser),
|
||||
StringContent: content.Transcript,
|
||||
Content: content.Transcript,
|
||||
StringAudios: []string{content.Audio},
|
||||
})
|
||||
textContent += content.Transcript
|
||||
case types.MessageContentTypeInputImage:
|
||||
msg.StringImages = append(msg.StringImages, content.ImageURL)
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
if nrOfImgsInMessage > 0 {
|
||||
templated, err := templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
}, textContent)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to apply multimodal template", "error", err)
|
||||
templated = textContent
|
||||
}
|
||||
msg.StringContent = templated
|
||||
msg.Content = templated
|
||||
} else {
|
||||
msg.StringContent = textContent
|
||||
msg.Content = textContent
|
||||
}
|
||||
conversationHistory = append(conversationHistory, msg)
|
||||
} else if item.Assistant != nil {
|
||||
for _, content := range item.Assistant.Content {
|
||||
switch content.Type {
|
||||
@@ -905,6 +1000,11 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
|
||||
}
|
||||
conv.Lock.Unlock()
|
||||
|
||||
var images []string
|
||||
for _, m := range conversationHistory {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
responseID := generateUniqueID()
|
||||
sendEvent(c, types.ResponseCreatedEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
@@ -915,15 +1015,15 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
|
||||
},
|
||||
})
|
||||
|
||||
predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, nil, nil, nil, nil, session.Tools, session.ToolChoice, nil, nil, nil)
|
||||
predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil)
|
||||
if err != nil {
|
||||
sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID)
|
||||
sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here
|
||||
return
|
||||
}
|
||||
|
||||
pred, err := predFunc()
|
||||
if err != nil {
|
||||
sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID)
|
||||
sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", "")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1171,7 +1271,6 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
|
||||
Status: types.ResponseStatusCompleted,
|
||||
},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// Helper functions to generate unique IDs
|
||||
|
||||
@@ -26,6 +26,7 @@ const (
|
||||
MessageContentTypeTranscript MessageContentType = "transcript"
|
||||
MessageContentTypeInputText MessageContentType = "input_text"
|
||||
MessageContentTypeInputAudio MessageContentType = "input_audio"
|
||||
MessageContentTypeInputImage MessageContentType = "input_image"
|
||||
MessageContentTypeOutputText MessageContentType = "output_text"
|
||||
MessageContentTypeOutputAudio MessageContentType = "output_audio"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user