mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-08 07:56:06 -04:00
feat(tts): add support for streaming mode (#8291)
* feat(tts): add support for streaming mode Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Send first audio, make sure it's 16 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
2c44b06a67
commit
68dd9765a0
@@ -41,6 +41,7 @@ type Backend interface {
|
||||
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error
|
||||
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error)
|
||||
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
|
||||
|
||||
@@ -65,6 +65,10 @@ func (llm *Base) TTS(*pb.TTSRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) TTSStream(*pb.TTSRequest, chan []byte) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) SoundGeneration(*pb.SoundGenerationRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
@@ -270,6 +270,56 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
|
||||
return client.TTS(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
stream, err := client.TTSStream(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
// Check if context is cancelled before receiving
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
reply, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
// Check if error is due to context cancellation
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
f(reply)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
|
||||
@@ -55,6 +55,14 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.
|
||||
return e.s.TTS(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
|
||||
bs := &embedBackendServerStream{
|
||||
ctx: ctx,
|
||||
fn: f,
|
||||
}
|
||||
return e.s.TTSStream(in, bs)
|
||||
}
|
||||
|
||||
func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.SoundGeneration(ctx, in)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ type AIModel interface {
|
||||
Detect(*pb.DetectOptions) (pb.DetectResponse, error)
|
||||
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
|
||||
TTS(*pb.TTSRequest) error
|
||||
TTSStream(*pb.TTSRequest, chan []byte) error
|
||||
SoundGeneration(*pb.SoundGenerationRequest) error
|
||||
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
||||
Status() (pb.StatusResponse, error)
|
||||
|
||||
@@ -99,6 +99,27 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error)
|
||||
return &pb.Result{Message: "TTS audio generated", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) TTSStream(in *pb.TTSRequest, stream pb.Backend_TTSStreamServer) error {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
audioChan := make(chan []byte)
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
for audioChunk := range audioChan {
|
||||
stream.Send(&pb.Reply{Audio: audioChunk})
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
err := s.llm.TTSStream(in, audioChan)
|
||||
<-done
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
|
||||
Reference in New Issue
Block a user