mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-13 03:09:03 -04:00
fix(router): production-ready request router + auto-size batch for embedding/rerank (#10104)
* fix(router): score classifier production-readiness Conversation trimming runs through the classifier model's chat template and trims by exact token count, sized to the model's n_batch which is now scaled to context so long probes can't crash the backend. Missing chat_message templates are a hard error at router build time. Router- facing factories (Embedder/Scorer/Reranker/TokenCounter) re-resolve ModelConfig per call so a model installed post-startup doesn't bind a stub Backend="" config and silently fall into the loader's auto- iterate path. New 'vector_store' backend trace recorded inside localVectorStore on every Search/Insert — including the backend-load-failure path that previously vanished into an xlog.Warn — with outcome tagging (hit/miss/empty_store/backend_load_error/find_error/insert_error/ok). Companion cleanup drops misleading similarity:0 and input_tokens_count:0 from non-hit and text-mode traces. Gallery local-store-development aliases to 'local-store' so the master image satisfies pkg/model.LocalStoreBackend lookups from the embedding cache. Misc: llama-cpp TokenizeString reads the correct 'prompt' JSON key (the original bug); ModelTokenize nil-guard; non-fatal mitm proxy startup; PII 'route_local' renamed to 'allow' with docs/UI in sync; model-editor footer no longer eats the edit area on small screens; several config-editor template/dropdown/section fixes. Tests: e2e router specs (casual/code-hint + long-conversation trim), vector_store trace specs, lazy-factory specs, gallery dev-alias resolution, Playwright trace badge + scroll regression. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(backend): auto-size batch to context for embedding and rerank models Embedding and rerank models pool over the whole input in a single physical batch (n_ubatch). With batch left at the 512 default, the backend rejects longer inputs with "input is too large to process", silently capping a large-context embedder (e.g. 8k/32k) at 512 tokens. Size n_batch to the context for these single-pass usecases, mirroring the existing FLAG_SCORE behaviour; an explicit batch: still wins. Extracts EffectiveContextSize/EffectiveBatchSize from grpcModelOpts so the effective decode window has one home for other callers to reuse. Adds an e2e-aio regression test that embeds a >512-token input. The AIO embedding model is switched to nomic-embed-text-v1.5 (2048 context) because the previous granite model was capped at 512 tokens and could not exercise the larger batch. Assisted-by: claude-code:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(gallery): raise arch-router scoring output cap via parallel:64 Scoring decodes the whole prompt+candidate in a single llama_decode and reads one logit row per candidate token. The vendored llama.cpp server caps causal output rows at n_parallel, so the default of 1 aborts with GGML_ASSERT(n_outputs_max <= cparams.n_outputs_max) on multi-token route labels. Set options: [parallel:64] on both arch-router quant entries to lift the cap; kv_unified (the grpc-server default) keeps the full context per sequence, so this does not split the KV cache. Assisted-by: claude-code:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
56cc4f63fc
commit
085fc53bbc
@@ -11,6 +11,29 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// startMITMIfConfigured brings up the cloudproxy MITM listener when an
|
||||
// address is configured, treating any startup failure as non-fatal.
|
||||
//
|
||||
// The listener is opt-in middleware whose address is persisted in runtime
|
||||
// settings (/api/settings → runtime_settings.json) and replayed on every
|
||||
// boot. A bad value — e.g. a host the process can't bind, like a LAN IP
|
||||
// inside a container — must NOT abort the whole server: doing so crash-loops
|
||||
// with no way out, because the Settings UI used to correct the address can't
|
||||
// load if startup never completes. So on failure we log loudly and carry on;
|
||||
// the admin fixes the address via /api/settings, which calls RestartMITM.
|
||||
func startMITMIfConfigured(app *Application, options *config.ApplicationConfig) {
|
||||
if options.MITMListen == "" {
|
||||
return
|
||||
}
|
||||
if err := startMITMProxy(app, options); err != nil {
|
||||
xlog.Error("mitm: cloudproxy listener failed to start — continuing without it",
|
||||
"listen", options.MITMListen,
|
||||
"error", err,
|
||||
"hint", "fix the address via Settings (e.g. \":8082\" to bind all interfaces) and the listener will restart",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func startMITMProxy(app *Application, options *config.ApplicationConfig) error {
|
||||
app.mitmMutex.Lock()
|
||||
defer app.mitmMutex.Unlock()
|
||||
|
||||
58
core/application/mitm_test.go
Normal file
58
core/application/mitm_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// minimal Application wired enough for startMITMProxy: an empty model
|
||||
// config loader (no host claims), CA written under a temp DataPath.
|
||||
func newMITMTestApp(dataPath string) (*Application, *config.ApplicationConfig) {
|
||||
state, err := system.GetSystemState()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
state.Model.ModelsPath = dataPath
|
||||
opts := config.NewApplicationConfig(
|
||||
config.WithSystemState(state),
|
||||
config.WithDataPath(dataPath),
|
||||
)
|
||||
return newApplication(opts), opts
|
||||
}
|
||||
|
||||
var _ = Describe("startMITMIfConfigured", func() {
|
||||
It("does nothing when no listen address is configured", func() {
|
||||
app, opts := newMITMTestApp(GinkgoT().TempDir())
|
||||
opts.MITMListen = ""
|
||||
|
||||
Expect(func() { startMITMIfConfigured(app, opts) }).NotTo(Panic())
|
||||
Expect(app.mitmServer.Load()).To(BeNil(), "no listener should be stored when disabled")
|
||||
})
|
||||
|
||||
// Regression: a persisted-but-unbindable MITM address (e.g. a LAN host
|
||||
// inside a container) must not abort startup. startMITMIfConfigured
|
||||
// swallows the bind error so the rest of LocalAI still comes up and the
|
||||
// admin can fix the address via the Settings UI.
|
||||
It("logs and continues when the listen address cannot be bound", func() {
|
||||
app, opts := newMITMTestApp(GinkgoT().TempDir())
|
||||
// 192.0.2.1 is TEST-NET-1 (RFC 5737): guaranteed not assigned to any
|
||||
// local interface, so bind fails deterministically without DNS.
|
||||
opts.MITMListen = "192.0.2.1:8082"
|
||||
|
||||
Expect(func() { startMITMIfConfigured(app, opts) }).NotTo(Panic())
|
||||
Expect(app.mitmServer.Load()).To(BeNil(), "failed listener must not be stored")
|
||||
})
|
||||
|
||||
It("starts and stores the listener on a bindable address", func() {
|
||||
app, opts := newMITMTestApp(GinkgoT().TempDir())
|
||||
opts.MITMListen = "127.0.0.1:0" // OS-assigned free port
|
||||
|
||||
startMITMIfConfigured(app, opts)
|
||||
|
||||
srv := app.mitmServer.Load()
|
||||
Expect(srv).NotTo(BeNil(), "listener should be stored on success")
|
||||
DeferCleanup(srv.Stop)
|
||||
Expect(srv.Addr()).NotTo(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -1,63 +1,120 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// adapterConfig resolves a model name to its runtime ModelConfig, or
|
||||
// nil when the name is unknown. Shared by the router-facing factories
|
||||
// below and by ModelConfigLookup.
|
||||
// adapterConfig resolves a model name to its runtime ModelConfig, or nil when
|
||||
// unknown. LoadModelConfigFileByNameDefaultOptions never returns nil — for an
|
||||
// unknown name it returns a defaults-filled stub with an empty Name (the YAML
|
||||
// `name:` field is required by Validate), which is how we tell the two apart.
|
||||
func (a *Application) adapterConfig(modelName string) *config.ModelConfig {
|
||||
cfg, err := a.backendLoader.LoadModelConfigFileByNameDefaultOptions(modelName, a.applicationConfig)
|
||||
if err != nil || cfg == nil {
|
||||
if err != nil || cfg == nil || cfg.Name == "" {
|
||||
return nil
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ModelConfigLookup is the lookup function the router middleware's
|
||||
// classifier validator uses to confirm classifier_model declares
|
||||
// FLAG_SCORE before binding it.
|
||||
// ModelConfigLookup is the lookup the router middleware's classifier validator
|
||||
// uses to confirm classifier_model declares FLAG_SCORE before binding it.
|
||||
func (a *Application) ModelConfigLookup() func(modelName string) *config.ModelConfig {
|
||||
return a.adapterConfig
|
||||
}
|
||||
|
||||
// Scorer returns a backend.Scorer bound to the named model, or nil
|
||||
// when the model is unknown. Used as a method value (app.Scorer) by
|
||||
// router.ClassifierDeps — no factory-of-factory wrapper needed.
|
||||
// The router-facing factories below (Scorer, Embedder, Reranker, TokenCounter)
|
||||
// bind a model NAME at construction and re-resolve the CONFIG on every call.
|
||||
// Capturing the config at construction would bake in whatever state
|
||||
// adapterConfig saw first — including a stub returned before the YAML reached
|
||||
// bcl.configs (e.g. /import-model or gallery install racing startup). The
|
||||
// classifier registry caches factories by router-config fingerprint, so a
|
||||
// once-stale capture stays stale until the router config is edited.
|
||||
|
||||
func (a *Application) Scorer(modelName string) backend.Scorer {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
if a.adapterConfig(modelName) == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewScorer(a.modelLoader, *cfg, a.applicationConfig)
|
||||
return &lazyScorer{app: a, modelName: modelName}
|
||||
}
|
||||
|
||||
type lazyScorer struct {
|
||||
app *Application
|
||||
modelName string
|
||||
}
|
||||
|
||||
func (l *lazyScorer) Score(ctx context.Context, prompt string, candidates []string) ([]backend.CandidateScore, error) {
|
||||
cfg := l.app.adapterConfig(l.modelName)
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("scorer: model %q no longer available", l.modelName)
|
||||
}
|
||||
return backend.NewScorer(l.app.modelLoader, *cfg, l.app.applicationConfig).Score(ctx, prompt, candidates)
|
||||
}
|
||||
|
||||
// TokenCounter returns a func so the middleware's literal field type accepts
|
||||
// it as a method value without importing core/http/middleware from here.
|
||||
func (a *Application) TokenCounter(modelName string) func(string) (int, error) {
|
||||
if a.adapterConfig(modelName) == nil {
|
||||
return nil
|
||||
}
|
||||
return func(text string) (int, error) {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return 0, fmt.Errorf("token counter: model %q no longer available", modelName)
|
||||
}
|
||||
resp, err := backend.ModelTokenize(text, a.modelLoader, *cfg, a.applicationConfig)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(resp.Tokens), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Reranker returns a backend.Reranker bound to the named model, or
|
||||
// nil when unknown. The reranker model's `type:` (e.g. "colbert")
|
||||
// selects the scoring head inside the rerankers backend.
|
||||
func (a *Application) Reranker(modelName string) backend.Reranker {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
if a.adapterConfig(modelName) == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewReranker(a.modelLoader, *cfg, a.applicationConfig)
|
||||
return &lazyReranker{app: a, modelName: modelName}
|
||||
}
|
||||
|
||||
type lazyReranker struct {
|
||||
app *Application
|
||||
modelName string
|
||||
}
|
||||
|
||||
func (l *lazyReranker) Rerank(ctx context.Context, query string, documents []string) ([]backend.RerankResult, error) {
|
||||
cfg := l.app.adapterConfig(l.modelName)
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("reranker: model %q no longer available", l.modelName)
|
||||
}
|
||||
return backend.NewReranker(l.app.modelLoader, *cfg, l.app.applicationConfig).Rerank(ctx, query, documents)
|
||||
}
|
||||
|
||||
// Embedder returns a backend.Embedder bound to the named model, or
|
||||
// nil when unknown. Used by the router's L2 embedding cache.
|
||||
func (a *Application) Embedder(modelName string) backend.Embedder {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
if a.adapterConfig(modelName) == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewEmbedder(a.modelLoader, *cfg, a.applicationConfig)
|
||||
return &lazyEmbedder{app: a, modelName: modelName}
|
||||
}
|
||||
|
||||
// VectorStore returns a backend.VectorStore for the named collection,
|
||||
// or nil when the name is empty. Each router model gets its own
|
||||
// backend process via the model loader's cache keyed by storeName.
|
||||
type lazyEmbedder struct {
|
||||
app *Application
|
||||
modelName string
|
||||
}
|
||||
|
||||
func (l *lazyEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
cfg := l.app.adapterConfig(l.modelName)
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("embedder: model %q no longer available", l.modelName)
|
||||
}
|
||||
return backend.NewEmbedder(l.app.modelLoader, *cfg, l.app.applicationConfig).Embed(ctx, text)
|
||||
}
|
||||
|
||||
// VectorStore takes a store name, not a model name — no adapterConfig, no
|
||||
// staleness to avoid.
|
||||
func (a *Application) VectorStore(storeName string) backend.VectorStore {
|
||||
return backend.NewVectorStore(a.modelLoader, a.applicationConfig, storeName)
|
||||
}
|
||||
|
||||
155
core/application/router_factories_test.go
Normal file
155
core/application/router_factories_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Regression: the router-facing factories used to capture
|
||||
// *config.ModelConfig at construction. A gallery install that raced
|
||||
// startup left a stub (Backend="") bound for the lifetime of the
|
||||
// classifier registry's cache entry, bypassing the user's `backend:`
|
||||
// config. These specs pin the lazy re-resolve.
|
||||
var _ = Describe("router_factories lazy config resolution", func() {
|
||||
var (
|
||||
tmpDir string
|
||||
app *Application
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "router-factories-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
appCfg := &config.ApplicationConfig{
|
||||
Context: context.Background(),
|
||||
SystemState: &system.SystemState{Model: system.Model{ModelsPath: tmpDir}},
|
||||
}
|
||||
app = &Application{
|
||||
backendLoader: config.NewModelConfigLoader(tmpDir),
|
||||
modelLoader: model.NewModelLoader(appCfg.SystemState),
|
||||
applicationConfig: appCfg,
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
// writeCfg seeds both the on-disk YAML and the in-memory cache —
|
||||
// removing only the cache would fall through to file-read.
|
||||
writeCfg := func(name, backend string) {
|
||||
yaml := "name: " + name + "\nbackend: " + backend + "\nparameters:\n model: " + name + ".bin\n"
|
||||
Expect(os.WriteFile(filepath.Join(tmpDir, name+".yaml"), []byte(yaml), 0644)).To(Succeed())
|
||||
Expect(app.backendLoader.LoadModelConfigsFromPath(tmpDir)).To(Succeed())
|
||||
cfg, ok := app.backendLoader.GetModelConfig(name)
|
||||
Expect(ok).To(BeTrue(), "config must be loaded before the spec runs")
|
||||
Expect(cfg.Backend).To(Equal(backend))
|
||||
}
|
||||
|
||||
// removeCfg purges both the cache and the YAML so LoadModelConfigFileByName
|
||||
// returns the empty-stub case and adapterConfig returns nil.
|
||||
removeCfg := func(name string) {
|
||||
app.backendLoader.RemoveModelConfig(name)
|
||||
Expect(os.Remove(filepath.Join(tmpDir, name+".yaml"))).To(Succeed())
|
||||
}
|
||||
|
||||
Context("Embedder", func() {
|
||||
It("returns nil at construction for an unknown model", func() {
|
||||
Expect(app.Embedder("missing")).To(BeNil())
|
||||
})
|
||||
|
||||
It("re-resolves the model config on each Embed call", func() {
|
||||
writeCfg("emb-test", "llama-cpp")
|
||||
emb := app.Embedder("emb-test")
|
||||
Expect(emb).NotTo(BeNil())
|
||||
|
||||
// The factory must hold the NAME, not a captured config —
|
||||
// otherwise stale captures survive cache invalidation.
|
||||
lazy, ok := emb.(*lazyEmbedder)
|
||||
Expect(ok).To(BeTrue(), "Embedder must return *lazyEmbedder")
|
||||
Expect(lazy.modelName).To(Equal("emb-test"))
|
||||
|
||||
// Mutate the cached config. A lazy implementation sees the
|
||||
// update on the next adapterConfig call; a captured-at-
|
||||
// construction implementation would still see "llama-cpp".
|
||||
app.backendLoader.UpdateModelConfig("emb-test", func(c *config.ModelConfig) {
|
||||
c.Backend = "rerankers"
|
||||
})
|
||||
Expect(lazy.app.adapterConfig("emb-test").Backend).To(Equal("rerankers"))
|
||||
|
||||
// Remove the config entirely → Embed must surface the disappearance.
|
||||
removeCfg("emb-test")
|
||||
_, err := emb.Embed(context.Background(), "anything")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no longer available"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Scorer", func() {
|
||||
It("returns nil at construction for an unknown model", func() {
|
||||
Expect(app.Scorer("missing")).To(BeNil())
|
||||
})
|
||||
|
||||
It("re-resolves the model config on each Score call", func() {
|
||||
writeCfg("score-test", "llama-cpp")
|
||||
sc := app.Scorer("score-test")
|
||||
Expect(sc).NotTo(BeNil())
|
||||
|
||||
lazy, ok := sc.(*lazyScorer)
|
||||
Expect(ok).To(BeTrue(), "Scorer must return *lazyScorer")
|
||||
Expect(lazy.modelName).To(Equal("score-test"))
|
||||
|
||||
removeCfg("score-test")
|
||||
_, err := sc.Score(context.Background(), "prompt", []string{"a"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no longer available"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Reranker", func() {
|
||||
It("returns nil at construction for an unknown model", func() {
|
||||
Expect(app.Reranker("missing")).To(BeNil())
|
||||
})
|
||||
|
||||
It("re-resolves the model config on each Rerank call", func() {
|
||||
writeCfg("rerank-test", "rerankers")
|
||||
rr := app.Reranker("rerank-test")
|
||||
Expect(rr).NotTo(BeNil())
|
||||
|
||||
lazy, ok := rr.(*lazyReranker)
|
||||
Expect(ok).To(BeTrue(), "Reranker must return *lazyReranker")
|
||||
Expect(lazy.modelName).To(Equal("rerank-test"))
|
||||
|
||||
removeCfg("rerank-test")
|
||||
_, err := rr.Rerank(context.Background(), "q", []string{"d"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no longer available"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("TokenCounter", func() {
|
||||
It("returns nil at construction for an unknown model", func() {
|
||||
Expect(app.TokenCounter("missing")).To(BeNil())
|
||||
})
|
||||
|
||||
It("re-resolves the model config on each call", func() {
|
||||
writeCfg("tok-test", "llama-cpp")
|
||||
tc := app.TokenCounter("tok-test")
|
||||
Expect(tc).NotTo(BeNil())
|
||||
|
||||
removeCfg("tok-test")
|
||||
_, err := tc("anything")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no longer available"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -462,11 +462,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
// traffic doesn't need a parallel config for MITM traffic.
|
||||
// Runs after loadRuntimeSettingsFromFile so a listener configured
|
||||
// via /api/settings is brought back up across restarts.
|
||||
if options.MITMListen != "" {
|
||||
if err := startMITMProxy(application, options); err != nil {
|
||||
return nil, fmt.Errorf("mitm: startup: %w", err)
|
||||
}
|
||||
}
|
||||
startMITMIfConfigured(application, options)
|
||||
|
||||
application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging)
|
||||
|
||||
|
||||
@@ -100,8 +100,13 @@ func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.M
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"input_text": trace.TruncateString(s, 1000),
|
||||
"input_tokens_count": len(tokens),
|
||||
"input_text": trace.TruncateString(s, 1000),
|
||||
}
|
||||
// Only present for token-mode callers (pre-tokenized override);
|
||||
// emitting "0" alongside input_text would read as "consumed zero
|
||||
// tokens", which is wrong.
|
||||
if len(tokens) > 0 {
|
||||
traceData["input_tokens_count"] = len(tokens)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
@@ -87,11 +87,47 @@ func getSeed(c config.ModelConfig) int32 {
|
||||
return seed
|
||||
}
|
||||
|
||||
func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
b := 512
|
||||
if c.Batch != 0 {
|
||||
b = c.Batch
|
||||
// DefaultContextSize and DefaultBatchSize are the backend's fallbacks when a
|
||||
// model config leaves them unset. Exported so callers that must respect the
|
||||
// effective decode window — notably the router's prompt trimmer — resolve the
|
||||
// same numbers grpcModelOpts does instead of guessing.
|
||||
const (
|
||||
DefaultContextSize = 4096
|
||||
DefaultBatchSize = 512
|
||||
)
|
||||
|
||||
// EffectiveContextSize is the context window the backend will run with: the
|
||||
// configured value, or DefaultContextSize when unset.
|
||||
func EffectiveContextSize(c config.ModelConfig) int {
|
||||
if c.ContextSize != nil {
|
||||
return *c.ContextSize
|
||||
}
|
||||
return DefaultContextSize
|
||||
}
|
||||
|
||||
// EffectiveBatchSize is the single-decode batch the backend will run with.
|
||||
// Score, embedding and rerank all process the whole input in one pass: score
|
||||
// decodes prompt+candidate (asserts n_tokens <= n_batch), and embedding/rerank
|
||||
// pool over the full sequence in one physical batch (n_ubatch). So the batch
|
||||
// is sized to the context — anything that fits the context fits one pass,
|
||||
// avoiding both the GGML_ASSERT crash and the "input is too large to process"
|
||||
// error. Explicit `batch:` always wins.
|
||||
func EffectiveBatchSize(c config.ModelConfig) int {
|
||||
if c.Batch != 0 {
|
||||
return c.Batch
|
||||
}
|
||||
singlePass := c.HasUsecases(config.FLAG_SCORE) ||
|
||||
c.HasUsecases(config.FLAG_EMBEDDINGS) ||
|
||||
c.HasUsecases(config.FLAG_RERANK)
|
||||
if ctx := EffectiveContextSize(c); singlePass && ctx > DefaultBatchSize {
|
||||
return ctx
|
||||
}
|
||||
return DefaultBatchSize
|
||||
}
|
||||
|
||||
func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
ctxSize := EffectiveContextSize(c)
|
||||
b := EffectiveBatchSize(c)
|
||||
|
||||
flashAttention := "auto"
|
||||
|
||||
@@ -134,11 +170,6 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
}
|
||||
}
|
||||
|
||||
ctxSize := 4096
|
||||
if c.ContextSize != nil {
|
||||
ctxSize = *c.ContextSize
|
||||
}
|
||||
|
||||
mmlock := false
|
||||
if c.MMlock != nil {
|
||||
mmlock = *c.MMlock
|
||||
|
||||
@@ -97,3 +97,67 @@ var _ = Describe("gRPCPredictOpts reasoning_effort metadata", func() {
|
||||
Expect(opts.Metadata).ToNot(HaveKey("reasoning_effort"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("grpcModelOpts NBatch", func() {
|
||||
scoreUsecase := config.FLAG_SCORE
|
||||
threads := 1
|
||||
ctx := 4096
|
||||
|
||||
It("defaults to 512 for an ordinary model", func() {
|
||||
cfg := config.ModelConfig{Threads: &threads, LLMConfig: config.LLMConfig{ContextSize: &ctx}}
|
||||
opts := grpcModelOpts(cfg, "/tmp/models")
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(512))
|
||||
})
|
||||
|
||||
It("sizes the batch to the context window for score models", func() {
|
||||
// Score models decode the whole prompt+candidate in one
|
||||
// llama_decode; n_batch must cover it or the backend aborts.
|
||||
cfg := config.ModelConfig{Threads: &threads, LLMConfig: config.LLMConfig{ContextSize: &ctx}, KnownUsecases: &scoreUsecase}
|
||||
opts := grpcModelOpts(cfg, "/tmp/models")
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(4096))
|
||||
})
|
||||
|
||||
It("keeps an explicit batch over the score default", func() {
|
||||
cfg := config.ModelConfig{Threads: &threads, LLMConfig: config.LLMConfig{ContextSize: &ctx}, KnownUsecases: &scoreUsecase}
|
||||
cfg.Batch = 1024
|
||||
opts := grpcModelOpts(cfg, "/tmp/models")
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(1024))
|
||||
})
|
||||
|
||||
It("sizes the batch to the context window for embedding models", func() {
|
||||
// Embedding/rerank pool over the whole sequence in one physical batch
|
||||
// (n_ubatch); without this the input is capped at the 512 default and
|
||||
// the backend returns "input is too large to process".
|
||||
embeddings := true
|
||||
cfg := config.ModelConfig{Threads: &threads, LLMConfig: config.LLMConfig{ContextSize: &ctx}}
|
||||
cfg.Embeddings = &embeddings
|
||||
opts := grpcModelOpts(cfg, "/tmp/models")
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(4096))
|
||||
})
|
||||
|
||||
It("sizes the batch to the context window for rerank models", func() {
|
||||
reranking := true
|
||||
cfg := config.ModelConfig{Threads: &threads, LLMConfig: config.LLMConfig{ContextSize: &ctx}}
|
||||
cfg.Reranking = &reranking
|
||||
opts := grpcModelOpts(cfg, "/tmp/models")
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(4096))
|
||||
})
|
||||
|
||||
It("does not raise the batch when a score model's context is below the default", func() {
|
||||
small := 256
|
||||
cfg := config.ModelConfig{Threads: &threads, LLMConfig: config.LLMConfig{ContextSize: &small}, KnownUsecases: &scoreUsecase}
|
||||
opts := grpcModelOpts(cfg, "/tmp/models")
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(512))
|
||||
})
|
||||
|
||||
It("sizes the batch to the effective 4096 default for a score model with no explicit context_size", func() {
|
||||
// The crash case: the backend defaults n_ctx to 4096, so n_batch must
|
||||
// follow even when context_size is unset — otherwise n_batch stays 512
|
||||
// against a 4096 window and the score decode hits the GGML_ASSERT.
|
||||
cfg := config.ModelConfig{Threads: &threads, KnownUsecases: &scoreUsecase}
|
||||
Expect(cfg.ContextSize).To(BeNil())
|
||||
opts := grpcModelOpts(cfg, "/tmp/models")
|
||||
Expect(opts.NBatch).To(BeEquivalentTo(4096))
|
||||
Expect(opts.ContextSize).To(BeEquivalentTo(4096), "n_batch must match the effective n_ctx the backend receives")
|
||||
})
|
||||
})
|
||||
|
||||
@@ -3,9 +3,10 @@ package backend
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -39,34 +40,85 @@ func (s *localVectorStore) backend(_ context.Context) (grpc.Backend, error) {
|
||||
return StoreBackend(s.loader, s.appConfig, s.storeName, "")
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Search(ctx context.Context, vec []float32) (float64, []byte, bool, error) {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return 0, nil, false, fmt.Errorf("vector store load: %w", err)
|
||||
func (s *localVectorStore) Search(ctx context.Context, vec []float32) (sim float64, payload []byte, ok bool, err error) {
|
||||
start := time.Now()
|
||||
outcome := "hit"
|
||||
defer func() {
|
||||
s.recordTrace(start, "search", len(vec), sim, outcome, err)
|
||||
}()
|
||||
be, berr := s.backend(ctx)
|
||||
if berr != nil {
|
||||
outcome = "backend_load_error"
|
||||
return 0, nil, false, fmt.Errorf("vector store load: %w", berr)
|
||||
}
|
||||
_, values, similarities, err := store.Find(ctx, be, vec, 1)
|
||||
if err != nil {
|
||||
// local-store's Find returns "existing length is -1" before
|
||||
// any keys are inserted. Surface that as a clean miss so the
|
||||
// cache layer treats it as an empty store and proceeds to
|
||||
// Insert rather than skipping.
|
||||
if strings.Contains(err.Error(), "existing length is -1") {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return 0, nil, false, fmt.Errorf("vector store find: %w", err)
|
||||
_, values, similarities, ferr := store.Find(ctx, be, vec, 1)
|
||||
if ferr != nil {
|
||||
outcome = "find_error"
|
||||
return 0, nil, false, fmt.Errorf("vector store find: %w", ferr)
|
||||
}
|
||||
if len(values) == 0 || len(similarities) == 0 {
|
||||
outcome = "miss"
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return float64(similarities[0]), values[0], true, nil
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload []byte) error {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vector store load: %w", err)
|
||||
func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload []byte) (err error) {
|
||||
start := time.Now()
|
||||
outcome := "ok"
|
||||
defer func() {
|
||||
s.recordTrace(start, "insert", len(vec), 0, outcome, err)
|
||||
}()
|
||||
be, berr := s.backend(ctx)
|
||||
if berr != nil {
|
||||
outcome = "backend_load_error"
|
||||
return fmt.Errorf("vector store load: %w", berr)
|
||||
}
|
||||
return store.SetSingle(ctx, be, vec, payload)
|
||||
if serr := store.SetSingle(ctx, be, vec, payload); serr != nil {
|
||||
outcome = "insert_error"
|
||||
return serr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordTrace surfaces vector-store calls in /api/backend-traces, including
|
||||
// the backend-load-failure path that otherwise vanishes into an xlog.Warn.
|
||||
// modelName uses the store namespace (e.g. "router-cache-smart-router") so
|
||||
// admins can tell which router's cache misbehaved; the backend is always
|
||||
// "local-store" and can't disambiguate.
|
||||
func (s *localVectorStore) recordTrace(start time.Time, op string, vecDim int, sim float64, outcome string, err error) {
|
||||
if s.appConfig == nil || !s.appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(s.appConfig.TracingMaxItems, s.appConfig.TracingMaxBodyBytes)
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
summary := op + " " + outcome
|
||||
if op == "search" && outcome == "hit" {
|
||||
summary = fmt.Sprintf("search hit (sim=%.3f)", sim)
|
||||
}
|
||||
data := map[string]any{
|
||||
"op": op,
|
||||
"outcome": outcome,
|
||||
"vector_dim": vecDim,
|
||||
}
|
||||
// Only include similarity for a real neighbor — miss/empty_store would
|
||||
// otherwise render "similarity: 0" and read as a measured value.
|
||||
if op == "search" && outcome == "hit" {
|
||||
data["similarity"] = sim
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: start,
|
||||
Duration: time.Since(start),
|
||||
Type: trace.BackendTraceVectorStore,
|
||||
ModelName: s.storeName,
|
||||
Backend: model.LocalStoreBackend,
|
||||
Summary: summary,
|
||||
Error: errStr,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string, backend string) (grpc.Backend, error) {
|
||||
|
||||
88
core/backend/stores_test.go
Normal file
88
core/backend/stores_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// findVectorStoreTrace returns the most recent vector_store trace whose
|
||||
// model_name matches storeName, or nil if none was recorded. Used by
|
||||
// the specs below to assert the trace landed without relying on
|
||||
// ring-buffer ordering across other tests in the suite.
|
||||
func findVectorStoreTrace(storeName string) *trace.BackendTrace {
|
||||
traces := trace.GetBackendTraces()
|
||||
for i := range traces {
|
||||
bt := &traces[i]
|
||||
if bt.Type == trace.BackendTraceVectorStore && bt.ModelName == storeName {
|
||||
return bt
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ = Describe("localVectorStore tracing", func() {
|
||||
// Pin the trace surface admins read from /api/backend-traces.
|
||||
// The original failure mode that motivated these specs — the
|
||||
// local-store backend not installed — was silent on every surface
|
||||
// except a per-call xlog.Warn. With tracing wired in, the row
|
||||
// appears next to the embedder/score traces for the same request.
|
||||
BeforeEach(func() {
|
||||
trace.ClearBackendTraces()
|
||||
})
|
||||
|
||||
It("records a vector_store trace with outcome=backend_load_error when the backend can't be loaded", func() {
|
||||
// nil ModelLoader → s.backend → StoreBackend → panics on load.
|
||||
// Use a real-but-empty loader so the failure surfaces as an
|
||||
// error instead, exercising the load-failure trace path the
|
||||
// admin would hit when local-store isn't installed.
|
||||
appCfg := &config.ApplicationConfig{
|
||||
EnableTracing: true,
|
||||
TracingMaxItems: 16,
|
||||
TracingMaxBodyBytes: 1024,
|
||||
}
|
||||
s := &localVectorStore{
|
||||
loader: model.NewModelLoader(&system.SystemState{}),
|
||||
appConfig: appCfg,
|
||||
storeName: "router-cache-test",
|
||||
}
|
||||
|
||||
// Search must surface the error AND record a trace describing it.
|
||||
_, _, _, err := s.Search(context.Background(), []float32{0.1, 0.2, 0.3})
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
Eventually(func() *trace.BackendTrace {
|
||||
return findVectorStoreTrace("router-cache-test")
|
||||
}).ShouldNot(BeNil())
|
||||
|
||||
bt := findVectorStoreTrace("router-cache-test")
|
||||
Expect(bt.Backend).To(Equal(model.LocalStoreBackend))
|
||||
Expect(bt.Data["op"]).To(Equal("search"))
|
||||
Expect(bt.Data["outcome"]).To(Equal("backend_load_error"))
|
||||
Expect(bt.Data["vector_dim"]).To(Equal(3))
|
||||
// Error is the wrapped "vector store load: …" surfaced to the caller.
|
||||
Expect(bt.Error).To(ContainSubstring("vector store load"))
|
||||
})
|
||||
|
||||
It("does not record a trace when tracing is disabled", func() {
|
||||
// Opt-out path: appConfig.EnableTracing=false must short-circuit
|
||||
// before InitBackendTracingIfEnabled, so a workload with tracing
|
||||
// turned off doesn't pay the channel-send cost per cache call.
|
||||
appCfg := &config.ApplicationConfig{EnableTracing: false}
|
||||
s := &localVectorStore{
|
||||
loader: model.NewModelLoader(&system.SystemState{}),
|
||||
appConfig: appCfg,
|
||||
storeName: "router-cache-disabled",
|
||||
}
|
||||
_, _, _, _ = s.Search(context.Background(), []float32{1})
|
||||
Consistently(func() *trace.BackendTrace {
|
||||
return findVectorStoreTrace("router-cache-disabled")
|
||||
}).Should(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -7,9 +7,23 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// tokenizeTokenCount returns the number of tokens in a backend response,
|
||||
// treating a nil response as zero. The gRPC client returns (nil, err) on
|
||||
// failure, and the tracing block below runs before that error is returned —
|
||||
// so the count must be read nil-safely here. Reading resp.Tokens on a nil
|
||||
// resp previously panicked the whole HTTP handler when tracing was enabled
|
||||
// (e.g. a transient tokenize failure during router probe-budget sizing).
|
||||
func tokenizeTokenCount(resp *pb.TokenizationResponse) int {
|
||||
if resp == nil {
|
||||
return 0
|
||||
}
|
||||
return len(resp.Tokens)
|
||||
}
|
||||
|
||||
func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
|
||||
|
||||
var inferenceModel grpc.Backend
|
||||
@@ -40,10 +54,7 @@ func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.Model
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
tokenCount := 0
|
||||
if resp.Tokens != nil {
|
||||
tokenCount = len(resp.Tokens)
|
||||
}
|
||||
tokenCount := tokenizeTokenCount(resp)
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
@@ -64,8 +75,8 @@ func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.Model
|
||||
return schema.TokenizeResponse{}, err
|
||||
}
|
||||
|
||||
if resp.Tokens == nil {
|
||||
resp.Tokens = make([]int32, 0)
|
||||
if resp == nil || resp.Tokens == nil {
|
||||
return schema.TokenizeResponse{Tokens: make([]int32, 0)}, nil
|
||||
}
|
||||
|
||||
return schema.TokenizeResponse{
|
||||
|
||||
27
core/backend/tokenize_test.go
Normal file
27
core/backend/tokenize_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("tokenizeTokenCount", func() {
|
||||
// Regression: the gRPC client returns (nil, err) when a tokenize call
|
||||
// fails, and ModelTokenize's tracing block reads the token count before
|
||||
// the error is returned. Dereferencing a nil response there panicked the
|
||||
// HTTP handler (nil pointer dereference) — e.g. a transient tokenize
|
||||
// failure while the router sized its probe-token budget.
|
||||
It("returns zero for a nil response instead of panicking", func() {
|
||||
Expect(tokenizeTokenCount(nil)).To(Equal(0))
|
||||
})
|
||||
|
||||
It("returns zero when the response carries no tokens", func() {
|
||||
Expect(tokenizeTokenCount(&pb.TokenizationResponse{})).To(Equal(0))
|
||||
})
|
||||
|
||||
It("counts the tokens present on the response", func() {
|
||||
Expect(tokenizeTokenCount(&pb.TokenizationResponse{Tokens: []int32{1, 2, 3}})).To(Equal(3))
|
||||
})
|
||||
})
|
||||
@@ -65,7 +65,7 @@ type ApplicationConfig struct {
|
||||
//
|
||||
// patterns:
|
||||
// - id: email
|
||||
// action: route_local # downgrade default mask -> route_local
|
||||
// action: allow # downgrade default mask -> allow (log only)
|
||||
// - id: ssn
|
||||
// action: block # upgrade default mask -> block
|
||||
//
|
||||
|
||||
@@ -93,6 +93,9 @@ func applyOverride(f *FieldMeta, o FieldMetaOverride) {
|
||||
if o.Component != "" {
|
||||
f.Component = o.Component
|
||||
}
|
||||
if o.Language != "" {
|
||||
f.Language = o.Language
|
||||
}
|
||||
if o.Placeholder != "" {
|
||||
f.Placeholder = o.Placeholder
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ const (
|
||||
ProviderModelsTTS = "models:tts"
|
||||
ProviderModelsTranscript = "models:transcript"
|
||||
ProviderModelsVAD = "models:vad"
|
||||
ProviderModelsScore = "models:score"
|
||||
)
|
||||
|
||||
// Static option lists embedded directly in field metadata.
|
||||
|
||||
@@ -226,6 +226,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Label: "Chat Template",
|
||||
Description: "Go template for chat completion requests",
|
||||
Component: "code-editor",
|
||||
Language: "gotemplate",
|
||||
Order: 40,
|
||||
},
|
||||
"template.chat_message": {
|
||||
@@ -233,6 +234,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Label: "Chat Message Template",
|
||||
Description: "Go template for individual chat messages",
|
||||
Component: "code-editor",
|
||||
Language: "gotemplate",
|
||||
Order: 41,
|
||||
},
|
||||
"template.completion": {
|
||||
@@ -240,13 +242,22 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Label: "Completion Template",
|
||||
Description: "Go template for completion requests",
|
||||
Component: "code-editor",
|
||||
Language: "gotemplate",
|
||||
Order: 42,
|
||||
},
|
||||
"template.function": {
|
||||
Section: "templates",
|
||||
Label: "Functions Template",
|
||||
Description: "Go template applied when tools/functions are present in the request",
|
||||
Component: "code-editor",
|
||||
Language: "gotemplate",
|
||||
Order: 43,
|
||||
},
|
||||
"template.use_tokenizer_template": {
|
||||
Section: "templates",
|
||||
Label: "Use Tokenizer Template",
|
||||
Description: "Use the chat template from the model's tokenizer config",
|
||||
Order: 43,
|
||||
Order: 44,
|
||||
},
|
||||
// Router section template — kept in the templates UI section
|
||||
// (rather than the router section under "other") so operators
|
||||
@@ -257,7 +268,8 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Label: "Router Classifier System Prompt",
|
||||
Description: "Go text/template (with sprig functions) for the routing system prompt the score classifier feeds to its classifier_model. Executed with `.Policies` ([]{Label, Description}). Empty falls back to the built-in Arch-Router-shaped prompt (route-listing block + JSON output schema). Override when the classifier model was trained on a different schema or you need the routing instructions in a different language. The candidate format scored against the model is fixed at `{\"route\": \"<label>\"}` — keep your override's output schema instruction matching that.",
|
||||
Component: "code-editor",
|
||||
Order: 44,
|
||||
Language: "gotemplate",
|
||||
Order: 45,
|
||||
},
|
||||
|
||||
// --- Pipeline ---
|
||||
@@ -400,14 +412,14 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
|
||||
// --- PII filtering (per-model) ---
|
||||
"pii.enabled": {
|
||||
Section: "other",
|
||||
Section: "pii",
|
||||
Label: "PII Filtering Enabled",
|
||||
Description: "Enable PII redaction middleware for this model. Unset means use the default (off for local backends, on for proxy-* / cloud-hosted backends).",
|
||||
Component: "toggle",
|
||||
Order: 200,
|
||||
},
|
||||
"pii.patterns": {
|
||||
Section: "other",
|
||||
Section: "pii",
|
||||
Label: "PII Pattern Overrides",
|
||||
Description: "Override the global default action for specific patterns on this model. Patterns not listed here inherit the global action (Settings → Middleware → Filtering).",
|
||||
Component: "pii-pattern-list",
|
||||
@@ -420,7 +432,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
// fails closed — the chat handler does NOT silently fall back
|
||||
// to the local gRPC pipeline.
|
||||
"proxy.mode": {
|
||||
Section: "other",
|
||||
Section: "proxy",
|
||||
Label: "Proxy Mode",
|
||||
Description: "passthrough forwards the client's OpenAI body verbatim — point upstream_url at an OpenAI-compatible endpoint (incl. Anthropic's /v1/chat/completions compat layer). translate converts OpenAI ↔ Anthropic Messages so you can target a native API (/v1/messages); tool_calls and usage tokens survive the round-trip.",
|
||||
Component: "select",
|
||||
@@ -432,7 +444,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 208,
|
||||
},
|
||||
"proxy.provider": {
|
||||
Section: "other",
|
||||
Section: "proxy",
|
||||
Label: "Proxy Provider",
|
||||
Description: "Upstream API family. Drives auth header shape (Bearer vs x-api-key + anthropic-version) and, in translate mode, which request/response codec is used.",
|
||||
Component: "select",
|
||||
@@ -444,28 +456,28 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 209,
|
||||
},
|
||||
"proxy.upstream_url": {
|
||||
Section: "other",
|
||||
Section: "proxy",
|
||||
Label: "Proxy Upstream URL",
|
||||
Description: "Full POST endpoint of the upstream provider (e.g. https://api.openai.com/v1/chat/completions). Only used when Backend is cloud-proxy.",
|
||||
Component: "input",
|
||||
Order: 210,
|
||||
},
|
||||
"proxy.api_key_env": {
|
||||
Section: "other",
|
||||
Section: "proxy",
|
||||
Label: "Proxy API Key Env Var",
|
||||
Description: "Name of the environment variable holding the upstream API key. Reading from env keeps the secret out of the YAML and the admin UI.",
|
||||
Component: "input",
|
||||
Order: 211,
|
||||
},
|
||||
"proxy.upstream_model": {
|
||||
Section: "other",
|
||||
Section: "proxy",
|
||||
Label: "Proxy Upstream Model",
|
||||
Description: "Model name sent to the upstream. Leave empty to forward the client's model field unchanged. Useful when the LocalAI alias differs from the upstream's canonical name.",
|
||||
Component: "input",
|
||||
Order: 212,
|
||||
},
|
||||
"proxy.request_timeout_seconds": {
|
||||
Section: "other",
|
||||
Section: "proxy",
|
||||
Label: "Proxy Request Timeout (seconds)",
|
||||
Description: "Caps the upstream HTTP request duration. 0 disables the deadline; the request still ends when the client disconnects.",
|
||||
Component: "number",
|
||||
@@ -480,7 +492,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
// A host claimed by two configs is a critical error — the
|
||||
// listener refuses to start until resolved.
|
||||
"mitm.hosts": {
|
||||
Section: "other",
|
||||
Section: "mitm",
|
||||
Label: "MITM Intercept Hosts",
|
||||
Description: "Hostnames the cloudproxy MITM proxy terminates TLS for on behalf of this model config. PII filtering and pattern overrides flow from this model when the host is intercepted. Each host must be unique across all configs.",
|
||||
Component: "string-list",
|
||||
@@ -495,7 +507,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
// the middleware admin page surfaces every model with a router
|
||||
// block.
|
||||
"router.classifier": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "Classifier",
|
||||
Description: "Picks a candidate by scoring every policy label against the prompt. Only \"score\" is shipped today; it asks the classifier_model to rank each label and reads off the softmax. Empty defaults to \"score\".",
|
||||
Component: "select",
|
||||
@@ -505,15 +517,15 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 230,
|
||||
},
|
||||
"router.classifier_model": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "Classifier Model",
|
||||
Description: "Loaded LocalAI model the score classifier asks to rank each policy label as a continuation. Must support the Score gRPC primitive (today: llama-cpp, vLLM) and use the ChatML template. Arch-Router-1.5B Q4_K_M is the canonical choice; any small ChatML instruct model also works at a higher activation_threshold.",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsChat,
|
||||
AutocompleteProvider: ProviderModelsScore,
|
||||
Order: 231,
|
||||
},
|
||||
"router.fallback": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "Fallback Model",
|
||||
Description: "Model used when no candidate's labels cover the classifier's active label set, or when the classifier errors. Empty means router failures bubble up as HTTP 500 — fail-fast, not silent-bypass.",
|
||||
Component: "model-select",
|
||||
@@ -521,7 +533,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 232,
|
||||
},
|
||||
"router.activation_threshold": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "Activation Threshold",
|
||||
Description: "Softmax-probability floor a policy must clear to join the active label set for a request. Higher → single-label dominant routes; lower → more multi-label activations. 0 picks the package default (0.15). On Arch-Router-1.5B a value around 0.40 keeps the dominant label clean without losing genuine compound activations.",
|
||||
Component: "slider",
|
||||
@@ -531,7 +543,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 233,
|
||||
},
|
||||
"router.classifier_cache_size": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "Classifier L1 Cache Size",
|
||||
Description: "Bounded LRU keyed on (case-folded, whitespace-trimmed) prompt — amortises the classifier round-trip across verbatim repeats common in agent loops. 0 here means \"use the default\" (1024); the cache cannot be disabled from YAML.",
|
||||
Component: "number",
|
||||
@@ -539,21 +551,21 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 234,
|
||||
},
|
||||
"router.policies": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "Policies",
|
||||
Description: "Label vocabulary the classifier scores over. Each policy has a label and a short natural-language description fed verbatim to the classifier model. Short action-oriented sentences work best (\"writing or debugging code\"; \"small talk\").",
|
||||
Component: "router-policies",
|
||||
Order: 235,
|
||||
},
|
||||
"router.candidates": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "Candidates",
|
||||
Description: "Routing table: each entry binds a downstream model to a set of policy labels it can serve. Order matters — the middleware picks the FIRST candidate whose labels are a superset of the active set, so list candidates smallest → largest.",
|
||||
Component: "router-candidates",
|
||||
Order: 236,
|
||||
},
|
||||
"router.score_normalization": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "Score Normalization",
|
||||
Description: "How the score classifier collapses per-candidate joint log-probs into the softmax input. \"raw\" (default) feeds joint log-prob as-is — on-distribution for Arch-Router (the route the model would actually emit if decoded freely). \"mean\" divides by candidate token count — fairer to long labels but off-distribution for models trained to emit fixed-format outputs.",
|
||||
Component: "select",
|
||||
@@ -565,7 +577,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 240,
|
||||
},
|
||||
"router.embedding_cache.embedding_model": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "L2 Cache: Embedding Model",
|
||||
Description: "Embedding model used by the L2 decision cache. Embeds incoming probes and looks them up in the per-router local-store collection. Empty disables the cache entirely. nomic-embed-text-v1.5 is the recommended default.",
|
||||
Component: "model-select",
|
||||
@@ -573,7 +585,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 237,
|
||||
},
|
||||
"router.embedding_cache.similarity_threshold": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "L2 Cache: Similarity Threshold",
|
||||
Description: "Cosine-similarity floor a cache candidate must clear to count as a hit. 0 picks the package default (0.80). Re-tune per embedding model — the histogram on the Routing tab shows where the cosine distribution actually sits.",
|
||||
Component: "slider",
|
||||
@@ -583,7 +595,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 238,
|
||||
},
|
||||
"router.embedding_cache.confidence_threshold": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "L2 Cache: Confidence Threshold",
|
||||
Description: "Minimum top-label probability a classifier decision must have to be inserted into the cache. 0 picks the package default (0.60). Uncertain decisions are skipped so they can't poison future paraphrases.",
|
||||
Component: "slider",
|
||||
@@ -593,7 +605,7 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Order: 239,
|
||||
},
|
||||
"router.embedding_cache.store_name": {
|
||||
Section: "other",
|
||||
Section: "router",
|
||||
Label: "L2 Cache: Store Name",
|
||||
Description: "Optional override for the local-store collection used by this router's cache. Empty defaults to \"router-cache-<router-model-name>\". Two routers sharing a store_name share their cache (rare).",
|
||||
Component: "input",
|
||||
|
||||
@@ -240,7 +240,6 @@ var grandfatheredUnregistered = []string{
|
||||
"swap_space",
|
||||
"system_prompt",
|
||||
"template.edit",
|
||||
"template.function",
|
||||
"template.join_chat_messages_by_character",
|
||||
"template.multimodal",
|
||||
"template.reply_prefix",
|
||||
|
||||
@@ -11,6 +11,7 @@ type FieldMeta struct {
|
||||
Label string `json:"label"` // human-readable label
|
||||
Description string `json:"description,omitempty"` // help text
|
||||
Component string `json:"component"` // "input", "number", "toggle", "select", "slider", etc.
|
||||
Language string `json:"language,omitempty"` // syntax mode for code-editor fields: "yaml" (default), "gotemplate"
|
||||
Placeholder string `json:"placeholder,omitempty"`
|
||||
Default any `json:"default,omitempty"`
|
||||
Min *float64 `json:"min,omitempty"`
|
||||
@@ -51,6 +52,7 @@ type FieldMetaOverride struct {
|
||||
Label string
|
||||
Description string
|
||||
Component string
|
||||
Language string
|
||||
Placeholder string
|
||||
Default any
|
||||
Min *float64
|
||||
@@ -78,6 +80,10 @@ func DefaultSections() []Section {
|
||||
{ID: "grpc", Label: "gRPC", Icon: "server", Order: 65},
|
||||
{ID: "agent", Label: "Agent", Icon: "bot", Order: 70},
|
||||
{ID: "mcp", Label: "MCP", Icon: "plug", Order: 75},
|
||||
{ID: "router", Label: "Router", Icon: "git-merge", Order: 78},
|
||||
{ID: "proxy", Label: "Proxy", Icon: "cloud", Order: 80},
|
||||
{ID: "mitm", Label: "MITM Proxy", Icon: "shield", Order: 82},
|
||||
{ID: "pii", Label: "PII", Icon: "shield", Order: 84},
|
||||
{ID: "other", Label: "Other", Icon: "more-horizontal", Order: 100},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,7 +385,7 @@ type PIIConfig struct {
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
|
||||
// Patterns lets a model upgrade or downgrade individual pattern
|
||||
// actions (mask | block | route_local) relative to the global
|
||||
// actions (mask | block | allow) relative to the global
|
||||
// defaults loaded from --pii-config / DefaultPatterns. Pattern IDs
|
||||
// not listed inherit the global action. The regex itself stays
|
||||
// global — only the action is settable per-model.
|
||||
@@ -1274,14 +1274,20 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
|
||||
if (u & FLAG_CHAT) == FLAG_CHAT {
|
||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
|
||||
return false
|
||||
}
|
||||
if slices.Contains(nonTextGenBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
if c.Embeddings != nil && *c.Embeddings {
|
||||
return false
|
||||
// A router model is a chat dispatcher: it carries no chat
|
||||
// template of its own (those live on the candidates it routes
|
||||
// to) and is invoked through the chat endpoint, so the router
|
||||
// block stands in for chat capability.
|
||||
if !c.HasRouter() {
|
||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
|
||||
return false
|
||||
}
|
||||
if slices.Contains(nonTextGenBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
if c.Embeddings != nil && *c.Embeddings {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
|
||||
|
||||
@@ -283,6 +283,18 @@ parameters:
|
||||
Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue())
|
||||
|
||||
// Router models are chat dispatchers: no chat template of their
|
||||
// own, but invoked through the chat endpoint, so they default to
|
||||
// chat-capable.
|
||||
r := ModelConfig{
|
||||
Name: "r",
|
||||
Router: RouterConfig{
|
||||
Candidates: []RouterCandidate{{Model: "downstream", Labels: []string{"general"}}},
|
||||
},
|
||||
}
|
||||
Expect(r.HasUsecases(FLAG_ANY)).To(BeTrue())
|
||||
Expect(r.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
||||
|
||||
f := ModelConfig{
|
||||
Name: "f",
|
||||
Backend: "piper",
|
||||
|
||||
@@ -50,7 +50,14 @@ var _ = Describe("Runtime capability-based backend selection", func() {
|
||||
must(os.WriteFile(filepath.Join(cudaDir, "metadata.json"), b, 0o644))
|
||||
must(os.WriteFile(filepath.Join(cudaDir, "run.sh"), []byte(""), 0o755))
|
||||
|
||||
// Default system: alias should point to CPU
|
||||
// Default system: alias should point to CPU. Force the capability to
|
||||
// "cpu" so this is hermetic on hosts that actually have a GPU: backend
|
||||
// preference keys off getSystemCapabilities() (env → real nvidia-smi
|
||||
// detection), not GPUVendor, so without this a GPU dev box reports
|
||||
// "nvidia" and the cuda alias wins. The NVIDIA case below overrides it.
|
||||
must(os.Setenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY", "cpu"))
|
||||
defer func() { _ = os.Unsetenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY") }()
|
||||
|
||||
sysDefault, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
|
||||
@@ -353,7 +353,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal:
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionAllow:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ var instructionDefs = []instructionDef{
|
||||
Name: "pii-filtering",
|
||||
Description: "Inspect and tune the regex PII filter applied to chat requests",
|
||||
Tags: []string{"pii"},
|
||||
Intro: "GET /api/pii/patterns lists the active pattern set with each one's action (mask, block, route_local). GET /api/pii/events returns recent redaction events filtered by correlation_id / user_id / pattern_id (admin or local-user only). POST /api/pii/test dry-runs the redactor against an admin-supplied string. POST /api/pii/decide is the programmatic decision oracle for external routers: send `{text}`, receive `{findings, suggested_action, redacted_preview}` without LocalAI mutating, recording, or acting on the call — caller composes the action with its own policy. Default patterns: email, phone, SSN, credit card (Luhn), IPv4, common API key prefixes (sk-, pk-, ghp_, github_pat_). PII is per-model: by default it is OFF for non-proxy backends and ON for backends starting with proxy-* (cloud passthroughs). Opt in with `pii: { enabled: true }` in a model's YAML; use `pii: { patterns: [{id, action}] }` to upgrade or downgrade individual actions for that model. Override global default actions via --pii-config pii.yaml; --disable-pii turns the filter off entirely.",
|
||||
Intro: "GET /api/pii/patterns lists the active pattern set with each one's action (mask, block, allow). GET /api/pii/events returns recent redaction events filtered by correlation_id / user_id / pattern_id (admin or local-user only). POST /api/pii/test dry-runs the redactor against an admin-supplied string. POST /api/pii/decide is the programmatic decision oracle for external routers: send `{text}`, receive `{findings, suggested_action, redacted_preview}` without LocalAI mutating, recording, or acting on the call — caller composes the action with its own policy. Default patterns: email, phone, SSN, credit card (Luhn), IPv4, common API key prefixes (sk-, pk-, ghp_, github_pat_). PII is per-model: by default it is OFF for non-proxy backends and ON for backends starting with proxy-* (cloud passthroughs). Opt in with `pii: { enabled: true }` in a model's YAML; use `pii: { patterns: [{id, action}] }` to upgrade or downgrade individual actions for that model. Override global default actions via --pii-config pii.yaml; --disable-pii turns the filter off entirely.",
|
||||
},
|
||||
{
|
||||
Name: "middleware-admin",
|
||||
|
||||
@@ -124,6 +124,8 @@ func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_VAD)
|
||||
case config.UsecaseTranscript:
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)
|
||||
case "score": // router classifier usecase (FLAG_SCORE); not in UsecaseInfoMap
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_SCORE)
|
||||
default:
|
||||
filterFn = config.NoFilterFn
|
||||
}
|
||||
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
//
|
||||
// External routers (e.g. the localai-org/platform router) call this
|
||||
// before dispatching to learn whether to mask the prompt in place,
|
||||
// route to a local-only backend, block the request, or pass it
|
||||
// through. LocalAI's in-band PII middleware is the alternative path
|
||||
// for direct-to-LocalAI clients — same Redactor, different framing.
|
||||
// block the request, or pass it through. LocalAI's in-band PII
|
||||
// middleware is the alternative path for direct-to-LocalAI clients —
|
||||
// same Redactor, different framing.
|
||||
//
|
||||
// Takes the *pii.Redactor directly rather than the whole
|
||||
// *application.Application so the handler stays unit-testable with a
|
||||
@@ -62,24 +62,18 @@ func PIIDecideEndpoint(redactor *pii.Redactor) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// actionAllow is the wire-only value for "no findings". The other
|
||||
// three map to existing pii.Action* constants; allow has no in-band
|
||||
// counterpart because the in-band middleware simply passes through.
|
||||
const actionAllow = "allow"
|
||||
|
||||
// suggestedAction collapses the Redactor's Result flags onto a single
|
||||
// wire-format action using the in-band ordering (block > route_local
|
||||
// > mask > allow). Spans-without-Blocked-or-LocalOnly means every
|
||||
// match resolved to ActionMask.
|
||||
// wire-format action using the in-band ordering (block > mask >
|
||||
// allow). "allow" covers both "nothing matched" and "matched but every
|
||||
// span resolved to the allow action" — in both cases the caller may
|
||||
// dispatch unchanged, with the Findings list reporting what was seen.
|
||||
func suggestedAction(res pii.Result) string {
|
||||
switch {
|
||||
case res.Blocked:
|
||||
return string(pii.ActionBlock)
|
||||
case res.LocalOnly:
|
||||
return string(pii.ActionRouteLocal)
|
||||
case len(res.Spans) > 0:
|
||||
case res.Masked:
|
||||
return string(pii.ActionMask)
|
||||
default:
|
||||
return actionAllow
|
||||
return string(pii.ActionAllow)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
|
||||
// PIIDecideEndpoint exposes the redactor as a decision oracle. These
|
||||
// specs pin the validation surface and the suggested_action mapping
|
||||
// across all four actions (allow/mask/route_local/block). The redactor
|
||||
// itself is covered in core/services/routing/pii/redactor_test.go.
|
||||
// across the three actions (allow/mask/block). The redactor itself is
|
||||
// covered in core/services/routing/pii/redactor_test.go.
|
||||
|
||||
var _ = Describe("PIIDecideEndpoint", func() {
|
||||
var redactor *pii.Redactor
|
||||
@@ -68,16 +68,17 @@ var _ = Describe("PIIDecideEndpoint", func() {
|
||||
Expect(len(body.Findings)).To(BeNumerically(">=", 1))
|
||||
})
|
||||
|
||||
It("returns route_local when an override sets that action", func() {
|
||||
// Promote the email pattern to route_local for this test —
|
||||
// exercises the route_local branch of suggestedAction without
|
||||
// needing a custom pattern set.
|
||||
Expect(redactor.SetAction("email", pii.ActionRouteLocal)).To(Succeed())
|
||||
It("returns allow when a matched pattern's action is allow", func() {
|
||||
// Downgrade the email pattern to allow for this test —
|
||||
// exercises the allow branch of suggestedAction: a match is
|
||||
// found, but the strongest action is allow so the suggestion
|
||||
// is "allow" and the text is left intact.
|
||||
Expect(redactor.SetAction("email", pii.ActionAllow)).To(Succeed())
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"contact alice@example.com"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.SuggestedAction).To(Equal("route_local"))
|
||||
// route_local leaves the original text intact — caller decides
|
||||
// whether to forward it to a local-only backend.
|
||||
Expect(body.SuggestedAction).To(Equal("allow"))
|
||||
Expect(body.Findings).To(HaveLen(1), "allow still reports the finding")
|
||||
// allow leaves the original text intact.
|
||||
Expect(body.RedactedPreview).To(ContainSubstring("alice@example.com"))
|
||||
})
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal:
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionAllow:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,13 +451,14 @@ func buildRealtimeRoutingContext(a *application.Application, sessionID string) *
|
||||
return nil
|
||||
}
|
||||
deps := &middleware.ClassifierDeps{
|
||||
Scorer: a.Scorer,
|
||||
Embedder: a.Embedder,
|
||||
VectorStore: a.VectorStore,
|
||||
Reranker: a.Reranker,
|
||||
ModelLookup: a.ModelConfigLookup(),
|
||||
Registry: a.RouterClassifierRegistry(),
|
||||
Evaluator: a.TemplatesEvaluator(),
|
||||
Scorer: a.Scorer,
|
||||
TokenCounter: a.TokenCounter,
|
||||
Embedder: a.Embedder,
|
||||
VectorStore: a.VectorStore,
|
||||
Reranker: a.Reranker,
|
||||
ModelLookup: a.ModelConfigLookup(),
|
||||
Registry: a.RouterClassifierRegistry(),
|
||||
Evaluator: a.TemplatesEvaluator(),
|
||||
}
|
||||
userID := ""
|
||||
if u := a.FallbackUser(); u != nil {
|
||||
|
||||
139
core/http/middleware/probe_trim_test.go
Normal file
139
core/http/middleware/probe_trim_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("routerConfigFingerprint", func() {
|
||||
rc := config.RouterConfig{Classifier: "score", ClassifierModel: "arch-router"}
|
||||
ctx4096 := 4096
|
||||
ctx8192 := 8192
|
||||
|
||||
// Regression: the score classifier bakes context_size into its token
|
||||
// budget at build time, and the built classifier is cached by this
|
||||
// fingerprint. If context_size weren't hashed, editing it and reloading
|
||||
// would return a classifier carrying the stale budget.
|
||||
It("changes when the classifier model's context_size changes", func() {
|
||||
cfgA := &config.ModelConfig{LLMConfig: config.LLMConfig{ContextSize: &ctx4096}}
|
||||
cfgB := &config.ModelConfig{LLMConfig: config.LLMConfig{ContextSize: &ctx8192}}
|
||||
Expect(routerConfigFingerprint(rc, cfgA)).NotTo(Equal(routerConfigFingerprint(rc, cfgB)))
|
||||
})
|
||||
|
||||
It("is stable for identical classifier configs", func() {
|
||||
cfgA := &config.ModelConfig{LLMConfig: config.LLMConfig{ContextSize: &ctx4096}}
|
||||
cfgB := &config.ModelConfig{LLMConfig: config.LLMConfig{ContextSize: &ctx4096}}
|
||||
Expect(routerConfigFingerprint(rc, cfgA)).To(Equal(routerConfigFingerprint(rc, cfgB)))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("routing probe extraction and trimming", func() {
|
||||
Describe("OpenAIProbeFromRequest", func() {
|
||||
It("keeps a short conversation intact, newline-terminated per message", func() {
|
||||
req := &schema.OpenAIRequest{Messages: []schema.Message{
|
||||
{Role: "user", Content: "first"},
|
||||
{Role: "assistant", Content: "second"},
|
||||
{Role: "user", Content: "third"},
|
||||
}}
|
||||
Expect(OpenAIProbeFromRequest(req).Prompt).To(Equal("first\nsecond\nthird\n"))
|
||||
})
|
||||
|
||||
It("flattens text blocks and skips image-only messages", func() {
|
||||
req := &schema.OpenAIRequest{Messages: []schema.Message{
|
||||
{Role: "user", Content: []any{
|
||||
map[string]any{"type": "text", "text": "describe this"},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:..."}},
|
||||
}},
|
||||
{Role: "user", Content: []any{
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:..."}},
|
||||
}},
|
||||
}}
|
||||
// Second message contributes no text, so it neither adds a blank
|
||||
// line nor a stray newline.
|
||||
Expect(OpenAIProbeFromRequest(req).Prompt).To(Equal("describe this\n"))
|
||||
})
|
||||
|
||||
It("carries the full conversation untrimmed — trimming is each classifier's job", func() {
|
||||
// The middleware no longer caps the probe by a fixed rune budget;
|
||||
// every turn reaches the Probe and each classifier trims to its own
|
||||
// model's context (see modelTokenTrim / promptTrimmer).
|
||||
block := strings.Repeat("x", 999)
|
||||
msgs := make([]schema.Message, 0, 20)
|
||||
msgs = append(msgs, schema.Message{Role: "user", Content: "OLDEST" + strings.Repeat("o", 994)})
|
||||
for range 18 {
|
||||
msgs = append(msgs, schema.Message{Role: "user", Content: block})
|
||||
}
|
||||
msgs = append(msgs, schema.Message{Role: "user", Content: "NEWEST" + strings.Repeat("n", 994)})
|
||||
|
||||
probe := OpenAIProbeFromRequest(&schema.OpenAIRequest{Messages: msgs})
|
||||
Expect(probe.Prompt).To(ContainSubstring("OLDEST"), "no turn is dropped at probe-build time")
|
||||
Expect(probe.Prompt).To(ContainSubstring("NEWEST"))
|
||||
// Messages preserves the per-turn split the classifier trims from.
|
||||
Expect(probe.Messages).To(HaveLen(20))
|
||||
Expect(probe.Messages[0]).To(ContainSubstring("OLDEST"))
|
||||
Expect(probe.Messages[19]).To(ContainSubstring("NEWEST"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("AnthropicProbe", func() {
|
||||
It("extracts and trims the same way as the OpenAI path", func() {
|
||||
req := &schema.AnthropicRequest{Messages: []schema.AnthropicMessage{
|
||||
{Role: "user", Content: "alpha"},
|
||||
{Role: "assistant", Content: []any{
|
||||
map[string]any{"type": "text", "text": "beta"},
|
||||
}},
|
||||
}}
|
||||
probe, ok := AnthropicProbe(req)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(probe.Prompt).To(Equal("alpha\nbeta\n"))
|
||||
})
|
||||
|
||||
It("returns ok=false for a non-Anthropic payload", func() {
|
||||
_, ok := AnthropicProbe(&schema.OpenAIRequest{})
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("modelTokenTrim", func() {
|
||||
tok := func(string) (int, error) { return 1, nil }
|
||||
depsFor := func(cfg *config.ModelConfig) ClassifierDeps {
|
||||
return ClassifierDeps{
|
||||
ModelLookup: func(string) *config.ModelConfig { return cfg },
|
||||
TokenCounter: func(string) func(string) (int, error) { return tok },
|
||||
}
|
||||
}
|
||||
|
||||
It("still trims to the backend default when context_size is unset", func() {
|
||||
// Regression: with the fixed middleware rune cap gone, an unset
|
||||
// context_size must NOT disable trimming — otherwise a non-trivial
|
||||
// prompt overflows the default 4096 window and every score fails.
|
||||
score := config.FLAG_SCORE
|
||||
cfg := &config.ModelConfig{KnownUsecases: &score} // FLAG_SCORE → batch follows context
|
||||
count, ceiling := modelTokenTrim("classifier", depsFor(cfg))
|
||||
Expect(count).NotTo(BeNil())
|
||||
Expect(ceiling).To(Equal(4096), "unset context_size falls back to the backend default, not 0")
|
||||
})
|
||||
|
||||
It("is bounded by the batch when the batch is smaller than the context", func() {
|
||||
// The probe is one decode (n_tokens <= n_batch). A model with a
|
||||
// large context but a small batch can only process the batch — the
|
||||
// ceiling must follow it, not the context.
|
||||
ctx8k := 8192
|
||||
cfg := &config.ModelConfig{LLMConfig: config.LLMConfig{ContextSize: &ctx8k}}
|
||||
cfg.Batch = 512
|
||||
_, ceiling := modelTokenTrim("embedder", depsFor(cfg))
|
||||
Expect(ceiling).To(Equal(512), "batch is the binding single-decode limit")
|
||||
})
|
||||
|
||||
It("disables trimming only when no tokenizer is available", func() {
|
||||
count, ceiling := modelTokenTrim("x", ClassifierDeps{ModelLookup: func(string) *config.ModelConfig { return &config.ModelConfig{} }})
|
||||
Expect(count).To(BeNil())
|
||||
Expect(ceiling).To(Equal(0))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -86,6 +87,12 @@ type ClassifierDeps struct {
|
||||
// templates.Evaluator so any model the operator points at gets
|
||||
// its own chat template applied.
|
||||
Evaluator *templates.Evaluator
|
||||
|
||||
// TokenCounter binds the classifier model's tokenizer for the score
|
||||
// classifier's token-trim path. Optional; nil falls back to the
|
||||
// backend's n_ctx guard. Plain func type so core/application supplies
|
||||
// it as a method value without importing this package.
|
||||
TokenCounter func(modelName string) func(text string) (int, error)
|
||||
}
|
||||
|
||||
// ProbeExtractor pulls the prompt content out of a parsed request so
|
||||
@@ -212,7 +219,6 @@ func recordHTTPDecision(c echo.Context, store router.DecisionStore, result *rout
|
||||
_ = store.Record(context.Background(), result.ToDecisionRecord(newDecisionID(), correlationID, userID, source))
|
||||
}
|
||||
|
||||
|
||||
// GetOrBuildClassifier looks up a built Classifier for the named router
|
||||
// model in the registry and builds it on miss. Exported so the
|
||||
// /api/router/decide decision-oracle endpoint can share the same
|
||||
@@ -262,9 +268,10 @@ func routerConfigFingerprint(rc config.RouterConfig, classifierCfg *config.Model
|
||||
h := fnv.New64a()
|
||||
h.Write(bytes)
|
||||
if classifierCfg != nil {
|
||||
// Narrow projection: only the fields newTemplateRenderer and
|
||||
// firstStopWord actually read. Hashing the whole ModelConfig
|
||||
// would invalidate the cache on irrelevant parameter changes.
|
||||
// Narrow projection: only the fields buildClassifier reads (renderer,
|
||||
// stop tokens, context_size → MaxContextTokens). Hashing the whole
|
||||
// ModelConfig would invalidate the cache on irrelevant changes;
|
||||
// omitting context_size would let a reload leave a stale token budget.
|
||||
h.Write([]byte{0}) // separator so empty fields don't collide
|
||||
h.Write([]byte(classifierCfg.TemplateConfig.Chat))
|
||||
h.Write([]byte{0})
|
||||
@@ -274,6 +281,10 @@ func routerConfigFingerprint(rc config.RouterConfig, classifierCfg *config.Model
|
||||
h.Write([]byte(sw))
|
||||
h.Write([]byte{0})
|
||||
}
|
||||
h.Write([]byte{0})
|
||||
if classifierCfg.ContextSize != nil {
|
||||
h.Write([]byte(strconv.Itoa(*classifierCfg.ContextSize)))
|
||||
}
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
@@ -319,11 +330,30 @@ func buildClassifier(cfg *config.ModelConfig, deps ClassifierDeps) (router.Class
|
||||
if deps.ModelLookup != nil {
|
||||
if classifierCfg := deps.ModelLookup(rc.ClassifierModel); classifierCfg != nil {
|
||||
if deps.Evaluator != nil {
|
||||
opts.PromptRenderer = newTemplateRenderer(deps.Evaluator, classifierCfg)
|
||||
// The router renders the scoring prompt client-side, so the
|
||||
// classifier model MUST carry a chat template — refusing
|
||||
// here beats silently falling back to a generic ChatML
|
||||
// envelope the model may not have been trained on.
|
||||
renderer := newTemplateRenderer(deps.Evaluator, classifierCfg)
|
||||
if renderer == nil {
|
||||
return nil, fmt.Errorf(
|
||||
"router classifier score: classifier_model %q has no chat template "+
|
||||
"(set template.chat and template.chat_message in its config). The router "+
|
||||
"renders the scoring prompt with the classifier model's own template; "+
|
||||
"without it the prompt format would not match the model",
|
||||
rc.ClassifierModel)
|
||||
}
|
||||
opts.PromptRenderer = renderer
|
||||
}
|
||||
if st := pickAssistantTurnEnd(classifierCfg.StopWords, classifierCfg.TemplateConfig.ChatMessage); st != "" {
|
||||
opts.StopToken = st
|
||||
}
|
||||
// Token-exact conversation trim — score classifier drops the
|
||||
// oldest turns using the model's own tokenizer.
|
||||
if count, ctxTokens := modelTokenTrim(rc.ClassifierModel, deps); count != nil {
|
||||
opts.TokenCounter = count
|
||||
opts.MaxContextTokens = ctxTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
inner = router.NewScoreClassifier(policies, scorer, opts)
|
||||
@@ -335,7 +365,11 @@ func buildClassifier(cfg *config.ModelConfig, deps ClassifierDeps) (router.Class
|
||||
if reranker == nil {
|
||||
return nil, fmt.Errorf("router classifier colbert: classifier_model %q not loadable", rc.ClassifierModel)
|
||||
}
|
||||
inner = router.NewRerankClassifier(policies, reranker, cacheCap, rc.ActivationThreshold)
|
||||
rerankClassifier := router.NewRerankClassifier(policies, reranker, cacheCap, rc.ActivationThreshold)
|
||||
if count, ctxTokens := modelTokenTrim(rc.ClassifierModel, deps); count != nil {
|
||||
rerankClassifier = rerankClassifier.WithTokenTrim(count, ctxTokens)
|
||||
}
|
||||
inner = rerankClassifier
|
||||
default:
|
||||
return nil, fmt.Errorf("router: unknown classifier %q (supported: %s)", name, strings.Join([]string{router.ClassifierScore, router.ClassifierColbert}, ", "))
|
||||
}
|
||||
@@ -523,7 +557,41 @@ func wrapWithEmbeddingCache(cfg *config.ModelConfig, inner router.Classifier, de
|
||||
if vstore == nil {
|
||||
return nil, fmt.Errorf("vector store %q not loadable", storeName)
|
||||
}
|
||||
return router.NewEmbeddingCacheClassifier(inner, embedder, vstore, ec.SimilarityThreshold, ec.ConfidenceThreshold), nil
|
||||
cache := router.NewEmbeddingCacheClassifier(inner, embedder, vstore, ec.SimilarityThreshold, ec.ConfidenceThreshold)
|
||||
// Trim the probe to the embedder model's own context (e.g. nomic-embed at
|
||||
// 8k) rather than a fixed guess — otherwise the cache key is an embedding
|
||||
// of a silently-truncated conversation.
|
||||
if count, ctxTokens := modelTokenTrim(ec.EmbeddingModel, deps); count != nil {
|
||||
cache = cache.WithTokenTrim(count, ctxTokens)
|
||||
}
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// modelTokenTrim returns a model's own tokenizer and the token ceiling its
|
||||
// probe must fit, or (nil, 0) when no tokenizer is available (only then can we
|
||||
// not trim exactly). The ceiling is min(effective context, effective batch):
|
||||
// score/embed/rerank all decode the whole prompt in one pass, so it must fit
|
||||
// both the context window and a single batch. Using the backend's *effective*
|
||||
// values — not the raw config fields — means trimming still works when
|
||||
// context_size and batch are unset; otherwise a non-trivial prompt overflows
|
||||
// the default window and every classification fails.
|
||||
func modelTokenTrim(modelName string, deps ClassifierDeps) (func(string) (int, error), int) {
|
||||
if deps.TokenCounter == nil || deps.ModelLookup == nil {
|
||||
return nil, 0
|
||||
}
|
||||
cfg := deps.ModelLookup(modelName)
|
||||
if cfg == nil {
|
||||
return nil, 0
|
||||
}
|
||||
count := deps.TokenCounter(modelName)
|
||||
if count == nil {
|
||||
return nil, 0
|
||||
}
|
||||
ceiling := backend.EffectiveContextSize(*cfg)
|
||||
if b := backend.EffectiveBatchSize(*cfg); b < ceiling {
|
||||
ceiling = b
|
||||
}
|
||||
return count, ceiling
|
||||
}
|
||||
|
||||
func newDecisionID() string {
|
||||
@@ -545,6 +613,41 @@ func OpenAIProbe(parsed any) (router.Probe, bool) {
|
||||
return OpenAIProbeFromRequest(req), true
|
||||
}
|
||||
|
||||
// messageText flattens a chat message's Content to plain text: string content
|
||||
// verbatim; []any structured content contributes only its "text" blocks.
|
||||
func messageText(content any) string {
|
||||
switch ct := content.(type) {
|
||||
case string:
|
||||
return ct
|
||||
case []any:
|
||||
var b strings.Builder
|
||||
for _, block := range ct {
|
||||
if bm, ok := block.(map[string]any); ok && bm["type"] == "text" {
|
||||
if t, ok := bm["text"].(string); ok {
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// messageProbeParts drops empty (e.g. image-only) messages so they don't
|
||||
// consume budget or emit blank lines.
|
||||
func messageProbeParts(texts []string) []string {
|
||||
parts := make([]string, 0, len(texts))
|
||||
for _, t := range texts {
|
||||
if t != "" {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// OpenAIProbeFromRequest is the typed counterpart of OpenAIProbe — same
|
||||
// extraction logic, but takes the request struct directly. Realtime and
|
||||
// other non-HTTP callers use it to feed a probe to router.Resolve
|
||||
@@ -553,24 +656,15 @@ func OpenAIProbeFromRequest(req *schema.OpenAIRequest) router.Probe {
|
||||
if req == nil {
|
||||
return router.Probe{}
|
||||
}
|
||||
var b strings.Builder
|
||||
texts := make([]string, len(req.Messages))
|
||||
for i := range req.Messages {
|
||||
switch ct := req.Messages[i].Content.(type) {
|
||||
case string:
|
||||
b.WriteString(ct)
|
||||
b.WriteByte('\n')
|
||||
case []any:
|
||||
for _, block := range ct {
|
||||
if bm, ok := block.(map[string]any); ok && bm["type"] == "text" {
|
||||
if t, ok := bm["text"].(string); ok {
|
||||
b.WriteString(t)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
texts[i] = messageText(req.Messages[i].Content)
|
||||
}
|
||||
return router.Probe{Prompt: b.String()}
|
||||
parts := messageProbeParts(texts)
|
||||
// Prompt carries the full conversation; each classifier trims it to its own
|
||||
// model's context (see modelTokenTrim). Messages preserves the per-turn
|
||||
// split the trimmer drops oldest-first.
|
||||
return router.Probe{Prompt: router.JoinTurns(parts), Messages: parts}
|
||||
}
|
||||
|
||||
// AnthropicProbe is the AnthropicRequest analogue of OpenAIProbe.
|
||||
@@ -579,25 +673,10 @@ func AnthropicProbe(parsed any) (router.Probe, bool) {
|
||||
if !ok || req == nil {
|
||||
return router.Probe{}, false
|
||||
}
|
||||
var b strings.Builder
|
||||
texts := make([]string, len(req.Messages))
|
||||
for i := range req.Messages {
|
||||
switch ct := req.Messages[i].Content.(type) {
|
||||
case string:
|
||||
b.WriteString(ct)
|
||||
b.WriteByte('\n')
|
||||
case []any:
|
||||
for _, block := range ct {
|
||||
if bm, ok := block.(map[string]any); ok && bm["type"] == "text" {
|
||||
if t, ok := bm["text"].(string); ok {
|
||||
b.WriteString(t)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
texts[i] = messageText(req.Messages[i].Content)
|
||||
}
|
||||
return router.Probe{
|
||||
Prompt: b.String(),
|
||||
}, true
|
||||
parts := messageProbeParts(texts)
|
||||
return router.Probe{Prompt: router.JoinTurns(parts), Messages: parts}, true
|
||||
}
|
||||
|
||||
|
||||
@@ -246,11 +246,12 @@ var _ = Describe("RouteModel rendered classifier prompt", func() {
|
||||
"rendered prompt must end at assistant-open marker. got: %q", s.lastPrompt)
|
||||
})
|
||||
|
||||
It("falls back to chatMLRenderer when the classifier model has no chat_message template", func() {
|
||||
// Partial template config: only outer Chat, no per-role
|
||||
// piece. The renderer must refuse rather than emit a prompt
|
||||
// that drops the system turn, so the score classifier's
|
||||
// built-in ChatML default takes over.
|
||||
It("refuses to build the router when the classifier model has no chat_message template", func() {
|
||||
// Partial template config: only the outer Chat, no per-role piece.
|
||||
// The router renders the scoring prompt client-side from the
|
||||
// classifier model's own template, so a missing template is a hard
|
||||
// error rather than a silent fall back to a generic ChatML envelope
|
||||
// the model may not have been trained on.
|
||||
writePartialClassifierModel(modelDir, "arch-router")
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
|
||||
@@ -266,19 +267,9 @@ var _ = Describe("RouteModel rendered classifier prompt", func() {
|
||||
ModelLookup: loaderLookup(loader, appConfig),
|
||||
Evaluator: eval,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// chatMLRenderer fallback emits its own envelope and still
|
||||
// embeds the routing system prompt. OpenAIProbeFromRequest
|
||||
// appends "\n" after each message body, so the user content
|
||||
// reaches the renderer as "hello world\n" — the substring
|
||||
// match accounts for that.
|
||||
Expect(s.lastPrompt).To(ContainSubstring("<routes>"),
|
||||
"fallback renderer also dropped the system prompt")
|
||||
Expect(s.lastPrompt).To(ContainSubstring("<|im_start|>system\n"))
|
||||
Expect(s.lastPrompt).To(ContainSubstring("<|im_start|>user\nhello world\n<|im_end|>"))
|
||||
Expect(strings.HasSuffix(s.lastPrompt, "<|im_start|>assistant\n")).To(BeTrue(),
|
||||
"chatMLRenderer fallback must end at assistant-open marker. got: %q", s.lastPrompt)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no chat template"),
|
||||
"missing classifier template must surface as a clear config error. got: %v", err)
|
||||
})
|
||||
|
||||
It("uses the classifier model's first stopword as the candidate suffix", func() {
|
||||
@@ -533,8 +524,8 @@ template:
|
||||
|
||||
// writePartialClassifierModel writes a classifier model that has the
|
||||
// outer Chat template but no ChatMessage — exercises the
|
||||
// newTemplateRenderer "refuse partial templating" branch that hands
|
||||
// off to chatMLRenderer.
|
||||
// newTemplateRenderer "refuse partial templating" branch, which makes
|
||||
// buildClassifier reject the router with a missing-template error.
|
||||
func writePartialClassifierModel(modelDir, name string) {
|
||||
body := `name: ` + name + `
|
||||
backend: llama-cpp
|
||||
|
||||
@@ -224,4 +224,38 @@ test.describe('Model Editor - Interactive Tab', () => {
|
||||
expect(estimateCalled).toBe(true)
|
||||
})
|
||||
|
||||
test('interactive tab scrolls at body height (no inner overflow pane) and tracks the active section', async ({ page }) => {
|
||||
// Regression: the form sections used to live inside an overflow:auto pane
|
||||
// with maxHeight: calc(100vh - 340px), which kept the global footer in
|
||||
// view on every screen and ate ~50px of editing room on short windows.
|
||||
// Pin two pieces of the fix:
|
||||
// 1. The two-column container (sticky nav + content) has no scrollable
|
||||
// inner element on its content side — body-scroll handles overflow.
|
||||
// 2. The active-section tracker now listens to window scroll. Scrolling
|
||||
// the window should run the tracker without throwing, and the
|
||||
// `<nav>` sidebar must still render.
|
||||
const contentOverflowY = await page.evaluate(() => {
|
||||
const sidebar = document.querySelector('nav')
|
||||
// The content column is the next sibling of the sticky sidebar.
|
||||
const content = sidebar?.nextElementSibling
|
||||
return content ? getComputedStyle(content).overflowY : 'no-content'
|
||||
})
|
||||
expect(['visible', 'normal', 'auto', 'scroll', 'no-content']).toContain(contentOverflowY)
|
||||
expect(contentOverflowY).not.toBe('scroll')
|
||||
// 'auto' could exist on some browsers but should NOT — the fix removes it.
|
||||
// We assert the strong invariant separately.
|
||||
expect(['auto']).not.toContain(contentOverflowY)
|
||||
|
||||
// Add a couple of fields to give the page a touch more height, then
|
||||
// force a window scroll. The tracker should run; the sidebar should
|
||||
// remain visible.
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Temperature')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'Temperature' }).first().click()
|
||||
await page.evaluate(() => window.scrollTo(0, 200))
|
||||
await page.waitForTimeout(50)
|
||||
await expect(page.locator('nav').first()).toBeVisible()
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
94
core/http/react-ui/e2e/model-editor-back-nav.spec.js
Normal file
94
core/http/react-ui/e2e/model-editor-back-nav.spec.js
Normal file
@@ -0,0 +1,94 @@
|
||||
import { test, expect } from './coverage-fixtures.js'
|
||||
|
||||
// Exercises the "Back to <page>" navigation convention: whichever page links
|
||||
// into the Model Editor stamps its origin as react-router location state, and
|
||||
// the editor's Back button returns there (captioned with the origin) instead
|
||||
// of a hardcoded route. Also covers the Middleware page's ?tab= persistence,
|
||||
// which is what lets the editor return you to the exact tab you came from.
|
||||
|
||||
const MOCK_METADATA = {
|
||||
sections: [{ id: 'general', label: 'General', icon: 'settings', order: 0 }],
|
||||
fields: [
|
||||
{ path: 'name', yaml_key: 'name', go_type: 'string', ui_type: 'string', section: 'general', label: 'Model Name', description: 'id', component: 'input', order: 0 },
|
||||
],
|
||||
}
|
||||
const MOCK_YAML = 'name: mock-model\nbackend: mock-backend\n'
|
||||
|
||||
// Router config with one model, so the Routing tab renders an editable model
|
||||
// link we can click through to the editor.
|
||||
const MOCK_MIDDLEWARE_STATUS = {
|
||||
pii: { enabled_globally: false, default_enabled_for_backends: [], patterns: [], models: [], recent_event_count: 0 },
|
||||
router: {
|
||||
configured: true,
|
||||
models: [{ name: 'smart-router', classifier: 'score', fallback: 'qwen-7b', policies: [], candidates: [] }],
|
||||
recent_decision_count: 0,
|
||||
available_classifiers: ['score'],
|
||||
},
|
||||
}
|
||||
|
||||
// Make the editor render for any model name (the header — and thus the Back
|
||||
// button — only appears once metadata + config have loaded).
|
||||
async function mockEditorEndpoints(page) {
|
||||
await page.route('**/api/models/config-metadata*', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_METADATA) }))
|
||||
await page.route('**/api/models/edit/**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ config: MOCK_YAML, name: 'mock-model' }) }))
|
||||
await page.route('**/api/models/config-json/**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: '{}' }))
|
||||
}
|
||||
|
||||
test.describe('Model Editor — Back navigation', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.route('**/api/auth/status', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ authEnabled: false, staticApiKeyRequired: false, providers: [] }) }))
|
||||
await mockEditorEndpoints(page)
|
||||
})
|
||||
|
||||
test('Back returns to Manage with a "Back to Manage" caption', async ({ page }) => {
|
||||
await page.goto('/app/manage')
|
||||
await expect(page.locator('.table')).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
// Open the first row's action menu and pick "Edit configuration".
|
||||
const trigger = page.locator('button.action-menu__trigger').first()
|
||||
await expect(trigger).toBeVisible()
|
||||
await trigger.click()
|
||||
await page.getByRole('menuitem', { name: 'Edit configuration' }).click()
|
||||
|
||||
await expect(page).toHaveURL(/\/app\/model-editor\//)
|
||||
const back = page.getByRole('button', { name: /Back to Manage/ })
|
||||
await expect(back).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
await back.click()
|
||||
await expect(page).toHaveURL(/\/app\/manage/)
|
||||
})
|
||||
|
||||
test('returns to the originating Middleware tab (?tab=routing) it was opened from', async ({ page }) => {
|
||||
await page.route('**/api/middleware/status', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_MIDDLEWARE_STATUS) }))
|
||||
await page.route('**/api/pii/events?**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ events: [] }) }))
|
||||
await page.route('**/api/router/decisions?**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ decisions: [] }) }))
|
||||
|
||||
await page.goto('/app/middleware')
|
||||
// Switching to Routing must push the tab into the URL.
|
||||
await page.getByRole('button', { name: /Routing/i }).click()
|
||||
await expect(page).toHaveURL(/[?&]tab=routing/)
|
||||
|
||||
// Click through to the router model's config, then back.
|
||||
await page.getByRole('link', { name: 'smart-router' }).click()
|
||||
await expect(page).toHaveURL(/\/app\/model-editor\/smart-router/)
|
||||
const back = page.getByRole('button', { name: /Back to Middleware/ })
|
||||
await expect(back).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
await back.click()
|
||||
// Returns to the exact tab, not the default Filtering tab.
|
||||
await expect(page).toHaveURL(/\/app\/middleware\?tab=routing/)
|
||||
await expect(page.getByText('smart-router').first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('falls back to "Back to Manage" on a direct visit with no origin state', async ({ page }) => {
|
||||
await page.goto('/app/model-editor/mock-model')
|
||||
await expect(page.getByRole('button', { name: /Back to Manage/ })).toBeVisible({ timeout: 10_000 })
|
||||
})
|
||||
})
|
||||
@@ -48,3 +48,77 @@ test.describe('Traces - Error Display', () => {
|
||||
await expect(page.locator('th', { hasText: 'Type' })).toBeVisible()
|
||||
})
|
||||
})
|
||||
|
||||
// Pin the BackendTraceDetail expansion path for a vector_store trace —
|
||||
// the type that surfaces the router's embedding-cache plumbing. The
|
||||
// row click triggers the detail render, which exercises typeBadgeStyle
|
||||
// (with the new vector_store badge color), the DataFields component
|
||||
// (op / outcome / vector_dim / similarity), and the "View backend
|
||||
// logs" link that resolves to the store namespace. Without this spec
|
||||
// the new color entry plus the data-field render branches stay
|
||||
// uncovered, dragging UI line coverage below the regression gate.
|
||||
test.describe('Traces - vector_store backend trace detail', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.route('**/api/traces', (route) => {
|
||||
route.fulfill({ contentType: 'application/json', body: '[]' })
|
||||
})
|
||||
await page.route('**/api/backend-traces', (route) => {
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify([
|
||||
{
|
||||
type: 'vector_store',
|
||||
timestamp: '2026-05-28T13:56:25.558Z',
|
||||
model_name: 'router-cache-smart-router',
|
||||
backend: 'local-store',
|
||||
summary: 'search hit (sim=0.989)',
|
||||
duration: 160_000_000,
|
||||
error: '',
|
||||
data: {
|
||||
op: 'search',
|
||||
outcome: 'hit',
|
||||
vector_dim: 768,
|
||||
similarity: 0.9899752140045166,
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'vector_store',
|
||||
timestamp: '2026-05-28T13:49:07.545Z',
|
||||
model_name: 'router-cache-smart-router',
|
||||
backend: 'local-store',
|
||||
summary: 'search miss',
|
||||
duration: 100_000_000,
|
||||
error: '',
|
||||
data: {
|
||||
op: 'search',
|
||||
outcome: 'miss',
|
||||
vector_dim: 768,
|
||||
},
|
||||
},
|
||||
]),
|
||||
})
|
||||
})
|
||||
await page.goto('/app/traces')
|
||||
await expect(page.locator('text=Tracing is')).toBeVisible({ timeout: 10_000 })
|
||||
await page.locator('button', { hasText: 'Backend Traces' }).click()
|
||||
})
|
||||
|
||||
test('renders type badge and expands data fields on row click', async ({ page }) => {
|
||||
// The vector_store badge appears in the type column.
|
||||
await expect(page.locator('td span', { hasText: 'vector_store' }).first()).toBeVisible()
|
||||
|
||||
// Clicking the first row expands BackendTraceDetail, which renders
|
||||
// the four data fields. Use the first row's "search hit" summary
|
||||
// as the anchor to disambiguate from the miss row below.
|
||||
await page.locator('tr', { hasText: 'search hit' }).first().click()
|
||||
|
||||
// DataFields renders op/outcome/vector_dim/similarity as label/value pairs.
|
||||
// 'hit' appears as the rendered outcome value.
|
||||
await expect(page.locator('text=outcome').first()).toBeVisible()
|
||||
await expect(page.locator('text=hit').first()).toBeVisible()
|
||||
|
||||
// The model_name → /app/backend-logs link is the BackendTraceDetail
|
||||
// affordance for jumping to logs for the store namespace.
|
||||
await expect(page.locator('a', { hasText: 'View backend logs' })).toBeVisible()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,6 +4,12 @@ export default defineConfig({
|
||||
testDir: './e2e',
|
||||
timeout: 30_000,
|
||||
retries: process.env.CI ? 2 : 0,
|
||||
// TEMPORARY: cap parallelism. Playwright's default (cores/2) oversubscribes
|
||||
// high-core dev machines and intermittently starves the page-teardown
|
||||
// coverage harvest past the 30s test timeout (flaky "Tearing down page"
|
||||
// failures, different specs each run). Capped at 8 pending a proper
|
||||
// root-cause fix; override with PW_WORKERS.
|
||||
workers: process.env.PW_WORKERS ? Number(process.env.PW_WORKERS) : 8,
|
||||
reporter: process.env.CI ? 'html' : 'list',
|
||||
use: {
|
||||
baseURL: 'http://127.0.0.1:8089',
|
||||
|
||||
@@ -13,6 +13,7 @@ import { useCodeMirror } from '../hooks/useCodeMirror'
|
||||
import { useTheme } from '../contexts/ThemeContext'
|
||||
import { getThemeExtension } from '../utils/cmTheme'
|
||||
import { createYamlCompletionSource } from '../utils/cmYamlComplete'
|
||||
import { goTemplate } from '../utils/cmGoTemplate'
|
||||
|
||||
function yamlIssueToDiagnostic(issue, cmDoc, severity) {
|
||||
const len = cmDoc.length
|
||||
@@ -43,14 +44,17 @@ const yamlLinter = linter(view => {
|
||||
return diagnostics
|
||||
})
|
||||
|
||||
export default function CodeEditor({ value, onChange, disabled, minHeight = '500px', fields }) {
|
||||
export default function CodeEditor({ value, onChange, disabled, minHeight = '500px', fields, language = 'yaml' }) {
|
||||
const containerRef = useRef(null)
|
||||
const { theme } = useTheme()
|
||||
const isGoTemplate = language === 'gotemplate'
|
||||
|
||||
// Static extensions — only recreate when fields change
|
||||
// Static extensions — only recreate when fields/language change
|
||||
const extensions = useMemo(() => {
|
||||
const exts = [
|
||||
yaml(),
|
||||
// Go templates aren't YAML — skip the YAML mode/linter so valid
|
||||
// `{{ ... }}` syntax isn't flagged as a YAML parse error.
|
||||
isGoTemplate ? goTemplate : yaml(),
|
||||
lineNumbers(),
|
||||
highlightActiveLineGutter(),
|
||||
highlightActiveLine(),
|
||||
@@ -59,8 +63,6 @@ export default function CodeEditor({ value, onChange, disabled, minHeight = '500
|
||||
indentOnInput(),
|
||||
bracketMatching(),
|
||||
highlightSelectionMatches(),
|
||||
yamlLinter,
|
||||
lintGutter(),
|
||||
history(),
|
||||
indentUnit.of(' '),
|
||||
EditorState.tabSize.of(2),
|
||||
@@ -77,15 +79,18 @@ export default function CodeEditor({ value, onChange, disabled, minHeight = '500
|
||||
}),
|
||||
]
|
||||
|
||||
if (fields && fields.length > 0) {
|
||||
exts.push(autocompletion({
|
||||
override: [createYamlCompletionSource(fields)],
|
||||
activateOnTyping: true,
|
||||
}))
|
||||
if (!isGoTemplate) {
|
||||
exts.push(yamlLinter, lintGutter())
|
||||
if (fields && fields.length > 0) {
|
||||
exts.push(autocompletion({
|
||||
override: [createYamlCompletionSource(fields)],
|
||||
activateOnTyping: true,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
return exts
|
||||
}, [minHeight, fields])
|
||||
}, [minHeight, fields, isGoTemplate])
|
||||
|
||||
// Dynamic extensions — reconfigured via Compartments (preserves undo/cursor/scroll)
|
||||
const dynamicExtensions = useMemo(() => ({
|
||||
|
||||
@@ -16,6 +16,7 @@ const PROVIDER_TO_CAPABILITY = {
|
||||
'models:tts': 'FLAG_TTS',
|
||||
'models:transcript': 'FLAG_TRANSCRIPT',
|
||||
'models:vad': 'FLAG_VAD',
|
||||
'models:score': 'FLAG_SCORE',
|
||||
}
|
||||
|
||||
function coerceValue(raw, uiType) {
|
||||
@@ -325,7 +326,7 @@ export default function ConfigFieldRenderer({ field, value, onChange, onRemove,
|
||||
</div>
|
||||
{isStructured
|
||||
? <StructuredCodeEditor value={value} onChange={handleChange} minHeight="80px" />
|
||||
: <CodeEditor value={value || ''} onChange={handleChange} minHeight="80px" />}
|
||||
: <CodeEditor value={value || ''} onChange={handleChange} minHeight="80px" language={field.language} />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import SearchableSelect from './SearchableSelect'
|
||||
const ACTION_OPTIONS = [
|
||||
{ value: 'mask', label: 'Mask — replace with a [REDACTED:id] placeholder' },
|
||||
{ value: 'block', label: 'Block — reject the request (request side) / mask in stream' },
|
||||
{ value: 'route_local', label: 'Route local — keep text, force local-only routing' },
|
||||
{ value: 'allow', label: 'Allow — detect & log, leave text unchanged' },
|
||||
]
|
||||
|
||||
export default function PIIPatternListEditor({ value, onChange }) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react'
|
||||
import { useNavigate, useOutletContext } from 'react-router-dom'
|
||||
import { useNavigate, useOutletContext, useLocation } from 'react-router-dom'
|
||||
import { agentJobsApi, modelsApi } from '../utils/api'
|
||||
import { fromState } from '../utils/editorNav'
|
||||
import { useModels } from '../hooks/useModels'
|
||||
import { useAuth } from '../context/AuthContext'
|
||||
import { useUserMap } from '../hooks/useUserMap'
|
||||
@@ -13,6 +14,7 @@ import ConfirmDialog from '../components/ConfirmDialog'
|
||||
export default function AgentJobs() {
|
||||
const { addToast } = useOutletContext()
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
const { models } = useModels()
|
||||
const { isAdmin, authEnabled, user } = useAuth()
|
||||
const userMap = useUserMap()
|
||||
@@ -338,7 +340,7 @@ export default function AgentJobs() {
|
||||
</td>
|
||||
<td>
|
||||
{task.model ? (
|
||||
<a onClick={() => navigate(`/app/model-editor/${encodeURIComponent(task.model)}`)} style={{ cursor: 'pointer', color: 'var(--color-primary)', fontSize: '0.8125rem' }}>
|
||||
<a onClick={() => navigate(`/app/model-editor/${encodeURIComponent(task.model)}`, { state: fromState(location, 'Agent Jobs') })} style={{ cursor: 'pointer', color: 'var(--color-primary)', fontSize: '0.8125rem' }}>
|
||||
{task.model}
|
||||
</a>
|
||||
) : '-'}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState, useEffect, useRef, useCallback, useMemo } from 'react'
|
||||
import { useParams, useOutletContext, useNavigate } from 'react-router-dom'
|
||||
import { useParams, useOutletContext, useNavigate, useLocation } from 'react-router-dom'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { fromState } from '../utils/editorNav'
|
||||
import { useChat } from '../hooks/useChat'
|
||||
import ModelSelector from '../components/ModelSelector'
|
||||
import { renderMarkdown, highlightAll } from '../utils/markdown'
|
||||
@@ -285,6 +286,7 @@ export default function Chat() {
|
||||
const { model: urlModel } = useParams()
|
||||
const { addToast } = useOutletContext()
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
const { t } = useTranslation('chat')
|
||||
const { isAdmin } = useAuth()
|
||||
const { operations } = useOperations()
|
||||
@@ -904,7 +906,7 @@ export default function Chat() {
|
||||
<button
|
||||
type="button"
|
||||
className="btn btn-secondary btn-sm"
|
||||
onClick={() => navigate(`/app/model-editor/${encodeURIComponent(activeChat.model)}`)}
|
||||
onClick={() => navigate(`/app/model-editor/${encodeURIComponent(activeChat.model)}`, { state: fromState(location, 'Chat') })}
|
||||
title={t('header.editConfig')}
|
||||
>
|
||||
<i className="fas fa-pen-to-square" /> {t('header.editConfig')}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import { useNavigate, useOutletContext, useSearchParams } from 'react-router-dom'
|
||||
import { useNavigate, useOutletContext, useSearchParams, useLocation } from 'react-router-dom'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { fromState } from '../utils/editorNav'
|
||||
import ResourceMonitor from '../components/ResourceMonitor'
|
||||
import ConfirmDialog from '../components/ConfirmDialog'
|
||||
import NodeDistributionChip from '../components/NodeDistributionChip'
|
||||
@@ -121,6 +122,7 @@ function formatBackendVersion(metadata) {
|
||||
export default function Manage() {
|
||||
const { addToast } = useOutletContext()
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
const { t } = useTranslation('admin')
|
||||
const [searchParams, setSearchParams] = useSearchParams()
|
||||
const initialTab = searchParams.get('tab') || localStorage.getItem('manage-tab') || 'models'
|
||||
@@ -673,7 +675,7 @@ export default function Manage() {
|
||||
onClick: () => handleTogglePinned(model.id, model.pinned),
|
||||
disabled: pinningModels.has(model.id) || !!model.disabled },
|
||||
{ key: 'edit', icon: 'fa-pen-to-square', label: 'Edit configuration',
|
||||
onClick: () => navigate(`/app/model-editor/${encodeURIComponent(model.id)}`) },
|
||||
onClick: () => navigate(`/app/model-editor/${encodeURIComponent(model.id)}`, { state: fromState(location, 'Manage') }) },
|
||||
{ key: 'logs', icon: 'fa-terminal', label: 'Backend logs',
|
||||
onClick: () => navigate(`/app/backend-logs/${encodeURIComponent(model.id)}`) },
|
||||
{ divider: true },
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState, useEffect, useCallback, useRef, useMemo, Fragment } from 'react'
|
||||
import { useOutletContext, Link, useNavigate } from 'react-router-dom'
|
||||
import { useOutletContext, Link, useNavigate, useLocation, useSearchParams } from 'react-router-dom'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
import { fromState } from '../utils/editorNav'
|
||||
import { settingsApi } from '../utils/api'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
|
||||
@@ -26,13 +27,13 @@ const TABS = [
|
||||
{ id: 'events', label: 'Events', icon: 'fa-list-ul' },
|
||||
]
|
||||
|
||||
const ACTIONS = ['mask', 'block', 'route_local']
|
||||
const ACTIONS = ['mask', 'block', 'allow']
|
||||
|
||||
function actionBadge(action) {
|
||||
const colors = {
|
||||
mask: 'var(--color-primary)',
|
||||
block: 'var(--color-error)',
|
||||
route_local: 'var(--color-warning)',
|
||||
allow: 'var(--color-warning)',
|
||||
}
|
||||
return (
|
||||
<span style={{
|
||||
@@ -75,9 +76,20 @@ export default function Middleware() {
|
||||
const [events, setEvents] = useState([])
|
||||
const [decisions, setDecisions] = useState([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [activeTab, setActiveTab] = useState('filtering')
|
||||
// The active tab lives in the URL (?tab=) so deep links and the model-editor
|
||||
// Back button (which captures location.search) return to the same tab; a
|
||||
// localStorage fallback restores it on a bare visit. Mirrors the Manage page.
|
||||
const [searchParams, setSearchParams] = useSearchParams()
|
||||
const initialTab = searchParams.get('tab') || localStorage.getItem('middleware-tab') || 'filtering'
|
||||
const [activeTab, setActiveTab] = useState(TABS.some(t => t.id === initialTab) ? initialTab : 'filtering')
|
||||
const [pendingPattern, setPendingPattern] = useState(null) // id while a PUT is in flight
|
||||
|
||||
const selectTab = (id) => {
|
||||
setActiveTab(id)
|
||||
localStorage.setItem('middleware-tab', id)
|
||||
setSearchParams({ tab: id })
|
||||
}
|
||||
|
||||
// silent=true on background polls: skips the loading spinner and
|
||||
// suppresses toast spam if the server is briefly unreachable.
|
||||
const fetchAll = useCallback(async (silent = false) => {
|
||||
@@ -178,7 +190,7 @@ export default function Middleware() {
|
||||
<button
|
||||
key={tab.id}
|
||||
className={`btn btn-sm ${activeTab === tab.id ? 'btn-primary' : 'btn-secondary'}`}
|
||||
onClick={() => setActiveTab(tab.id)}
|
||||
onClick={() => selectTab(tab.id)}
|
||||
>
|
||||
<i className={`fas ${tab.icon}`} style={{ marginRight: 4 }} />
|
||||
{tab.label}
|
||||
@@ -215,6 +227,7 @@ export default function Middleware() {
|
||||
}
|
||||
|
||||
function FilteringTab({ status, pendingPattern, onSetAction, onSetDisabled, onPersist, persisting }) {
|
||||
const location = useLocation()
|
||||
if (!status?.pii) return null
|
||||
const pii = status.pii
|
||||
|
||||
@@ -353,6 +366,7 @@ function FilteringTab({ status, pendingPattern, onSetAction, onSetDisabled, onPe
|
||||
<td>
|
||||
<Link
|
||||
to={`/app/model-editor/${encodeURIComponent(m.name)}`}
|
||||
state={fromState(location, 'Middleware')}
|
||||
className="btn btn-secondary btn-sm"
|
||||
style={{ fontSize: '0.6875rem', padding: '2px 8px' }}
|
||||
title={`Edit ${m.name}.yaml`}
|
||||
@@ -485,6 +499,7 @@ function DecisionDetail({ d }) {
|
||||
|
||||
function RoutingTab({ status, decisions }) {
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
const router = status?.router || { configured: false }
|
||||
const [expanded, setExpanded] = useState(() => new Set())
|
||||
|
||||
@@ -519,7 +534,7 @@ function RoutingTab({ status, decisions }) {
|
||||
<button
|
||||
className="btn btn-primary"
|
||||
style={{ marginTop: 'var(--spacing-md)' }}
|
||||
onClick={() => navigate('/app/model-editor?template=router')}
|
||||
onClick={() => navigate('/app/model-editor?template=router', { state: fromState(location, 'Middleware') })}
|
||||
>
|
||||
<i className="fas fa-plus" /> Create routing model
|
||||
</button>
|
||||
@@ -539,7 +554,7 @@ function RoutingTab({ status, decisions }) {
|
||||
</span>
|
||||
<button
|
||||
className="btn btn-secondary btn-sm"
|
||||
onClick={() => navigate('/app/model-editor?template=router')}
|
||||
onClick={() => navigate('/app/model-editor?template=router', { state: fromState(location, 'Middleware') })}
|
||||
title="Open the model editor with the Routing Model template pre-selected"
|
||||
>
|
||||
<i className="fas fa-plus" /> Add routing model
|
||||
@@ -560,7 +575,9 @@ function RoutingTab({ status, decisions }) {
|
||||
<tbody>
|
||||
{router.models.map(m => (
|
||||
<tr key={m.name}>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.8125rem', fontWeight: 600 }}>{m.name}</td>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.8125rem', fontWeight: 600 }}>
|
||||
<Link to={`/app/model-editor/${encodeURIComponent(m.name)}`} state={fromState(location, 'Middleware')} title="Edit this router model's config">{m.name}</Link>
|
||||
</td>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.75rem' }}>{m.classifier}</td>
|
||||
<td style={{ fontSize: '0.75rem' }}>
|
||||
{(m.candidates || []).map((c, i) => (
|
||||
@@ -657,6 +674,7 @@ function RoutingTab({ status, decisions }) {
|
||||
|
||||
function ProxyTab({ status, addToast, onChanged }) {
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
const mitm = status?.mitm
|
||||
const serverListen = mitm?.configured_addr || ''
|
||||
|
||||
@@ -722,7 +740,7 @@ function ProxyTab({ status, addToast, onChanged }) {
|
||||
<code style={{ fontFamily: 'var(--font-mono)' }}>{h}</code>
|
||||
{' claimed by: '}
|
||||
{(conflicts[h] || []).map(name => (
|
||||
<Link key={name} to={`/app/model-editor/${encodeURIComponent(name)}`} style={{ marginRight: 6, fontFamily: 'var(--font-mono)' }}>
|
||||
<Link key={name} to={`/app/model-editor/${encodeURIComponent(name)}`} state={fromState(location, 'Middleware')} style={{ marginRight: 6, fontFamily: 'var(--font-mono)' }}>
|
||||
{name}
|
||||
</Link>
|
||||
))}
|
||||
@@ -754,7 +772,7 @@ function ProxyTab({ status, addToast, onChanged }) {
|
||||
<ul style={{ margin: 0, paddingLeft: 20, fontFamily: 'var(--font-mono)' }}>
|
||||
{ownerEntries.map(([host, name]) => (
|
||||
<li key={host}>
|
||||
{host} → <Link to={`/app/model-editor/${encodeURIComponent(name)}`}>{name}</Link>
|
||||
{host} → <Link to={`/app/model-editor/${encodeURIComponent(name)}`} state={fromState(location, 'Middleware')}>{name}</Link>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
@@ -784,7 +802,7 @@ function ProxyTab({ status, addToast, onChanged }) {
|
||||
<h2 style={{ fontSize: '1rem', fontWeight: 600, margin: 0 }}>MITM Models</h2>
|
||||
<button
|
||||
className="btn btn-secondary btn-sm"
|
||||
onClick={() => navigate('/app/model-editor?template=mitm')}
|
||||
onClick={() => navigate('/app/model-editor?template=mitm', { state: fromState(location, 'Middleware') })}
|
||||
title="Open the model editor with the MITM Intercept template pre-selected"
|
||||
>
|
||||
<i className="fas fa-plus" /> Add MITM model
|
||||
@@ -815,6 +833,7 @@ function ProxyTab({ status, addToast, onChanged }) {
|
||||
<td>
|
||||
<Link
|
||||
to={`/app/model-editor/${encodeURIComponent(m.name)}`}
|
||||
state={fromState(location, 'Middleware')}
|
||||
className="btn btn-secondary btn-sm"
|
||||
style={{ fontSize: '0.6875rem', padding: '2px 8px' }}
|
||||
>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState, useEffect, useRef, useMemo, useCallback } from 'react'
|
||||
import { useParams, useNavigate, useOutletContext, useSearchParams } from 'react-router-dom'
|
||||
import { useParams, useNavigate, useOutletContext, useSearchParams, useLocation } from 'react-router-dom'
|
||||
import YAML from 'yaml'
|
||||
import { modelsApi } from '../utils/api'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
@@ -17,7 +17,8 @@ const SECTION_ICONS = {
|
||||
general: 'fa-cog', llm: 'fa-microchip', parameters: 'fa-sliders',
|
||||
templates: 'fa-file-code', functions: 'fa-wrench', reasoning: 'fa-brain',
|
||||
diffusers: 'fa-image', tts: 'fa-volume-up', pipeline: 'fa-code-branch',
|
||||
grpc: 'fa-server', agent: 'fa-robot', mcp: 'fa-plug', other: 'fa-ellipsis-h',
|
||||
grpc: 'fa-server', agent: 'fa-robot', mcp: 'fa-plug', router: 'fa-route', proxy: 'fa-cloud',
|
||||
mitm: 'fa-user-secret', pii: 'fa-user-shield', other: 'fa-ellipsis-h',
|
||||
}
|
||||
|
||||
const SECTION_COLORS = {
|
||||
@@ -25,7 +26,8 @@ const SECTION_COLORS = {
|
||||
templates: 'var(--color-warning)', functions: 'var(--color-info, var(--color-primary))',
|
||||
reasoning: 'var(--color-accent)', diffusers: 'var(--color-warning)', tts: 'var(--color-success)',
|
||||
pipeline: 'var(--color-accent)', grpc: 'var(--color-text-muted)', agent: 'var(--color-primary)',
|
||||
mcp: 'var(--color-accent)', other: 'var(--color-text-muted)',
|
||||
mcp: 'var(--color-accent)', router: 'var(--color-accent)', proxy: 'var(--color-info, var(--color-primary))',
|
||||
mitm: 'var(--color-warning)', pii: 'var(--color-error)', other: 'var(--color-text-muted)',
|
||||
}
|
||||
|
||||
function flattenConfig(obj, prefix = '') {
|
||||
@@ -71,6 +73,10 @@ export default function ModelEditor() {
|
||||
const { name } = useParams()
|
||||
const [searchParams] = useSearchParams()
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
// Where the Back button returns to. Set by whichever page linked here (see
|
||||
// utils/editorNav); falls back to the historical defaults for direct visits.
|
||||
const backState = location.state && location.state.from ? location.state : null
|
||||
const { addToast } = useOutletContext()
|
||||
const { sections, fields, loading: metaLoading, error: metaError } = useConfigMetadata()
|
||||
|
||||
@@ -89,7 +95,6 @@ export default function ModelEditor() {
|
||||
const [activeSection, setActiveSection] = useState(null)
|
||||
const [tabSwitchWarning, setTabSwitchWarning] = useState(false)
|
||||
|
||||
const contentRef = useRef(null)
|
||||
const sectionRefs = useRef({})
|
||||
|
||||
const vramEstimate = useVramEstimate({
|
||||
@@ -187,25 +192,29 @@ export default function ModelEditor() {
|
||||
}
|
||||
}, [activeSection, activeSections])
|
||||
|
||||
// Scroll tracking
|
||||
// Scroll tracking — the editor used to have its own overflow:auto pane
|
||||
// and listened to that container's scroll; the pane has been removed so
|
||||
// small screens don't have the global footer always clipping into the
|
||||
// form. Scrolling now happens at the window level, and the anchor for
|
||||
// "which section is at the top" is a fixed viewport offset (the sticky
|
||||
// sidebar sits roughly at the top of the editor area).
|
||||
useEffect(() => {
|
||||
const container = contentRef.current
|
||||
if (!container || tab !== 'interactive') return
|
||||
if (tab !== 'interactive') return
|
||||
const onScroll = () => {
|
||||
const containerTop = container.getBoundingClientRect().top
|
||||
const anchorY = 80 // viewport px below which a section is "active"
|
||||
let closest = activeSections[0]?.id
|
||||
let closestDist = Infinity
|
||||
for (const s of activeSections) {
|
||||
const el = sectionRefs.current[s.id]
|
||||
if (el) {
|
||||
const dist = Math.abs(el.getBoundingClientRect().top - containerTop - 8)
|
||||
const dist = Math.abs(el.getBoundingClientRect().top - anchorY)
|
||||
if (dist < closestDist) { closestDist = dist; closest = s.id }
|
||||
}
|
||||
}
|
||||
if (closest) setActiveSection(closest)
|
||||
}
|
||||
container.addEventListener('scroll', onScroll, { passive: true })
|
||||
return () => container.removeEventListener('scroll', onScroll)
|
||||
window.addEventListener('scroll', onScroll, { passive: true })
|
||||
return () => window.removeEventListener('scroll', onScroll)
|
||||
}, [activeSections, configLoading, metaLoading, tab])
|
||||
|
||||
const scrollTo = (id) => {
|
||||
@@ -263,7 +272,9 @@ export default function ModelEditor() {
|
||||
if (!/^[a-zA-Z0-9_.-]+$/.test(modelName.trim())) { addToast('Invalid model name — use only letters, numbers, hyphens, underscores, and dots', 'error'); setSaving(false); return }
|
||||
await modelsApi.importConfig(JSON.stringify(config), 'application/json')
|
||||
addToast('Model created successfully', 'success')
|
||||
navigate(`/app/model-editor/${encodeURIComponent(modelName.trim())}`)
|
||||
// replace: the transient create URL shouldn't sit in history, so
|
||||
// Back (browser or in-page) skips it and returns to the linking page.
|
||||
navigate(`/app/model-editor/${encodeURIComponent(modelName.trim())}`, { replace: true, state: backState })
|
||||
} else {
|
||||
await modelsApi.patchConfig(name, config)
|
||||
setInitialValues(structuredClone(values))
|
||||
@@ -293,9 +304,9 @@ export default function ModelEditor() {
|
||||
addToast('Model created successfully', 'success')
|
||||
try {
|
||||
const parsed = YAML.parse(yamlText)
|
||||
if (parsed?.name) navigate(`/app/model-editor/${encodeURIComponent(parsed.name)}`)
|
||||
else navigate('/app/manage')
|
||||
} catch { navigate('/app/manage') }
|
||||
if (parsed?.name) navigate(`/app/model-editor/${encodeURIComponent(parsed.name)}`, { replace: true, state: backState })
|
||||
else navigate(backState ? backState.from : '/app/manage')
|
||||
} catch { navigate(backState ? backState.from : '/app/manage') }
|
||||
} else {
|
||||
const response = await fetch(apiUrl(`/models/edit/${encodeURIComponent(name)}`), {
|
||||
method: 'POST',
|
||||
@@ -323,7 +334,7 @@ export default function ModelEditor() {
|
||||
// editor URL points at a name that no longer exists on the backend.
|
||||
// Redirect so refreshes and subsequent saves hit the new name.
|
||||
if (parsedName && parsedName !== name) {
|
||||
navigate(`/app/model-editor/${encodeURIComponent(parsedName)}`, { replace: true })
|
||||
navigate(`/app/model-editor/${encodeURIComponent(parsedName)}`, { replace: true, state: backState })
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
@@ -405,9 +416,14 @@ export default function ModelEditor() {
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-sm)' }}>
|
||||
<button className="btn btn-secondary" onClick={() => {
|
||||
if (isCreateMode && selectedTemplate) { setSelectedTemplate(null); setValues({}); setActiveFieldPaths(new Set()) }
|
||||
else if (backState) navigate(backState.from)
|
||||
else navigate(isCreateMode ? '/app/models' : '/app/manage')
|
||||
}}>
|
||||
<i className="fas fa-arrow-left" /> Back
|
||||
<i className="fas fa-arrow-left" /> Back to {
|
||||
isCreateMode && selectedTemplate ? 'Templates'
|
||||
: backState ? backState.fromLabel
|
||||
: isCreateMode ? 'Models' : 'Manage'
|
||||
}
|
||||
</button>
|
||||
{!showTemplateSelector && tab === 'interactive' && (
|
||||
<button className={`btn ${isDirty ? 'btn-primary' : 'btn-secondary'}`} onClick={handleInteractiveSave} disabled={saving || !isDirty}>
|
||||
@@ -543,12 +559,15 @@ export default function ModelEditor() {
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Two-column layout */}
|
||||
<div style={{ display: 'flex', gap: 0, minHeight: 'calc(100vh - 340px)' }}>
|
||||
{/* Sidebar */}
|
||||
{/* Two-column layout. Both columns flow at body-scroll height —
|
||||
no inner overflow:auto here, so the global footer ends up
|
||||
below the content (like every other page) instead of pinned
|
||||
to the viewport bottom, eating editing space on short screens. */}
|
||||
<div style={{ display: 'flex', gap: 0 }}>
|
||||
{/* Sidebar — sticks to the top of the viewport as the body scrolls. */}
|
||||
<nav style={{
|
||||
width: 180, flexShrink: 0, padding: '0 var(--spacing-sm)',
|
||||
position: 'sticky', top: 0, alignSelf: 'flex-start',
|
||||
position: 'sticky', top: 'var(--spacing-md)', alignSelf: 'flex-start',
|
||||
}}>
|
||||
{activeSections.map(s => (
|
||||
<button
|
||||
@@ -584,10 +603,8 @@ export default function ModelEditor() {
|
||||
|
||||
{/* Content */}
|
||||
<div
|
||||
ref={contentRef}
|
||||
style={{
|
||||
flex: 1, overflow: 'auto', padding: '0 var(--spacing-lg) var(--spacing-xl) var(--spacing-md)',
|
||||
maxHeight: 'calc(100vh - 340px)',
|
||||
flex: 1, padding: '0 var(--spacing-lg) var(--spacing-xl) var(--spacing-md)',
|
||||
}}
|
||||
>
|
||||
{activeSections.length === 0 && (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState, useCallback, useEffect } from 'react'
|
||||
import { useNavigate, useOutletContext } from 'react-router-dom'
|
||||
import { useNavigate, useOutletContext, useLocation } from 'react-router-dom'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { fromState } from '../utils/editorNav'
|
||||
import { modelsApi } from '../utils/api'
|
||||
import { safeHref } from '../utils/url'
|
||||
import { useDebouncedCallback } from '../hooks/useDebounce'
|
||||
@@ -40,6 +41,7 @@ const FILTERS = [
|
||||
export default function Models() {
|
||||
const { addToast } = useOutletContext()
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
const { t } = useTranslation('models')
|
||||
const { operations } = useOperations()
|
||||
const { resources } = useResources()
|
||||
@@ -286,7 +288,7 @@ export default function Models() {
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
<button className="btn btn-primary btn-sm" onClick={() => navigate('/app/model-editor')}>
|
||||
<button className="btn btn-primary btn-sm" onClick={() => navigate('/app/model-editor', { state: fromState(location, 'Models') })}>
|
||||
<i className="fas fa-plus" /> {t('actions.addModel')}
|
||||
</button>
|
||||
<button className="btn btn-secondary btn-sm" onClick={() => navigate('/app/import-model')}>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState, useRef, useEffect, useCallback, useMemo } from 'react'
|
||||
import { useOutletContext, useNavigate } from 'react-router-dom'
|
||||
import { useOutletContext, useNavigate, useLocation } from 'react-router-dom'
|
||||
import { realtimeApi } from '../utils/api'
|
||||
import { fromState } from '../utils/editorNav'
|
||||
import ModelSelector from '../components/ModelSelector'
|
||||
import ClientMCPDropdown from '../components/ClientMCPDropdown'
|
||||
import { useMCPClient } from '../hooks/useMCPClient'
|
||||
@@ -38,6 +39,7 @@ function upsertAssistant(prev, itemId, text, mode) {
|
||||
export default function Talk() {
|
||||
const { addToast } = useOutletContext()
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
|
||||
// Pipeline models
|
||||
const [pipelineModels, setPipelineModels] = useState([])
|
||||
@@ -644,7 +646,7 @@ export default function Talk() {
|
||||
disabled={isConnected}
|
||||
searchPlaceholder="Search pipeline models..."
|
||||
/>
|
||||
<button className="btn btn-secondary btn-sm" onClick={() => navigate('/app/model-editor?template=pipeline')}
|
||||
<button className="btn btn-secondary btn-sm" onClick={() => navigate('/app/model-editor?template=pipeline', { state: fromState(location, 'Talk') })}
|
||||
style={{ marginTop: 'var(--spacing-xs)' }}>
|
||||
<i className="fas fa-plus" style={{ marginRight: 'var(--spacing-xs)' }} /> Create Pipeline Model
|
||||
</button>
|
||||
@@ -724,7 +726,7 @@ export default function Talk() {
|
||||
)}
|
||||
{selectedModelInfo && !isConnected && (
|
||||
<div style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
<button className="btn btn-secondary btn-sm" onClick={() => navigate(`/app/model-editor/${encodeURIComponent(selectedModel)}`)}>
|
||||
<button className="btn btn-secondary btn-sm" onClick={() => navigate(`/app/model-editor/${encodeURIComponent(selectedModel)}`, { state: fromState(location, 'Talk') })}>
|
||||
<i className="fas fa-pen-to-square" style={{ marginRight: 'var(--spacing-xs)' }} />
|
||||
{selectedModelInfo.self_contained ? ' Edit Model Config' : ' Edit Pipeline'}
|
||||
</button>
|
||||
|
||||
@@ -74,6 +74,7 @@ const TYPE_COLORS = {
|
||||
tokenize: { bg: 'var(--color-secondary-light)', color: 'var(--color-text-muted)' },
|
||||
detection: { bg: 'var(--color-info-light)', color: 'var(--color-data-8)' },
|
||||
model_load: { bg: 'var(--color-error-light)', color: 'var(--color-data-2)' },
|
||||
vector_store: { bg: 'var(--color-accent-light)', color: 'var(--color-data-7)' },
|
||||
}
|
||||
|
||||
function typeBadgeStyle(type) {
|
||||
|
||||
46
core/http/react-ui/src/utils/cmGoTemplate.js
vendored
Normal file
46
core/http/react-ui/src/utils/cmGoTemplate.js
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
import { StreamLanguage } from '@codemirror/language'
|
||||
|
||||
// Go text/template keywords valid inside an action `{{ ... }}`.
|
||||
const KEYWORDS = new Set([
|
||||
'if', 'else', 'end', 'range', 'with', 'define', 'template',
|
||||
'block', 'break', 'continue', 'nil', 'true', 'false',
|
||||
])
|
||||
|
||||
// Minimal Go text/template highlighter: distinguishes literal text from
|
||||
// action bodies inside `{{ ... }}`. Highlighting only — it does not
|
||||
// validate template grammar.
|
||||
export const goTemplate = StreamLanguage.define({
|
||||
startState() {
|
||||
return { inAction: false }
|
||||
},
|
||||
token(stream, state) {
|
||||
if (!state.inAction) {
|
||||
if (stream.match('{{')) {
|
||||
state.inAction = true
|
||||
return 'meta'
|
||||
}
|
||||
while (!stream.eol()) {
|
||||
if (stream.match('{{', false)) break
|
||||
stream.next()
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
if (stream.match('}}')) {
|
||||
state.inAction = false
|
||||
return 'meta'
|
||||
}
|
||||
if (stream.eatSpace()) return null
|
||||
if (stream.match(/^-(?=\s)/) || stream.match(/^[|()]/)) return 'operator'
|
||||
if (stream.match(/^"(?:[^"\\]|\\.)*"/)) return 'string'
|
||||
if (stream.match(/^`[^`]*`/)) return 'string'
|
||||
if (stream.match(/^\$[a-zA-Z0-9_]*/)) return 'variable-2'
|
||||
if (stream.match(/^\.[a-zA-Z0-9_.]*/)) return 'property'
|
||||
if (stream.match(/^[0-9]+(\.[0-9]+)?/)) return 'number'
|
||||
if (stream.match(/^[a-zA-Z_][a-zA-Z0-9_]*/)) {
|
||||
return KEYWORDS.has(stream.current()) ? 'keyword' : 'variable'
|
||||
}
|
||||
stream.next()
|
||||
return null
|
||||
},
|
||||
})
|
||||
15
core/http/react-ui/src/utils/editorNav.js
vendored
Normal file
15
core/http/react-ui/src/utils/editorNav.js
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
// Navigation context for the Model Editor.
|
||||
//
|
||||
// Many pages link into the Model Editor (Models, Manage, Chat, Talk, Agent
|
||||
// Jobs, Middleware…). Its in-page Back button used to navigate to a hardcoded
|
||||
// route, so it always dumped you on the same page regardless of where you came
|
||||
// from. To fix that, every linker passes this object as react-router location
|
||||
// state; the editor reads it and returns you to the exact page that linked
|
||||
// here, labelled "Back to <label>".
|
||||
//
|
||||
// `location` is the source page's useLocation() value, so `from` captures the
|
||||
// full path including any sub-route or query string — returning lands you
|
||||
// where you actually were, not just on the section root.
|
||||
export function fromState(location, label) {
|
||||
return { from: location.pathname + location.search, fromLabel: label }
|
||||
}
|
||||
@@ -58,13 +58,14 @@ func RegisterAnthropicRoutes(app *echo.Echo,
|
||||
middleware.AnthropicProbe,
|
||||
router.SourceAnthropic,
|
||||
middleware.ClassifierDeps{
|
||||
Scorer: application.Scorer,
|
||||
Embedder: application.Embedder,
|
||||
VectorStore: application.VectorStore,
|
||||
Reranker: application.Reranker,
|
||||
ModelLookup: application.ModelConfigLookup(),
|
||||
Registry: application.RouterClassifierRegistry(),
|
||||
Evaluator: application.TemplatesEvaluator(),
|
||||
Scorer: application.Scorer,
|
||||
TokenCounter: application.TokenCounter,
|
||||
Embedder: application.Embedder,
|
||||
VectorStore: application.VectorStore,
|
||||
Reranker: application.Reranker,
|
||||
ModelLookup: application.ModelConfigLookup(),
|
||||
Registry: application.RouterClassifierRegistry(),
|
||||
Evaluator: application.TemplatesEvaluator(),
|
||||
},
|
||||
),
|
||||
middleware.AdmissionControl(application.AdmissionLimiter(), application.PIIEvents()),
|
||||
|
||||
@@ -135,13 +135,14 @@ func RegisterMiddlewareRoutes(e *echo.Echo, app *application.Application) {
|
||||
app.ModelConfigLoader(),
|
||||
app.ApplicationConfig(),
|
||||
middleware.ClassifierDeps{
|
||||
Scorer: app.Scorer,
|
||||
Embedder: app.Embedder,
|
||||
VectorStore: app.VectorStore,
|
||||
Reranker: app.Reranker,
|
||||
ModelLookup: app.ModelConfigLookup(),
|
||||
Registry: app.RouterClassifierRegistry(),
|
||||
Evaluator: app.TemplatesEvaluator(),
|
||||
Scorer: app.Scorer,
|
||||
TokenCounter: app.TokenCounter,
|
||||
Embedder: app.Embedder,
|
||||
VectorStore: app.VectorStore,
|
||||
Reranker: app.Reranker,
|
||||
ModelLookup: app.ModelConfigLookup(),
|
||||
Registry: app.RouterClassifierRegistry(),
|
||||
Evaluator: app.TemplatesEvaluator(),
|
||||
},
|
||||
)
|
||||
e.POST("/api/router/decide", func(c echo.Context) error {
|
||||
@@ -220,8 +221,8 @@ func buildRouterStatus(app *application.Application) map[string]any {
|
||||
}
|
||||
|
||||
out := map[string]any{
|
||||
"configured": hasAny,
|
||||
"models": models,
|
||||
"configured": hasAny,
|
||||
"models": models,
|
||||
"recent_decision_count": recentCount,
|
||||
"available_classifiers": []string{router.ClassifierScore},
|
||||
}
|
||||
|
||||
@@ -71,13 +71,14 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
middleware.OpenAIProbe,
|
||||
router.SourceChat,
|
||||
middleware.ClassifierDeps{
|
||||
Scorer: application.Scorer,
|
||||
Embedder: application.Embedder,
|
||||
VectorStore: application.VectorStore,
|
||||
Reranker: application.Reranker,
|
||||
ModelLookup: application.ModelConfigLookup(),
|
||||
Registry: application.RouterClassifierRegistry(),
|
||||
Evaluator: application.TemplatesEvaluator(),
|
||||
Scorer: application.Scorer,
|
||||
TokenCounter: application.TokenCounter,
|
||||
Embedder: application.Embedder,
|
||||
VectorStore: application.VectorStore,
|
||||
Reranker: application.Reranker,
|
||||
ModelLookup: application.ModelConfigLookup(),
|
||||
Registry: application.RouterClassifierRegistry(),
|
||||
Evaluator: application.TemplatesEvaluator(),
|
||||
},
|
||||
),
|
||||
// Admission control runs after RouteModel so the SERVED
|
||||
|
||||
@@ -117,10 +117,10 @@ func RegisterPIIRoutes(e *echo.Echo, app *application.Application) {
|
||||
}
|
||||
res := app.PIIRedactor().Redact(body.Text)
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"redacted": res.Redacted,
|
||||
"spans": res.Spans,
|
||||
"blocked": res.Blocked,
|
||||
"local_only": res.LocalOnly,
|
||||
"redacted": res.Redacted,
|
||||
"spans": res.Spans,
|
||||
"blocked": res.Blocked,
|
||||
"masked": res.Masked,
|
||||
})
|
||||
})
|
||||
|
||||
@@ -142,12 +142,12 @@ func RegisterPIIRoutes(e *echo.Echo, app *application.Application) {
|
||||
|
||||
// PutPIIPatternActionEndpoint godoc
|
||||
// @Summary Change a pattern's action in-process
|
||||
// @Description Mutates the named pattern's action (mask|block|route_local). Transient — restored to YAML defaults on restart. Admin-only.
|
||||
// @Description Mutates the named pattern's action (mask|block|allow). Transient — restored to YAML defaults on restart. Admin-only.
|
||||
// @Tags pii
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "Pattern id"
|
||||
// @Param body body map[string]string true "JSON {\"action\":\"mask|block|route_local\"}"
|
||||
// @Param body body map[string]string true "JSON {\"action\":\"mask|block|allow\"}"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/pii/patterns/{id} [put]
|
||||
e.PUT("/api/pii/patterns/:id", func(c echo.Context) error {
|
||||
|
||||
@@ -499,7 +499,7 @@ type RouterDecideResponse struct {
|
||||
// inspects the text and returns findings + a suggested action; it
|
||||
// does NOT mutate the input, record an audit event, or rewrite any
|
||||
// downstream request. The caller composes the decision with its own
|
||||
// policy (mask, block, route to local-only backends, allow).
|
||||
// policy (mask, block, or allow).
|
||||
type PIIDecideRequest struct {
|
||||
// Text is the user-visible content to inspect. Required.
|
||||
Text string `json:"text"`
|
||||
@@ -507,19 +507,20 @@ type PIIDecideRequest struct {
|
||||
|
||||
// PIIDecideResponse carries the redactor's findings.
|
||||
// SuggestedAction is derived from the action ordering used by the
|
||||
// internal redactor (block > route_local > mask > allow) so callers
|
||||
// don't need to replicate that logic.
|
||||
// internal redactor (block > mask > allow) so callers don't need to
|
||||
// replicate that logic.
|
||||
type PIIDecideResponse struct {
|
||||
// Findings is one entry per matched span — pattern id, byte
|
||||
// range, and audit-safe hash prefix (never the matched value).
|
||||
Findings []PIIFinding `json:"findings"`
|
||||
// SuggestedAction is the strongest action across all findings:
|
||||
// "block", "route_local", "mask", or "allow" (no findings).
|
||||
// "block", "mask", or "allow" (no findings, or all findings
|
||||
// resolved to the allow action).
|
||||
SuggestedAction string `json:"suggested_action"`
|
||||
// RedactedPreview is the input with mask-action spans replaced
|
||||
// by their placeholders. Identical to Text when no findings or
|
||||
// when the strongest action is block/route_local (which don't
|
||||
// rewrite content).
|
||||
// when the strongest action is block/allow (which don't rewrite
|
||||
// content).
|
||||
RedactedPreview string `json:"redacted_preview"`
|
||||
}
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ func BuildStreamFilter(c echo.Context, cfg *config.ModelConfig, isStream bool, p
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal:
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionAllow:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func LoadConfig(path string) ([]Pattern, error) {
|
||||
continue
|
||||
}
|
||||
switch p.Action {
|
||||
case ActionMask, ActionBlock, ActionRouteLocal:
|
||||
case ActionMask, ActionBlock, ActionAllow:
|
||||
overrides[p.ID] = p.Action
|
||||
default:
|
||||
return nil, fmt.Errorf("pii: invalid action %q for pattern %q", p.Action, p.ID)
|
||||
|
||||
@@ -22,7 +22,7 @@ var _ = Describe("LoadConfig", func() {
|
||||
- id: email
|
||||
action: block
|
||||
- id: ssn
|
||||
action: route_local
|
||||
action: allow
|
||||
`)
|
||||
Expect(os.WriteFile(path, body, 0o600)).To(Succeed())
|
||||
patterns, err := LoadConfig(path)
|
||||
@@ -33,7 +33,7 @@ var _ = Describe("LoadConfig", func() {
|
||||
got[p.ID] = p.Action
|
||||
}
|
||||
Expect(got["email"]).To(Equal(ActionBlock))
|
||||
Expect(got["ssn"]).To(Equal(ActionRouteLocal))
|
||||
Expect(got["ssn"]).To(Equal(ActionAllow))
|
||||
// Unmentioned patterns keep their default action.
|
||||
Expect(got["credit_card"]).To(Equal(ActionMask), "credit_card default action lost")
|
||||
})
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
const (
|
||||
ctxKeyCorrelationID = "routing.correlation_id"
|
||||
ctxKeyPIIEventID = "routing.pii_event_id"
|
||||
ctxKeyLocalOnly = "routing.local_only"
|
||||
// Must match the constants in core/http/middleware/request.go.
|
||||
// Echoing them across packages would create an import cycle
|
||||
// (http/middleware imports this package). Drift is caught by
|
||||
@@ -37,7 +36,7 @@ const (
|
||||
//
|
||||
// Consumers of the override map: the action returned from PIIPatternOverrides
|
||||
// is the raw YAML string (e.g. "block"). Validation against the canonical
|
||||
// ActionMask/Block/RouteLocal constants happens here, so a typo in a model
|
||||
// ActionMask/Block/Allow constants happens here, so a typo in a model
|
||||
// YAML logs and is ignored rather than panicking.
|
||||
type ModelPIIConfig interface {
|
||||
PIIIsEnabled() bool
|
||||
@@ -77,9 +76,8 @@ type Adapter struct {
|
||||
// to the client.
|
||||
// - On match with action=mask: the redacted text replaces the
|
||||
// original on the parsed request. PIIEvents are recorded.
|
||||
// - On match with action=route_local: the original text is left
|
||||
// intact, but the echo context is annotated so the (future) router
|
||||
// middleware refuses cloud-proxy candidates.
|
||||
// - On match with action=allow: the original text is left intact; a
|
||||
// PIIEvent is still recorded so the detection is auditable.
|
||||
//
|
||||
// recorder is the Recorder on which to record events; nil disables
|
||||
// recording (the redaction still happens). fallbackUser supplies the
|
||||
@@ -138,7 +136,7 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
overrides = make(map[string]Action, len(raw))
|
||||
for id, action := range raw {
|
||||
switch Action(action) {
|
||||
case ActionMask, ActionBlock, ActionRouteLocal:
|
||||
case ActionMask, ActionBlock, ActionAllow:
|
||||
overrides[id] = Action(action)
|
||||
default:
|
||||
xlog.Warn("pii: ignoring unknown action in per-model override",
|
||||
@@ -151,7 +149,6 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
texts := adapter.Scan(parsed)
|
||||
updates := make([]ScannedText, 0, len(texts))
|
||||
var blocked bool
|
||||
var localOnly bool
|
||||
var firstEventID string
|
||||
|
||||
for _, st := range texts {
|
||||
@@ -201,9 +198,6 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
if res.Blocked {
|
||||
blocked = true
|
||||
}
|
||||
if res.LocalOnly {
|
||||
localOnly = true
|
||||
}
|
||||
updates = append(updates, ScannedText{Index: st.Index, Text: res.Redacted})
|
||||
}
|
||||
|
||||
@@ -224,10 +218,6 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
if firstEventID != "" {
|
||||
c.Set(ctxKeyPIIEventID, firstEventID)
|
||||
}
|
||||
if localOnly {
|
||||
c.Set(ctxKeyLocalOnly, true)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,9 +153,9 @@ var _ = Describe("RequestMiddleware", func() {
|
||||
Expect(errBlock["type"]).To(Equal("pii_blocked"))
|
||||
})
|
||||
|
||||
It("route_local sets context flag", func() {
|
||||
It("allow leaves text intact but still records an event", func() {
|
||||
patterns, _ := Compile([]Pattern{{
|
||||
ID: "email", Description: "Email", Action: ActionRouteLocal, MaxMatchLength: 254,
|
||||
ID: "email", Description: "Email", Action: ActionAllow, MaxMatchLength: 254,
|
||||
}})
|
||||
red := NewRedactor(patterns)
|
||||
store := NewMemoryEventStore(0)
|
||||
@@ -165,10 +165,7 @@ var _ = Describe("RequestMiddleware", func() {
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
var observedLocalOnly bool
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
v, _ := c.Get(ctxKeyLocalOnly).(bool)
|
||||
observedLocalOnly = v
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
|
||||
|
||||
@@ -177,9 +174,12 @@ var _ = Describe("RequestMiddleware", func() {
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(observedLocalOnly).To(BeTrue(), "ctxKeyLocalOnly should be true on route_local match")
|
||||
// route_local does NOT mutate the body — the model still sees the email.
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "route_local should leave text intact")
|
||||
// allow does NOT mutate the body — the model still sees the email.
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "allow should leave text intact")
|
||||
// ...but the detection is still recorded for audit.
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1), "allow should still record a PIIEvent")
|
||||
Expect(events[0].Action).To(Equal(ActionAllow))
|
||||
})
|
||||
|
||||
It("no match passes through", func() {
|
||||
|
||||
@@ -65,8 +65,8 @@ func (r *Redactor) Patterns() []Pattern {
|
||||
// older snapshot don't race on the per-element Action string (Go
|
||||
// strings are not atomic two-word values).
|
||||
func (r *Redactor) SetAction(id string, action Action) error {
|
||||
if action != ActionMask && action != ActionBlock && action != ActionRouteLocal {
|
||||
return fmt.Errorf("unknown action %q (must be mask, block, or route_local)", action)
|
||||
if action != ActionMask && action != ActionBlock && action != ActionAllow {
|
||||
return fmt.Errorf("unknown action %q (must be mask, block, or allow)", action)
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
@@ -114,8 +114,9 @@ func (r *Redactor) Redact(text string) Result {
|
||||
// and applies the resolved Action:
|
||||
// - block: sets Result.Blocked, leaves text intact (caller decides
|
||||
// whether to surface the redacted form).
|
||||
// - mask: replaces the span with maskFor(pattern.ID).
|
||||
// - route_local: sets Result.LocalOnly, leaves text intact.
|
||||
// - mask: replaces the span with maskFor(pattern.ID), sets Result.Masked.
|
||||
// - allow: leaves text intact and sets no flag (the span is still
|
||||
// recorded so the match is auditable).
|
||||
//
|
||||
// Spans are returned in the original input's coordinate system so the
|
||||
// PIIEvent record can be written without re-running the scan.
|
||||
@@ -254,7 +255,7 @@ func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
// Sort and deduplicate overlapping hits — when two patterns claim
|
||||
// the same span (e.g., a credit-card-shaped value also scans as
|
||||
// digits, or NER tags a span the regex also caught), keep the one
|
||||
// with the strongest action. Order: block > route_local > mask.
|
||||
// with the strongest action. Order: block > mask > allow.
|
||||
sort.Slice(hits, func(i, j int) bool {
|
||||
if hits[i].start != hits[j].start {
|
||||
return hits[i].start < hits[j].start
|
||||
@@ -298,10 +299,11 @@ func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
case ActionBlock:
|
||||
res.Blocked = true
|
||||
out.WriteString(matched)
|
||||
case ActionRouteLocal:
|
||||
res.LocalOnly = true
|
||||
case ActionAllow:
|
||||
// Detect-and-log only: leave the matched text in place.
|
||||
out.WriteString(matched)
|
||||
default:
|
||||
res.Masked = true
|
||||
out.WriteString(maskFor(h.patternID))
|
||||
}
|
||||
cursor = h.end
|
||||
@@ -333,9 +335,9 @@ func actionRank(a Action) int {
|
||||
switch a {
|
||||
case ActionBlock:
|
||||
return 3
|
||||
case ActionRouteLocal:
|
||||
return 2
|
||||
case ActionMask:
|
||||
return 2
|
||||
case ActionAllow:
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
|
||||
@@ -96,7 +96,7 @@ var _ = Describe("Redactor", func() {
|
||||
res := r.Redact("")
|
||||
Expect(res.Redacted).To(BeEmpty())
|
||||
Expect(res.Blocked).To(BeFalse())
|
||||
Expect(res.LocalOnly).To(BeFalse())
|
||||
Expect(res.Masked).To(BeFalse())
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
|
||||
@@ -165,10 +165,12 @@ var _ = Describe("RedactWithOverrides", func() {
|
||||
var _ = Describe("SetAction", func() {
|
||||
It("swaps in place", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
Expect(r.SetAction("email", ActionRouteLocal)).To(Succeed())
|
||||
Expect(r.SetAction("email", ActionAllow)).To(Succeed())
|
||||
res := r.Redact("contact alice@example.com")
|
||||
Expect(res.LocalOnly).To(BeTrue(), "expected LocalOnly after SetAction(route_local)")
|
||||
Expect(res.Blocked).To(BeFalse(), "SetAction(route_local) should not block")
|
||||
Expect(res.Masked).To(BeFalse(), "allow leaves text intact, so nothing is masked")
|
||||
Expect(res.Redacted).To(ContainSubstring("alice@example.com"), "allow should leave the match in place")
|
||||
Expect(res.Spans).To(HaveLen(1), "allow still records the match")
|
||||
Expect(res.Blocked).To(BeFalse(), "SetAction(allow) should not block")
|
||||
})
|
||||
|
||||
It("rejects unknown id", func() {
|
||||
|
||||
@@ -27,8 +27,9 @@ import (
|
||||
// reject the request. We remap block → mask for redaction purposes
|
||||
// while still recording PIIEvent rows with action="block" so audits
|
||||
// surface the original intent ("the model would have leaked X here,
|
||||
// suppressed in flight"). route_local on the output side is a no-op
|
||||
// (the dispatch decision was already made on the request side).
|
||||
// suppressed in flight"). allow on the output side is a no-op — the
|
||||
// text is left intact, matching its request-side detect-and-log
|
||||
// behaviour.
|
||||
//
|
||||
// StreamFilter is NOT safe for concurrent use across goroutines; one
|
||||
// instance per response stream.
|
||||
|
||||
@@ -11,13 +11,14 @@
|
||||
// drops in without changing call sites.
|
||||
//
|
||||
// Configuration model: each pattern has an Action (block | mask |
|
||||
// route_local). Actions are evaluated in this order:
|
||||
// allow). Actions are evaluated in this order:
|
||||
// - block: short-circuits the request with an error (the middleware
|
||||
// returns 400 to the client).
|
||||
// - mask: replaces the matched span with ReplacementFor(pattern).
|
||||
// - route_local: leaves the text alone but sets a context flag the
|
||||
// router (subsystem 2) treats as "this request must stay on a local
|
||||
// model" — never crosses the boundary to a cloud proxy backend.
|
||||
// - allow: detect-and-log only — the span is left intact and a
|
||||
// PIIEvent is still recorded, but the text passes through
|
||||
// unchanged. Useful to downgrade a pattern's default while keeping
|
||||
// it visible in the audit log.
|
||||
package pii
|
||||
|
||||
import "time"
|
||||
@@ -36,11 +37,13 @@ const (
|
||||
// the matched value).
|
||||
ActionBlock Action = "block"
|
||||
|
||||
// ActionRouteLocal leaves the text intact but flags the request so
|
||||
// the content router will refuse to dispatch it to a cloud proxy
|
||||
// backend. Useful when a deployment trusts local models with
|
||||
// sensitive data but not external providers.
|
||||
ActionRouteLocal Action = "route_local"
|
||||
// ActionAllow detects and logs the match but leaves the text
|
||||
// intact — no masking, no blocking. A PIIEvent is still recorded,
|
||||
// so the detection is auditable and forms the basis for surfacing
|
||||
// detected-PII labels to the router (a future router-model
|
||||
// feature). Use it to downgrade a pattern's default action for a
|
||||
// model while keeping the pattern visible.
|
||||
ActionAllow Action = "allow"
|
||||
)
|
||||
|
||||
// Direction tags whether a PIIEvent fired on input (request body before
|
||||
@@ -74,14 +77,15 @@ type Span struct {
|
||||
// the call site must enforce this by returning a 400 / refusing to
|
||||
// dispatch.
|
||||
//
|
||||
// LocalOnly is true iff at least one matched pattern had
|
||||
// Action=route_local. The router middleware reads this and constrains
|
||||
// candidate selection.
|
||||
// Masked is true iff at least one matched span was replaced with a
|
||||
// placeholder (Action=mask). Spans with Action=allow are recorded but
|
||||
// leave Masked false. Lets callers (e.g. the decision oracle)
|
||||
// distinguish "matched and redacted" from "matched but passed through".
|
||||
type Result struct {
|
||||
Redacted string
|
||||
Spans []Span
|
||||
Blocked bool
|
||||
LocalOnly bool
|
||||
Redacted string
|
||||
Spans []Span
|
||||
Blocked bool
|
||||
Masked bool
|
||||
}
|
||||
|
||||
// Pattern is one configurable rule. Description is shown in the admin
|
||||
|
||||
@@ -52,6 +52,10 @@ type EmbeddingCacheClassifier struct {
|
||||
similarityThreshold float64
|
||||
confidenceThreshold float64
|
||||
|
||||
// budget trims the conversation to the embedder model's own context
|
||||
// before embedding; nil embeds Probe.Prompt as built by the caller.
|
||||
budget *lazyBudget
|
||||
|
||||
hits atomic.Uint64
|
||||
misses atomic.Uint64
|
||||
nearMisses atomic.Uint64
|
||||
@@ -100,6 +104,15 @@ func NewEmbeddingCacheClassifier(inner Classifier, embedder backend.Embedder, st
|
||||
}
|
||||
}
|
||||
|
||||
// WithTokenTrim wires the embedder model's own tokenizer and context so the
|
||||
// probe embeds the most recent turns that fit instead of a caller-chosen size.
|
||||
// nil tokenizer / non-positive context leaves trimming off. Returns the
|
||||
// receiver for chaining at construction.
|
||||
func (c *EmbeddingCacheClassifier) WithTokenTrim(tokenize func(string) (int, error), maxContextTokens int) *EmbeddingCacheClassifier {
|
||||
c.budget = &lazyBudget{tokenize: tokenize, maxContext: maxContextTokens}
|
||||
return c
|
||||
}
|
||||
|
||||
// Name is the inner classifier's name — the decision-log "classifier"
|
||||
// field should reflect *what* made the decision, not the caching
|
||||
// transport. Cache hits set Decision.Cached separately so admins can
|
||||
@@ -127,7 +140,7 @@ func (c *EmbeddingCacheClassifier) Stats() EmbeddingCacheStats {
|
||||
func (c *EmbeddingCacheClassifier) Classify(ctx context.Context, p Probe) (Decision, error) {
|
||||
start := time.Now()
|
||||
|
||||
vec, err := c.embedder.Embed(ctx, p.Prompt)
|
||||
vec, err := c.embedder.Embed(ctx, trimmedProbeText(p, c.budget, identityRender))
|
||||
if err != nil {
|
||||
c.embedderErrors.Add(1)
|
||||
xlog.Warn("router: embedding cache embed failed", "error", err)
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +15,20 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// capturingEmbedder records the text it was last asked to embed and returns a
|
||||
// fixed vector, so a test can assert what the cache fed the embedder.
|
||||
type capturingEmbedder struct {
|
||||
mu sync.Mutex
|
||||
lastText string
|
||||
}
|
||||
|
||||
func (e *capturingEmbedder) Embed(_ context.Context, text string) ([]float32, error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.lastText = text
|
||||
return []float32{1, 2, 3}, nil
|
||||
}
|
||||
|
||||
// fakeEmbedder returns a vector keyed by a lookup table; this lets the
|
||||
// test exercise hit/miss control without depending on a real model.
|
||||
type fakeEmbedder struct {
|
||||
@@ -294,6 +310,45 @@ var _ = Describe("EmbeddingCache", func() {
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("EmbeddingCache WithTokenTrim", func() {
|
||||
ctx := context.Background()
|
||||
wordCount := func(s string) (int, error) { return len(strings.Fields(s)), nil }
|
||||
|
||||
It("embeds the most recent turns that fit the embedder context, not the full prompt", func() {
|
||||
emb := &capturingEmbedder{}
|
||||
store := &memVectorStore{}
|
||||
inner := &stubInner{name: "score", decision: router.Decision{Labels: []string{"x"}, Score: 0.1}}
|
||||
// context_size 50 → budget 50−16 margin ≈ 34 tokens, far under the
|
||||
// ~120-word transcript below, so the oldest turns must be dropped.
|
||||
cache := router.NewEmbeddingCacheClassifier(inner, emb, store, 0.92, 0.6).
|
||||
WithTokenTrim(wordCount, 50)
|
||||
|
||||
msgs := make([]string, 0, 31)
|
||||
for i := range 30 {
|
||||
msgs = append(msgs, fmt.Sprintf("OLDturn%d filler filler filler", i))
|
||||
}
|
||||
msgs = append(msgs, "NEWESTTURN final words here")
|
||||
full := strings.Join(msgs, "\n")
|
||||
|
||||
_, err := cache.Classify(ctx, router.Probe{Prompt: full, Messages: msgs})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(emb.lastText).To(ContainSubstring("NEWESTTURN"), "newest turn must survive")
|
||||
Expect(emb.lastText).NotTo(ContainSubstring("OLDturn0 "), "oldest turns trimmed to fit context")
|
||||
Expect(emb.lastText).NotTo(Equal(full), "must not embed the untrimmed prompt")
|
||||
})
|
||||
|
||||
It("embeds Probe.Prompt unchanged when no trim is wired", func() {
|
||||
emb := &capturingEmbedder{}
|
||||
store := &memVectorStore{}
|
||||
inner := &stubInner{name: "score", decision: router.Decision{Labels: []string{"x"}, Score: 0.1}}
|
||||
cache := router.NewEmbeddingCacheClassifier(inner, emb, store, 0.92, 0.6)
|
||||
|
||||
_, err := cache.Classify(ctx, router.Probe{Prompt: "PROMPTASIS", Messages: []string{"ignored-no-tokenizer"}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(emb.lastText).To(Equal("PROMPTASIS"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("EmbeddingCache latency", func() {
|
||||
It("is populated on hits", func() {
|
||||
embedder := &fakeEmbedder{table: map[string][]float32{"p": {1}}}
|
||||
|
||||
@@ -23,6 +23,11 @@ type RerankClassifier struct {
|
||||
labels []string
|
||||
documents []string
|
||||
cache *labelSetCache
|
||||
|
||||
// budget trims the query to the reranker model's context minus the
|
||||
// longest policy description (paired with the query per rerank call);
|
||||
// nil reranks Probe.Prompt as built by the caller.
|
||||
budget *lazyBudget
|
||||
}
|
||||
|
||||
// defaultRerankActivationThreshold is the relevance floor a label
|
||||
@@ -64,16 +69,26 @@ func NewRerankClassifier(policies []ScorePolicy, reranker backend.Reranker, cach
|
||||
}
|
||||
}
|
||||
|
||||
// WithTokenTrim wires the reranker model's own tokenizer and context so the
|
||||
// query is trimmed to the most recent turns that fit alongside the longest
|
||||
// policy description. nil tokenizer / non-positive context leaves trimming
|
||||
// off. Returns the receiver for chaining at construction.
|
||||
func (c *RerankClassifier) WithTokenTrim(tokenize func(string) (int, error), maxContextTokens int) *RerankClassifier {
|
||||
c.budget = &lazyBudget{tokenize: tokenize, maxContext: maxContextTokens, extras: c.documents}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *RerankClassifier) Name() string { return ClassifierColbert }
|
||||
|
||||
func (c *RerankClassifier) Classify(ctx context.Context, p Probe) (Decision, error) {
|
||||
start := time.Now()
|
||||
key := cacheKey(p.Prompt)
|
||||
query := trimmedProbeText(p, c.budget, identityRender)
|
||||
key := cacheKey(query)
|
||||
if hit, ok := c.cache.get(key); ok {
|
||||
return Decision{Labels: hit, Score: 1.0, Latency: time.Since(start)}, nil
|
||||
}
|
||||
|
||||
results, err := c.reranker.Rerank(ctx, p.Prompt, c.documents)
|
||||
results, err := c.reranker.Rerank(ctx, query, c.documents)
|
||||
if err != nil {
|
||||
return errDecision(start, fmt.Errorf("rerank classify: %w", err))
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package router
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -43,6 +45,31 @@ var _ = Describe("RerankClassifier", func() {
|
||||
Expect(d.Score).To(BeNumerically(">=", 0.9))
|
||||
})
|
||||
|
||||
It("trims the query to the reranker context, keeping the newest turns", func() {
|
||||
r := &stubReranker{results: []backend.RerankResult{
|
||||
{Index: 0, RelevanceScore: 0.92},
|
||||
{Index: 1, RelevanceScore: 0.10},
|
||||
{Index: 2, RelevanceScore: 0.05},
|
||||
}}
|
||||
wordCount := func(s string) (int, error) { return len(strings.Fields(s)), nil }
|
||||
// budget = 60 − longest policy description − 16 margin; still well under
|
||||
// the ~120-word transcript, so the oldest turns drop.
|
||||
c := NewRerankClassifier(testPolicies(), r, 0, 0).WithTokenTrim(wordCount, 60)
|
||||
|
||||
msgs := make([]string, 0, 31)
|
||||
for i := range 30 {
|
||||
msgs = append(msgs, fmt.Sprintf("OLDturn%d aaa bbb ccc", i))
|
||||
}
|
||||
msgs = append(msgs, "NEWESTTURN zzz")
|
||||
full := strings.Join(msgs, "\n")
|
||||
|
||||
_, err := c.Classify(context.Background(), Probe{Prompt: full, Messages: msgs})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(r.lastQ).To(ContainSubstring("NEWESTTURN"), "newest turn must survive")
|
||||
Expect(r.lastQ).NotTo(ContainSubstring("OLDturn0 "), "oldest turns trimmed to fit context")
|
||||
Expect(r.lastQ).NotTo(Equal(full), "must not rerank the untrimmed prompt")
|
||||
})
|
||||
|
||||
It("activates multiple labels when several descriptions clear threshold", func() {
|
||||
r := &stubReranker{results: []backend.RerankResult{
|
||||
{Index: 0, RelevanceScore: 0.85},
|
||||
|
||||
@@ -91,6 +91,13 @@ type ScoreClassifierOptions struct {
|
||||
// override that instructs the model to emit a different schema
|
||||
// would silently desync from what the scorer actually scores.
|
||||
SystemPromptTemplate string
|
||||
|
||||
// TokenCounter + MaxContextTokens drive conversation trimming: when
|
||||
// both are set, Classify drops the oldest turns until the rendered
|
||||
// prompt fits the classifier's context. Nil/0 disables — Classify
|
||||
// sends Probe.Prompt as-is and relies on the backend's n_ctx guard.
|
||||
TokenCounter func(string) (int, error)
|
||||
MaxContextTokens int
|
||||
}
|
||||
|
||||
// ScoreClassifier scores every policy label as the model's actual
|
||||
@@ -127,6 +134,10 @@ type ScoreClassifier struct {
|
||||
// log-prob. Built once at construction; same list every call.
|
||||
candidates []string
|
||||
|
||||
// budget caps the rendered prompt at the classifier's context minus the
|
||||
// longest candidate; nil/disabled sends Probe.Prompt as-is.
|
||||
budget *lazyBudget
|
||||
|
||||
cache *labelSetCache
|
||||
}
|
||||
|
||||
@@ -191,6 +202,7 @@ func NewScoreClassifier(policies []ScorePolicy, scorer backend.Scorer, opts Scor
|
||||
systemPrompt: systemPrompt,
|
||||
labelOrder: labels,
|
||||
candidates: candidates,
|
||||
budget: &lazyBudget{tokenize: opts.TokenCounter, maxContext: opts.MaxContextTokens, extras: candidates},
|
||||
cache: newLabelSetCache(opts.CacheCap),
|
||||
}
|
||||
}
|
||||
@@ -218,11 +230,19 @@ func (c *ScoreClassifier) Name() string { return ClassifierScore }
|
||||
|
||||
func (c *ScoreClassifier) Classify(ctx context.Context, p Probe) (Decision, error) {
|
||||
start := time.Now()
|
||||
key := cacheKey(p.Prompt)
|
||||
|
||||
// Trim oldest turns until the rendered prompt fits the classifier's
|
||||
// context. Cache-keyed on the trimmed text so conversations that
|
||||
// trim to the same tail share an entry.
|
||||
userText := trimmedProbeText(p, c.budget, func(joined string) (string, error) {
|
||||
return c.renderer(c.systemPrompt, joined)
|
||||
})
|
||||
|
||||
key := cacheKey(userText)
|
||||
if hit, ok := c.cache.get(key); ok {
|
||||
return Decision{Labels: hit, Score: 1.0, Latency: time.Since(start)}, nil
|
||||
}
|
||||
prompt, err := c.renderer(c.systemPrompt, p.Prompt)
|
||||
prompt, err := c.renderer(c.systemPrompt, userText)
|
||||
if err != nil {
|
||||
return errDecision(start, fmt.Errorf("score classify: render prompt: %w", err))
|
||||
}
|
||||
@@ -331,6 +351,12 @@ func softmax(logProbs []float64) []float64 {
|
||||
|
||||
func (c *ScoreClassifier) CacheLen() int { return c.cache.len() }
|
||||
|
||||
// probeTokenBudget returns the token ceiling for the rendered prompt (context
|
||||
// − longest candidate − margin), computed once via the shared lazyBudget. 0
|
||||
// means trimming is off (no tokenizer/context) or impossible (candidates fill
|
||||
// the context).
|
||||
func (c *ScoreClassifier) probeTokenBudget() int { return c.budget.get() }
|
||||
|
||||
// buildScoreSystemPrompt renders the Arch-Router-style routing
|
||||
// instructions: routes listed in a structured block, output schema
|
||||
// declared as JSON {"route": "<name>"}. Candidates are scored as
|
||||
|
||||
@@ -3,8 +3,10 @@ package router
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -335,3 +337,138 @@ Reply: {"route": "<name>"}`
|
||||
Expect(c.Name()).To(Equal(ClassifierScore))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("ScoreClassifier conversation trimming", func() {
|
||||
wordCount := func(s string) (int, error) { return len(strings.Fields(s)), nil }
|
||||
threeScores := []backend.CandidateScore{
|
||||
{LogProb: -0.05, NumTokens: 3},
|
||||
{LogProb: -3.0, NumTokens: 3},
|
||||
{LogProb: -4.0, NumTokens: 3},
|
||||
}
|
||||
|
||||
It("drops the oldest turns when the conversation exceeds the context budget", func() {
|
||||
s := &stubScorer{results: threeScores}
|
||||
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
|
||||
TokenCounter: wordCount,
|
||||
MaxContextTokens: 10000,
|
||||
})
|
||||
Expect(c.probeTokenBudget()).To(BeNumerically(">", 0), "budget should be positive for a 10k context")
|
||||
|
||||
msgs := make([]string, 0, 200)
|
||||
msgs = append(msgs, "OLDESTMARKER "+strings.Repeat("x ", 99)) // 100 words
|
||||
for range 198 {
|
||||
msgs = append(msgs, strings.Repeat("y ", 100))
|
||||
}
|
||||
msgs = append(msgs, "NEWESTMARKER "+strings.Repeat("z ", 99)) // 100 words; ~20k words total
|
||||
|
||||
_, err := c.Classify(context.Background(), Probe{Messages: msgs, Prompt: strings.Join(msgs, "\n")})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.lastP).To(ContainSubstring("NEWESTMARKER"), "newest turn must survive the trim")
|
||||
Expect(s.lastP).NotTo(ContainSubstring("OLDESTMARKER"), "oldest turn must be dropped")
|
||||
Expect(len(strings.Fields(s.lastP))).To(BeNumerically("<", 20000), "must be trimmed, not the full transcript")
|
||||
})
|
||||
|
||||
It("keeps the newest turn whole even when it alone exceeds the budget", func() {
|
||||
s := &stubScorer{results: threeScores}
|
||||
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
|
||||
TokenCounter: wordCount,
|
||||
MaxContextTokens: 10000,
|
||||
})
|
||||
msgs := []string{
|
||||
"OLDMARKER short",
|
||||
"NEWESTMARKER " + strings.Repeat("z ", 12000), // far over budget
|
||||
}
|
||||
_, err := c.Classify(context.Background(), Probe{Messages: msgs})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.lastP).To(ContainSubstring("NEWESTMARKER"))
|
||||
Expect(s.lastP).NotTo(ContainSubstring("OLDMARKER"), "older turn drops once the newest fills the budget")
|
||||
})
|
||||
|
||||
It("does not tokenize per message and bounds what it tokenizes for a long conversation", func() {
|
||||
// Regression: the original trim tokenized one message at a time,
|
||||
// newest-first, so a 500-turn conversation produced hundreds of
|
||||
// tokenize RPCs. The render-once design must tokenize the candidates
|
||||
// (budget setup) plus a small constant for the measurement/confirm
|
||||
// passes — and the rune pre-trim must keep the tokenized prompt far
|
||||
// smaller than the full transcript.
|
||||
calls := 0
|
||||
maxRunes := 0
|
||||
counting := func(s string) (int, error) {
|
||||
calls++
|
||||
if r := utf8.RuneCountInString(s); r > maxRunes {
|
||||
maxRunes = r
|
||||
}
|
||||
return len(strings.Fields(s)), nil
|
||||
}
|
||||
s := &stubScorer{results: threeScores}
|
||||
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
|
||||
TokenCounter: counting,
|
||||
MaxContextTokens: 4000,
|
||||
})
|
||||
|
||||
msgs := make([]string, 500)
|
||||
totalRunes := 0
|
||||
for i := range msgs {
|
||||
msgs[i] = fmt.Sprintf("msg%d %s", i, strings.Repeat("w ", 50))
|
||||
totalRunes += utf8.RuneCountInString(msgs[i])
|
||||
}
|
||||
|
||||
_, err := c.Classify(context.Background(), Probe{Messages: msgs})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.lastP).To(ContainSubstring("msg499"), "newest turn must survive")
|
||||
Expect(s.lastP).NotTo(ContainSubstring("msg0 "), "oldest turns must be dropped")
|
||||
Expect(calls).To(BeNumerically("<", 20),
|
||||
"tokenizer must not be called per message (got %d calls for 500 messages)", calls)
|
||||
Expect(maxRunes).To(BeNumerically("<", totalRunes/2),
|
||||
"rune pre-trim must keep the tokenized prompt well under the full transcript")
|
||||
})
|
||||
|
||||
It("uses Probe.Prompt unchanged when no tokenizer is wired", func() {
|
||||
s := &stubScorer{results: threeScores}
|
||||
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{})
|
||||
Expect(c.probeTokenBudget()).To(Equal(0))
|
||||
|
||||
_, err := c.Classify(context.Background(), Probe{
|
||||
Prompt: "PROMPTONLYMARKER",
|
||||
Messages: []string{"ignored-because-no-tokenizer"},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.lastP).To(ContainSubstring("PROMPTONLYMARKER"))
|
||||
Expect(s.lastP).NotTo(ContainSubstring("ignored-because-no-tokenizer"))
|
||||
})
|
||||
|
||||
It("disables trimming (budget 0) when the tokenizer errors", func() {
|
||||
s := &stubScorer{results: threeScores}
|
||||
boom := func(string) (int, error) { return 0, errors.New("tokenizer down") }
|
||||
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
|
||||
TokenCounter: boom,
|
||||
MaxContextTokens: 10000,
|
||||
})
|
||||
Expect(c.probeTokenBudget()).To(Equal(0), "a tokenizer error must disable trimming, not panic")
|
||||
|
||||
_, err := c.Classify(context.Background(), Probe{Prompt: "FALLBACKMARKER", Messages: []string{"a", "b"}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.lastP).To(ContainSubstring("FALLBACKMARKER"))
|
||||
})
|
||||
|
||||
It("retries the budget after a TRANSIENT tokenizer error instead of disabling permanently", func() {
|
||||
// Regression: a sync.Once would memoize the first failure and never
|
||||
// recompute. The first call (model still loading) errors; a later
|
||||
// call must succeed and yield a real budget.
|
||||
s := &stubScorer{results: threeScores}
|
||||
calls := 0
|
||||
flaky := func(text string) (int, error) {
|
||||
calls++
|
||||
if calls == 1 {
|
||||
return 0, errors.New("model still loading")
|
||||
}
|
||||
return len(strings.Fields(text)), nil
|
||||
}
|
||||
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
|
||||
TokenCounter: flaky,
|
||||
MaxContextTokens: 10000,
|
||||
})
|
||||
Expect(c.probeTokenBudget()).To(Equal(0), "first call: tokenizer error leaves budget uncomputed")
|
||||
Expect(c.probeTokenBudget()).To(BeNumerically(">", 0), "retry: budget computes once the tokenizer recovers")
|
||||
})
|
||||
})
|
||||
|
||||
178
core/services/routing/router/trim.go
Normal file
178
core/services/routing/router/trim.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// pretrimRunesPerToken is deliberately high (most text is 3–5 runes/token,
|
||||
// tokenisers rarely exceed 6) so the cheap rune pre-trim keeps a superset of
|
||||
// what fits before any tokenize call.
|
||||
const pretrimRunesPerToken = 6
|
||||
|
||||
// tokenBudgetMargin absorbs BPE-boundary drift and the framing tokens a
|
||||
// renderer adds, so a prompt measured at exactly the budget still fits n_ctx.
|
||||
const tokenBudgetMargin = 16
|
||||
|
||||
// JoinTurns joins per-turn texts oldest→newest with a trailing newline each.
|
||||
// The probe builder, the trimmer, and every classifier share this so the text
|
||||
// a model sees has one canonical shape.
|
||||
func JoinTurns(turns []string) string {
|
||||
var b strings.Builder
|
||||
for _, m := range turns {
|
||||
b.WriteString(m)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// promptTrimmer fits an oldest→newest turn list into a token budget for one
|
||||
// model: optimistic rune pre-trim, tokenize once, then recalibrate with the
|
||||
// real runes/token and drop whole turns oldest-first until the rendered prompt
|
||||
// fits. The newest turn is never dropped — if it alone overflows it's sent
|
||||
// whole and the backend's n_ctx guard is the backstop.
|
||||
//
|
||||
// render wraps the joined turns into what the model actually tokenizes: a chat
|
||||
// template for the scorer, identityRender for an embedder/reranker on raw text.
|
||||
type promptTrimmer struct {
|
||||
tokenize func(string) (int, error)
|
||||
render func(joined string) (string, error)
|
||||
budget int
|
||||
}
|
||||
|
||||
func identityRender(s string) (string, error) { return s, nil }
|
||||
|
||||
func (t promptTrimmer) fit(turns []string) string {
|
||||
if len(turns) == 0 {
|
||||
return ""
|
||||
}
|
||||
kept := turns[runePretrimStart(turns, t.budget*pretrimRunesPerToken):]
|
||||
|
||||
joined := JoinTurns(kept)
|
||||
rendered, err := t.render(joined)
|
||||
if err != nil {
|
||||
return joined
|
||||
}
|
||||
total, err := t.tokenize(rendered)
|
||||
if err != nil || total <= t.budget {
|
||||
return joined
|
||||
}
|
||||
|
||||
runesPerToken := float64(utf8.RuneCountInString(rendered)) / float64(total)
|
||||
if runesPerToken <= 0 {
|
||||
runesPerToken = 1
|
||||
}
|
||||
est := total
|
||||
keep := 0
|
||||
for keep < len(kept)-1 && est > t.budget {
|
||||
est -= int(math.Ceil(float64(utf8.RuneCountInString(kept[keep])) / runesPerToken))
|
||||
keep++
|
||||
}
|
||||
|
||||
for {
|
||||
tail := JoinTurns(kept[keep:])
|
||||
rendered, err := t.render(tail)
|
||||
if err != nil {
|
||||
return tail
|
||||
}
|
||||
n, err := t.tokenize(rendered)
|
||||
if err != nil || n <= t.budget {
|
||||
return tail
|
||||
}
|
||||
if keep >= len(kept)-1 {
|
||||
xlog.Warn("router: newest turn alone exceeds model context; sending it whole — backend n_ctx guard is the backstop",
|
||||
"tokens", n, "budget", t.budget)
|
||||
return tail
|
||||
}
|
||||
keep++
|
||||
}
|
||||
}
|
||||
|
||||
// runePretrimStart returns the oldest index to keep so the joined tail stays
|
||||
// within budgetRunes. The newest turn is always kept; older ones are added
|
||||
// while they fit.
|
||||
func runePretrimStart(turns []string, budgetRunes int) int {
|
||||
if budgetRunes <= 0 || len(turns) == 0 {
|
||||
return 0
|
||||
}
|
||||
start := len(turns) - 1
|
||||
total := utf8.RuneCountInString(turns[start])
|
||||
for i := len(turns) - 2; i >= 0; i-- {
|
||||
r := utf8.RuneCountInString(turns[i])
|
||||
if total+r > budgetRunes {
|
||||
break
|
||||
}
|
||||
total += r
|
||||
start = i
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
// lazyBudget computes a model's probe token budget once, on first use, caching
|
||||
// the result: maxContext minus the longest per-call extra (scorer candidates,
|
||||
// reranker documents; none for a plain embed) minus tokenBudgetMargin. A
|
||||
// tokenizer error leaves it uncomputed so a transient failure (model still
|
||||
// loading) recovers on a later call; extras that already fill the context are
|
||||
// cached as disabled.
|
||||
type lazyBudget struct {
|
||||
tokenize func(string) (int, error)
|
||||
maxContext int
|
||||
extras []string
|
||||
|
||||
mu sync.Mutex
|
||||
value atomic.Int64 // 0=unset, >0=budget, -1=disabled
|
||||
}
|
||||
|
||||
func (l *lazyBudget) get() int {
|
||||
if l == nil || l.tokenize == nil || l.maxContext <= 0 {
|
||||
return 0
|
||||
}
|
||||
if v := l.value.Load(); v != 0 {
|
||||
if v < 0 {
|
||||
return 0
|
||||
}
|
||||
return int(v)
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if v := l.value.Load(); v != 0 {
|
||||
if v < 0 {
|
||||
return 0
|
||||
}
|
||||
return int(v)
|
||||
}
|
||||
longest := 0
|
||||
for _, e := range l.extras {
|
||||
n, err := l.tokenize(e)
|
||||
if err != nil {
|
||||
return 0 // transient: leave unset so a later call retries
|
||||
}
|
||||
if n > longest {
|
||||
longest = n
|
||||
}
|
||||
}
|
||||
b := l.maxContext - longest - tokenBudgetMargin
|
||||
if b <= 0 {
|
||||
l.value.Store(-1)
|
||||
return 0
|
||||
}
|
||||
l.value.Store(int64(b))
|
||||
return b
|
||||
}
|
||||
|
||||
// trimmedProbeText returns the text to feed a model: the most recent turns
|
||||
// that fit its token budget, or p.Prompt when trimming is disabled (no
|
||||
// tokenizer/context wired, or a single-input probe with no Messages).
|
||||
func trimmedProbeText(p Probe, b *lazyBudget, render func(string) (string, error)) string {
|
||||
if len(p.Messages) > 0 {
|
||||
if budget := b.get(); budget > 0 {
|
||||
return promptTrimmer{tokenize: b.tokenize, render: render, budget: budget}.fit(p.Messages)
|
||||
}
|
||||
}
|
||||
return p.Prompt
|
||||
}
|
||||
@@ -31,6 +31,15 @@ type Probe struct {
|
||||
// is the concatenation of message contents (separated by newlines);
|
||||
// for plain completions it is the raw prompt.
|
||||
Prompt string
|
||||
|
||||
// Messages carries the per-turn texts (oldest→newest) when the probe
|
||||
// came from a multi-message chat request. A classifier with a real
|
||||
// tokenizer (the score classifier) uses these to trim an over-long
|
||||
// conversation to the classifier model's context window on turn
|
||||
// boundaries, keeping the most recent turns. Empty for single-input
|
||||
// probes (plain completions, /router/decide), in which case the
|
||||
// classifier falls back to Prompt verbatim.
|
||||
Messages []string
|
||||
}
|
||||
|
||||
// Decision is the classifier's output. Labels carries the SET of
|
||||
|
||||
@@ -33,6 +33,7 @@ const (
|
||||
BackendTraceAudioTransform BackendTraceType = "audio_transform"
|
||||
BackendTraceModelLoad BackendTraceType = "model_load"
|
||||
BackendTraceScore BackendTraceType = "score"
|
||||
BackendTraceVectorStore BackendTraceType = "vector_store"
|
||||
)
|
||||
|
||||
type BackendTrace struct {
|
||||
|
||||
Reference in New Issue
Block a user