Files
LocalAI/core/config/model_config.go
Ettore Di Giacinto e86ade54a6 feat(api): add /v1/audio/diarization endpoint with sherpa-onnx + vibevoice.cpp (#9654)
* feat(api): add /v1/audio/diarization endpoint with sherpa-onnx + vibevoice.cpp

Closes #1648.

OpenAI-style multipart endpoint that returns "who spoke when". Single
endpoint instead of the issue's three-endpoint sketch (refactor /vad,
/vad/embedding, /diarization) — the typical client wants one call, and
embeddings can land later as a sibling without breaking this surface.

Response shape borrows from Pyannote/Deepgram: segments carry a
normalised SPEAKER_NN id (zero-padded, stable across the response) plus
the raw backend label, optional per-segment text when the backend bundles
ASR, and a speakers summary in verbose_json. response_format also accepts
rttm so consumers can pipe straight into pyannote.metrics / dscore.

Backends:

* vibevoice-cpp — Diarize() reuses the existing vv_capi_asr pass.
  vibevoice's ASR prompt asks the model to emit
  [{Start,End,Speaker,Content}] natively, so diarization is a by-product
  of the same pass; include_text=true preserves the transcript per
  segment, otherwise we drop it.

* sherpa-onnx — wraps the upstream SherpaOnnxOfflineSpeakerDiarization
  C API (pyannote segmentation + speaker-embedding extractor + fast
  clustering). libsherpa-shim grew config builders, a SetClustering
  wrapper for per-call num_clusters/threshold overrides, and a
  segment_at accessor (purego can't read field arrays out of
  SherpaOnnxOfflineSpeakerDiarizationSegment[] directly).

Plumbing: new Diarize gRPC RPC + DiarizeRequest / DiarizeSegment /
DiarizeResponse messages, threaded through interface.go, base, server,
client, embed. Default Base impl returns unimplemented.

Capability surfaces all updated: FLAG_DIARIZATION usecase,
FeatureAudioDiarization permission (default-on), RouteFeatureRegistry
entries for /v1/audio/diarization and /audio/diarization, audio
instruction-def description widened, CAP_DIARIZATION JS symbol,
swagger regenerated, /api/instructions discovery map updated.

Tests:

* core/backend: speaker-label normalisation (first-seen → SPEAKER_NN,
  per-speaker totals, nil-safety, fallback to backend NumSpeakers when
  no segments).

* core/http/endpoints/openai: RTTM rendering (file-id basename, negative
  duration clamping, fallback id).

* tests/e2e: mock-backend grew a deterministic Diarize that emits
  raw labels "5","2","5" so the e2e suite verifies SPEAKER_NN
  remapping, verbose_json speakers summary + transcript pass-through
  (gated by include_text), RTTM bytes content-type, and rejection of
  unknown response_format. mock-diarize model config registered with
  known_usecases=[FLAG_DIARIZATION] to bypass the backend-name guard.

Docs: new features/audio-diarization.md (request/response, RTTM example,
sherpa-onnx + vibevoice setup), cross-link from audio-to-text.md, entry
in whats-new.md.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

* fix(diarization): correct sherpa-onnx symbol name + lint cleanup

CI failures on #9654:

* sherpa-onnx-grpc-{tts,transcription} and sherpa-onnx-realtime panicked
  at backend startup with `undefined symbol: SherpaOnnxDestroyOfflineSpeakerDiarizationResult`.
  Upstream's actual symbol is SherpaOnnxOfflineSpeakerDiarizationDestroyResult
  (Destroy in the middle, not the prefix); the rest of the diarization
  surface follows the same naming pattern. The mismatched name made
  purego.RegisterLibFunc fail at dlopen time and crashed the gRPC server
  before the BeforeAll could probe Health, taking down every sherpa-onnx
  test job — not just the diarization-related ones.

* golangci-lint flagged 5 errcheck violations on new defer cleanups
  (os.RemoveAll / Close / conn.Close); wrap each in a `defer func() { _ = X() }()`
  closure (matches the pattern other LocalAI files use for new code, since
  pre-existing bare defers are grandfathered in via new-from-merge-base).

* golangci-lint also flagged forbidigo violations: the new
  diarization_test.go files used testing.T-style `t.Errorf` / `t.Fatalf`,
  which are forbidden by the project's coding-style policy
  (.agents/coding-style.md). Convert both files to Ginkgo/Gomega
  Describe/It with Expect(...) — they get picked up by the existing
  TestBackend / TestOpenAI suites, no new suite plumbing needed.

* modernize linter: tightened the diarization segment loop to
  `for i := range int(numSegments)` (Go 1.22+ idiom).

Verified locally: golangci-lint with new-from-merge-base=origin/master
reports 0 issues across all touched packages, and the four mocked
diarization e2e specs in tests/e2e/mock_backend_test.go still pass.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

