From 68dd9765a03ea88563e5768d52e374e7eb2ca2e8 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 30 Jan 2026 11:58:01 +0100 Subject: [PATCH] feat(tts): add support for streaming mode (#8291) * feat(tts): add support for streaming mode Signed-off-by: Ettore Di Giacinto * Send first audio, make sure it's 16 Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- backend/backend.proto | 1 + backend/python/voxcpm/backend.py | 84 ++++++++++++++++++++ backend/python/voxcpm/test.py | 36 +++++++++ core/backend/tts.go | 102 +++++++++++++++++++++++++ core/http/endpoints/localai/tts.go | 25 ++++++ core/schema/localai.go | 1 + docs/content/features/text-to-audio.md | 35 +++++++++ pkg/grpc/backend.go | 1 + pkg/grpc/base/base.go | 4 + pkg/grpc/client.go | 50 ++++++++++++ pkg/grpc/embed.go | 8 ++ pkg/grpc/interface.go | 1 + pkg/grpc/server.go | 21 +++++ 13 files changed, 369 insertions(+) diff --git a/backend/backend.proto b/backend/backend.proto index 31e40c4f3..50b239a77 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -17,6 +17,7 @@ service Backend { rpc GenerateVideo(GenerateVideoRequest) returns (Result) {} rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} rpc TTS(TTSRequest) returns (Result) {} + rpc TTSStream(TTSRequest) returns (stream Reply) {} rpc SoundGeneration(SoundGenerationRequest) returns (Result) {} rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {} rpc Status(HealthMessage) returns (StatusResponse) {} diff --git a/backend/python/voxcpm/backend.py b/backend/python/voxcpm/backend.py index 84bb99e96..0c1970648 100644 --- a/backend/python/voxcpm/backend.py +++ b/backend/python/voxcpm/backend.py @@ -207,6 +207,90 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(success=True) + def TTSStream(self, request, context): + try: + # Get generation parameters from options with defaults + cfg_value = self.options.get("cfg_value", 2.0) + inference_timesteps = self.options.get("inference_timesteps", 10) + normalize = self.options.get("normalize", False) + denoise = self.options.get("denoise", False) + retry_badcase = self.options.get("retry_badcase", True) + retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3) + retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0) + + # Handle voice cloning via prompt_wav_path and prompt_text + prompt_wav_path = None + prompt_text = None + + # Priority: request.voice > AudioPath > options + if hasattr(request, 'voice') and request.voice: + # If voice is provided, try to use it as a path + if os.path.exists(request.voice): + prompt_wav_path = request.voice + elif hasattr(request, 'ModelFile') and request.ModelFile: + model_file_base = os.path.dirname(request.ModelFile) + potential_path = os.path.join(model_file_base, request.voice) + if os.path.exists(potential_path): + prompt_wav_path = potential_path + elif hasattr(request, 'ModelPath') and request.ModelPath: + potential_path = os.path.join(request.ModelPath, request.voice) + if os.path.exists(potential_path): + prompt_wav_path = potential_path + + if hasattr(request, 'AudioPath') and request.AudioPath: + if os.path.isabs(request.AudioPath): + prompt_wav_path = request.AudioPath + elif hasattr(request, 'ModelFile') and request.ModelFile: + model_file_base = os.path.dirname(request.ModelFile) + prompt_wav_path = os.path.join(model_file_base, request.AudioPath) + elif hasattr(request, 'ModelPath') and request.ModelPath: + prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath) + else: + prompt_wav_path = request.AudioPath + + # Get prompt_text from options if available + if "prompt_text" in self.options: + prompt_text = self.options["prompt_text"] + + # Prepare text + text = request.text.strip() + + # Get sample rate from model (needed for WAV header) + sample_rate = self.model.tts_model.sample_rate + + print(f"Streaming audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, sample_rate: {sample_rate}", file=sys.stderr) + + # Send sample rate as first message (in message field as JSON or string) + # Format: "sample_rate:16000" so we can parse it + import json + sample_rate_info = json.dumps({"sample_rate": int(sample_rate)}) + yield backend_pb2.Reply(message=bytes(sample_rate_info, 'utf-8')) + + # Stream audio chunks + for chunk in self.model.generate_streaming( + text=text, + prompt_wav_path=prompt_wav_path, + prompt_text=prompt_text, + cfg_value=cfg_value, + inference_timesteps=inference_timesteps, + normalize=normalize, + denoise=denoise, + retry_badcase=retry_badcase, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + ): + # Convert numpy array to int16 PCM and then to bytes + # Ensure values are in int16 range + chunk_int16 = np.clip(chunk * 32767, -32768, 32767).astype(np.int16) + chunk_bytes = chunk_int16.tobytes() + yield backend_pb2.Reply(audio=chunk_bytes) + + except Exception as err: + print(f"Error in TTSStream: {err}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + # Yield an error reply + yield backend_pb2.Reply(message=bytes(f"Error: {err}", 'utf-8')) + def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ diff --git a/backend/python/voxcpm/test.py b/backend/python/voxcpm/test.py index 0a94012aa..63b691421 100644 --- a/backend/python/voxcpm/test.py +++ b/backend/python/voxcpm/test.py @@ -49,3 +49,39 @@ class TestBackendServicer(unittest.TestCase): self.fail("LoadModel service failed") finally: self.tearDown() + + def test_tts_stream(self): + """ + This method tests if TTS streaming works correctly + """ + try: + self.setUp() + print("Starting test_tts_stream") + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="openbmb/VoxCPM1.5")) + print(response) + self.assertTrue(response.success) + self.assertEqual(response.message, "Model loaded successfully") + + # Test TTSStream + tts_request = backend_pb2.TTSRequest(text="VoxCPM is an innovative end-to-end TTS model from ModelBest. This is a streaming test.", dst="test_stream.wav") + chunks_received = 0 + total_audio_bytes = 0 + + for reply in stub.TTSStream(tts_request): + # Verify that we receive audio chunks + if reply.audio: + chunks_received += 1 + total_audio_bytes += len(reply.audio) + self.assertGreater(len(reply.audio), 0, "Audio chunk should not be empty") + + # Verify that we received multiple chunks + self.assertGreater(chunks_received, 0, "Should receive at least one audio chunk") + self.assertGreater(total_audio_bytes, 0, "Total audio bytes should be greater than 0") + print(f"Received {chunks_received} chunks with {total_audio_bytes} total bytes") + except Exception as err: + print(err) + self.fail("TTSStream service failed") + finally: + self.tearDown() diff --git a/core/backend/tts.go b/core/backend/tts.go index 9c75cb37a..6e97ca8c3 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -1,12 +1,16 @@ package backend import ( + "bytes" "context" + "encoding/binary" + "encoding/json" "fmt" "os" "path/filepath" "github.com/mudler/LocalAI/core/config" + laudio "github.com/mudler/LocalAI/pkg/audio" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" @@ -74,3 +78,101 @@ func ModelTTS( return filePath, res, err } + +func ModelTTSStream( + text, + voice, + language string, + loader *model.ModelLoader, + appConfig *config.ApplicationConfig, + modelConfig config.ModelConfig, + audioCallback func([]byte) error, +) error { + opts := ModelOptions(modelConfig, appConfig) + ttsModel, err := loader.Load(opts...) + if err != nil { + return err + } + + if ttsModel == nil { + return fmt.Errorf("could not load tts model %q", modelConfig.Model) + } + + // We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect. + // This should be addressed in a follow up PR soon. + // Copying it over nearly verbatim, as TTS backends are not functional without this. + modelPath := "" + // Checking first that it exists and is not outside ModelPath + // TODO: we should actually first check if the modelFile is looking like + // a FS path + mp := filepath.Join(loader.ModelPath, modelConfig.Model) + if _, err := os.Stat(mp); err == nil { + if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil { + return err + } + modelPath = mp + } else { + modelPath = modelConfig.Model // skip this step if it fails????? + } + + var sampleRate uint32 = 16000 // default + headerSent := false + var callbackErr error + + err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{ + Text: text, + Model: modelPath, + Voice: voice, + Language: &language, + }, func(reply *proto.Reply) { + // First message contains sample rate info + if !headerSent && len(reply.Message) > 0 { + var info map[string]interface{} + if json.Unmarshal(reply.Message, &info) == nil { + if sr, ok := info["sample_rate"].(float64); ok { + sampleRate = uint32(sr) + } + } + // Send WAV header with placeholder size (0xFFFFFFFF for streaming) + header := laudio.WAVHeader{ + ChunkID: [4]byte{'R', 'I', 'F', 'F'}, + ChunkSize: 0xFFFFFFFF, // Unknown size for streaming + Format: [4]byte{'W', 'A', 'V', 'E'}, + Subchunk1ID: [4]byte{'f', 'm', 't', ' '}, + Subchunk1Size: 16, + AudioFormat: 1, // PCM + NumChannels: 1, // Mono + SampleRate: sampleRate, + ByteRate: sampleRate * 2, // SampleRate * BlockAlign + BlockAlign: 2, // 16-bit = 2 bytes + BitsPerSample: 16, + Subchunk2ID: [4]byte{'d', 'a', 't', 'a'}, + Subchunk2Size: 0xFFFFFFFF, // Unknown size for streaming + } + + var buf bytes.Buffer + if writeErr := binary.Write(&buf, binary.LittleEndian, header); writeErr != nil { + callbackErr = writeErr + return + } + + if writeErr := audioCallback(buf.Bytes()); writeErr != nil { + callbackErr = writeErr + return + } + headerSent = true + } + + // Stream audio chunks + if len(reply.Audio) > 0 { + if writeErr := audioCallback(reply.Audio); writeErr != nil { + callbackErr = writeErr + } + } + }) + + if callbackErr != nil { + return callbackErr + } + return err +} diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 9dd588ad7..01bc1cd82 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -50,6 +50,31 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig cfg.Voice = input.Voice } + // Handle streaming TTS + if input.Stream { + // Set headers for streaming audio + c.Response().Header().Set("Content-Type", "audio/wav") + c.Response().Header().Set("Transfer-Encoding", "chunked") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + + // Stream audio chunks as they're generated + err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error { + _, writeErr := c.Response().Write(audioChunk) + if writeErr != nil { + return writeErr + } + c.Response().Flush() + return nil + }) + if err != nil { + return err + } + + return nil + } + + // Non-streaming TTS (existing behavior) filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg) if err != nil { return err diff --git a/core/schema/localai.go b/core/schema/localai.go index 29e1faf3f..62373a5cc 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -53,6 +53,7 @@ type TTSRequest struct { Backend string `json:"backend" yaml:"backend"` Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format + Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"` // (optional) enable streaming TTS } // @Description VAD request body diff --git a/docs/content/features/text-to-audio.md b/docs/content/features/text-to-audio.md index 7f7a9bcf2..9771057d9 100644 --- a/docs/content/features/text-to-audio.md +++ b/docs/content/features/text-to-audio.md @@ -29,6 +29,41 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ Returns an `audio/wav` file. +## Streaming TTS + +LocalAI supports streaming TTS generation, allowing audio to be played as it's generated. This is useful for real-time applications and reduces latency. + +To enable streaming, add `"stream": true` to your request: + +```bash +curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ + "input": "Hello world, this is a streaming test", + "model": "voxcpm", + "stream": true +}' | aplay +``` + +The audio will be streamed chunk-by-chunk as it's generated, allowing playback to start before generation completes. This is particularly useful for long texts or when you want to minimize perceived latency. + +You can also pipe the streamed audio directly to audio players like `aplay` (Linux) or save it to a file: + +```bash +# Stream to aplay (Linux) +curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ + "input": "This is a longer text that will be streamed as it is generated", + "model": "voxcpm", + "stream": true +}' | aplay + +# Stream to a file +curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ + "input": "Streaming audio to file", + "model": "voxcpm", + "stream": true +}' > output.wav +``` + +Note: Streaming TTS is currently supported by the `voxcpm` backend. Other backends will fall back to non-streaming mode if streaming is not supported. ## Backends diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 8ecb818a0..c63da40a0 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -41,6 +41,7 @@ type Backend interface { GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) + TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 2d0ebc555..6a72cc95c 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -65,6 +65,10 @@ func (llm *Base) TTS(*pb.TTSRequest) error { return fmt.Errorf("unimplemented") } +func (llm *Base) TTSStream(*pb.TTSRequest, chan []byte) error { + return fmt.Errorf("unimplemented") +} + func (llm *Base) SoundGeneration(*pb.SoundGenerationRequest) error { return fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index dbdeeab24..ccc7611b3 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -270,6 +270,56 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp return client.TTS(ctx, in, opts...) } +func (c *Client) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB + grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB + )) + if err != nil { + return err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + + stream, err := client.TTSStream(ctx, in, opts...) + if err != nil { + return err + } + + for { + // Check if context is cancelled before receiving + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + reply, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + // Check if error is due to context cancellation + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + f(reply) + } + + return nil +} + func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 03cac344f..88d7ca760 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -55,6 +55,14 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc. return e.s.TTS(ctx, in) } +func (e *embedBackend) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error { + bs := &embedBackendServerStream{ + ctx: ctx, + fn: f, + } + return e.s.TTSStream(in, bs) +} + func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) { return e.s.SoundGeneration(ctx, in) } diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index bb22af55c..9610b817e 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -18,6 +18,7 @@ type AIModel interface { Detect(*pb.DetectOptions) (pb.DetectResponse, error) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) TTS(*pb.TTSRequest) error + TTSStream(*pb.TTSRequest, chan []byte) error SoundGeneration(*pb.SoundGenerationRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) Status() (pb.StatusResponse, error) diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 8cc6ee43e..af771f3d2 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -99,6 +99,27 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) return &pb.Result{Message: "TTS audio generated", Success: true}, nil } +func (s *server) TTSStream(in *pb.TTSRequest, stream pb.Backend_TTSStreamServer) error { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + audioChan := make(chan []byte) + + done := make(chan bool) + go func() { + for audioChunk := range audioChan { + stream.Send(&pb.Reply{Audio: audioChunk}) + } + done <- true + }() + + err := s.llm.TTSStream(in, audioChan) + <-done + + return err +} + func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) { if s.llm.Locking() { s.llm.Lock()