mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-18 13:38:49 -04:00
feat: include tokens usage for streamed output (#4282)
Use pb.Reply instead of []byte with Reply.GetMessage() in llama grpc to get the proper usage data in reply streaming mode at the last [DONE] frame Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
@@ -37,7 +37,7 @@ type Backend interface {
|
||||
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
|
||||
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
|
||||
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
|
||||
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
|
||||
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
|
||||
@@ -136,7 +136,7 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
|
||||
return client.LoadModel(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
|
||||
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
@@ -158,7 +158,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
|
||||
}
|
||||
|
||||
for {
|
||||
feature, err := stream.Recv()
|
||||
reply, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
@@ -167,7 +167,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
|
||||
|
||||
return err
|
||||
}
|
||||
f(feature.GetMessage())
|
||||
f(reply)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -35,7 +35,7 @@ func (e *embedBackend) LoadModel(ctx context.Context, in *pb.ModelOptions, opts
|
||||
return e.s.LoadModel(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
|
||||
func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
|
||||
bs := &embedBackendServerStream{
|
||||
ctx: ctx,
|
||||
fn: f,
|
||||
@@ -97,11 +97,11 @@ func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsReques
|
||||
|
||||
type embedBackendServerStream struct {
|
||||
ctx context.Context
|
||||
fn func(s []byte)
|
||||
fn func(reply *pb.Reply)
|
||||
}
|
||||
|
||||
func (e *embedBackendServerStream) Send(reply *pb.Reply) error {
|
||||
e.fn(reply.GetMessage())
|
||||
e.fn(reply)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user