mirror of
https://github.com/mudler/LocalAI.git
synced 2026-07-02 20:37:03 -04:00
Image generation (and the tts/transcript/embeddings/vad/rerank/llm helpers) pass the request context to loader.Load so distributed routing decisions reach the request's X-LocalAI-Node holder. That context also governs cancellation of the load, so when a client disconnects mid-load the LoadModel RPC is aborted, stopLoadProcess tears down the backend process, and every retry restarts from scratch. Heavy diffusers/LLM models on a slow host (e.g. a shared-memory iGPU) take long enough to load that the request routinely ends first, so the model never finishes loading and the UI shows "NetworkError when attempting to fetch resource". Wrap the load context with context.WithoutCancel: the routing holder value still propagates, but the request's cancellation no longer aborts the load, so it runs to completion and caches for the next request. Inference keeps the cancellable request context, so a disconnect still stops generation. Adds a regression spec asserting a canceled request context does not cancel the model load while the routing holder still reaches the router. Fixes #10636 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code]
151 lines
4.3 KiB
Go
151 lines
4.3 KiB
Go
package backend
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/trace"
|
|
|
|
"github.com/mudler/LocalAI/pkg/grpc"
|
|
model "github.com/mudler/LocalAI/pkg/model"
|
|
)
|
|
|
|
// Embedder produces a fixed-dimension vector from a prompt. The
|
|
// router's L2 embedding cache uses it to look up semantically-similar
|
|
// past decisions.
|
|
type Embedder interface {
|
|
Embed(ctx context.Context, text string) ([]float32, error)
|
|
}
|
|
|
|
// NewEmbedder binds (loader, modelConfig, appConfig) into an Embedder.
|
|
func NewEmbedder(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Embedder {
|
|
return &modelEmbedder{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
|
}
|
|
|
|
type modelEmbedder struct {
|
|
loader *model.ModelLoader
|
|
modelConfig config.ModelConfig
|
|
appConfig *config.ApplicationConfig
|
|
}
|
|
|
|
func (e *modelEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
|
fn, err := ModelEmbedding(ctx, text, nil, e.loader, e.modelConfig, e.appConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return fn()
|
|
}
|
|
|
|
func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
|
|
|
// model.WithContext carries the request context into the load so distributed
|
|
// routing decisions reach the request's X-LocalAI-Node holder via
|
|
// distributedhdr.Stamp. context.WithoutCancel keeps those values but drops
|
|
// the request's cancellation, so a slow first load still completes and
|
|
// caches if the client disconnects instead of aborting the LoadModel RPC and
|
|
// tearing down the backend process (issue #10636). Inference below keeps the
|
|
// cancellable ctx, so a disconnect still stops generation.
|
|
opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx)))
|
|
|
|
inferenceModel, err := loader.Load(opts...)
|
|
if err != nil {
|
|
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
|
return nil, err
|
|
}
|
|
|
|
var fn func() ([]float32, error)
|
|
switch model := inferenceModel.(type) {
|
|
case grpc.Backend:
|
|
fn = func() ([]float32, error) {
|
|
predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath)
|
|
if len(tokens) > 0 {
|
|
embeds := []int32{}
|
|
|
|
for _, t := range tokens {
|
|
embeds = append(embeds, int32(t))
|
|
}
|
|
predictOptions.EmbeddingTokens = embeds
|
|
|
|
res, err := model.Embeddings(appConfig.Context, predictOptions)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return res.Embeddings, nil
|
|
}
|
|
predictOptions.Embeddings = s
|
|
|
|
res, err := model.Embeddings(appConfig.Context, predictOptions)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return res.Embeddings, nil
|
|
}
|
|
default:
|
|
fn = func() ([]float32, error) {
|
|
return nil, fmt.Errorf("embeddings not supported by the backend")
|
|
}
|
|
}
|
|
|
|
wrappedFn := func() ([]float32, error) {
|
|
embeds, err := fn()
|
|
if err != nil {
|
|
return embeds, err
|
|
}
|
|
// Return embeddings as-is to preserve full dimensionality
|
|
// Trailing zeros may be valid values in some embedding models
|
|
return embeds, nil
|
|
}
|
|
|
|
if appConfig.EnableTracing {
|
|
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
|
|
|
traceData := map[string]any{
|
|
"input_text": trace.TruncateString(s, 1000),
|
|
}
|
|
// Only present for token-mode callers (pre-tokenized override);
|
|
// emitting "0" alongside input_text would read as "consumed zero
|
|
// tokens", which is wrong.
|
|
if len(tokens) > 0 {
|
|
traceData["input_tokens_count"] = len(tokens)
|
|
}
|
|
|
|
startTime := time.Now()
|
|
originalFn := wrappedFn
|
|
wrappedFn = func() ([]float32, error) {
|
|
result, err := originalFn()
|
|
duration := time.Since(startTime)
|
|
|
|
traceData["embedding_dimensions"] = len(result)
|
|
|
|
errStr := ""
|
|
if err != nil {
|
|
errStr = err.Error()
|
|
}
|
|
|
|
summary := trace.TruncateString(s, 200)
|
|
if summary == "" {
|
|
summary = fmt.Sprintf("tokens[%d]", len(tokens))
|
|
}
|
|
|
|
trace.RecordBackendTrace(trace.BackendTrace{
|
|
Timestamp: startTime,
|
|
Duration: duration,
|
|
Type: trace.BackendTraceEmbedding,
|
|
ModelName: modelConfig.Name,
|
|
Backend: modelConfig.Backend,
|
|
Summary: summary,
|
|
Error: errStr,
|
|
Data: traceData,
|
|
})
|
|
|
|
return result, err
|
|
}
|
|
}
|
|
|
|
return wrappedFn, nil
|
|
}
|