mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
refactor(backend): propagate request ctx into biometric, detection, rerank, diarization paths
Replaces remaining context.Background() sites in core/backend with the caller's ctx. After this commit, every core/backend/*.go entry point threads the request ctx end-to-end to the gRPC client. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func Detection(
|
||||
ctx context.Context,
|
||||
sourceFile string,
|
||||
prompt string,
|
||||
points []float32,
|
||||
@@ -38,7 +39,7 @@ func Detection(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
||||
res, err := detectionModel.Detect(ctx, &proto.DetectOptions{
|
||||
Src: sourceFile,
|
||||
Prompt: prompt,
|
||||
Points: points,
|
||||
|
||||
@@ -63,7 +63,7 @@ func loadDiarizationModel(ml *model.ModelLoader, modelConfig config.ModelConfig,
|
||||
|
||||
// ModelDiarization runs the Diarize RPC against the configured backend
|
||||
// and returns a normalized schema.DiarizationResult.
|
||||
func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) {
|
||||
func ModelDiarization(ctx context.Context, req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) {
|
||||
m, err := loadDiarizationModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -74,7 +74,7 @@ func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig
|
||||
threads = uint32(*modelConfig.Threads)
|
||||
}
|
||||
|
||||
r, err := m.Diarize(context.Background(), req.toProto(threads))
|
||||
r, err := m.Diarize(ctx, req.toProto(threads))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func FaceAnalyze(
|
||||
ctx context.Context,
|
||||
img string,
|
||||
actions []string,
|
||||
antiSpoofing bool,
|
||||
@@ -35,7 +36,7 @@ func FaceAnalyze(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := faceModel.FaceAnalyze(context.Background(), &proto.FaceAnalyzeRequest{
|
||||
res, err := faceModel.FaceAnalyze(ctx, &proto.FaceAnalyzeRequest{
|
||||
Img: img,
|
||||
Actions: actions,
|
||||
AntiSpoofing: antiSpoofing,
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
// backend picks the highest-confidence face and returns its
|
||||
// L2-normalized embedding.
|
||||
func FaceEmbed(
|
||||
ctx context.Context,
|
||||
imgBase64 string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
@@ -32,7 +33,7 @@ func FaceEmbed(
|
||||
predictOpts := gRPCPredictOpts(modelConfig, loader.ModelPath)
|
||||
predictOpts.Images = []string{imgBase64}
|
||||
|
||||
res, err := faceModel.Embeddings(context.Background(), predictOpts)
|
||||
res, err := faceModel.Embeddings(ctx, predictOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func FaceVerify(
|
||||
ctx context.Context,
|
||||
img1, img2 string,
|
||||
threshold float32,
|
||||
antiSpoofing bool,
|
||||
@@ -35,7 +36,7 @@ func FaceVerify(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := faceModel.FaceVerify(context.Background(), &proto.FaceVerifyRequest{
|
||||
res, err := faceModel.FaceVerify(ctx, &proto.FaceVerifyRequest{
|
||||
Img1: img1,
|
||||
Img2: img2,
|
||||
Threshold: threshold,
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
|
||||
func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
rerankModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
@@ -29,7 +29,7 @@ func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := rerankModel.Rerank(context.Background(), request)
|
||||
res, err := rerankModel.Rerank(ctx, request)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TokenMetrics(
|
||||
ctx context.Context,
|
||||
modelFile string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
@@ -26,7 +27,7 @@ func TokenMetrics(
|
||||
return nil, fmt.Errorf("could not loadmodel model")
|
||||
}
|
||||
|
||||
res, err := model.GetTokenMetrics(context.Background(), &proto.MetricsRequest{})
|
||||
res, err := model.GetTokenMetrics(ctx, &proto.MetricsRequest{})
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func VoiceAnalyze(
|
||||
ctx context.Context,
|
||||
audio string,
|
||||
actions []string,
|
||||
loader *model.ModelLoader,
|
||||
@@ -34,7 +35,7 @@ func VoiceAnalyze(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceAnalyze(context.Background(), &proto.VoiceAnalyzeRequest{
|
||||
res, err := voiceModel.VoiceAnalyze(ctx, &proto.VoiceAnalyzeRequest{
|
||||
Audio: audio,
|
||||
Actions: actions,
|
||||
})
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
// OpenAI-compatible and text-only), this call takes an audio path and
|
||||
// returns the backend's speaker-encoder output.
|
||||
func VoiceEmbed(
|
||||
ctx context.Context,
|
||||
audioPath string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
@@ -37,7 +38,7 @@ func VoiceEmbed(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceEmbed(context.Background(), &proto.VoiceEmbedRequest{
|
||||
res, err := voiceModel.VoiceEmbed(ctx, &proto.VoiceEmbedRequest{
|
||||
Audio: audioPath,
|
||||
})
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func VoiceVerify(
|
||||
ctx context.Context,
|
||||
audio1, audio2 string,
|
||||
threshold float32,
|
||||
antiSpoofing bool,
|
||||
@@ -35,7 +36,7 @@ func VoiceVerify(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceVerify(context.Background(), &proto.VoiceVerifyRequest{
|
||||
res, err := voiceModel.VoiceVerify(ctx, &proto.VoiceVerifyRequest{
|
||||
Audio1: audio1,
|
||||
Audio2: audio2,
|
||||
Threshold: threshold,
|
||||
|
||||
@@ -52,7 +52,7 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
Documents: input.Documents,
|
||||
}
|
||||
|
||||
results, err := backend.Rerank(request, ml, appConfig, *cfg)
|
||||
results, err := backend.Rerank(c.Request().Context(), request, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := backend.Detection(image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg)
|
||||
res, err := backend.Detection(c.Request().Context(), image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func FaceAnalyzeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
}
|
||||
|
||||
xlog.Debug("FaceAnalyze", "model", cfg.Name, "backend", cfg.Backend, "actions", input.Actions)
|
||||
res, err := backend.FaceAnalyze(img, input.Actions, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
res, err := backend.FaceAnalyze(c.Request().Context(), img, input.Actions, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func FaceEmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
}
|
||||
|
||||
xlog.Debug("FaceEmbed", "model", cfg.Name, "backend", cfg.Backend)
|
||||
vec, err := backend.FaceEmbed(img, ml, appConfig, *cfg)
|
||||
vec, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func FaceIdentifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
threshold := cmp.Or(input.Threshold, defaultIdentifyThreshold)
|
||||
|
||||
xlog.Debug("FaceIdentify", "model", cfg.Name, "topK", topK, "threshold", threshold)
|
||||
probe, err := backend.FaceEmbed(img, ml, appConfig, *cfg)
|
||||
probe, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func FaceRegisterEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
}
|
||||
|
||||
xlog.Debug("FaceRegister", "model", cfg.Name, "name", input.Name)
|
||||
embedding, err := backend.FaceEmbed(img, ml, appConfig, *cfg)
|
||||
embedding, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func FaceVerifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
}
|
||||
|
||||
xlog.Debug("FaceVerify", "model", cfg.Name, "backend", cfg.Backend)
|
||||
res, err := backend.FaceVerify(img1, img2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
res, err := backend.FaceVerify(c.Request().Context(), img1, img2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
}
|
||||
xlog.Debug("Token Metrics for model", "model", modelFile)
|
||||
|
||||
response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg)
|
||||
response, err := backend.TokenMetrics(c.Request().Context(), modelFile, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func VoiceAnalyzeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
defer cleanup()
|
||||
|
||||
xlog.Debug("VoiceAnalyze", "model", cfg.Name, "backend", cfg.Backend, "actions", input.Actions)
|
||||
res, err := backend.VoiceAnalyze(audio, input.Actions, ml, appConfig, *cfg)
|
||||
res, err := backend.VoiceAnalyze(c.Request().Context(), audio, input.Actions, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func VoiceEmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
defer cleanup()
|
||||
|
||||
xlog.Debug("VoiceEmbed", "model", cfg.Name, "backend", cfg.Backend)
|
||||
res, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg)
|
||||
res, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func VoiceIdentifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
threshold := cmp.Or(input.Threshold, defaultVoiceIdentifyThreshold)
|
||||
|
||||
xlog.Debug("VoiceIdentify", "model", cfg.Name, "topK", topK, "threshold", threshold)
|
||||
embed, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg)
|
||||
embed, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func VoiceRegisterEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
defer cleanup()
|
||||
|
||||
xlog.Debug("VoiceRegister", "model", cfg.Name, "name", input.Name)
|
||||
res, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg)
|
||||
res, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func VoiceVerifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
defer cleanup2()
|
||||
|
||||
xlog.Debug("VoiceVerify", "model", cfg.Name, "backend", cfg.Backend)
|
||||
res, err := backend.VoiceVerify(audio1, audio2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
res, err := backend.VoiceVerify(c.Request().Context(), audio1, audio2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func DiarizationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
_ = dstFile.Close()
|
||||
req.Audio = dst
|
||||
|
||||
result, err := backend.ModelDiarization(req, ml, *modelConfig, appConfig)
|
||||
result, err := backend.ModelDiarization(c.Request().Context(), req, ml, *modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user