package grpc import ( "context" "errors" "fmt" "io" "sync" "time" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) const maxGRPCMessageSize = 50 * 1024 * 1024 // 50MB // bearerToken implements credentials.PerRPCCredentials to inject a bearer token // into every gRPC call. type bearerToken struct { token string } func (b bearerToken) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { return map[string]string{"authorization": "Bearer " + b.token}, nil } func (b bearerToken) RequireTransportSecurity() bool { return false } type Client struct { address string busy bool parallel bool token string sync.Mutex opMutex sync.Mutex wd WatchDog } type WatchDog interface { Mark(address string) UnMark(address string) } func (c *Client) IsBusy() bool { c.Lock() defer c.Unlock() return c.busy } func (c *Client) setBusy(v bool) { c.Lock() c.busy = v c.Unlock() } func (c *Client) wdMark() { if c.wd != nil { c.wd.Mark(c.address) } } func (c *Client) wdUnMark() { if c.wd != nil { c.wd.UnMark(c.address) } } // dial creates a gRPC client connection with common options. // If c.token is set, bearer token credentials are included. func (c *Client) dial() (*grpc.ClientConn, error) { opts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions( grpc.MaxCallRecvMsgSize(maxGRPCMessageSize), grpc.MaxCallSendMsgSize(maxGRPCMessageSize), ), } if c.token != "" { opts = append(opts, grpc.WithPerRPCCredentials(bearerToken{token: c.token})) } return grpc.NewClient(c.address, opts...) } func (c *Client) HealthCheck(ctx context.Context) (bool, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) conn, err := c.dial() if err != nil { return false, err } defer conn.Close() client := pb.NewBackendClient(conn) // The healthcheck call shouldn't take long time ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() res, err := client.Health(ctx, &pb.HealthMessage{}) if err != nil { return false, err } if string(res.Message) == "OK" { return true, nil } return false, fmt.Errorf("health check failed: %s", res.Message) } func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.Embedding(ctx, in, opts...) } func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.Predict(ctx, in, opts...) } func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.LoadModel(ctx, in, opts...) } func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return err } defer conn.Close() client := pb.NewBackendClient(conn) stream, err := client.PredictStream(ctx, in, opts...) if err != nil { return err } for { // Check if context is cancelled before receiving select { case <-ctx.Done(): return ctx.Err() default: } reply, err := stream.Recv() if errors.Is(err, io.EOF) { break } if err != nil { // Check if error is due to context cancellation if ctx.Err() != nil { return ctx.Err() } fmt.Println("Error", err) return err } f(reply) } return nil } func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.GenerateImage(ctx, in, opts...) } func (c *Client) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.GenerateVideo(ctx, in, opts...) } func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.TTS(ctx, in, opts...) } func (c *Client) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return err } defer conn.Close() client := pb.NewBackendClient(conn) stream, err := client.TTSStream(ctx, in, opts...) if err != nil { return err } for { // Check if context is cancelled before receiving select { case <-ctx.Done(): return ctx.Err() default: } reply, err := stream.Recv() if errors.Is(err, io.EOF) { break } if err != nil { // Check if error is due to context cancellation if ctx.Err() != nil { return ctx.Err() } return err } f(reply) } return nil } func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.SoundGeneration(ctx, in, opts...) } func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.AudioTranscription(ctx, in, opts...) } func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) res, err := client.TokenizeString(ctx, in, opts...) if err != nil { return nil, err } return res, nil } func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.Status(ctx, &pb.HealthMessage{}) } func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.StoresSet(ctx, in, opts...) } func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.wdMark() defer c.wdUnMark() c.setBusy(true) defer c.setBusy(false) conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.StoresDelete(ctx, in, opts...) } func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.StoresGet(ctx, in, opts...) } func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.StoresFind(ctx, in, opts...) } func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.Rerank(ctx, in, opts...) } func (c *Client) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.GetMetrics(ctx, in, opts...) } func (c *Client) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.VAD(ctx, in, opts...) } func (c *Client) Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.Detect(ctx, in, opts...) } func (c *Client) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.AudioEncode(ctx, in, opts...) } func (c *Client) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.AudioDecode(ctx, in, opts...) } func (c *Client) StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.StartFineTune(ctx, in, opts...) } func (c *Client) FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return err } defer conn.Close() client := pb.NewBackendClient(conn) stream, err := client.FineTuneProgress(ctx, in, opts...) if err != nil { return err } for { select { case <-ctx.Done(): return ctx.Err() default: } update, err := stream.Recv() if errors.Is(err, io.EOF) { break } if err != nil { if ctx.Err() != nil { return ctx.Err() } return err } f(update) } return nil } func (c *Client) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.StopFineTune(ctx, in, opts...) } func (c *Client) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.ListCheckpoints(ctx, in, opts...) } func (c *Client) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.ExportModel(ctx, in, opts...) } func (c *Client) StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.StartQuantization(ctx, in, opts...) } func (c *Client) QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return err } defer conn.Close() client := pb.NewBackendClient(conn) stream, err := client.QuantizationProgress(ctx, in, opts...) if err != nil { return err } for { select { case <-ctx.Done(): return ctx.Err() default: } update, err := stream.Recv() if errors.Is(err, io.EOF) { break } if err != nil { if ctx.Err() != nil { return ctx.Err() } return err } f(update) } return nil } func (c *Client) StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.StopQuantization(ctx, in, opts...) } func (c *Client) Free(ctx context.Context) error { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return err } defer conn.Close() client := pb.NewBackendClient(conn) _, err = client.Free(ctx, &pb.HealthMessage{}) return err } func (c *Client) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() } c.setBusy(true) defer c.setBusy(false) c.wdMark() defer c.wdUnMark() conn, err := c.dial() if err != nil { return nil, err } defer conn.Close() client := pb.NewBackendClient(conn) return client.ModelMetadata(ctx, in, opts...) }