diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 383212cff..6746e22d3 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -32,13 +32,26 @@ import ( ) const ( + // XXX: Presently it seems all ASR/VAD backends use 16Khz. If a backend uses 24Khz then it will likely still work, but have reduced performance localSampleRate = 16000 - remoteSampleRate = 24000 + defaultRemoteSampleRate = 24000 ) // A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result // If the model support instead audio-to-audio, we will use the specific gRPC calls instead +// LockedWebsocket wraps a websocket connection with a mutex for safe concurrent writes +type LockedWebsocket struct { + *websocket.Conn + sync.Mutex +} + +func (l *LockedWebsocket) WriteMessage(messageType int, data []byte) error { + l.Lock() + defer l.Unlock() + return l.Conn.WriteMessage(messageType, data) +} + // Session represents a single WebSocket connection and its state type Session struct { ID string @@ -58,7 +71,8 @@ type Session struct { DefaultConversationID string ModelInterface Model // The pipeline model config or the config for an any-to-any model - ModelConfig *config.ModelConfig + ModelConfig *config.ModelConfig + InputSampleRate int } func (s *Session) FromClient(session *types.SessionUnion) { @@ -162,7 +176,8 @@ func Realtime(application *application.Application) echo.HandlerFunc { } func registerRealtime(application *application.Application, model string) func(c *websocket.Conn) { - return func(c *websocket.Conn) { + return func(conn *websocket.Conn) { + c := &LockedWebsocket{Conn: conn} evaluator := application.TemplatesEvaluator() xlog.Debug("Realtime WebSocket connection established", "address", c.RemoteAddr().String(), "model", model) @@ -202,7 +217,8 @@ func registerRealtime(application *application.Application, model string) func(c InputAudioTranscription: &types.AudioTranscription{ Model: sttModel, }, - Conversations: make(map[string]*Conversation), + Conversations: make(map[string]*Conversation), + InputSampleRate: defaultRemoteSampleRate, } // Create a default conversation @@ -455,7 +471,7 @@ func registerRealtime(application *application.Application, model string) func(c } // Helper function to send events to the client -func sendEvent(c *websocket.Conn, event types.ServerEvent) { +func sendEvent(c *LockedWebsocket, event types.ServerEvent) { eventBytes, err := json.Marshal(event) if err != nil { xlog.Error("failed to marshal event", "error", err) @@ -467,7 +483,7 @@ func sendEvent(c *websocket.Conn, event types.ServerEvent) { } // Helper function to send errors to the client -func sendError(c *websocket.Conn, code, message, param, eventID string) { +func sendError(c *LockedWebsocket, code, message, param, eventID string) { errorEvent := types.ErrorEvent{ ServerEventBase: types.ServerEventBase{ EventID: eventID, @@ -484,7 +500,7 @@ func sendError(c *websocket.Conn, code, message, param, eventID string) { sendEvent(c, errorEvent) } -func sendNotImplemented(c *websocket.Conn, message string) { +func sendNotImplemented(c *LockedWebsocket, message string) { sendError(c, "not_implemented", message, "", "event_TODO") } @@ -529,6 +545,12 @@ func updateTransSession(session *Session, update *types.SessionUnion, cl *config session.TurnDetection = update.Transcription.Audio.Input.TurnDetection } + if update.Transcription.Audio.Input.Format != nil && update.Transcription.Audio.Input.Format.PCM != nil { + if update.Transcription.Audio.Input.Format.PCM.Rate > 0 { + session.InputSampleRate = update.Transcription.Audio.Input.Format.PCM.Rate + } + } + return nil } @@ -582,6 +604,12 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode session.TurnDetection = rt.Audio.Input.TurnDetection } + if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.Format != nil && rt.Audio.Input.Format.PCM != nil { + if rt.Audio.Input.Format.PCM.Rate > 0 { + session.InputSampleRate = rt.Audio.Input.Format.PCM.Rate + } + } + if rt.Instructions != "" { session.Instructions = rt.Instructions } @@ -598,7 +626,7 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode // handleVAD is a goroutine that listens for audio data from the client, // runs VAD on the audio data, and commits utterances to the conversation -func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) { +func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done chan struct{}) { vadContext, cancel := context.WithCancel(context.Background()) go func() { <-done @@ -627,12 +655,12 @@ func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done cha session.AudioBufferLock.Unlock() aints := sound.BytesToInt16sLE(allAudio) - if len(aints) == 0 || len(aints) < int(silenceThreshold)*remoteSampleRate { + if len(aints) == 0 || len(aints) < int(silenceThreshold)*session.InputSampleRate { continue } - // Resample from 24kHz to 16kHz - aints = sound.ResampleInt16(aints, remoteSampleRate, localSampleRate) + // Resample from InputSampleRate to 16kHz + aints = sound.ResampleInt16(aints, session.InputSampleRate, localSampleRate) segments, err := runVAD(vadContext, session, aints) if err != nil { @@ -712,7 +740,7 @@ func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done cha } } -func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *websocket.Conn) { +func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *LockedWebsocket) { if len(utt) == 0 { return } @@ -794,7 +822,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS } // Function to generate a response based on the conversation -func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *websocket.Conn, mt int) { +func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *LockedWebsocket, mt int) { xlog.Debug("Generating realtime response...") config := session.ModelInterface.PredictConfig() diff --git a/core/http/views/model-editor.html b/core/http/views/model-editor.html index d01a91ae7..6b2133e87 100644 --- a/core/http/views/model-editor.html +++ b/core/http/views/model-editor.html @@ -1030,7 +1030,7 @@ parameters: if (!isPipeline && !config.backend) { throw new Error('Backend is required'); } - if (!config.parameters || !config.parameters.model) { + if (!isPipeline && (!config.parameters || !config.parameters.model)) { throw new Error('Model file/path is required in parameters.model'); } @@ -1056,10 +1056,9 @@ parameters: if (!isPipeline && !config.backend) { throw new Error('Backend is required'); } - if (!config.parameters || !config.parameters.model) { + if (!isPipeline && (!config.parameters || !config.parameters.model)) { throw new Error('Model file/path is required in parameters.model'); } - const endpoint = this.isEditMode ? `/models/edit/{{.ModelName}}` : '/models/import'; const response = await fetch(endpoint, {