mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-29 19:19:19 -04:00
* feat(distributed): add per-request node ID context holder Introduce pkg/distributedhdr, a leaf package carrying a per-request *atomic.Value holder for the picked worker node ID from the SmartRouter (core/services/nodes) up to the HTTP response writer wrapper (core/http/middleware). Avoids the import cycle that a shared key in either consumer would create. Exposes NewHolder, WithHolder, Holder, Stamp, Load, Inherit. The holder is atomic.Value so cross-goroutine publish from the router to the response writer wrapper is race-clean. Assisted-by: Claude:claude-opus-4-7[1m] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): add ExposeNodeHeader middleware + response writer wrapper New ApplicationConfig.ExposeNodeHeader bool + --expose-node-header CLI flag / LOCALAI_EXPOSE_NODE_HEADER env var (default off; the node ID reveals internal topology and is opt-in). The middleware creates a per-request *atomic.Value holder, attaches it to c.Request().Context() via distributedhdr.WithHolder, and wraps c.Response().Writer with a custom http.ResponseWriter that sets the X-LocalAI-Node header on first Write / WriteHeader / Flush by reading the holder. Implements http.Flusher, http.Hijacker, Unwrap so it composes cleanly with Echo and http.NewResponseController. request.go propagates the holder onto derived contexts via distributedhdr.Inherit so the holder survives the correlation-ID context replacement. Unit + race-clean concurrency + integration specs. Assisted-by: Claude:claude-opus-4-7[1m] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): stamp node ID in router and wire middleware to inference routes ModelRouterAdapter.Route stamps the picked node ID into the per-request holder via distributedhdr.Stamp(ctx, result.Node.ID) right after replica selection. Wire ExposeNodeHeader middleware to: - OpenAI chat/completion/embeddings + audio transcriptions/speech + image generations/inpainting - Anthropic /v1/messages - Ollama /api/chat, /api/generate, /api/embed, /api/embeddings - Jina /v1/rerank - LocalAI /v1/vad The middleware's wrapper reads the holder on first byte and sets the X-LocalAI-Node response header before delegating to the underlying writer. Per-request scope means no race under concurrent multi-replica routing. Assisted-by: Claude:claude-opus-4-7[1m] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(distributed): thread request context through backend Load + cover ctx propagation Five non-OpenAI backend helpers were silently using app.Context instead of the request context for the gRPC backend call: transcription, TTS, image generation, rerank, VAD. Effect: distributedhdr.Stamp in the router callback was a silent no-op for these paths, AND client cancellation didn't propagate to in-flight inference. Thread c.Request().Context() (or the equivalent input.Context after the request middleware has installed the correlation-ID derived context) through each helper and into ModelOptions via model.WithContext(ctx). ImageGeneration's signature gains a leading ctx parameter; in-tree callers (openai image, openai inpainting, openai inpainting_test) are updated to match. ModelEmbedding gains a leading ctx parameter for the same reason; the openai and ollama embedding handlers pass the request context through. chat_stream_workers.go defers the initial role=assistant chunk emission until the first token callback so the wrapper's lazy X-LocalAI-Node lookup against the loader runs AFTER ml.Load has stamped the per-modelID node ID; semantically identical for clients (role still arrives before any text). Regression test core/backend/ctx_propagation_test.go pins ctx propagation for all five helpers. Docs updated to enumerate the full endpoint coverage of the --expose-node-header flag. Assisted-by: Claude:claude-opus-4-7[1m] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
214 lines
6.8 KiB
Go
214 lines
6.8 KiB
Go
package backend
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"maps"
|
|
"time"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/core/trace"
|
|
|
|
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
"github.com/mudler/LocalAI/pkg/model"
|
|
)
|
|
|
|
// TranscriptionRequest groups the parameters accepted by ModelTranscription.
|
|
// Use this so callers don't have to pass long positional arg lists when they
|
|
// only care about a subset of fields.
|
|
type TranscriptionRequest struct {
|
|
Audio string
|
|
Language string
|
|
Translate bool
|
|
Diarize bool
|
|
Prompt string
|
|
Temperature float32
|
|
TimestampGranularities []string
|
|
}
|
|
|
|
func (r *TranscriptionRequest) toProto(threads uint32) *proto.TranscriptRequest {
|
|
return &proto.TranscriptRequest{
|
|
Dst: r.Audio,
|
|
Language: r.Language,
|
|
Translate: r.Translate,
|
|
Diarize: r.Diarize,
|
|
Threads: threads,
|
|
Prompt: r.Prompt,
|
|
Temperature: r.Temperature,
|
|
TimestampGranularities: r.TimestampGranularities,
|
|
}
|
|
}
|
|
|
|
func loadTranscriptionModel(ctx context.Context, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
|
if modelConfig.Backend == "" {
|
|
modelConfig.Backend = model.WhisperBackend
|
|
}
|
|
// model.WithContext(ctx) overrides the app-context default set in
|
|
// ModelOptions so distributed routing decisions reach the request's
|
|
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
|
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
|
transcriptionModel, err := ml.Load(opts...)
|
|
if err != nil {
|
|
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
|
return nil, err
|
|
}
|
|
if transcriptionModel == nil {
|
|
return nil, fmt.Errorf("could not load transcription model")
|
|
}
|
|
return transcriptionModel, nil
|
|
}
|
|
|
|
func ModelTranscription(ctx context.Context, audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
|
return ModelTranscriptionWithOptions(ctx, TranscriptionRequest{
|
|
Audio: audio,
|
|
Language: language,
|
|
Translate: translate,
|
|
Diarize: diarize,
|
|
Prompt: prompt,
|
|
}, ml, modelConfig, appConfig)
|
|
}
|
|
|
|
func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
|
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var startTime time.Time
|
|
var audioSnippet map[string]any
|
|
if appConfig.EnableTracing {
|
|
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
|
startTime = time.Now()
|
|
// Capture audio before the backend call — the backend may delete the file.
|
|
audioSnippet = trace.AudioSnippet(req.Audio, appConfig.TracingMaxBodyBytes)
|
|
}
|
|
|
|
r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
|
|
if err != nil {
|
|
if appConfig.EnableTracing {
|
|
errData := map[string]any{
|
|
"audio_file": req.Audio,
|
|
"language": req.Language,
|
|
"translate": req.Translate,
|
|
"diarize": req.Diarize,
|
|
"prompt": req.Prompt,
|
|
}
|
|
if audioSnippet != nil {
|
|
maps.Copy(errData, audioSnippet)
|
|
}
|
|
trace.RecordBackendTrace(trace.BackendTrace{
|
|
Timestamp: startTime,
|
|
Duration: time.Since(startTime),
|
|
Type: trace.BackendTraceTranscription,
|
|
ModelName: modelConfig.Name,
|
|
Backend: modelConfig.Backend,
|
|
Summary: trace.TruncateString(req.Audio, 200),
|
|
Error: err.Error(),
|
|
Data: errData,
|
|
})
|
|
}
|
|
return nil, err
|
|
}
|
|
tr := transcriptResultFromProto(r)
|
|
|
|
if appConfig.EnableTracing {
|
|
data := map[string]any{
|
|
"audio_file": req.Audio,
|
|
"language": req.Language,
|
|
"translate": req.Translate,
|
|
"diarize": req.Diarize,
|
|
"prompt": req.Prompt,
|
|
"result_text": tr.Text,
|
|
"segments_count": len(tr.Segments),
|
|
}
|
|
if audioSnippet != nil {
|
|
maps.Copy(data, audioSnippet)
|
|
}
|
|
trace.RecordBackendTrace(trace.BackendTrace{
|
|
Timestamp: startTime,
|
|
Duration: time.Since(startTime),
|
|
Type: trace.BackendTraceTranscription,
|
|
ModelName: modelConfig.Name,
|
|
Backend: modelConfig.Backend,
|
|
Summary: trace.TruncateString(req.Audio+" -> "+tr.Text, 200),
|
|
Data: data,
|
|
})
|
|
}
|
|
|
|
return tr, err
|
|
}
|
|
|
|
// TranscriptionStreamChunk is a streaming event emitted by
|
|
// ModelTranscriptionStream. Either Delta carries an incremental text fragment,
|
|
// or Final carries the completed transcription as the very last event.
|
|
type TranscriptionStreamChunk struct {
|
|
Delta string
|
|
Final *schema.TranscriptionResult
|
|
}
|
|
|
|
// ModelTranscriptionStream runs the gRPC streaming transcription RPC and
|
|
// invokes onChunk for each event the backend produces. Backends that don't
|
|
// support real streaming should still emit one terminal event with Final set,
|
|
// which the HTTP layer turns into a single delta + done SSE pair.
|
|
func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
|
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pbReq := req.toProto(uint32(*modelConfig.Threads))
|
|
pbReq.Stream = true
|
|
|
|
return transcriptionModel.AudioTranscriptionStream(ctx, pbReq, func(chunk *proto.TranscriptStreamResponse) {
|
|
if chunk == nil {
|
|
return
|
|
}
|
|
out := TranscriptionStreamChunk{Delta: chunk.Delta}
|
|
if chunk.FinalResult != nil {
|
|
out.Final = transcriptResultFromProto(chunk.FinalResult)
|
|
}
|
|
onChunk(out)
|
|
})
|
|
}
|
|
|
|
func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionResult {
|
|
if r == nil {
|
|
return &schema.TranscriptionResult{}
|
|
}
|
|
tr := &schema.TranscriptionResult{
|
|
Text: r.Text,
|
|
Language: r.Language,
|
|
Duration: float64(r.Duration),
|
|
}
|
|
|
|
for _, s := range r.Segments {
|
|
var tks []int
|
|
for _, t := range s.Tokens {
|
|
tks = append(tks, int(t))
|
|
}
|
|
var words []schema.TranscriptionWord
|
|
for _, w := range s.Words {
|
|
var word = schema.TranscriptionWord{
|
|
Start: time.Duration(w.Start),
|
|
End: time.Duration(w.End),
|
|
Text: w.Text,
|
|
}
|
|
words = append(words, word)
|
|
tr.Words = append(tr.Words, word)
|
|
}
|
|
tr.Segments = append(tr.Segments,
|
|
schema.TranscriptionSegment{
|
|
Text: s.Text,
|
|
Id: int(s.Id),
|
|
Start: time.Duration(s.Start),
|
|
End: time.Duration(s.End),
|
|
Tokens: tks,
|
|
Speaker: s.Speaker,
|
|
Words: words,
|
|
})
|
|
}
|
|
return tr
|
|
}
|