mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-18 23:21:58 -05:00
fix(realtime): Better support for thinking models and setting model parameters (#8595)
* fix(realtime): Wrap functions in OpenAI chat completions format Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(realtime): Set max tokens from session object Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(realtime): Find thinking start tag for thinking extraction Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(realtime): Don't send buffer cleared message when we automatically drop it Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
2fabdc08e6
commit
86b3bc9313
@@ -23,18 +23,18 @@ import (
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
laudio "github.com/mudler/LocalAI/pkg/audio"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
"github.com/mudler/LocalAI/pkg/sound"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
const (
|
||||
// XXX: Presently it seems all ASR/VAD backends use 16Khz. If a backend uses 24Khz then it will likely still work, but have reduced performance
|
||||
localSampleRate = 16000
|
||||
localSampleRate = 16000
|
||||
defaultRemoteSampleRate = 24000
|
||||
)
|
||||
|
||||
@@ -74,6 +74,7 @@ type Session struct {
|
||||
// The pipeline model config or the config for an any-to-any model
|
||||
ModelConfig *config.ModelConfig
|
||||
InputSampleRate int
|
||||
MaxOutputTokens types.IntOrInf
|
||||
}
|
||||
|
||||
func (s *Session) FromClient(session *types.SessionUnion) {
|
||||
@@ -95,12 +96,13 @@ func (s *Session) ToServer() types.SessionUnion {
|
||||
} else {
|
||||
return types.SessionUnion{
|
||||
Realtime: &types.RealtimeSession{
|
||||
ID: s.ID,
|
||||
Object: "realtime.session",
|
||||
Model: s.Model,
|
||||
Instructions: s.Instructions,
|
||||
Tools: s.Tools,
|
||||
ToolChoice: s.ToolChoice,
|
||||
ID: s.ID,
|
||||
Object: "realtime.session",
|
||||
Model: s.Model,
|
||||
Instructions: s.Instructions,
|
||||
Tools: s.Tools,
|
||||
ToolChoice: s.ToolChoice,
|
||||
MaxOutputTokens: s.MaxOutputTokens,
|
||||
Audio: &types.RealtimeSessionAudio{
|
||||
Input: &types.SessionAudioInput{
|
||||
TurnDetection: s.TurnDetection,
|
||||
@@ -678,6 +680,10 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
||||
session.ToolChoice = rt.ToolChoice
|
||||
}
|
||||
|
||||
if rt.MaxOutputTokens != 0 {
|
||||
session.MaxOutputTokens = rt.MaxOutputTokens
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -733,18 +739,18 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch
|
||||
audioLength := float64(len(aints)) / localSampleRate
|
||||
|
||||
// TODO: When resetting the buffer we should retain a small postfix
|
||||
// TODO: The OpenAI documentation seems to suggest that only the client decides when to clear the buffer
|
||||
if len(segments) == 0 && audioLength > silenceThreshold {
|
||||
session.AudioBufferLock.Lock()
|
||||
session.InputAudioBuffer = nil
|
||||
session.AudioBufferLock.Unlock()
|
||||
xlog.Debug("Detected silence for a while, clearing audio buffer")
|
||||
|
||||
sendEvent(c, types.InputAudioBufferClearedEvent{
|
||||
ServerEventBase: types.ServerEventBase{
|
||||
EventID: "event_TODO",
|
||||
},
|
||||
})
|
||||
// NOTE: OpenAI doesn't send this message unless the client requests it
|
||||
// xlog.Debug("Detected silence for a while, clearing audio buffer")
|
||||
// sendEvent(c, types.InputAudioBufferClearedEvent{
|
||||
// ServerEventBase: types.ServerEventBase{
|
||||
// EventID: "event_TODO",
|
||||
// },
|
||||
// })
|
||||
|
||||
continue
|
||||
} else if len(segments) == 0 {
|
||||
@@ -914,6 +920,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o
|
||||
tools := session.Tools
|
||||
toolChoice := session.ToolChoice
|
||||
instructions := session.Instructions
|
||||
maxOutputTokens := session.MaxOutputTokens
|
||||
// Overrides
|
||||
if overrides != nil {
|
||||
if overrides.Tools != nil {
|
||||
@@ -925,8 +932,29 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o
|
||||
if overrides.Instructions != "" {
|
||||
instructions = overrides.Instructions
|
||||
}
|
||||
if overrides.MaxOutputTokens != 0 {
|
||||
maxOutputTokens = overrides.MaxOutputTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Apply MaxOutputTokens to model config if specified
|
||||
// Save original value to restore after prediction
|
||||
var originalMaxTokens *int
|
||||
if config != nil {
|
||||
originalMaxTokens = config.Maxtokens
|
||||
if maxOutputTokens != 0 && !maxOutputTokens.IsInf() {
|
||||
tokenValue := int(maxOutputTokens)
|
||||
config.Maxtokens = &tokenValue
|
||||
xlog.Debug("Applied max_output_tokens to config", "value", tokenValue)
|
||||
}
|
||||
}
|
||||
// Defer restoration of original value
|
||||
defer func() {
|
||||
if config != nil {
|
||||
config.Maxtokens = originalMaxTokens
|
||||
}
|
||||
}()
|
||||
|
||||
var conversationHistory schema.Messages
|
||||
conversationHistory = append(conversationHistory, schema.Message{
|
||||
Role: string(types.MessageRoleSystem),
|
||||
@@ -1034,13 +1062,34 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o
|
||||
}
|
||||
|
||||
xlog.Debug("Function config for parsing", "function_name_key", config.FunctionsConfig.FunctionNameKey, "function_arguments_key", config.FunctionsConfig.FunctionArgumentsKey)
|
||||
xlog.Debug("LLM raw response", "text", pred.Response, "response_length", len(pred.Response), "usage", pred.Usage)
|
||||
|
||||
// Safely dereference pointer fields for logging
|
||||
maxTokens := "nil"
|
||||
if config.Maxtokens != nil {
|
||||
maxTokens = fmt.Sprintf("%d", *config.Maxtokens)
|
||||
}
|
||||
contextSize := "nil"
|
||||
if config.ContextSize != nil {
|
||||
contextSize = fmt.Sprintf("%d", *config.ContextSize)
|
||||
}
|
||||
xlog.Debug("Model parameters", "max_tokens", maxTokens, "context_size", contextSize, "stopwords", config.StopWords)
|
||||
|
||||
rawResponse := pred.Response
|
||||
if config.TemplateConfig.ReplyPrefix != "" {
|
||||
rawResponse = config.TemplateConfig.ReplyPrefix + rawResponse
|
||||
}
|
||||
|
||||
reasoningText, responseWithoutReasoning := reasoning.ExtractReasoningWithConfig(rawResponse, "", config.ReasoningConfig)
|
||||
// Detect thinking start token from template for reasoning extraction
|
||||
var template string
|
||||
if config.TemplateConfig.UseTokenizerTemplate {
|
||||
template = config.GetModelTemplate()
|
||||
} else {
|
||||
template = config.TemplateConfig.Chat
|
||||
}
|
||||
thinkingStartToken := reasoning.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
|
||||
reasoningText, responseWithoutReasoning := reasoning.ExtractReasoningWithConfig(rawResponse, thinkingStartToken, config.ReasoningConfig)
|
||||
xlog.Debug("LLM Response", "reasoning", reasoningText, "response_without_reasoning", responseWithoutReasoning)
|
||||
|
||||
textContent := functions.ParseTextContent(responseWithoutReasoning, config.FunctionsConfig)
|
||||
|
||||
@@ -194,7 +194,40 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
|
||||
var toolsJSON string
|
||||
if len(tools) > 0 {
|
||||
b, _ := json.Marshal(tools)
|
||||
// Convert tools to OpenAI Chat Completions format (nested)
|
||||
// as expected by most backends (including llama.cpp)
|
||||
var chatTools []functions.Tool
|
||||
for _, t := range tools {
|
||||
if t.Function != nil {
|
||||
var params map[string]interface{}
|
||||
switch p := t.Function.Parameters.(type) {
|
||||
case map[string]interface{}:
|
||||
params = p
|
||||
case string:
|
||||
if err := json.Unmarshal([]byte(p), ¶ms); err != nil {
|
||||
xlog.Warn("Failed to parse parameters JSON string", "error", err, "function", t.Function.Name)
|
||||
}
|
||||
case nil:
|
||||
params = map[string]interface{}{}
|
||||
default:
|
||||
// Try to marshal/unmarshal to get map
|
||||
b, err := json.Marshal(p)
|
||||
if err == nil {
|
||||
_ = json.Unmarshal(b, ¶ms)
|
||||
}
|
||||
}
|
||||
|
||||
chatTools = append(chatTools, functions.Tool{
|
||||
Type: "function",
|
||||
Function: functions.Function{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
b, _ := json.Marshal(chatTools)
|
||||
toolsJSON = string(b)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user