mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-31 17:52:51 -05:00
feat(tts): add support for streaming mode (#8291)
* feat(tts): add support for streaming mode Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Send first audio, make sure it's 16 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
2c44b06a67
commit
68dd9765a0
@@ -17,6 +17,7 @@ service Backend {
|
||||
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
|
||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
||||
rpc TTS(TTSRequest) returns (Result) {}
|
||||
rpc TTSStream(TTSRequest) returns (stream Reply) {}
|
||||
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||
|
||||
@@ -207,6 +207,90 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def TTSStream(self, request, context):
|
||||
try:
|
||||
# Get generation parameters from options with defaults
|
||||
cfg_value = self.options.get("cfg_value", 2.0)
|
||||
inference_timesteps = self.options.get("inference_timesteps", 10)
|
||||
normalize = self.options.get("normalize", False)
|
||||
denoise = self.options.get("denoise", False)
|
||||
retry_badcase = self.options.get("retry_badcase", True)
|
||||
retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3)
|
||||
retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0)
|
||||
|
||||
# Handle voice cloning via prompt_wav_path and prompt_text
|
||||
prompt_wav_path = None
|
||||
prompt_text = None
|
||||
|
||||
# Priority: request.voice > AudioPath > options
|
||||
if hasattr(request, 'voice') and request.voice:
|
||||
# If voice is provided, try to use it as a path
|
||||
if os.path.exists(request.voice):
|
||||
prompt_wav_path = request.voice
|
||||
elif hasattr(request, 'ModelFile') and request.ModelFile:
|
||||
model_file_base = os.path.dirname(request.ModelFile)
|
||||
potential_path = os.path.join(model_file_base, request.voice)
|
||||
if os.path.exists(potential_path):
|
||||
prompt_wav_path = potential_path
|
||||
elif hasattr(request, 'ModelPath') and request.ModelPath:
|
||||
potential_path = os.path.join(request.ModelPath, request.voice)
|
||||
if os.path.exists(potential_path):
|
||||
prompt_wav_path = potential_path
|
||||
|
||||
if hasattr(request, 'AudioPath') and request.AudioPath:
|
||||
if os.path.isabs(request.AudioPath):
|
||||
prompt_wav_path = request.AudioPath
|
||||
elif hasattr(request, 'ModelFile') and request.ModelFile:
|
||||
model_file_base = os.path.dirname(request.ModelFile)
|
||||
prompt_wav_path = os.path.join(model_file_base, request.AudioPath)
|
||||
elif hasattr(request, 'ModelPath') and request.ModelPath:
|
||||
prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath)
|
||||
else:
|
||||
prompt_wav_path = request.AudioPath
|
||||
|
||||
# Get prompt_text from options if available
|
||||
if "prompt_text" in self.options:
|
||||
prompt_text = self.options["prompt_text"]
|
||||
|
||||
# Prepare text
|
||||
text = request.text.strip()
|
||||
|
||||
# Get sample rate from model (needed for WAV header)
|
||||
sample_rate = self.model.tts_model.sample_rate
|
||||
|
||||
print(f"Streaming audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, sample_rate: {sample_rate}", file=sys.stderr)
|
||||
|
||||
# Send sample rate as first message (in message field as JSON or string)
|
||||
# Format: "sample_rate:16000" so we can parse it
|
||||
import json
|
||||
sample_rate_info = json.dumps({"sample_rate": int(sample_rate)})
|
||||
yield backend_pb2.Reply(message=bytes(sample_rate_info, 'utf-8'))
|
||||
|
||||
# Stream audio chunks
|
||||
for chunk in self.model.generate_streaming(
|
||||
text=text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=cfg_value,
|
||||
inference_timesteps=inference_timesteps,
|
||||
normalize=normalize,
|
||||
denoise=denoise,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
):
|
||||
# Convert numpy array to int16 PCM and then to bytes
|
||||
# Ensure values are in int16 range
|
||||
chunk_int16 = np.clip(chunk * 32767, -32768, 32767).astype(np.int16)
|
||||
chunk_bytes = chunk_int16.tobytes()
|
||||
yield backend_pb2.Reply(audio=chunk_bytes)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in TTSStream: {err}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
# Yield an error reply
|
||||
yield backend_pb2.Reply(message=bytes(f"Error: {err}", 'utf-8'))
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
|
||||
@@ -49,3 +49,39 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts_stream(self):
|
||||
"""
|
||||
This method tests if TTS streaming works correctly
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
print("Starting test_tts_stream")
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="openbmb/VoxCPM1.5"))
|
||||
print(response)
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
|
||||
# Test TTSStream
|
||||
tts_request = backend_pb2.TTSRequest(text="VoxCPM is an innovative end-to-end TTS model from ModelBest. This is a streaming test.", dst="test_stream.wav")
|
||||
chunks_received = 0
|
||||
total_audio_bytes = 0
|
||||
|
||||
for reply in stub.TTSStream(tts_request):
|
||||
# Verify that we receive audio chunks
|
||||
if reply.audio:
|
||||
chunks_received += 1
|
||||
total_audio_bytes += len(reply.audio)
|
||||
self.assertGreater(len(reply.audio), 0, "Audio chunk should not be empty")
|
||||
|
||||
# Verify that we received multiple chunks
|
||||
self.assertGreater(chunks_received, 0, "Should receive at least one audio chunk")
|
||||
self.assertGreater(total_audio_bytes, 0, "Total audio bytes should be greater than 0")
|
||||
print(f"Received {chunks_received} chunks with {total_audio_bytes} total bytes")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTSStream service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
laudio "github.com/mudler/LocalAI/pkg/audio"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -74,3 +78,101 @@ func ModelTTS(
|
||||
|
||||
return filePath, res, err
|
||||
}
|
||||
|
||||
func ModelTTSStream(
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
audioCallback func([]byte) error,
|
||||
) error {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ttsModel == nil {
|
||||
return fmt.Errorf("could not load tts model %q", modelConfig.Model)
|
||||
}
|
||||
|
||||
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
|
||||
// This should be addressed in a follow up PR soon.
|
||||
// Copying it over nearly verbatim, as TTS backends are not functional without this.
|
||||
modelPath := ""
|
||||
// Checking first that it exists and is not outside ModelPath
|
||||
// TODO: we should actually first check if the modelFile is looking like
|
||||
// a FS path
|
||||
mp := filepath.Join(loader.ModelPath, modelConfig.Model)
|
||||
if _, err := os.Stat(mp); err == nil {
|
||||
if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
modelPath = mp
|
||||
} else {
|
||||
modelPath = modelConfig.Model // skip this step if it fails?????
|
||||
}
|
||||
|
||||
var sampleRate uint32 = 16000 // default
|
||||
headerSent := false
|
||||
var callbackErr error
|
||||
|
||||
err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Language: &language,
|
||||
}, func(reply *proto.Reply) {
|
||||
// First message contains sample rate info
|
||||
if !headerSent && len(reply.Message) > 0 {
|
||||
var info map[string]interface{}
|
||||
if json.Unmarshal(reply.Message, &info) == nil {
|
||||
if sr, ok := info["sample_rate"].(float64); ok {
|
||||
sampleRate = uint32(sr)
|
||||
}
|
||||
}
|
||||
// Send WAV header with placeholder size (0xFFFFFFFF for streaming)
|
||||
header := laudio.WAVHeader{
|
||||
ChunkID: [4]byte{'R', 'I', 'F', 'F'},
|
||||
ChunkSize: 0xFFFFFFFF, // Unknown size for streaming
|
||||
Format: [4]byte{'W', 'A', 'V', 'E'},
|
||||
Subchunk1ID: [4]byte{'f', 'm', 't', ' '},
|
||||
Subchunk1Size: 16,
|
||||
AudioFormat: 1, // PCM
|
||||
NumChannels: 1, // Mono
|
||||
SampleRate: sampleRate,
|
||||
ByteRate: sampleRate * 2, // SampleRate * BlockAlign
|
||||
BlockAlign: 2, // 16-bit = 2 bytes
|
||||
BitsPerSample: 16,
|
||||
Subchunk2ID: [4]byte{'d', 'a', 't', 'a'},
|
||||
Subchunk2Size: 0xFFFFFFFF, // Unknown size for streaming
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if writeErr := binary.Write(&buf, binary.LittleEndian, header); writeErr != nil {
|
||||
callbackErr = writeErr
|
||||
return
|
||||
}
|
||||
|
||||
if writeErr := audioCallback(buf.Bytes()); writeErr != nil {
|
||||
callbackErr = writeErr
|
||||
return
|
||||
}
|
||||
headerSent = true
|
||||
}
|
||||
|
||||
// Stream audio chunks
|
||||
if len(reply.Audio) > 0 {
|
||||
if writeErr := audioCallback(reply.Audio); writeErr != nil {
|
||||
callbackErr = writeErr
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if callbackErr != nil {
|
||||
return callbackErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -50,6 +50,31 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
cfg.Voice = input.Voice
|
||||
}
|
||||
|
||||
// Handle streaming TTS
|
||||
if input.Stream {
|
||||
// Set headers for streaming audio
|
||||
c.Response().Header().Set("Content-Type", "audio/wav")
|
||||
c.Response().Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Stream audio chunks as they're generated
|
||||
err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
_, writeErr := c.Response().Write(audioChunk)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Non-streaming TTS (existing behavior)
|
||||
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -53,6 +53,7 @@ type TTSRequest struct {
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
|
||||
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
|
||||
Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"` // (optional) enable streaming TTS
|
||||
}
|
||||
|
||||
// @Description VAD request body
|
||||
|
||||
@@ -29,6 +29,41 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
|
||||
Returns an `audio/wav` file.
|
||||
|
||||
## Streaming TTS
|
||||
|
||||
LocalAI supports streaming TTS generation, allowing audio to be played as it's generated. This is useful for real-time applications and reduces latency.
|
||||
|
||||
To enable streaming, add `"stream": true` to your request:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
"input": "Hello world, this is a streaming test",
|
||||
"model": "voxcpm",
|
||||
"stream": true
|
||||
}' | aplay
|
||||
```
|
||||
|
||||
The audio will be streamed chunk-by-chunk as it's generated, allowing playback to start before generation completes. This is particularly useful for long texts or when you want to minimize perceived latency.
|
||||
|
||||
You can also pipe the streamed audio directly to audio players like `aplay` (Linux) or save it to a file:
|
||||
|
||||
```bash
|
||||
# Stream to aplay (Linux)
|
||||
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
"input": "This is a longer text that will be streamed as it is generated",
|
||||
"model": "voxcpm",
|
||||
"stream": true
|
||||
}' | aplay
|
||||
|
||||
# Stream to a file
|
||||
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
"input": "Streaming audio to file",
|
||||
"model": "voxcpm",
|
||||
"stream": true
|
||||
}' > output.wav
|
||||
```
|
||||
|
||||
Note: Streaming TTS is currently supported by the `voxcpm` backend. Other backends will fall back to non-streaming mode if streaming is not supported.
|
||||
|
||||
## Backends
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ type Backend interface {
|
||||
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error
|
||||
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error)
|
||||
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
|
||||
|
||||
@@ -65,6 +65,10 @@ func (llm *Base) TTS(*pb.TTSRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) TTSStream(*pb.TTSRequest, chan []byte) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) SoundGeneration(*pb.SoundGenerationRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
@@ -270,6 +270,56 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
|
||||
return client.TTS(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
stream, err := client.TTSStream(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
// Check if context is cancelled before receiving
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
reply, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
// Check if error is due to context cancellation
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
f(reply)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
|
||||
@@ -55,6 +55,14 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.
|
||||
return e.s.TTS(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
|
||||
bs := &embedBackendServerStream{
|
||||
ctx: ctx,
|
||||
fn: f,
|
||||
}
|
||||
return e.s.TTSStream(in, bs)
|
||||
}
|
||||
|
||||
func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.SoundGeneration(ctx, in)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ type AIModel interface {
|
||||
Detect(*pb.DetectOptions) (pb.DetectResponse, error)
|
||||
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
|
||||
TTS(*pb.TTSRequest) error
|
||||
TTSStream(*pb.TTSRequest, chan []byte) error
|
||||
SoundGeneration(*pb.SoundGenerationRequest) error
|
||||
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
||||
Status() (pb.StatusResponse, error)
|
||||
|
||||
@@ -99,6 +99,27 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error)
|
||||
return &pb.Result{Message: "TTS audio generated", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) TTSStream(in *pb.TTSRequest, stream pb.Backend_TTSStreamServer) error {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
audioChan := make(chan []byte)
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
for audioChunk := range audioChan {
|
||||
stream.Send(&pb.Reply{Audio: audioChunk})
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
err := s.llm.TTSStream(in, audioChan)
|
||||
<-done
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
|
||||
Reference in New Issue
Block a user