mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-30 03:25:42 -04:00
Add a routing middleware stack and a cloud-proxy backend. * cloud-proxy: a Go gRPC backend that forwards OpenAI- and Anthropic-shaped chat requests to upstream providers, with an optional translate mode (OpenAI request -> Anthropic /v1/messages -> OpenAI response) and full tool-calling support. * routing: admission control, content-aware model routing (embedding cache + classifier + rerank + Arch-Router score), PII detection/redaction (regex + NER) with streaming filter and OpenAI/Anthropic adapters, and a per-user/per-key billing recorder backed by GORM or in-memory storage. * middleware: UsageMiddleware records usage via the billing recorder, plus admission, route-model, usage-stamp and trace middlewares. * observability: BackendTrace ring buffer stores full request bodies (capped), MITM proxy emits structured trace events, and router classifier decisions surface at /api/router/decide. * gallery: Arch-Router-1.5B (Q4_K_M and Q8_0). * UI: cloud-proxy model-editor fields, classifier system-prompt and score-normalization config, and a Traces page rendering request bodies. Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash] Signed-off-by: Richard Palethorpe <io@richiejp.com>
604 lines
24 KiB
Go
604 lines
24 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"hash/fnv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/mudler/LocalAI/core/backend"
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/http/auth"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/core/services/routing/router"
|
|
"github.com/mudler/LocalAI/core/templates"
|
|
"github.com/mudler/xlog"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
// ScorerFactory returns a backend.Scorer bound to a named classifier
|
|
// model. The score classifier uses it to compute joint log-prob of
|
|
// every policy label against the routing prompt.
|
|
type ScorerFactory func(modelName string) backend.Scorer
|
|
|
|
// EmbedderFactory returns a backend.Embedder bound to a named model.
|
|
// Used by the L2 embedding cache. Returning nil signals "model not
|
|
// loadable" — the middleware then falls back to the uncached
|
|
// classifier so routing still happens.
|
|
type EmbedderFactory func(modelName string) backend.Embedder
|
|
|
|
// VectorStoreFactory returns a backend.VectorStore bound to a named
|
|
// collection. Each router model's cache lives in its own collection
|
|
// so two routers can't poison each other's hits.
|
|
type VectorStoreFactory func(storeName string) backend.VectorStore
|
|
|
|
// RerankerFactory returns a backend.Reranker bound to a named model.
|
|
// Used by the colbert classifier to score policy descriptions against
|
|
// the prompt via LocalAI's rerankers backend. Returning nil signals
|
|
// "model not loadable" — buildClassifier reports a config error.
|
|
type RerankerFactory func(modelName string) backend.Reranker
|
|
|
|
// ModelConfigLookup resolves a model name to its config, or nil when
|
|
// unknown. Used by buildClassifier to confirm the classifier_model
|
|
// declared the score usecase — the actual usecase-conflict check
|
|
// lives in ModelConfig.Validate() and runs at config load/save time.
|
|
type ModelConfigLookup func(modelName string) *config.ModelConfig
|
|
|
|
// ClassifierDeps bundles the backend factories the router middleware
|
|
// needs to build a classifier and its optional L2 cache. Bundled into
|
|
// one struct because RouteModel already takes many positional
|
|
// arguments — additions to the dependency surface go here instead of
|
|
// growing the signature.
|
|
//
|
|
// Embedder and VectorStore are optional: when both are non-nil and the
|
|
// router config declares an embedding_cache block, the score
|
|
// classifier is wrapped in EmbeddingCacheClassifier. Otherwise the
|
|
// score classifier runs unwrapped and the embedding-cache YAML is
|
|
// ignored with a warning.
|
|
type ClassifierDeps struct {
|
|
Scorer ScorerFactory
|
|
Embedder EmbedderFactory
|
|
VectorStore VectorStoreFactory
|
|
Reranker RerankerFactory
|
|
|
|
// ModelLookup resolves the classifier_model name to its config so
|
|
// buildClassifier can reject misconfigurations that would
|
|
// otherwise crash the llama-cpp backend at request time. Optional
|
|
// — when nil, the check is skipped (tests, embedded callers that
|
|
// haven't wired the loader).
|
|
ModelLookup ModelConfigLookup
|
|
|
|
// Registry is the shared classifier cache. Both the OpenAI and
|
|
// Anthropic routes pass the same registry so the admin stats
|
|
// endpoint sees every live classifier. Nil falls back to a local
|
|
// registry — tests that don't need cross-route stats use this.
|
|
Registry *router.Registry
|
|
|
|
// Evaluator renders the classifier model's chat template around
|
|
// the routing system + user prompt. Optional — when nil, the
|
|
// score classifier falls back to a built-in ChatML envelope,
|
|
// which is correct for Arch-Router/Qwen but wrong for non-ChatML
|
|
// routing models. Production wiring passes the app-wide
|
|
// templates.Evaluator so any model the operator points at gets
|
|
// its own chat template applied.
|
|
Evaluator *templates.Evaluator
|
|
}
|
|
|
|
// ProbeExtractor pulls the prompt content out of a parsed request so
|
|
// the classifier can inspect it without taking a dependency on the
|
|
// schema package. One extractor per request shape — wired by the
|
|
// route registration site (mirrors the piiadapter pattern).
|
|
//
|
|
// Returns ok=false when the parsed value isn't the expected type — the
|
|
// middleware then passes through without engaging the router.
|
|
type ProbeExtractor func(parsed any) (router.Probe, bool)
|
|
|
|
// RouteModel runs after SetModelAndConfig and the schema-specific
|
|
// SetXRequest, looks at the resolved model's Router config, and (when
|
|
// present) reclassifies the request to one of the candidates.
|
|
//
|
|
// The middleware:
|
|
//
|
|
// 1. Loads MODEL_CONFIG from the echo context. If nil or HasRouter()
|
|
// is false, passes through.
|
|
// 2. Extracts the probe via the supplied ProbeExtractor.
|
|
// 3. Invokes the classifier matching cfg.Router.Classifier
|
|
// ("score" or "colbert"). If the classifier can't be built —
|
|
// missing classifier_model, misconfigured policies, etc. — the
|
|
// request fails with 503. cfg.Router.Fallback only catches
|
|
// Classify-time errors and label-coverage misses, not config
|
|
// bugs that would otherwise be silent.
|
|
// 4. Resolves the chosen candidate to its model name. Reloads the
|
|
// ModelConfig for that model and asserts depth-1 (the candidate
|
|
// must NOT itself have a Router). Violation returns 500 — config
|
|
// bug, not a request bug.
|
|
// 5. Updates input.Model in place, replaces MODEL_CONFIG with the
|
|
// candidate's config, and stamps RequestedModel/ServedModel on the
|
|
// context so UsageMiddleware records the routing.
|
|
// 6. Writes a DecisionRecord to the store for the admin page.
|
|
//
|
|
// store may be nil when --disable-stats turns off the routing log;
|
|
// classification still runs.
|
|
//
|
|
// Composition with SmartRouter (distributed mode): this middleware
|
|
// only does *model* selection. Node selection still happens in
|
|
// SmartRouter.Route() downstream of this middleware.
|
|
// RouteModel wires the router middleware. source is the value written to
|
|
// DecisionRecord.Source (router.SourceChat / SourceAnthropic / ...) so
|
|
// the admin page can split decisions by entry point. Pass
|
|
// router.SourceChat for the OpenAI chat endpoint, router.SourceAnthropic
|
|
// for the Anthropic messages endpoint.
|
|
func RouteModel(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, store router.DecisionStore, fallbackUser *auth.User, extractor ProbeExtractor, source string, deps ClassifierDeps) echo.MiddlewareFunc {
|
|
registry := deps.Registry
|
|
if registry == nil {
|
|
registry = router.NewRegistry()
|
|
}
|
|
candidateLoader := func(name string) (*config.ModelConfig, error) {
|
|
return loader.LoadModelConfigFileByNameDefaultOptions(name, appConfig)
|
|
}
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
|
if !ok || cfg == nil || !cfg.HasRouter() {
|
|
return next(c)
|
|
}
|
|
|
|
parsed := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST)
|
|
if parsed == nil {
|
|
return next(c)
|
|
}
|
|
|
|
probe, probeOK := extractor(parsed)
|
|
if !probeOK {
|
|
return next(c)
|
|
}
|
|
|
|
classifier, err := GetOrBuildClassifier(registry, cfg, deps)
|
|
if err != nil {
|
|
// Build-time failures are config bugs (missing
|
|
// classifier_model, undeclared usecase, policy
|
|
// validation, ...). Silently falling back would hide
|
|
// them and make the router look "working" while the
|
|
// classifier model is never invoked — surface as 503
|
|
// with the underlying reason so operators see it.
|
|
xlog.Warn("router: classifier build failed",
|
|
"router_model", cfg.Name, "classifier", cfg.Router.Classifier, "error", err)
|
|
return echo.NewHTTPError(503, "router classifier unavailable: "+err.Error())
|
|
}
|
|
|
|
result, err := router.Resolve(c.Request().Context(), cfg, classifier, candidateLoader, probe)
|
|
if err != nil {
|
|
xlog.Warn("router: resolve failed", "router_model", cfg.Name, "error", err)
|
|
return echo.NewHTTPError(500, err.Error())
|
|
}
|
|
|
|
if req, ok := parsed.(schema.LocalAIRequest); ok {
|
|
chosen := result.ChosenModel
|
|
req.ModelName(&chosen)
|
|
}
|
|
|
|
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, result.ChosenConfig)
|
|
c.Set(ContextKeyRequestedModel, result.RouterModel)
|
|
c.Set(ContextKeyServedModel, result.ChosenModel)
|
|
|
|
if store != nil {
|
|
recordHTTPDecision(c, store, result, fallbackUser, source)
|
|
}
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// recordHTTPDecision writes the resolved decision to the store with
|
|
// HTTP-shaped audit metadata (correlation id from header, user from
|
|
// auth middleware, fallback to the synthetic local user). Realtime
|
|
// has its own recorder that supplies session-derived metadata
|
|
// instead.
|
|
func recordHTTPDecision(c echo.Context, store router.DecisionStore, result *router.ResolveResult, fallbackUser *auth.User, source string) {
|
|
correlationID, _ := c.Get(ContextKeyCorrelationID).(string)
|
|
if correlationID == "" {
|
|
correlationID = c.Response().Header().Get("X-Correlation-ID")
|
|
}
|
|
userID := ""
|
|
if u := auth.GetUser(c); u != nil {
|
|
userID = u.ID
|
|
} else if fallbackUser != nil {
|
|
userID = fallbackUser.ID
|
|
}
|
|
_ = store.Record(context.Background(), result.ToDecisionRecord(newDecisionID(), correlationID, userID, source))
|
|
}
|
|
|
|
|
|
// GetOrBuildClassifier looks up a built Classifier for the named router
|
|
// model in the registry and builds it on miss. Exported so the
|
|
// /api/router/decide decision-oracle endpoint can share the same
|
|
// build-once cache that the in-band RouteModel middleware uses.
|
|
func GetOrBuildClassifier(registry *router.Registry, cfg *config.ModelConfig, deps ClassifierDeps) (router.Classifier, error) {
|
|
// Fingerprint folds the classifier model's renderer-affecting
|
|
// fields (chat templates + stopwords) in alongside the router
|
|
// config. Without this, hot-reloading the classifier model's
|
|
// YAML (via ReloadModelsEndpoint, /import-model, or the MCP
|
|
// reload_models tool) wouldn't rebuild the cached classifier —
|
|
// the candidates slice and renderer closure are baked at build
|
|
// time from those fields and would silently keep the stale
|
|
// stop token / template until process restart.
|
|
var classifierCfg *config.ModelConfig
|
|
if deps.ModelLookup != nil {
|
|
classifierCfg = deps.ModelLookup(cfg.Router.ClassifierModel)
|
|
}
|
|
fp := routerConfigFingerprint(cfg.Router, classifierCfg)
|
|
if cached, ok := registry.Get(cfg.Name, fp); ok {
|
|
return cached, nil
|
|
}
|
|
c, err := buildClassifier(cfg, deps)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
registry.Put(cfg.Name, fp, c)
|
|
return c, nil
|
|
}
|
|
|
|
// routerConfigFingerprint is a stable cache key for the (router cfg,
|
|
// classifier model cfg) tuple. FNV-64 over the YAML form of the
|
|
// router block plus the renderer-affecting fields of the classifier
|
|
// model — equality-only, not cryptographic. YAML-marshal picks up
|
|
// any future RouterConfig field without this function needing to be
|
|
// touched; for the classifier model we hash a narrow projection so
|
|
// unrelated changes (parameters, files, ...) don't burst the cache.
|
|
// Pass classifierCfg=nil when no lookup is wired — the fingerprint
|
|
// degenerates to the router-only form, matching pre-refactor behaviour.
|
|
func routerConfigFingerprint(rc config.RouterConfig, classifierCfg *config.ModelConfig) uint64 {
|
|
bytes, err := yaml.Marshal(rc)
|
|
if err != nil {
|
|
// Marshalling a value type can't fail in practice; fall
|
|
// back to a hash that varies per call so we don't quietly
|
|
// share a cache entry across distinct configs.
|
|
return uint64(time.Now().UnixNano())
|
|
}
|
|
h := fnv.New64a()
|
|
h.Write(bytes)
|
|
if classifierCfg != nil {
|
|
// Narrow projection: only the fields newTemplateRenderer and
|
|
// firstStopWord actually read. Hashing the whole ModelConfig
|
|
// would invalidate the cache on irrelevant parameter changes.
|
|
h.Write([]byte{0}) // separator so empty fields don't collide
|
|
h.Write([]byte(classifierCfg.TemplateConfig.Chat))
|
|
h.Write([]byte{0})
|
|
h.Write([]byte(classifierCfg.TemplateConfig.ChatMessage))
|
|
h.Write([]byte{0})
|
|
for _, sw := range classifierCfg.StopWords {
|
|
h.Write([]byte(sw))
|
|
h.Write([]byte{0})
|
|
}
|
|
}
|
|
return h.Sum64()
|
|
}
|
|
|
|
func buildClassifier(cfg *config.ModelConfig, deps ClassifierDeps) (router.Classifier, error) {
|
|
rc := cfg.Router
|
|
name := rc.Classifier
|
|
if name == "" {
|
|
name = router.ClassifierScore
|
|
}
|
|
policies, err := validateRouterPolicies(name, rc)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cacheCap := rc.ClassifierCacheSize
|
|
if cacheCap == 0 {
|
|
cacheCap = 1024
|
|
}
|
|
|
|
var inner router.Classifier
|
|
switch name {
|
|
case router.ClassifierScore:
|
|
if deps.Scorer == nil {
|
|
return nil, fmt.Errorf("router classifier score unavailable: no scorer factory wired")
|
|
}
|
|
if err := assertClassifierDeclaresScore(rc.ClassifierModel, deps.ModelLookup); err != nil {
|
|
return nil, err
|
|
}
|
|
scorer := deps.Scorer(rc.ClassifierModel)
|
|
if scorer == nil {
|
|
return nil, fmt.Errorf("router classifier score: classifier_model %q not loadable", rc.ClassifierModel)
|
|
}
|
|
opts := router.ScoreClassifierOptions{
|
|
CacheCap: cacheCap,
|
|
ActivationThreshold: rc.ActivationThreshold,
|
|
Normalization: rc.ScoreNormalization,
|
|
SystemPromptTemplate: rc.ClassifierSystemTemplate,
|
|
}
|
|
// Build the prompt renderer + stop token from the classifier
|
|
// model's own config when available. Without ModelLookup
|
|
// (tests, embedded callers) the score classifier's built-in
|
|
// ChatML defaults kick in, which is correct for Arch-Router.
|
|
if deps.ModelLookup != nil {
|
|
if classifierCfg := deps.ModelLookup(rc.ClassifierModel); classifierCfg != nil {
|
|
if deps.Evaluator != nil {
|
|
opts.PromptRenderer = newTemplateRenderer(deps.Evaluator, classifierCfg)
|
|
}
|
|
if st := pickAssistantTurnEnd(classifierCfg.StopWords, classifierCfg.TemplateConfig.ChatMessage); st != "" {
|
|
opts.StopToken = st
|
|
}
|
|
}
|
|
}
|
|
inner = router.NewScoreClassifier(policies, scorer, opts)
|
|
case router.ClassifierColbert:
|
|
if deps.Reranker == nil {
|
|
return nil, fmt.Errorf("router classifier colbert unavailable: no reranker factory wired")
|
|
}
|
|
reranker := deps.Reranker(rc.ClassifierModel)
|
|
if reranker == nil {
|
|
return nil, fmt.Errorf("router classifier colbert: classifier_model %q not loadable", rc.ClassifierModel)
|
|
}
|
|
inner = router.NewRerankClassifier(policies, reranker, cacheCap, rc.ActivationThreshold)
|
|
default:
|
|
return nil, fmt.Errorf("router: unknown classifier %q (supported: %s)", name, strings.Join([]string{router.ClassifierScore, router.ClassifierColbert}, ", "))
|
|
}
|
|
|
|
if rc.EmbeddingCache == nil {
|
|
return inner, nil
|
|
}
|
|
wrapped, err := wrapWithEmbeddingCache(cfg, inner, deps)
|
|
if err != nil {
|
|
// Caching plumbing problems must not break routing — log,
|
|
// drop the cache layer, and return the uncached classifier.
|
|
// The admin UI surfaces the warning via the classifier-build
|
|
// error path used elsewhere.
|
|
xlog.Warn("router: embedding cache disabled",
|
|
"router_model", cfg.Name, "error", err)
|
|
return inner, nil
|
|
}
|
|
return wrapped, nil
|
|
}
|
|
|
|
// assertClassifierDeclaresScore refuses to build the score classifier
|
|
// unless classifier_model's config declares FLAG_SCORE. The actual
|
|
// usecase-conflict check (score + chat/completion/embeddings on
|
|
// llama-cpp) lives in ModelConfig.Validate() and fires at config load
|
|
// and save time — by the time we get here, any model that reached the
|
|
// loader is already conflict-free. This check just refuses to bind a
|
|
// model that never declared itself for Score in the first place; that
|
|
// model could be a misconfigured chat model the operator pointed at
|
|
// by accident, and without FLAG_SCORE the validator never saw it.
|
|
//
|
|
// When lookup is nil (test wiring) the check is skipped and we fall
|
|
// back to the C++ backend's runtime tripwire as the last line of
|
|
// defence.
|
|
func assertClassifierDeclaresScore(classifierModel string, lookup ModelConfigLookup) error {
|
|
if lookup == nil {
|
|
return nil
|
|
}
|
|
cfg := lookup(classifierModel)
|
|
if cfg == nil {
|
|
// Unknown model — Scorer() will produce a clearer "not
|
|
// loadable" error a few lines down.
|
|
return nil
|
|
}
|
|
if !cfg.HasUsecases(config.FLAG_SCORE) {
|
|
return fmt.Errorf(
|
|
"router classifier score: classifier_model %q does not declare the "+
|
|
"score usecase. Add `known_usecases: [score]` to its config so "+
|
|
"the loader can reject conflicting usecase combinations",
|
|
classifierModel)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// validateRouterPolicies checks the shared invariants both classifiers
|
|
// rely on (non-empty policies, every candidate label declared as a
|
|
// policy, every candidate has a model + at least one label) and
|
|
// returns the parsed []ScorePolicy. Both Score and Rerank classifiers
|
|
// take the same policy shape.
|
|
func validateRouterPolicies(classifierName string, rc config.RouterConfig) ([]router.ScorePolicy, error) {
|
|
if rc.ClassifierModel == "" {
|
|
return nil, fmt.Errorf("router classifier %s requires classifier_model", classifierName)
|
|
}
|
|
if len(rc.Policies) == 0 {
|
|
return nil, fmt.Errorf("router classifier %s requires at least one policy", classifierName)
|
|
}
|
|
policies := make([]router.ScorePolicy, 0, len(rc.Policies))
|
|
for _, p := range rc.Policies {
|
|
if p.Label == "" {
|
|
return nil, fmt.Errorf("router classifier %s: policy with empty label", classifierName)
|
|
}
|
|
if p.Description == "" {
|
|
return nil, fmt.Errorf("router classifier %s: policy %q has no description", classifierName, p.Label)
|
|
}
|
|
policies = append(policies, router.ScorePolicy{Label: p.Label, Description: p.Description})
|
|
}
|
|
policyLabels := make(map[string]struct{}, len(policies))
|
|
for _, p := range policies {
|
|
policyLabels[p.Label] = struct{}{}
|
|
}
|
|
for _, c := range rc.Candidates {
|
|
if c.Model == "" {
|
|
return nil, fmt.Errorf("router classifier %s: candidate has empty model field", classifierName)
|
|
}
|
|
if len(c.Labels) == 0 {
|
|
return nil, fmt.Errorf("router classifier %s: candidate %q has no labels", classifierName, c.Model)
|
|
}
|
|
for _, l := range c.Labels {
|
|
if _, ok := policyLabels[l]; !ok {
|
|
return nil, fmt.Errorf("router classifier %s: candidate %q references unknown label %q (not in policies)", classifierName, c.Model, l)
|
|
}
|
|
}
|
|
}
|
|
return policies, nil
|
|
}
|
|
|
|
// newTemplateRenderer adapts the templates.Evaluator + the classifier
|
|
// model's config into the router.PromptRenderer callback. The
|
|
// resulting renderer pushes the routing system + user prompt through
|
|
// the classifier model's full chat-template pipeline — per-role
|
|
// formatting via TemplateConfig.ChatMessage, then the outer
|
|
// TemplateConfig.Chat — so non-ChatML routing models render
|
|
// correctly without router-package awareness of the template format.
|
|
//
|
|
// We must go through TemplateMessages, not EvaluateTemplateForPrompt
|
|
// directly: the gallery's outer Chat templates are uniformly
|
|
// `{{.Input -}}<|im_start|>assistant` (or the Llama-3 equivalent)
|
|
// and reference {{.Input}} only — never {{.SystemPrompt}}. Passing
|
|
// our routing system prompt through .SystemPrompt would silently
|
|
// drop it because Go text/template ignores unreferenced fields.
|
|
// TemplateMessages instead renders each role through ChatMessage and
|
|
// joins them into the .Input the outer template DOES read.
|
|
//
|
|
// Returns nil (forcing the score classifier's chatMLRenderer
|
|
// fallback) when either template piece is missing — partial
|
|
// templating would still drop content.
|
|
func newTemplateRenderer(eval *templates.Evaluator, classifierCfg *config.ModelConfig) router.PromptRenderer {
|
|
if classifierCfg.TemplateConfig.Chat == "" || classifierCfg.TemplateConfig.ChatMessage == "" {
|
|
return nil
|
|
}
|
|
cfgCopy := *classifierCfg
|
|
return func(system, user string) (string, error) {
|
|
messages := []schema.Message{
|
|
{Role: "system", StringContent: system},
|
|
{Role: "user", StringContent: user},
|
|
}
|
|
rendered := eval.TemplateMessages(schema.OpenAIRequest{}, messages, &cfgCopy, nil, false)
|
|
if rendered == "" {
|
|
return "", fmt.Errorf("router: classifier %q chat template produced empty output", cfgCopy.Name)
|
|
}
|
|
return rendered, nil
|
|
}
|
|
}
|
|
|
|
// pickAssistantTurnEnd returns the classifier model's assistant
|
|
// turn-end token — the one to suffix candidates with so the model's
|
|
// "I'm done" signal folds into the per-candidate joint log-prob.
|
|
//
|
|
// Strategy: prefer the stopword that *literally appears* in the
|
|
// chat_message template, because that token is the assistant
|
|
// turn-end by construction. ChatML's chat_message ends with
|
|
// "<|im_end|>", Llama-3's ends with "<|eot_id|>", etc. — the
|
|
// template is the source of truth.
|
|
//
|
|
// Fallback: the first non-empty stopword. That's right for
|
|
// well-ordered configs (ChatML conventionally lists <|im_end|>
|
|
// first) but wrong for some gallery Llama-3 templates that defensively
|
|
// list <|im_end|> first even though the actual turn-end is <|eot_id|>.
|
|
// The template-scan above catches those.
|
|
//
|
|
// When no stopwords are configured at all, return "" — caller falls
|
|
// back to defaultStopToken (<|im_end|>) inside the score classifier.
|
|
func pickAssistantTurnEnd(words []string, chatMessageTemplate string) string {
|
|
if chatMessageTemplate != "" {
|
|
for _, w := range words {
|
|
if w != "" && strings.Contains(chatMessageTemplate, w) {
|
|
return w
|
|
}
|
|
}
|
|
}
|
|
for _, w := range words {
|
|
if w != "" {
|
|
return w
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func wrapWithEmbeddingCache(cfg *config.ModelConfig, inner router.Classifier, deps ClassifierDeps) (router.Classifier, error) {
|
|
ec := cfg.Router.EmbeddingCache
|
|
if ec.EmbeddingModel == "" {
|
|
return nil, fmt.Errorf("embedding_cache requires embedding_model")
|
|
}
|
|
if deps.Embedder == nil || deps.VectorStore == nil {
|
|
return nil, fmt.Errorf("embedding cache factories not wired")
|
|
}
|
|
embedder := deps.Embedder(ec.EmbeddingModel)
|
|
if embedder == nil {
|
|
return nil, fmt.Errorf("embedding_model %q not loadable", ec.EmbeddingModel)
|
|
}
|
|
storeName := ec.StoreName
|
|
if storeName == "" {
|
|
storeName = "router-cache-" + cfg.Name
|
|
}
|
|
vstore := deps.VectorStore(storeName)
|
|
if vstore == nil {
|
|
return nil, fmt.Errorf("vector store %q not loadable", storeName)
|
|
}
|
|
return router.NewEmbeddingCacheClassifier(inner, embedder, vstore, ec.SimilarityThreshold, ec.ConfidenceThreshold), nil
|
|
}
|
|
|
|
func newDecisionID() string {
|
|
var b [12]byte
|
|
_, _ = rand.Read(b[:])
|
|
return "rd_" + hex.EncodeToString(b[:])
|
|
}
|
|
|
|
// OpenAIProbe extracts a router.Probe from a parsed *schema.OpenAIRequest.
|
|
// Concatenates message contents (string-form or text blocks of the
|
|
// structured `[]any` content) so the classifier sees a single corpus
|
|
// for length and content-shape rules. Image blocks are skipped — a
|
|
// future multimodal classifier can take a different route.
|
|
func OpenAIProbe(parsed any) (router.Probe, bool) {
|
|
req, ok := parsed.(*schema.OpenAIRequest)
|
|
if !ok || req == nil {
|
|
return router.Probe{}, false
|
|
}
|
|
return OpenAIProbeFromRequest(req), true
|
|
}
|
|
|
|
// OpenAIProbeFromRequest is the typed counterpart of OpenAIProbe — same
|
|
// extraction logic, but takes the request struct directly. Realtime and
|
|
// other non-HTTP callers use it to feed a probe to router.Resolve
|
|
// without going through an echo.Context first.
|
|
func OpenAIProbeFromRequest(req *schema.OpenAIRequest) router.Probe {
|
|
if req == nil {
|
|
return router.Probe{}
|
|
}
|
|
var b strings.Builder
|
|
for i := range req.Messages {
|
|
switch ct := req.Messages[i].Content.(type) {
|
|
case string:
|
|
b.WriteString(ct)
|
|
b.WriteByte('\n')
|
|
case []any:
|
|
for _, block := range ct {
|
|
if bm, ok := block.(map[string]any); ok && bm["type"] == "text" {
|
|
if t, ok := bm["text"].(string); ok {
|
|
b.WriteString(t)
|
|
b.WriteByte('\n')
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return router.Probe{Prompt: b.String()}
|
|
}
|
|
|
|
// AnthropicProbe is the AnthropicRequest analogue of OpenAIProbe.
|
|
func AnthropicProbe(parsed any) (router.Probe, bool) {
|
|
req, ok := parsed.(*schema.AnthropicRequest)
|
|
if !ok || req == nil {
|
|
return router.Probe{}, false
|
|
}
|
|
var b strings.Builder
|
|
for i := range req.Messages {
|
|
switch ct := req.Messages[i].Content.(type) {
|
|
case string:
|
|
b.WriteString(ct)
|
|
b.WriteByte('\n')
|
|
case []any:
|
|
for _, block := range ct {
|
|
if bm, ok := block.(map[string]any); ok && bm["type"] == "text" {
|
|
if t, ok := bm["text"].(string); ok {
|
|
b.WriteString(t)
|
|
b.WriteByte('\n')
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return router.Probe{
|
|
Prompt: b.String(),
|
|
}, true
|
|
}
|
|
|