* fix(vibevoice-cpp): convert non-WAV input via ffmpeg + raise ASR token budget

Confirmed end-to-end against a real LocalAI instance with vibevoice-asr-q4_k
loaded and the multi-speaker MP3 sample at vibevoice.cpp/samples/2p_argument.mp3:
both /v1/audio/transcriptions and /v1/audio/diarization now succeed and
return correctly attributed speaker turns for the full clip.

Two latent issues surfaced once the diarization endpoint actually exercised
the backend with a non-trivial input:

1. vv_capi_asr only accepts WAV via load_wav_24k_mono. The previous code
   passed the uploaded path straight through, so anything that wasn't
   already a 24 kHz mono s16le WAV failed at the C side with rc=-8 and
   the very unhelpful "vv_capi_asr failed". prepareWavInput shells out
   to ffmpeg ("-ar 24000 -ac 1 -acodec pcm_s16le") in a per-call temp
   dir, matching the rate the model was trained on; both AudioTranscription
   and Diarize now route through it. This is the same shape sherpa-onnx
   uses (utils.AudioToWav), but vibevoice needs 24 kHz rather than 16 kHz
   so we don't reuse that helper.

2. The C ABI's max_new_tokens defaults to 256 when 0 is passed. That's
   fine for a five-second clip but not for anything past ~10 s — vibevoice
   stops mid-JSON, the parse fails, and the caller sees a hard error.
   Pass a much larger budget (16 384 ≈ ~9 minutes of speech at the
   model's ~30 tok/s rate); generation stops at EOS so this is a cap
   rather than a target.

3. As a defensive belt-and-braces, mirror AudioTranscription's existing
   "fall back to a single segment if the model emits non-JSON text"
   pattern in Diarize, so partial / unusual model output never produces
   a 500. This kept the endpoint usable while diagnosing (1) and (2),
   and is the right behaviour to keep.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

* fix(vibevoice-cpp): pass valid WAVs through directly so ffmpeg is not required at runtime

Spotted by tests-e2e-backend (1.25.x): the previous fix forced every
incoming audio file through `ffmpeg -ar 24000 ...`, which meant the
backend container — which does not ship ffmpeg — failed even for the
existing happy path where the caller already uploads a WAV. The
container-side error was:

    rpc error: code = Unknown desc = vibevoice-cpp: ffmpeg convert to
    24k mono wav: exec: "ffmpeg": executable file not found in $PATH

Reading vibevoice.cpp's audio_io.cpp, `load_wav_24k_mono` uses drwav and
already accepts any PCM/IEEE-float WAV at any sample rate, downmixes
multi-channel input to mono, and resamples to 24 kHz internally. So the
only inputs that genuinely need an external converter are non-WAV
formats (MP3, OGG, FLAC, ...).

Detect WAVs by RIFF/WAVE magic at bytes 0..3 / 8..11 and pass them
straight through with a no-op cleanup; everything else still goes
through ffmpeg with the same 24 kHz mono s16le target. The result:

* Container builds without ffmpeg keep working for WAV uploads
  (the e2e-backends fixture is jfk.wav at 16 kHz mono s16le).
* MP3 and other non-WAV inputs still get the new ffmpeg conversion
  path so the diarization endpoint stays useful.
* If the caller uploads a non-WAV but ffmpeg isn't on PATH, the
  surfaced error is still descriptive enough to act on.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

* fix(ci): make gcc-14 install in Dockerfile.golang best-effort for jammy bases

The LocalVQE PR (bb033b16) made `gcc-14 g++-14` an unconditional apt
install in backend/Dockerfile.golang and pointed update-alternatives at
them. That works on the default `BASE_IMAGE=ubuntu:24.04` (noble has
gcc-14 in main), but every Go backend that builds on
`nvcr.io/nvidia/l4t-jetpack:r36.4.0` — jammy under the hood — now fails
at the apt step:

    E: Unable to locate package gcc-14

This blocked unrelated jobs:
backend-jobs(*-nvidia-l4t-arm64-{stablediffusion-ggml, sam3-cpp, whisper,
acestep-cpp, qwen3-tts-cpp, vibevoice-cpp}). LocalVQE itself is only
matrix-built on ubuntu:24.04 (CPU + Vulkan), so it doesn't actually
need gcc-14 anywhere else.

Make the gcc-14 install conditional on the package being available in
the configured apt repos. On noble: identical behaviour to today (gcc-14
installed, update-alternatives points at it). On jammy: skip the
gcc-14 stanza entirely and let build-essential's default gcc take over,
which is what the other Go backends compile with anyway.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-05-05 15:10:13 +02:00

894 lines
33 KiB
Go

package config
import (
"encoding/json"
"fmt"
"os"
"regexp"
"slices"
"strings"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/reasoning"
"github.com/mudler/cogito"
"gopkg.in/yaml.v3"
)
const (
RAND_SEED = -1
)
// @Description TTS configuration
type TTSConfig struct {
// Voice wav path or id
Voice string `yaml:"voice,omitempty" json:"voice,omitempty"`
AudioPath string `yaml:"audio_path,omitempty" json:"audio_path,omitempty"`
}
// @Description ModelConfig represents a model configuration
type ModelConfig struct {
modelConfigFile string `yaml:"-" json:"-"`
modelTemplate string `yaml:"-" json:"-"`
schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"`
Name string `yaml:"name,omitempty" json:"name,omitempty"`
F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"`
Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"`
Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"`
Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"`
Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"`
Backend string `yaml:"backend,omitempty" json:"backend,omitempty"`
TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"`
KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"`
KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"`
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
InputToken [][]int `yaml:"-" json:"-"`
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
ResponseFormat string `yaml:"-" json:"-"`
ResponseFormatMap map[string]any `yaml:"-" json:"-"`
// MediaMarker is the runtime-discovered multimodal marker the backend expects
// in the prompt (e.g. "<__media__>" or a random "<__media_<rand>__>" picked by
// llama.cpp). Populated on first successful ModelMetadata call. Empty until
// then — callers must fall back to templates.DefaultMultiMediaMarker.
MediaMarker string `yaml:"-" json:"-"`
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline" json:",inline"`
// Diffusers
Diffusers Diffusers `yaml:"diffusers,omitempty" json:"diffusers,omitempty"`
Step int `yaml:"step,omitempty" json:"step,omitempty"`
// GRPC Options
GRPC GRPC `yaml:"grpc,omitempty" json:"grpc,omitempty"`
// TTS specifics
TTSConfig `yaml:"tts,omitempty" json:"tts,omitempty"`
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
DownloadFiles []File `yaml:"download_files,omitempty" json:"download_files,omitempty"`
Description string `yaml:"description,omitempty" json:"description,omitempty"`
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"`
Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"`
// ConcurrencyGroups declares per-node mutual-exclusion groups: the model
// cannot be loaded alongside another model that shares any group name.
// See docs/content/advanced/vram-management.md for usage.
ConcurrencyGroups []string `yaml:"concurrency_groups,omitempty" json:"concurrency_groups,omitempty"`
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
}
// @Description MCP configuration
type MCPConfig struct {
Servers string `yaml:"remote,omitempty" json:"remote,omitempty"`
Stdio string `yaml:"stdio,omitempty" json:"stdio,omitempty"`
}
// @Description Agent configuration
type AgentConfig struct {
MaxAttempts int `yaml:"max_attempts,omitempty" json:"max_attempts,omitempty"`
MaxIterations int `yaml:"max_iterations,omitempty" json:"max_iterations,omitempty"`
EnableReasoning bool `yaml:"enable_reasoning,omitempty" json:"enable_reasoning,omitempty"`
EnablePlanning bool `yaml:"enable_planning,omitempty" json:"enable_planning,omitempty"`
EnableMCPPrompts bool `yaml:"enable_mcp_prompts,omitempty" json:"enable_mcp_prompts,omitempty"`
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator,omitempty" json:"enable_plan_re_evaluator,omitempty"`
DisableSinkState bool `yaml:"disable_sink_state,omitempty" json:"disable_sink_state,omitempty"`
LoopDetection int `yaml:"loop_detection,omitempty" json:"loop_detection,omitempty"`
MaxAdjustmentAttempts int `yaml:"max_adjustment_attempts,omitempty" json:"max_adjustment_attempts,omitempty"`
ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"`
}
// HasMCPServers returns true if any MCP servers (remote or stdio) are configured.
func (c MCPConfig) HasMCPServers() bool {
return c.Servers != "" || c.Stdio != ""
}
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
var remote MCPGenericConfig[MCPRemoteServers]
var stdio MCPGenericConfig[MCPSTDIOServers]
if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
return remote, stdio, err
}
if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
return remote, stdio, err
}
return remote, stdio, nil
}
// @Description MCP generic configuration
type MCPGenericConfig[T any] struct {
Servers T `yaml:"mcpServers,omitempty" json:"mcpServers,omitempty"`
}
type MCPRemoteServers map[string]MCPRemoteServer
type MCPSTDIOServers map[string]MCPSTDIOServer
// @Description MCP remote server configuration
type MCPRemoteServer struct {
URL string `json:"url,omitempty"`
Token string `json:"token,omitempty"`
}
// @Description MCP STDIO server configuration
type MCPSTDIOServer struct {
Args []string `json:"args,omitempty"`
Env map[string]string `json:"env,omitempty"`
Command string `json:"command,omitempty"`
}
// @Description Pipeline defines other models to use for audio-to-audio
type Pipeline struct {
TTS string `yaml:"tts,omitempty" json:"tts,omitempty"`
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
}
// @Description File configuration for model downloads
type File struct {
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`
SHA256 string `yaml:"sha256,omitempty" json:"sha256,omitempty"`
URI downloader.URI `yaml:"uri,omitempty" json:"uri,omitempty"`
}
type FeatureFlag map[string]*bool
func (ff FeatureFlag) Enabled(s string) bool {
if v, exists := ff[s]; exists && v != nil {
return *v
}
return false
}
// @Description GRPC configuration
type GRPC struct {
Attempts int `yaml:"attempts,omitempty" json:"attempts,omitempty"`
AttemptsSleepTime int `yaml:"attempts_sleep_time,omitempty" json:"attempts_sleep_time,omitempty"`
}
// @Description Diffusers configuration
type Diffusers struct {
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
PipelineType string `yaml:"pipeline_type,omitempty" json:"pipeline_type,omitempty"`
SchedulerType string `yaml:"scheduler_type,omitempty" json:"scheduler_type,omitempty"`
EnableParameters string `yaml:"enable_parameters,omitempty" json:"enable_parameters,omitempty"` // A list of comma separated parameters to specify
IMG2IMG bool `yaml:"img2img,omitempty" json:"img2img,omitempty"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip,omitempty" json:"clip_skip,omitempty"` // Skip every N frames
ClipModel string `yaml:"clip_model,omitempty" json:"clip_model,omitempty"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder,omitempty" json:"clip_subfolder,omitempty"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net,omitempty" json:"control_net,omitempty"`
}
// @Description LLMConfig is a struct that holds the configuration that are generic for most of the LLM backends.
type LLMConfig struct {
SystemPrompt string `yaml:"system_prompt,omitempty" json:"system_prompt,omitempty"`
TensorSplit string `yaml:"tensor_split,omitempty" json:"tensor_split,omitempty"`
MainGPU string `yaml:"main_gpu,omitempty" json:"main_gpu,omitempty"`
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
Mirostat *int `yaml:"mirostat,omitempty" json:"mirostat,omitempty"`
NGPULayers *int `yaml:"gpu_layers,omitempty" json:"gpu_layers,omitempty"`
MMap *bool `yaml:"mmap,omitempty" json:"mmap,omitempty"`
MMlock *bool `yaml:"mmlock,omitempty" json:"mmlock,omitempty"`
LowVRAM *bool `yaml:"low_vram,omitempty" json:"low_vram,omitempty"`
Reranking *bool `yaml:"reranking,omitempty" json:"reranking,omitempty"`
Grammar string `yaml:"grammar,omitempty" json:"grammar,omitempty"`
StopWords []string `yaml:"stopwords,omitempty" json:"stopwords,omitempty"`
Cutstrings []string `yaml:"cutstrings,omitempty" json:"cutstrings,omitempty"`
ExtractRegex []string `yaml:"extract_regex,omitempty" json:"extract_regex,omitempty"`
TrimSpace []string `yaml:"trimspace,omitempty" json:"trimspace,omitempty"`
TrimSuffix []string `yaml:"trimsuffix,omitempty" json:"trimsuffix,omitempty"`
ContextSize *int `yaml:"context_size,omitempty" json:"context_size,omitempty"`
NUMA bool `yaml:"numa,omitempty" json:"numa,omitempty"`
LoraAdapter string `yaml:"lora_adapter,omitempty" json:"lora_adapter,omitempty"`
LoraBase string `yaml:"lora_base,omitempty" json:"lora_base,omitempty"`
LoraAdapters []string `yaml:"lora_adapters,omitempty" json:"lora_adapters,omitempty"`
LoraScales []float32 `yaml:"lora_scales,omitempty" json:"lora_scales,omitempty"`
LoraScale float32 `yaml:"lora_scale,omitempty" json:"lora_scale,omitempty"`
NoMulMatQ bool `yaml:"no_mulmatq,omitempty" json:"no_mulmatq,omitempty"`
DraftModel string `yaml:"draft_model,omitempty" json:"draft_model,omitempty"`
NDraft int32 `yaml:"n_draft,omitempty" json:"n_draft,omitempty"`
Quantization string `yaml:"quantization,omitempty" json:"quantization,omitempty"`
LoadFormat string `yaml:"load_format,omitempty" json:"load_format,omitempty"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization,omitempty" json:"gpu_memory_utilization,omitempty"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code,omitempty" json:"trust_remote_code,omitempty"` // vLLM
EnforceEager bool `yaml:"enforce_eager,omitempty" json:"enforce_eager,omitempty"` // vLLM
SwapSpace int `yaml:"swap_space,omitempty" json:"swap_space,omitempty"` // vLLM
MaxModelLen int `yaml:"max_model_len,omitempty" json:"max_model_len,omitempty"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size,omitempty" json:"tensor_parallel_size,omitempty"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM
DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM
// EngineArgs is a backend-native passthrough applied to the engine constructor
// (e.g. vLLM AsyncEngineArgs). Values may be primitives or nested maps; nested
// maps materialise into the backend's nested config dataclasses (e.g.
// SpeculativeConfig, KVTransferConfig, CompilationConfig). Unknown keys cause
// the backend to fail LoadModel with a list of valid names.
EngineArgs map[string]any `yaml:"engine_args,omitempty" json:"engine_args,omitempty"`
MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"`
FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"`
NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"`
CacheTypeK string `yaml:"cache_type_k,omitempty" json:"cache_type_k,omitempty"`
CacheTypeV string `yaml:"cache_type_v,omitempty" json:"cache_type_v,omitempty"`
RopeScaling string `yaml:"rope_scaling,omitempty" json:"rope_scaling,omitempty"`
ModelType string `yaml:"type,omitempty" json:"type,omitempty"`
YarnExtFactor float32 `yaml:"yarn_ext_factor,omitempty" json:"yarn_ext_factor,omitempty"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor,omitempty" json:"yarn_attn_factor,omitempty"`
YarnBetaFast float32 `yaml:"yarn_beta_fast,omitempty" json:"yarn_beta_fast,omitempty"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow,omitempty" json:"yarn_beta_slow,omitempty"`
CFGScale float32 `yaml:"cfg_scale,omitempty" json:"cfg_scale,omitempty"` // Classifier-Free Guidance Scale
}
// @Description LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
type LimitMMPerPrompt struct {
LimitImagePerPrompt int `yaml:"image,omitempty" json:"image,omitempty"`
LimitVideoPerPrompt int `yaml:"video,omitempty" json:"video,omitempty"`
LimitAudioPerPrompt int `yaml:"audio,omitempty" json:"audio,omitempty"`
}
// @Description TemplateConfig is a struct that holds the configuration of the templating system
type TemplateConfig struct {
// Chat is the template used in the chat completion endpoint
Chat string `yaml:"chat,omitempty" json:"chat,omitempty"`
// ChatMessage is the template used for chat messages
ChatMessage string `yaml:"chat_message,omitempty" json:"chat_message,omitempty"`
// Completion is the template used for completion requests
Completion string `yaml:"completion,omitempty" json:"completion,omitempty"`
// Edit is the template used for edit completion requests
Edit string `yaml:"edit,omitempty" json:"edit,omitempty"`
// Functions is the template used when tools are present in the client requests
Functions string `yaml:"function,omitempty" json:"function,omitempty"`
// UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
// Note: this is mostly consumed for backends such as vllm and transformers
// that can use the tokenizers specified in the JSON config files of the models
UseTokenizerTemplate bool `yaml:"use_tokenizer_template,omitempty" json:"use_tokenizer_template,omitempty"`
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
// It defaults to \n
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character,omitempty" json:"join_chat_messages_by_character,omitempty"`
Multimodal string `yaml:"multimodal,omitempty" json:"multimodal,omitempty"`
ReplyPrefix string `yaml:"reply_prefix,omitempty" json:"reply_prefix,omitempty"`
}
func (c *ModelConfig) syncKnownUsecasesFromString() {
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
// Make sure the usecases are valid, we rewrite with what we identified
c.KnownUsecaseStrings = []string{}
for k, usecase := range GetAllModelConfigUsecases() {
if c.HasUsecases(usecase) {
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
}
}
}
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
type BCAlias ModelConfig
var aux BCAlias
if err := value.Decode(&aux); err != nil {
return err
}
mc := ModelConfig(aux)
*c = mc
c.syncKnownUsecasesFromString()
return nil
}
func (c *ModelConfig) SetFunctionCallString(s string) {
c.functionCallString = s
}
func (c *ModelConfig) SetFunctionCallNameString(s string) {
c.functionCallNameString = s
}
func (c *ModelConfig) ShouldUseFunctions() bool {
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
}
func (c *ModelConfig) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}
// MMProjFileName returns the filename of the MMProj file
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
func (c *ModelConfig) MMProjFileName() string {
uri := downloader.URI(c.MMProj)
if uri.LooksLikeURL() {
f, _ := uri.FilenameFromUrl()
return f
}
return c.MMProj
}
func (c *ModelConfig) IsMMProjURL() bool {
uri := downloader.URI(c.MMProj)
return uri.LooksLikeURL()
}
func (c *ModelConfig) IsModelURL() bool {
uri := downloader.URI(c.Model)
return uri.LooksLikeURL()
}
// ModelFileName returns the filename of the model
// If the model is a URL, it will return the MD5 of the URL which is the filename
func (c *ModelConfig) ModelFileName() string {
uri := downloader.URI(c.Model)
if uri.LooksLikeURL() {
f, _ := uri.FilenameFromUrl()
return f
}
return c.Model
}
func (c *ModelConfig) FunctionToCall() string {
if c.functionCallNameString != "" &&
c.functionCallNameString != "none" && c.functionCallNameString != "auto" {
return c.functionCallNameString
}
return c.functionCallString
}
func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
lo := &LoadOptions{}
lo.Apply(opts...)
ctx := lo.ctxSize
threads := lo.threads
f16 := lo.f16
debug := lo.debug
// Apply model-family-specific inference defaults before generic fallbacks.
// This ensures gallery-installed and runtime-loaded models get optimal parameters.
ApplyInferenceDefaults(cfg, cfg.Name, cfg.Model)
// https://github.com/ggerganov/llama.cpp/blob/75cd4c77292034ecec587ecb401366f57338f7c0/common/sampling.h#L22
defaultTopP := 0.95
defaultTopK := 40
defaultMinP := 0.0
defaultTemp := 0.9
// https://github.com/mudler/LocalAI/issues/2780
defaultMirostat := 0
defaultMirostatTAU := 5.0
defaultMirostatETA := 0.1
defaultTypicalP := 1.0
defaultTFZ := 1.0
defaultZero := 0
trueV := true
falseV := false
if cfg.Seed == nil {
// random number generator seed
defaultSeed := RAND_SEED
cfg.Seed = &defaultSeed
}
if cfg.TopK == nil {
cfg.TopK = &defaultTopK
}
if cfg.MinP == nil {
cfg.MinP = &defaultMinP
}
if cfg.TypicalP == nil {
cfg.TypicalP = &defaultTypicalP
}
if cfg.TFZ == nil {
cfg.TFZ = &defaultTFZ
}
if cfg.MMap == nil {
// MMap is enabled by default
// Only exception is for Intel GPUs
if os.Getenv("XPU") != "" {
cfg.MMap = &falseV
} else {
cfg.MMap = &trueV
}
}
if cfg.MMlock == nil {
// MMlock is disabled by default
cfg.MMlock = &falseV
}
if cfg.TopP == nil {
cfg.TopP = &defaultTopP
}
if cfg.Temperature == nil {
cfg.Temperature = &defaultTemp
}
if cfg.Maxtokens == nil {
cfg.Maxtokens = &defaultZero
}
if cfg.Mirostat == nil {
cfg.Mirostat = &defaultMirostat
}
if cfg.MirostatETA == nil {
cfg.MirostatETA = &defaultMirostatETA
}
if cfg.MirostatTAU == nil {
cfg.MirostatTAU = &defaultMirostatTAU
}
if cfg.LowVRAM == nil {
cfg.LowVRAM = &falseV
}
if cfg.Embeddings == nil {
cfg.Embeddings = &falseV
}
if cfg.Reranking == nil {
cfg.Reranking = &falseV
}
if threads == 0 {
// Threads can't be 0
threads = 4
}
if cfg.Threads == nil {
cfg.Threads = &threads
}
if cfg.F16 == nil {
cfg.F16 = &f16
}
if cfg.Debug == nil {
cfg.Debug = &falseV
}
if debug {
cfg.Debug = &trueV
}
// If a context size was provided via LoadOptions, apply it before hooks so they
// don't override it with their own defaults.
if ctx != 0 && cfg.ContextSize == nil {
cfg.ContextSize = &ctx
}
runBackendHooks(cfg, lo.modelPath)
cfg.syncKnownUsecasesFromString()
}
func (c *ModelConfig) Validate() (bool, error) {
downloadedFileNames := []string{}
for _, f := range c.DownloadFiles {
downloadedFileNames = append(downloadedFileNames, f.Filename)
}
validationTargets := []string{c.Backend, c.Model, c.MMProj}
validationTargets = append(validationTargets, downloadedFileNames...)
// Simple validation to make sure the model can be correctly loaded
for _, n := range validationTargets {
if n == "" {
continue
}
if strings.HasPrefix(n, string(os.PathSeparator)) ||
strings.Contains(n, "..") {
return false, fmt.Errorf("invalid file path: %s", n)
}
}
if c.Backend != "" {
// a regex that checks that is a string name with no special characters, except '-' and '_'
re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`)
if !re.MatchString(c.Backend) {
return false, fmt.Errorf("invalid backend name: %s", c.Backend)
}
}
// Validate MCP configuration if present
if c.MCP.Servers != "" || c.MCP.Stdio != "" {
if _, _, err := c.MCP.MCPConfigFromYAML(); err != nil {
return false, fmt.Errorf("invalid MCP configuration: %w", err)
}
}
// engine_args crosses the gRPC boundary as a JSON-encoded string. Reject
// unmarshalable values here so a config that would silently lose user-set
// options at load time is rejected at parse time instead.
if len(c.EngineArgs) > 0 {
if _, err := json.Marshal(c.EngineArgs); err != nil {
return false, fmt.Errorf("engine_args is not JSON-serialisable: %w", err)
}
}
return true, nil
}
func (c *ModelConfig) HasTemplate() bool {
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate
}
func (c *ModelConfig) GetModelConfigFile() string {
return c.modelConfigFile
}
// GetModelTemplate returns the model's chat template if available
func (c *ModelConfig) GetModelTemplate() string {
return c.modelTemplate
}
// IsDisabled returns true if the model is disabled
func (c *ModelConfig) IsDisabled() bool {
return c.Disabled != nil && *c.Disabled
}
// IsPinned returns true if the model is pinned (excluded from idle unloading and eviction)
func (c *ModelConfig) IsPinned() bool {
return c.Pinned != nil && *c.Pinned
}
// GetConcurrencyGroups returns the model's concurrency groups, normalized:
// trimmed of whitespace, empty entries dropped, deduped. Returns nil when no
// effective groups remain. The result is a fresh slice; the caller may
// mutate it without affecting the config.
func (c *ModelConfig) GetConcurrencyGroups() []string {
if len(c.ConcurrencyGroups) == 0 {
return nil
}
out := make([]string, 0, len(c.ConcurrencyGroups))
for _, g := range c.ConcurrencyGroups {
g = strings.TrimSpace(g)
if g == "" || slices.Contains(out, g) {
continue
}
out = append(out, g)
}
if len(out) == 0 {
return nil
}
return out
}
type ModelConfigUsecase int
const (
FLAG_ANY ModelConfigUsecase = 0b000000000000
FLAG_CHAT ModelConfigUsecase = 0b000000000001
FLAG_COMPLETION ModelConfigUsecase = 0b000000000010
FLAG_EDIT ModelConfigUsecase = 0b000000000100
FLAG_EMBEDDINGS ModelConfigUsecase = 0b000000001000
FLAG_RERANK ModelConfigUsecase = 0b000000010000
FLAG_IMAGE ModelConfigUsecase = 0b000000100000
FLAG_TRANSCRIPT ModelConfigUsecase = 0b000001000000
FLAG_TTS ModelConfigUsecase = 0b000010000000
FLAG_SOUND_GENERATION ModelConfigUsecase = 0b000100000000
FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000
FLAG_VAD ModelConfigUsecase = 0b010000000000
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
FLAG_DETECTION ModelConfigUsecase = 0b1000000000000
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b10000000000000
FLAG_SPEAKER_RECOGNITION ModelConfigUsecase = 0b100000000000000
FLAG_AUDIO_TRANSFORM ModelConfigUsecase = 0b1000000000000000
FLAG_DIARIZATION ModelConfigUsecase = 0b10000000000000000
// Common Subsets
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
)
func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
return map[string]ModelConfigUsecase{
// Note: FLAG_ANY is intentionally excluded from this map
// because it's 0 and would always match in HasUsecases checks
"FLAG_CHAT": FLAG_CHAT,
"FLAG_COMPLETION": FLAG_COMPLETION,
"FLAG_EDIT": FLAG_EDIT,
"FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
"FLAG_RERANK": FLAG_RERANK,
"FLAG_IMAGE": FLAG_IMAGE,
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
"FLAG_TTS": FLAG_TTS,
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
"FLAG_TOKENIZE": FLAG_TOKENIZE,
"FLAG_VAD": FLAG_VAD,
"FLAG_LLM": FLAG_LLM,
"FLAG_VIDEO": FLAG_VIDEO,
"FLAG_DETECTION": FLAG_DETECTION,
"FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION,
"FLAG_SPEAKER_RECOGNITION": FLAG_SPEAKER_RECOGNITION,
"FLAG_AUDIO_TRANSFORM": FLAG_AUDIO_TRANSFORM,
"FLAG_DIARIZATION": FLAG_DIARIZATION,
}
}
func stringToFlag(s string) string {
return "FLAG_" + strings.ToUpper(s)
}
func GetUsecasesFromYAML(input []string) *ModelConfigUsecase {
if len(input) == 0 {
return nil
}
result := FLAG_ANY
flags := GetAllModelConfigUsecases()
for _, str := range input {
for _, flag := range []string{stringToFlag(str), str} {
f, exists := flags[flag]
if exists {
result |= f
}
}
}
return &result
}
// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success.
func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
return true
}
return c.GuessUsecases(u)
}
// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
// Backends that are clearly not text-generation
nonTextGenBackends := []string{
"whisper", "piper", "kokoro",
"diffusers", "stablediffusion", "stablediffusion-ggml",
"rerankers", "silero-vad", "rfdetr", "insightface", "speaker-recognition",
"transformers-musicgen", "ace-step", "acestep-cpp",
}
if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
return false
}
if slices.Contains(nonTextGenBackends, c.Backend) {
return false
}
if c.Embeddings != nil && *c.Embeddings {
return false
}
}
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
if c.TemplateConfig.Completion == "" {
return false
}
if slices.Contains(nonTextGenBackends, c.Backend) {
return false
}
}
if (u & FLAG_EDIT) == FLAG_EDIT {
if c.TemplateConfig.Edit == "" {
return false
}
}
if (u & FLAG_EMBEDDINGS) == FLAG_EMBEDDINGS {
if c.Embeddings == nil || !*c.Embeddings {
return false
}
}
if (u & FLAG_IMAGE) == FLAG_IMAGE {
imageBackends := []string{"diffusers", "stablediffusion", "stablediffusion-ggml"}
if !slices.Contains(imageBackends, c.Backend) {
return false
}
if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
return false
}
}
if (u & FLAG_VIDEO) == FLAG_VIDEO {
videoBackends := []string{"diffusers", "stablediffusion", "vllm-omni"}
if !slices.Contains(videoBackends, c.Backend) {
return false
}
if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
return false
}
}
if (u & FLAG_RERANK) == FLAG_RERANK {
if c.Backend != "rerankers" && (c.Reranking == nil || !*c.Reranking) {
return false
}
}
if (u & FLAG_TRANSCRIPT) == FLAG_TRANSCRIPT {
if c.Backend != "whisper" {
return false
}
// whisper models with vad_only option are VAD, not transcription
if slices.Contains(c.Options, "vad_only") {
return false
}
}
if (u & FLAG_TTS) == FLAG_TTS {
ttsBackends := []string{"piper", "transformers-musicgen", "kokoro"}
if !slices.Contains(ttsBackends, c.Backend) {
return false
}
}
if (u & FLAG_DETECTION) == FLAG_DETECTION {
detectionBackends := []string{"rfdetr", "sam3-cpp", "insightface"}
if !slices.Contains(detectionBackends, c.Backend) {
return false
}
}
if (u & FLAG_FACE_RECOGNITION) == FLAG_FACE_RECOGNITION {
faceBackends := []string{"insightface"}
if !slices.Contains(faceBackends, c.Backend) {
return false
}
}
if (u & FLAG_SPEAKER_RECOGNITION) == FLAG_SPEAKER_RECOGNITION {
speakerBackends := []string{"speaker-recognition"}
if !slices.Contains(speakerBackends, c.Backend) {
return false
}
}
if (u & FLAG_AUDIO_TRANSFORM) == FLAG_AUDIO_TRANSFORM {
audioTransformBackends := []string{"localvqe"}
if !slices.Contains(audioTransformBackends, c.Backend) {
return false
}
}
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
soundGenBackends := []string{"transformers-musicgen", "ace-step", "acestep-cpp", "mock-backend"}
if !slices.Contains(soundGenBackends, c.Backend) {
return false
}
}
if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
if !slices.Contains(tokenizeCapableBackends, c.Backend) {
return false
}
}
if (u & FLAG_VAD) == FLAG_VAD {
if c.Backend != "silero-vad" && c.Backend != "sherpa-onnx" && !(c.Backend == "whisper" && slices.Contains(c.Options, "vad_only")) {
return false
}
}
if (u & FLAG_DIARIZATION) == FLAG_DIARIZATION {
// vibevoice-cpp emits speaker-labelled segments natively from its
// ASR pass; sherpa-onnx pipes pyannote segmentation + speaker
// embeddings + clustering. Both surface as a Diarize gRPC.
diarizationBackends := []string{"vibevoice-cpp", "sherpa-onnx"}
if !slices.Contains(diarizationBackends, c.Backend) {
return false
}
}
return true
}
// BuildCogitoOptions generates cogito options from the model configuration
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
cogitoOpts := []cogito.Option{
cogito.WithIterations(3), // default to 3 iterations
cogito.WithMaxAttempts(3), // default to 3 attempts
cogito.WithForceReasoning(),
}
// Apply agent configuration options
if c.Agent.EnableReasoning {
cogitoOpts = append(cogitoOpts, cogito.WithForceReasoning())
}
if c.Agent.EnablePlanning {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
}
if c.Agent.EnableMCPPrompts {
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
}
if c.Agent.EnablePlanReEvaluator {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
}
if c.Agent.MaxIterations != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
}
if c.Agent.MaxAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
}
if c.Agent.DisableSinkState {
cogitoOpts = append(cogitoOpts, cogito.DisableSinkState)
}
if c.Agent.LoopDetection != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithLoopDetection(c.Agent.LoopDetection))
}
if c.Agent.MaxAdjustmentAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAdjustmentAttempts(c.Agent.MaxAdjustmentAttempts))
}
if c.Agent.ForceReasoningTool {
cogitoOpts = append(cogitoOpts, cogito.WithForceReasoningTool())
}
return cogitoOpts
}