Compare commits

...

50 Commits

Author SHA1 Message Date
Ettore Di Giacinto
e4fec35772 fix(review): address self-review findings on the distributed install fixes
Three findings from an adversarial review of this branch:

1. CRITICAL - OpCache.GetStatus crashed under concurrent load. m.Map() returns
   the live internal map by reference, so deleting from it on the read path was
   an unsynchronized write to a map four HTTP handlers poll every ~1s -> a
   'concurrent map writes' fatal. Rewritten to iterate a Keys() snapshot, build
   a fresh result map, and apply evictions via the locked DeleteUUID after the
   loop. Added a -race concurrency regression guard.

2. HIGH - GetStatus evicted failed ops too, hiding them from /api/operations
   and breaking the dismiss-failed-op flow (the panel keeps Error != nil ops so
   the admin can read the error and click Dismiss). Eviction now fires only for
   terminal ops with Error == nil (success/cancelled); failures are retained.

3. MEDIUM - DeleteStalePendingBackendOps missed StatusUnhealthy nodes. A node
   marked unhealthy on a NATS ErrNoResponders never transitions to offline
   (health.go skips re-marking it), so its pending ops leaked exactly like the
   offline case. Unhealthy is now reaped via the same stale-heartbeat grace path
   (a fresh-heartbeat node is recovering and keeps its op).

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-07 23:31:53 +00:00
Ettore Di Giacinto
05c0a08e24 fix(galleryop): persist cancellation + periodically reap orphaned ops
Two distributed gaps surfaced when a replica was killed mid-upgrade on a live
cluster, leaving the backend stuck 'processing' in the UI forever:

1. CancelOperation flipped the in-memory status to cancelled and broadcast a
   NATS event but never persisted the terminal status. On the next replica
   restart the still-active row re-hydrated straight back into
   processingBackends and the UI spun again. It now calls store.Cancel(id) so
   the cancel survives a restart.

2. CleanStale (which marks abandoned active ops failed) only ran once on
   startup, so an op orphaned AFTER startup - its owning replica's foreground
   handler goroutine gone - was never reaped until the next restart. Add
   GalleryService.ReapStaleOperations and run it on a 15m ticker (CleanStale
   now returns the reaped count for observability).

Neither is covered by the OpCache self-evict fix: an orphaned op never reaches
Processed, so it would never self-evict.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-07 23:10:53 +00:00
Ettore Di Giacinto
7824105a31 fix(nodes): stream per-node progress during backend upgrade
The install dispatch subscribed to a per-op progress subject and streamed
per-node download ticks; the upgrade dispatch did a bare 15-minute blocking
NATS round-trip with no subscription, so the UI showed progress:0 the whole
time (the 'reinstalling but nothing happens' report on a slow node).

Thread the op ID through BackendManager.UpgradeBackend -> the distributed
manager -> the adapter, and have the adapter subscribe to the per-op progress
subject before the request (extracted into a shared subscribeProgress helper
reused by install/upgrade/force-fallback). The worker's upgradeBackend now
creates the same DebouncedInstallProgressPublisher installBackend uses. An
upgrade is a force-reinstall, so it reuses SubjectNodeBackendInstallProgress
rather than minting a new subject - no new NATS permission, no new
rolling-update compat surface. Reconciler-driven retries pass empty
opID/onProgress and stay on the silent path.

Reproduced on a live cluster: upgrade of llama-cpp-development on agx-orin-slow
sat at progress:0 for 4+ minutes with no per-node feedback.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-07 22:58:09 +00:00
Ettore Di Giacinto
47fa847d55 fix(nodes): clear pending backend ops behind offline/draining nodes
ListDuePendingBackendOps filters status=healthy, so a backend op queued against
a node that went offline (stale heartbeat) or draining (admin action) was never
retried, aged out, or deleted - it leaked forever and kept the UI operation
spinning. Add DeleteStalePendingBackendOps and run it each reconcile pass:
draining nodes are cleared immediately (model rows already purged), offline
nodes once their heartbeat is older than a grace window (blip protection).

Reproduced on a live cluster: orphaned llama-cpp install rows targeting an
offline (nvidia-thor) and a draining (mac-mini-m4) node sat at attempts=0
indefinitely.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-07 22:37:24 +00:00
Ettore Di Giacinto
9a7ebc1151 fix(galleryop): self-evict terminal ops from OpCache.GetStatus
The processingBackends map (the UI 'reinstalling' spinner source) only cleared
an op when a client polled /api/backends/job/:uid. The Manage-page Reinstall and
Upgrade buttons never poll, so completed installs leaked into processingBackends
forever and the backend card spun 'reinstalling' even though the install had
finished. Evict terminal ops on the list read instead; DeleteUUID already
broadcasts the eviction so peer replicas converge.

