mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-13 16:14:24 -05:00
fix(realtime): Sampling and websocket locking (#8521)
* fix(realtime): Use locked websocket for concurrent access Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(realtime): Use sample rate set in session Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(config): Allow pipelines to have no model parameters Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
cff972094c
commit
1479bee894
@@ -32,13 +32,26 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// XXX: Presently it seems all ASR/VAD backends use 16Khz. If a backend uses 24Khz then it will likely still work, but have reduced performance
|
||||
localSampleRate = 16000
|
||||
remoteSampleRate = 24000
|
||||
defaultRemoteSampleRate = 24000
|
||||
)
|
||||
|
||||
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
|
||||
// If the model support instead audio-to-audio, we will use the specific gRPC calls instead
|
||||
|
||||
// LockedWebsocket wraps a websocket connection with a mutex for safe concurrent writes
|
||||
type LockedWebsocket struct {
|
||||
*websocket.Conn
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (l *LockedWebsocket) WriteMessage(messageType int, data []byte) error {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
// Session represents a single WebSocket connection and its state
|
||||
type Session struct {
|
||||
ID string
|
||||
@@ -58,7 +71,8 @@ type Session struct {
|
||||
DefaultConversationID string
|
||||
ModelInterface Model
|
||||
// The pipeline model config or the config for an any-to-any model
|
||||
ModelConfig *config.ModelConfig
|
||||
ModelConfig *config.ModelConfig
|
||||
InputSampleRate int
|
||||
}
|
||||
|
||||
func (s *Session) FromClient(session *types.SessionUnion) {
|
||||
@@ -162,7 +176,8 @@ func Realtime(application *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
func registerRealtime(application *application.Application, model string) func(c *websocket.Conn) {
|
||||
return func(c *websocket.Conn) {
|
||||
return func(conn *websocket.Conn) {
|
||||
c := &LockedWebsocket{Conn: conn}
|
||||
|
||||
evaluator := application.TemplatesEvaluator()
|
||||
xlog.Debug("Realtime WebSocket connection established", "address", c.RemoteAddr().String(), "model", model)
|
||||
@@ -202,7 +217,8 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
InputAudioTranscription: &types.AudioTranscription{
|
||||
Model: sttModel,
|
||||
},
|
||||
Conversations: make(map[string]*Conversation),
|
||||
Conversations: make(map[string]*Conversation),
|
||||
InputSampleRate: defaultRemoteSampleRate,
|
||||
}
|
||||
|
||||
// Create a default conversation
|
||||
@@ -455,7 +471,7 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
}
|
||||
|
||||
// Helper function to send events to the client
|
||||
func sendEvent(c *websocket.Conn, event types.ServerEvent) {
|
||||
func sendEvent(c *LockedWebsocket, event types.ServerEvent) {
|
||||
eventBytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
xlog.Error("failed to marshal event", "error", err)
|
||||
@@ -467,7 +483,7 @@ func sendEvent(c *websocket.Conn, event types.ServerEvent) {
|
||||
}
|
||||
|
||||
// Helper function to send errors to the client
|
||||
func sendError(c *websocket.Conn, code, message, param, eventID string) {
|
||||
func sendError(c *LockedWebsocket, code, message, param, eventID string) {
|
||||
errorEvent := types.ErrorEvent{
|
||||
ServerEventBase: types.ServerEventBase{
|
||||
EventID: eventID,
|
||||
@@ -484,7 +500,7 @@ func sendError(c *websocket.Conn, code, message, param, eventID string) {
|
||||
sendEvent(c, errorEvent)
|
||||
}
|
||||
|
||||
func sendNotImplemented(c *websocket.Conn, message string) {
|
||||
func sendNotImplemented(c *LockedWebsocket, message string) {
|
||||
sendError(c, "not_implemented", message, "", "event_TODO")
|
||||
}
|
||||
|
||||
@@ -529,6 +545,12 @@ func updateTransSession(session *Session, update *types.SessionUnion, cl *config
|
||||
session.TurnDetection = update.Transcription.Audio.Input.TurnDetection
|
||||
}
|
||||
|
||||
if update.Transcription.Audio.Input.Format != nil && update.Transcription.Audio.Input.Format.PCM != nil {
|
||||
if update.Transcription.Audio.Input.Format.PCM.Rate > 0 {
|
||||
session.InputSampleRate = update.Transcription.Audio.Input.Format.PCM.Rate
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -582,6 +604,12 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
||||
session.TurnDetection = rt.Audio.Input.TurnDetection
|
||||
}
|
||||
|
||||
if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.Format != nil && rt.Audio.Input.Format.PCM != nil {
|
||||
if rt.Audio.Input.Format.PCM.Rate > 0 {
|
||||
session.InputSampleRate = rt.Audio.Input.Format.PCM.Rate
|
||||
}
|
||||
}
|
||||
|
||||
if rt.Instructions != "" {
|
||||
session.Instructions = rt.Instructions
|
||||
}
|
||||
@@ -598,7 +626,7 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
||||
|
||||
// handleVAD is a goroutine that listens for audio data from the client,
|
||||
// runs VAD on the audio data, and commits utterances to the conversation
|
||||
func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done chan struct{}) {
|
||||
vadContext, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-done
|
||||
@@ -627,12 +655,12 @@ func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done cha
|
||||
session.AudioBufferLock.Unlock()
|
||||
|
||||
aints := sound.BytesToInt16sLE(allAudio)
|
||||
if len(aints) == 0 || len(aints) < int(silenceThreshold)*remoteSampleRate {
|
||||
if len(aints) == 0 || len(aints) < int(silenceThreshold)*session.InputSampleRate {
|
||||
continue
|
||||
}
|
||||
|
||||
// Resample from 24kHz to 16kHz
|
||||
aints = sound.ResampleInt16(aints, remoteSampleRate, localSampleRate)
|
||||
// Resample from InputSampleRate to 16kHz
|
||||
aints = sound.ResampleInt16(aints, session.InputSampleRate, localSampleRate)
|
||||
|
||||
segments, err := runVAD(vadContext, session, aints)
|
||||
if err != nil {
|
||||
@@ -712,7 +740,7 @@ func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done cha
|
||||
}
|
||||
}
|
||||
|
||||
func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||
func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *LockedWebsocket) {
|
||||
if len(utt) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -794,7 +822,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS
|
||||
}
|
||||
|
||||
// Function to generate a response based on the conversation
|
||||
func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *websocket.Conn, mt int) {
|
||||
func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *LockedWebsocket, mt int) {
|
||||
xlog.Debug("Generating realtime response...")
|
||||
|
||||
config := session.ModelInterface.PredictConfig()
|
||||
|
||||
@@ -1030,7 +1030,7 @@ parameters:
|
||||
if (!isPipeline && !config.backend) {
|
||||
throw new Error('Backend is required');
|
||||
}
|
||||
if (!config.parameters || !config.parameters.model) {
|
||||
if (!isPipeline && (!config.parameters || !config.parameters.model)) {
|
||||
throw new Error('Model file/path is required in parameters.model');
|
||||
}
|
||||
|
||||
@@ -1056,10 +1056,9 @@ parameters:
|
||||
if (!isPipeline && !config.backend) {
|
||||
throw new Error('Backend is required');
|
||||
}
|
||||
if (!config.parameters || !config.parameters.model) {
|
||||
if (!isPipeline && (!config.parameters || !config.parameters.model)) {
|
||||
throw new Error('Model file/path is required in parameters.model');
|
||||
}
|
||||
|
||||
const endpoint = this.isEditMode ? `/models/edit/{{.ModelName}}` : '/models/import';
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
|
||||
Reference in New Issue
Block a user