fix(realtime): Sampling and websocket locking (#8521)

* fix(realtime): Use locked websocket for concurrent access

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(realtime): Use sample rate set in session

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(config): Allow pipelines to have no model parameters

Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe
2026-02-12 12:57:34 +00:00
committed by GitHub
parent cff972094c
commit 1479bee894
2 changed files with 43 additions and 16 deletions

View File

@@ -32,13 +32,26 @@ import (
)
const (
// XXX: Presently it seems all ASR/VAD backends use 16Khz. If a backend uses 24Khz then it will likely still work, but have reduced performance
localSampleRate = 16000
remoteSampleRate = 24000
defaultRemoteSampleRate = 24000
)
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
// If the model support instead audio-to-audio, we will use the specific gRPC calls instead
// LockedWebsocket wraps a websocket connection with a mutex for safe concurrent writes
type LockedWebsocket struct {
*websocket.Conn
sync.Mutex
}
func (l *LockedWebsocket) WriteMessage(messageType int, data []byte) error {
l.Lock()
defer l.Unlock()
return l.Conn.WriteMessage(messageType, data)
}
// Session represents a single WebSocket connection and its state
type Session struct {
ID string
@@ -58,7 +71,8 @@ type Session struct {
DefaultConversationID string
ModelInterface Model
// The pipeline model config or the config for an any-to-any model
ModelConfig *config.ModelConfig
ModelConfig *config.ModelConfig
InputSampleRate int
}
func (s *Session) FromClient(session *types.SessionUnion) {
@@ -162,7 +176,8 @@ func Realtime(application *application.Application) echo.HandlerFunc {
}
func registerRealtime(application *application.Application, model string) func(c *websocket.Conn) {
return func(c *websocket.Conn) {
return func(conn *websocket.Conn) {
c := &LockedWebsocket{Conn: conn}
evaluator := application.TemplatesEvaluator()
xlog.Debug("Realtime WebSocket connection established", "address", c.RemoteAddr().String(), "model", model)
@@ -202,7 +217,8 @@ func registerRealtime(application *application.Application, model string) func(c
InputAudioTranscription: &types.AudioTranscription{
Model: sttModel,
},
Conversations: make(map[string]*Conversation),
Conversations: make(map[string]*Conversation),
InputSampleRate: defaultRemoteSampleRate,
}
// Create a default conversation
@@ -455,7 +471,7 @@ func registerRealtime(application *application.Application, model string) func(c
}
// Helper function to send events to the client
func sendEvent(c *websocket.Conn, event types.ServerEvent) {
func sendEvent(c *LockedWebsocket, event types.ServerEvent) {
eventBytes, err := json.Marshal(event)
if err != nil {
xlog.Error("failed to marshal event", "error", err)
@@ -467,7 +483,7 @@ func sendEvent(c *websocket.Conn, event types.ServerEvent) {
}
// Helper function to send errors to the client
func sendError(c *websocket.Conn, code, message, param, eventID string) {
func sendError(c *LockedWebsocket, code, message, param, eventID string) {
errorEvent := types.ErrorEvent{
ServerEventBase: types.ServerEventBase{
EventID: eventID,
@@ -484,7 +500,7 @@ func sendError(c *websocket.Conn, code, message, param, eventID string) {
sendEvent(c, errorEvent)
}
func sendNotImplemented(c *websocket.Conn, message string) {
func sendNotImplemented(c *LockedWebsocket, message string) {
sendError(c, "not_implemented", message, "", "event_TODO")
}
@@ -529,6 +545,12 @@ func updateTransSession(session *Session, update *types.SessionUnion, cl *config
session.TurnDetection = update.Transcription.Audio.Input.TurnDetection
}
if update.Transcription.Audio.Input.Format != nil && update.Transcription.Audio.Input.Format.PCM != nil {
if update.Transcription.Audio.Input.Format.PCM.Rate > 0 {
session.InputSampleRate = update.Transcription.Audio.Input.Format.PCM.Rate
}
}
return nil
}
@@ -582,6 +604,12 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
session.TurnDetection = rt.Audio.Input.TurnDetection
}
if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.Format != nil && rt.Audio.Input.Format.PCM != nil {
if rt.Audio.Input.Format.PCM.Rate > 0 {
session.InputSampleRate = rt.Audio.Input.Format.PCM.Rate
}
}
if rt.Instructions != "" {
session.Instructions = rt.Instructions
}
@@ -598,7 +626,7 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
// handleVAD is a goroutine that listens for audio data from the client,
// runs VAD on the audio data, and commits utterances to the conversation
func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done chan struct{}) {
vadContext, cancel := context.WithCancel(context.Background())
go func() {
<-done
@@ -627,12 +655,12 @@ func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done cha
session.AudioBufferLock.Unlock()
aints := sound.BytesToInt16sLE(allAudio)
if len(aints) == 0 || len(aints) < int(silenceThreshold)*remoteSampleRate {
if len(aints) == 0 || len(aints) < int(silenceThreshold)*session.InputSampleRate {
continue
}
// Resample from 24kHz to 16kHz
aints = sound.ResampleInt16(aints, remoteSampleRate, localSampleRate)
// Resample from InputSampleRate to 16kHz
aints = sound.ResampleInt16(aints, session.InputSampleRate, localSampleRate)
segments, err := runVAD(vadContext, session, aints)
if err != nil {
@@ -712,7 +740,7 @@ func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done cha
}
}
func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *websocket.Conn) {
func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *LockedWebsocket) {
if len(utt) == 0 {
return
}
@@ -794,7 +822,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS
}
// Function to generate a response based on the conversation
func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *websocket.Conn, mt int) {
func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *LockedWebsocket, mt int) {
xlog.Debug("Generating realtime response...")
config := session.ModelInterface.PredictConfig()

View File

@@ -1030,7 +1030,7 @@ parameters:
if (!isPipeline && !config.backend) {
throw new Error('Backend is required');
}
if (!config.parameters || !config.parameters.model) {
if (!isPipeline && (!config.parameters || !config.parameters.model)) {
throw new Error('Model file/path is required in parameters.model');
}
@@ -1056,10 +1056,9 @@ parameters:
if (!isPipeline && !config.backend) {
throw new Error('Backend is required');
}
if (!config.parameters || !config.parameters.model) {
if (!isPipeline && (!config.parameters || !config.parameters.model)) {
throw new Error('Model file/path is required in parameters.model');
}
const endpoint = this.isEditMode ? `/models/edit/{{.ModelName}}` : '/models/import';
const response = await fetch(endpoint, {