feat(tts): add support for streaming mode (#8291)

* feat(tts): add support for streaming mode

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Send first audio, make sure it's 16

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-01-30 11:58:01 +01:00
committed by GitHub
parent 2c44b06a67
commit 68dd9765a0
13 changed files with 369 additions and 0 deletions

View File

@@ -41,6 +41,7 @@ type Backend interface {
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)

View File

@@ -65,6 +65,10 @@ func (llm *Base) TTS(*pb.TTSRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) TTSStream(*pb.TTSRequest, chan []byte) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) SoundGeneration(*pb.SoundGenerationRequest) error {
return fmt.Errorf("unimplemented")
}

View File

@@ -270,6 +270,56 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
return client.TTS(ctx, in, opts...)
}
func (c *Client) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
))
if err != nil {
return err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
stream, err := client.TTSStream(ctx, in, opts...)
if err != nil {
return err
}
for {
// Check if context is cancelled before receiving
select {
case <-ctx.Done():
return ctx.Err()
default:
}
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
// Check if error is due to context cancellation
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
f(reply)
}
return nil
}
func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()

View File

@@ -55,6 +55,14 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.
return e.s.TTS(ctx, in)
}
func (e *embedBackend) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
bs := &embedBackendServerStream{
ctx: ctx,
fn: f,
}
return e.s.TTSStream(in, bs)
}
func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.SoundGeneration(ctx, in)
}

View File

@@ -18,6 +18,7 @@ type AIModel interface {
Detect(*pb.DetectOptions) (pb.DetectResponse, error)
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
TTS(*pb.TTSRequest) error
TTSStream(*pb.TTSRequest, chan []byte) error
SoundGeneration(*pb.SoundGenerationRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
Status() (pb.StatusResponse, error)

View File

@@ -99,6 +99,27 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error)
return &pb.Result{Message: "TTS audio generated", Success: true}, nil
}
func (s *server) TTSStream(in *pb.TTSRequest, stream pb.Backend_TTSStreamServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
audioChan := make(chan []byte)
done := make(chan bool)
go func() {
for audioChunk := range audioChan {
stream.Send(&pb.Reply{Audio: audioChunk})
}
done <- true
}()
err := s.llm.TTSStream(in, audioChan)
<-done
return err
}
func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()