Reproduced on a live 5-node distributed cluster: 5 backends sat in
processingBackends with underlying jobs reporting completed:true,progress:100.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-07 22:31:04 +00:00
LocalAI [bot]
f6cc90d258 chore: ⬆️ Update mudler/parakeet.cpp to e270af73b94c9a5c37ec516230219ed4580e1db6 (#10212)
⬆️ Update mudler/parakeet.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-07 23:52:44 +02:00
Adira
2c804bef5a fix(config): skip vocab arrays and mmap GGUF headers to speed up startup (#10213)
When the models directory holds many GGUF files, startup parsed every
model's full GGUF — including the tokenizer vocab arrays
(tokenizer.ggml.tokens/scores/merges, often >100k entries) — once per
model while guessing defaults. On slow storage (e.g. a models directory
on a Docker volume) those hundreds of thousands of tiny reads dominate
boot time before the HTTP server comes up.

The default-guessing path and the VRAM metadata reader only consume
scalar metadata and array lengths, never the array contents. Parse with
SkipLargeMetadata (seek past large arrays) and UseMMap (fault in a few
header pages instead of issuing per-element read() syscalls). For a
256k-token vocab this cuts the parse from ~524k read() syscalls to 8.
The mapping is released when ParseGGUFFile returns.

Fixes #9790

Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Adira Denis Muhando <dennisadira@gmail.com>
2026-06-07 23:33:52 +02:00
LocalAI [bot]
6070402477 chore(model gallery): 🤖 add 1 new models via gallery agent (#10209)
chore(model gallery): 🤖 add new models via gallery agent

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-07 22:09:32 +02:00
LocalAI [bot]
67f80a152b fix(mtp): don't auto-enable self-spec MTP for draft-only assistant GGUFs (#10208)
Gemma4 MTP (ggml-org/llama.cpp#23398) registers the prediction head as a
separate `gemma4-assistant` architecture. That assistant GGUF still carries
`<arch>.nextn_predict_layers`, so the architecture-agnostic detection in
HasEmbeddedMTPHead matched it and appended the `spec_type:draft-mtp` defaults.

Unlike the DeepSeek/Qwen embedded-head models, an assistant checkpoint cannot
self-speculate: it is a draft model that requires a paired target context
(`ctx_other`) and throws if loaded alone. Auto-applying the self-spec defaults
to a standalone assistant import therefore produces a broken config.

Guard the detection against draft-only assistant architectures (the `-assistant`
suffix is upstream's naming convention) so importing one no longer yields a
self-speculation config. Two-model target+draft pairing remains expressible
manually via `draft_model:` and is left to a follow-up.


Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-07 22:09:02 +02:00
LocalAI [bot]
a7cb587d96 feat(parakeet-cpp): real segment timestamps (NeMo-faithful) (#10207)
* feat(parakeet-cpp): real segment timestamps (NeMo-faithful)

Offline: replace the single synthetic whole-clip segment with multiple
segments grouped exactly like NeMo's get_segment_offsets - a new segment
after sentence-ending punctuation ('. ? !'), each carrying start/end and
its time-window token ids. The optional model option segment_gap_threshold
(NeMo's unit: encoder FRAMES, default 0=off) adds NeMo's silence-gap split,
converted to seconds via the JSON frame_sec the engine now reports.
Per-segment words are still gated behind timestamp_granularities=["word"];
a zero-word document falls back to a single text segment.

Streaming: when libparakeet.so exposes the ABI v4 JSON entry points
(probed), drive parakeet_capi_stream_feed_json / _finalize_json and
accumulate the streamed per-word timestamps into per-utterance segments
(EOU stays the boundary), so streaming FinalResult segments now carry
start/end. Falls back to the text-only feed against an older library.

Pure-Go specs cover splitWordsIntoSegments (punctuation + gap rules, NeMo
elif order, fallback), transcriptResultFromDoc (multi-segment, token
windows, word-granularity gate), and the streaming segmenter.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* docs(audio): document parakeet-cpp segment timestamps + segment_gap_threshold

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* test(parakeet-cpp): update model-gated specs for multi-segment output

The offline AudioTranscription specs asserted the old single synthetic
segment (Segments HaveLen(1), Segments[0].Text == res.Text). With
NeMo-faithful segmentation a multi-sentence clip now yields multiple
punctuation-delimited segments, so assert the new contract instead:
one-or-more time-ordered segments, each with text and (under word
granularity) per-segment words whose span tracks the segment start/end.
Caught by running the model-gated suite on the dgx (GB10) against the
real tdt_ctc-110m + realtime_eou models.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-07 22:08:24 +02:00
LocalAI [bot]
f7c74ad2da chore: ⬆️ Update ggml-org/llama.cpp to 31e82494c0a3913c919c1027fa70500fbf4c07dd (#10191)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-07 10:43:17 +02:00
LocalAI [bot]
7402d1fd20 chore(turboquant): bump to 7d9715f1 + fix compilation against rebased fork (#10205)
* chore(turboquant): bump TheTom/llama-cpp-turboquant to 7d9715f1

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

* fix(turboquant): drop obsolete legacy-spec shim after fork rebased

The TheTom/llama-cpp-turboquant fork (pin c9aa86a) rebased past the
upstream common_params_speculative refactor (ggml-org/llama.cpp
#22397/#22838/#22964), the model_tgt rename (#22838) and get_media_marker
(#21962). The old fork-compat shim forced now-wrong legacy code paths,
breaking the build with errors like 'struct common_params_speculative has
no member named mparams_dft / type' and 'server_context_impl has no member
named model'.

Remove the obsolete LOCALAI_LEGACY_LLAMA_CPP_SPEC branches from the shared
grpc-server.cpp (stock llama-cpp and the modern fork both take the modern
path now), and narrow the one remaining gap (the fork still lacks
common_params::checkpoint_min_step) to a dedicated
LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP guard injected by
patch-grpc-server.sh. The patch script now only adds the turbo2/3/4
KV-cache types and injects that one macro.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

* fix(turboquant): HIP-port the fork's CUDA additions (copy2d 3D-peer + cudaEventCreate)

The turboquant fork adds/modifies a few ggml-cuda.cu spots with CUDA APIs that
ggml's HIP/MUSA shim does not provide, breaking the -gpu-rocm-hipblas-turboquant
build. patches/0001-hip-guard-copy2d-peer-fastpath.patch (applied by
apply-patches.sh) ports them:

- Guard ggml_cuda_copy2d_across_devices's 3D-peer copy fast path with
  #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) so HIP/MUSA fall through
  to the existing cudaMemcpyAsync staging fallback (HIP genuinely lacks
  cudaMemcpy3DPeerAsync, per the fork's own comment).
- Create the device event in ggml_backend_cuda_device_event_new with the
  HIP-aliased cudaEventCreateWithFlags(.., cudaEventDisableTiming) instead of the
  un-aliased plain cudaEventCreate, matching this file's own usage elsewhere.

CUDA builds are unaffected.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

* ci(turboquant): drop the ROCm/hipblas build flavor

The TheTom/llama-cpp-turboquant fork is not ROCm-clean at the current pin:
beyond the CUDA-API gaps already patched (3D-peer copy, cudaEventCreate),
its llama.cpp base fails to compile the flash-attention MMA f16 kernels for
head-dim 640 under HIP (cols_per_warp evaluates to 0 -> division-by-zero /
non-constant static asserts in fattn-mma-f16.cuh). That is a deep
ggml-on-ROCm kernel issue, not something a small fork patch can paper over.

Drop -gpu-rocm-hipblas-turboquant from the build matrix so turboquant still
ships for cpu / cublas / vulkan / sycl. Re-add it once the fork's HIP path
compiles (or upstream ggml fixes the large-head-dim MMA kernels for ROCm).

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-07 10:42:06 +02:00
LocalAI [bot]
8c42695ef8 chore: ⬆️ Update ggml-org/whisper.cpp to a8ec021f2750a473ff4a8f3883bc9fdf5feafa84 (#10202)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-07 08:37:42 +02:00
LocalAI [bot]
72e3241431 chore: ⬆️ Update mudler/parakeet.cpp to abd0087dcc92ec5ad1f96f9fd86c49eb26a5ce67 (#10204)
⬆️ Update mudler/parakeet.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-07 00:37:28 +02:00
LocalAI [bot]
cd2bf95862 fix(docs): use relearn notice shortcode instead of unsupported alert (#10206)
The Hugo relearn theme does not provide an "alert" shortcode, so the
docs deploy failed at the Build site step:

  failed to extract shortcode: template for shortcode "alert" not found
  docs/content/features/distributed-mode.md:136

Convert the warning block to the theme-supported notice shortcode used
everywhere else in the docs.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-07 00:37:12 +02:00
LocalAI [bot]
f64b72dd7d feat: support Ideogram4 in stablediffusion-ggml backend + gallery (#10201)
* feat(stablediffusion-ggml): support Ideogram4 unconditional diffusion model

Bump stable-diffusion.cpp from 1f9ee88 to b9254dd, the upstream commit that
adds Ideogram4 support (leejet/stable-diffusion.cpp#1609). Ideogram4 derives
its classifier-free guidance from a separate unconditional diffusion model,
exposed upstream through the new sd_ctx_params_t.uncond_diffusion_model_path
field.

Wire that field into the gosd wrapper via a new uncond_diffusion_model_path
option. The _path suffix is deliberate: the Go loader only resolves options
whose name contains "path" to an absolute path under the model directory, so
this keeps the option consistent with diffusion_model_path and
high_noise_diffusion_model_path.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

* feat(gallery): add Ideogram4 stablediffusion-ggml models

Single-file GGUF weights for Ideogram4 are now published
(stduhpf/ideogram-4-gguf), so add the model to the gallery. Ideogram4 is a
text-to-image model with strong, accurate in-image text rendering, driven by
a Qwen3-VL-8B text encoder and real classifier-free guidance from a separate
unconditional diffusion model (the uncond_diffusion_model_path support added
in the preceding commit).

Two index entries, both built on gallery/virtual.yaml with the full config
inlined in overrides (same pattern as the other models, no dedicated template
file):
- ideogram-4-iq4nl-ggml (4-bit, ~11.6GB diffusion)
- ideogram-4-q8_0-ggml  (8-bit, ~20GB diffusion)

Each bundles the diffusion + unconditional GGUF (stduhpf), the
Qwen3-VL-8B-Instruct text encoder (unsloth), and the FLUX.2 VAE (Comfy-Org
mirror, non-gated). cfg_scale is 7 to match the upstream Ideogram4 default,
since it performs real CFG unlike the guidance-distilled Flux/Z-Image models.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-06 22:50:12 +02:00
LocalAI [bot]
03c84cff28 feat(parakeet-cpp): nemotron-3.5-asr multilingual streaming model + request language support (#10199)
* feat(parakeet-cpp): honor request language (multilingual nemotron) on batched + streaming paths

Reads opts.GetLanguage() and threads it through to the new
parakeet_capi_transcribe_pcm_batch_json_lang and parakeet_capi_stream_begin_lang
C-API entry points, both probed with Dlsym so the backend still loads against an
older libparakeet.so (falling back to the non-lang paths, i.e. model default).

parakeet.cpp's batched C-API takes a single target_lang for the whole batch, so
the dispatcher only coalesces same-language requests: a request whose language
differs from the batch leader is held as a single carry-over and becomes the
leader of the next batch, never dropped and never left waiting (including on
shutdown). A new batcher test asserts no dispatched batch is ever mixed-language
and that every submitted request still receives a reply.

Assisted-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* feat(gallery): add parakeet-cpp-nemotron-3.5-asr-streaming-0.6b; bump parakeet.cpp pin

Adds the multilingual prompt-conditioned streaming model to the gallery (q8_0
default, OpenMDW-1.1) and bumps the parakeet-cpp backend pin to the parakeet.cpp
commit that ships nemotron support plus batched causal subsampling and the
batched target_lang C-API.

Assisted-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-06 13:53:10 +02:00
LocalAI [bot]
9bc69c9e5f chore(model gallery): 🤖 add 1 new models via gallery agent (#10200)
chore(model gallery): 🤖 add new models via gallery agent

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-06 13:52:46 +02:00
LocalAI [bot]
1e6c9cfd60 chore: ⬆️ Update ikawrakow/ik_llama.cpp to 6b9de3dbaa21ae95ea80638e5ee836795cc48c93 (#10190)
⬆️ Update ikawrakow/ik_llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-06 09:42:43 +02:00
LocalAI [bot]
0e6712f734 chore: ⬆️ Update mudler/parakeet.cpp to 843600590f96a31467a5199f827c253f34c110f7 (#10198)
chore(parakeet-cpp): bump pin to banded long-audio attention (843600590)

Update PARAKEET_VERSION to mudler/parakeet.cpp@843600590f
(merge of parakeet.cpp#9). Brings NeMo rel_pos_local_attn banded/Longformer
attention with the chunk-matmul construction: long audio now uses O(T*window)
attention instead of global O(T^2), fixing the encoder OOM on long clips
(~16.6-min clip: 54GB->9.4GB peak, ~4x faster) at NeMo's full [128,128] window.
Short clips are unchanged (global path). No C-ABI change.


Assisted-by: Claude:claude-opus-4-8

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-06 09:25:25 +02:00
LocalAI [bot]
0e4cee9a97 chore: bump LocalAGI + localrecall (fix pgvector hybrid search seqscan, #10186) (#10192)
chore: bump LocalAGI and localrecall (index-backed RRF hybrid search)

Bumps the agent stack to pull in the PostgreSQL hybrid-search fix:

- mudler/localrecall -> v0.6.3-...-9a3b3321a9cd (mudler/LocalRecall#46, merged)
- mudler/LocalAGI    -> ...-14aed1ae4336 (mudler/LocalAGI#477, merged)

localrecall's hybrid search previously sorted on a wrapped scalar
similarity expression, which blinded the planner into a full sequential
scan over every row and exceeded the statement timeout on large
collections, returning an empty result set. It now uses the canonical
Reciprocal Rank Fusion pattern (index-backed candidate retrieval + FULL
OUTER JOIN + weighted RRF).

Fixes #10186

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-06 09:16:59 +02:00
Copilot
352b7ec604 Harden gallery-agent Hugging Face fetches against transient rate limiting (#10187)
* Initial plan

* fix: retry HuggingFace trending fetch on transient rate limits

* fix: handle body close/write errors in huggingface retry paths

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-06-05 23:43:06 +02:00
LocalAI [bot]
ba706422fb chore: ⬆️ Update vllm-project/vllm cu130 wheel to 0.22.1 (#10188)
⬆️ Update vllm-project/vllm cu130 wheel

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-05 23:42:50 +02:00
LocalAI [bot]
e837921c2c feat: forward reasoning_effort to the backend so jinja models honor it (#10184)
* feat: forward reasoning_effort to the backend so jinja models honor it

reasoning_effort was only mapped to the binary enable_thinking toggle and
otherwise reached Go-side templates — it was never sent to the backend. So
jinja-templated models whose chat template keys on reasoning_effort (gpt-oss
Harmony, LFM2.5) could not be driven by it: LFM2.5 ignores enable_thinking and
kept emitting <think>.

Forward the effective reasoning_effort to the backend as a chat_template_kwarg
(mirroring enable_thinking) in grpc-server.cpp, and put it in PredictOptions
metadata (gRPCPredictOpts). Add a config-level default: ModelConfig.reasoning_effort
and Pipeline.reasoning_effort, resolved by ModelConfig.ApplyReasoningEffort
(request value overrides config default, none->disable / level->enable, an
operator's reasoning.disable wins). request.go now uses that helper.

Assisted-by: Claude:claude-opus-4-8 go test, golangci-lint
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(realtime): set the pipeline LLM's reasoning_effort

Apply Pipeline.ReasoningEffort to the pipeline's LLM config when the realtime
model is built (per-session copy, overrides the LLM's own reasoning_effort),
and surface the resolved effort on the template input so Go-templated models
get it too. jinja models receive it via the backend metadata. This lets a
realtime pipeline disable thinking on models that only honor reasoning_effort
(e.g. LFM2.5), which enable_thinking can't.

Assisted-by: Claude:claude-opus-4-8 go test, golangci-lint
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-05 13:45:43 +00:00
Richard Palethorpe
73385713ca feat(distributed): enforce registration token for worker file transfer (#10183)
The worker HTTP file-transfer server is authenticated by the registration
token via checkBearerToken, which fails open on an empty token: every
/v1/files, /v1/files-list and /v1/backend-logs request is then served
unauthenticated, granting read/write to the worker's models/staging/data
directories. The fail-open was also silent (the only auth log sat on the
unreachable reject branch), and the worker process never runs
DistributedConfig.Validate(), so the existing frontend warning did not
cover the component that exposes the server.

Mirror the NatsRequireAuth pattern: keep anonymous as the default but make
it loud and opt-in enforceable.

- Log a prominent warning when the file-transfer server starts tokenless.
- Add LOCALAI_REGISTRATION_REQUIRE_AUTH: DistributedConfig.Validate() errors
  on an empty token (frontend) and the worker refuses to start (fail-fast,
  before registration), so production can fail closed. Also satisfies the
  F-003 suggestion to fail Validate() on distributed + empty token.
- Add LOCALAI_DISTRIBUTED_REQUIRE_AUTH umbrella switch implying both
  RegistrationRequireAuth and NatsRequireAuth — one production knob locking
  down the registration/file-transfer layer and the NATS bus together; the
  granular flags remain available as single-layer overrides. Wired into the
  frontend, supervisor worker, and agent worker (vLLM worker has neither a
  NATS connection nor a file-transfer server, so it is left untouched).
- Document in distributed-mode.md (warning callout + flag tables).

Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-06-05 14:34:28 +02:00
LocalAI [bot]
a4e671779a chore: ⬆️ Update ggml-org/whisper.cpp to 99613cb720b65036237d44b52f753b51f75c2797 (#10178)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-05 09:04:25 +02:00
LocalAI [bot]
7051b2e0a1 chore: ⬆️ Update ggml-org/llama.cpp to 7c158fbb4aec1bdc9c81d6ca0e785139f4826fae (#10179)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-05 09:04:10 +02:00
LocalAI [bot]
469737101a chore: ⬆️ Update ikawrakow/ik_llama.cpp to 1520eda980564241434b791ce2bbbd128c4be9ea (#10180)
⬆️ Update ikawrakow/ik_llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-05 09:03:08 +02:00
LocalAI [bot]
858257eaf0 fix(distributed): self-heal stale 'model not loaded' routing (#10181)
* fix(distributed): self-heal stale 'model not loaded' routing

In distributed mode the registry can list a model as loaded on a node
while the worker has evicted it (autonomous LRU eviction, an out-of-band
unload, etc.) yet the backend process survives. The router's cached-node
check only verifies the process is alive (probeHealth), so it routes there
and inference fails with "<backend>: model not loaded" — and stays broken
until the controller restarts and rebuilds its registry.

InFlightTrackingClient now reconciles this: when a tracked inference call
returns a model-not-loaded error, it drops the stale replica row
(RemoveNodeModel) so the next request reloads the model on a healthy node
instead of routing back to the evicted one. The original error is returned
unchanged; only the registry is corrected.

Assisted-by: Claude:claude-opus-4-8 go vet
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(distributed): typed model-not-loaded error via gRPC status code

Replace the controller-side error-string match with a shared, code-aware
helper. Go error types don't survive the gRPC boundary, so the signal is
carried as a status code (FailedPrecondition):

- pkg/grpc/grpcerrors: ModelNotLoaded(backend) constructor +
  IsModelNotLoaded(err) checker (status-code first, message fallback for
  backends not yet migrated).
- InFlightTrackingClient.reconcile now uses grpcerrors.IsModelNotLoaded.
- Migrate the Go backends that emit this error (parakeet-cpp, cloud-proxy,
  rfdetr-cpp) to the typed constructor.

Acting on a false positive is harmless (the model is just reloaded).

Assisted-by: Claude:claude-opus-4-8 go vet
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-05 09:01:36 +02:00
Adira
ef80a0e825 fix(config): add face/speaker recognition constants and register insightface + speaker-recognition (#10110)
FLAG_FACE_RECOGNITION and FLAG_SPEAKER_RECOGNITION already existed as
ModelConfigUsecase bitmask flags, and GuessUsecases already gate-checks
both backends by name — but BackendCapabilities had no entries for
either, so the UI could not classify them.

Also missing were the Method* constants for the five proto-defined RPCs
these backends implement (FaceVerify, FaceAnalyze, VoiceVerify,
VoiceEmbed, VoiceAnalyze) and the corresponding Usecase* strings
and UsecaseInfoMap entries needed to wire them into the rest of the
capability system.

Changes:
- Add MethodFaceVerify, MethodFaceAnalyze, MethodVoiceVerify,
  MethodVoiceEmbed, MethodVoiceAnalyze GRPCMethod constants
- Add UsecaseFaceRecognition ("face_recognition") and
  UsecaseSpeakerRecognition ("speaker_recognition") Usecase constants
- Add UsecaseInfoMap entries for both new usecases, referencing the
  existing FLAG_FACE_RECOGNITION and FLAG_SPEAKER_RECOGNITION flags
- Register insightface: Embedding + Detect + FaceVerify + FaceAnalyze
- Register speaker-recognition: VoiceVerify + VoiceEmbed + VoiceAnalyze

Follows up on #10107 which left these two out because they needed new
constants first.

Assisted-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Adira Denis Muhando <dennisadira@gmail.com>
2026-06-04 21:48:01 +02:00
LocalAI [bot]
92726f7631 fix(distributed): stage directory-based models to remote nodes (#10175)
Distributed file-staging treated every model path field (ModelFile, etc.)
as a single regular file: it os.Open'd the path and streamed its fd as the
HTTP PUT body. For directory-based models — e.g. qwen3-tts-cpp, whose
weights and tokenizer ggufs live under one directory referenced by
parameters.model — opening the directory succeeds but reading its fd
returns EISDIR, so routing the model to a remote NATS worker failed with
"read /models/<model>: is a directory". Single-file models were unaffected,
so only multi-file pipelines (e.g. the realtime TTS stage) broke.

stageModelFiles now detects a directory path field and stages each
contained file individually (via the new stageDirectory helper), preserving
structure with the existing StagingKeyMapper and rewriting the field to the
remote directory (deriving ModelPath as before). countStageableFiles makes
the progress total count a directory's files so the staging tracker stays
accurate.

Assisted-by: Claude:claude-opus-4-8 go vet

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-04 18:05:38 +02:00
LocalAI [bot]
994063ba9a feat(qwen3-tts-cpp): normalize request language for flexible matching (#10174)
The qwen3-tts.cpp backend honored the request `language` field only via exact lowercase two-letter codes in the C++ language_to_id table, silently defaulting to English for anything else (en-US, EN, english, ...).

Add normalizeLanguage() in the Go handler: lowercase + trim, strip the region/locale suffix (en-US, pt_BR, zh-Hans -> en/pt/zh), and resolve common English full names (english -> en). The canonical codes match the existing C++ table, so no C++ change is needed. Covered by a pure-Go Ginkgo spec. Also document the language field and accepted forms under the Qwen3-TTS docs.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-04 17:26:31 +02:00
LocalAI [bot]
c1a55cf72d chore: ⬆️ Update mudler/parakeet.cpp to b11fe5bca78ad8b342dd559a43d76df3984bb447 (#10167)
⬆️ Update mudler/parakeet.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-04 12:07:09 +02:00
LocalAI [bot]
96758841d8 chore: ⬆️ Update predict-woo/qwen3-tts.cpp to 136e5d36c17083da0321fd96512dc7b263f94a44 (#10165)
⬆️ Update predict-woo/qwen3-tts.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-04 12:06:55 +02:00
LocalAI [bot]
7a59260621 chore: ⬆️ Update CrispStrobe/CrispASR to 13d54e110e1538e0f0bc3af0680b9ab246cfb48d (#10145)
⬆️ Update CrispStrobe/CrispASR

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-04 12:06:32 +02:00
LocalAI [bot]
27e63b9a78 feat(tts): support per-request instructions and params (#10172)
The OpenAI-compatible TTS endpoint accepts an `instructions` field, but it
was silently dropped at the HTTP->gRPC boundary: neither schema.TTSRequest
nor the gRPC TTSRequest proto carried it, so backends could only read such a
value from static YAML options (identical for every request). This blocked
per-line emotion/style and, for Qwen3-TTS VoiceDesign, limited a model config
to a single designed voice.

Plumb a generic per-request instruction string end to end, plus an optional
backend-specific params map:

- proto: add `optional string instructions` and `map<string,string> params`
  to TTSRequest.
- schema: add Instructions (maps OpenAI `instructions`) and Params (LocalAI
  extension) to schema.TTSRequest.
- core: thread both through ModelTTS/ModelTTSStream via a newTTSRequest helper
  that attaches instructions only when non-empty (so backends can fall back to
  YAML when unset); forward them from the /v1/audio/speech handler.
- qwen-tts: prefer the per-request instruction over the YAML `instruct` option
  (used by both mode detection and generation) and merge per-request params.
- chatterbox: merge per-request params (coerced to float/int/bool) over YAML
  options into generate() kwargs.

Fully backward compatible: empty instructions fall back to the YAML option and
backends that don't support style/voice instructions ignore the field.

Closes #10164


Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-04 11:45:02 +02:00
LocalAI [bot]
55c0911c23 chore: ⬆️ Update leejet/stable-diffusion.cpp to 1f9ee88e09c258053fa59d5e05e23dfb10fa0b13 (#10166)
⬆️ Update leejet/stable-diffusion.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-04 09:34:34 +02:00
LocalAI [bot]
f6cb6ab6d9 chore: ⬆️ Update ggml-org/llama.cpp to 94a220cd6745e6e3f8de62870b66fd5b9bc92700 (#10168)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-04 09:34:13 +02:00
LocalAI [bot]
9f11b09c6a chore(model-gallery): ⬆️ update checksum (#10169)
⬆️ Checksum updates in gallery/index.yaml

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-04 00:32:15 +02:00
LocalAI [bot]
a5c4f822f0 chore: ⬆️ Update antirez/ds4 to 477c0e82e2699b35a65fd0a1ed6fe66b41087dfe (#10142)
⬆️ Update antirez/ds4

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-03 19:45:23 +02:00
LocalAI [bot]
fb36c262fe chore(model gallery): 🤖 add 1 new models via gallery agent (#10163)
chore(model gallery): 🤖 add new models via gallery agent

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-03 19:44:51 +02:00
LocalAI [bot]
0e4e8980e6 chore: ⬆️ Update ggml-org/llama.cpp to 5c394fdc8b564eff6faacc50a139529d875f0e36 (#10143)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-03 19:44:21 +02:00
Richard Palethorpe
3a932a9803 feat(distributed): Add NATS JWT authentication and TLS/mTLS options (#10159)
* feat(distributed): NATS JWT auth, TLS/mTLS options, and e2e coverage

Mint per-node NATS user JWTs at registration when LOCALAI_NATS_ACCOUNT_SEED
is set, and connect workers with scoped credentials from the register response.
Add optional LOCALAI_NATS_TLS_CA/CERT/KEY for private CA and mTLS alongside
tls:// URLs, plus test-e2e-distributed and NatsJWT container e2e specs.

Document JWT setup (nats-auth-setup.sh) and TLS env vars in distributed-mode.

Assisted-by: Grok:grok grok-build
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(distributed): correct NATS JWT scoping and harden client auth

The JWT-auth path added in 46467cc7 had several gaps that fail silently
under LOCALAI_NATS_REQUIRE_AUTH:

- Agent-worker minted JWTs did not allow the subjects the agent worker
  actually subscribes to (jobs.mcp-ci.new and nodes.<id>.backend.stop),
  so MCP-CI jobs and backend-stop session cleanup were silently dropped.
  Scope the agent permission set to those subjects.
- NATS subscription permission violations were swallowed (Subscribe
  returned a live-but-dead subscription). Confirm subscriptions with a
  server round-trip so a denial surfaces synchronously, and log async
  permission errors.
- The backend worker connected anonymously when given a JWT without its
  paired seed; reject the unpaired credential instead.
- The documented service-user permissions in nats-auth-setup.sh omitted
  prefixcache.>, which the frontend publishes and subscribes; add it.

Also: add a credential-provider hook to the messaging client (consumed by
the follow-up credential-lifecycle change), drop the always-nil error from
NatsMessagingOptions, run go mod tidy (jwt/v2 and nkeys are now direct),
and gofmt the feature's files.

Tests: an agent-JWT e2e spec that connects to the enforcing NATS server
and exercises every subscription the agent worker makes, plus permission
allow-list coverage unit tests.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(distributed): acquire and auto-refresh worker NATS credentials

Workers fetched NATS credentials once at startup, which broke two cases
under JWT auth: a worker that registered while still pending admin
approval never received a minted JWT (it connected unauthenticated and
gave up), and a long-running worker's 24h JWT expired with no way to renew
it.

Introduce workerregistry.NATSCredentialManager, built on idempotent
re-registration (the frontend preserves the node row and mints a fresh JWT
each call):

- Acquire re-registers through admin approval until the node is approved
  and credentials are minted (or returns the first success when auth is
  not required, preserving anonymous-NATS behavior).
- RefreshLoop re-registers before the JWT expires (~75% of its lifetime),
  updating the credentials served to the connection.
- Both are bounded (default 100 attempts / consecutive failures) and
  return an error on exhaustion, so an unapprovable or unrenewable worker
  exits non-zero and surfaces the problem instead of hanging or drifting
  toward an expired credential.

The messaging client gains WithUserJWTProvider, fetching credentials on
each (re)connect so the connection transparently adopts a refreshed JWT
when the server expires the old one. RegisterFull exposes the approval
status and full response; Register delegates to it.

Both the backend worker and the agent worker are wired to this: explicit
env credentials are used as-is, minted credentials are acquired-with-wait
and refreshed, and a permanent refresh failure shuts the worker down so it
restarts and re-acquires.

Tests cover Acquire (wait-through-pending, bounded give-up, context
cancel), RefreshLoop (refresh-before-expiry, bounded failure, no-expiry
exit) and jwtExpiry decoding. Docs updated in distributed-mode.md.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-06-03 19:43:56 +02:00
LocalAI [bot]
9d10418593 fix(parakeet-cpp): convert audio before the non-batched transcribe path (#10161)
The direct (non-batched) transcription path handed the original upload
path straight to the C library via parakeet_capi_transcribe_path_json.
That loader only understands 16 kHz mono WAV/PCM, so any other format
(MP3, etc.) failed with "parakeet: failed to load audio: <file>".

Only the batched path converted the input (via decodeWavMono16k ->
utils.AudioToWav). Every other audio backend (whisper, crispasr)
converts unconditionally with utils.AudioToWav before handing the file
to its engine; the parakeet-cpp fallback was the lone exception.

Extract a convertToWavMono16k helper (reused by decodeWavMono16k) that
produces a 16 kHz mono WAV in a temp dir, and run the non-batched path
through it before calling the C loader. WAV inputs already in the target
format are passed through without ffmpeg.

Add specs covering the helper (decodable copy + cleanup, and an error on
a missing input) that need neither the model, the C library, nor ffmpeg.


Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-03 15:06:57 +02:00
dependabot[bot]
5470051d4d chore(deps): bump grpcio from 1.80.0 to 1.81.0 in /backend/python/transformers (#10158)
chore(deps): bump grpcio in /backend/python/transformers

Bumps [grpcio](https://github.com/grpc/grpc) from 1.80.0 to 1.81.0.
- [Release notes](https://github.com/grpc/grpc/releases)
- [Commits](https://github.com/grpc/grpc/compare/v1.80.0...v1.81.0)

---
updated-dependencies:
- dependency-name: grpcio
  dependency-version: 1.81.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-03 10:38:43 +02:00
LocalAI [bot]
68c5eeebc3 chore: ⬆️ Update ggml-org/whisper.cpp to 610e664ba7cfe3af46125ed1b5a1184fccb51bcd (#10140)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-03 10:38:28 +02:00
dependabot[bot]
1531fabe23 chore(deps): bump securego/gosec from 2.22.9 to 2.27.1 (#10147)
Bumps [securego/gosec](https://github.com/securego/gosec) from 2.22.9 to 2.27.1.
- [Release notes](https://github.com/securego/gosec/releases)
- [Commits](https://github.com/securego/gosec/compare/v2.22.9...v2.27.1)

---
updated-dependencies:
- dependency-name: securego/gosec
  dependency-version: 2.27.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-03 10:38:07 +02:00
LocalAI [bot]
b7673d5b76 chore: ⬆️ Update leejet/stable-diffusion.cpp to 2d40a8b2adcdf8b5b0ca0535f3bb7801b6ba13e5 (#10144)
⬆️ Update leejet/stable-diffusion.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-06-03 10:37:51 +02:00
dependabot[bot]
b64bdaf406 chore(deps): bump github.com/google/go-containerregistry from 0.21.5 to 0.21.6 (#10149)
chore(deps): bump github.com/google/go-containerregistry

Bumps [github.com/google/go-containerregistry](https://github.com/google/go-containerregistry) from 0.21.5 to 0.21.6.
- [Release notes](https://github.com/google/go-containerregistry/releases)
- [Commits](https://github.com/google/go-containerregistry/compare/v0.21.5...v0.21.6)

---
updated-dependencies:
- dependency-name: github.com/google/go-containerregistry
  dependency-version: 0.21.6
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-03 10:37:33 +02:00
dependabot[bot]
eebf08ff1d chore(deps): bump grpcio from 1.80.0 to 1.81.0 in /backend/python/vllm (#10157)
Bumps [grpcio](https://github.com/grpc/grpc) from 1.80.0 to 1.81.0.
- [Release notes](https://github.com/grpc/grpc/releases)
- [Commits](https://github.com/grpc/grpc/compare/v1.80.0...v1.81.0)

---
updated-dependencies:
- dependency-name: grpcio
  dependency-version: 1.81.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-03 10:37:16 +02:00
127 changed files with 5273 additions and 553 deletions

View File

@@ -1766,20 +1766,6 @@ include:
dockerfile: "./backend/Dockerfile.llama-cpp"
context: "./"
ubuntu-version: '2404'
- build-type: 'hipblas'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-rocm-hipblas-turboquant'
builder-base-image: 'quay.io/go-skynet/ci-cache:base-grpc-rocm-amd64'
runs-on: 'ubuntu-latest'
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
skip-drivers: 'false'
backend: "turboquant"
dockerfile: "./backend/Dockerfile.turboquant"
context: "./"
ubuntu-version: '2404'
- build-type: 'hipblas'
cuda-major-version: ""
cuda-minor-version: ""

View File

@@ -3,6 +3,7 @@ package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
@@ -113,6 +114,17 @@ func main() {
fmt.Println("Searching for trending models on HuggingFace...")
rawModels, err := client.GetTrending(searchTerm, limit)
if err != nil {
if errors.Is(err, hfapi.ErrRateLimited) {
fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err)
writeSummary(AddedModelSummary{
SearchTerm: searchTerm,
TotalFound: 0,
ModelsAdded: 0,
Quantization: quantization,
ProcessingTime: time.Since(startTime).String(),
})
return
}
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
os.Exit(1)
}
@@ -277,4 +289,3 @@ func truncateString(s string, maxLen int) string {
}
return s[:maxLen] + "..."
}

View File

@@ -18,7 +18,7 @@ jobs:
if: ${{ github.actor != 'dependabot[bot]' }}
- name: Run Gosec Security Scanner
if: ${{ github.actor != 'dependabot[bot]' }}
uses: securego/gosec@v2.22.9
uses: securego/gosec@v2.27.1
with:
# we let the report trigger content trigger a failure using the GitHub Security features.
args: '-no-fail -fmt sarif -out results.sarif ./...'

View File

@@ -309,13 +309,20 @@ run-e2e-aio: protogen-go
@echo 'Running e2e AIO tests'
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
test-e2e-distributed: protogen-go
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
# cpu-vllm backend from the current working tree, then drives a
# head + headless follower via testcontainers-go and asserts a chat
# completion. BuildKit caches both images, so re-runs only rebuild
# what changed. The test lives under tests/e2e/distributed and is
# selected by the VLLMMultinode label so it doesn't run alongside
# the other distributed-suite tests by default.
# test-e2e-distributed.
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
@echo 'Running e2e vLLM multi-node DP test'
LOCALAI_IMAGE=local-ai \

View File

@@ -537,6 +537,15 @@ message TTSRequest {
string dst = 3;
string voice = 4;
optional string language = 5;
// instructions is a free-form, per-request style/voice description (maps to
// the OpenAI `instructions` field). Backends that support expressive synthesis
// (e.g. Qwen3-TTS CustomVoice/VoiceDesign) prefer this over the static YAML
// option when set; backends that don't simply ignore it.
optional string instructions = 6;
// params carries optional, backend-specific per-request generation parameters
// (e.g. Chatterbox exaggeration/cfg_weight/temperature). Values are strings and
// coerced by the backend; unset leaves the backend's configured defaults.
map<string, string> params = 7;
}
message VADRequest {

View File

@@ -1,10 +1,10 @@
# ds4 backend Makefile.
#
# Upstream pin lives below as DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
# Upstream pin lives below as DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
# (.github/bump_deps.sh) can find and update it - matches the
# llama-cpp / ik-llama-cpp / turboquant convention.
DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
DS4_REPO?=https://github.com/antirez/ds4
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))

View File

@@ -1,5 +1,5 @@
IK_LLAMA_VERSION?=3f40e73c367ad9f0c1b1819f28c7348c26aa340d
IK_LLAMA_VERSION?=6b9de3dbaa21ae95ea80638e5ee836795cc48c93
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
CMAKE_ARGS?=

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=5dcb71166686799f0d873eab7386234302d05ecf
LLAMA_VERSION?=31e82494c0a3913c919c1027fa70500fbf4c07dd
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?=

View File

@@ -482,23 +482,13 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
if (!request->draftmodel().empty()) {
params.speculative.draft.mparams.path = request->draftmodel();
// Default to draft type if a draft model is set but no explicit type.
// Upstream (post ggml-org/llama.cpp#22838) made the speculative type a
// vector; the turboquant fork still uses the legacy scalar. The
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
}
#else
// Upstream made the speculative type a vector (ggml-org/llama.cpp#22838)
// and renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE (#22964).
const bool no_spec_type = params.speculative.types.empty() ||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
if (no_spec_type) {
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
}
#endif
}
// params.model_alias ??
@@ -574,9 +564,10 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
// tokens (0 disables the minimum). Match upstream's default (256). This
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
// also shifted from a fixed cadence to a minimum spacing. The turboquant
// fork branched before the field existed, so skip it on the legacy path
// (LOCALAI_LEGACY_LLAMA_CPP_SPEC is injected by patch-grpc-server.sh).
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
// fork still lacks common_params::checkpoint_min_step, so skip it there
// (LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP is injected by
// backend/cpp/turboquant/patch-grpc-server.sh).
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
params.checkpoint_min_step = 256;
#endif
@@ -752,7 +743,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
params.cache_idle_slots = false;
}
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
@@ -906,17 +897,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
// Speculative decoding options
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
// Fork only knows a single scalar `type`. Take the first comma-
// separated value and assign it via the singular helper.
std::string first = optval_str;
const auto comma = first.find(',');
if (comma != std::string::npos) first = first.substr(0, comma);
auto type = common_speculative_type_from_name(first);
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
params.speculative.type = type;
}
#else
// Upstream switched to a vector of types (comma-separated for multi-type
// chaining via common_speculative_types_from_names). We keep accepting a
// single value here, but also tolerate comma-separated lists.
@@ -945,7 +925,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
if (!parsed.empty()) {
params.speculative.types = parsed;
}
#endif
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
if (optval != NULL) {
try { params.speculative.draft.n_max = std::stoi(optval_str); } catch (...) {}
@@ -983,21 +962,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
// shares the target context size. Accept the option for backward
// compatibility but silently ignore it.
// Everything below relies on struct shape introduced in ggml-org/llama.cpp#22838
// (parallel drafting): `ngram_mod`, `ngram_map_k`, `ngram_map_k4v`,
// `ngram_cache`, and the `draft.{cache_type_*, cpuparams*, tensor_buft_overrides}`
// fields. The turboquant fork branched before that, so its build defines
// LOCALAI_LEGACY_LLAMA_CPP_SPEC via patch-grpc-server.sh and these option
// keys become unrecognized (silently dropped, like any unknown opt) for it.
//
// The `#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC` / `#else` split below sits at the
// closing-brace position of the `draft_ctx_size` branch on purpose: in the
// legacy build the chain ends here (the brace closes draft_ctx_size), and in
// the modern build the chain continues with `} else if (...)` instead, so the
// brace count stays balanced under both branches of the preprocessor.
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
}
#else
// --- ngram_mod family (upstream --spec-ngram-mod-*) ---
} else if (!strcmp(optname, "spec_ngram_mod_n_min")) {
if (optval != NULL) {
@@ -1127,7 +1091,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
}
if (!cur.empty()) flush(cur);
}
#endif // LOCALAI_LEGACY_LLAMA_CPP_SPEC — closes the `else`/`#ifdef` opened at draft_ctx_size
}
// Set params.n_parallel from environment variable if not set via options (fallback)
@@ -1177,15 +1140,11 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
}
// The draft tensor_buft_overrides are only populated under the modern
// (post-#22838) layout, whose population code is itself gated by
// LOCALAI_LEGACY_LLAMA_CPP_SPEC above. The turboquant fork lacks
// common_params_speculative::draft entirely, so skip the sentinel there too.
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
// Terminate the draft tensor_buft_overrides list with a sentinel, mirroring
// the main-model handling above.
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
}
#endif
// TODO: Add yarn
@@ -1944,6 +1903,17 @@ public:
body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true");
}
// Pass reasoning_effort via chat_template_kwargs too: the lever
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
// from enable_thinking which those templates ignore.
auto re_it = metadata.find("reasoning_effort");
if (re_it != metadata.end() && !re_it->second.empty()) {
if (!body_json.contains("chat_template_kwargs")) {
body_json["chat_template_kwargs"] = json::object();
}
body_json["chat_template_kwargs"]["reasoning_effort"] = re_it->second;
}
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
@@ -2737,6 +2707,17 @@ public:
body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true");
}
// Pass reasoning_effort via chat_template_kwargs too: the lever
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
// from enable_thinking which those templates ignore.
auto predict_re_it = predict_metadata.find("reasoning_effort");
if (predict_re_it != predict_metadata.end() && !predict_re_it->second.empty()) {
if (!body_json.contains("chat_template_kwargs")) {
body_json["chat_template_kwargs"] = json::object();
}
body_json["chat_template_kwargs"]["reasoning_effort"] = predict_re_it->second;
}
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());

View File

@@ -1,7 +1,7 @@
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
TURBOQUANT_VERSION?=7d9715f1f071fa07c7b2ad3dbfd320b314139e65
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
CMAKE_ARGS?=

View File

@@ -4,21 +4,19 @@
#
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
# server-side random per-instance marker) with the legacy "<__media__>"
# literal. The fork branched before that PR, so server-common.cpp has no
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
# "<__media__>", and Go-side tooling falls back to that sentinel when the
# backend does not expose media_marker, so substituting the literal keeps
# behavior identical on the turboquant path.
# 3. Revert the `common_params_speculative` field references to the
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
# the turboquant fork branched before that PR and still exposes the flat
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
# below map the new nested paths back to the legacy flat names so the
# shared grpc-server.cpp keeps compiling against the fork's common.h.
# Drop this block once the fork rebases past #22397.
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file
# so the grpc-server option parser skips the two references to
# common_params::checkpoint_min_step (the default and the option handler).
# That field does not exist in the fork yet; drop this once it does.
#
# The fork used to lag upstream on the whole common_params_speculative refactor
# (ggml-org/llama.cpp#22397/#22838/#22964), the model_tgt rename (#22838) and
# get_media_marker (#21962), which required a much larger compat shim here
# (flat-field sed renames + a coarse LOCALAI_LEGACY_LLAMA_CPP_SPEC define). The
# fork has since rebased past all of those, so the only remaining gap is
# checkpoint_min_step. If a future bump reintroduces a divergence, add a narrow
# guard in grpc-server.cpp keyed on a fork-specific macro and inject it here
# rather than resurrecting the coarse one.
#
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
@@ -72,72 +70,20 @@ else
echo "==> KV allow-list patch OK"
fi
if grep -q 'get_media_marker()' "$SRC"; then
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
# Only one call site today (ModelMetadata), but replace all occurrences to
# stay robust if upstream adds more. Use a temp file to avoid relying on
# sed -i portability (the builder image uses GNU sed, but keeping this
# consistent with the awk block above).
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> get_media_marker() substitution OK"
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file so
# the grpc-server option parser skips the two references to
# common_params::checkpoint_min_step (the default assignment and the option
# handler). That field does not exist in the fork yet. Drop this block once
# the fork rebases past the bump that added checkpoint_min_step.
if grep -q '^#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP' "$SRC"; then
echo "==> $SRC already defines LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP, skipping"
else
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
fi
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
# Each substitution is the exact post-refactor path → legacy flat field.
# Order doesn't matter because the source paths are disjoint, but we keep
# the most-specific (mparams.path) first for readability.
sed -E \
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
"$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> speculative field rename OK"
else
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
fi
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
# exposes the field as `model` on `server_context_impl`. The two call sites
# are in the Rerank and ModelMetadata RPC handlers.
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> model_tgt rename OK"
else
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
fi
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
# draft.tensor_buft_overrides) introduced for the post-#22838 layout, the
# draft.tensor_buft_overrides sentinel termination, and the
# common_params::checkpoint_min_step default/option (added with the
# 35c9b1f3 bump). Those blocks reference struct fields that simply do not
# exist in the fork.
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
else
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
# Insert the define before the very first `#include` so it precedes all the
# speculative-decoding code paths.
echo "==> patching $SRC to define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top"
# Insert the define before the very first `#include` so it precedes the
# checkpoint_min_step references.
awk '
!done && /^#include/ {
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
print "#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP 1"
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
print ""
done = 1
@@ -145,13 +91,13 @@ else
{ print }
END {
if (!done) {
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP" > "/dev/stderr"
exit 1
}
}
' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
echo "==> LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP define OK"
fi
echo "==> all patches applied"

View File

@@ -0,0 +1,55 @@
hip: port the turboquant CUDA additions that ggml's HIP shim doesn't cover
The turboquant fork adds/modifies a few ggml-cuda.cu spots with CUDA APIs
that ggml's HIP (and MUSA) compatibility layer does not provide, breaking
the -gpu-rocm-hipblas-turboquant build:
1. ggml_cuda_copy2d_across_devices() (host-staged cross-device copy for
split mul_mat output) uses the CUDA 3D-peer copy APIs
cudaMemcpy3DPeerParms / make_cudaPitchedPtr / make_cudaExtent /
cudaMemcpy3DPeerAsync. HIP genuinely does not support these (see the
fork's own comment "HIP does not support cudaMemcpy3DPeerAsync"), so
guard the peer fast path with #if !defined(GGML_USE_HIP) &&
!defined(GGML_USE_MUSA) -- matching how the fork already guards the
same API for the sibling 2D copy -- and fall through to the existing
cudaMemcpyAsync staging fallback below (functionally identical,
slightly slower on multi-GPU ROCm).
2. ggml_backend_cuda_device_event_new() creates its event with plain
cudaEventCreate, which ggml's HIP shim does not alias (it only aliases
cudaEventCreateWithFlags). Use cudaEventCreateWithFlags(...,
cudaEventDisableTiming) -- exactly what the rest of this file already
does (cf. lines ~1034, ~3461) and HIP-safe.
CUDA builds are unaffected. Drop the relevant hunk once the fork HIP-ports
these; apply-patches.sh fails fast if an anchor goes stale.
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 0427e6b..6352e6a 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -1933,6 +1933,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
size_t width, size_t height, cudaStream_t dst_stream, cudaStream_t src_stream) {
const auto & info = ggml_cuda_info();
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // 3D-peer copy types unmapped by ggml's HIP/MUSA shim; use staging fallback below
if (info.peer_access[src_device][dst_device]) {
cudaMemcpy3DPeerParms p = {};
p.dstDevice = dst_device;
@@ -1942,6 +1943,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
p.extent = make_cudaExtent(width, height, 1);
return cudaMemcpy3DPeerAsync(&p, dst_stream);
}
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
// Fallback: stage all rows through a single contiguous pinned buffer
int prev_device = ggml_cuda_get_device();
@@ -5714,7 +5716,7 @@ static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_
ggml_cuda_set_device(dev_ctx->device);
cudaEvent_t event;
- CUDA_CHECK(cudaEventCreate(&event));
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
return new ggml_backend_event {
/* .device = */ dev,

View File

@@ -14,6 +14,7 @@ import (
"github.com/mudler/xlog"
"github.com/mudler/LocalAI/pkg/grpc/base"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/httpclient"
)
@@ -145,7 +146,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
cfg := c.cfg.Load()
if cfg == nil {
return nil, errors.New("cloud-proxy: model not loaded")
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
}
if cfg.mode != modeTranslate {
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
@@ -175,7 +176,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
cfg := c.cfg.Load()
if cfg == nil {
return errors.New("cloud-proxy: model not loaded")
return grpcerrors.ModelNotLoaded("cloud-proxy")
}
if cfg.mode != modeTranslate {
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
@@ -269,7 +270,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,
cfg := c.cfg.Load()
if cfg == nil {
return errors.New("cloud-proxy: model not loaded")
return grpcerrors.ModelNotLoaded("cloud-proxy")
}
if cfg.mode != modePassthrough {
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# CrispASR version (release tag)
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
CRISPASR_VERSION?=05e60432bcb5bc2113f8c395a41e86497c11504a
CRISPASR_VERSION?=13d54e110e1538e0f0bc3af0680b9ab246cfb48d
SO_TARGET?=libgocrispasr.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -1,6 +1,6 @@
# parakeet-cpp backend Makefile.
#
# Upstream pin lives below as PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
# Upstream pin lives below as PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
# (.github/bump_deps.sh) can find and update it - matches the
# whisper.cpp / ds4 / vibevoice-cpp convention.
#
@@ -15,7 +15,7 @@
# That's what the L0 smoke test uses. The default target below does the
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
GOCMD?=go

View File

@@ -7,8 +7,12 @@ import "time"
type batchRequest struct {
pcm []float32
decoder int32
tag string
reply chan batchReply
// language is the per-request target locale ("" means the model default).
// parakeet.cpp's batched C-API takes ONE target_lang for the whole batch,
// so the dispatcher only coalesces requests that share a language.
language string
tag string
reply chan batchReply
}
// batchReply carries one per-item JSON object string (an element of the C-API's
@@ -43,13 +47,25 @@ func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchReques
// run is the dispatcher loop: accumulate submitted requests until either maxSize
// is reached or maxWait elapses since the first queued request, then dispatch.
// Exits when stop is closed (draining any partially-filled batch first).
//
// A batch carries ONE language (parakeet.cpp's batched C-API takes a single
// target_lang), so a request whose language differs from the batch leader is
// not coalesced: it is held in carry and becomes the leader of the next batch.
// carry is therefore never dropped and its caller never deadlocks: every batch
// (including a lone carry on stop) is dispatched, and runBatch replies to all.
func (b *batcher) run(stop <-chan struct{}) {
var carry *batchRequest
for {
var first *batchRequest
select {
case first = <-b.submit:
case <-stop:
return
if carry != nil {
// A mismatched request from the previous fill leads this batch.
first, carry = carry, nil
} else {
select {
case first = <-b.submit:
case <-stop:
return
}
}
batch := []*batchRequest{first}
@@ -64,12 +80,22 @@ func (b *batcher) run(stop <-chan struct{}) {
for len(batch) < b.maxSize {
select {
case r := <-b.submit:
if r.language != first.language {
// Different language: carry it to the next batch so this
// batch stays single-language, then dispatch what we have.
carry = r
break fill
}
batch = append(batch, r)
case <-timer.C:
break fill
case <-stop:
timer.Stop()
b.runBatch(batch)
// Don't strand a carried request's caller on shutdown.
if carry != nil {
b.runBatch([]*batchRequest{carry})
}
return
}
}

View File

@@ -105,4 +105,60 @@ var _ = Describe("batcher", func() {
go func() { <-rep }()
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
})
It("never coalesces requests with different languages into one batch", func() {
// parakeet.cpp's batched C-API takes ONE target_lang per batch, so the
// dispatcher must keep every dispatched batch single-language. Submit a
// mix of languages and assert (a) no batch ever carries more than one
// distinct language and (b) every submitted request still gets a reply
// (the mismatched carry-over is never dropped).
var mu sync.Mutex
var langsPerBatch [][]string
run := func(reqs []*batchRequest) {
seen := map[string]struct{}{}
var distinct []string
for _, r := range reqs {
if _, ok := seen[r.language]; !ok {
seen[r.language] = struct{}{}
distinct = append(distinct, r.language)
}
}
mu.Lock()
langsPerBatch = append(langsPerBatch, distinct)
mu.Unlock()
echoReply(reqs)
}
// Large window + size so the fill loop stays open across submits and the
// language constraint (not the timer) is what splits the batches.
b := newBatcher(16, 200*time.Millisecond, run)
stop := make(chan struct{})
go b.run(stop)
defer close(stop)
langs := []string{"en", "en", "de", "de", "en", "fr", "fr"}
const N = 7
var wg sync.WaitGroup
got := make([]string, N)
for i := 0; i < N; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
rep := make(chan batchReply, 1)
b.submit <- &batchRequest{tag: string(rune('a' + i)), language: langs[i], reply: rep}
got[i] = (<-rep).json
}(i)
}
wg.Wait()
mu.Lock()
defer mu.Unlock()
// Invariant: every dispatched batch is single-language.
for _, distinct := range langsPerBatch {
Expect(len(distinct)).To(Equal(1), "a batch coalesced more than one language: %v", distinct)
}
// Liveness: every request got a reply (carry-over never stranded).
for i := 0; i < N; i++ {
Expect(got[i]).To(Equal(string(rune('a' + i))))
}
})
})

View File

@@ -15,6 +15,7 @@ import (
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
@@ -47,6 +48,13 @@ var (
// side reads them as const float*/const int*.
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
// CppTranscribePcmBatchJSONLang is the multilingual variant of the batched
// JSON entry point: identical, plus a trailing target_lang. "" (the model
// default, "auto") is passed for non-prompt models, which ignore it; an
// unknown locale on a prompt model returns 0 and sets last_error. Present
// only in newer libparakeet.so; nil falls back to CppTranscribePcmBatchJSON.
CppTranscribePcmBatchJSONLang func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32, targetLang string) uintptr
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
@@ -54,6 +62,18 @@ var (
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
CppStreamFinalize func(s uintptr) uintptr
CppStreamFree func(s uintptr)
// CppStreamBeginLang is the multilingual variant of stream_begin: identical,
// plus a trailing target_lang ("" means the model default). Present only in
// newer libparakeet.so; nil falls back to CppStreamBegin.
CppStreamBeginLang func(ctx uintptr, targetLang string) uintptr
// Streaming JSON variants (ABI v4): feed/finalize returning a malloc'd char*
// JSON document {text,eou,frame_sec,words} (uintptr, freed via CppFreeString)
// so streaming segments can carry per-word timestamps. Present only in newer
// libparakeet.so; nil falls back to the text-only CppStreamFeed/Finalize path.
CppStreamFeedJSON func(s uintptr, pcm []float32, nSamples int32) uintptr
CppStreamFinalizeJSON func(s uintptr) uintptr
)
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
@@ -71,9 +91,26 @@ const streamChunkSamples = 16000
//
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
type transcriptJSON struct {
Text string `json:"text"`
Words []transcriptWord `json:"words"`
Tokens []transcriptToken `json:"tokens"`
Text string `json:"text"`
FrameSec float64 `json:"frame_sec"`
Words []transcriptWord `json:"words"`
Tokens []transcriptToken `json:"tokens"`
}
// streamFeedJSON mirrors the document returned by
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v4):
//
// {"text":"...","eou":0,"frame_sec":0.080000,
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
//
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
// <EOU>/<EOB> fired this feed; "words" are the words finalized this call with
// absolute (stream-relative) start/end seconds.
type streamFeedJSON struct {
Text string `json:"text"`
Eou int `json:"eou"`
FrameSec float64 `json:"frame_sec"`
Words []transcriptWord `json:"words"`
}
type transcriptWord struct {
@@ -102,6 +139,10 @@ type ParakeetCpp struct {
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
bat *batcher
batStop chan struct{}
// segmentGapFrames is NeMo's segment_gap_threshold in ENCODER FRAMES (model
// YAML option, default 0=off). When >0 it adds NeMo's silence-gap split on
// top of the punctuation split; converted to seconds via the JSON frame_sec.
segmentGapFrames int
}
// Load is the LocalAI gRPC entry point for LoadModel: it calls
@@ -131,6 +172,11 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
if maxWaitMs < 0 {
maxWaitMs = 0
}
// NeMo's segment_gap_threshold (encoder frames, default 0=off). Off by
// default matches NeMo's default (punctuation-only segments); when set it
// additionally splits segments on inter-word silence (see transcriptResultFromDoc).
p.segmentGapFrames = optInt(opts, "segment_gap_threshold", 0)
if CppTranscribePcmBatchJSON != nil {
p.batStop = make(chan struct{})
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
@@ -186,8 +232,19 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
if len(reqs) > 0 {
dec = reqs[0].decoder
}
// All requests in a batch share one language (the batcher coalesces only
// same-language requests), so any element's language describes the batch.
lang := ""
if len(reqs) > 0 {
lang = reqs[0].language
}
p.engineMu.Lock()
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
var cstr uintptr
if CppTranscribePcmBatchJSONLang != nil {
cstr = CppTranscribePcmBatchJSONLang(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec, lang)
} else {
cstr = CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
}
p.engineMu.Unlock()
if cstr == 0 {
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
@@ -225,21 +282,31 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
// OpenAI API, whose default is segment-level); token ids always populate
// Segment.Tokens.
//
// translate/diarize/prompt/temperature/language/threads are not applicable to
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
// translate/diarize/prompt/temperature/threads are not applicable to parakeet
// and are ignored; language is honored on the batched + streaming paths (see
// opts.GetLanguage() below); streaming is handled by AudioTranscriptionStream
// (L2).
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if p.ctxPtr == 0 {
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
}
if opts.Dst == "" {
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
}
// Fallback when the batched C-API is unavailable: transcribe directly from
// the file path (original behavior, no batching).
// Fallback when the batched C-API is unavailable: transcribe from a file
// path (original behavior, no batching). The C library's audio loader only
// understands 16 kHz mono WAV/PCM, so convert the input first - otherwise
// any non-WAV upload (MP3, etc.) fails with "failed to load audio". This
// mirrors what every other audio backend (whisper, crispasr) does via
// utils.AudioToWav before handing the file to the engine.
if p.bat == nil {
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
converted, cleanup, err := convertToWavMono16k(opts.Dst)
if err != nil {
return pb.TranscriptResult{}, err
}
defer cleanup()
cstr := CppTranscribePathJSON(p.ctxPtr, converted, 0)
if cstr == 0 {
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
}
@@ -249,7 +316,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
}
return transcriptResultFromDoc(doc, opts), nil
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
}
// Batched path: decode to PCM, submit to the batcher, wait for this request's
@@ -261,7 +328,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
}
rep := make(chan batchReply, 1)
select {
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, language: opts.GetLanguage(), reply: rep}:
case <-ctx.Done():
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
}
@@ -278,34 +345,169 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
}
return transcriptResultFromDoc(doc, opts), nil
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
}
// segmentSeparators is NeMo's default segment_seperators (sentence-ending
// punctuation). Splitting on these matches NeMo's default segment timestamps.
var segmentSeparators = []rune{'.', '?', '!'}
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
// synthesising a single whole-clip segment and attaching word timings only when
// the caller requested word granularity. Shared by the batched and direct paths.
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
// grouping words into NeMo-faithful segments (see splitWordsIntoSegments). The
// optional gapFrames (NeMo's segment_gap_threshold, in encoder FRAMES; 0=off)
// additionally splits on inter-word silence; it is converted to a seconds gap
// with the document's frame_sec. Per-segment word timings are attached only when
// the caller requested word granularity; token ids populate each segment's
// Tokens by time-window membership. Shared by the batched and direct paths.
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
text := strings.TrimSpace(doc.Text)
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
for _, w := range doc.Words {
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
gapSeconds := 0.0
if gapFrames > 0 {
if doc.FrameSec > 0 {
gapSeconds = float64(gapFrames) * doc.FrameSec
} else {
xlog.Warn("parakeet-cpp: segment_gap_threshold set but libparakeet.so " +
"did not report frame_sec; falling back to punctuation-only segments")
}
}
tokens := make([]int32, 0, len(doc.Tokens))
for _, t := range doc.Tokens {
tokens = append(tokens, t.ID)
groups := splitWordsIntoSegments(doc.Words, segmentSeparators, gapSeconds)
if len(groups) == 0 {
// No words (edge case): single whole-clip text segment.
return pb.TranscriptResult{
Text: text,
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
}
}
var segStart, segEnd int64
if len(words) > 0 {
segStart = words[0].Start
segEnd = words[len(words)-1].End
wantWords := wordsRequested(opts.TimestampGranularities)
segments := make([]*pb.TranscriptSegment, 0, len(groups))
for id, group := range groups {
parts := make([]string, len(group))
for i, gw := range group {
parts[i] = gw.W
}
seg := &pb.TranscriptSegment{
Id: int32(id),
Start: secondsToNanos(group[0].Start),
End: secondsToNanos(group[len(group)-1].End),
Text: strings.TrimSpace(strings.Join(parts, " ")),
Tokens: tokensInWindow(doc.Tokens, group[0].Start, group[len(group)-1].End),
}
if wantWords {
ws := make([]*pb.TranscriptWord, len(group))
for i, gw := range group {
ws[i] = &pb.TranscriptWord{Start: secondsToNanos(gw.Start), End: secondsToNanos(gw.End), Text: gw.W}
}
seg.Words = ws
}
segments = append(segments, seg)
}
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
if wordsRequested(opts.TimestampGranularities) {
seg.Words = words
}
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
return pb.TranscriptResult{Text: text, Segments: segments}
}
// splitWordsIntoSegments groups words into segments exactly as NeMo's
// get_segment_offsets does (nemo/collections/asr/parts/utils/timestamp_utils.py).
// Walking the words, it closes a segment when (1) the gap rule is enabled
// (gapSeconds > 0) and the segment already has words and the gap from the
// previous word's end to this word's start is >= gapSeconds - the current word
// then STARTS a new segment - or, checked only when the gap rule did not apply
// (NeMo's elif), (2) the word ends with (or is) a separator, which closes the
// segment INCLUDING that word. Trailing words flush into a final segment.
// gapSeconds <= 0 disables the gap rule, matching NeMo's default
// segment_gap_threshold=None (punctuation-only segments).
func splitWordsIntoSegments(words []transcriptWord, separators []rune, gapSeconds float64) [][]transcriptWord {
var segments [][]transcriptWord
var cur []transcriptWord
for i, word := range words {
gapActive := gapSeconds > 0 && len(cur) > 0
if gapActive && (word.Start-words[i-1].End) >= gapSeconds {
segments = append(segments, cur)
cur = []transcriptWord{word}
continue
}
if !gapActive && endsWithSeparator(word.W, separators) {
cur = append(cur, word)
segments = append(segments, cur)
cur = nil
continue
}
cur = append(cur, word)
}
if len(cur) > 0 {
segments = append(segments, cur)
}
return segments
}
// endsWithSeparator reports whether w's last rune is in separators (matching
// NeMo's `word[-1] in delims or word in delims`).
func endsWithSeparator(w string, separators []rune) bool {
r := []rune(strings.TrimSpace(w))
if len(r) == 0 {
return false
}
last := r[len(r)-1]
for _, s := range separators {
if last == s {
return true
}
}
return false
}
// tokensInWindow returns the ids of tokens whose timestamp t falls in
// [start, end] (inclusive), assigning each token to the segment that spans its
// time. The last segment's end is the last word end, so the final token is
// included.
func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
var ids []int32
for _, t := range tokens {
if t.T >= start && t.T <= end {
ids = append(ids, t.ID)
}
}
return ids
}
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
// is the model's own utterance boundary; each closed segment takes its start/end
// from its first/last accumulated word.
type streamSegmenter struct {
segs []*pb.TranscriptSegment
cur []transcriptWord
nextID int32
}
func (s *streamSegmenter) add(doc streamFeedJSON) {
s.cur = append(s.cur, doc.Words...)
if doc.Eou != 0 {
s.flush()
}
}
func (s *streamSegmenter) flush() {
if len(s.cur) == 0 {
return
}
parts := make([]string, len(s.cur))
for i, w := range s.cur {
parts[i] = w.W
}
s.segs = append(s.segs, &pb.TranscriptSegment{
Id: s.nextID,
Start: secondsToNanos(s.cur[0].Start),
End: secondsToNanos(s.cur[len(s.cur)-1].End),
Text: strings.TrimSpace(strings.Join(parts, " ")),
})
s.nextID++
s.cur = nil
}
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
// wordsRequested reports whether the caller asked for word-level timestamps.
// The OpenAI transcription API gates word timings behind
// timestamp_granularities[] containing "word" and defaults to segment-level
@@ -342,7 +544,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
defer close(results)
if p.ctxPtr == 0 {
return errors.New("parakeet-cpp: model not loaded")
return grpcerrors.ModelNotLoaded("parakeet-cpp")
}
if opts.Dst == "" {
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
@@ -351,7 +553,12 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
return status.Error(codes.Canceled, "transcription cancelled")
}
stream := CppStreamBegin(p.ctxPtr)
var stream uintptr
if CppStreamBeginLang != nil {
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
} else {
stream = CppStreamBegin(p.ctxPtr)
}
if stream == 0 {
// Not a cache-aware streaming model: run a normal offline
// transcription and emit it as one delta + a closing final result.
@@ -380,6 +587,14 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
return err
}
// ABI v4: when the streaming JSON entry points are present, drive them so the
// per-utterance segments carry per-word start/end timestamps. Falls through to
// the text-only loop below against an older libparakeet.so. Runs under the
// engineMu already held above.
if CppStreamFeedJSON != nil {
return p.streamJSON(ctx, stream, data, duration, results)
}
var (
full strings.Builder
segText strings.Builder
@@ -456,21 +671,102 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
return nil
}
// streamJSON drives the ABI v4 streaming JSON entry points: each feed/finalize
// returns a {text,eou,frame_sec,words} document. The newly-finalized text is
// emitted as a delta (unchanged streaming contract) while words are accumulated
// into per-utterance segments (closed on EOU) so the closing FinalResult carries
// timestamped segments. Runs under engineMu (already held by the caller).
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
duration float32, results chan *pb.TranscriptStreamResponse) error {
var (
full strings.Builder
seg streamSegmenter
)
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
// emits the delta, and routes words through the segmenter.
consume := func(ret uintptr) error {
if ret == 0 {
msg := CppLastError(p.ctxPtr)
if msg == "" {
msg = "unknown error"
}
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
}
raw := goStringFromCPtr(ret)
CppFreeString(ret)
var doc streamFeedJSON
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
}
if doc.Text != "" {
full.WriteString(doc.Text)
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
}
seg.add(doc)
return nil
}
for off := 0; off < len(data); off += streamChunkSamples {
if err := ctx.Err(); err != nil {
return status.Error(codes.Canceled, "transcription cancelled")
}
end := min(off+streamChunkSamples, len(data))
chunk := data[off:end]
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
return err
}
}
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
return err
}
seg.flush() // close any trailing utterance that never saw an EOU
text := strings.TrimSpace(full.String())
segments := seg.segments()
if len(segments) == 0 && text != "" {
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
}
results <- &pb.TranscriptStreamResponse{
FinalResult: &pb.TranscriptResult{
Text: text,
Segments: segments,
Duration: duration,
},
}
return nil
}
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
// float samples plus the clip duration in seconds. Mirrors the whisper
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
// decodes the PCM.
func decodeWavMono16k(path string) ([]float32, float32, error) {
// convertToWavMono16k converts an arbitrary audio file to a 16 kHz mono WAV in
// a fresh temp dir and returns the path together with a cleanup func the caller
// must defer. WAV inputs already at 16 kHz/mono/16-bit are passed through by
// utils.AudioToWav (hardlink/copy), everything else is transcoded via ffmpeg.
// Used by the direct (non-batched) transcription path, which hands a file path
// to the C library's WAV-only audio loader.
func convertToWavMono16k(path string) (string, func(), error) {
dir, err := os.MkdirTemp("", "parakeet")
if err != nil {
return nil, 0, err
return "", func() {}, err
}
defer func() { _ = os.RemoveAll(dir) }()
cleanup := func() { _ = os.RemoveAll(dir) }
converted := filepath.Join(dir, "converted.wav")
if err := utils.AudioToWav(path, converted); err != nil {
cleanup()
return "", func() {}, err
}
return converted, cleanup, nil
}
func decodeWavMono16k(path string) ([]float32, float32, error) {
converted, cleanup, err := convertToWavMono16k(path)
if err != nil {
return nil, 0, err
}
defer cleanup()
fh, err := os.Open(converted)
if err != nil {

View File

@@ -3,11 +3,14 @@ package main
import (
"context"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"github.com/ebitengine/purego"
"github.com/go-audio/audio"
"github.com/go-audio/wav"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -50,6 +53,10 @@ func ensureLibLoaded() {
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
}
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
})
@@ -70,6 +77,24 @@ func fixturesOrSkip() (string, string) {
return modelPath, audioPath
}
// writeMono16kWav writes `samples` frames of 16 kHz mono 16-bit silence to
// path. The result is already in AudioToWav's target format, so the conversion
// helper copies it through without invoking ffmpeg.
func writeMono16kWav(path string, samples int) {
GinkgoHelper()
f, err := os.Create(path)
Expect(err).ToNot(HaveOccurred())
enc := wav.NewEncoder(f, 16000, 16, 1, 1)
buf := &audio.IntBuffer{
Format: &audio.Format{NumChannels: 1, SampleRate: 16000},
SourceBitDepth: 16,
Data: make([]int, samples),
}
Expect(enc.Write(buf)).To(Succeed())
Expect(enc.Close()).To(Succeed())
Expect(f.Close()).To(Succeed())
}
var _ = Describe("ParakeetCpp", func() {
Context("AudioTranscription", func() {
It("transcribes a WAV via the parakeet C-API", func() {
@@ -86,13 +111,22 @@ var _ = Describe("ParakeetCpp", func() {
Expect(err).ToNot(HaveOccurred())
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
"expected non-empty transcript for %s", audioPath)
Expect(res.Segments).To(HaveLen(1),
"synthesises a single whole-clip segment")
Expect(res.Segments[0].Text).To(Equal(res.Text),
"single segment text must equal the top-level text")
// Default (no granularities) is segment-level: no per-word timings.
Expect(res.Segments[0].Words).To(BeEmpty(),
"word timings are opt-in via timestamp_granularities")
// NeMo-faithful segmentation: one or more punctuation-delimited
// segments, each with text and a monotonically-advancing time span.
Expect(res.Segments).ToNot(BeEmpty(), "expected at least one segment")
var prevEnd int64
for i, seg := range res.Segments {
Expect(strings.TrimSpace(seg.Text)).ToNot(BeEmpty(),
"segment %d must have text", i)
Expect(seg.End).To(BeNumerically(">=", seg.Start),
"segment %d end must not precede its start", i)
Expect(seg.Start).To(BeNumerically(">=", prevEnd),
"segments must be in time order")
prevEnd = seg.End
// Default (no granularities) is segment-level: no per-word timings.
Expect(seg.Words).To(BeEmpty(),
"word timings are opt-in via timestamp_granularities")
}
})
It("emits word-level timestamps when granularity=word", func() {
@@ -108,15 +142,61 @@ var _ = Describe("ParakeetCpp", func() {
TimestampGranularities: []string{"word"},
})
Expect(err).ToNot(HaveOccurred())
Expect(res.Segments).To(HaveLen(1))
seg := res.Segments[0]
Expect(seg.Words).ToNot(BeEmpty(),
"expected per-word timestamps with granularity=word")
// Monotonic, non-negative timings spanning the segment.
Expect(seg.Words[0].Start).To(BeNumerically(">=", int64(0)))
Expect(seg.End).To(BeNumerically(">=", seg.Start))
Expect(seg.Words[len(seg.Words)-1].End).To(Equal(seg.End),
"segment end tracks the last word")
Expect(res.Segments).ToNot(BeEmpty())
// With word granularity every segment carries its own words, and each
// segment's span tracks its first/last word; word starts advance
// monotonically across the whole transcript.
totalWords := 0
var prevStart int64 = -1
for i, seg := range res.Segments {
Expect(seg.Words).ToNot(BeEmpty(),
"segment %d must carry per-word timestamps with granularity=word", i)
Expect(seg.Start).To(Equal(seg.Words[0].Start),
"segment %d start tracks its first word", i)
Expect(seg.End).To(Equal(seg.Words[len(seg.Words)-1].End),
"segment %d end tracks its last word", i)
for _, w := range seg.Words {
Expect(w.End).To(BeNumerically(">=", w.Start))
Expect(w.Start).To(BeNumerically(">=", prevStart))
prevStart = w.Start
totalWords++
}
}
Expect(totalWords).To(BeNumerically(">", 0))
Expect(res.Segments[0].Words[0].Start).To(BeNumerically(">=", int64(0)))
})
})
Context("convertToWavMono16k", func() {
// The non-batched transcription path hands a file path to the C
// library's WAV-only audio loader, so it must convert first.
// utils.AudioToWav passes an already-16kHz/mono/16-bit WAV through
// without ffmpeg, which lets us exercise the helper (and the
// regression: the direct path used to skip conversion entirely)
// without a model, the C library, or ffmpeg.
It("returns a decodable 16kHz mono WAV copy and cleans it up", func() {
dir := GinkgoT().TempDir()
src := filepath.Join(dir, "input.wav")
writeMono16kWav(src, 16000) // 1s of silence at 16 kHz
converted, cleanup, err := convertToWavMono16k(src)
Expect(err).ToNot(HaveOccurred())
// It must produce a fresh temp file, not return the original path.
Expect(converted).ToNot(Equal(src))
Expect(converted).To(BeAnExistingFile())
pcm, _, err := decodeWavMono16k(converted)
Expect(err).ToNot(HaveOccurred())
Expect(pcm).To(HaveLen(16000), "round-trips the sample count")
cleanup()
Expect(converted).ToNot(BeAnExistingFile(), "cleanup removes the temp dir")
})
It("errors on a non-existent input rather than passing the path through", func() {
_, _, err := convertToWavMono16k(filepath.Join(GinkgoT().TempDir(), "missing.mp3"))
Expect(err).To(HaveOccurred())
})
})

View File

@@ -65,6 +65,25 @@ func main() {
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
}
// Per-request language variants (multilingual nemotron). Same probe pattern:
// present only in libparakeet.so built with multilingual support, so the
// backend still loads against an older library and falls back to the
// non-lang batched + streaming entry points (model default / "auto").
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json_lang"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppTranscribePcmBatchJSONLang, lib, "parakeet_capi_transcribe_pcm_batch_json_lang")
}
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_begin_lang"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppStreamBeginLang, lib, "parakeet_capi_stream_begin_lang")
}
// Streaming JSON entry points (ABI v4): surface per-word timestamps on the
// streaming path. Same probe pattern; absent in older libparakeet.so, where
// the backend falls back to the text-only streaming feed.
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
}
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
flag.Parse()

View File

@@ -0,0 +1,127 @@
package main
import (
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func tw(text string, start, end float64) transcriptWord {
return transcriptWord{W: text, Start: start, End: end}
}
var _ = Describe("splitWordsIntoSegments (NeMo get_segment_offsets parity)", func() {
seps := []rune{'.', '?', '!'}
It("splits on sentence-ending punctuation, including the delimiter word", func() {
words := []transcriptWord{tw("hello", 0, 0.4), tw("world.", 0.4, 0.8), tw("bye", 1.0, 1.3)}
segs := splitWordsIntoSegments(words, seps, 0)
Expect(segs).To(HaveLen(2))
Expect(segs[0]).To(HaveLen(2))
Expect(segs[0][1].W).To(Equal("world."))
Expect(segs[1]).To(HaveLen(1))
Expect(segs[1][0].W).To(Equal("bye"))
})
It("keeps a single segment with no terminal punctuation and gap off", func() {
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
segs := splitWordsIntoSegments(words, seps, 0)
Expect(segs).To(HaveLen(1))
})
It("splits on the gap rule when enabled, the gapped word starting the next segment", func() {
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
segs := splitWordsIntoSegments(words, seps, 1.0) // c is 4.6s after b
Expect(segs).To(HaveLen(2))
Expect(segs[0]).To(HaveLen(2)) // a b
Expect(segs[1]).To(HaveLen(1)) // c
Expect(segs[1][0].W).To(Equal("c"))
})
It("checks the gap rule before punctuation (NeMo elif order)", func() {
// "b." would terminate, but c is far after it -> gap closes [a b.] at b.
words := []transcriptWord{tw("a", 0, 0.2), tw("b.", 0.2, 0.4), tw("c", 9.0, 9.2)}
segs := splitWordsIntoSegments(words, seps, 1.0)
Expect(segs).To(HaveLen(2))
Expect(segs[0]).To(HaveLen(2))
Expect(segs[1][0].W).To(Equal("c"))
})
It("still splits on punctuation when the gap rule is enabled but does not fire", func() {
words := []transcriptWord{tw("hi.", 0, 0.4), tw("bye", 0.4, 0.8)}
segs := splitWordsIntoSegments(words, seps, 5.0) // gap never reached
Expect(segs).To(HaveLen(2))
Expect(segs[0][0].W).To(Equal("hi."))
})
It("returns nothing for empty input", func() {
Expect(splitWordsIntoSegments(nil, seps, 0)).To(BeEmpty())
})
})
var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
doc := transcriptJSON{
Text: "hello world. bye now",
FrameSec: 0.08,
Words: []transcriptWord{
{W: "hello", Start: 0.0, End: 0.4},
{W: "world.", Start: 0.4, End: 0.8},
{W: "bye", Start: 1.0, End: 1.3},
{W: "now", Start: 1.3, End: 1.6},
},
Tokens: []transcriptToken{{ID: 1, T: 0.1}, {ID: 2, T: 0.5}, {ID: 3, T: 1.1}, {ID: 4, T: 1.4}},
}
It("emits one segment per punctuation-delimited group with start/end", func() {
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
Expect(res.Segments).To(HaveLen(2))
Expect(res.Segments[0].Text).To(Equal("hello world."))
Expect(res.Segments[0].Start).To(Equal(int64(0)))
Expect(res.Segments[0].End).To(Equal(secondsToNanos(0.8)))
Expect(res.Segments[1].Text).To(Equal("bye now"))
Expect(res.Segments[1].Start).To(Equal(secondsToNanos(1.0)))
Expect(res.Segments[1].Id).To(Equal(int32(1)))
})
It("assigns tokens to the segment whose time window contains them", func() {
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
Expect(res.Segments[0].Tokens).To(Equal([]int32{1, 2}))
Expect(res.Segments[1].Tokens).To(Equal([]int32{3, 4}))
})
It("attaches per-segment words only when word granularity requested", func() {
plain := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
Expect(plain.Segments[0].Words).To(BeEmpty())
withWords := transcriptResultFromDoc(doc, &pb.TranscriptRequest{TimestampGranularities: []string{"word"}}, 0)
Expect(withWords.Segments[0].Words).To(HaveLen(2))
})
It("falls back to a single text segment when there are no words", func() {
res := transcriptResultFromDoc(transcriptJSON{Text: "hi"}, &pb.TranscriptRequest{}, 0)
Expect(res.Segments).To(HaveLen(1))
Expect(res.Segments[0].Text).To(Equal("hi"))
})
})
var _ = Describe("streaming segment assembly", func() {
It("closes a segment with start/end from its words on EOU", func() {
acc := &streamSegmenter{}
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
}})
segs := acc.segments()
Expect(segs).To(HaveLen(1))
Expect(segs[0].Text).To(Equal("hello world"))
Expect(segs[0].Start).To(Equal(int64(0)))
Expect(segs[0].End).To(Equal(secondsToNanos(0.9)))
})
It("buffers words across feeds until EOU", func() {
acc := &streamSegmenter{}
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
Expect(acc.segments()).To(BeEmpty())
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
Expect(acc.segments()).To(HaveLen(1))
Expect(acc.segments()[0].Text).To(Equal("hi there"))
})
})

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# qwen3-tts.cpp version
QWEN3TTS_REPO?=https://github.com/predict-woo/qwen3-tts.cpp
QWEN3TTS_CPP_VERSION?=7a762e2ad4bacc6fdda81d81bf10a09ffb546f29
QWEN3TTS_CPP_VERSION?=136e5d36c17083da0321fd96512dc7b263f94a44
SO_TARGET?=libgoqwen3ttscpp.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
@@ -21,6 +22,43 @@ type Qwen3TtsCpp struct {
threads int
}
// languageNameAliases maps common full language names to the canonical
// two-letter code understood by the C++ language_to_id table.
var languageNameAliases = map[string]string{
"english": "en",
"russian": "ru",
"chinese": "zh",
"japanese": "ja",
"korean": "ko",
"german": "de",
"french": "fr",
"spanish": "es",
"italian": "it",
"portuguese": "pt",
}
// normalizeLanguage coerces a caller-supplied language into the canonical code
// the model expects. It lowercases, trims, strips any region/locale suffix
// (en-US, en_US, ja.JP -> en/ja), and resolves common full names (english -> en).
// An empty input stays empty so the C++ side applies its English default; an
// unrecognized value is returned normalized so C++ can log it and default.
func normalizeLanguage(lang string) string {
lang = strings.ToLower(strings.TrimSpace(lang))
if lang == "" {
return ""
}
// Strip region/locale suffix: keep the segment before the first separator.
if i := strings.IndexAny(lang, "-_."); i >= 0 {
lang = lang[:i]
}
if code, ok := languageNameAliases[lang]; ok {
return code
}
return lang
}
func (q *Qwen3TtsCpp) Load(opts *pb.ModelOptions) error {
// ModelFile is the model directory path (containing GGUF files)
modelDir := opts.ModelFile
@@ -54,7 +92,7 @@ func (q *Qwen3TtsCpp) TTS(req *pb.TTSRequest) error {
dst := req.Dst
language := ""
if req.Language != nil {
language = *req.Language
language = normalizeLanguage(*req.Language)
}
// Synthesis parameters with sensible defaults

View File

@@ -0,0 +1,53 @@
package main
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLanguageNormalization(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "qwen3-tts-cpp language normalization")
}
var _ = Describe("normalizeLanguage", func() {
DescribeTable("maps caller input to the canonical model language code",
func(input, expected string) {
Expect(normalizeLanguage(input)).To(Equal(expected))
},
// Canonical codes pass through unchanged
Entry("canonical en", "en", "en"),
Entry("canonical zh", "zh", "zh"),
Entry("canonical pt", "pt", "pt"),
// Case-insensitive
Entry("uppercase", "EN", "en"),
Entry("mixed case", "Ja", "ja"),
// Surrounding whitespace
Entry("trims whitespace", " en ", "en"),
// Region/locale stripping
Entry("BCP-47 region", "en-US", "en"),
Entry("underscore region", "en_US", "en"),
Entry("dotted locale", "ja.JP", "ja"),
Entry("region + case", "ZH-CN", "zh"),
// Full-name aliases
Entry("english name", "english", "en"),
Entry("chinese name cased", "Chinese", "zh"),
Entry("japanese name", "japanese", "ja"),
Entry("russian name", "russian", "ru"),
Entry("portuguese name", "portuguese", "pt"),
// Empty stays empty (C++ applies the English default)
Entry("empty", "", ""),
Entry("whitespace only", " ", ""),
// Unknown values pass through normalized so C++ can log + default
Entry("unknown code", "klingon", "klingon"),
Entry("unknown with region", "xx-YY", "xx"),
)
})

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=7948df8ac1070f5f6881b8d34675821893eb97d6
STABLEDIFFUSION_GGML_VERSION?=b9254dda0d10b91ee6f17fb7f4420097dd29824b
CMAKE_ARGS+=-DGGML_MAX_NAME=128

View File

@@ -386,6 +386,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
const char *llm_vision_path = "";
const char *diffusion_model_path = stableDiffusionModel;
const char *high_noise_diffusion_model_path = "";
const char *uncond_diffusion_model_path = "";
const char *taesd_path = "";
const char *control_net_path = "";
const char *embedding_dir = "";
@@ -472,6 +473,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval);
if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval);
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
if (!strcmp(optname, "uncond_diffusion_model_path")) uncond_diffusion_model_path = strdup(optval);
if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
if (!strcmp(optname, "embedding_dir")) {
@@ -571,6 +573,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
ctx_params.llm_vision_path = llm_vision_path;
ctx_params.diffusion_model_path = diffusion_model_path;
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
ctx_params.uncond_diffusion_model_path = uncond_diffusion_model_path;
ctx_params.vae_path = vae_path;
ctx_params.audio_vae_path = audio_vae_path;
ctx_params.embeddings_connectors_path = embeddings_connectors_path;

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=23ee03506a91ac3d3f0071b40e66a430eebdfa1d
WHISPER_CPP_VERSION?=a8ec021f2750a473ff4a8f3883bc9fdf5feafa84
SO_TARGET?=libgowhisper.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -37,6 +37,20 @@ def is_int(s):
except ValueError:
return False
def coerce_param_value(value):
"""Coerce a TTSRequest.params value (string on the wire) to the type the
Chatterbox generate() kwargs expect (float/int/bool), matching how static
YAML options are coerced at load time. Non-string values pass through."""
if not isinstance(value, str):
return value
if is_float(value):
return float(value)
if is_int(value):
return int(value)
if value.lower() in ["true", "false"]:
return value.lower() == "true"
return value
def split_text_at_word_boundary(text, max_length=250):
"""
Split text at word boundaries without truncating words.
@@ -191,6 +205,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# add options to kwargs
kwargs.update(self.options)
# Merge per-request params (TTSRequest.params), overriding the static
# YAML options. This exposes Chatterbox generation knobs (e.g.
# exaggeration, cfg_weight, temperature) per request. Values arrive as
# strings on the wire and are coerced to float/int/bool.
if hasattr(request, "params") and request.params:
for key, value in request.params.items():
kwargs[key] = coerce_param_value(value)
# Check if text exceeds 250 characters
# (chatterbox does not support long text)
# https://github.com/resemble-ai/chatterbox/issues/60

View File

@@ -47,6 +47,26 @@ def is_int(s):
return False
def coerce_param_value(value):
"""Coerce a string param value (from the TTSRequest.params map, which is
string-typed on the wire) into the most specific Python type the model
generation kwargs expect: bool, int, float, else the original string."""
if not isinstance(value, str):
return value
lowered = value.strip().lower()
if lowered in ("true", "false"):
return lowered == "true"
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
@@ -322,6 +342,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Model loaded successfully", success=True)
def _effective_instruct(self, request):
"""Resolve the instruction/style string for this request, preferring the
per-request TTSRequest.instructions value and falling back to the static
YAML `instruct` option. Empty string means "no instruction"."""
req_instruct = (
request.instructions
if hasattr(request, "instructions") and request.instructions
else ""
)
if req_instruct:
return req_instruct
return self.options.get("instruct", "") or ""
def _detect_mode(self, request):
"""Detect which mode to use based on request parameters."""
# Priority: VoiceClone > VoiceDesign > CustomVoice
@@ -338,8 +371,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if self.audio_path or self.voices:
return "VoiceClone"
# VoiceDesign: instruct option is provided
if "instruct" in self.options and self.options["instruct"]:
# VoiceDesign: instruct provided per-request or via YAML option
if self._effective_instruct(request):
return "VoiceDesign"
# Default to CustomVoice
@@ -690,10 +723,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if do_sample is not None:
generation_kwargs["do_sample"] = do_sample
instruct = self.options.get("instruct", "")
# Prefer the per-request instruction (TTSRequest.instructions) over the
# static YAML `instruct` option. This lets clients set a different style
# (CustomVoice emotion) or designed voice (VoiceDesign) per request.
instruct = self._effective_instruct(request)
if instruct is not None and instruct != "":
generation_kwargs["instruct"] = instruct
# Merge any per-request backend-specific params (TTSRequest.params).
# Values arrive as strings on the wire; coerce to int/float/bool so the
# model receives the types it expects. These override YAML-derived kwargs.
if hasattr(request, "params") and request.params:
for key, value in request.params.items():
generation_kwargs[key] = coerce_param_value(value)
# Generate audio based on mode
if mode == "VoiceClone":
# VoiceClone mode

View File

@@ -1,4 +1,4 @@
grpcio==1.80.0
grpcio==1.81.0
protobuf==7.35.0
certifi
setuptools

View File

@@ -3,5 +3,5 @@
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
# so uv consults this index alongside PyPI.
--extra-index-url https://wheels.vllm.ai/0.22.0/cu130
vllm==0.22.0
--extra-index-url https://wheels.vllm.ai/0.22.1/cu130
vllm==0.22.1

View File

@@ -1,4 +1,4 @@
grpcio==1.80.0
grpcio==1.81.0
protobuf
certifi
setuptools

View File

@@ -102,7 +102,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
// Connect to NATS
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
natsAuth := cfg.Distributed.NatsAuthConfig()
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
}
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
if err != nil {
return nil, fmt.Errorf("connecting to NATS: %w", err)
}

View File

@@ -23,9 +23,9 @@ import (
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/services/routing/router"
"github.com/mudler/LocalAI/core/services/storage"
"github.com/mudler/LocalAI/pkg/signals"
coreStartup "github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/signals"
"github.com/mudler/LocalAI/pkg/vram"
"github.com/mudler/LocalAI/pkg/model"
@@ -308,10 +308,31 @@ func New(opts ...config.AppOption) (*Application, error) {
application.galleryService.SetNATSClient(distSvc.Nats)
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
// Clean up stale in-progress operations from previous crashed instances
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
if _, err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
xlog.Warn("Failed to clean stale gallery operations", "error", err)
}
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
// Reap stale ops periodically, not just at boot: an op orphaned by
// a replica that died mid-install (its foreground handler goroutine
// gone) would otherwise linger "processing" in the UI until the next
// restart. 30m matches the install/upgrade ceiling so a genuinely
// slow op is never reaped out from under itself.
gsvc := application.galleryService
go func() {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for {
select {
case <-options.Context.Done():
return
case <-ticker.C:
if _, err := gsvc.ReapStaleOperations(30 * time.Minute); err != nil {
xlog.Warn("Failed to reap stale gallery operations", "error", err)
}
}
}
}()
}
// Hydrate from the store first so the wildcard subscriber finds an
// already-populated statuses map for any operations still in flight

View File

@@ -214,7 +214,9 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
"from", info.InstalledVersion, "to", info.AvailableVersion)
var err error
if bm != nil {
err = bm.UpgradeBackend(ctx, name, nil)
// Background auto-upgrade: no live admin watching a progress bar,
// so opID is empty and the distributed path skips progress streaming.
err = bm.UpgradeBackend(ctx, "", name, nil)
} else {
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)

View File

@@ -123,14 +123,14 @@ var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
})
It("ModelTTS forwards the request context to the SmartRouter", func() {
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
err := backend.ModelTTSStream(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg, func([]byte) error { return nil })
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()

View File

@@ -239,13 +239,13 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
if c.Backend == "cloud-proxy" {
opts.Proxy = &pb.ProxyOptions{
UpstreamUrl: c.Proxy.UpstreamURL,
Mode: c.Proxy.Mode,
Provider: c.Proxy.Provider,
ApiKeyEnv: c.Proxy.APIKeyEnv,
ApiKeyFile: c.Proxy.APIKeyFile,
UpstreamModel: c.Proxy.UpstreamModel,
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
UpstreamUrl: c.Proxy.UpstreamURL,
Mode: c.Proxy.Mode,
Provider: c.Proxy.Provider,
ApiKeyEnv: c.Proxy.APIKeyEnv,
ApiKeyFile: c.Proxy.APIKeyFile,
UpstreamModel: c.Proxy.UpstreamModel,
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
}
}
@@ -323,6 +323,12 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
metadata["enable_thinking"] = "true"
}
}
// Forward the effective reasoning effort so the backend can pass it to the
// jinja chat template (chat_template_kwargs.reasoning_effort) — the lever
// models like gpt-oss / LFM2.5 actually read, distinct from enable_thinking.
if c.ReasoningEffort != "" {
metadata["reasoning_effort"] = c.ReasoningEffort
}
pbOpts.Metadata = metadata
// Logprobs and TopLogprobs are set by the caller if provided

View File

@@ -75,3 +75,25 @@ var _ = Describe("gRPCPredictOpts enable_thinking metadata", func() {
Expect(opts.Metadata).ToNot(HaveKey("enable_thinking"))
})
})
// Guards forwarding the effective reasoning_effort into PredictOptions.Metadata,
// where the backend passes it to the jinja chat template (chat_template_kwargs)
// so models like gpt-oss / LFM2.5 honor it.
var _ = Describe("gRPCPredictOpts reasoning_effort metadata", func() {
withEffort := func(effort string) config.ModelConfig {
cfg := config.ModelConfig{}
cfg.SetDefaults()
cfg.ReasoningEffort = effort
return cfg
}
It("forwards reasoning_effort when set", func() {
opts := gRPCPredictOpts(withEffort("none"), "/tmp/models")
Expect(opts.Metadata).To(HaveKeyWithValue("reasoning_effort", "none"))
})
It("omits reasoning_effort when empty", func() {
opts := gRPCPredictOpts(withEffort(""), "/tmp/models")
Expect(opts.Metadata).ToNot(HaveKey("reasoning_effort"))
})
})

View File

@@ -20,11 +20,32 @@ import (
"github.com/mudler/LocalAI/pkg/utils"
)
// newTTSRequest assembles the gRPC TTSRequest from the per-request inputs. The
// optional instructions string is only attached when non-empty so backends can
// distinguish "no per-request instruction" (fall back to YAML) from an explicit
// empty one. params is forwarded as-is (nil when unset).
func newTTSRequest(text, modelPath, voice, dst, language, instructions string, params map[string]string) *proto.TTSRequest {
req := &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: dst,
Language: &language,
Params: params,
}
if instructions != "" {
req.Instructions = &instructions
}
return req
}
func ModelTTS(
ctx context.Context,
text,
voice,
language string,
language,
instructions string,
params map[string]string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
modelConfig config.ModelConfig,
@@ -74,13 +95,9 @@ func ModelTTS(
startTime = time.Now()
}
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
Language: &language,
})
ttsRequest := newTTSRequest(text, modelPath, voice, filePath, language, instructions, params)
res, err := ttsModel.TTS(ctx, ttsRequest)
if appConfig.EnableTracing {
errStr := ""
@@ -128,7 +145,9 @@ func ModelTTSStream(
ctx context.Context,
text,
voice,
language string,
language,
instructions string,
params map[string]string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
modelConfig config.ModelConfig,
@@ -177,12 +196,10 @@ func ModelTTSStream(
var totalPCMBytes int
snippetCapped := false
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Language: &language,
}, func(reply *proto.Reply) {
// Streaming TTS writes to the HTTP response, not a file, so dst is empty.
ttsRequest := newTTSRequest(text, modelPath, voice, "", language, instructions, params)
err = ttsModel.TTSStream(ctx, ttsRequest, func(reply *proto.Reply) {
// First message contains sample rate info
if !headerSent && len(reply.Message) > 0 {
var info map[string]any

42
core/backend/tts_test.go Normal file
View File

@@ -0,0 +1,42 @@
package backend
// Specs for the TTSRequest assembly that carries the per-request
// instructions/params from the OpenAI `instructions` field (and the LocalAI
// `params` extension) through to the gRPC boundary. Before this plumbing the
// instruction value was dropped before reaching the backend; these specs pin
// that it now survives, and that the empty case stays backward compatible.
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("newTTSRequest", func() {
It("attaches the instructions when a per-request value is set", func() {
req := newTTSRequest("hi", "/m", "alloy", "/out.wav", "en", "cheerful narrator", nil)
Expect(req.Instructions).ToNot(BeNil())
Expect(req.GetInstructions()).To(Equal("cheerful narrator"))
Expect(req.GetText()).To(Equal("hi"))
Expect(req.GetVoice()).To(Equal("alloy"))
Expect(req.GetDst()).To(Equal("/out.wav"))
Expect(req.GetLanguage()).To(Equal("en"))
})
It("leaves instructions unset when empty so backends fall back to YAML", func() {
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
Expect(req.Instructions).To(BeNil())
Expect(req.GetInstructions()).To(Equal(""))
})
It("forwards per-request params through to the backend", func() {
params := map[string]string{"exaggeration": "0.7", "cfg_weight": "0.3"}
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", params)
Expect(req.GetParams()).To(HaveKeyWithValue("exaggeration", "0.7"))
Expect(req.GetParams()).To(HaveKeyWithValue("cfg_weight", "0.3"))
})
It("leaves params nil when none are supplied", func() {
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
Expect(req.GetParams()).To(BeNil())
})
})

View File

@@ -52,10 +52,28 @@ type AgentWorkerCMD struct {
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
// DistributedRequireAuth is the umbrella switch; for the agent worker (which
// has no file-transfer server) it implies NATS auth is required.
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch implying --nats-require-auth (agent workers have no file-transfer server)" group:"distributed"`
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
// Timeouts
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
}
// natsAuthRequired reports whether NATS JWT credentials must be present — the
// granular flag or the umbrella (LOCALAI_DISTRIBUTED_REQUIRE_AUTH).
func (cmd *AgentWorkerCMD) natsAuthRequired() bool {
return cmd.NatsRequireAuth || cmd.DistributedRequireAuth
}
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
@@ -81,15 +99,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
registrationBody["token"] = cmd.RegistrationToken
}
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
// Context cancelled on shutdown — used by registration waits, heartbeat, and
// other background goroutines.
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
// Acquire credentials via (re)registration. When the bus requires auth and no
// static fallback is configured, wait through admin approval until the
// frontend mints credentials rather than starting unauthenticated.
credMgr := workerregistry.NewNATSCredentialManager(
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
return regClient.RegisterFull(ctx, registrationBody)
},
cmd.natsAuthRequired() && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
)
res, err := credMgr.Acquire(shutdownCtx)
if err != nil {
return fmt.Errorf("registration failed: %w", err)
}
nodeID := res.ID
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
// Use provisioned API token if none was set
if cmd.APIToken == "" {
cmd.APIToken = apiToken
cmd.APIToken = res.APIToken
}
// Start heartbeat
@@ -98,14 +131,40 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
}
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
// Context cancelled on shutdown — used by heartbeat and other background goroutines
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
// Connect to NATS
natsClient, err := messaging.New(cmd.NatsURL)
// Resolve NATS credentials with precedence: explicit env override, then
// frontend-minted (auto-refreshed before expiry), then service fallback.
// Each static source must supply JWT and seed together.
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
var natsOpts []messaging.Option
switch {
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
}
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
case credMgr.HasCredentials():
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
go func() {
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
shutdownCancel()
}
}()
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
}
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
case cmd.natsAuthRequired():
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
}
if natsTLS.Enabled() {
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
}
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
if err != nil {
return fmt.Errorf("connecting to NATS: %w", err)
}
@@ -183,17 +242,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
// Wait for shutdown
// Wait for an OS signal or an internal fatal condition (e.g. NATS
// credentials became unrenewable), so the worker restarts and re-acquires
// rather than lingering unable to serve.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
var runErr error
select {
case <-sigCh:
case <-shutdownCtx.Done():
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
xlog.Error("Internal shutdown requested", "error", runErr)
}
xlog.Info("Shutting down agent worker")
shutdownCancel() // stop heartbeat loop immediately
dispatcher.Stop()
mcpTools.CloseAllMCPSessions()
regClient.GracefulDeregister(nodeID)
return nil
return runErr
}
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.

View File

@@ -154,11 +154,21 @@ type RunCMD struct {
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
RegistrationRequireAuth bool `env:"LOCALAI_REGISTRATION_REQUIRE_AUTH" default:"false" help:"Fail startup when distributed mode is enabled but LOCALAI_REGISTRATION_TOKEN is empty (node endpoints and worker file-transfer server would otherwise be unauthenticated)" group:"distributed"`
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch: require BOTH NATS JWT credentials and a registration token when distributed mode is enabled (implies --nats-require-auth and --registration-require-auth)" group:"distributed"`
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
NatsAccountSeed string `env:"LOCALAI_NATS_ACCOUNT_SEED" help:"NATS account signing seed (SU...) used to mint per-node worker JWTs at registration" group:"distributed"`
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"NATS user JWT for the frontend (and agent workers) to publish control-plane messages" group:"distributed"`
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"NATS user signing seed (SU...) paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
NatsWorkerJWTTTL string `env:"LOCALAI_NATS_WORKER_JWT_TTL" help:"Lifetime of minted per-node NATS JWTs (e.g. 24h, default 24h)" group:"distributed"`
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT credentials (service JWT + account seed) when distributed mode is enabled" group:"distributed"`
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI); use with tls:// in --nats-url" group:"distributed"`
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
Version bool
@@ -283,6 +293,40 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
if r.RegistrationToken != "" {
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
}
if r.RegistrationRequireAuth {
opts = append(opts, config.EnableRegistrationRequireAuth)
}
if r.DistributedRequireAuth {
opts = append(opts, config.EnableDistributedRequireAuth)
}
if r.NatsAccountSeed != "" {
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
}
if r.NatsServiceJWT != "" {
opts = append(opts, config.WithNatsServiceJWT(r.NatsServiceJWT))
}
if r.NatsServiceSeed != "" {
opts = append(opts, config.WithNatsServiceSeed(r.NatsServiceSeed))
}
if r.NatsWorkerJWTTTL != "" {
d, err := time.ParseDuration(r.NatsWorkerJWTTTL)
if err != nil {
return fmt.Errorf("invalid LOCALAI_NATS_WORKER_JWT_TTL %q: %w", r.NatsWorkerJWTTTL, err)
}
opts = append(opts, config.WithNatsWorkerJWTTTL(d))
}
if r.NatsRequireAuth {
opts = append(opts, config.EnableNatsRequireAuth)
}
if r.NatsTLSCA != "" {
opts = append(opts, config.WithNatsTLSCA(r.NatsTLSCA))
}
if r.NatsTLSCert != "" {
opts = append(opts, config.WithNatsTLSCert(r.NatsTLSCert))
}
if r.NatsTLSKey != "" {
opts = append(opts, config.WithNatsTLSKey(r.NatsTLSKey))
}
if r.AutoApproveNodes {
opts = append(opts, config.EnableAutoApproveNodes)
}

View File

@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
options.Backend = t.Backend
options.Model = t.Model
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, "", nil, ml, opts, options)
if err != nil {
return err
}

View File

@@ -96,7 +96,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
FrontendURL: r.RegisterTo,
RegistrationToken: r.RegistrationToken,
}
nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
nodeID, _, _, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
if regErr != nil {
return fmt.Errorf("registering with frontend: %w", regErr)
}

View File

@@ -58,65 +58,77 @@ func (c *RegistrationClient) setAuth(req *http.Request) {
// RegisterResponse is the JSON body returned by /api/node/register.
type RegisterResponse struct {
ID string `json:"id"`
APIToken string `json:"api_token,omitempty"`
ID string `json:"id"`
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
APIToken string `json:"api_token,omitempty"`
NatsJWT string `json:"nats_jwt,omitempty"`
NatsUserSeed string `json:"nats_user_seed,omitempty"`
}
// Register sends a single registration request and returns the node ID and
// (optionally) an auto-provisioned API token.
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
// RegisterFull sends a single registration request and returns the full
// response (node ID, approval status, and optional API token / NATS creds).
// Re-registration is idempotent: the frontend preserves the node row and mints
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
jsonBody, _ := json.Marshal(body)
url := c.baseURL() + "/api/node/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return "", "", fmt.Errorf("creating request: %w", err)
return nil, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", "", fmt.Errorf("posting to %s: %w", url, err)
return nil, fmt.Errorf("posting to %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
}
var result RegisterResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", "", fmt.Errorf("decoding response: %w", err)
return nil, fmt.Errorf("decoding response: %w", err)
}
return result.ID, result.APIToken, nil
return &result, nil
}
// Register sends a single registration request and returns the node ID and
// optional credentials (API token for agent workers, NATS JWT when configured).
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
res, err := c.RegisterFull(ctx, body)
if err != nil {
return "", "", "", "", err
}
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
}
// RegisterWithRetry retries registration with exponential backoff.
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
backoff := 2 * time.Second
maxBackoff := 30 * time.Second
var nodeID, apiToken string
var err error
for attempt := 1; attempt <= maxRetries; attempt++ {
nodeID, apiToken, err = c.Register(ctx, body)
nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body)
if err == nil {
return nodeID, apiToken, nil
return nodeID, apiToken, natsJWT, natsSeed, nil
}
if attempt == maxRetries {
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
}
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
select {
case <-ctx.Done():
return "", "", ctx.Err()
return "", "", "", "", ctx.Err()
case <-time.After(backoff):
}
backoff = min(backoff*2, maxBackoff)
}
return nodeID, apiToken, err
return nodeID, apiToken, natsJWT, natsSeed, err
}
// Heartbeat sends a single heartbeat POST with the given body.

View File

@@ -0,0 +1,200 @@
package workerregistry
import (
"context"
"fmt"
"sync"
"time"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/mudler/xlog"
)
// statusPending mirrors nodes.StatusPending. It is duplicated rather than
// imported so the lightweight registration client does not pull in the nodes
// package (and its gorm/DB dependencies).
const statusPending = "pending"
// defaultMaxAttempts bounds how many times Acquire registers (and how many
// consecutive times RefreshLoop may fail) before giving up. It is high enough
// to ride out a slow admin approval or a transient frontend outage, but finite
// so an unauthorized/unapprovable worker exits and surfaces the problem (via a
// non-zero exit and the resulting restart) rather than waiting forever.
const defaultMaxAttempts = 100
// RegisterFunc performs one idempotent registration round-trip.
type RegisterFunc func(ctx context.Context) (*RegisterResponse, error)
// NATSCredentialManager acquires NATS credentials at startup — waiting through
// admin approval when required — and refreshes them before the minted JWT
// expires, by re-registering (which mints a fresh JWT). The live NATS
// connection adopts a refreshed JWT on its next reconnect via Provider. Safe
// for concurrent use.
//
// It addresses two failure modes: a worker that needs credentials but registers
// while still pending approval (it would otherwise give up and never connect),
// and a long-running worker whose 24h JWT expires with no way to renew it.
type NATSCredentialManager struct {
register RegisterFunc
requireCreds bool // block until credentials are present (frontend minting in use)
// Tunables; defaults set by NewNATSCredentialManager, overridable in tests.
initialBackoff time.Duration
maxBackoff time.Duration
maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited)
refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed
refreshRetry time.Duration
expiryOf func(jwt string) (time.Time, bool)
mu sync.RWMutex
jwt string
seed string
nodeID string
}
// NewNATSCredentialManager builds a manager over register. When requireCreds is
// true, Acquire blocks until the node is approved and credentials are minted.
func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager {
return &NATSCredentialManager{
register: register,
requireCreds: requireCreds,
initialBackoff: 2 * time.Second,
maxBackoff: 30 * time.Second,
maxAttempts: defaultMaxAttempts,
refreshLead: 0.75,
refreshRetry: 30 * time.Second,
expiryOf: jwtExpiry,
}
}
// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token
// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT).
func jwtExpiry(token string) (time.Time, bool) {
if token == "" {
return time.Time{}, false
}
uc, err := natsauth.DecodeUserClaims(token)
if err != nil || uc.Expires == 0 {
return time.Time{}, false
}
return time.Unix(uc.Expires, 0), true
}
func (m *NATSCredentialManager) store(res *RegisterResponse) {
m.mu.Lock()
defer m.mu.Unlock()
m.nodeID = res.ID
if res.NatsJWT != "" && res.NatsUserSeed != "" {
m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed
}
}
// Current returns the latest NATS credentials (both empty until acquired).
func (m *NATSCredentialManager) Current() (jwt, seed string) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.jwt, m.seed
}
// NodeID returns the node ID from the most recent registration.
func (m *NATSCredentialManager) NodeID() string {
m.mu.RLock()
defer m.mu.RUnlock()
return m.nodeID
}
// Provider returns a callback compatible with messaging.WithUserJWTProvider,
// supplying the current credentials on each (re)connect.
func (m *NATSCredentialManager) Provider() func() (string, string) {
return m.Current
}
// HasCredentials reports whether complete NATS credentials have been obtained.
func (m *NATSCredentialManager) HasCredentials() bool {
jwt, seed := m.Current()
return jwt != "" && seed != ""
}
// Acquire registers and, when requireCreds is set, keeps re-registering with
// exponential backoff until the node is approved (status != pending) and
// credentials are minted. Without requireCreds it returns the first successful
// response (the historical one-shot behavior, preserved for anonymous NATS).
func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) {
backoff := m.initialBackoff
var lastReason error
for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ {
res, err := m.register(ctx)
switch {
case err != nil:
lastReason = err
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
case !m.requireCreds:
m.store(res)
return res, nil
case res.Status == statusPending:
lastReason = fmt.Errorf("node %s still pending admin approval", res.ID)
xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
case res.NatsJWT == "" || res.NatsUserSeed == "":
lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID)
xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
default:
m.store(res)
return res, nil
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(backoff):
}
backoff = min(backoff*2, m.maxBackoff)
}
return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason)
}
// RefreshLoop re-registers to mint a fresh JWT before the current one expires,
// updating the credentials returned by Current/Provider so the NATS connection
// adopts them on its next reconnect. It returns nil when ctx is cancelled or
// when the current credential has no expiry (nothing to refresh), and a non-nil
// error after maxAttempts consecutive refresh failures — letting the caller
// exit the worker so it restarts and re-acquires (or surfaces the outage)
// rather than silently drifting toward an expired, unrenewable JWT.
func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error {
failures := 0
for {
jwt, _ := m.Current()
exp, ok := m.expiryOf(jwt)
if !ok {
xlog.Debug("NATS credential has no expiry; refresh loop exiting")
return nil
}
wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0)
select {
case <-ctx.Done():
return nil
case <-time.After(wait):
}
res, err := m.register(ctx)
if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" {
m.store(res)
failures = 0
xlog.Info("Refreshed NATS credentials", "node", res.ID)
continue
}
failures++
if err != nil {
xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err)
} else {
xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures)
}
if m.maxAttempts > 0 && failures >= m.maxAttempts {
return fmt.Errorf("NATS credential refresh failed %d times in a row", failures)
}
// Back off before retrying so a persistent failure near expiry does not spin.
select {
case <-ctx.Done():
return nil
case <-time.After(m.refreshRetry):
}
}
}

View File

@@ -0,0 +1,198 @@
package workerregistry
import (
"context"
"sync"
"testing"
"time"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/nats-io/nkeys"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestWorkerRegistry(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "WorkerRegistry")
}
// fakeRegister returns a sequence of canned responses/errors, one per call, and
// records how many times it was invoked. The last entry repeats once exhausted.
type fakeRegister struct {
mu sync.Mutex
steps []step
calls int
}
type step struct {
res *RegisterResponse
err error
}
func (f *fakeRegister) fn() RegisterFunc {
return func(context.Context) (*RegisterResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
i := f.calls
f.calls++
if i >= len(f.steps) {
i = len(f.steps) - 1
}
return f.steps[i].res, f.steps[i].err
}
}
func (f *fakeRegister) count() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls
}
var _ = Describe("NATSCredentialManager", func() {
approved := func(jwt, seed string) *RegisterResponse {
return &RegisterResponse{ID: "node-1", Status: "healthy", NatsJWT: jwt, NatsUserSeed: seed}
}
pending := &RegisterResponse{ID: "node-1", Status: "pending"}
Describe("Acquire (#4 — wait through admin approval)", func() {
It("keeps re-registering until the node is approved and credentials are minted", func() {
f := &fakeRegister{steps: []step{
{res: pending}, // not approved yet
{res: approved("", "")}, // approved but JWT not minted yet
{res: approved("jwt-1", "seed-1")}, // finally minted
}}
m := NewNATSCredentialManager(f.fn(), true /* requireCreds */)
m.initialBackoff = time.Millisecond
m.maxBackoff = time.Millisecond
res, err := m.Acquire(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(res.ID).To(Equal("node-1"))
Expect(f.count()).To(Equal(3))
jwt, seed := m.Current()
Expect(jwt).To(Equal("jwt-1"))
Expect(seed).To(Equal("seed-1"))
Expect(m.HasCredentials()).To(BeTrue())
Expect(m.NodeID()).To(Equal("node-1"))
})
It("returns immediately on the first success when credentials are not required (anonymous NATS)", func() {
f := &fakeRegister{steps: []step{{res: pending}}}
m := NewNATSCredentialManager(f.fn(), false /* requireCreds */)
res, err := m.Acquire(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(res.Status).To(Equal("pending"))
Expect(f.count()).To(Equal(1))
Expect(m.HasCredentials()).To(BeFalse())
})
It("aborts when the context is cancelled while waiting for approval", func() {
f := &fakeRegister{steps: []step{{res: pending}}}
m := NewNATSCredentialManager(f.fn(), true)
m.initialBackoff = 10 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := m.Acquire(ctx)
Expect(err).To(MatchError(context.Canceled))
})
It("gives up after a bounded number of attempts so the worker exits and alerts", func() {
f := &fakeRegister{steps: []step{{res: pending}}} // never approved
m := NewNATSCredentialManager(f.fn(), true)
m.initialBackoff = time.Millisecond
m.maxBackoff = time.Millisecond
m.maxAttempts = 5
_, err := m.Acquire(context.Background())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("after 5 attempts"))
Expect(err.Error()).To(ContainSubstring("pending admin approval"))
Expect(f.count()).To(Equal(5))
})
})
Describe("RefreshLoop (#5 — renew before the JWT expires)", func() {
It("re-registers before expiry and updates the credentials served to new connections", func() {
f := &fakeRegister{steps: []step{{res: approved("jwt-2", "seed-2")}}}
m := NewNATSCredentialManager(f.fn(), true)
m.refreshLead = 0.5
m.refreshRetry = time.Millisecond
// jwt-1 expires soon; jwt-2 is long-lived so the loop then idles.
m.expiryOf = func(jwt string) (time.Time, bool) {
switch jwt {
case "jwt-1":
return time.Now().Add(40 * time.Millisecond), true
case "jwt-2":
return time.Now().Add(time.Hour), true
default:
return time.Time{}, false
}
}
m.store(approved("jwt-1", "seed-1"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = m.RefreshLoop(ctx) }()
Eventually(func() string {
jwt, _ := m.Current()
return jwt
}, "2s", "10ms").Should(Equal("jwt-2"))
})
It("returns an error after the bounded number of consecutive failures so the caller can exit", func() {
f := &fakeRegister{steps: []step{{err: context.DeadlineExceeded}}} // refresh always fails
m := NewNATSCredentialManager(f.fn(), true)
m.refreshLead = 0.5
m.refreshRetry = time.Millisecond
m.maxAttempts = 3
m.expiryOf = func(string) (time.Time, bool) { return time.Now().Add(time.Millisecond), true }
m.store(approved("jwt-1", "seed-1"))
errCh := make(chan error, 1)
go func() { errCh <- m.RefreshLoop(context.Background()) }()
Eventually(errCh, "2s").Should(Receive(MatchError(ContainSubstring("3 times in a row"))))
})
It("exits promptly when the current credential has no expiry (nothing to refresh)", func() {
f := &fakeRegister{steps: []step{{res: approved("x", "y")}}}
m := NewNATSCredentialManager(f.fn(), true)
m.expiryOf = func(string) (time.Time, bool) { return time.Time{}, false }
m.store(approved("static", "seed"))
done := make(chan struct{})
go func() { _ = m.RefreshLoop(context.Background()); close(done) }()
Eventually(done, "1s").Should(BeClosed())
Expect(f.count()).To(Equal(0)) // never tried to re-register
})
})
Describe("jwtExpiry default", func() {
It("decodes the expiry of a real minted worker JWT", func() {
akp, err := nkeys.CreateAccount()
Expect(err).ToNot(HaveOccurred())
seed, err := akp.Seed()
Expect(err).ToNot(HaveOccurred())
cfg := natsauth.Config{AccountSeed: string(seed), WorkerJWTTTL: time.Hour}
token, _, err := cfg.MintWorkerJWT("node-1", "backend")
Expect(err).ToNot(HaveOccurred())
exp, ok := jwtExpiry(token)
Expect(ok).To(BeTrue())
Expect(exp).To(BeTemporally("~", time.Now().Add(time.Hour), 2*time.Minute))
})
It("reports no expiry for an empty or undecodable token", func() {
_, ok := jwtExpiry("")
Expect(ok).To(BeFalse())
_, ok = jwtExpiry("not-a-jwt")
Expect(ok).To(BeFalse())
})
})
})

View File

@@ -22,9 +22,11 @@ const (
UsecaseRerank = "rerank"
UsecaseDetection = "detection"
UsecaseVAD = "vad"
UsecaseAudioTransform = "audio_transform"
UsecaseDiarization = "diarization"
UsecaseRealtimeAudio = "realtime_audio"
UsecaseAudioTransform = "audio_transform"
UsecaseDiarization = "diarization"
UsecaseRealtimeAudio = "realtime_audio"
UsecaseFaceRecognition = "face_recognition"
UsecaseSpeakerRecognition = "speaker_recognition"
)
// GRPCMethod identifies a Backend service RPC from backend.proto.
@@ -47,6 +49,11 @@ const (
MethodAudioTransform GRPCMethod = "AudioTransform"
MethodDiarize GRPCMethod = "Diarize"
MethodAudioToAudioStream GRPCMethod = "AudioToAudioStream"
MethodFaceVerify GRPCMethod = "FaceVerify"
MethodFaceAnalyze GRPCMethod = "FaceAnalyze"
MethodVoiceVerify GRPCMethod = "VoiceVerify"
MethodVoiceEmbed GRPCMethod = "VoiceEmbed"
MethodVoiceAnalyze GRPCMethod = "VoiceAnalyze"
)
// UsecaseInfo describes a single known_usecase value and how it maps
@@ -154,6 +161,16 @@ var UsecaseInfoMap = map[string]UsecaseInfo{
GRPCMethod: MethodAudioToAudioStream,
Description: "Self-contained any-to-any audio model for the Realtime API — accepts microphone audio and emits speech + transcript (+ optional function calls) from a single backend via the AudioToAudioStream RPC.",
},
UsecaseFaceRecognition: {
Flag: FLAG_FACE_RECOGNITION,
GRPCMethod: MethodFaceVerify,
Description: "Face recognition — verify identity, analyze attributes (age/gender/emotion) via FaceVerify and FaceAnalyze RPCs.",
},
UsecaseSpeakerRecognition: {
Flag: FLAG_SPEAKER_RECOGNITION,
GRPCMethod: MethodVoiceVerify,
Description: "Speaker recognition — verify identity, embed and analyze voice via VoiceVerify, VoiceEmbed and VoiceAnalyze RPCs.",
},
}
// BackendCapability describes which gRPC methods and usecases a backend supports.
@@ -471,6 +488,21 @@ var BackendCapabilities = map[string]BackendCapability{
DefaultUsecases: []string{UsecaseDetection},
Description: "RF-DETR C++ object detection",
},
// --- Face and speaker recognition backends ---
"insightface": {
GRPCMethods: []GRPCMethod{MethodEmbedding, MethodDetect, MethodFaceVerify, MethodFaceAnalyze},
PossibleUsecases: []string{UsecaseEmbeddings, UsecaseDetection, UsecaseFaceRecognition},
DefaultUsecases: []string{UsecaseFaceRecognition},
AcceptsImages: true,
Description: "InsightFace — face detection, embedding, verification and attribute analysis",
},
"speaker-recognition": {
GRPCMethods: []GRPCMethod{MethodVoiceVerify, MethodVoiceEmbed, MethodVoiceAnalyze},
PossibleUsecases: []string{UsecaseSpeakerRecognition},
DefaultUsecases: []string{UsecaseSpeakerRecognition},
Description: "Speaker recognition — voice identity verification and analysis",
},
"silero-vad": {
GRPCMethods: []GRPCMethod{MethodVAD},
PossibleUsecases: []string{UsecaseVAD},

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"time"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/mudler/xlog"
)
@@ -16,7 +18,29 @@ type DistributedConfig struct {
NatsURL string // --nats-url / LOCALAI_NATS_URL
StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint)
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
// RegistrationRequireAuth fails startup when distributed mode is enabled but
// RegistrationToken is empty. The default (false) keeps the historical
// fail-open behavior with a loud warning; production should set it so the
// node-register endpoints and the worker file-transfer server cannot run
// unauthenticated. Mirrors NatsRequireAuth for the NATS bus.
RegistrationRequireAuth bool // LOCALAI_REGISTRATION_REQUIRE_AUTH
// RequireAuth is the umbrella switch (LOCALAI_DISTRIBUTED_REQUIRE_AUTH) for
// distributed-mode auth: when true it implies BOTH NatsRequireAuth and
// RegistrationRequireAuth, so a single knob locks down the bus and the
// registration/file-transfer layer together. The granular flags remain
// available to enforce just one layer.
RequireAuth bool // LOCALAI_DISTRIBUTED_REQUIRE_AUTH
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
NatsServiceJWT string // LOCALAI_NATS_SERVICE_JWT — user JWT for frontends / agent workers
NatsServiceSeed string // LOCALAI_NATS_SERVICE_SEED — signing seed paired with service JWT
NatsWorkerJWTTTL time.Duration // LOCALAI_NATS_WORKER_JWT_TTL — minted worker JWT lifetime (default 24h)
NatsRequireAuth bool // LOCALAI_NATS_REQUIRE_AUTH — fail startup if NATS credentials are missing
NatsTLSCA string // LOCALAI_NATS_TLS_CA — PEM file for private CA (server verify)
NatsTLSCert string // LOCALAI_NATS_TLS_CERT — client cert for NATS mTLS
NatsTLSKey string // LOCALAI_NATS_TLS_KEY — client key paired with NatsTLSCert
// S3 configuration (used when StorageURL is set)
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
@@ -76,10 +100,23 @@ func (c DistributedConfig) Validate() error {
(c.StorageAccessKey == "" && c.StorageSecretKey != "") {
return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty")
}
// Warn about missing registration token (not an error)
// The registration token guards both the node HTTP register/heartbeat
// endpoints and the worker file-transfer server (which fails open on an
// empty token). Enforce it when registration auth is required (the granular
// flag or the umbrella); otherwise warn.
if c.RegistrationToken == "" {
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
if c.RegistrationAuthRequired() {
return fmt.Errorf("registration auth is required (LOCALAI_REGISTRATION_REQUIRE_AUTH or LOCALAI_DISTRIBUTED_REQUIRE_AUTH) but LOCALAI_REGISTRATION_TOKEN is empty")
}
xlog.Warn("distributed mode running without registration token — node endpoints and the worker file-transfer server are unprotected; set LOCALAI_REGISTRATION_TOKEN, or LOCALAI_DISTRIBUTED_REQUIRE_AUTH=true to fail closed")
}
if err := c.NatsAuthConfig().Validate(); err != nil {
return err
}
if err := c.NatsTLSFiles().Validate(); err != nil {
return err
}
c.NatsAuthConfig().WarnIfInsecure(true)
// Check for negative durations
for name, d := range map[string]time.Duration{
FlagMCPToolTimeout: c.MCPToolTimeout,
@@ -123,6 +160,76 @@ func WithRegistrationToken(token string) AppOption {
}
}
func WithNatsAccountSeed(seed string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsAccountSeed = seed
}
}
func WithNatsServiceJWT(jwt string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsServiceJWT = jwt
}
}
func WithNatsServiceSeed(seed string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsServiceSeed = seed
}
}
func WithNatsWorkerJWTTTL(d time.Duration) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsWorkerJWTTTL = d
}
}
var EnableNatsRequireAuth = func(o *ApplicationConfig) {
o.Distributed.NatsRequireAuth = true
}
// EnableRegistrationRequireAuth makes an empty registration token a hard error
// in distributed mode (see DistributedConfig.RegistrationRequireAuth).
var EnableRegistrationRequireAuth = func(o *ApplicationConfig) {
o.Distributed.RegistrationRequireAuth = true
}
// EnableDistributedRequireAuth is the umbrella switch implying both
// NatsRequireAuth and RegistrationRequireAuth (see DistributedConfig.RequireAuth).
var EnableDistributedRequireAuth = func(o *ApplicationConfig) {
o.Distributed.RequireAuth = true
}
// RegistrationAuthRequired reports whether an empty registration token must be
// treated as a fatal misconfiguration — the granular flag or the umbrella.
func (c DistributedConfig) RegistrationAuthRequired() bool {
return c.RegistrationRequireAuth || c.RequireAuth
}
// NatsAuthRequired reports whether NATS JWT credentials must be present — the
// granular flag or the umbrella.
func (c DistributedConfig) NatsAuthRequired() bool {
return c.NatsRequireAuth || c.RequireAuth
}
func WithNatsTLSCA(path string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsTLSCA = path
}
}
func WithNatsTLSCert(path string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsTLSCert = path
}
}
func WithNatsTLSKey(path string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsTLSKey = path
}
}
func WithStorageURL(url string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageURL = url
@@ -217,6 +324,44 @@ const (
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
const DefaultMaxUploadSize int64 = 50 << 30
// NatsTLSFiles returns NATS TLS/mTLS PEM paths for the messaging client.
func (c DistributedConfig) NatsTLSFiles() messaging.TLSFiles {
return messaging.TLSFiles{
CA: c.NatsTLSCA,
Cert: c.NatsTLSCert,
Key: c.NatsTLSKey,
}
}
// NatsMessagingOptions builds messaging client options (JWT + TLS) for distributed components.
// Pass explicit userJWT/userSeed when set (e.g. worker overrides); empty uses service JWT from config.
func (c DistributedConfig) NatsMessagingOptions(userJWT, userSeed string) []messaging.Option {
var opts []messaging.Option
jwt, seed := userJWT, userSeed
if jwt == "" && seed == "" {
auth := c.NatsAuthConfig()
jwt, seed = auth.ServiceUserJWT, auth.ServiceUserSeed
}
if jwt != "" && seed != "" {
opts = append(opts, messaging.WithUserJWT(jwt, seed))
}
if tls := c.NatsTLSFiles(); tls.Enabled() {
opts = append(opts, messaging.WithTLS(tls))
}
return opts
}
// NatsAuthConfig builds pkg/natsauth settings from distributed configuration.
func (c DistributedConfig) NatsAuthConfig() natsauth.Config {
return natsauth.Config{
AccountSeed: c.NatsAccountSeed,
ServiceUserJWT: c.NatsServiceJWT,
ServiceUserSeed: c.NatsServiceSeed,
WorkerJWTTTL: c.NatsWorkerJWTTTL,
RequireAuth: c.NatsAuthRequired(),
}
}
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)

View File

@@ -88,3 +88,66 @@ var _ = Describe("DistributedConfig.Validate negative-duration errors", func() {
Expect(c.Validate()).To(Succeed())
})
})
var _ = Describe("DistributedConfig.Validate registration auth", func() {
It("rejects an empty registration token when RequireAuth is set", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
RegistrationRequireAuth: true,
}
err := c.Validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_REQUIRE_AUTH"))
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
})
It("accepts a set registration token when RequireAuth is set", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
RegistrationToken: "s3cret",
RegistrationRequireAuth: true,
}
Expect(c.Validate()).To(Succeed())
})
It("warns but succeeds with an empty token when RequireAuth is unset", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
}
Expect(c.Validate()).To(Succeed())
})
It("rejects an empty token when the umbrella RequireAuth is set", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
RequireAuth: true,
// Provide NATS creds so only the registration-token gap remains.
NatsServiceJWT: "jwt",
NatsServiceSeed: "seed",
NatsAccountSeed: "acct",
}
err := c.Validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("LOCALAI_DISTRIBUTED_REQUIRE_AUTH"))
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
})
It("the umbrella implies NATS auth is required", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
RegistrationToken: "tok", // registration layer satisfied
RequireAuth: true, // umbrella → NATS creds now required
}
Expect(c.NatsAuthRequired()).To(BeTrue())
Expect(c.RegistrationAuthRequired()).To(BeTrue())
// Missing NATS service JWT/seed must now be fatal.
err := c.Validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("LOCALAI_NATS_REQUIRE_AUTH"))
})
})

View File

@@ -39,7 +39,21 @@ func llamaCppDefaults(cfg *ModelConfig, modelPath string) {
}
}()
f, err := gguf.ParseGGUFFile(guessPath)
// Startup parses every model's GGUF header to guess defaults. We only need
// scalar metadata (architecture, head/ff counts, chat_template, token IDs,
// MTP head) plus array *lengths* — never the array *contents*. Two options
// keep this cheap, which matters when many models live on slow storage such
// as a Docker volume (see https://github.com/mudler/LocalAI/issues/9790):
//
// - SkipLargeMetadata: seek past large array-valued metadata (the tokenizer
// vocab: tokenizer.ggml.tokens/scores/merges, often >100k entries) instead
// of reading and allocating every element. Lengths stay populated.
// - UseMMap: read the header via a memory map so faulting in a few pages
// replaces hundreds of thousands of tiny read() syscalls (measured ~524k
// -> 8 for a 256k-token vocab), the dominant cost on slow filesystems.
//
// The mapping is released when ParseGGUFFile returns.
f, err := gguf.ParseGGUFFile(guessPath, gguf.UseMMap(), gguf.SkipLargeMetadata())
if err == nil {
guessGGUFFromFile(cfg, f, 0)
}

View File

@@ -1,13 +1,76 @@
package config_test
import (
"bytes"
"encoding/binary"
"os"
"path/filepath"
. "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
gguf "github.com/gpustack/gguf-parser-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// GGUF metadata value type tags (see github.com/gpustack/gguf-parser-go).
const (
ggufTypeUint32 uint32 = 4
ggufTypeString uint32 = 8
ggufTypeArray uint32 = 9
)
// writeTestGGUF emits a minimal but valid little-endian GGUF v3 header carrying
// the scalar metadata the llama-cpp hook guesses from plus a large string vocab
// array (tokenizer.ggml.tokens). The big array is exactly what SkipLargeMetadata
// + UseMMap are expected to avoid reading element-by-element, so it must survive a
// round-trip through the real hook without corrupting the guessed defaults.
func writeTestGGUF(path, chatTemplate string, vocab int) error {
wStr := func(b *bytes.Buffer, s string) {
binary.Write(b, binary.LittleEndian, uint64(len(s)))
b.WriteString(s)
}
kvStr := func(b *bytes.Buffer, k, v string) {
wStr(b, k)
binary.Write(b, binary.LittleEndian, ggufTypeString)
wStr(b, v)
}
kvU32 := func(b *bytes.Buffer, k string, v uint32) {
wStr(b, k)
binary.Write(b, binary.LittleEndian, ggufTypeUint32)
binary.Write(b, binary.LittleEndian, v)
}
var meta bytes.Buffer
kvStr(&meta, "general.architecture", "llama")
kvStr(&meta, "general.name", "ReproModel")
kvU32(&meta, "llama.context_length", 4096)
kvU32(&meta, "llama.attention.head_count", 32)
kvU32(&meta, "llama.feed_forward_length", 11008)
kvU32(&meta, "llama.block_count", 32)
kvU32(&meta, "tokenizer.ggml.bos_token_id", 1)
kvStr(&meta, "tokenizer.chat_template", chatTemplate)
// large array value — the one the optimization skips reading
wStr(&meta, "tokenizer.ggml.tokens")
binary.Write(&meta, binary.LittleEndian, ggufTypeArray)
binary.Write(&meta, binary.LittleEndian, ggufTypeString)
binary.Write(&meta, binary.LittleEndian, uint64(vocab))
for i := 0; i < vocab; i++ {
wStr(&meta, "token")
}
var out bytes.Buffer
binary.Write(&out, binary.LittleEndian, gguf.GGUFMagicGGUFLe)
binary.Write(&out, binary.LittleEndian, uint32(3)) // version
binary.Write(&out, binary.LittleEndian, uint64(0)) // tensor count
binary.Write(&out, binary.LittleEndian, uint64(9)) // metadata kv count
out.Write(meta.Bytes())
return os.WriteFile(path, out.Bytes(), 0o644)
}
var _ = Describe("Backend hooks and parser defaults", func() {
Context("MatchParserDefaults", func() {
It("matches Qwen3 family", func() {
@@ -137,6 +200,58 @@ var _ = Describe("Backend hooks and parser defaults", func() {
})
})
Context("llamaCppDefaults GGUF guessing", func() {
// Regression coverage for https://github.com/mudler/LocalAI/issues/9790:
// the hook reads GGUF headers with SkipLargeMetadata + UseMMap to avoid
// pulling the whole tokenizer vocab off (slow) disk on every startup. This
// verifies that skipping the vocab array still yields the correct guessed
// defaults from the remaining scalar metadata.
const chatTemplate = "{{ bos_token }}{% for m in messages %}{{ m.content }}{% endfor %}"
It("guesses defaults from a GGUF whose large vocab is skipped", func() {
dir := GinkgoT().TempDir()
modelFile := "repro.gguf"
Expect(writeTestGGUF(filepath.Join(dir, modelFile), chatTemplate, 50000)).To(Succeed())
// A pre-set context size short-circuits the GGUF run-estimate, which
// needs full tensor info this header-only fixture deliberately omits;
// the metadata-reading path the optimization touches is unaffected.
ctxSize := 4096
cfg := &ModelConfig{
Backend: "llama-cpp",
LLMConfig: LLMConfig{ContextSize: &ctxSize},
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{Model: modelFile},
},
}
cfg.SetDefaults(ModelPath(dir))
// chat_template is a scalar string, not part of the skipped array,
// so it must be captured verbatim.
Expect(cfg.GetModelTemplate()).To(Equal(chatTemplate))
// scalar-derived defaults are still applied
Expect(cfg.ContextSize).NotTo(BeNil())
Expect(cfg.NGPULayers).NotTo(BeNil())
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
Expect(cfg.KnownUsecaseStrings).To(ContainElement("FLAG_CHAT"))
})
It("falls back to the default context size when the GGUF is unreadable", func() {
dir := GinkgoT().TempDir()
Expect(os.WriteFile(filepath.Join(dir, "bad.gguf"), []byte("not a gguf"), 0o644)).To(Succeed())
cfg := &ModelConfig{
Backend: "llama-cpp",
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{Model: "bad.gguf"},
},
}
cfg.SetDefaults(ModelPath(dir))
Expect(cfg.ContextSize).NotTo(BeNil())
})
})
Context("PromptCacheAll default", func() {
It("defaults to true when omitted from YAML", func() {
cfg := &ModelConfig{}

View File

@@ -128,6 +128,22 @@ func DefaultRegistry() map[string]FieldMetaOverride {
Advanced: true,
Order: 21,
},
"reasoning_effort": {
Section: "llm",
Label: "Reasoning Effort",
Description: "Default reasoning effort, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). A per-request reasoning_effort overrides it. 'none' also turns thinking off.",
Component: "select",
Options: []FieldOption{
{Value: "", Label: "Unset (model default)"},
{Value: "none", Label: "none (disable thinking)"},
{Value: "minimal", Label: "minimal"},
{Value: "low", Label: "low"},
{Value: "medium", Label: "medium"},
{Value: "high", Label: "high"},
},
Advanced: true,
Order: 22,
},
"cache_type_k": {
Section: "llm",
Label: "KV Cache Type (K)",
@@ -277,6 +293,21 @@ func DefaultRegistry() map[string]FieldMetaOverride {
AutocompleteProvider: ProviderModelsVAD,
Order: 63,
},
"pipeline.reasoning_effort": {
Section: "pipeline",
Label: "Reasoning Effort",
Description: "Reasoning effort for the pipeline's LLM, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). Overrides the LLM model's own reasoning_effort. 'none' also turns thinking off.",
Component: "select",
Options: []FieldOption{
{Value: "", Label: "Default (model config)"},
{Value: "none", Label: "none (disable thinking)"},
{Value: "minimal", Label: "minimal"},
{Value: "low", Label: "low"},
{Value: "medium", Label: "medium"},
{Value: "high", Label: "high"},
},
Order: 64,
},
// --- Functions ---
"function.grammar.parallel_calls": {

View File

@@ -63,6 +63,13 @@ type ModelConfig struct {
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
// ReasoningEffort is the default reasoning effort (none|minimal|low|medium|high)
// for this model. A per-request reasoning_effort overrides it. It is forwarded
// to the backend as the reasoning_effort chat_template_kwarg (see
// gRPCPredictOpts), so jinja-templated models that key on it — e.g. gpt-oss
// (Harmony) or LFM2.5 — honor it; "none" also toggles enable_thinking off.
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline" json:",inline"`
@@ -487,6 +494,40 @@ type Pipeline struct {
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
// ReasoningEffort sets the reasoning effort (none|minimal|low|medium|high) for
// the pipeline's LLM without editing the LLM model config. Overrides the LLM's
// own reasoning_effort. Unset leaves the LLM model config in charge.
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
}
// ApplyReasoningEffort resolves the effective reasoning effort — a per-request
// value (requestEffort) overrides the config's own ReasoningEffort default —
// stores it on the config so gRPCPredictOpts forwards it to the backend as the
// reasoning_effort chat_template_kwarg, and maps it onto the enable_thinking
// toggle the backend also reads:
// - "none" always disables thinking.
// - any explicit level enables it, UNLESS the config already disabled reasoning
// (an operator's explicit disable wins over a request asking to think).
//
// An empty requestEffort keeps the config's own default. With no effort set
// anywhere it is a no-op, leaving the model's reasoning settings untouched.
func (c *ModelConfig) ApplyReasoningEffort(requestEffort string) {
effort := requestEffort
if effort == "" {
effort = c.ReasoningEffort
}
c.ReasoningEffort = effort
switch strings.ToLower(effort) {
case "none":
disable := true
c.ReasoningConfig.DisableReasoning = &disable
case "minimal", "low", "medium", "high":
if c.ReasoningConfig.DisableReasoning == nil || !*c.ReasoningConfig.DisableReasoning {
enable := false
c.ReasoningConfig.DisableReasoning = &enable
}
}
}
// @Description File configuration for model downloads

View File

@@ -30,11 +30,26 @@ func MTPSpecOptions() []string {
return out
}
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a Multi-Token
// Prediction head. Detection reads `<arch>.nextn_predict_layers`, which is
// what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
// isDraftOnlyAssistantArch reports whether an architecture names a standalone
// MTP *draft* model rather than a self-speculating trunk. Upstream's Gemma4 MTP
// (ggml-org/llama.cpp#23398) registers the head as a separate `gemma4-assistant`
// architecture whose GGUF still carries `nextn_predict_layers`, but which cannot
// run alone: it requires a paired target context (`ctx_other`). Such archs must
// not trigger the embedded-head self-speculation defaults. The `-assistant`
// suffix is upstream's naming convention for these draft-only checkpoints.
func isDraftOnlyAssistantArch(arch string) bool {
return strings.HasSuffix(arch, "-assistant")
}
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a self-speculating
// Multi-Token Prediction head. Detection reads `<arch>.nextn_predict_layers`,
// which is what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
// `conversion/qwen.py` MTP mixin. A positive layer count means the head is
// present in the same GGUF as the trunk.
//
// Draft-only assistant architectures (e.g. Gemma4's `gemma4-assistant`) carry
// the same key but are separate draft checkpoints meant to be paired with a
// target model, so they are deliberately excluded here.
func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
if f == nil {
return 0, false
@@ -43,6 +58,9 @@ func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
if arch == "" {
return 0, false
}
if isDraftOnlyAssistantArch(arch) {
return 0, false
}
v, ok := f.Header.MetadataKV.Get(arch + ".nextn_predict_layers")
if !ok {
return 0, false

View File

@@ -3,10 +3,33 @@ package config_test
import (
. "github.com/mudler/LocalAI/core/config"
gguf "github.com/gpustack/gguf-parser-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// ggufWithArch fabricates a minimal in-memory GGUF carrying the given
// `general.architecture` and a positive `<arch>.nextn_predict_layers` count,
// so HasEmbeddedMTPHead can be exercised without a real model file.
func ggufWithArch(arch string, nextn uint32) *gguf.GGUFFile {
return &gguf.GGUFFile{
Header: gguf.GGUFHeader{
MetadataKV: gguf.GGUFMetadataKVs{
{
Key: "general.architecture",
ValueType: gguf.GGUFMetadataValueTypeString,
Value: arch,
},
{
Key: arch + ".nextn_predict_layers",
ValueType: gguf.GGUFMetadataValueTypeUint32,
Value: nextn,
},
},
},
}
}
var _ = Describe("MTP auto-defaults", func() {
Context("MTPSpecOptions", func() {
It("returns the upstream-recommended speculative tuple", func() {
@@ -82,5 +105,20 @@ var _ = Describe("MTP auto-defaults", func() {
Expect(ok).To(BeFalse())
Expect(n).To(BeZero())
})
It("detects a same-GGUF embedded head (DeepSeek/Qwen style)", func() {
n, ok := HasEmbeddedMTPHead(ggufWithArch("qwen3moe", 1))
Expect(ok).To(BeTrue())
Expect(n).To(Equal(uint32(1)))
})
It("ignores a gemma4-assistant draft-only model", func() {
// The assistant GGUF carries nextn_predict_layers but is a separate
// draft model that requires a paired target (ctx_other); it cannot
// self-speculate, so it must not trigger the embedded-head defaults.
n, ok := HasEmbeddedMTPHead(ggufWithArch("gemma4-assistant", 48))
Expect(ok).To(BeFalse())
Expect(n).To(BeZero())
})
})
})

View File

@@ -0,0 +1,52 @@
package config_test
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/config"
)
// ApplyReasoningEffort resolves the effective reasoning effort (request value
// overrides the model config default), stores it on the config so it reaches the
// backend, and maps it onto the enable_thinking toggle.
var _ = Describe("ModelConfig.ApplyReasoningEffort", func() {
It("uses the request value over the config default", func() {
c := &config.ModelConfig{ReasoningEffort: "high"}
c.ApplyReasoningEffort("none")
Expect(c.ReasoningEffort).To(Equal("none"))
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
})
It("falls back to the config default when the request omits it", func() {
c := &config.ModelConfig{ReasoningEffort: "none"}
c.ApplyReasoningEffort("")
Expect(c.ReasoningEffort).To(Equal("none"))
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
})
It("enables thinking for an explicit effort level", func() {
c := &config.ModelConfig{}
c.ApplyReasoningEffort("medium")
Expect(c.ReasoningEffort).To(Equal("medium"))
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
Expect(*c.ReasoningConfig.DisableReasoning).To(BeFalse())
})
It("does not let a level override an operator's config-level disable", func() {
disabled := true
c := &config.ModelConfig{}
c.ReasoningConfig.DisableReasoning = &disabled
c.ApplyReasoningEffort("high")
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
})
It("is a no-op on the toggle when no effort is set anywhere", func() {
c := &config.ModelConfig{}
c.ApplyReasoningEffort("")
Expect(c.ReasoningEffort).To(Equal(""))
Expect(c.ReasoningConfig.DisableReasoning).To(BeNil())
})
})

View File

@@ -420,8 +420,9 @@ func API(application *application.Application) (*echo.Echo, error) {
remoteUnloader = d.Router.Unloader()
}
}
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
natsCfg := distCfg.NatsAuthConfig()
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, natsCfg)
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken, natsCfg)
// Distributed SSE routes (job progress + agent events via NATS)
if d := application.Distributed(); d != nil {

View File

@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, "", nil, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -28,6 +28,7 @@ import (
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
"github.com/mudler/LocalAI/pkg/httpclient"
"github.com/mudler/LocalAI/pkg/natsauth"
)
// nodeError builds a schema.ErrorResponse for node endpoints.
@@ -89,7 +90,7 @@ type RegisterNodeRequest struct {
// RegisterNodeEndpoint registers a new backend node.
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
return func(c echo.Context) error {
var req RegisterNodeRequest
if err := c.Bind(&req); err != nil {
@@ -217,13 +218,15 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
}
}
attachNatsJWT(response, node, natsCfg)
return c.JSON(http.StatusCreated, response)
}
}
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
// For agent workers, it also provisions an API key so they can call the inference API.
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
id := c.Param("id")
@@ -253,10 +256,26 @@ func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecr
}
}
attachNatsJWT(response, node, natsCfg)
return c.JSON(http.StatusOK, response)
}
}
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
return
}
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
if err != nil {
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
return
}
response["nats_jwt"] = jwt
response["nats_user_seed"] = seed
}
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
// Returns the plaintext API key on success.
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {

View File

@@ -12,6 +12,8 @@ import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/testutil"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/nats-io/nkeys"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -63,7 +65,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusCreated))
@@ -74,6 +76,29 @@ var _ = Describe("Node HTTP handlers", func() {
Expect(resp["status"]).To(Equal(nodes.StatusHealthy))
})
It("returns nats_jwt when account seed is configured", func() {
akp, err := nkeys.CreateAccount()
Expect(err).ToNot(HaveOccurred())
seed, err := akp.Seed()
Expect(err).ToNot(HaveOccurred())
e := echo.New()
body := `{"name":"worker-nats","address":"10.0.0.2:50051"}`
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
natsCfg := natsauth.Config{AccountSeed: string(seed)}
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsCfg)
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusCreated))
var resp map[string]any
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
Expect(resp["nats_jwt"]).ToNot(BeEmpty())
})
It("returns 400 when name is missing", func() {
e := echo.New()
body := `{"address":"10.0.0.1:50051"}`
@@ -82,7 +107,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -102,7 +127,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -121,7 +146,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -140,7 +165,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -159,7 +184,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "")
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
@@ -172,7 +197,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", false, nil, "")
handler := RegisterNodeEndpoint(registry, "", false, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusCreated))
@@ -195,7 +220,7 @@ var _ = Describe("Node HTTP handlers", func() {
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body1))
req1.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec1 := httptest.NewRecorder()
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(e.NewContext(req1, rec1))).To(Succeed())
Expect(rec1.Code).To(Equal(http.StatusCreated))

View File

@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
c.Response().Header().Set("Connection", "keep-alive")
// Stream audio chunks as they're generated
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg, func(audioChunk []byte) error {
_, writeErr := c.Response().Write(audioChunk)
if writeErr != nil {
return writeErr
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
}
// Non-streaming TTS (existing behavior)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -44,10 +44,10 @@ type wrappedModel struct {
// deps in. nil-safe: with classifierRegistry == nil the per-turn
// routing block in Predict is skipped, preserving today's "one LLM
// for the whole session" behaviour.
routerDeps *middleware.ClassifierDeps
routerStore router.DecisionStore
routerSessionID string
routerUserID string
routerDeps *middleware.ClassifierDeps
routerStore router.DecisionStore
routerSessionID string
routerUserID string
}
// anyToAnyModel represent a model which supports Any-to-Any operations
@@ -119,6 +119,11 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
}
}
// Surface the resolved reasoning effort to the Go-side template path too
// (jinja models get it via backend metadata in gRPCPredictOpts; Go-templated
// models like gpt-oss read it from the template's .ReasoningEffort).
input.ReasoningEffort = turnCfg.ReasoningEffort
var predInput string
var funcs []functions.Function
if !turnCfg.TemplateConfig.UseTokenizerTemplate {
@@ -313,7 +318,7 @@ func newRealtimeDecisionID() string {
}
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
return backend.ModelTTS(ctx, text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
return backend.ModelTTS(ctx, text, voice, language, "", nil, m.modelLoader, m.appConfig, *m.TTSConfig)
}
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
@@ -449,6 +454,9 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
return nil, fmt.Errorf("failed to validate config: %w", err)
}
// Let the pipeline set the LLM's reasoning effort (cfgLLM is a per-session copy).
applyPipelineReasoning(cfgLLM, *pipeline)
cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
if err != nil {

View File

@@ -0,0 +1,16 @@
package openai
import "github.com/mudler/LocalAI/core/config"
// applyPipelineReasoning sets the reasoning effort for a realtime pipeline's LLM
// from the pipeline config, without editing the underlying LLM model config. The
// pipeline value overrides the LLM's own reasoning_effort; when the pipeline does
// not set it, the LLM model config's reasoning_effort (if any) is used. The LLM
// config passed in is the per-session copy returned by the config loader, so this
// does not affect other users of the same model.
func applyPipelineReasoning(llm *config.ModelConfig, pipeline config.Pipeline) {
if llm == nil {
return
}
llm.ApplyReasoningEffort(pipeline.ReasoningEffort)
}

View File

@@ -0,0 +1,33 @@
package openai
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/config"
)
// applyPipelineReasoning lets a realtime pipeline set the reasoning effort for
// its LLM (forwarded to the backend as reasoning_effort) without editing the LLM
// model config. The pipeline value overrides the LLM's own reasoning_effort.
var _ = Describe("applyPipelineReasoning", func() {
It("applies the pipeline reasoning_effort to the LLM config", func() {
llm := &config.ModelConfig{}
applyPipelineReasoning(llm, config.Pipeline{ReasoningEffort: "none"})
Expect(llm.ReasoningEffort).To(Equal("none"))
Expect(llm.ReasoningConfig.DisableReasoning).ToNot(BeNil())
Expect(*llm.ReasoningConfig.DisableReasoning).To(BeTrue())
})
It("falls back to the LLM's own reasoning_effort when the pipeline is unset", func() {
llm := &config.ModelConfig{ReasoningEffort: "high"}
applyPipelineReasoning(llm, config.Pipeline{})
Expect(llm.ReasoningEffort).To(Equal("high"))
Expect(llm.ReasoningConfig.DisableReasoning).ToNot(BeNil())
Expect(*llm.ReasoningConfig.DisableReasoning).To(BeFalse())
})
It("is nil-safe", func() {
applyPipelineReasoning(nil, config.Pipeline{ReasoningEffort: "low"})
})
})

View File

@@ -310,25 +310,13 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.
config.Temperature = input.Temperature
}
// Map the per-request reasoning_effort onto the reasoning toggle the
// backend reads (enable_thinking metadata, set in gRPCPredictOpts).
// "none" disables thinking for this request - the use case from #10072,
// running a single Qwen3-style model and turning reasoning off per
// request. Any explicit effort level enables thinking, UNLESS the model
// config explicitly disabled it (DisableReasoning==true wins): an
// operator who deliberately turned reasoning off should not be overridden
// by a request. A value of "none" always disables, since that never
// conflicts with a config that also disables.
switch strings.ToLower(input.ReasoningEffort) {
case "none":
disable := true
config.ReasoningConfig.DisableReasoning = &disable
case "minimal", "low", "medium", "high":
if config.ReasoningConfig.DisableReasoning == nil || !*config.ReasoningConfig.DisableReasoning {
enable := false
config.ReasoningConfig.DisableReasoning = &enable
}
}
// Resolve the effective reasoning effort (request overrides the model config
// default), store it so gRPCPredictOpts forwards it to the backend as the
// reasoning_effort chat_template_kwarg (what gpt-oss / LFM2.5 read), and map
// it onto the enable_thinking toggle. "none" disables thinking (the #10072
// use case); a level enables it unless the config already disabled reasoning
// (an operator's explicit disable wins over a request asking to think).
config.ApplyReasoningEffort(input.ReasoningEffort)
// Collapse the modern max_completion_tokens alias into the
// legacy Maxtokens field so downstream code reads exactly one.

View File

@@ -10,6 +10,7 @@ import (
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/pkg/natsauth"
"gorm.io/gorm"
)
@@ -35,7 +36,7 @@ func nodeReadyMiddleware(registry *nodes.NodeRegistry) echo.MiddlewareFunc {
// token but do not verify per-node identity. A compromised worker can heartbeat/drain/
// deregister other nodes. Future: issue per-node JWT at registration, validate node
// identity on subsequent requests (compare :id param with token subject).
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) {
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) {
if registry == nil {
return
}
@@ -44,7 +45,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
tokenAuthMw := nodeTokenAuth(registrationToken)
node := e.Group("/api/node", readyMw, tokenAuthMw)
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret))
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret, natsCfg))
node.POST("/:id/heartbeat", localai.HeartbeatEndpoint(registry))
node.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
node.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
@@ -60,7 +61,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
// backend install path (POST /:id/backends/install). That handler enqueues a
// ManagementOp on the gallery channel rather than blocking on a NATS reply, so
// the browser gets HTTP 202 + jobID immediately instead of waiting up to 3 minutes.
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string) {
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string, natsCfg natsauth.Config) {
if registry == nil {
return
}
@@ -81,7 +82,7 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry))
admin.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
admin.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret))
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret, natsCfg))
// Backend management on workers
admin.GET("/:id/backends", localai.ListBackendsOnNodeEndpoint(unloader))

View File

@@ -60,6 +60,14 @@ type TTSRequest struct {
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"` // (optional) enable streaming TTS
SampleRate int `json:"sample_rate,omitempty" yaml:"sample_rate,omitempty"` // (optional) desired output sample rate
// Instructions is a free-form, per-request style/voice description. It maps to
// the OpenAI `instructions` field and is forwarded to the backend so expressive
// TTS models (e.g. Qwen3-TTS CustomVoice/VoiceDesign) can vary tone or designed
// voice per request instead of only via the static YAML option.
Instructions string `json:"instructions,omitempty" yaml:"instructions,omitempty"`
// Params carries optional, backend-specific per-request generation parameters
// (LocalAI extension, e.g. Chatterbox exaggeration/cfg_weight/temperature).
Params map[string]string `json:"params,omitempty" yaml:"params,omitempty"`
}
// @Description VAD request body

View File

@@ -180,18 +180,21 @@ func (s *GalleryStore) Cancel(id string) error {
return s.UpdateStatus(id, "cancelled", "")
}
// CleanStale marks abandoned in-progress operations as failed.
// Should be called on startup to recover from crashed instances that
// left records in pending/downloading/processing state.
func (s *GalleryStore) CleanStale(age time.Duration) error {
// CleanStale marks abandoned in-progress operations as failed and returns the
// number of rows reaped. Called on startup AND periodically to recover from
// crashed/restarted instances that left records in pending/downloading/
// processing state — an op orphaned after startup would otherwise linger
// "processing" until the next restart.
func (s *GalleryStore) CleanStale(age time.Duration) (int64, error) {
cutoff := time.Now().Add(-age)
return s.db.Model(&GalleryOperationRecord{}).
res := s.db.Model(&GalleryOperationRecord{}).
Where("updated_at < ? AND status IN ?", cutoff, activeStatuses).
Updates(map[string]any{
"status": "failed",
"error": "stale operation cleaned up on startup",
"error": "stale operation reaped (abandoned by a crashed or restarted instance)",
"updated_at": time.Now(),
}).Error
})
return res.RowsAffected, res.Error
}
// CleanOld removes operations older than the given duration.

View File

@@ -71,7 +71,7 @@ func (g *GalleryService) backendHandler(op *ManagementOp[gallery.GalleryBackend,
var err error
if op.Upgrade {
err = g.backendManager.UpgradeBackend(ctx, op.GalleryElementName, progressCallback)
err = g.backendManager.UpgradeBackend(ctx, op.ID, op.GalleryElementName, progressCallback)
} else if op.Delete {
err = g.backendManager.DeleteBackend(op.GalleryElementName)
} else {

View File

@@ -0,0 +1,106 @@
package galleryop_test
import (
"context"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/distributed"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/testutil"
)
// Reproduces "a cancelled/orphaned op resurrects as 'processing' after a pod
// restart". CancelOperation flipped the in-memory status to cancelled and
// broadcast a NATS event, but never persisted the terminal status to the
// gallery store. On the next replica restart the still-"pending" row hydrated
// straight back into processingBackends and the UI spun again. CancelOperation
// must persist the cancellation so it survives a restart.
var _ = Describe("GalleryService.CancelOperation persistence", func() {
It("persists the cancelled status to the gallery store", func() {
db := testutil.SetupTestDB()
store, err := distributed.NewGalleryStore(db)
Expect(err).ToNot(HaveOccurred())
// Seed an in-flight op as if a replica was mid-install.
Expect(store.Create(&distributed.GalleryOperationRecord{
ID: "op-cancel",
GalleryElementName: "llama-cpp-development",
OpType: "backend_install",
Status: "pending",
Progress: 0,
})).To(Succeed())
svc := galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
svc.SetGalleryStore(store)
// Make the op locally cancellable so CancelOperation proceeds.
svc.StoreCancellation("op-cancel", context.CancelFunc(func() {}))
Expect(svc.CancelOperation("op-cancel")).To(Succeed())
// The persisted row must now be terminal — otherwise it re-hydrates as
// pending on the next restart.
rec, err := store.Get("op-cancel")
Expect(err).ToNot(HaveOccurred())
Expect(rec.Status).To(Equal("cancelled"))
// And a fresh service hydrating from the store must NOT see it as active.
fresh := galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
fresh.SetGalleryStore(store)
Expect(fresh.Hydrate()).To(Succeed())
Expect(fresh.GetStatus("op-cancel")).To(BeNil(),
"a cancelled op must not hydrate back as active after a restart")
})
})
// Reproduces "an op orphaned by a replica that died mid-flight stays 'pending'
// forever". CleanStale (which marks abandoned active ops failed) only ran once
// on startup, so an op orphaned AFTER startup was never reaped until the next
// restart. The service must reap stale ops on an interval, not just at boot.
var _ = Describe("GalleryService.ReapStaleOperations", func() {
It("marks abandoned active ops terminal once they pass the age cutoff", func() {
db := testutil.SetupTestDB()
store, err := distributed.NewGalleryStore(db)
Expect(err).ToNot(HaveOccurred())
Expect(store.Create(&distributed.GalleryOperationRecord{
ID: "orphan-op",
GalleryElementName: "llama-cpp-development",
OpType: "backend_install",
Status: "pending",
Progress: 0,
})).To(Succeed())
// Force the row's updated_at into the past so it is older than the cutoff.
Expect(db.Exec(
"UPDATE gallery_operations SET updated_at = ? WHERE id = ?",
time.Now().Add(-1*time.Hour), "orphan-op",
).Error).To(Succeed())
// A fresh, still-progressing op must NOT be reaped.
Expect(store.Create(&distributed.GalleryOperationRecord{
ID: "live-op",
GalleryElementName: "vllm-development",
OpType: "backend_install",
Status: "downloading",
Progress: 50,
})).To(Succeed())
svc := galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
svc.SetGalleryStore(store)
reaped, err := svc.ReapStaleOperations(30 * time.Minute)
Expect(err).ToNot(HaveOccurred())
Expect(reaped).To(Equal(int64(1)))
orphan, err := store.Get("orphan-op")
Expect(err).ToNot(HaveOccurred())
Expect(orphan.Status).To(Equal("failed"))
live, err := store.Get("live-op")
Expect(err).ToNot(HaveOccurred())
Expect(live.Status).To(Equal("downloading"), "a recently-updated op must not be reaped")
})
})

View File

@@ -20,7 +20,7 @@ type BackendManager interface {
InstallBackend(ctx context.Context, op *ManagementOp[gallery.GalleryBackend, any], progressCb ProgressCallback) error
DeleteBackend(name string) error
ListBackends() (gallery.SystemBackends, error)
UpgradeBackend(ctx context.Context, name string, progressCb ProgressCallback) error
UpgradeBackend(ctx context.Context, opID, name string, progressCb ProgressCallback) error
CheckUpgrades(ctx context.Context) (map[string]gallery.UpgradeInfo, error)
// IsDistributed reports whether installs fan out across worker nodes.
// The HTTP layer uses this to refuse hardware-specific (non-meta) installs

View File

@@ -96,7 +96,10 @@ func (b *LocalBackendManager) ListBackends() (gallery.SystemBackends, error) {
return gallery.ListSystemBackends(b.systemState)
}
func (b *LocalBackendManager) UpgradeBackend(ctx context.Context, name string, progressCb ProgressCallback) error {
// UpgradeBackend ignores opID: a single-node install reports progress through
// the local progressCb already; opID only matters for distributed per-node
// streaming (see DistributedBackendManager.UpgradeBackend).
func (b *LocalBackendManager) UpgradeBackend(ctx context.Context, _ string, name string, progressCb ProgressCallback) error {
return gallery.UpgradeBackend(ctx, b.systemState, b.modelLoader, b.backendGalleries, name, progressCb, b.requireBackendIntegrity)
}

View File

@@ -0,0 +1,92 @@
package galleryop_test
import (
"errors"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/galleryop"
)
// These specs reproduce the distributed "Reinstall spins forever" bug:
// processingBackends (the UI spinner source) is built from OpCache.GetStatus,
// which historically returned every cached op unconditionally. Cleanup only
// happened when a client polled /api/backends/job/:uid, but the Manage-page
// Reinstall/Upgrade buttons never poll, so a completed install stayed in
// processingBackends forever. GetStatus must self-evict terminal ops.
var _ = Describe("OpCache.GetStatus eviction", func() {
var (
svc *galleryop.GalleryService
cache *galleryop.OpCache
)
BeforeEach(func() {
svc = galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
cache = galleryop.NewOpCache(svc)
})
It("keeps an op that is still processing", func() {
cache.SetBackend("llama-cpp", "uuid-1")
svc.UpdateStatus("uuid-1", &galleryop.OpStatus{Message: "processing backend: llama-cpp", Progress: 0})
processing, _ := cache.GetStatus()
Expect(processing).To(HaveKeyWithValue("llama-cpp", "uuid-1"))
Expect(cache.Exists("llama-cpp")).To(BeTrue())
})
It("evicts a completed op so it no longer shows as processing", func() {
cache.SetBackend("llama-cpp", "uuid-1")
svc.UpdateStatus("uuid-1", &galleryop.OpStatus{Processed: true, Progress: 100, Message: "completed"})
processing, _ := cache.GetStatus()
Expect(processing).NotTo(HaveKey("llama-cpp"))
Expect(cache.Exists("llama-cpp")).To(BeFalse())
})
It("keeps a failed op so the operations panel can surface the error and offer Dismiss", func() {
cache.SetBackend("piper", "uuid-2")
svc.UpdateStatus("uuid-2", &galleryop.OpStatus{Processed: true, Error: errors.New("boom")})
processing, _ := cache.GetStatus()
Expect(processing).To(HaveKeyWithValue("piper", "uuid-2"))
Expect(cache.Exists("piper")).To(BeTrue())
})
It("evicts a cancelled op", func() {
cache.SetBackend("vllm", "uuid-3")
svc.UpdateStatus("uuid-3", &galleryop.OpStatus{Processed: true, Cancelled: true, Message: "cancelled"})
processing, _ := cache.GetStatus()
Expect(processing).NotTo(HaveKey("vllm"))
})
It("does not evict an op with no status yet (queued)", func() {
cache.SetBackend("whisper", "uuid-4")
processing, taskTypes := cache.GetStatus()
Expect(processing).To(HaveKeyWithValue("whisper", "uuid-4"))
Expect(taskTypes).To(HaveKeyWithValue("whisper", "Waiting"))
})
// Regression guard: GetStatus is called concurrently by four HTTP handlers
// (~1s poll). An earlier version evicted by deleting from m.Map() — which
// returns the live internal map by reference — causing a fatal
// "concurrent map writes" crash. Run under -race; must not panic or race.
It("is safe under concurrent GetStatus + Set/complete", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
for i := 0; i < 2000; i++ {
_, _ = cache.GetStatus()
}
close(done)
}()
for i := 0; i < 2000; i++ {
id := "uuid-c"
cache.SetBackend("concurrent-backend", id)
// Half the time mark it completed so GetStatus evicts it.
if i%2 == 0 {
svc.UpdateStatus(id, &galleryop.OpStatus{Processed: true, Progress: 100, Message: "completed"})
}
_, _ = cache.GetStatus()
}
<-done
})
})

View File

@@ -408,12 +408,34 @@ func (m *OpCache) Exists(key string) bool {
}
func (m *OpCache) GetStatus() (map[string]string, map[string]string) {
processingModelsData := m.Map()
taskTypes := map[string]string{}
processingModelsData := map[string]string{}
for k, v := range processingModelsData {
// Iterate a snapshot (Keys() copies) and build a fresh result map. We must
// NOT delete from m.Map() during the range: Map() returns the live internal
// map by reference, so a bare delete here would be an unsynchronized write
// to a map four HTTP handlers read every ~1s — a concurrent-map-write crash.
// Collect evictions and apply them via the locked DeleteUUID after the loop.
var evict []string
for _, k := range m.status.Keys() {
v := m.status.Get(k)
if v == "" {
continue // raced with a concurrent Delete
}
status := m.galleryService.GetStatus(v)
// Terminal ops must not keep showing as "processing". Cleanup was
// previously only triggered by a client polling /api/backends/job/:uid,
// but the Manage-page Reinstall/Upgrade buttons never poll, so completed
// ops leaked into processingBackends forever and the card spun
// "reinstalling" indefinitely. Evict here on the list read (the UI always
// calls this). We only evict SUCCESS/cancelled terminals (Error == nil):
// failed ops are kept so /api/operations can surface the error and offer
// Dismiss. DeleteUUID broadcasts the eviction so peer replicas converge.
if status != nil && status.Processed && status.Error == nil {
evict = append(evict, v)
continue
}
processingModelsData[k] = v
taskTypes[k] = "Installation"
if status != nil && status.Deletion {
taskTypes[k] = "Deletion"
@@ -422,6 +444,10 @@ func (m *OpCache) GetStatus() (map[string]string, map[string]string) {
}
}
for _, v := range evict {
m.DeleteUUID(v)
}
return processingModelsData, taskTypes
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"sync"
"time"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
@@ -31,9 +32,9 @@ type GalleryService struct {
// natsClient is the wider MessagingClient (Publisher + subscribe methods)
// when wired by the distributed startup path; broadcastSubs holds the
// progress + cancel subscriptions opened by SubscribeBroadcasts.
natsClient messaging.MessagingClient
galleryStore *distributed.GalleryStore
broadcastSubs []messaging.Subscription
natsClient messaging.MessagingClient
galleryStore *distributed.GalleryStore
broadcastSubs []messaging.Subscription
// OnBackendOpCompleted is fired after every successful install/upgrade/delete
// on the backend channel. The Application wires this to UpgradeChecker.TriggerCheck
@@ -274,6 +275,29 @@ func (g *GalleryService) GetAllStatus() map[string]*OpStatus {
return g.statuses
}
// ReapStaleOperations marks abandoned in-progress operations (pending/
// downloading/processing) older than `age` as failed, so an op orphaned by a
// replica that died mid-flight does not linger as "processing" forever. The
// store's CleanStale runs once on startup; this exposes it for periodic
// invocation (a post-startup orphan is otherwise not reaped until the next
// restart). No-op when no gallery store is wired. Returns rows reaped.
func (g *GalleryService) ReapStaleOperations(age time.Duration) (int64, error) {
g.Lock()
store := g.galleryStore
g.Unlock()
if store == nil {
return 0, nil
}
n, err := store.CleanStale(age)
if err != nil {
return 0, err
}
if n > 0 {
xlog.Info("Reaped stale gallery operations", "count", n)
}
return n, nil
}
// CancelOperation cancels an in-progress operation by its ID.
//
// In distributed mode the UI's cancel click may land on a different replica
@@ -295,6 +319,7 @@ func (g *GalleryService) CancelOperation(id string) error {
}
nc := g.natsClient
store := g.galleryStore
if !localExists && nc == nil {
g.Unlock()
@@ -315,6 +340,17 @@ func (g *GalleryService) CancelOperation(id string) error {
}
g.Unlock()
// Persist the terminal status so the cancel survives a restart. Without
// this the row stays in its active state and re-hydrates straight back into
// processingBackends on the next replica boot — the UI spins again on an op
// the admin already cancelled. The peer that broadcasts wins the write; a
// no-op when standalone (store nil).
if store != nil {
if err := store.Cancel(id); err != nil {
xlog.Warn("Failed to persist gallery operation cancellation", "op_id", id, "error", err)
}
}
// I/O and user-provided callback after Unlock — the cancel-wildcard
// subscriber loops back into applyCancel on this same replica, which
// would otherwise deadlock on g.Mutex.

View File

@@ -2,15 +2,22 @@ package messaging
import (
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/xlog"
"github.com/nats-io/nats.go"
"github.com/nats-io/nkeys"
)
// subscribeConfirmTimeout bounds the server round-trip used to detect whether a
// subscription was rejected (e.g. by JWT permissions) before returning to the caller.
const subscribeConfirmTimeout = 5 * time.Second
// Client wraps a NATS connection and provides helpers for pub/sub and queue subscriptions.
type Client struct {
conn *nats.Conn
@@ -18,8 +25,13 @@ type Client struct {
}
// New creates a new NATS client with auto-reconnect.
func New(url string) (*Client, error) {
nc, err := nats.Connect(url,
func New(url string, opts ...Option) (*Client, error) {
var cfg connectConfig
for _, o := range opts {
o(&cfg)
}
natsOpts := []nats.Option{
nats.RetryOnFailedConnect(true),
nats.MaxReconnects(-1),
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
@@ -33,7 +45,60 @@ func New(url string) (*Client, error) {
nats.ClosedHandler(func(_ *nats.Conn) {
xlog.Info("NATS connection closed")
}),
)
// Surface async errors (notably permission violations) that NATS would
// otherwise deliver silently. A subscription the server rejects for a
// JWT permission means the worker never receives those messages, so make
// it loud rather than letting the feature fail invisibly.
nats.ErrorHandler(func(_ *nats.Conn, sub *nats.Subscription, err error) {
subject := ""
if sub != nil {
subject = sub.Subject
}
if errors.Is(err, nats.ErrPermissionViolation) {
xlog.Error("NATS permission violation — check JWT pub/sub allow lists", "subject", subject, "error", err)
return
}
xlog.Warn("NATS async error", "subject", subject, "error", err)
}),
}
switch {
case cfg.jwtProvider != nil:
// Fetch creds on every (re)connect so a refresh loop can rotate the JWT
// before expiry; the server expiring the old JWT triggers a reconnect
// that transparently picks up the new one.
natsOpts = append(natsOpts, nats.UserJWT(
func() (string, error) {
jwt, _ := cfg.jwtProvider()
if jwt == "" {
return "", fmt.Errorf("no NATS user JWT available")
}
return jwt, nil
},
func(nonce []byte) ([]byte, error) {
_, seed := cfg.jwtProvider()
kp, err := nkeys.FromSeed([]byte(seed))
if err != nil {
return nil, fmt.Errorf("loading NATS user seed: %w", err)
}
defer kp.Wipe()
return kp.Sign(nonce)
},
))
case cfg.userJWT != "" && cfg.userSeed != "":
natsOpts = append(natsOpts, nats.UserJWTAndSeed(cfg.userJWT, cfg.userSeed))
}
if cfg.tls.Enabled() {
if err := cfg.tls.Validate(); err != nil {
return nil, err
}
tlsOpts, err := cfg.tls.natsOptions()
if err != nil {
return nil, err
}
natsOpts = append(natsOpts, tlsOpts...)
}
nc, err := nats.Connect(url, natsOpts...)
if err != nil {
return nil, fmt.Errorf("connecting to NATS at %s: %w", sanitize.URL(url), err)
}
@@ -54,23 +119,67 @@ func (c *Client) Publish(subject string, data any) error {
// Subscribe creates a subscription on the given subject. All subscribers receive every message.
func (c *Client) Subscribe(subject string, handler func([]byte)) (Subscription, error) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data)
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
return conn.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data)
})
})
}
// QueueSubscribe creates a queue subscription. Within the same queue group,
// only one subscriber receives each message (load-balanced).
func (c *Client) QueueSubscribe(subject, queue string, handler func([]byte)) (Subscription, error) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
handler(msg.Data)
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
handler(msg.Data)
})
})
}
// confirmSubscription creates a subscription via mk and forces a server
// round-trip so that a permissions violation — which NATS otherwise reports
// only asynchronously — is returned to the caller synchronously. The server
// emits the "-ERR Permissions Violation" for a rejected SUB before the PONG
// that satisfies the flush, so by the time FlushTimeout returns the violation
// is recorded as the connection's last error. Without this, a worker whose JWT
// lacks a subject gets a non-nil subscription that never receives a message,
// turning a permission misconfiguration into a silent failure.
func (c *Client) confirmSubscription(subject string, mk func(*nats.Conn) (*nats.Subscription, error)) (Subscription, error) {
c.mu.RLock()
conn := c.conn
c.mu.RUnlock()
if conn == nil {
return nil, fmt.Errorf("subscribe to %s: nil NATS connection", subject)
}
sub, err := mk(conn)
if err != nil {
return nil, err
}
// A failed flush here means we could not round-trip to the server (not yet
// connected, reconnecting, slow link). RetryOnFailedConnect intentionally
// buffers subscriptions across that gap, so do NOT fail — keep the
// subscription and let it replay on (re)connect; a later permission
// violation is still logged by the async error handler in New.
if err := conn.FlushTimeout(subscribeConfirmTimeout); err != nil {
xlog.Debug("Could not confirm NATS subscription (will replay on connect)", "subject", subject, "error", err)
return sub, nil
}
// Flush succeeded, so any permission violation for this SUB has already been
// recorded as the connection's last error (the server emits it before the
// PONG). LastError is per-connection; match the exact quoted subject the
// server echoes ("Subscription to \"<subject>\"") so a stale violation for
// another subject can't be mis-attributed here.
if lerr := conn.LastError(); lerr != nil &&
errors.Is(lerr, nats.ErrPermissionViolation) &&
strings.Contains(lerr.Error(), `Subscription to "`+subject+`"`) {
_ = sub.Unsubscribe()
return nil, fmt.Errorf("subscription to %s denied by NATS server (check JWT sub allow list): %w", subject, lerr)
}
return sub, nil
}
// Request sends a request and waits for a reply (request-reply pattern).
// Returns the raw reply data.
func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
@@ -86,15 +195,15 @@ func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]
// SubscribeReply creates a subscription that supports replying to requests.
// The handler receives the raw request data and the reply subject.
func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data, func(replyData []byte) {
if msg.Reply != "" {
if err := msg.Respond(replyData); err != nil {
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
return conn.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data, func(replyData []byte) {
if msg.Reply != "" {
if err := msg.Respond(replyData); err != nil {
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
}
}
}
})
})
})
}
@@ -102,15 +211,15 @@ func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply
// QueueSubscribeReply creates a queue subscription that supports replying to requests.
// Load-balanced across subscribers in the same queue group, with request-reply support.
func (c *Client) QueueSubscribeReply(subject, queue string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
handler(msg.Data, func(replyData []byte) {
if msg.Reply != "" {
if err := msg.Respond(replyData); err != nil {
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
handler(msg.Data, func(replyData []byte) {
if msg.Reply != "" {
if err := msg.Respond(replyData); err != nil {
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
}
}
}
})
})
})
}

View File

@@ -0,0 +1,34 @@
package messaging
// Option configures NATS client connection behavior.
type Option func(*connectConfig)
// CredentialProvider returns the NATS user JWT and signing seed to use for the
// next (re)connect. It is consulted on every connection attempt, so a refresh
// loop can rotate credentials before they expire and the connection picks them
// up automatically when the server expires the old JWT and triggers a reconnect.
type CredentialProvider func() (jwt, seed string)
type connectConfig struct {
userJWT string
userSeed string
jwtProvider CredentialProvider
tls TLSFiles
}
// WithUserJWT connects using a static NATS user JWT and signing seed (UserJWTAndSeed).
func WithUserJWT(jwt, seed string) Option {
return func(c *connectConfig) {
c.userJWT = jwt
c.userSeed = seed
}
}
// WithUserJWTProvider connects using credentials fetched from provider on each
// (re)connect, enabling JWT rotation without dropping the client. Takes
// precedence over WithUserJWT when both are set.
func WithUserJWTProvider(provider CredentialProvider) Option {
return func(c *connectConfig) {
c.jwtProvider = provider
}
}

View File

@@ -194,6 +194,14 @@ type BackendUpgradeRequest struct {
// but the field lets future per-replica metadata (e.g. progress reporting
// scoped to a slot) ride the same wire without a v3 type.
ReplicaIndex int32 `json:"replica_index,omitempty"`
// OpID identifies the admin-side operation. When non-empty the worker
// publishes BackendInstallProgressEvent values to
// SubjectNodeBackendInstallProgress(nodeID, OpID) while the force-reinstall
// runs, so the master can stream per-node progress for upgrades exactly as
// it already does for installs (an upgrade IS a force-reinstall, so the
// install-progress subject is reused rather than minting a new one — no new
// NATS permission or rolling-update compat surface). Empty on legacy callers.
OpID string `json:"op_id,omitempty"`
}
// BackendUpgradeReply mirrors BackendInstallReply minus Address — upgrade does

View File

@@ -0,0 +1,68 @@
package messaging
import (
"fmt"
"os"
"github.com/nats-io/nats.go"
)
// TLSFiles holds PEM paths for NATS TLS / mTLS. Cert and key must be set together.
// Use tls:// in LOCALAI_NATS_URL; CA and client cert paths are optional extras.
type TLSFiles struct {
CA string // LOCALAI_NATS_TLS_CA — private CA for server verification
Cert string // LOCALAI_NATS_TLS_CERT — client certificate (mTLS)
Key string // LOCALAI_NATS_TLS_KEY — client private key
}
// Enabled reports whether any TLS file path is configured.
func (f TLSFiles) Enabled() bool {
return f.CA != "" || f.Cert != "" || f.Key != ""
}
// Validate checks path pairing and that files exist.
func (f TLSFiles) Validate() error {
if f.Cert != "" && f.Key == "" {
return fmt.Errorf("LOCALAI_NATS_TLS_KEY is required when LOCALAI_NATS_TLS_CERT is set")
}
if f.Key != "" && f.Cert == "" {
return fmt.Errorf("LOCALAI_NATS_TLS_CERT is required when LOCALAI_NATS_TLS_KEY is set")
}
for _, path := range []struct {
name, path string
}{
{"LOCALAI_NATS_TLS_CA", f.CA},
{"LOCALAI_NATS_TLS_CERT", f.Cert},
{"LOCALAI_NATS_TLS_KEY", f.Key},
} {
if path.path == "" {
continue
}
if _, err := os.Stat(path.path); err != nil {
return fmt.Errorf("%s: %w", path.name, err)
}
}
return nil
}
// natsOptions builds nats-go TLS options. Call Validate first.
func (f TLSFiles) natsOptions() ([]nats.Option, error) {
if !f.Enabled() {
return nil, nil
}
opts := []nats.Option{nats.Secure()}
if f.CA != "" {
opts = append(opts, nats.RootCAs(f.CA))
}
if f.Cert != "" {
opts = append(opts, nats.ClientCert(f.Cert, f.Key))
}
return opts, nil
}
// WithTLS configures CA and/or client certificate paths for the NATS connection.
func WithTLS(files TLSFiles) Option {
return func(c *connectConfig) {
c.tls = files
}
}

View File

@@ -0,0 +1,25 @@
package messaging_test
import (
"os"
"path/filepath"
"github.com/mudler/LocalAI/core/services/messaging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("TLSFiles", func() {
It("requires cert and key together", func() {
Expect((messaging.TLSFiles{Cert: "/tmp/c.pem"}).Validate()).To(HaveOccurred())
Expect((messaging.TLSFiles{Key: "/tmp/k.pem"}).Validate()).To(HaveOccurred())
})
It("validates files exist", func() {
dir := GinkgoT().TempDir()
ca := filepath.Join(dir, "ca.pem")
Expect(os.WriteFile(ca, []byte("x"), 0600)).To(Succeed())
Expect((messaging.TLSFiles{CA: ca}).Validate()).To(Succeed())
})
})

View File

@@ -61,6 +61,17 @@ func StartFileTransferServerWithListener(lis net.Listener, stagingDir, modelsDir
return nil, fmt.Errorf("creating staging dir %s: %w", stagingDir, err)
}
// An empty token makes checkBearerToken fail open: every /v1/files,
// /v1/files-list and /v1/backend-logs request is served unauthenticated,
// granting read/write to the staging/models/data directories to anyone who
// can reach this port. Surface that loudly — the worker process does not
// run DistributedConfig.Validate(), so this is the only signal an operator
// gets. Set LOCALAI_REGISTRATION_TOKEN (and LOCALAI_REGISTRATION_REQUIRE_AUTH
// to fail closed) to protect it.
if token == "" {
xlog.Warn("HTTP file transfer server starting WITHOUT a registration token — read/write to models/staging/data is unauthenticated for anyone who can reach this port; set LOCALAI_REGISTRATION_TOKEN")
}
mux := http.NewServeMux()
// PUT /v1/files/{key} — upload file

View File

@@ -7,6 +7,7 @@ import (
"encoding/hex"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
@@ -893,3 +894,50 @@ func sha256Hex(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}
var _ = Describe("StartFileTransferServerWithListener", func() {
start := func(token string) (string, func()) {
lis, err := net.Listen("tcp", "127.0.0.1:0")
Expect(err).NotTo(HaveOccurred())
staging := GinkgoT().TempDir()
models := GinkgoT().TempDir()
data := GinkgoT().TempDir()
srv, err := StartFileTransferServerWithListener(lis, staging, models, data, token, 0)
Expect(err).NotTo(HaveOccurred())
base := "http://" + lis.Addr().String()
return base, func() { ShutdownFileTransferServer(srv) }
}
// Exercises the empty-token fail-open warning branch: the server serves
// file requests with no Authorization header at all.
It("serves unauthenticated when started without a token", func() {
base, stop := start("")
defer stop()
resp, err := http.Get(base + "/v1/files/missing.bin")
Expect(err).NotTo(HaveOccurred())
defer func() { _ = resp.Body.Close() }()
// No 401 — the empty token fails open. The file is absent so we get 404.
Expect(resp.StatusCode).To(Equal(http.StatusNotFound))
})
It("rejects requests without the bearer token when a token is set", func() {
base, stop := start("s3cret")
defer stop()
resp, err := http.Get(base + "/v1/files/missing.bin")
Expect(err).NotTo(HaveOccurred())
defer func() { _ = resp.Body.Close() }()
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
It("serves the unauthenticated health endpoints regardless of token", func() {
base, stop := start("s3cret")
defer stop()
resp, err := http.Get(base + "/healthz")
Expect(err).NotTo(HaveOccurred())
defer func() { _ = resp.Body.Close() }()
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})
})

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
ggrpc "google.golang.org/grpc"
@@ -64,64 +65,95 @@ func (c *InFlightTrackingClient) track(ctx context.Context) func() {
}
}
// reconcile self-heals stale routing: when a backend reports that the model is
// no longer loaded (the process survived but the model was evicted, while the
// registry still lists it as loaded), it drops the replica row so the next
// request triggers a fresh load instead of routing back here. Without this the
// model stays unreachable until the controller restarts. The original error is
// returned unchanged.
func (c *InFlightTrackingClient) reconcile(err error) error {
if !grpcerrors.IsModelNotLoaded(err) {
return err
}
rmCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if rmErr := c.registry.RemoveNodeModel(rmCtx, c.nodeID, c.modelName, c.replicaIndex); rmErr != nil {
xlog.Warn("Failed to drop stale replica after model-not-loaded",
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex, "error", rmErr)
} else {
xlog.Warn("Backend reports model not loaded; dropped stale replica so the next request reloads",
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex)
}
return err
}
// --- Tracked inference methods ---
func (c *InFlightTrackingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) {
defer c.track(ctx)()
return c.Backend.Predict(ctx, in, opts...)
reply, err := c.Backend.Predict(ctx, in, opts...)
return reply, c.reconcile(err)
}
func (c *InFlightTrackingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.Backend.PredictStream(ctx, in, f, opts...)
return c.reconcile(c.Backend.PredictStream(ctx, in, f, opts...))
}
func (c *InFlightTrackingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
defer c.track(ctx)()
return c.Backend.Embeddings(ctx, in, opts...)
res, err := c.Backend.Embeddings(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.GenerateImage(ctx, in, opts...)
res, err := c.Backend.GenerateImage(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.GenerateVideo(ctx, in, opts...)
res, err := c.Backend.GenerateVideo(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.TTS(ctx, in, opts...)
res, err := c.Backend.TTS(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.Backend.TTSStream(ctx, in, f, opts...)
return c.reconcile(c.Backend.TTSStream(ctx, in, f, opts...))
}
func (c *InFlightTrackingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.SoundGeneration(ctx, in, opts...)
res, err := c.Backend.SoundGeneration(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
defer c.track(ctx)()
return c.Backend.AudioTranscription(ctx, in, opts...)
res, err := c.Backend.AudioTranscription(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)
return c.reconcile(c.Backend.AudioTranscriptionStream(ctx, in, f, opts...))
}
func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
defer c.track(ctx)()
return c.Backend.Detect(ctx, in, opts...)
res, err := c.Backend.Detect(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) {
defer c.track(ctx)()
return c.Backend.Rerank(ctx, in, opts...)
res, err := c.Backend.Rerank(ctx, in, opts...)
return res, c.reconcile(err)
}

View File

@@ -20,9 +20,17 @@ type fakeInFlightTracker struct {
mu sync.Mutex
increments int
decrements int
removed int
incrementErr error
}
func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.removed++
return nil
}
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
@@ -295,4 +303,33 @@ var _ = Describe("InFlightTrackingClient", func() {
Expect(tracker.decrements).To(Equal(1))
})
})
Describe("stale model reload (self-heal)", func() {
It("removes the replica when the backend reports the model is not loaded", func() {
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})
It("keeps the replica on an unrelated error", func() {
backend.predictErr = fmt.Errorf("context deadline exceeded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})
It("does not remove on success", func() {
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})
It("self-heals on a streamed call too", func() {
backend.streamErr = fmt.Errorf("whisper: model not loaded")
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})
})
})

View File

@@ -78,6 +78,9 @@ type ModelLookup interface {
type InFlightTracker interface {
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
// RemoveNodeModel drops a stale replica row so the next request reloads the
// model instead of routing back to a node where it is no longer loaded.
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
}
// NodeManager is used by HTTP endpoints for node registration and lifecycle.

View File

@@ -533,7 +533,7 @@ func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *gall
// backend.upgrade, we try the legacy backend.install Force=true path so a
// new master + old worker still converges. Drop the fallback once every
// worker in the fleet is on 2026-05-08 or newer.
func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name string, progressCb galleryop.ProgressCallback) error {
func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, opID, name string, progressCb galleryop.ProgressCallback) error {
galleriesJSON, _ := json.Marshal(d.backendGalleries)
installed, err := d.ListBackends()
@@ -549,17 +549,39 @@ func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name str
targetNodeIDs[n.NodeID] = true
}
// Empty opID: the caller (galleryop) doesn't thread an op ID into
// UpgradeBackend today, so we can't tag per-node sink writes with the
// right OpStatus key. Until the upgrade path takes a ManagementOp the
// way InstallBackend does, the sink stays no-op here.
result, err := d.enqueueAndDrainBackendOp(ctx, "", OpBackendUpgrade, name, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
reply, err := d.adapter.UpgradeBackend(node.ID, name, string(galleriesJSON), "", "", "", 0)
result, err := d.enqueueAndDrainBackendOp(ctx, opID, OpBackendUpgrade, name, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
// Per-node progress sink: fan each worker download tick into the legacy
// single-bar progressCb and the per-node OpStatus.Nodes view, exactly as
// InstallBackend does. Defined per-node so each closure captures its own
// node.Name. Without this an upgrade blocks opaque at progress 0 for the
// whole 15m round-trip (the original "reinstalling but nothing happens").
onProgress := func(ev messaging.BackendInstallProgressEvent) {
if progressCb != nil {
progressCb(ev.FileName, ev.Current, ev.Total, ev.Percentage)
}
if d.progressSink != nil && opID != "" {
d.progressSink.UpdateNodeProgress(opID, ev.NodeID, galleryop.NodeProgress{
NodeID: ev.NodeID,
NodeName: node.Name,
Status: galleryop.NodeStatusDownloading,
FileName: ev.FileName,
Current: ev.Current,
Total: ev.Total,
Percentage: ev.Percentage,
Phase: ev.Phase,
})
}
}
var onProgressArg func(messaging.BackendInstallProgressEvent)
if progressCb != nil || d.progressSink != nil {
onProgressArg = onProgress
}
reply, err := d.adapter.UpgradeBackend(node.ID, name, string(galleriesJSON), "", "", "", 0, opID, onProgressArg)
if err != nil {
// Rolling-update fallback: an older worker doesn't know
// backend.upgrade. Try the legacy install-with-force path.
if errors.Is(err, nats.ErrNoResponders) {
instReply, instErr := d.adapter.installWithForceFallback(node.ID, name, string(galleriesJSON), "", "", "", 0)
instReply, instErr := d.adapter.installWithForceFallback(node.ID, name, string(galleriesJSON), "", "", "", 0, opID, onProgressArg)
if instErr != nil {
return instErr
}

View File

@@ -317,7 +317,7 @@ func (stubLocalBackendManager) DeleteBackend(_ string) error { return gallery.Er
func (stubLocalBackendManager) ListBackends() (gallery.SystemBackends, error) {
return gallery.SystemBackends{}, nil
}
func (stubLocalBackendManager) UpgradeBackend(_ context.Context, _ string, _ galleryop.ProgressCallback) error {
func (stubLocalBackendManager) UpgradeBackend(_ context.Context, _ string, _ string, _ galleryop.ProgressCallback) error {
return nil
}
func (stubLocalBackendManager) CheckUpgrades(_ context.Context) (map[string]gallery.UpgradeInfo, error) {
@@ -782,7 +782,7 @@ var _ = Describe("DistributedBackendManager", func() {
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n2.ID),
messaging.BackendUpgradeReply{Success: false, Error: "registry unauthorized"})
err := mgr.UpgradeBackend(ctx, "vllm-development", nil)
err := mgr.UpgradeBackend(ctx, "", "vllm-development", nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("worker-a"))
Expect(err.Error()).To(ContainSubstring("image manifest not found"))
@@ -797,7 +797,7 @@ var _ = Describe("DistributedBackendManager", func() {
scriptInstalled("vllm-development", n1.ID)
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n1.ID),
messaging.BackendUpgradeReply{Success: true})
Expect(mgr.UpgradeBackend(ctx, "vllm-development", nil)).To(Succeed())
Expect(mgr.UpgradeBackend(ctx, "", "vllm-development", nil)).To(Succeed())
})
})
@@ -819,7 +819,7 @@ var _ = Describe("DistributedBackendManager", func() {
// if the manager attempts it, the scripted-client default returns
// fakeNoRespondersErr and the assertion below fails loudly.
Expect(mgr.UpgradeBackend(ctx, "cpu-insightface-development", nil)).To(Succeed())
Expect(mgr.UpgradeBackend(ctx, "", "cpu-insightface-development", nil)).To(Succeed())
mc.mu.Lock()
defer mc.mu.Unlock()
@@ -835,7 +835,7 @@ var _ = Describe("DistributedBackendManager", func() {
n1 := registerHealthyBackend("worker-a", "10.0.0.1:50051")
scriptNoBackends(n1.ID)
err := mgr.UpgradeBackend(ctx, "vllm-development", nil)
err := mgr.UpgradeBackend(ctx, "", "vllm-development", nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("not installed on any node"))
@@ -865,7 +865,7 @@ var _ = Describe("DistributedBackendManager", func() {
func(req messaging.BackendInstallRequest) bool { return req.Force },
messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"})
Expect(mgr.UpgradeBackend(ctx, "vllm-development", nil)).To(Succeed())
Expect(mgr.UpgradeBackend(ctx, "", "vllm-development", nil)).To(Succeed())
})
It("returns the upgrade error when it is not ErrNoResponders", func() {
@@ -875,7 +875,7 @@ var _ = Describe("DistributedBackendManager", func() {
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n.ID),
messaging.BackendUpgradeReply{Success: false, Error: "disk full"})
err := mgr.UpgradeBackend(ctx, "vllm-development", nil)
err := mgr.UpgradeBackend(ctx, "", "vllm-development", nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("disk full"))
})

View File

@@ -0,0 +1,135 @@
package nodes
import (
"context"
"runtime"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/services/testutil"
)
// These specs reproduce the distributed "pending ops behind dead nodes leak
// forever" bug. ListDuePendingBackendOps only returns rows whose node is
// StatusHealthy, so an op queued against a node that goes offline (heartbeat
// stale) or draining (admin action) is never retried, never aged out, and
// never deleted. On a live cluster these rows sat at attempts=0 indefinitely
// and kept the UI operation alive. DeleteStalePendingBackendOps garbage-collects
// them: draining nodes immediately (models already purged), offline nodes only
// after a grace window so a brief heartbeat blip does not nuke in-flight work.
var _ = Describe("DeleteStalePendingBackendOps", func() {
var (
registry *NodeRegistry
ctx context.Context
)
BeforeEach(func() {
if runtime.GOOS == "darwin" {
Skip("testcontainers requires Docker, not available on macOS CI")
}
db := testutil.SetupTestDB()
var err error
registry, err = NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
ctx = context.Background()
})
// registerBackend registers an auto-approved backend node and returns its ID.
registerBackend := func(name, address string) string {
node := &BackendNode{Name: name, NodeType: NodeTypeBackend, Address: address}
Expect(registry.Register(ctx, node, true)).To(Succeed())
fetched, err := registry.GetByName(ctx, name)
Expect(err).ToNot(HaveOccurred())
return fetched.ID
}
// setHeartbeat forces a node's last_heartbeat (Register/MarkOffline leave it
// at "now"; we age it to simulate a node that went silent a while ago).
setHeartbeat := func(nodeID string, t time.Time) {
Expect(registry.db.WithContext(ctx).Model(&BackendNode{}).
Where("id = ?", nodeID).
Update("last_heartbeat", t).Error).To(Succeed())
}
pendingCountFor := func(nodeID string) int64 {
var n int64
Expect(registry.db.WithContext(ctx).Model(&PendingBackendOp{}).
Where("node_id = ?", nodeID).Count(&n).Error).To(Succeed())
return n
}
It("clears ops behind an offline node whose heartbeat is past the grace window", func() {
dead := registerBackend("nvidia-thor", "10.0.0.9:50051")
Expect(registry.UpsertPendingBackendOp(ctx, dead, "llama-cpp-development", OpBackendInstall, nil)).To(Succeed())
Expect(registry.MarkOffline(ctx, dead)).To(Succeed())
setHeartbeat(dead, time.Now().Add(-1*time.Hour))
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
Expect(err).ToNot(HaveOccurred())
Expect(removed).To(Equal(int64(1)))
Expect(pendingCountFor(dead)).To(Equal(int64(0)))
})
It("clears ops behind a draining node immediately, even with a fresh heartbeat", func() {
// Mirrors the live mac-mini-m4 case: draining but still heartbeating.
drain := registerBackend("mac-mini-m4", "10.0.0.3:50051")
Expect(registry.UpsertPendingBackendOp(ctx, drain, "llama-cpp-development", OpBackendInstall, nil)).To(Succeed())
Expect(registry.MarkDraining(ctx, drain)).To(Succeed())
setHeartbeat(drain, time.Now()) // fresh heartbeat
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
Expect(err).ToNot(HaveOccurred())
Expect(removed).To(Equal(int64(1)))
Expect(pendingCountFor(drain)).To(Equal(int64(0)))
})
It("clears ops behind an unhealthy node with a stale heartbeat (never ages to offline)", func() {
// A node marked unhealthy on a NATS ErrNoResponders never transitions to
// offline, so its ops must be reaped via the same stale-heartbeat path.
sick := registerBackend("agx-orin-sick", "10.0.0.7:50051")
Expect(registry.UpsertPendingBackendOp(ctx, sick, "llama-cpp-development", OpBackendUpgrade, nil)).To(Succeed())
Expect(registry.MarkUnhealthy(ctx, sick)).To(Succeed())
setHeartbeat(sick, time.Now().Add(-1*time.Hour))
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
Expect(err).ToNot(HaveOccurred())
Expect(removed).To(Equal(int64(1)))
Expect(pendingCountFor(sick)).To(Equal(int64(0)))
})
It("keeps ops behind an unhealthy node that is still heartbeating (recovering)", func() {
recovering := registerBackend("agx-orin-flap", "10.0.0.8:50051")
Expect(registry.UpsertPendingBackendOp(ctx, recovering, "llama-cpp-development", OpBackendUpgrade, nil)).To(Succeed())
Expect(registry.MarkUnhealthy(ctx, recovering)).To(Succeed())
setHeartbeat(recovering, time.Now()) // fresh heartbeat → recovering
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
Expect(err).ToNot(HaveOccurred())
Expect(removed).To(Equal(int64(0)))
Expect(pendingCountFor(recovering)).To(Equal(int64(1)))
})
It("keeps ops behind a node that only just went offline (within grace)", func() {
blip := registerBackend("agx-orin", "10.0.0.4:50051")
Expect(registry.UpsertPendingBackendOp(ctx, blip, "parakeet-cpp-development", OpBackendInstall, nil)).To(Succeed())
Expect(registry.MarkOffline(ctx, blip)).To(Succeed())
setHeartbeat(blip, time.Now().Add(-1*time.Minute)) // gone only 1m, grace 10m
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
Expect(err).ToNot(HaveOccurred())
Expect(removed).To(Equal(int64(0)))
Expect(pendingCountFor(blip)).To(Equal(int64(1)))
})
It("keeps ops behind a healthy node", func() {
healthy := registerBackend("dgx-spark", "10.0.0.1:50051")
Expect(registry.UpsertPendingBackendOp(ctx, healthy, "llama-cpp-development", OpBackendUpgrade, nil)).To(Succeed())
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
Expect(err).ToNot(HaveOccurred())
Expect(removed).To(Equal(int64(0)))
Expect(pendingCountFor(healthy)).To(Equal(int64(1)))
})
})

View File

@@ -189,6 +189,13 @@ func (rc *ReplicaReconciler) reconcileState(ctx context.Context) {
// passed on nodes that are currently healthy. On success the row is deleted;
// on failure attempts++ and next_retry_at moves out via exponential backoff.
func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) {
// Garbage-collect ops behind nodes that went offline/draining. These are
// invisible to ListDuePendingBackendOps (which filters status=healthy), so
// without this sweep they leak forever and keep the UI operation spinning.
if _, err := rc.registry.DeleteStalePendingBackendOps(ctx, stalePendingBackendOpGrace); err != nil {
xlog.Warn("Reconciler: failed to clear stale pending backend ops", "error", err)
}
ops, err := rc.registry.ListDuePendingBackendOps(ctx)
if err != nil {
xlog.Warn("Reconciler: failed to list pending backend ops", "error", err)
@@ -223,10 +230,13 @@ func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) {
// the same worker. Falls back to the legacy backend.install
// Force=true path on nats.ErrNoResponders for old workers that
// don't subscribe to backend.upgrade yet (rolling-update window).
reply, err := rc.adapter.UpgradeBackend(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0)
// Reconciler retries are background reconciliation with no live
// admin watching a progress bar, so opID/onProgress are empty —
// the adapter skips the progress subscription entirely.
reply, err := rc.adapter.UpgradeBackend(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0, "", nil)
if err != nil {
if errors.Is(err, nats.ErrNoResponders) {
instReply, instErr := rc.adapter.installWithForceFallback(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0)
instReply, instErr := rc.adapter.installWithForceFallback(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0, "", nil)
if instErr != nil {
applyErr = instErr
} else if !instReply.Success {
@@ -293,6 +303,13 @@ func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) {
// amount of further retrying will help.
const maxPendingBackendOpAttempts = 10
// stalePendingBackendOpGrace is how long a node may be offline before its
// pending backend ops are garbage-collected. Draining nodes are cleared
// immediately regardless of this window (see DeleteStalePendingBackendOps).
// ListDuePendingBackendOps never surfaces ops behind non-healthy nodes, so
// without this sweep they would leak forever and keep the UI op spinning.
const stalePendingBackendOpGrace = 15 * time.Minute
// probeLoadedModels gRPC-health-checks model addresses that the DB says are
// loaded. If a model's backend process is gone (OOM, crash, manual restart)
// we remove the row so ghosts don't linger. Only probes rows older than

View File

@@ -1776,6 +1776,38 @@ func (r *NodeRegistry) DeletePendingBackendOp(ctx context.Context, id uint) erro
return nil
}
// DeleteStalePendingBackendOps garbage-collects pending backend ops whose target
// node can never drain them. ListDuePendingBackendOps only returns rows behind a
// StatusHealthy node, so ops behind a node that went offline or draining are
// otherwise never retried, aged out, or deleted — they leak forever and keep the
// UI operation spinning. Draining nodes are cleared immediately (an explicit
// admin action; their model rows are already purged). Offline nodes are cleared
// only once their last heartbeat is older than `grace`, so a brief heartbeat blip
// does not nuke an install that is still legitimately in flight. Returns the
// number of rows deleted.
func (r *NodeRegistry) DeleteStalePendingBackendOps(ctx context.Context, grace time.Duration) (int64, error) {
cutoff := time.Now().Add(-grace)
// Draining nodes are cleared immediately (admin action; model rows already
// purged). Offline AND unhealthy nodes are cleared only once their heartbeat
// is older than the grace window: a node marked unhealthy on a NATS
// ErrNoResponders never transitions to offline (health.go skips re-marking
// it), so without including unhealthy here its ops would leak exactly like
// the offline case. A node with a fresh heartbeat (last_heartbeat > cutoff)
// is recovering and keeps its op for retry.
res := r.db.WithContext(ctx).
Where(`node_id IN (SELECT id FROM backend_nodes WHERE status = ?)
OR node_id IN (SELECT id FROM backend_nodes WHERE status IN ? AND last_heartbeat <= ?)`,
StatusDraining, []string{StatusOffline, StatusUnhealthy}, cutoff).
Delete(&PendingBackendOp{})
if res.Error != nil {
return 0, fmt.Errorf("deleting stale pending backend ops: %w", res.Error)
}
if res.RowsAffected > 0 {
xlog.Info("Cleared pending backend ops behind non-healthy nodes", "deleted", res.RowsAffected)
}
return res.RowsAffected, nil
}
// RecordPendingBackendOpFailure bumps Attempts, captures the error, and
// pushes NextRetryAt out with exponential backoff capped at 15 minutes.
func (r *NodeRegistry) RecordPendingBackendOpFailure(ctx context.Context, id uint, errMsg string) error {

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
@@ -932,13 +933,12 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
{"AudioPath", &opts.AudioPath},
}
// Count stageable files for progress tracking
// Count stageable files for progress tracking. Directory models expand to
// the number of files they contain, matching what stageDirectory uploads.
totalFiles := 0
for _, f := range fields {
if *f.val != "" {
if _, err := os.Stat(*f.val); err == nil {
totalFiles++
}
totalFiles += countStageableFiles(*f.val)
}
}
for _, adapter := range opts.LoraAdapters {
@@ -969,8 +969,33 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
*f.val = ""
continue
}
fileIdx++
localPath := *f.val
// Directory models (e.g. qwen3-tts-cpp ships its weights and tokenizer
// ggufs under one directory) can't be uploaded as a single file — the
// stager would open the directory and read its fd, failing with
// "is a directory" (EISDIR). Expand the directory and stage each
// contained file, then rewrite the field to the remote directory.
if fi, statErr := os.Stat(localPath); statErr == nil && fi.IsDir() {
remoteDir, dirErr := r.stageDirectory(ctx, node, trackingKey, localPath, keyMapper, &fileIdx, totalFiles)
if dirErr != nil {
if f.name == "ModelFile" {
xlog.Error("Failed to stage model directory for remote node", "node", node.Name, "field", f.name, "path", localPath, "error", dirErr)
return nil, fmt.Errorf("staging model file: %w", dirErr)
}
xlog.Warn("Failed to stage model directory, clearing field", "field", f.name, "path", localPath, "error", dirErr)
*f.val = ""
continue
}
*f.val = remoteDir
if f.name == "ModelFile" && opts.Model != "" {
opts.ModelPath = DeriveRemoteModelPath(remoteDir, opts.Model)
xlog.Debug("Derived remote ModelPath", "modelPath", opts.ModelPath)
}
continue
}
fileIdx++
key := keyMapper.Key(localPath)
// Attach progress callback to context for byte-level tracking
@@ -1074,6 +1099,77 @@ func (r *SmartRouter) withStagingCallback(ctx context.Context, trackingKey, file
})
}
// countStageableFiles returns the number of regular files a model path expands
// to for staging: 1 for a regular file, the contained file count for a
// directory, and 0 if the path does not exist.
func countStageableFiles(path string) int {
fi, err := os.Stat(path)
if err != nil {
return 0
}
if !fi.IsDir() {
return 1
}
n := 0
_ = filepath.WalkDir(path, func(_ string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return nil
}
if !d.IsDir() {
n++
}
return nil
})
return n
}
// stageDirectory stages every file under a directory-based model (e.g.
// qwen3-tts-cpp, whose weights and tokenizer ggufs live in one directory).
// Each file is uploaded individually with a structure-preserving key; the
// returned path is the remote directory that contained them, suitable for the
// backend's ModelFile/ModelPath. fileIdx is advanced per staged file so the
// staging progress tracker stays accurate.
func (r *SmartRouter) stageDirectory(ctx context.Context, node *BackendNode, trackingKey, dir string, keyMapper *StagingKeyMapper, fileIdx *int, totalFiles int) (string, error) {
var remoteDir string
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
*fileIdx++
fileName := filepath.Base(path)
stageCtx := r.withStagingCallback(ctx, trackingKey, fileName, *fileIdx, totalFiles)
xlog.Info("Staging file", "model", trackingKey, "node", node.Name, "field", "ModelDir", "file", fileName, "fileIndex", *fileIdx, "totalFiles", totalFiles)
remoteFile, err := r.fileStager.EnsureRemote(stageCtx, node.ID, path, keyMapper.Key(path))
if err != nil {
return fmt.Errorf("staging %s: %w", path, err)
}
r.stagingTracker.FileComplete(trackingKey, *fileIdx, totalFiles)
// Every file under dir shares the same remote parent directory; derive
// it from this file's staged path and its path relative to dir.
rel, relErr := filepath.Rel(dir, path)
if relErr != nil {
return relErr
}
remoteDir = DeriveRemoteModelPath(remoteFile, rel)
r.stageCompanionFiles(ctx, node, path, keyMapper.Key)
return nil
})
if err != nil {
return "", err
}
if remoteDir == "" {
return "", fmt.Errorf("model directory %s contains no files", dir)
}
return remoteDir, nil
}
// stageCompanionFiles stages known companion files that exist alongside
// localPath. For example, piper TTS implicitly loads ".onnx.json" next to
// the ".onnx" model file. Errors are logged but not propagated.

View File

@@ -0,0 +1,64 @@
package nodes
import (
"context"
"os"
"path/filepath"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// These tests cover staging of "directory models" — models whose ModelFile is a
// directory containing multiple files (e.g. qwen3-tts-cpp ships weights +
// tokenizer ggufs under one directory). The HTTP file stager uploads a single
// regular file per path, so a directory ModelFile must be expanded into its
// constituent files; otherwise the upload reads a directory fd and fails with
// "is a directory" (EISDIR) on remote NATS worker nodes.
var _ = Describe("stageModelFiles directory models", func() {
var (
stager *fakeFileStager
router *SmartRouter
node *BackendNode
tmp string
modelID = "qwen3-tts-cpp"
)
BeforeEach(func() {
stager = &fakeFileStager{}
router = &SmartRouter{
fileStager: stager,
stagingTracker: NewStagingTracker(),
}
node = &BackendNode{ID: "node-1", Name: "node-1", Address: "10.0.0.1:50051"}
tmp = GinkgoT().TempDir()
})
It("stages every file inside a directory ModelFile instead of the directory path", func() {
modelDir := filepath.Join(tmp, "models", modelID)
Expect(os.MkdirAll(modelDir, 0o755)).To(Succeed())
weights := filepath.Join(modelDir, "qwen3-tts-0.6b-f16.gguf")
tokenizer := filepath.Join(modelDir, "qwen3-tts-tokenizer-f16.gguf")
Expect(os.WriteFile(weights, []byte("weights"), 0o644)).To(Succeed())
Expect(os.WriteFile(tokenizer, []byte("tokenizer"), 0o644)).To(Succeed())
opts := &pb.ModelOptions{
Model: modelID,
ModelFile: modelDir,
}
_, err := router.stageModelFiles(context.Background(), node, opts, "track-key")
Expect(err).ToNot(HaveOccurred())
staged := make([]string, 0, len(stager.ensureCalls))
for _, c := range stager.ensureCalls {
staged = append(staged, c.localPath)
}
// Each contained file is staged individually; the directory path itself
// is never handed to the stager (which would read a directory fd).
Expect(staged).To(ConsistOf(weights, tokenizer))
Expect(staged).ToNot(ContainElement(modelDir))
})
})

View File

@@ -365,7 +365,7 @@ func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ strin
return f.installReply, f.installErr
}
func (f *fakeUnloader) UpgradeBackend(nodeID, backend, _, _, _, _ string, replica int) (*messaging.BackendUpgradeReply, error) {
func (f *fakeUnloader) UpgradeBackend(nodeID, backend, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error) {
f.mu.Lock()
f.upgradeCalls = append(f.upgradeCalls, upgradeCall{nodeID, backend, replica})
f.mu.Unlock()

View File

@@ -35,7 +35,7 @@ type backendStopRequest struct {
// backend.upgrade subject.
type NodeCommandSender interface {
InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error)
UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error)
UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error)
DeleteBackend(nodeID, backendName string) (*messaging.BackendDeleteReply, error)
ListBackends(nodeID string) (*messaging.BackendListReply, error)
StopBackend(nodeID, backend string) error
@@ -127,38 +127,8 @@ func (a *RemoteUnloaderAdapter) InstallBackend(
xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID, "replica", replicaIndex, "opID", opID)
// Subscribe to the per-op progress subject BEFORE publishing the install
// request so we don't miss early events. When onProgress is nil OR opID
// is empty (the reconciler-driven retry path), skip subscription entirely:
// silent installs cost nothing extra.
var sub messaging.Subscription
if onProgress != nil && opID != "" {
progressSubject := messaging.SubjectNodeBackendInstallProgress(nodeID, opID)
s, subErr := a.nats.Subscribe(progressSubject, func(raw []byte) {
var ev messaging.BackendInstallProgressEvent
if err := json.Unmarshal(raw, &ev); err != nil {
xlog.Debug("malformed install progress event", "subject", progressSubject, "error", err)
return
}
// Goroutine guard: a slow onProgress callback must not stall
// the NATS reader thread.
//
// NOTE: events spawn one goroutine each, so ordering at the
// consumer is best-effort. In practice the worker debounces to
// ~250ms which is far larger than goroutine scheduling jitter,
// so reordering is rare. The worker's final Flush() event is
// intended to win as the terminal tick. A future hardening pass
// could add a Seq uint64 field to BackendInstallProgressEvent
// and drop stale-by-seq at the bridge if reordering becomes a
// real UX issue.
go onProgress(ev)
})
if subErr != nil {
xlog.Warn("Failed to subscribe to install progress subject; proceeding without progress streaming",
"subject", progressSubject, "error", subErr)
} else {
sub = s
}
}
// request so we don't miss early events.
sub := a.subscribeProgress(nodeID, opID, onProgress)
reply, err := messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
Backend: backendType,
@@ -182,18 +152,58 @@ func (a *RemoteUnloaderAdapter) InstallBackend(
return reply, err
}
// subscribeProgress subscribes to the per-op backend-install progress subject
// so the master can stream per-node download ticks while a worker installs or
// upgrades. Returns nil (and subscribes to nothing) when onProgress is nil or
// opID is empty — the reconciler-driven retry path and legacy callers stay
// silent at no cost. Shared by InstallBackend, UpgradeBackend, and the legacy
// force-install fallback: an upgrade is a force-reinstall, so it reuses the
// install-progress subject rather than minting a new one (no new NATS
// permission, no new rolling-update compat surface). Caller must Unsubscribe
// the returned subscription after the request completes.
func (a *RemoteUnloaderAdapter) subscribeProgress(nodeID, opID string, onProgress func(messaging.BackendInstallProgressEvent)) messaging.Subscription {
if onProgress == nil || opID == "" {
return nil
}
progressSubject := messaging.SubjectNodeBackendInstallProgress(nodeID, opID)
s, subErr := a.nats.Subscribe(progressSubject, func(raw []byte) {
var ev messaging.BackendInstallProgressEvent
if err := json.Unmarshal(raw, &ev); err != nil {
xlog.Debug("malformed backend progress event", "subject", progressSubject, "error", err)
return
}
// Goroutine guard: a slow onProgress callback must not stall the NATS
// reader thread. Events spawn one goroutine each, so ordering at the
// consumer is best-effort; the worker debounces to ~250ms which dwarfs
// goroutine scheduling jitter, and its final Flush() is the terminal tick.
go onProgress(ev)
})
if subErr != nil {
xlog.Warn("Failed to subscribe to backend progress subject; proceeding without progress streaming",
"subject", progressSubject, "error", subErr)
return nil
}
return s
}
// UpgradeBackend sends a backend.upgrade request-reply to a worker node.
// The worker stops every live process for this backend, force-reinstalls
// from the gallery (overwriting the on-disk artifact), and replies. The
// next routine InstallBackend call spawns a fresh process with the new
// binary - upgrade itself does not start a process.
//
// When opID is non-empty and onProgress is set, the master subscribes to the
// per-op progress subject before firing the request so a long force-reinstall
// streams per-node download ticks instead of blocking opaque at progress 0.
//
// Timeout: configured via DistributedConfig.BackendUpgradeTimeoutOrDefault
// (default 15m). Real-world worst case observed: 8-10 minutes for large
// CUDA-l4t backend images on Jetson over WiFi.
func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error) {
func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error) {
subject := messaging.SubjectNodeBackendUpgrade(nodeID)
xlog.Info("Sending NATS backend.upgrade", "nodeID", nodeID, "backend", backendType, "replica", replicaIndex)
xlog.Info("Sending NATS backend.upgrade", "nodeID", nodeID, "backend", backendType, "replica", replicaIndex, "opID", opID)
sub := a.subscribeProgress(nodeID, opID, onProgress)
reply, err := messaging.RequestJSON[messaging.BackendUpgradeRequest, messaging.BackendUpgradeReply](a.nats, subject, messaging.BackendUpgradeRequest{
Backend: backendType,
@@ -202,7 +212,13 @@ func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSO
Name: name,
Alias: alias,
ReplicaIndex: int32(replicaIndex),
OpID: opID,
}, a.upgradeTimeout)
if sub != nil {
_ = sub.Unsubscribe()
}
if err != nil && isNATSTimeout(err) {
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)
@@ -216,10 +232,12 @@ func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSO
// doesn't subscribe to the new subject). It re-fires the legacy
// backend.install with Force=true. Drop this once every worker is on
// 2026-05-08 or newer.
func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error) {
func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error) {
subject := messaging.SubjectNodeBackendInstall(nodeID)
xlog.Warn("Falling back to legacy backend.install Force=true (old worker)", "nodeID", nodeID, "backend", backendType)
sub := a.subscribeProgress(nodeID, opID, onProgress)
reply, err := messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
Backend: backendType,
BackendGalleries: galleriesJSON,
@@ -228,7 +246,13 @@ func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, ga
Alias: alias,
ReplicaIndex: int32(replicaIndex),
Force: true,
OpID: opID,
}, a.upgradeTimeout)
if sub != nil {
_ = sub.Unsubscribe()
}
if err != nil && isNATSTimeout(err) {
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)

View File

@@ -282,7 +282,7 @@ var _ = Describe("RemoteUnloaderAdapter timeout configuration", func() {
mc.scriptReply(messaging.SubjectNodeBackendUpgrade("n1"), messaging.BackendUpgradeReply{Success: true})
adapter := NewRemoteUnloaderAdapter(nil, mc, 7*time.Minute, 11*time.Minute)
_, err := adapter.UpgradeBackend("n1", "llama-cpp", "[]", "", "", "", 0)
_, err := adapter.UpgradeBackend("n1", "llama-cpp", "[]", "", "", "", 0, "", nil)
Expect(err).ToNot(HaveOccurred())
Expect(mc.calls).To(HaveLen(1))

View File

@@ -1,6 +1,7 @@
package nodes
import (
"sync"
"time"
. "github.com/onsi/ginkgo/v2"
@@ -18,7 +19,7 @@ var _ = Describe("RemoteUnloaderAdapter.UpgradeBackend", func() {
messaging.BackendUpgradeReply{Success: true})
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
reply, err := adapter.UpgradeBackend(nodeID, "llama-cpp", `[{"name":"x"}]`, "", "", "", 0)
reply, err := adapter.UpgradeBackend(nodeID, "llama-cpp", `[{"name":"x"}]`, "", "", "", 0, "", nil)
Expect(err).ToNot(HaveOccurred())
Expect(reply.Success).To(BeTrue())
})
@@ -27,7 +28,55 @@ var _ = Describe("RemoteUnloaderAdapter.UpgradeBackend", func() {
mc := newScriptedMessagingClient() // unscripted subject => fakeNoRespondersErr by harness convention
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
_, err := adapter.UpgradeBackend("missing-node", "llama-cpp", "", "", "", "", 0)
_, err := adapter.UpgradeBackend("missing-node", "llama-cpp", "", "", "", "", 0, "", nil)
Expect(err).To(HaveOccurred())
})
// Reproducer for "upgrade reports progress:0 the whole time" (Bug B). The
// install path streamed per-node download ticks; the upgrade path did a bare
// request→single-reply with no progress subscription, so a long force-reinstall
// blocked opaque. The adapter must subscribe to the per-op progress subject
// (reused from install) BEFORE the request and deliver each tick to onProgress.
It("streams per-node progress ticks during the upgrade", func() {
mc := newScriptedMessagingClient()
nodeID := "node-slow"
opID := "op-upgrade-1"
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(nodeID),
messaging.BackendUpgradeReply{Success: true})
// The worker would publish these while force-reinstalling. The harness
// replays them as soon as the adapter subscribes to the per-op subject.
mc.scheduleProgressPublish(nodeID, opID, []messaging.BackendInstallProgressEvent{
{NodeID: nodeID, FileName: "llama-cpp.tar", Current: "10 MB", Total: "100 MB", Percentage: 10},
{NodeID: nodeID, FileName: "llama-cpp.tar", Current: "100 MB", Total: "100 MB", Percentage: 100},
})
var mu sync.Mutex
var got []messaging.BackendInstallProgressEvent
onProgress := func(ev messaging.BackendInstallProgressEvent) {
mu.Lock()
got = append(got, ev)
mu.Unlock()
}
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
reply, err := adapter.UpgradeBackend(nodeID, "llama-cpp", `[{"name":"x"}]`, "", "", "", 0, opID, onProgress)
Expect(err).ToNot(HaveOccurred())
Expect(reply.Success).To(BeTrue())
// Confirm it subscribed to the (reused) install-progress subject for this op.
Expect(mc.subscribeCalls()).To(ContainElement(messaging.SubjectNodeBackendInstallProgress(nodeID, opID)))
// Progress events are delivered asynchronously (goroutine-per-event), so
// poll for both and assert on the set — ordering is best-effort by design.
Eventually(func() []float64 {
mu.Lock()
defer mu.Unlock()
pcts := make([]float64, 0, len(got))
for _, e := range got {
pcts = append(pcts, e.Percentage)
}
return pcts
}, 2*time.Second, 20*time.Millisecond).Should(ConsistOf(float64(10), float64(100)))
})
})

View File

@@ -0,0 +1,30 @@
package worker
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Worker auth-required helpers", func() {
DescribeTable("NatsAuthRequired",
func(nats, umbrella, want bool) {
cfg := &Config{NatsRequireAuth: nats, DistributedRequireAuth: umbrella}
Expect(cfg.NatsAuthRequired()).To(Equal(want))
},
Entry("neither", false, false, false),
Entry("granular only", true, false, true),
Entry("umbrella only", false, true, true),
Entry("both", true, true, true),
)
DescribeTable("RegistrationAuthRequired",
func(reg, umbrella, want bool) {
cfg := &Config{RegistrationRequireAuth: reg, DistributedRequireAuth: umbrella}
Expect(cfg.RegistrationAuthRequired()).To(Equal(want))
},
Entry("neither", false, false, false),
Entry("granular only", true, false, true),
Entry("umbrella only", false, true, true),
Entry("both", true, true, true),
)
})

View File

@@ -44,12 +44,14 @@ type Config struct {
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server" hidden:""`
// Registration (required)
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""`
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"`
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""`
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
RegistrationRequireAuth bool `env:"LOCALAI_REGISTRATION_REQUIRE_AUTH" default:"false" help:"Refuse to start the HTTP file-transfer server when no registration token is set (otherwise it fails open and serves read/write to models/staging/data unauthenticated)" group:"registration"`
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch implying both --nats-require-auth and --registration-require-auth" group:"distributed"`
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"`
// MaxReplicasPerModel caps how many replicas of any one model can run on
// this worker concurrently. Default 1 = historical single-replica
// behavior. Set higher when a node has enough VRAM to host multiple
@@ -60,7 +62,13 @@ type Config struct {
MaxReplicasPerModel int `env:"LOCALAI_MAX_REPLICAS_PER_MODEL" default:"1" help:"Max replicas of any single model on this worker. Default 1 preserves single-replica behavior; set higher to allow stacking replicas on a fat node." group:"registration"`
// NATS (required)
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (normally from registration nats_jwt)" group:"distributed"`
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user signing seed override (normally from registration nats_user_seed)" group:"distributed"`
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed from registration or env" group:"distributed"`
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
// S3 storage for distributed file transfer
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"`
@@ -69,3 +77,15 @@ type Config struct {
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key" group:"distributed"`
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret key" group:"distributed"`
}
// NatsAuthRequired reports whether NATS JWT credentials must be present — the
// granular flag or the umbrella (LOCALAI_DISTRIBUTED_REQUIRE_AUTH).
func (c Config) NatsAuthRequired() bool {
return c.NatsRequireAuth || c.DistributedRequireAuth
}
// RegistrationAuthRequired reports whether a registration token must be set
// before the file-transfer server may start — the granular flag or the umbrella.
func (c Config) RegistrationAuthRequired() bool {
return c.RegistrationRequireAuth || c.DistributedRequireAuth
}

View File

@@ -186,17 +186,29 @@ func (s *backendSupervisor) upgradeBackend(req messaging.BackendUpgradeRequest)
}
}
// When the master tagged this upgrade with an OpID, stream gallery download
// progress back on the per-op subject (reused from install — an upgrade is a
// force-reinstall). Old masters omit OpID and stay on the silent path. The
// deferred Flush guarantees a terminal-percentage event even if the upgrade
// errors out, so the master's per-node bar never hangs mid-download.
var downloadCb func(file, current, total string, percentage float64)
if req.OpID != "" && s.nats != nil {
publisher := nodes.NewDebouncedInstallProgressPublisher(s.nats, s.nodeID, req.OpID, req.Backend, installProgressDebounce)
downloadCb = publisher.OnDownload
defer publisher.Flush()
}
if req.URI != "" {
xlog.Info("Upgrading backend from external URI", "backend", req.Backend, "uri", req.URI)
if err := galleryop.InstallExternalBackend(
context.Background(), galleries, s.systemState, s.ml, nil, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
context.Background(), galleries, s.systemState, s.ml, downloadCb, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
); err != nil {
return fmt.Errorf("upgrading backend from external URI: %w", err)
}
} else {
xlog.Info("Upgrading backend from gallery", "backend", req.Backend)
if err := gallery.InstallBackendFromGallery(
context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, true, /* force */
context.Background(), galleries, s.systemState, s.ml, req.Backend, downloadCb, true, /* force */
s.cfg.RequireBackendIntegrity,
); err != nil {
return fmt.Errorf("upgrading backend from gallery: %w", err)

Some files were not shown because too many files have changed in this diff Show More