mirror of
https://github.com/mudler/LocalAI.git
synced 2026-07-03 04:46:54 -04:00
Image generation (and the tts/transcript/embeddings/vad/rerank/llm helpers) pass the request context to loader.Load so distributed routing decisions reach the request's X-LocalAI-Node holder. That context also governs cancellation of the load, so when a client disconnects mid-load the LoadModel RPC is aborted, stopLoadProcess tears down the backend process, and every retry restarts from scratch. Heavy diffusers/LLM models on a slow host (e.g. a shared-memory iGPU) take long enough to load that the request routinely ends first, so the model never finishes loading and the UI shows "NetworkError when attempting to fetch resource". Wrap the load context with context.WithoutCancel: the routing holder value still propagates, but the request's cancellation no longer aborts the load, so it runs to completion and caches for the next request. Inference keeps the cancellable request context, so a disconnect still stops generation. Adds a regression spec asserting a canceled request context does not cancel the model load while the routing holder still reaches the router. Fixes #10636 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code]
219 lines
7.1 KiB
Go
219 lines
7.1 KiB
Go
package backend
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"maps"
|
|
"time"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/core/trace"
|
|
|
|
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
"github.com/mudler/LocalAI/pkg/model"
|
|
)
|
|
|
|
// TranscriptionRequest groups the parameters accepted by ModelTranscription.
|
|
// Use this so callers don't have to pass long positional arg lists when they
|
|
// only care about a subset of fields.
|
|
type TranscriptionRequest struct {
|
|
Audio string
|
|
Language string
|
|
Translate bool
|
|
Diarize bool
|
|
Prompt string
|
|
Temperature float32
|
|
TimestampGranularities []string
|
|
}
|
|
|
|
func (r *TranscriptionRequest) toProto(threads uint32) *proto.TranscriptRequest {
|
|
return &proto.TranscriptRequest{
|
|
Dst: r.Audio,
|
|
Language: r.Language,
|
|
Translate: r.Translate,
|
|
Diarize: r.Diarize,
|
|
Threads: threads,
|
|
Prompt: r.Prompt,
|
|
Temperature: r.Temperature,
|
|
TimestampGranularities: r.TimestampGranularities,
|
|
}
|
|
}
|
|
|
|
func loadTranscriptionModel(ctx context.Context, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
|
if modelConfig.Backend == "" {
|
|
modelConfig.Backend = model.WhisperBackend
|
|
}
|
|
// model.WithContext carries the request context into the load so distributed
|
|
// routing decisions reach the request's X-LocalAI-Node holder via
|
|
// distributedhdr.Stamp. context.WithoutCancel keeps those values but drops
|
|
// the request's cancellation, so a slow first load still completes and
|
|
// caches if the client disconnects instead of aborting the LoadModel RPC and
|
|
// tearing down the backend process (issue #10636). Inference below keeps the
|
|
// cancellable ctx, so a disconnect still stops generation.
|
|
opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx)))
|
|
transcriptionModel, err := ml.Load(opts...)
|
|
if err != nil {
|
|
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
|
return nil, err
|
|
}
|
|
if transcriptionModel == nil {
|
|
return nil, fmt.Errorf("could not load transcription model")
|
|
}
|
|
return transcriptionModel, nil
|
|
}
|
|
|
|
func ModelTranscription(ctx context.Context, audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
|
return ModelTranscriptionWithOptions(ctx, TranscriptionRequest{
|
|
Audio: audio,
|
|
Language: language,
|
|
Translate: translate,
|
|
Diarize: diarize,
|
|
Prompt: prompt,
|
|
}, ml, modelConfig, appConfig)
|
|
}
|
|
|
|
func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
|
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var startTime time.Time
|
|
var audioSnippet map[string]any
|
|
if appConfig.EnableTracing {
|
|
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
|
startTime = time.Now()
|
|
// Capture audio before the backend call — the backend may delete the file.
|
|
audioSnippet = trace.AudioSnippet(req.Audio, appConfig.TracingMaxBodyBytes)
|
|
}
|
|
|
|
r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
|
|
if err != nil {
|
|
if appConfig.EnableTracing {
|
|
errData := map[string]any{
|
|
"audio_file": req.Audio,
|
|
"language": req.Language,
|
|
"translate": req.Translate,
|
|
"diarize": req.Diarize,
|
|
"prompt": req.Prompt,
|
|
}
|
|
if audioSnippet != nil {
|
|
maps.Copy(errData, audioSnippet)
|
|
}
|
|
trace.RecordBackendTrace(trace.BackendTrace{
|
|
Timestamp: startTime,
|
|
Duration: time.Since(startTime),
|
|
Type: trace.BackendTraceTranscription,
|
|
ModelName: modelConfig.Name,
|
|
Backend: modelConfig.Backend,
|
|
Summary: trace.TruncateString(req.Audio, 200),
|
|
Error: err.Error(),
|
|
Data: errData,
|
|
})
|
|
}
|
|
return nil, err
|
|
}
|
|
tr := transcriptResultFromProto(r)
|
|
|
|
if appConfig.EnableTracing {
|
|
data := map[string]any{
|
|
"audio_file": req.Audio,
|
|
"language": req.Language,
|
|
"translate": req.Translate,
|
|
"diarize": req.Diarize,
|
|
"prompt": req.Prompt,
|
|
"result_text": tr.Text,
|
|
"segments_count": len(tr.Segments),
|
|
}
|
|
if audioSnippet != nil {
|
|
maps.Copy(data, audioSnippet)
|
|
}
|
|
trace.RecordBackendTrace(trace.BackendTrace{
|
|
Timestamp: startTime,
|
|
Duration: time.Since(startTime),
|
|
Type: trace.BackendTraceTranscription,
|
|
ModelName: modelConfig.Name,
|
|
Backend: modelConfig.Backend,
|
|
Summary: trace.TruncateString(req.Audio+" -> "+tr.Text, 200),
|
|
Data: data,
|
|
})
|
|
}
|
|
|
|
return tr, err
|
|
}
|
|
|
|
// TranscriptionStreamChunk is a streaming event emitted by
|
|
// ModelTranscriptionStream. Either Delta carries an incremental text fragment,
|
|
// or Final carries the completed transcription as the very last event.
|
|
type TranscriptionStreamChunk struct {
|
|
Delta string
|
|
Final *schema.TranscriptionResult
|
|
}
|
|
|
|
// ModelTranscriptionStream runs the gRPC streaming transcription RPC and
|
|
// invokes onChunk for each event the backend produces. Backends that don't
|
|
// support real streaming should still emit one terminal event with Final set,
|
|
// which the HTTP layer turns into a single delta + done SSE pair.
|
|
func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
|
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pbReq := req.toProto(uint32(*modelConfig.Threads))
|
|
pbReq.Stream = true
|
|
|
|
return transcriptionModel.AudioTranscriptionStream(ctx, pbReq, func(chunk *proto.TranscriptStreamResponse) {
|
|
if chunk == nil {
|
|
return
|
|
}
|
|
out := TranscriptionStreamChunk{Delta: chunk.Delta}
|
|
if chunk.FinalResult != nil {
|
|
out.Final = transcriptResultFromProto(chunk.FinalResult)
|
|
}
|
|
onChunk(out)
|
|
})
|
|
}
|
|
|
|
func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionResult {
|
|
if r == nil {
|
|
return &schema.TranscriptionResult{}
|
|
}
|
|
tr := &schema.TranscriptionResult{
|
|
Text: r.Text,
|
|
Language: r.Language,
|
|
Duration: float64(r.Duration),
|
|
Eou: r.Eou,
|
|
}
|
|
|
|
for _, s := range r.Segments {
|
|
var tks []int
|
|
for _, t := range s.Tokens {
|
|
tks = append(tks, int(t))
|
|
}
|
|
var words []schema.TranscriptionWord
|
|
for _, w := range s.Words {
|
|
var word = schema.TranscriptionWord{
|
|
Start: time.Duration(w.Start),
|
|
End: time.Duration(w.End),
|
|
Text: w.Text,
|
|
}
|
|
words = append(words, word)
|
|
tr.Words = append(tr.Words, word)
|
|
}
|
|
tr.Segments = append(tr.Segments,
|
|
schema.TranscriptionSegment{
|
|
Text: s.Text,
|
|
Id: int(s.Id),
|
|
Start: time.Duration(s.Start),
|
|
End: time.Duration(s.End),
|
|
Tokens: tks,
|
|
Speaker: s.Speaker,
|
|
Words: words,
|
|
})
|
|
}
|
|
return tr
|
|
}
|