Files
LocalAI/core/config/model_config.go
Richard Palethorpe 6a80e23733 feat(middleware): Model routing, PII filtering, Cloud model proxies (#9802)
Add a routing middleware stack and a cloud-proxy backend.

* cloud-proxy: a Go gRPC backend that forwards OpenAI- and
  Anthropic-shaped chat requests to upstream providers, with an
  optional translate mode (OpenAI request -> Anthropic /v1/messages
  -> OpenAI response) and full tool-calling support.

* routing: admission control, content-aware model routing
  (embedding cache + classifier + rerank + Arch-Router score),
  PII detection/redaction (regex + NER) with streaming filter and
  OpenAI/Anthropic adapters, and a per-user/per-key billing recorder
  backed by GORM or in-memory storage.

* middleware: UsageMiddleware records usage via the billing recorder,
  plus admission, route-model, usage-stamp and trace middlewares.

* observability: BackendTrace ring buffer stores full request bodies
  (capped), MITM proxy emits structured trace events, and router
  classifier decisions surface at /api/router/decide.

* gallery: Arch-Router-1.5B (Q4_K_M and Q8_0).

* UI: cloud-proxy model-editor fields, classifier system-prompt and
  score-normalization config, and a Traces page rendering request
  bodies.

Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-25 09:28:27 +02:00

1370 lines
55 KiB
Go

package config
import (
"encoding/json"
"fmt"
"os"
"regexp"
"slices"
"strings"
"text/template"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/reasoning"
"github.com/mudler/cogito"
"gopkg.in/yaml.v3"
)
const (
RAND_SEED = -1
)
// @Description TTS configuration
type TTSConfig struct {
// Voice wav path or id
Voice string `yaml:"voice,omitempty" json:"voice,omitempty"`
AudioPath string `yaml:"audio_path,omitempty" json:"audio_path,omitempty"`
}
// @Description ModelConfig represents a model configuration
type ModelConfig struct {
modelConfigFile string `yaml:"-" json:"-"`
modelTemplate string `yaml:"-" json:"-"`
schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"`
Name string `yaml:"name,omitempty" json:"name,omitempty"`
F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"`
Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"`
Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"`
Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"`
Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"`
Backend string `yaml:"backend,omitempty" json:"backend,omitempty"`
TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"`
KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"`
KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"`
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
InputToken [][]int `yaml:"-" json:"-"`
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
ResponseFormat string `yaml:"-" json:"-"`
ResponseFormatMap map[string]any `yaml:"-" json:"-"`
// MediaMarker is the runtime-discovered multimodal marker the backend expects
// in the prompt (e.g. "<__media__>" or a random "<__media_<rand>__>" picked by
// llama.cpp). Populated on first successful ModelMetadata call. Empty until
// then — callers must fall back to templates.DefaultMultiMediaMarker.
MediaMarker string `yaml:"-" json:"-"`
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline" json:",inline"`
// Diffusers
Diffusers Diffusers `yaml:"diffusers,omitempty" json:"diffusers,omitempty"`
Step int `yaml:"step,omitempty" json:"step,omitempty"`
// GRPC Options
GRPC GRPC `yaml:"grpc,omitempty" json:"grpc,omitempty"`
// TTS specifics
TTSConfig `yaml:"tts,omitempty" json:"tts,omitempty"`
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
DownloadFiles []File `yaml:"download_files,omitempty" json:"download_files,omitempty"`
Description string `yaml:"description,omitempty" json:"description,omitempty"`
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"`
Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"`
// ConcurrencyGroups declares per-node mutual-exclusion groups: the model
// cannot be loaded alongside another model that shares any group name.
// See docs/content/advanced/vram-management.md for usage.
ConcurrencyGroups []string `yaml:"concurrency_groups,omitempty" json:"concurrency_groups,omitempty"`
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
PII PIIConfig `yaml:"pii,omitempty" json:"pii,omitempty"`
Router RouterConfig `yaml:"router,omitempty" json:"router,omitempty"`
Proxy ProxyConfig `yaml:"proxy,omitempty" json:"proxy,omitempty"`
MITM MITMModelConfig `yaml:"mitm,omitempty" json:"mitm,omitempty"`
Limits LimitsConfig `yaml:"limits,omitempty" json:"limits,omitempty"`
}
// @Description Admission-control limits applied per request. The
// admission middleware enforces these before invoking the handler;
// requests that exceed a limit get 503 with a Retry-After hint so
// clients back off rather than pile on. Per-model so cloud passthroughs
// can have a stricter ceiling than local models.
type LimitsConfig struct {
// MaxConcurrent caps simultaneous in-flight requests for this
// model. 0 = unlimited (default). Useful for cloud-passthrough
// configs where the upstream rate-limits aggressively, or for
// local backends whose memory budget tops out before LocalAI's
// queue depth would.
MaxConcurrent int `yaml:"max_concurrent,omitempty" json:"max_concurrent,omitempty"`
// RetryAfterSeconds advises clients how long to wait before
// retrying when admission rejects. 0 defaults to 1s — enough to
// let an in-flight request finish on a busy local model. The
// value is sent verbatim in the Retry-After response header.
RetryAfterSeconds int `yaml:"retry_after_seconds,omitempty" json:"retry_after_seconds,omitempty"`
}
// @Description MITM intercept binding for the model. When the cloudproxy
// MITM listener is enabled and any host listed here appears in a CONNECT,
// the proxy uses THIS model config's pii: settings to filter the
// intercepted body. Strict 1-to-1: a host claimed by two configs is a
// configuration error and disables the MITM listener until resolved.
//
// Lets an admin pair a host (api.anthropic.com) with the model's
// PII overrides without maintaining a parallel per-host map.
type MITMModelConfig struct {
// Hosts is the list of hostnames this model claims for MITM
// interception. Each entry must be unique across all model configs.
Hosts []string `yaml:"hosts,omitempty" json:"hosts,omitempty"`
}
// @Description Cloud proxy configuration. The cloud-proxy backend
// forwards a model's traffic to an external provider. Two modes:
//
// - mode: passthrough — client and upstream must speak the same wire
// format; the backend ships the raw request body to the upstream
// URL and streams the response back untouched. The streaming PII
// filter still runs because it operates on extracted token text.
//
// - mode: translate — the backend converts LocalAI's internal proto
// to the provider's wire format and back. Unlocks cross-provider
// routing (OpenAI client → Anthropic upstream, etc.) at the cost
// of dropping provider-specific extensions that the internal proto
// doesn't model.
type ProxyConfig struct {
// UpstreamURL is the full POST endpoint, e.g.
// https://api.openai.com/v1/chat/completions or
// https://api.anthropic.com/v1/messages. Required.
UpstreamURL string `yaml:"upstream_url,omitempty" json:"upstream_url,omitempty"`
// Mode selects passthrough (wire-perfect) or translate (full
// control via internal proto). Empty defaults to passthrough.
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
// Provider identifies the upstream's wire format for translate
// mode (openai, anthropic). Ignored in passthrough mode — the
// wire format there is whatever the client sent.
Provider string `yaml:"provider,omitempty" json:"provider,omitempty"`
// APIKeyEnv names the environment variable holding the upstream
// API key. Mutually exclusive with APIKeyFile. Both empty is
// allowed (no-auth upstreams).
APIKeyEnv string `yaml:"api_key_env,omitempty" json:"api_key_env,omitempty"`
// APIKeyFile is a path to a file whose contents are the upstream
// API key. Trailing whitespace is trimmed. Mutually exclusive
// with APIKeyEnv. The integration point for K8s secret mounts,
// Vault agent files, and similar external-secret workflows.
APIKeyFile string `yaml:"api_key_file,omitempty" json:"api_key_file,omitempty"`
// UpstreamModel overrides the model name sent to the upstream.
// Useful when the LocalAI-facing model alias differs from the
// upstream's canonical name (e.g. local "claude-strict" maps to
// upstream "claude-3-5-sonnet-20241022"). Empty means forward
// the client's model field unchanged.
UpstreamModel string `yaml:"upstream_model,omitempty" json:"upstream_model,omitempty"`
// RequestTimeoutSeconds caps the upstream request duration. 0
// means no per-request timeout (only the request context, which
// is bound to the client connection, applies).
RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"`
}
// Proxy mode names. Validate() normalises an empty Mode to
// ProxyModePassthrough so downstream code only sees concrete values.
const (
ProxyModePassthrough = "passthrough"
ProxyModeTranslate = "translate"
)
// Proxy provider names. Only meaningful in translate mode, where the
// cloud-proxy backend picks the wire format to use against the
// upstream URL.
const (
ProxyProviderOpenAI = "openai"
ProxyProviderAnthropic = "anthropic"
)
// IsCloudProxyBackendPassthrough reports whether this model uses the
// cloud-proxy gRPC backend in passthrough mode. Empty Mode counts as
// passthrough (SetDefaults normalises it, but Validate accepts empty
// too — handlers should not rely on a particular call order).
func (c *ModelConfig) IsCloudProxyBackendPassthrough() bool {
if c.Backend != "cloud-proxy" {
return false
}
return c.Proxy.Mode == "" || c.Proxy.Mode == ProxyModePassthrough
}
// @Description Intelligent routing configuration. When a model declares
// a Router block, requests addressed to it are reclassified at runtime
// and dispatched to one of the named candidates. The router rewrites
// input.Model in-place, then the standard model-resolution path picks
// up the resolved config — meaning ACL checks, disabled-state, and
// per-model PII still run against the chosen target.
//
// Depth-1 invariant: candidates must NOT themselves carry a Router
// block. The router's "smart-router → claude-strict → cloud-proxy"
// chain is fine, but "router-A → router-B → claude" is rejected at
// config load to keep the dispatch graph acyclic and predictable. The
// middleware also asserts depth ≤ 1 at runtime as a defensive check.
type RouterConfig struct {
// Classifier picks the implementation. Only "score" ships today:
// it asks the classifier model to score every Policy label as a
// continuation of the routing prompt and reads off the
// distribution. Empty defaults to "score".
Classifier string `yaml:"classifier,omitempty" json:"classifier,omitempty"`
// Policies is the label vocabulary the classifier scores over.
// Each policy carries a natural-language description that ends up
// in the system prompt the classifier model sees — short, action-
// oriented sentences work best ("writing or debugging code",
// "small talk", ...). The Score classifier picks the subset of
// labels whose softmax probability passes ActivationThreshold.
Policies []RouterPolicy `yaml:"policies,omitempty" json:"policies,omitempty"`
// Candidates is the routing table — each entry binds a downstream
// model to a set of labels it can serve. The middleware picks the
// FIRST candidate whose Labels are a superset of the active label
// set from the classifier. Admins order this list smallest →
// largest so a query that needs one label routes to the smallest
// capable model, while a query that needs multiple falls to a
// bigger candidate that covers them all.
Candidates []RouterCandidate `yaml:"candidates,omitempty" json:"candidates,omitempty"`
// Fallback is the model used when no candidate matches the active
// label set, or when the classifier returns nothing above
// threshold. Empty fallback means router failures bubble up as
// 500 — fail-fast, not silent-bypass.
Fallback string `yaml:"fallback,omitempty" json:"fallback,omitempty"`
// ClassifierModel names the model the Score classifier scores
// against (Arch-Router-1.5B is the canonical choice).
ClassifierModel string `yaml:"classifier_model,omitempty" json:"classifier_model,omitempty"`
// ClassifierCacheSize bounds the per-prompt memo cache that
// amortises the classifier round-trip across repeat probes.
// 0 disables the cache. Default 1024.
ClassifierCacheSize int `yaml:"classifier_cache_size,omitempty" json:"classifier_cache_size,omitempty"`
// ActivationThreshold is the softmax-probability floor a policy
// must clear to be considered "active" for the request. 0
// defaults to a sensible value (~0.15) inside the classifier.
// Higher → narrower routes (single-label dominant); lower →
// more multi-label activations.
ActivationThreshold float64 `yaml:"activation_threshold,omitempty" json:"activation_threshold,omitempty"`
// ClassifierSystemTemplate overrides the routing system prompt
// the score classifier feeds to its classifier_model. Go
// text/template + Sprig, executed with `.Policies []ScorePolicy`
// (Label + Description fields). Empty falls back to the built-in
// Arch-Router-shaped template (route-listing block + JSON output
// schema). Override when the classifier model was trained on a
// different schema (e.g. bare label output, XML route block) or
// when the routing instructions need to be in a different
// language. The candidate format scored against the model is
// fixed at `{"route": "<label>"}` and IS NOT templated — keep
// your override's output schema instruction matching that, or
// the per-candidate scores degenerate.
ClassifierSystemTemplate string `yaml:"classifier_system_template,omitempty" json:"classifier_system_template,omitempty"`
// ScoreNormalization picks how the score classifier collapses
// per-candidate joint log-probs into the softmax input.
// - ""/"raw": use joint log-prob as-is (default). Matches the
// distribution the classifier model was trained against — the
// route the model would actually emit if decoded freely.
// - "mean": divide by candidate token count. Fairer to long
// labels (their joint log-prob is mechanically smaller because
// it sums more negatives), but off-distribution for models
// trained to emit fixed-format outputs like Arch-Router's
// {"route": "name"}.
// Future modes (e.g. "weighted_mean") will land here too.
ScoreNormalization string `yaml:"score_normalization,omitempty" json:"score_normalization,omitempty"`
// EmbeddingCache configures the L2 cache that maps prompt
// embeddings to past decisions, so semantically-similar prompts
// reuse a classification instead of re-running the classifier
// model. Omit the block to disable. See router/embedding_cache.go.
EmbeddingCache *EmbeddingCacheConfig `yaml:"embedding_cache,omitempty" json:"embedding_cache,omitempty"`
}
// EmbeddingCacheConfig configures the L2 embedding-similarity decision
// cache. Pairs naturally with a larger / slower classifier model: the
// classifier round-trip is amortised across paraphrases of the same
// intent. The cache uses the standard /v1/embeddings backend for
// vector generation and the local-store gRPC surface for KNN search.
type EmbeddingCacheConfig struct {
// EmbeddingModel names the loaded LocalAI model used to embed
// router prompts. Required when the cache is enabled. Any model
// that supports the Embeddings gRPC primitive works;
// nomic-embed-text-v1.5 is the recommended default.
EmbeddingModel string `yaml:"embedding_model" json:"embedding_model"`
// SimilarityThreshold is the cosine-similarity floor a cache
// candidate must clear to be treated as a hit. 0 picks the
// package default (0.80). Higher → fewer false hits, higher miss
// rate; lower → more aggressive sharing across paraphrases.
SimilarityThreshold float64 `yaml:"similarity_threshold,omitempty" json:"similarity_threshold,omitempty"`
// ConfidenceThreshold is the minimum classifier top-label
// probability for a decision to be inserted into the cache. 0
// picks the package default (0.60). Uncertain decisions are not
// cached so they can't poison future paraphrases.
ConfidenceThreshold float64 `yaml:"confidence_threshold,omitempty" json:"confidence_threshold,omitempty"`
// StoreName overrides the local-store collection name used for
// this router's cache. Empty defaults to "router-cache-<router>"
// where <router> is the parent model name. Useful when two
// router models should share a cache (rare).
StoreName string `yaml:"store_name,omitempty" json:"store_name,omitempty"`
}
// RouterPolicy is one entry in the label vocabulary. The label string
// is what the classifier model emits and what candidates reference in
// their Labels field; the description is the natural-language hint
// fed to the classifier so it can match user intent against the label
// space.
type RouterPolicy struct {
Label string `yaml:"label" json:"label"`
Description string `yaml:"description" json:"description"`
}
// RouterCandidate names a downstream model and the policy labels it
// is willing to serve. Labels are matched as a set: the middleware
// picks the first candidate whose Labels is a superset of the
// classifier's active set.
type RouterCandidate struct {
Model string `yaml:"model" json:"model"`
Labels []string `yaml:"labels" json:"labels"`
}
// HasRouter returns true when the model declares a router config with
// at least one candidate. Used by the RouteModel middleware to decide
// whether to engage the classifier.
func (c *ModelConfig) HasRouter() bool {
return len(c.Router.Candidates) > 0
}
// @Description PII filtering configuration. PII redaction is per-model so
// that local models don't pay the latency or behaviour change of regex
// scanning, while cloud-bound traffic (cloud-proxy backend) can default to
// on. Setting Enabled explicitly always wins over the backend default.
type PIIConfig struct {
// Enabled toggles redaction for this model. When unset (zero value),
// the resolved default depends on Backend: cloud-proxy defaults to
// true, everything else to false. A pointer is used so the absence of
// the YAML key is distinguishable from explicit false.
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
// Patterns lets a model upgrade or downgrade individual pattern
// actions (mask | block | route_local) relative to the global
// defaults loaded from --pii-config / DefaultPatterns. Pattern IDs
// not listed inherit the global action. The regex itself stays
// global — only the action is settable per-model.
Patterns []PIIPatternOverride `yaml:"patterns,omitempty" json:"patterns,omitempty"`
}
// @Description Per-model action override for a single PII pattern.
type PIIPatternOverride struct {
ID string `yaml:"id" json:"id"`
Action string `yaml:"action" json:"action"`
}
// PIIIsEnabled returns the resolved PII state for this model. Single
// source of truth for the gating decision so the middleware and the
// /api/middleware/status admin view agree.
func (c *ModelConfig) PIIIsEnabled() bool {
if c.PII.Enabled != nil {
return *c.PII.Enabled
}
return c.Backend == "cloud-proxy"
}
// PIIPatternOverrides returns the per-pattern action overrides as a map
// keyed by pattern ID. The values are the raw action strings — the pii
// package validates and converts them.
//
// Returned via the documented modelPIIConfig interface in
// core/services/routing/pii/middleware.go without taking a config
// dependency on this package.
func (c *ModelConfig) PIIPatternOverrides() map[string]string {
if len(c.PII.Patterns) == 0 {
return nil
}
out := make(map[string]string, len(c.PII.Patterns))
for _, p := range c.PII.Patterns {
if p.ID == "" {
continue
}
out[p.ID] = p.Action
}
return out
}
// @Description MCP configuration
type MCPConfig struct {
Servers string `yaml:"remote,omitempty" json:"remote,omitempty"`
Stdio string `yaml:"stdio,omitempty" json:"stdio,omitempty"`
}
// @Description Agent configuration
type AgentConfig struct {
MaxAttempts int `yaml:"max_attempts,omitempty" json:"max_attempts,omitempty"`
MaxIterations int `yaml:"max_iterations,omitempty" json:"max_iterations,omitempty"`
EnableReasoning bool `yaml:"enable_reasoning,omitempty" json:"enable_reasoning,omitempty"`
EnablePlanning bool `yaml:"enable_planning,omitempty" json:"enable_planning,omitempty"`
EnableMCPPrompts bool `yaml:"enable_mcp_prompts,omitempty" json:"enable_mcp_prompts,omitempty"`
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator,omitempty" json:"enable_plan_re_evaluator,omitempty"`
DisableSinkState bool `yaml:"disable_sink_state,omitempty" json:"disable_sink_state,omitempty"`
LoopDetection int `yaml:"loop_detection,omitempty" json:"loop_detection,omitempty"`
MaxAdjustmentAttempts int `yaml:"max_adjustment_attempts,omitempty" json:"max_adjustment_attempts,omitempty"`
ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"`
}
// HasMCPServers returns true if any MCP servers (remote or stdio) are configured.
func (c MCPConfig) HasMCPServers() bool {
return c.Servers != "" || c.Stdio != ""
}
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
var remote MCPGenericConfig[MCPRemoteServers]
var stdio MCPGenericConfig[MCPSTDIOServers]
if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
return remote, stdio, err
}
if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
return remote, stdio, err
}
return remote, stdio, nil
}
// @Description MCP generic configuration
type MCPGenericConfig[T any] struct {
Servers T `yaml:"mcpServers,omitempty" json:"mcpServers,omitempty"`
}
type MCPRemoteServers map[string]MCPRemoteServer
type MCPSTDIOServers map[string]MCPSTDIOServer
// @Description MCP remote server configuration
type MCPRemoteServer struct {
URL string `json:"url,omitempty"`
Token string `json:"token,omitempty"`
}
// @Description MCP STDIO server configuration
type MCPSTDIOServer struct {
Args []string `json:"args,omitempty"`
Env map[string]string `json:"env,omitempty"`
Command string `json:"command,omitempty"`
}
// @Description Pipeline defines other models to use for audio-to-audio
type Pipeline struct {
TTS string `yaml:"tts,omitempty" json:"tts,omitempty"`
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
}
// @Description File configuration for model downloads
type File struct {
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`
SHA256 string `yaml:"sha256,omitempty" json:"sha256,omitempty"`
URI downloader.URI `yaml:"uri,omitempty" json:"uri,omitempty"`
}
type FeatureFlag map[string]*bool
func (ff FeatureFlag) Enabled(s string) bool {
if v, exists := ff[s]; exists && v != nil {
return *v
}
return false
}
// @Description GRPC configuration
type GRPC struct {
Attempts int `yaml:"attempts,omitempty" json:"attempts,omitempty"`
AttemptsSleepTime int `yaml:"attempts_sleep_time,omitempty" json:"attempts_sleep_time,omitempty"`
}
// @Description Diffusers configuration
type Diffusers struct {
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
PipelineType string `yaml:"pipeline_type,omitempty" json:"pipeline_type,omitempty"`
SchedulerType string `yaml:"scheduler_type,omitempty" json:"scheduler_type,omitempty"`
EnableParameters string `yaml:"enable_parameters,omitempty" json:"enable_parameters,omitempty"` // A list of comma separated parameters to specify
IMG2IMG bool `yaml:"img2img,omitempty" json:"img2img,omitempty"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip,omitempty" json:"clip_skip,omitempty"` // Skip every N frames
ClipModel string `yaml:"clip_model,omitempty" json:"clip_model,omitempty"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder,omitempty" json:"clip_subfolder,omitempty"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net,omitempty" json:"control_net,omitempty"`
}
// @Description LLMConfig is a struct that holds the configuration that are generic for most of the LLM backends.
type LLMConfig struct {
SystemPrompt string `yaml:"system_prompt,omitempty" json:"system_prompt,omitempty"`
TensorSplit string `yaml:"tensor_split,omitempty" json:"tensor_split,omitempty"`
MainGPU string `yaml:"main_gpu,omitempty" json:"main_gpu,omitempty"`
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
PromptCacheAll *bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
Mirostat *int `yaml:"mirostat,omitempty" json:"mirostat,omitempty"`
NGPULayers *int `yaml:"gpu_layers,omitempty" json:"gpu_layers,omitempty"`
MMap *bool `yaml:"mmap,omitempty" json:"mmap,omitempty"`
MMlock *bool `yaml:"mmlock,omitempty" json:"mmlock,omitempty"`
LowVRAM *bool `yaml:"low_vram,omitempty" json:"low_vram,omitempty"`
Reranking *bool `yaml:"reranking,omitempty" json:"reranking,omitempty"`
Grammar string `yaml:"grammar,omitempty" json:"grammar,omitempty"`
StopWords []string `yaml:"stopwords,omitempty" json:"stopwords,omitempty"`
Cutstrings []string `yaml:"cutstrings,omitempty" json:"cutstrings,omitempty"`
ExtractRegex []string `yaml:"extract_regex,omitempty" json:"extract_regex,omitempty"`
TrimSpace []string `yaml:"trimspace,omitempty" json:"trimspace,omitempty"`
TrimSuffix []string `yaml:"trimsuffix,omitempty" json:"trimsuffix,omitempty"`
ContextSize *int `yaml:"context_size,omitempty" json:"context_size,omitempty"`
NUMA bool `yaml:"numa,omitempty" json:"numa,omitempty"`
LoraAdapter string `yaml:"lora_adapter,omitempty" json:"lora_adapter,omitempty"`
LoraBase string `yaml:"lora_base,omitempty" json:"lora_base,omitempty"`
LoraAdapters []string `yaml:"lora_adapters,omitempty" json:"lora_adapters,omitempty"`
LoraScales []float32 `yaml:"lora_scales,omitempty" json:"lora_scales,omitempty"`
LoraScale float32 `yaml:"lora_scale,omitempty" json:"lora_scale,omitempty"`
NoMulMatQ bool `yaml:"no_mulmatq,omitempty" json:"no_mulmatq,omitempty"`
DraftModel string `yaml:"draft_model,omitempty" json:"draft_model,omitempty"`
NDraft int32 `yaml:"n_draft,omitempty" json:"n_draft,omitempty"`
Quantization string `yaml:"quantization,omitempty" json:"quantization,omitempty"`
LoadFormat string `yaml:"load_format,omitempty" json:"load_format,omitempty"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization,omitempty" json:"gpu_memory_utilization,omitempty"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code,omitempty" json:"trust_remote_code,omitempty"` // vLLM
EnforceEager bool `yaml:"enforce_eager,omitempty" json:"enforce_eager,omitempty"` // vLLM
SwapSpace int `yaml:"swap_space,omitempty" json:"swap_space,omitempty"` // vLLM
MaxModelLen int `yaml:"max_model_len,omitempty" json:"max_model_len,omitempty"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size,omitempty" json:"tensor_parallel_size,omitempty"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM
DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM
// EngineArgs is a backend-native passthrough applied to the engine constructor
// (e.g. vLLM AsyncEngineArgs). Values may be primitives or nested maps; nested
// maps materialise into the backend's nested config dataclasses (e.g.
// SpeculativeConfig, KVTransferConfig, CompilationConfig). Unknown keys cause
// the backend to fail LoadModel with a list of valid names.
EngineArgs map[string]any `yaml:"engine_args,omitempty" json:"engine_args,omitempty"`
MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"`
FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"`
NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"`
CacheTypeK string `yaml:"cache_type_k,omitempty" json:"cache_type_k,omitempty"`
CacheTypeV string `yaml:"cache_type_v,omitempty" json:"cache_type_v,omitempty"`
RopeScaling string `yaml:"rope_scaling,omitempty" json:"rope_scaling,omitempty"`
ModelType string `yaml:"type,omitempty" json:"type,omitempty"`
YarnExtFactor float32 `yaml:"yarn_ext_factor,omitempty" json:"yarn_ext_factor,omitempty"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor,omitempty" json:"yarn_attn_factor,omitempty"`
YarnBetaFast float32 `yaml:"yarn_beta_fast,omitempty" json:"yarn_beta_fast,omitempty"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow,omitempty" json:"yarn_beta_slow,omitempty"`
CFGScale float32 `yaml:"cfg_scale,omitempty" json:"cfg_scale,omitempty"` // Classifier-Free Guidance Scale
}
// @Description LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
type LimitMMPerPrompt struct {
LimitImagePerPrompt int `yaml:"image,omitempty" json:"image,omitempty"`
LimitVideoPerPrompt int `yaml:"video,omitempty" json:"video,omitempty"`
LimitAudioPerPrompt int `yaml:"audio,omitempty" json:"audio,omitempty"`
}
// @Description TemplateConfig is a struct that holds the configuration of the templating system
type TemplateConfig struct {
// Chat is the template used in the chat completion endpoint
Chat string `yaml:"chat,omitempty" json:"chat,omitempty"`
// ChatMessage is the template used for chat messages
ChatMessage string `yaml:"chat_message,omitempty" json:"chat_message,omitempty"`
// Completion is the template used for completion requests
Completion string `yaml:"completion,omitempty" json:"completion,omitempty"`
// Edit is the template used for edit completion requests
Edit string `yaml:"edit,omitempty" json:"edit,omitempty"`
// Functions is the template used when tools are present in the client requests
Functions string `yaml:"function,omitempty" json:"function,omitempty"`
// UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
// Note: this is mostly consumed for backends such as vllm and transformers
// that can use the tokenizers specified in the JSON config files of the models
UseTokenizerTemplate bool `yaml:"use_tokenizer_template,omitempty" json:"use_tokenizer_template,omitempty"`
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
// It defaults to \n
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character,omitempty" json:"join_chat_messages_by_character,omitempty"`
Multimodal string `yaml:"multimodal,omitempty" json:"multimodal,omitempty"`
ReplyPrefix string `yaml:"reply_prefix,omitempty" json:"reply_prefix,omitempty"`
}
func (c *ModelConfig) syncKnownUsecasesFromString() {
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
// Make sure the usecases are valid, we rewrite with what we identified
c.KnownUsecaseStrings = []string{}
for k, usecase := range GetAllModelConfigUsecases() {
if c.HasUsecases(usecase) {
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
}
}
}
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
type BCAlias ModelConfig
var aux BCAlias
if err := value.Decode(&aux); err != nil {
return err
}
mc := ModelConfig(aux)
*c = mc
c.syncKnownUsecasesFromString()
return nil
}
func (c *ModelConfig) SetFunctionCallString(s string) {
c.functionCallString = s
}
func (c *ModelConfig) SetFunctionCallNameString(s string) {
c.functionCallNameString = s
}
func (c *ModelConfig) ShouldUseFunctions() bool {
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
}
func (c *ModelConfig) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}
// MMProjFileName returns the filename of the MMProj file
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
func (c *ModelConfig) MMProjFileName() string {
uri := downloader.URI(c.MMProj)
if uri.LooksLikeURL() {
f, _ := uri.FilenameFromUrl()
return f
}
return c.MMProj
}
func (c *ModelConfig) IsMMProjURL() bool {
uri := downloader.URI(c.MMProj)
return uri.LooksLikeURL()
}
func (c *ModelConfig) IsModelURL() bool {
uri := downloader.URI(c.Model)
return uri.LooksLikeURL()
}
// ModelFileName returns the filename of the model
// If the model is a URL, it will return the MD5 of the URL which is the filename
func (c *ModelConfig) ModelFileName() string {
uri := downloader.URI(c.Model)
if uri.LooksLikeURL() {
f, _ := uri.FilenameFromUrl()
return f
}
return c.Model
}
func (c *ModelConfig) FunctionToCall() string {
if c.functionCallNameString != "" &&
c.functionCallNameString != "none" && c.functionCallNameString != "auto" {
return c.functionCallNameString
}
return c.functionCallString
}
func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
lo := &LoadOptions{}
lo.Apply(opts...)
ctx := lo.ctxSize
threads := lo.threads
f16 := lo.f16
debug := lo.debug
// Cloud-proxy: normalise empty Mode so downstream consumers
// switch on two concrete values only. Validate accepts empty too,
// but SetDefaults is the chokepoint that runs before any
// inference path reads cfg.Proxy.Mode.
if cfg.Proxy.Mode == "" {
cfg.Proxy.Mode = ProxyModePassthrough
}
// Apply model-family-specific inference defaults before generic fallbacks.
// This ensures gallery-installed and runtime-loaded models get optimal parameters.
ApplyInferenceDefaults(cfg, cfg.Name, cfg.Model)
// https://github.com/ggerganov/llama.cpp/blob/75cd4c77292034ecec587ecb401366f57338f7c0/common/sampling.h#L22
defaultTopP := 0.95
defaultTopK := 40
defaultMinP := 0.0
defaultTemp := 0.9
// https://github.com/mudler/LocalAI/issues/2780
defaultMirostat := 0
defaultMirostatTAU := 5.0
defaultMirostatETA := 0.1
defaultTypicalP := 1.0
defaultTFZ := 1.0
defaultZero := 0
trueV := true
falseV := false
if cfg.Seed == nil {
// random number generator seed
defaultSeed := RAND_SEED
cfg.Seed = &defaultSeed
}
if cfg.TopK == nil {
cfg.TopK = &defaultTopK
}
if cfg.MinP == nil {
cfg.MinP = &defaultMinP
}
if cfg.TypicalP == nil {
cfg.TypicalP = &defaultTypicalP
}
if cfg.TFZ == nil {
cfg.TFZ = &defaultTFZ
}
if cfg.MMap == nil {
// MMap is enabled by default
// Only exception is for Intel GPUs
if os.Getenv("XPU") != "" {
cfg.MMap = &falseV
} else {
cfg.MMap = &trueV
}
}
if cfg.MMlock == nil {
// MMlock is disabled by default
cfg.MMlock = &falseV
}
if cfg.TopP == nil {
cfg.TopP = &defaultTopP
}
if cfg.Temperature == nil {
cfg.Temperature = &defaultTemp
}
if cfg.Maxtokens == nil {
cfg.Maxtokens = &defaultZero
}
if cfg.Mirostat == nil {
cfg.Mirostat = &defaultMirostat
}
if cfg.MirostatETA == nil {
cfg.MirostatETA = &defaultMirostatETA
}
if cfg.MirostatTAU == nil {
cfg.MirostatTAU = &defaultMirostatTAU
}
if cfg.LowVRAM == nil {
cfg.LowVRAM = &falseV
}
if cfg.Embeddings == nil {
cfg.Embeddings = &falseV
}
if cfg.Reranking == nil {
cfg.Reranking = &falseV
}
if cfg.PromptCacheAll == nil {
// Match upstream llama.cpp's default (common/common.h: cache_prompt = true)
// and let cache_idle_slots / kv_unified actually do useful work; users can
// opt out with an explicit `prompt_cache_all: false` in the model YAML.
cfg.PromptCacheAll = &trueV
}
if threads == 0 {
// Threads can't be 0
threads = 4
}
if cfg.Threads == nil {
cfg.Threads = &threads
}
if cfg.F16 == nil {
cfg.F16 = &f16
}
if cfg.Debug == nil {
cfg.Debug = &falseV
}
if debug {
cfg.Debug = &trueV
}
// If a context size was provided via LoadOptions, apply it before hooks so they
// don't override it with their own defaults.
if ctx != 0 && cfg.ContextSize == nil {
cfg.ContextSize = &ctx
}
runBackendHooks(cfg, lo.modelPath)
cfg.syncKnownUsecasesFromString()
}
func (c *ModelConfig) Validate() (bool, error) {
downloadedFileNames := []string{}
for _, f := range c.DownloadFiles {
downloadedFileNames = append(downloadedFileNames, f.Filename)
}
validationTargets := []string{c.Backend, c.Model, c.MMProj}
validationTargets = append(validationTargets, downloadedFileNames...)
// Simple validation to make sure the model can be correctly loaded
for _, n := range validationTargets {
if n == "" {
continue
}
if strings.HasPrefix(n, string(os.PathSeparator)) ||
strings.Contains(n, "..") {
return false, fmt.Errorf("invalid file path: %s", n)
}
}
if c.Backend != "" {
// a regex that checks that is a string name with no special characters, except '-' and '_'
re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`)
if !re.MatchString(c.Backend) {
return false, fmt.Errorf("invalid backend name: %s", c.Backend)
}
}
// Validate MCP configuration if present
if c.MCP.Servers != "" || c.MCP.Stdio != "" {
if _, _, err := c.MCP.MCPConfigFromYAML(); err != nil {
return false, fmt.Errorf("invalid MCP configuration: %w", err)
}
}
// engine_args crosses the gRPC boundary as a JSON-encoded string. Reject
// unmarshalable values here so a config that would silently lose user-set
// options at load time is rejected at parse time instead.
if len(c.EngineArgs) > 0 {
if _, err := json.Marshal(c.EngineArgs); err != nil {
return false, fmt.Errorf("engine_args is not JSON-serialisable: %w", err)
}
}
// Cloud-proxy: at most one of api_key_env / api_key_file may be
// set. Both empty means no Authorization header (no-auth upstream
// or a development passthrough). The mode field accepts the empty
// string (defaults to passthrough), "passthrough", or "translate".
if c.Proxy.APIKeyEnv != "" && c.Proxy.APIKeyFile != "" {
return false, fmt.Errorf("proxy: api_key_env and api_key_file are mutually exclusive")
}
switch c.Proxy.Mode {
case "", ProxyModePassthrough, ProxyModeTranslate:
// Empty is accepted at validate-time and normalised to
// passthrough by SetDefaults so it never reaches runtime.
default:
return false, fmt.Errorf("proxy: unknown mode %q (expected %s or %s)",
c.Proxy.Mode, ProxyModePassthrough, ProxyModeTranslate)
}
if c.Proxy.Mode == ProxyModeTranslate && c.Proxy.Provider == "" {
return false, fmt.Errorf("proxy: translate mode requires provider (%s, %s)",
ProxyProviderOpenAI, ProxyProviderAnthropic)
}
// Score on llama-cpp bypasses the slot loop and races the
// llama_context against concurrent generation/embedding traffic
// (see backend/cpp/llama-cpp/grpc-server.cpp on Score). Reject the
// combination here so operators are forced to split the model.
const scoreConflicts = FLAG_CHAT | FLAG_COMPLETION | FLAG_EMBEDDINGS
if (c.Backend == "llama-cpp" || c.Backend == "llama") &&
c.HasUsecases(FLAG_SCORE) && c.KnownUsecases != nil &&
*c.KnownUsecases&scoreConflicts != 0 {
return false, fmt.Errorf(
"known_usecases conflict on llama-cpp: score is incompatible " +
"with chat/completion/embeddings — split into separate model configs")
}
// router.score_normalization is consumed lazily by the score
// classifier at first-request time; without load-time validation
// a typo wouldn't surface until the first router request panicked
// inside NewScoreClassifier. Reject unknown values here so the
// operator sees the offending key at startup.
switch c.Router.ScoreNormalization {
case "", ScoreNormalizationRaw, ScoreNormalizationMean:
// ok
default:
return false, fmt.Errorf("router: unknown score_normalization %q (expected %q or %q)",
c.Router.ScoreNormalization, ScoreNormalizationRaw, ScoreNormalizationMean)
}
// router.classifier_system_template parses as Go text/template
// (Sprig funcs available at execution time). Reject malformed
// templates at load time so the operator sees the parse error
// at startup rather than as a 500 on the first router request.
if c.Router.ClassifierSystemTemplate != "" {
if _, err := template.New("classifier_system").Parse(c.Router.ClassifierSystemTemplate); err != nil {
return false, fmt.Errorf("router: classifier_system_template parse error: %w", err)
}
}
return true, nil
}
// Score normalisation modes mirror router.ScoreNormalization* —
// duplicated as constants on the config package so ModelConfig.Validate
// can reject unknown values without taking a dependency on the router
// package (which already depends on config).
const (
ScoreNormalizationRaw = "raw"
ScoreNormalizationMean = "mean"
)
func (c *ModelConfig) HasTemplate() bool {
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate
}
func (c *ModelConfig) GetModelConfigFile() string {
return c.modelConfigFile
}
// GetModelTemplate returns the model's chat template if available
func (c *ModelConfig) GetModelTemplate() string {
return c.modelTemplate
}
// IsDisabled returns true if the model is disabled
func (c *ModelConfig) IsDisabled() bool {
return c.Disabled != nil && *c.Disabled
}
// IsPinned returns true if the model is pinned (excluded from idle unloading and eviction)
func (c *ModelConfig) IsPinned() bool {
return c.Pinned != nil && *c.Pinned
}
// GetConcurrencyGroups returns the model's concurrency groups, normalized:
// trimmed of whitespace, empty entries dropped, deduped. Returns nil when no
// effective groups remain. The result is a fresh slice; the caller may
// mutate it without affecting the config.
func (c *ModelConfig) GetConcurrencyGroups() []string {
if len(c.ConcurrencyGroups) == 0 {
return nil
}
out := make([]string, 0, len(c.ConcurrencyGroups))
for _, g := range c.ConcurrencyGroups {
g = strings.TrimSpace(g)
if g == "" || slices.Contains(out, g) {
continue
}
out = append(out, g)
}
if len(out) == 0 {
return nil
}
return out
}
type ModelConfigUsecase int
const (
FLAG_ANY ModelConfigUsecase = 0b000000000000
FLAG_CHAT ModelConfigUsecase = 0b000000000001
FLAG_COMPLETION ModelConfigUsecase = 0b000000000010
FLAG_EDIT ModelConfigUsecase = 0b000000000100
FLAG_EMBEDDINGS ModelConfigUsecase = 0b000000001000
FLAG_RERANK ModelConfigUsecase = 0b000000010000
FLAG_IMAGE ModelConfigUsecase = 0b000000100000
FLAG_TRANSCRIPT ModelConfigUsecase = 0b000001000000
FLAG_TTS ModelConfigUsecase = 0b000010000000
FLAG_SOUND_GENERATION ModelConfigUsecase = 0b000100000000
FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000
FLAG_VAD ModelConfigUsecase = 0b010000000000
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
FLAG_DETECTION ModelConfigUsecase = 0b1000000000000
FLAG_VISION ModelConfigUsecase = 0b10000000000000
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b100000000000000
FLAG_SPEAKER_RECOGNITION ModelConfigUsecase = 0b1000000000000000
FLAG_AUDIO_TRANSFORM ModelConfigUsecase = 0b10000000000000000
FLAG_DIARIZATION ModelConfigUsecase = 0b100000000000000000
FLAG_REALTIME_AUDIO ModelConfigUsecase = 0b1000000000000000000
// Marks a model as wired for the Score gRPC primitive (joint
// log-prob of candidate continuations under a shared prompt). Must
// be declared explicitly via `known_usecases: [score]` — there's
// no heuristic for it. On the llama-cpp backend, Score bypasses
// the slot loop and races the llama_context, so Validate() refuses
// to load a llama-cpp config that combines FLAG_SCORE with
// chat/completion/embeddings.
FLAG_SCORE ModelConfigUsecase = 0b10000000000000000000
// Common Subsets
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
)
// ModalityGroups defines groups of usecases that belong to the same modality.
// Flags within the same group are NOT orthogonal (e.g., chat and completion are
// both text/language). A model is multimodal when its usecases span 2+ groups.
var ModalityGroups = []ModelConfigUsecase{
FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT, // text/language
FLAG_VISION | FLAG_DETECTION, // visual understanding
FLAG_TRANSCRIPT | FLAG_REALTIME_AUDIO, // speech input — realtime_audio is any-to-any, so it counts here too
FLAG_TTS | FLAG_SOUND_GENERATION | FLAG_REALTIME_AUDIO, // audio output — and here, so a lone realtime_audio flag still reads as multimodal
FLAG_AUDIO_TRANSFORM, // audio in/out transforms
FLAG_IMAGE | FLAG_VIDEO, // visual generation
}
// IsMultimodal returns true if the given usecases span two or more orthogonal
// modality groups. For example chat+vision is multimodal, but chat+completion
// is not (both belong to the text/language group).
func IsMultimodal(usecases ModelConfigUsecase) bool {
groupCount := 0
for _, group := range ModalityGroups {
if usecases&group != 0 {
groupCount++
if groupCount >= 2 {
return true
}
}
}
return false
}
func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
return map[string]ModelConfigUsecase{
// Note: FLAG_ANY is intentionally excluded from this map
// because it's 0 and would always match in HasUsecases checks
"FLAG_CHAT": FLAG_CHAT,
"FLAG_COMPLETION": FLAG_COMPLETION,
"FLAG_EDIT": FLAG_EDIT,
"FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
"FLAG_RERANK": FLAG_RERANK,
"FLAG_IMAGE": FLAG_IMAGE,
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
"FLAG_TTS": FLAG_TTS,
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
"FLAG_TOKENIZE": FLAG_TOKENIZE,
"FLAG_VAD": FLAG_VAD,
"FLAG_LLM": FLAG_LLM,
"FLAG_VIDEO": FLAG_VIDEO,
"FLAG_DETECTION": FLAG_DETECTION,
"FLAG_VISION": FLAG_VISION,
"FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION,
"FLAG_SPEAKER_RECOGNITION": FLAG_SPEAKER_RECOGNITION,
"FLAG_AUDIO_TRANSFORM": FLAG_AUDIO_TRANSFORM,
"FLAG_DIARIZATION": FLAG_DIARIZATION,
"FLAG_REALTIME_AUDIO": FLAG_REALTIME_AUDIO,
"FLAG_SCORE": FLAG_SCORE,
}
}
func stringToFlag(s string) string {
return "FLAG_" + strings.ToUpper(s)
}
func GetUsecasesFromYAML(input []string) *ModelConfigUsecase {
if len(input) == 0 {
return nil
}
result := FLAG_ANY
flags := GetAllModelConfigUsecases()
for _, str := range input {
for _, flag := range []string{stringToFlag(str), str} {
f, exists := flags[flag]
if exists {
result |= f
}
}
}
return &result
}
// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success.
//
// Declared known_usecases are normally additive — the guessing heuristic
// still adds whatever it can infer from backend/templates. The one
// exception is FLAG_SCORE: when the operator declared score, they
// reserved the model for the router classifier. Letting GuessUsecases
// paint chat/completion on top would surface it in chat pickers it was
// deliberately kept out of, and (on llama-cpp) reintroduce the slot
// contention the score/chat conflict check exists to prevent. So a
// declared score list is authoritative.
func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
if c.KnownUsecases != nil {
if (u & *c.KnownUsecases) == u {
return true
}
if (*c.KnownUsecases & FLAG_SCORE) == FLAG_SCORE {
return false
}
}
return c.GuessUsecases(u)
}
// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
// Backends that are clearly not text-generation
nonTextGenBackends := []string{
"whisper", "piper", "kokoro",
"diffusers", "stablediffusion", "stablediffusion-ggml",
"rerankers", "silero-vad", "rfdetr", "insightface", "speaker-recognition",
"transformers-musicgen", "ace-step", "acestep-cpp",
}
if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
return false
}
if slices.Contains(nonTextGenBackends, c.Backend) {
return false
}
if c.Embeddings != nil && *c.Embeddings {
return false
}
}
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
if c.TemplateConfig.Completion == "" {
return false
}
if slices.Contains(nonTextGenBackends, c.Backend) {
return false
}
}
if (u & FLAG_EDIT) == FLAG_EDIT {
if c.TemplateConfig.Edit == "" {
return false
}
}
if (u & FLAG_EMBEDDINGS) == FLAG_EMBEDDINGS {
if c.Embeddings == nil || !*c.Embeddings {
return false
}
}
if (u & FLAG_IMAGE) == FLAG_IMAGE {
imageBackends := []string{"diffusers", "stablediffusion", "stablediffusion-ggml"}
if !slices.Contains(imageBackends, c.Backend) {
return false
}
if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
return false
}
}
if (u & FLAG_VIDEO) == FLAG_VIDEO {
videoBackends := []string{"diffusers", "stablediffusion", "vllm-omni"}
if !slices.Contains(videoBackends, c.Backend) {
return false
}
if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
return false
}
}
if (u & FLAG_RERANK) == FLAG_RERANK {
if c.Backend != "rerankers" && (c.Reranking == nil || !*c.Reranking) {
return false
}
}
if (u & FLAG_TRANSCRIPT) == FLAG_TRANSCRIPT {
if c.Backend != "whisper" {
return false
}
// whisper models with vad_only option are VAD, not transcription
if slices.Contains(c.Options, "vad_only") {
return false
}
}
if (u & FLAG_TTS) == FLAG_TTS {
ttsBackends := []string{"piper", "transformers-musicgen", "kokoro"}
if !slices.Contains(ttsBackends, c.Backend) {
return false
}
}
if (u & FLAG_DETECTION) == FLAG_DETECTION {
detectionBackends := []string{"rfdetr", "sam3-cpp", "insightface"}
if !slices.Contains(detectionBackends, c.Backend) {
return false
}
}
if (u & FLAG_FACE_RECOGNITION) == FLAG_FACE_RECOGNITION {
faceBackends := []string{"insightface"}
if !slices.Contains(faceBackends, c.Backend) {
return false
}
}
if (u & FLAG_SPEAKER_RECOGNITION) == FLAG_SPEAKER_RECOGNITION {
speakerBackends := []string{"speaker-recognition"}
if !slices.Contains(speakerBackends, c.Backend) {
return false
}
}
if (u & FLAG_AUDIO_TRANSFORM) == FLAG_AUDIO_TRANSFORM {
audioTransformBackends := []string{"localvqe"}
if !slices.Contains(audioTransformBackends, c.Backend) {
return false
}
}
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
soundGenBackends := []string{"transformers-musicgen", "ace-step", "acestep-cpp", "mock-backend"}
if !slices.Contains(soundGenBackends, c.Backend) {
return false
}
}
if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
if !slices.Contains(tokenizeCapableBackends, c.Backend) {
return false
}
}
if (u & FLAG_VAD) == FLAG_VAD {
if c.Backend != "silero-vad" && c.Backend != "sherpa-onnx" && !(c.Backend == "whisper" && slices.Contains(c.Options, "vad_only")) {
return false
}
}
if (u & FLAG_DIARIZATION) == FLAG_DIARIZATION {
// vibevoice-cpp emits speaker-labelled segments natively from its
// ASR pass; sherpa-onnx pipes pyannote segmentation + speaker
// embeddings + clustering. Both surface as a Diarize gRPC.
diarizationBackends := []string{"vibevoice-cpp", "sherpa-onnx"}
if !slices.Contains(diarizationBackends, c.Backend) {
return false
}
}
if (u & FLAG_REALTIME_AUDIO) == FLAG_REALTIME_AUDIO {
// Backends that own a single any-to-any loop and implement
// AudioToAudioStream — listed here so models without an explicit
// known_usecases still surface on the Talk page.
realtimeAudioBackends := []string{"liquid-audio"}
if !slices.Contains(realtimeAudioBackends, c.Backend) {
return false
}
}
if (u & FLAG_SCORE) == FLAG_SCORE {
// No heuristic: Score-intent is a deliberate operator choice
// (it reserves the model from generation traffic on llama-cpp),
// so HasUsecases(FLAG_SCORE) is true only when KnownUsecases
// declares it explicitly.
return false
}
return true
}
// BuildCogitoOptions generates cogito options from the model configuration
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
cogitoOpts := []cogito.Option{
cogito.WithIterations(3), // default to 3 iterations
cogito.WithMaxAttempts(3), // default to 3 attempts
cogito.WithForceReasoning(),
}
// Apply agent configuration options
if c.Agent.EnableReasoning {
cogitoOpts = append(cogitoOpts, cogito.WithForceReasoning())
}
if c.Agent.EnablePlanning {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
}
if c.Agent.EnableMCPPrompts {
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
}
if c.Agent.EnablePlanReEvaluator {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
}
if c.Agent.MaxIterations != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
}
if c.Agent.MaxAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
}
if c.Agent.DisableSinkState {
cogitoOpts = append(cogitoOpts, cogito.DisableSinkState)
}
if c.Agent.LoopDetection != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithLoopDetection(c.Agent.LoopDetection))
}
if c.Agent.MaxAdjustmentAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAdjustmentAttempts(c.Agent.MaxAdjustmentAttempts))
}
if c.Agent.ForceReasoningTool {
cogitoOpts = append(cogitoOpts, cogito.WithForceReasoningTool())
}
return cogitoOpts
}