fix(router): production-ready request router + auto-size batch for embedding/rerank (#10104)

* fix(router): score classifier production-readiness

Conversation trimming runs through the classifier model's chat template
and trims by exact token count, sized to the model's n_batch which is
now scaled to context so long probes can't crash the backend. Missing
chat_message templates are a hard error at router build time. Router-
facing factories (Embedder/Scorer/Reranker/TokenCounter) re-resolve
ModelConfig per call so a model installed post-startup doesn't bind a
stub Backend="" config and silently fall into the loader's auto-
iterate path.

New 'vector_store' backend trace recorded inside localVectorStore on
every Search/Insert — including the backend-load-failure path that
previously vanished into an xlog.Warn — with outcome tagging
(hit/miss/empty_store/backend_load_error/find_error/insert_error/ok).
Companion cleanup drops misleading similarity:0 and input_tokens_count:0
from non-hit and text-mode traces.

Gallery local-store-development aliases to 'local-store' so the master
image satisfies pkg/model.LocalStoreBackend lookups from the embedding
cache.

Misc: llama-cpp TokenizeString reads the correct 'prompt' JSON key
(the original bug); ModelTokenize nil-guard; non-fatal mitm proxy
startup; PII 'route_local' renamed to 'allow' with docs/UI in sync;
model-editor footer no longer eats the edit area on small screens;
several config-editor template/dropdown/section fixes.

Tests: e2e router specs (casual/code-hint + long-conversation trim),
vector_store trace specs, lazy-factory specs, gallery dev-alias
resolution, Playwright trace badge + scroll regression.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(backend): auto-size batch to context for embedding and rerank models

Embedding and rerank models pool over the whole input in a single physical batch (n_ubatch). With batch left at the 512 default, the backend rejects longer inputs with "input is too large to process", silently capping a large-context embedder (e.g. 8k/32k) at 512 tokens. Size n_batch to the context for these single-pass usecases, mirroring the existing FLAG_SCORE behaviour; an explicit batch: still wins.

Extracts EffectiveContextSize/EffectiveBatchSize from grpcModelOpts so the effective decode window has one home for other callers to reuse.

Adds an e2e-aio regression test that embeds a >512-token input. The AIO embedding model is switched to nomic-embed-text-v1.5 (2048 context) because the previous granite model was capped at 512 tokens and could not exercise the larger batch.

Assisted-by: claude-code:claude-opus-4-8 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(gallery): raise arch-router scoring output cap via parallel:64

Scoring decodes the whole prompt+candidate in a single llama_decode and
reads one logit row per candidate token. The vendored llama.cpp server
caps causal output rows at n_parallel, so the default of 1 aborts with
GGML_ASSERT(n_outputs_max <= cparams.n_outputs_max) on multi-token route
labels. Set options: [parallel:64] on both arch-router quant entries to
lift the cap; kv_unified (the grpc-server default) keeps the full context
per sequence, so this does not split the KV cache.

Assisted-by: claude-code:claude-opus-4-8 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe
2026-06-12 15:21:15 +01:00
committed by GitHub
parent 56cc4f63fc
commit 085fc53bbc
86 changed files with 2305 additions and 387 deletions

View File

@@ -38,7 +38,7 @@ func BuildStreamFilter(c echo.Context, cfg *config.ModelConfig, isStream bool, p
overrides = make(map[string]pii.Action, len(raw))
for ovid, action := range raw {
switch pii.Action(action) {
case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal:
case pii.ActionMask, pii.ActionBlock, pii.ActionAllow:
overrides[ovid] = pii.Action(action)
}
}

View File

@@ -53,7 +53,7 @@ func LoadConfig(path string) ([]Pattern, error) {
continue
}
switch p.Action {
case ActionMask, ActionBlock, ActionRouteLocal:
case ActionMask, ActionBlock, ActionAllow:
overrides[p.ID] = p.Action
default:
return nil, fmt.Errorf("pii: invalid action %q for pattern %q", p.Action, p.ID)

View File

@@ -22,7 +22,7 @@ var _ = Describe("LoadConfig", func() {
- id: email
action: block
- id: ssn
action: route_local
action: allow
`)
Expect(os.WriteFile(path, body, 0o600)).To(Succeed())
patterns, err := LoadConfig(path)
@@ -33,7 +33,7 @@ var _ = Describe("LoadConfig", func() {
got[p.ID] = p.Action
}
Expect(got["email"]).To(Equal(ActionBlock))
Expect(got["ssn"]).To(Equal(ActionRouteLocal))
Expect(got["ssn"]).To(Equal(ActionAllow))
// Unmentioned patterns keep their default action.
Expect(got["credit_card"]).To(Equal(ActionMask), "credit_card default action lost")
})

View File

@@ -21,7 +21,6 @@ import (
const (
ctxKeyCorrelationID = "routing.correlation_id"
ctxKeyPIIEventID = "routing.pii_event_id"
ctxKeyLocalOnly = "routing.local_only"
// Must match the constants in core/http/middleware/request.go.
// Echoing them across packages would create an import cycle
// (http/middleware imports this package). Drift is caught by
@@ -37,7 +36,7 @@ const (
//
// Consumers of the override map: the action returned from PIIPatternOverrides
// is the raw YAML string (e.g. "block"). Validation against the canonical
// ActionMask/Block/RouteLocal constants happens here, so a typo in a model
// ActionMask/Block/Allow constants happens here, so a typo in a model
// YAML logs and is ignored rather than panicking.
type ModelPIIConfig interface {
PIIIsEnabled() bool
@@ -77,9 +76,8 @@ type Adapter struct {
// to the client.
// - On match with action=mask: the redacted text replaces the
// original on the parsed request. PIIEvents are recorded.
// - On match with action=route_local: the original text is left
// intact, but the echo context is annotated so the (future) router
// middleware refuses cloud-proxy candidates.
// - On match with action=allow: the original text is left intact; a
// PIIEvent is still recorded so the detection is auditable.
//
// recorder is the Recorder on which to record events; nil disables
// recording (the redaction still happens). fallbackUser supplies the
@@ -138,7 +136,7 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
overrides = make(map[string]Action, len(raw))
for id, action := range raw {
switch Action(action) {
case ActionMask, ActionBlock, ActionRouteLocal:
case ActionMask, ActionBlock, ActionAllow:
overrides[id] = Action(action)
default:
xlog.Warn("pii: ignoring unknown action in per-model override",
@@ -151,7 +149,6 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
texts := adapter.Scan(parsed)
updates := make([]ScannedText, 0, len(texts))
var blocked bool
var localOnly bool
var firstEventID string
for _, st := range texts {
@@ -201,9 +198,6 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
if res.Blocked {
blocked = true
}
if res.LocalOnly {
localOnly = true
}
updates = append(updates, ScannedText{Index: st.Index, Text: res.Redacted})
}
@@ -224,10 +218,6 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
if firstEventID != "" {
c.Set(ctxKeyPIIEventID, firstEventID)
}
if localOnly {
c.Set(ctxKeyLocalOnly, true)
}
return next(c)
}
}

View File

@@ -153,9 +153,9 @@ var _ = Describe("RequestMiddleware", func() {
Expect(errBlock["type"]).To(Equal("pii_blocked"))
})
It("route_local sets context flag", func() {
It("allow leaves text intact but still records an event", func() {
patterns, _ := Compile([]Pattern{{
ID: "email", Description: "Email", Action: ActionRouteLocal, MaxMatchLength: 254,
ID: "email", Description: "Email", Action: ActionAllow, MaxMatchLength: 254,
}})
red := NewRedactor(patterns)
store := NewMemoryEventStore(0)
@@ -165,10 +165,7 @@ var _ = Describe("RequestMiddleware", func() {
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
e := echo.New()
var observedLocalOnly bool
e.POST("/chat", func(c echo.Context) error {
v, _ := c.Get(ctxKeyLocalOnly).(bool)
observedLocalOnly = v
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
@@ -177,9 +174,12 @@ var _ = Describe("RequestMiddleware", func() {
e.ServeHTTP(w, req)
Expect(w.Code).To(Equal(http.StatusOK))
Expect(observedLocalOnly).To(BeTrue(), "ctxKeyLocalOnly should be true on route_local match")
// route_local does NOT mutate the body — the model still sees the email.
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "route_local should leave text intact")
// allow does NOT mutate the body — the model still sees the email.
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "allow should leave text intact")
// ...but the detection is still recorded for audit.
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(HaveLen(1), "allow should still record a PIIEvent")
Expect(events[0].Action).To(Equal(ActionAllow))
})
It("no match passes through", func() {

View File

@@ -65,8 +65,8 @@ func (r *Redactor) Patterns() []Pattern {
// older snapshot don't race on the per-element Action string (Go
// strings are not atomic two-word values).
func (r *Redactor) SetAction(id string, action Action) error {
if action != ActionMask && action != ActionBlock && action != ActionRouteLocal {
return fmt.Errorf("unknown action %q (must be mask, block, or route_local)", action)
if action != ActionMask && action != ActionBlock && action != ActionAllow {
return fmt.Errorf("unknown action %q (must be mask, block, or allow)", action)
}
r.mu.Lock()
defer r.mu.Unlock()
@@ -114,8 +114,9 @@ func (r *Redactor) Redact(text string) Result {
// and applies the resolved Action:
// - block: sets Result.Blocked, leaves text intact (caller decides
// whether to surface the redacted form).
// - mask: replaces the span with maskFor(pattern.ID).
// - route_local: sets Result.LocalOnly, leaves text intact.
// - mask: replaces the span with maskFor(pattern.ID), sets Result.Masked.
// - allow: leaves text intact and sets no flag (the span is still
// recorded so the match is auditable).
//
// Spans are returned in the original input's coordinate system so the
// PIIEvent record can be written without re-running the scan.
@@ -254,7 +255,7 @@ func mergeAndEmit(text string, hits []rawHit) Result {
// Sort and deduplicate overlapping hits — when two patterns claim
// the same span (e.g., a credit-card-shaped value also scans as
// digits, or NER tags a span the regex also caught), keep the one
// with the strongest action. Order: block > route_local > mask.
// with the strongest action. Order: block > mask > allow.
sort.Slice(hits, func(i, j int) bool {
if hits[i].start != hits[j].start {
return hits[i].start < hits[j].start
@@ -298,10 +299,11 @@ func mergeAndEmit(text string, hits []rawHit) Result {
case ActionBlock:
res.Blocked = true
out.WriteString(matched)
case ActionRouteLocal:
res.LocalOnly = true
case ActionAllow:
// Detect-and-log only: leave the matched text in place.
out.WriteString(matched)
default:
res.Masked = true
out.WriteString(maskFor(h.patternID))
}
cursor = h.end
@@ -333,9 +335,9 @@ func actionRank(a Action) int {
switch a {
case ActionBlock:
return 3
case ActionRouteLocal:
return 2
case ActionMask:
return 2
case ActionAllow:
return 1
}
return 0

View File

@@ -96,7 +96,7 @@ var _ = Describe("Redactor", func() {
res := r.Redact("")
Expect(res.Redacted).To(BeEmpty())
Expect(res.Blocked).To(BeFalse())
Expect(res.LocalOnly).To(BeFalse())
Expect(res.Masked).To(BeFalse())
Expect(res.Spans).To(BeEmpty())
})
@@ -165,10 +165,12 @@ var _ = Describe("RedactWithOverrides", func() {
var _ = Describe("SetAction", func() {
It("swaps in place", func() {
r := NewRedactor(mustCompile("email"))
Expect(r.SetAction("email", ActionRouteLocal)).To(Succeed())
Expect(r.SetAction("email", ActionAllow)).To(Succeed())
res := r.Redact("contact alice@example.com")
Expect(res.LocalOnly).To(BeTrue(), "expected LocalOnly after SetAction(route_local)")
Expect(res.Blocked).To(BeFalse(), "SetAction(route_local) should not block")
Expect(res.Masked).To(BeFalse(), "allow leaves text intact, so nothing is masked")
Expect(res.Redacted).To(ContainSubstring("alice@example.com"), "allow should leave the match in place")
Expect(res.Spans).To(HaveLen(1), "allow still records the match")
Expect(res.Blocked).To(BeFalse(), "SetAction(allow) should not block")
})
It("rejects unknown id", func() {

View File

@@ -27,8 +27,9 @@ import (
// reject the request. We remap block → mask for redaction purposes
// while still recording PIIEvent rows with action="block" so audits
// surface the original intent ("the model would have leaked X here,
// suppressed in flight"). route_local on the output side is a no-op
// (the dispatch decision was already made on the request side).
// suppressed in flight"). allow on the output side is a no-op — the
// text is left intact, matching its request-side detect-and-log
// behaviour.
//
// StreamFilter is NOT safe for concurrent use across goroutines; one
// instance per response stream.

View File

@@ -11,13 +11,14 @@
// drops in without changing call sites.
//
// Configuration model: each pattern has an Action (block | mask |
// route_local). Actions are evaluated in this order:
// allow). Actions are evaluated in this order:
// - block: short-circuits the request with an error (the middleware
// returns 400 to the client).
// - mask: replaces the matched span with ReplacementFor(pattern).
// - route_local: leaves the text alone but sets a context flag the
// router (subsystem 2) treats as "this request must stay on a local
// model" — never crosses the boundary to a cloud proxy backend.
// - allow: detect-and-log only — the span is left intact and a
// PIIEvent is still recorded, but the text passes through
// unchanged. Useful to downgrade a pattern's default while keeping
// it visible in the audit log.
package pii
import "time"
@@ -36,11 +37,13 @@ const (
// the matched value).
ActionBlock Action = "block"
// ActionRouteLocal leaves the text intact but flags the request so
// the content router will refuse to dispatch it to a cloud proxy
// backend. Useful when a deployment trusts local models with
// sensitive data but not external providers.
ActionRouteLocal Action = "route_local"
// ActionAllow detects and logs the match but leaves the text
// intact — no masking, no blocking. A PIIEvent is still recorded,
// so the detection is auditable and forms the basis for surfacing
// detected-PII labels to the router (a future router-model
// feature). Use it to downgrade a pattern's default action for a
// model while keeping the pattern visible.
ActionAllow Action = "allow"
)
// Direction tags whether a PIIEvent fired on input (request body before
@@ -74,14 +77,15 @@ type Span struct {
// the call site must enforce this by returning a 400 / refusing to
// dispatch.
//
// LocalOnly is true iff at least one matched pattern had
// Action=route_local. The router middleware reads this and constrains
// candidate selection.
// Masked is true iff at least one matched span was replaced with a
// placeholder (Action=mask). Spans with Action=allow are recorded but
// leave Masked false. Lets callers (e.g. the decision oracle)
// distinguish "matched and redacted" from "matched but passed through".
type Result struct {
Redacted string
Spans []Span
Blocked bool
LocalOnly bool
Redacted string
Spans []Span
Blocked bool
Masked bool
}
// Pattern is one configurable rule. Description is shown in the admin

View File

@@ -52,6 +52,10 @@ type EmbeddingCacheClassifier struct {
similarityThreshold float64
confidenceThreshold float64
// budget trims the conversation to the embedder model's own context
// before embedding; nil embeds Probe.Prompt as built by the caller.
budget *lazyBudget
hits atomic.Uint64
misses atomic.Uint64
nearMisses atomic.Uint64
@@ -100,6 +104,15 @@ func NewEmbeddingCacheClassifier(inner Classifier, embedder backend.Embedder, st
}
}
// WithTokenTrim wires the embedder model's own tokenizer and context so the
// probe embeds the most recent turns that fit instead of a caller-chosen size.
// nil tokenizer / non-positive context leaves trimming off. Returns the
// receiver for chaining at construction.
func (c *EmbeddingCacheClassifier) WithTokenTrim(tokenize func(string) (int, error), maxContextTokens int) *EmbeddingCacheClassifier {
c.budget = &lazyBudget{tokenize: tokenize, maxContext: maxContextTokens}
return c
}
// Name is the inner classifier's name — the decision-log "classifier"
// field should reflect *what* made the decision, not the caching
// transport. Cache hits set Decision.Cached separately so admins can
@@ -127,7 +140,7 @@ func (c *EmbeddingCacheClassifier) Stats() EmbeddingCacheStats {
func (c *EmbeddingCacheClassifier) Classify(ctx context.Context, p Probe) (Decision, error) {
start := time.Now()
vec, err := c.embedder.Embed(ctx, p.Prompt)
vec, err := c.embedder.Embed(ctx, trimmedProbeText(p, c.budget, identityRender))
if err != nil {
c.embedderErrors.Add(1)
xlog.Warn("router: embedding cache embed failed", "error", err)

View File

@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
@@ -13,6 +15,20 @@ import (
. "github.com/onsi/gomega"
)
// capturingEmbedder records the text it was last asked to embed and returns a
// fixed vector, so a test can assert what the cache fed the embedder.
type capturingEmbedder struct {
mu sync.Mutex
lastText string
}
func (e *capturingEmbedder) Embed(_ context.Context, text string) ([]float32, error) {
e.mu.Lock()
defer e.mu.Unlock()
e.lastText = text
return []float32{1, 2, 3}, nil
}
// fakeEmbedder returns a vector keyed by a lookup table; this lets the
// test exercise hit/miss control without depending on a real model.
type fakeEmbedder struct {
@@ -294,6 +310,45 @@ var _ = Describe("EmbeddingCache", func() {
})
})
var _ = Describe("EmbeddingCache WithTokenTrim", func() {
ctx := context.Background()
wordCount := func(s string) (int, error) { return len(strings.Fields(s)), nil }
It("embeds the most recent turns that fit the embedder context, not the full prompt", func() {
emb := &capturingEmbedder{}
store := &memVectorStore{}
inner := &stubInner{name: "score", decision: router.Decision{Labels: []string{"x"}, Score: 0.1}}
// context_size 50 → budget 5016 margin ≈ 34 tokens, far under the
// ~120-word transcript below, so the oldest turns must be dropped.
cache := router.NewEmbeddingCacheClassifier(inner, emb, store, 0.92, 0.6).
WithTokenTrim(wordCount, 50)
msgs := make([]string, 0, 31)
for i := range 30 {
msgs = append(msgs, fmt.Sprintf("OLDturn%d filler filler filler", i))
}
msgs = append(msgs, "NEWESTTURN final words here")
full := strings.Join(msgs, "\n")
_, err := cache.Classify(ctx, router.Probe{Prompt: full, Messages: msgs})
Expect(err).NotTo(HaveOccurred())
Expect(emb.lastText).To(ContainSubstring("NEWESTTURN"), "newest turn must survive")
Expect(emb.lastText).NotTo(ContainSubstring("OLDturn0 "), "oldest turns trimmed to fit context")
Expect(emb.lastText).NotTo(Equal(full), "must not embed the untrimmed prompt")
})
It("embeds Probe.Prompt unchanged when no trim is wired", func() {
emb := &capturingEmbedder{}
store := &memVectorStore{}
inner := &stubInner{name: "score", decision: router.Decision{Labels: []string{"x"}, Score: 0.1}}
cache := router.NewEmbeddingCacheClassifier(inner, emb, store, 0.92, 0.6)
_, err := cache.Classify(ctx, router.Probe{Prompt: "PROMPTASIS", Messages: []string{"ignored-no-tokenizer"}})
Expect(err).NotTo(HaveOccurred())
Expect(emb.lastText).To(Equal("PROMPTASIS"))
})
})
var _ = Describe("EmbeddingCache latency", func() {
It("is populated on hits", func() {
embedder := &fakeEmbedder{table: map[string][]float32{"p": {1}}}

View File

@@ -23,6 +23,11 @@ type RerankClassifier struct {
labels []string
documents []string
cache *labelSetCache
// budget trims the query to the reranker model's context minus the
// longest policy description (paired with the query per rerank call);
// nil reranks Probe.Prompt as built by the caller.
budget *lazyBudget
}
// defaultRerankActivationThreshold is the relevance floor a label
@@ -64,16 +69,26 @@ func NewRerankClassifier(policies []ScorePolicy, reranker backend.Reranker, cach
}
}
// WithTokenTrim wires the reranker model's own tokenizer and context so the
// query is trimmed to the most recent turns that fit alongside the longest
// policy description. nil tokenizer / non-positive context leaves trimming
// off. Returns the receiver for chaining at construction.
func (c *RerankClassifier) WithTokenTrim(tokenize func(string) (int, error), maxContextTokens int) *RerankClassifier {
c.budget = &lazyBudget{tokenize: tokenize, maxContext: maxContextTokens, extras: c.documents}
return c
}
func (c *RerankClassifier) Name() string { return ClassifierColbert }
func (c *RerankClassifier) Classify(ctx context.Context, p Probe) (Decision, error) {
start := time.Now()
key := cacheKey(p.Prompt)
query := trimmedProbeText(p, c.budget, identityRender)
key := cacheKey(query)
if hit, ok := c.cache.get(key); ok {
return Decision{Labels: hit, Score: 1.0, Latency: time.Since(start)}, nil
}
results, err := c.reranker.Rerank(ctx, p.Prompt, c.documents)
results, err := c.reranker.Rerank(ctx, query, c.documents)
if err != nil {
return errDecision(start, fmt.Errorf("rerank classify: %w", err))
}

View File

@@ -3,6 +3,8 @@ package router
import (
"context"
"errors"
"fmt"
"strings"
"github.com/mudler/LocalAI/core/backend"
. "github.com/onsi/ginkgo/v2"
@@ -43,6 +45,31 @@ var _ = Describe("RerankClassifier", func() {
Expect(d.Score).To(BeNumerically(">=", 0.9))
})
It("trims the query to the reranker context, keeping the newest turns", func() {
r := &stubReranker{results: []backend.RerankResult{
{Index: 0, RelevanceScore: 0.92},
{Index: 1, RelevanceScore: 0.10},
{Index: 2, RelevanceScore: 0.05},
}}
wordCount := func(s string) (int, error) { return len(strings.Fields(s)), nil }
// budget = 60 longest policy description 16 margin; still well under
// the ~120-word transcript, so the oldest turns drop.
c := NewRerankClassifier(testPolicies(), r, 0, 0).WithTokenTrim(wordCount, 60)
msgs := make([]string, 0, 31)
for i := range 30 {
msgs = append(msgs, fmt.Sprintf("OLDturn%d aaa bbb ccc", i))
}
msgs = append(msgs, "NEWESTTURN zzz")
full := strings.Join(msgs, "\n")
_, err := c.Classify(context.Background(), Probe{Prompt: full, Messages: msgs})
Expect(err).NotTo(HaveOccurred())
Expect(r.lastQ).To(ContainSubstring("NEWESTTURN"), "newest turn must survive")
Expect(r.lastQ).NotTo(ContainSubstring("OLDturn0 "), "oldest turns trimmed to fit context")
Expect(r.lastQ).NotTo(Equal(full), "must not rerank the untrimmed prompt")
})
It("activates multiple labels when several descriptions clear threshold", func() {
r := &stubReranker{results: []backend.RerankResult{
{Index: 0, RelevanceScore: 0.85},

View File

@@ -91,6 +91,13 @@ type ScoreClassifierOptions struct {
// override that instructs the model to emit a different schema
// would silently desync from what the scorer actually scores.
SystemPromptTemplate string
// TokenCounter + MaxContextTokens drive conversation trimming: when
// both are set, Classify drops the oldest turns until the rendered
// prompt fits the classifier's context. Nil/0 disables — Classify
// sends Probe.Prompt as-is and relies on the backend's n_ctx guard.
TokenCounter func(string) (int, error)
MaxContextTokens int
}
// ScoreClassifier scores every policy label as the model's actual
@@ -127,6 +134,10 @@ type ScoreClassifier struct {
// log-prob. Built once at construction; same list every call.
candidates []string
// budget caps the rendered prompt at the classifier's context minus the
// longest candidate; nil/disabled sends Probe.Prompt as-is.
budget *lazyBudget
cache *labelSetCache
}
@@ -191,6 +202,7 @@ func NewScoreClassifier(policies []ScorePolicy, scorer backend.Scorer, opts Scor
systemPrompt: systemPrompt,
labelOrder: labels,
candidates: candidates,
budget: &lazyBudget{tokenize: opts.TokenCounter, maxContext: opts.MaxContextTokens, extras: candidates},
cache: newLabelSetCache(opts.CacheCap),
}
}
@@ -218,11 +230,19 @@ func (c *ScoreClassifier) Name() string { return ClassifierScore }
func (c *ScoreClassifier) Classify(ctx context.Context, p Probe) (Decision, error) {
start := time.Now()
key := cacheKey(p.Prompt)
// Trim oldest turns until the rendered prompt fits the classifier's
// context. Cache-keyed on the trimmed text so conversations that
// trim to the same tail share an entry.
userText := trimmedProbeText(p, c.budget, func(joined string) (string, error) {
return c.renderer(c.systemPrompt, joined)
})
key := cacheKey(userText)
if hit, ok := c.cache.get(key); ok {
return Decision{Labels: hit, Score: 1.0, Latency: time.Since(start)}, nil
}
prompt, err := c.renderer(c.systemPrompt, p.Prompt)
prompt, err := c.renderer(c.systemPrompt, userText)
if err != nil {
return errDecision(start, fmt.Errorf("score classify: render prompt: %w", err))
}
@@ -331,6 +351,12 @@ func softmax(logProbs []float64) []float64 {
func (c *ScoreClassifier) CacheLen() int { return c.cache.len() }
// probeTokenBudget returns the token ceiling for the rendered prompt (context
// longest candidate margin), computed once via the shared lazyBudget. 0
// means trimming is off (no tokenizer/context) or impossible (candidates fill
// the context).
func (c *ScoreClassifier) probeTokenBudget() int { return c.budget.get() }
// buildScoreSystemPrompt renders the Arch-Router-style routing
// instructions: routes listed in a structured block, output schema
// declared as JSON {"route": "<name>"}. Candidates are scored as

View File

@@ -3,8 +3,10 @@ package router
import (
"context"
"errors"
"fmt"
"sort"
"strings"
"unicode/utf8"
"github.com/mudler/LocalAI/core/backend"
. "github.com/onsi/ginkgo/v2"
@@ -335,3 +337,138 @@ Reply: {"route": "<name>"}`
Expect(c.Name()).To(Equal(ClassifierScore))
})
})
var _ = Describe("ScoreClassifier conversation trimming", func() {
wordCount := func(s string) (int, error) { return len(strings.Fields(s)), nil }
threeScores := []backend.CandidateScore{
{LogProb: -0.05, NumTokens: 3},
{LogProb: -3.0, NumTokens: 3},
{LogProb: -4.0, NumTokens: 3},
}
It("drops the oldest turns when the conversation exceeds the context budget", func() {
s := &stubScorer{results: threeScores}
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
TokenCounter: wordCount,
MaxContextTokens: 10000,
})
Expect(c.probeTokenBudget()).To(BeNumerically(">", 0), "budget should be positive for a 10k context")
msgs := make([]string, 0, 200)
msgs = append(msgs, "OLDESTMARKER "+strings.Repeat("x ", 99)) // 100 words
for range 198 {
msgs = append(msgs, strings.Repeat("y ", 100))
}
msgs = append(msgs, "NEWESTMARKER "+strings.Repeat("z ", 99)) // 100 words; ~20k words total
_, err := c.Classify(context.Background(), Probe{Messages: msgs, Prompt: strings.Join(msgs, "\n")})
Expect(err).NotTo(HaveOccurred())
Expect(s.lastP).To(ContainSubstring("NEWESTMARKER"), "newest turn must survive the trim")
Expect(s.lastP).NotTo(ContainSubstring("OLDESTMARKER"), "oldest turn must be dropped")
Expect(len(strings.Fields(s.lastP))).To(BeNumerically("<", 20000), "must be trimmed, not the full transcript")
})
It("keeps the newest turn whole even when it alone exceeds the budget", func() {
s := &stubScorer{results: threeScores}
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
TokenCounter: wordCount,
MaxContextTokens: 10000,
})
msgs := []string{
"OLDMARKER short",
"NEWESTMARKER " + strings.Repeat("z ", 12000), // far over budget
}
_, err := c.Classify(context.Background(), Probe{Messages: msgs})
Expect(err).NotTo(HaveOccurred())
Expect(s.lastP).To(ContainSubstring("NEWESTMARKER"))
Expect(s.lastP).NotTo(ContainSubstring("OLDMARKER"), "older turn drops once the newest fills the budget")
})
It("does not tokenize per message and bounds what it tokenizes for a long conversation", func() {
// Regression: the original trim tokenized one message at a time,
// newest-first, so a 500-turn conversation produced hundreds of
// tokenize RPCs. The render-once design must tokenize the candidates
// (budget setup) plus a small constant for the measurement/confirm
// passes — and the rune pre-trim must keep the tokenized prompt far
// smaller than the full transcript.
calls := 0
maxRunes := 0
counting := func(s string) (int, error) {
calls++
if r := utf8.RuneCountInString(s); r > maxRunes {
maxRunes = r
}
return len(strings.Fields(s)), nil
}
s := &stubScorer{results: threeScores}
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
TokenCounter: counting,
MaxContextTokens: 4000,
})
msgs := make([]string, 500)
totalRunes := 0
for i := range msgs {
msgs[i] = fmt.Sprintf("msg%d %s", i, strings.Repeat("w ", 50))
totalRunes += utf8.RuneCountInString(msgs[i])
}
_, err := c.Classify(context.Background(), Probe{Messages: msgs})
Expect(err).NotTo(HaveOccurred())
Expect(s.lastP).To(ContainSubstring("msg499"), "newest turn must survive")
Expect(s.lastP).NotTo(ContainSubstring("msg0 "), "oldest turns must be dropped")
Expect(calls).To(BeNumerically("<", 20),
"tokenizer must not be called per message (got %d calls for 500 messages)", calls)
Expect(maxRunes).To(BeNumerically("<", totalRunes/2),
"rune pre-trim must keep the tokenized prompt well under the full transcript")
})
It("uses Probe.Prompt unchanged when no tokenizer is wired", func() {
s := &stubScorer{results: threeScores}
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{})
Expect(c.probeTokenBudget()).To(Equal(0))
_, err := c.Classify(context.Background(), Probe{
Prompt: "PROMPTONLYMARKER",
Messages: []string{"ignored-because-no-tokenizer"},
})
Expect(err).NotTo(HaveOccurred())
Expect(s.lastP).To(ContainSubstring("PROMPTONLYMARKER"))
Expect(s.lastP).NotTo(ContainSubstring("ignored-because-no-tokenizer"))
})
It("disables trimming (budget 0) when the tokenizer errors", func() {
s := &stubScorer{results: threeScores}
boom := func(string) (int, error) { return 0, errors.New("tokenizer down") }
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
TokenCounter: boom,
MaxContextTokens: 10000,
})
Expect(c.probeTokenBudget()).To(Equal(0), "a tokenizer error must disable trimming, not panic")
_, err := c.Classify(context.Background(), Probe{Prompt: "FALLBACKMARKER", Messages: []string{"a", "b"}})
Expect(err).NotTo(HaveOccurred())
Expect(s.lastP).To(ContainSubstring("FALLBACKMARKER"))
})
It("retries the budget after a TRANSIENT tokenizer error instead of disabling permanently", func() {
// Regression: a sync.Once would memoize the first failure and never
// recompute. The first call (model still loading) errors; a later
// call must succeed and yield a real budget.
s := &stubScorer{results: threeScores}
calls := 0
flaky := func(text string) (int, error) {
calls++
if calls == 1 {
return 0, errors.New("model still loading")
}
return len(strings.Fields(text)), nil
}
c := NewScoreClassifier(testPolicies(), s, ScoreClassifierOptions{
TokenCounter: flaky,
MaxContextTokens: 10000,
})
Expect(c.probeTokenBudget()).To(Equal(0), "first call: tokenizer error leaves budget uncomputed")
Expect(c.probeTokenBudget()).To(BeNumerically(">", 0), "retry: budget computes once the tokenizer recovers")
})
})

