Files
LocalAI/core/backend/ctx_propagation_test.go
LocalAI [bot] 27e63b9a78 feat(tts): support per-request instructions and params (#10172)
The OpenAI-compatible TTS endpoint accepts an `instructions` field, but it
was silently dropped at the HTTP->gRPC boundary: neither schema.TTSRequest
nor the gRPC TTSRequest proto carried it, so backends could only read such a
value from static YAML options (identical for every request). This blocked
per-line emotion/style and, for Qwen3-TTS VoiceDesign, limited a model config
to a single designed voice.

Plumb a generic per-request instruction string end to end, plus an optional
backend-specific params map:

- proto: add `optional string instructions` and `map<string,string> params`
  to TTSRequest.
- schema: add Instructions (maps OpenAI `instructions`) and Params (LocalAI
  extension) to schema.TTSRequest.
- core: thread both through ModelTTS/ModelTTSStream via a newTTSRequest helper
  that attaches instructions only when non-empty (so backends can fall back to
  YAML when unset); forward them from the /v1/audio/speech handler.
- qwen-tts: prefer the per-request instruction over the YAML `instruct` option
  (used by both mode detection and generation) and merge per-request params.
- chatterbox: merge per-request params (coerced to float/int/bool) over YAML
  options into generate() kwargs.

Fully backward compatible: empty instructions fall back to the YAML option and
backends that don't support style/voice instructions ignore the field.

Closes #10164


Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-04 11:45:02 +02:00

170 lines
6.6 KiB
Go

package backend_test
// Regression spec for X-LocalAI-Node coverage on audio/image/TTS/rerank/VAD.
//
// The X-LocalAI-Node middleware (core/http/middleware.ExposeNodeHeader)
// works end-to-end only if the per-request holder attached to the HTTP
// request context reaches the SmartRouter via ml.Load(opts...). The chain
// is:
//
// handler -> backend.Foo(ctx, ...) -> ModelOptions(cfg, app, WithContext(ctx))
// -> ml.Load(opts...) -> grpcModel(..., o.context) -> modelRouter(ctx, ...)
// -> SmartRouter -> distributedhdr.Stamp(ctx, nodeID)
//
// If any backend helper drops `ctx` and lets ModelOptions fall back to the
// app context, the router never sees the per-request holder and the
// header silently stays empty for that endpoint. These specs pin the
// request-context-reaches-router contract for the five backend helpers
// that were previously dropping ctx between the handler and Load.
import (
"context"
"sync/atomic"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
pbproto "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/distributedhdr"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// newCapturingLoader returns a ModelLoader wired with a stub model router
// that captures the context it receives and then short-circuits with a
// sentinel error. The router callback is the exact seam where the
// SmartRouter would call distributedhdr.Stamp in production, so observing
// the holder here is equivalent to observing it at the real router.
func newCapturingLoader() (*model.ModelLoader, *atomic.Value, func() context.Context) {
loader := model.NewModelLoader(&system.SystemState{})
var captured atomic.Value
loader.SetModelRouter(func(ctx context.Context, _ string, _, _, _ string, _ *pbproto.ModelOptions, _ bool) (*model.Model, error) {
captured.Store(ctx)
// Return an error so the backend short-circuits before trying to
// dial gRPC. We only care about the context-arrival contract.
return nil, errRouterShortCircuit
})
get := func() context.Context {
v, _ := captured.Load().(context.Context)
return v
}
return loader, &captured, get
}
var errRouterShortCircuit = sentinelErr("router short-circuit (test)")
type sentinelErr string
func (s sentinelErr) Error() string { return string(s) }
func newAppCfg() *config.ApplicationConfig {
return config.NewApplicationConfig(config.WithSystemState(&system.SystemState{}))
}
func newModelCfg() config.ModelConfig {
threads := 1
cfg := config.ModelConfig{
Name: "test-model",
Backend: "stub-backend",
Threads: &threads,
}
cfg.Model = "test.bin"
return cfg
}
var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
const fakeNodeID = "node-ctx-propagation-7"
var (
appCfg *config.ApplicationConfig
modelCfg config.ModelConfig
loader *model.ModelLoader
routerCtxOf func() context.Context
holder *atomic.Value
reqCtx context.Context
)
BeforeEach(func() {
appCfg = newAppCfg()
modelCfg = newModelCfg()
loader, _, routerCtxOf = newCapturingLoader()
holder = distributedhdr.NewHolder()
reqCtx = distributedhdr.WithHolder(context.Background(), holder)
})
// stampViaRouterCtx asserts the captured router context carries the
// SAME holder that was attached to the request. We verify by stamping
// through the router-side ctx and observing the value via the
// request-side holder; if the holders were different objects the load
// would return "".
stampViaRouterCtx := func() {
routerCtx := routerCtxOf()
Expect(routerCtx).ToNot(BeNil(), "router callback must have been invoked")
distributedhdr.Stamp(routerCtx, fakeNodeID)
Expect(distributedhdr.Load(holder)).To(Equal(fakeNodeID),
"stamp via router-side ctx must be observable via the request-side holder")
}
It("Rerank forwards the request context to the SmartRouter", func() {
_, err := backend.Rerank(reqCtx, &pbproto.RerankRequest{Query: "q"}, loader, appCfg, modelCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("VAD forwards the request context to the SmartRouter", func() {
_, err := backend.VAD(&schema.VADRequest{}, reqCtx, loader, appCfg, modelCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTTS forwards the request context to the SmartRouter", func() {
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
err := backend.ModelTTSStream(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg, func([]byte) error { return nil })
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTranscriptionWithOptions forwards the request context to the SmartRouter", func() {
_, err := backend.ModelTranscriptionWithOptions(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTranscriptionStream forwards the request context to the SmartRouter", func() {
err := backend.ModelTranscriptionStream(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg, func(backend.TranscriptionStreamChunk) {})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ImageGeneration forwards the request context to the SmartRouter", func() {
_, err := backend.ImageGeneration(reqCtx, 64, 64, 1, 0, "p", "", "", "/tmp/out.png", loader, modelCfg, appCfg, nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("does NOT leak the holder when the app context is used instead", func() {
// Sanity: the bug being fixed manifests as the router getting
// appCfg.Context (no holder) instead of reqCtx (holder). A direct
// call with context.Background() must not see the holder via the
// app context surface.
appCtxOnly := appCfg.Context
Expect(distributedhdr.Holder(appCtxOnly)).To(BeNil(),
"the app context must not be the carrier of per-request holders")
})
})