View File

@@ -0,0 +1,178 @@
package router
import (
"math"
"strings"
"sync"
"sync/atomic"
"unicode/utf8"
"github.com/mudler/xlog"
)
// pretrimRunesPerToken is deliberately high (most text is 35 runes/token,
// tokenisers rarely exceed 6) so the cheap rune pre-trim keeps a superset of
// what fits before any tokenize call.
const pretrimRunesPerToken = 6
// tokenBudgetMargin absorbs BPE-boundary drift and the framing tokens a
// renderer adds, so a prompt measured at exactly the budget still fits n_ctx.
const tokenBudgetMargin = 16
// JoinTurns joins per-turn texts oldest→newest with a trailing newline each.
// The probe builder, the trimmer, and every classifier share this so the text
// a model sees has one canonical shape.
func JoinTurns(turns []string) string {
var b strings.Builder
for _, m := range turns {
b.WriteString(m)
b.WriteByte('\n')
}
return b.String()
}
// promptTrimmer fits an oldest→newest turn list into a token budget for one
// model: optimistic rune pre-trim, tokenize once, then recalibrate with the
// real runes/token and drop whole turns oldest-first until the rendered prompt
// fits. The newest turn is never dropped — if it alone overflows it's sent
// whole and the backend's n_ctx guard is the backstop.
//
// render wraps the joined turns into what the model actually tokenizes: a chat
// template for the scorer, identityRender for an embedder/reranker on raw text.
type promptTrimmer struct {
tokenize func(string) (int, error)
render func(joined string) (string, error)
budget int
}
func identityRender(s string) (string, error) { return s, nil }
func (t promptTrimmer) fit(turns []string) string {
if len(turns) == 0 {
return ""
}
kept := turns[runePretrimStart(turns, t.budget*pretrimRunesPerToken):]
joined := JoinTurns(kept)
rendered, err := t.render(joined)
if err != nil {
return joined
}
total, err := t.tokenize(rendered)
if err != nil || total <= t.budget {
return joined
}
runesPerToken := float64(utf8.RuneCountInString(rendered)) / float64(total)
if runesPerToken <= 0 {
runesPerToken = 1
}
est := total
keep := 0
for keep < len(kept)-1 && est > t.budget {
est -= int(math.Ceil(float64(utf8.RuneCountInString(kept[keep])) / runesPerToken))
keep++
}
for {
tail := JoinTurns(kept[keep:])
rendered, err := t.render(tail)
if err != nil {
return tail
}
n, err := t.tokenize(rendered)
if err != nil || n <= t.budget {
return tail
}
if keep >= len(kept)-1 {
xlog.Warn("router: newest turn alone exceeds model context; sending it whole — backend n_ctx guard is the backstop",
"tokens", n, "budget", t.budget)
return tail
}
keep++
}
}
// runePretrimStart returns the oldest index to keep so the joined tail stays
// within budgetRunes. The newest turn is always kept; older ones are added
// while they fit.
func runePretrimStart(turns []string, budgetRunes int) int {
if budgetRunes <= 0 || len(turns) == 0 {
return 0
}
start := len(turns) - 1
total := utf8.RuneCountInString(turns[start])
for i := len(turns) - 2; i >= 0; i-- {
r := utf8.RuneCountInString(turns[i])
if total+r > budgetRunes {
break
}
total += r
start = i
}
return start
}
// lazyBudget computes a model's probe token budget once, on first use, caching
// the result: maxContext minus the longest per-call extra (scorer candidates,
// reranker documents; none for a plain embed) minus tokenBudgetMargin. A
// tokenizer error leaves it uncomputed so a transient failure (model still
// loading) recovers on a later call; extras that already fill the context are
// cached as disabled.
type lazyBudget struct {
tokenize func(string) (int, error)
maxContext int
extras []string
mu sync.Mutex
value atomic.Int64 // 0=unset, >0=budget, -1=disabled
}
func (l *lazyBudget) get() int {
if l == nil || l.tokenize == nil || l.maxContext <= 0 {
return 0
}
if v := l.value.Load(); v != 0 {
if v < 0 {
return 0
}
return int(v)
}
l.mu.Lock()
defer l.mu.Unlock()
if v := l.value.Load(); v != 0 {
if v < 0 {
return 0
}
return int(v)
}
longest := 0
for _, e := range l.extras {
n, err := l.tokenize(e)
if err != nil {
return 0 // transient: leave unset so a later call retries
}
if n > longest {
longest = n
}
}
b := l.maxContext - longest - tokenBudgetMargin
if b <= 0 {
l.value.Store(-1)
return 0
}
l.value.Store(int64(b))
return b
}
// trimmedProbeText returns the text to feed a model: the most recent turns
// that fit its token budget, or p.Prompt when trimming is disabled (no
// tokenizer/context wired, or a single-input probe with no Messages).
func trimmedProbeText(p Probe, b *lazyBudget, render func(string) (string, error)) string {
if len(p.Messages) > 0 {
if budget := b.get(); budget > 0 {
return promptTrimmer{tokenize: b.tokenize, render: render, budget: budget}.fit(p.Messages)
}
}
return p.Prompt
}

View File

@@ -31,6 +31,15 @@ type Probe struct {
// is the concatenation of message contents (separated by newlines);
// for plain completions it is the raw prompt.
Prompt string
// Messages carries the per-turn texts (oldest→newest) when the probe
// came from a multi-message chat request. A classifier with a real
// tokenizer (the score classifier) uses these to trim an over-long
// conversation to the classifier model's context window on turn
// boundaries, keeping the most recent turns. Empty for single-input
// probes (plain completions, /router/decide), in which case the
// classifier falls back to Prompt verbatim.
Messages []string
}
// Decision is the classifier's output. Labels carries the SET of