mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-08 00:36:37 -04:00
Compare commits
50 Commits
feat/p2p-f
...
fix/distri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4fec35772 | ||
|
|
05c0a08e24 | ||
|
|
7824105a31 | ||
|
|
47fa847d55 | ||
|
|
9a7ebc1151 | ||
|
|
f6cc90d258 | ||
|
|
2c804bef5a | ||
|
|
6070402477 | ||
|
|
67f80a152b | ||
|
|
a7cb587d96 | ||
|
|
f7c74ad2da | ||
|
|
7402d1fd20 | ||
|
|
8c42695ef8 | ||
|
|
72e3241431 | ||
|
|
cd2bf95862 | ||
|
|
f64b72dd7d | ||
|
|
03c84cff28 | ||
|
|
9bc69c9e5f | ||
|
|
1e6c9cfd60 | ||
|
|
0e6712f734 | ||
|
|
0e4cee9a97 | ||
|
|
352b7ec604 | ||
|
|
ba706422fb | ||
|
|
e837921c2c | ||
|
|
73385713ca | ||
|
|
a4e671779a | ||
|
|
7051b2e0a1 | ||
|
|
469737101a | ||
|
|
858257eaf0 | ||
|
|
ef80a0e825 | ||
|
|
92726f7631 | ||
|
|
994063ba9a | ||
|
|
c1a55cf72d | ||
|
|
96758841d8 | ||
|
|
7a59260621 | ||
|
|
27e63b9a78 | ||
|
|
55c0911c23 | ||
|
|
f6cb6ab6d9 | ||
|
|
9f11b09c6a | ||
|
|
a5c4f822f0 | ||
|
|
fb36c262fe | ||
|
|
0e4e8980e6 | ||
|
|
3a932a9803 | ||
|
|
9d10418593 | ||
|
|
5470051d4d | ||
|
|
68c5eeebc3 | ||
|
|
1531fabe23 | ||
|
|
b7673d5b76 | ||
|
|
b64bdaf406 | ||
|
|
eebf08ff1d |
14
.github/backend-matrix.yml
vendored
14
.github/backend-matrix.yml
vendored
@@ -1766,20 +1766,6 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-turboquant'
|
||||
builder-base-image: 'quay.io/go-skynet/ci-cache:base-grpc-rocm-amd64'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
|
||||
13
.github/gallery-agent/main.go
vendored
13
.github/gallery-agent/main.go
vendored
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -113,6 +114,17 @@ func main() {
|
||||
fmt.Println("Searching for trending models on HuggingFace...")
|
||||
rawModels, err := client.GetTrending(searchTerm, limit)
|
||||
if err != nil {
|
||||
if errors.Is(err, hfapi.ErrRateLimited) {
|
||||
fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err)
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: 0,
|
||||
ModelsAdded: 0,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -277,4 +289,3 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
|
||||
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: securego/gosec@v2.22.9
|
||||
uses: securego/gosec@v2.27.1
|
||||
with:
|
||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||
|
||||
9
Makefile
9
Makefile
@@ -309,13 +309,20 @@ run-e2e-aio: protogen-go
|
||||
@echo 'Running e2e AIO tests'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
|
||||
|
||||
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
|
||||
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
|
||||
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
|
||||
test-e2e-distributed: protogen-go
|
||||
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
|
||||
|
||||
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
|
||||
# cpu-vllm backend from the current working tree, then drives a
|
||||
# head + headless follower via testcontainers-go and asserts a chat
|
||||
# completion. BuildKit caches both images, so re-runs only rebuild
|
||||
# what changed. The test lives under tests/e2e/distributed and is
|
||||
# selected by the VLLMMultinode label so it doesn't run alongside
|
||||
# the other distributed-suite tests by default.
|
||||
# test-e2e-distributed.
|
||||
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
|
||||
@echo 'Running e2e vLLM multi-node DP test'
|
||||
LOCALAI_IMAGE=local-ai \
|
||||
|
||||
@@ -537,6 +537,15 @@ message TTSRequest {
|
||||
string dst = 3;
|
||||
string voice = 4;
|
||||
optional string language = 5;
|
||||
// instructions is a free-form, per-request style/voice description (maps to
|
||||
// the OpenAI `instructions` field). Backends that support expressive synthesis
|
||||
// (e.g. Qwen3-TTS CustomVoice/VoiceDesign) prefer this over the static YAML
|
||||
// option when set; backends that don't simply ignore it.
|
||||
optional string instructions = 6;
|
||||
// params carries optional, backend-specific per-request generation parameters
|
||||
// (e.g. Chatterbox exaggeration/cfg_weight/temperature). Values are strings and
|
||||
// coerced by the backend; unset leaves the backend's configured defaults.
|
||||
map<string, string> params = 7;
|
||||
}
|
||||
|
||||
message VADRequest {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
# Upstream pin lives below as DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=3f40e73c367ad9f0c1b1819f28c7348c26aa340d
|
||||
IK_LLAMA_VERSION?=6b9de3dbaa21ae95ea80638e5ee836795cc48c93
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=5dcb71166686799f0d873eab7386234302d05ecf
|
||||
LLAMA_VERSION?=31e82494c0a3913c919c1027fa70500fbf4c07dd
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -482,23 +482,13 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!request->draftmodel().empty()) {
|
||||
params.speculative.draft.mparams.path = request->draftmodel();
|
||||
// Default to draft type if a draft model is set but no explicit type.
|
||||
// Upstream (post ggml-org/llama.cpp#22838) made the speculative type a
|
||||
// vector; the turboquant fork still uses the legacy scalar. The
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
||||
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
|
||||
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
}
|
||||
#else
|
||||
// Upstream made the speculative type a vector (ggml-org/llama.cpp#22838)
|
||||
// and renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE (#22964).
|
||||
const bool no_spec_type = params.speculative.types.empty() ||
|
||||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
||||
if (no_spec_type) {
|
||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// params.model_alias ??
|
||||
@@ -574,9 +564,10 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// tokens (0 disables the minimum). Match upstream's default (256). This
|
||||
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
|
||||
// also shifted from a fixed cadence to a minimum spacing. The turboquant
|
||||
// fork branched before the field existed, so skip it on the legacy path
|
||||
// (LOCALAI_LEGACY_LLAMA_CPP_SPEC is injected by patch-grpc-server.sh).
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// fork still lacks common_params::checkpoint_min_step, so skip it there
|
||||
// (LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh).
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
params.checkpoint_min_step = 256;
|
||||
#endif
|
||||
|
||||
@@ -752,7 +743,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
|
||||
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
|
||||
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
|
||||
@@ -906,17 +897,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// Fork only knows a single scalar `type`. Take the first comma-
|
||||
// separated value and assign it via the singular helper.
|
||||
std::string first = optval_str;
|
||||
const auto comma = first.find(',');
|
||||
if (comma != std::string::npos) first = first.substr(0, comma);
|
||||
auto type = common_speculative_type_from_name(first);
|
||||
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||
params.speculative.type = type;
|
||||
}
|
||||
#else
|
||||
// Upstream switched to a vector of types (comma-separated for multi-type
|
||||
// chaining via common_speculative_types_from_names). We keep accepting a
|
||||
// single value here, but also tolerate comma-separated lists.
|
||||
@@ -945,7 +925,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!parsed.empty()) {
|
||||
params.speculative.types = parsed;
|
||||
}
|
||||
#endif
|
||||
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.draft.n_max = std::stoi(optval_str); } catch (...) {}
|
||||
@@ -983,21 +962,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// shares the target context size. Accept the option for backward
|
||||
// compatibility but silently ignore it.
|
||||
|
||||
// Everything below relies on struct shape introduced in ggml-org/llama.cpp#22838
|
||||
// (parallel drafting): `ngram_mod`, `ngram_map_k`, `ngram_map_k4v`,
|
||||
// `ngram_cache`, and the `draft.{cache_type_*, cpuparams*, tensor_buft_overrides}`
|
||||
// fields. The turboquant fork branched before that, so its build defines
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC via patch-grpc-server.sh and these option
|
||||
// keys become unrecognized (silently dropped, like any unknown opt) for it.
|
||||
//
|
||||
// The `#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC` / `#else` split below sits at the
|
||||
// closing-brace position of the `draft_ctx_size` branch on purpose: in the
|
||||
// legacy build the chain ends here (the brace closes draft_ctx_size), and in
|
||||
// the modern build the chain continues with `} else if (...)` instead, so the
|
||||
// brace count stays balanced under both branches of the preprocessor.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
}
|
||||
#else
|
||||
// --- ngram_mod family (upstream --spec-ngram-mod-*) ---
|
||||
} else if (!strcmp(optname, "spec_ngram_mod_n_min")) {
|
||||
if (optval != NULL) {
|
||||
@@ -1127,7 +1091,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
}
|
||||
if (!cur.empty()) flush(cur);
|
||||
}
|
||||
#endif // LOCALAI_LEGACY_LLAMA_CPP_SPEC — closes the `else`/`#ifdef` opened at draft_ctx_size
|
||||
}
|
||||
|
||||
// Set params.n_parallel from environment variable if not set via options (fallback)
|
||||
@@ -1177,15 +1140,11 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
}
|
||||
// The draft tensor_buft_overrides are only populated under the modern
|
||||
// (post-#22838) layout, whose population code is itself gated by
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC above. The turboquant fork lacks
|
||||
// common_params_speculative::draft entirely, so skip the sentinel there too.
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// Terminate the draft tensor_buft_overrides list with a sentinel, mirroring
|
||||
// the main-model handling above.
|
||||
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
|
||||
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: Add yarn
|
||||
|
||||
@@ -1944,6 +1903,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto re_it = metadata.find("reasoning_effort");
|
||||
if (re_it != metadata.end() && !re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2737,6 +2707,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto predict_re_it = predict_metadata.find("reasoning_effort");
|
||||
if (predict_re_it != predict_metadata.end() && !predict_re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = predict_re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
|
||||
TURBOQUANT_VERSION?=7d9715f1f071fa07c7b2ad3dbfd320b314139e65
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -4,21 +4,19 @@
|
||||
#
|
||||
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
|
||||
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
|
||||
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
|
||||
# server-side random per-instance marker) with the legacy "<__media__>"
|
||||
# literal. The fork branched before that PR, so server-common.cpp has no
|
||||
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
|
||||
# "<__media__>", and Go-side tooling falls back to that sentinel when the
|
||||
# backend does not expose media_marker, so substituting the literal keeps
|
||||
# behavior identical on the turboquant path.
|
||||
# 3. Revert the `common_params_speculative` field references to the
|
||||
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
|
||||
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
|
||||
# the turboquant fork branched before that PR and still exposes the flat
|
||||
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
|
||||
# below map the new nested paths back to the legacy flat names so the
|
||||
# shared grpc-server.cpp keeps compiling against the fork's common.h.
|
||||
# Drop this block once the fork rebases past #22397.
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file
|
||||
# so the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default and the option handler).
|
||||
# That field does not exist in the fork yet; drop this once it does.
|
||||
#
|
||||
# The fork used to lag upstream on the whole common_params_speculative refactor
|
||||
# (ggml-org/llama.cpp#22397/#22838/#22964), the model_tgt rename (#22838) and
|
||||
# get_media_marker (#21962), which required a much larger compat shim here
|
||||
# (flat-field sed renames + a coarse LOCALAI_LEGACY_LLAMA_CPP_SPEC define). The
|
||||
# fork has since rebased past all of those, so the only remaining gap is
|
||||
# checkpoint_min_step. If a future bump reintroduces a divergence, add a narrow
|
||||
# guard in grpc-server.cpp keyed on a fork-specific macro and inject it here
|
||||
# rather than resurrecting the coarse one.
|
||||
#
|
||||
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
|
||||
@@ -72,72 +70,20 @@ else
|
||||
echo "==> KV allow-list patch OK"
|
||||
fi
|
||||
|
||||
if grep -q 'get_media_marker()' "$SRC"; then
|
||||
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
|
||||
# Only one call site today (ModelMetadata), but replace all occurrences to
|
||||
# stay robust if upstream adds more. Use a temp file to avoid relying on
|
||||
# sed -i portability (the builder image uses GNU sed, but keeping this
|
||||
# consistent with the awk block above).
|
||||
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> get_media_marker() substitution OK"
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file so
|
||||
# the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default assignment and the option
|
||||
# handler). That field does not exist in the fork yet. Drop this block once
|
||||
# the fork rebases past the bump that added checkpoint_min_step.
|
||||
if grep -q '^#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP, skipping"
|
||||
else
|
||||
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
|
||||
fi
|
||||
|
||||
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
|
||||
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
|
||||
# Each substitution is the exact post-refactor path → legacy flat field.
|
||||
# Order doesn't matter because the source paths are disjoint, but we keep
|
||||
# the most-specific (mparams.path) first for readability.
|
||||
sed -E \
|
||||
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
|
||||
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
|
||||
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
|
||||
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
|
||||
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
|
||||
"$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> speculative field rename OK"
|
||||
else
|
||||
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
|
||||
fi
|
||||
|
||||
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
|
||||
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
|
||||
# exposes the field as `model` on `server_context_impl`. The two call sites
|
||||
# are in the Rerank and ModelMetadata RPC handlers.
|
||||
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
|
||||
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
|
||||
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> model_tgt rename OK"
|
||||
else
|
||||
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
|
||||
fi
|
||||
|
||||
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
|
||||
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
|
||||
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
|
||||
# draft.tensor_buft_overrides) introduced for the post-#22838 layout, the
|
||||
# draft.tensor_buft_overrides sentinel termination, and the
|
||||
# common_params::checkpoint_min_step default/option (added with the
|
||||
# 35c9b1f3 bump). Those blocks reference struct fields that simply do not
|
||||
# exist in the fork.
|
||||
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
|
||||
else
|
||||
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
|
||||
# Insert the define before the very first `#include` so it precedes all the
|
||||
# speculative-decoding code paths.
|
||||
echo "==> patching $SRC to define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top"
|
||||
# Insert the define before the very first `#include` so it precedes the
|
||||
# checkpoint_min_step references.
|
||||
awk '
|
||||
!done && /^#include/ {
|
||||
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
|
||||
print "#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP 1"
|
||||
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
|
||||
print ""
|
||||
done = 1
|
||||
@@ -145,13 +91,13 @@ else
|
||||
{ print }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
|
||||
echo "==> LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP define OK"
|
||||
fi
|
||||
|
||||
echo "==> all patches applied"
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
hip: port the turboquant CUDA additions that ggml's HIP shim doesn't cover
|
||||
|
||||
The turboquant fork adds/modifies a few ggml-cuda.cu spots with CUDA APIs
|
||||
that ggml's HIP (and MUSA) compatibility layer does not provide, breaking
|
||||
the -gpu-rocm-hipblas-turboquant build:
|
||||
|
||||
1. ggml_cuda_copy2d_across_devices() (host-staged cross-device copy for
|
||||
split mul_mat output) uses the CUDA 3D-peer copy APIs
|
||||
cudaMemcpy3DPeerParms / make_cudaPitchedPtr / make_cudaExtent /
|
||||
cudaMemcpy3DPeerAsync. HIP genuinely does not support these (see the
|
||||
fork's own comment "HIP does not support cudaMemcpy3DPeerAsync"), so
|
||||
guard the peer fast path with #if !defined(GGML_USE_HIP) &&
|
||||
!defined(GGML_USE_MUSA) -- matching how the fork already guards the
|
||||
same API for the sibling 2D copy -- and fall through to the existing
|
||||
cudaMemcpyAsync staging fallback below (functionally identical,
|
||||
slightly slower on multi-GPU ROCm).
|
||||
|
||||
2. ggml_backend_cuda_device_event_new() creates its event with plain
|
||||
cudaEventCreate, which ggml's HIP shim does not alias (it only aliases
|
||||
cudaEventCreateWithFlags). Use cudaEventCreateWithFlags(...,
|
||||
cudaEventDisableTiming) -- exactly what the rest of this file already
|
||||
does (cf. lines ~1034, ~3461) and HIP-safe.
|
||||
|
||||
CUDA builds are unaffected. Drop the relevant hunk once the fork HIP-ports
|
||||
these; apply-patches.sh fails fast if an anchor goes stale.
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 0427e6b..6352e6a 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -1933,6 +1933,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
size_t width, size_t height, cudaStream_t dst_stream, cudaStream_t src_stream) {
|
||||
|
||||
const auto & info = ggml_cuda_info();
|
||||
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // 3D-peer copy types unmapped by ggml's HIP/MUSA shim; use staging fallback below
|
||||
if (info.peer_access[src_device][dst_device]) {
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = dst_device;
|
||||
@@ -1942,6 +1943,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
p.extent = make_cudaExtent(width, height, 1);
|
||||
return cudaMemcpy3DPeerAsync(&p, dst_stream);
|
||||
}
|
||||
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
// Fallback: stage all rows through a single contiguous pinned buffer
|
||||
int prev_device = ggml_cuda_get_device();
|
||||
@@ -5714,7 +5716,7 @@ static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_
|
||||
ggml_cuda_set_device(dev_ctx->device);
|
||||
|
||||
cudaEvent_t event;
|
||||
- CUDA_CHECK(cudaEventCreate(&event));
|
||||
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
|
||||
|
||||
return new ggml_backend_event {
|
||||
/* .device = */ dev,
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
@@ -145,7 +146,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -175,7 +176,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -269,7 +270,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=05e60432bcb5bc2113f8c395a41e86497c11504a
|
||||
CRISPASR_VERSION?=13d54e110e1538e0f0bc3af0680b9ab246cfb48d
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
@@ -15,7 +15,7 @@
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
|
||||
@@ -7,8 +7,12 @@ import "time"
|
||||
type batchRequest struct {
|
||||
pcm []float32
|
||||
decoder int32
|
||||
tag string
|
||||
reply chan batchReply
|
||||
// language is the per-request target locale ("" means the model default).
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang for the whole batch,
|
||||
// so the dispatcher only coalesces requests that share a language.
|
||||
language string
|
||||
tag string
|
||||
reply chan batchReply
|
||||
}
|
||||
|
||||
// batchReply carries one per-item JSON object string (an element of the C-API's
|
||||
@@ -43,13 +47,25 @@ func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchReques
|
||||
// run is the dispatcher loop: accumulate submitted requests until either maxSize
|
||||
// is reached or maxWait elapses since the first queued request, then dispatch.
|
||||
// Exits when stop is closed (draining any partially-filled batch first).
|
||||
//
|
||||
// A batch carries ONE language (parakeet.cpp's batched C-API takes a single
|
||||
// target_lang), so a request whose language differs from the batch leader is
|
||||
// not coalesced: it is held in carry and becomes the leader of the next batch.
|
||||
// carry is therefore never dropped and its caller never deadlocks: every batch
|
||||
// (including a lone carry on stop) is dispatched, and runBatch replies to all.
|
||||
func (b *batcher) run(stop <-chan struct{}) {
|
||||
var carry *batchRequest
|
||||
for {
|
||||
var first *batchRequest
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
if carry != nil {
|
||||
// A mismatched request from the previous fill leads this batch.
|
||||
first, carry = carry, nil
|
||||
} else {
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
batch := []*batchRequest{first}
|
||||
|
||||
@@ -64,12 +80,22 @@ func (b *batcher) run(stop <-chan struct{}) {
|
||||
for len(batch) < b.maxSize {
|
||||
select {
|
||||
case r := <-b.submit:
|
||||
if r.language != first.language {
|
||||
// Different language: carry it to the next batch so this
|
||||
// batch stays single-language, then dispatch what we have.
|
||||
carry = r
|
||||
break fill
|
||||
}
|
||||
batch = append(batch, r)
|
||||
case <-timer.C:
|
||||
break fill
|
||||
case <-stop:
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
// Don't strand a carried request's caller on shutdown.
|
||||
if carry != nil {
|
||||
b.runBatch([]*batchRequest{carry})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,4 +105,60 @@ var _ = Describe("batcher", func() {
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("never coalesces requests with different languages into one batch", func() {
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang per batch, so the
|
||||
// dispatcher must keep every dispatched batch single-language. Submit a
|
||||
// mix of languages and assert (a) no batch ever carries more than one
|
||||
// distinct language and (b) every submitted request still gets a reply
|
||||
// (the mismatched carry-over is never dropped).
|
||||
var mu sync.Mutex
|
||||
var langsPerBatch [][]string
|
||||
run := func(reqs []*batchRequest) {
|
||||
seen := map[string]struct{}{}
|
||||
var distinct []string
|
||||
for _, r := range reqs {
|
||||
if _, ok := seen[r.language]; !ok {
|
||||
seen[r.language] = struct{}{}
|
||||
distinct = append(distinct, r.language)
|
||||
}
|
||||
}
|
||||
mu.Lock()
|
||||
langsPerBatch = append(langsPerBatch, distinct)
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
// Large window + size so the fill loop stays open across submits and the
|
||||
// language constraint (not the timer) is what splits the batches.
|
||||
b := newBatcher(16, 200*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
langs := []string{"en", "en", "de", "de", "en", "fr", "fr"}
|
||||
const N = 7
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), language: langs[i], reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// Invariant: every dispatched batch is single-language.
|
||||
for _, distinct := range langsPerBatch {
|
||||
Expect(len(distinct)).To(Equal(1), "a batch coalesced more than one language: %v", distinct)
|
||||
}
|
||||
// Liveness: every request got a reply (carry-over never stranded).
|
||||
for i := 0; i < N; i++ {
|
||||
Expect(got[i]).To(Equal(string(rune('a' + i))))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -47,6 +48,13 @@ var (
|
||||
// side reads them as const float*/const int*.
|
||||
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
|
||||
|
||||
// CppTranscribePcmBatchJSONLang is the multilingual variant of the batched
|
||||
// JSON entry point: identical, plus a trailing target_lang. "" (the model
|
||||
// default, "auto") is passed for non-prompt models, which ignore it; an
|
||||
// unknown locale on a prompt model returns 0 and sets last_error. Present
|
||||
// only in newer libparakeet.so; nil falls back to CppTranscribePcmBatchJSON.
|
||||
CppTranscribePcmBatchJSONLang func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32, targetLang string) uintptr
|
||||
|
||||
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
|
||||
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
|
||||
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
|
||||
@@ -54,6 +62,18 @@ var (
|
||||
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
|
||||
CppStreamFinalize func(s uintptr) uintptr
|
||||
CppStreamFree func(s uintptr)
|
||||
|
||||
// CppStreamBeginLang is the multilingual variant of stream_begin: identical,
|
||||
// plus a trailing target_lang ("" means the model default). Present only in
|
||||
// newer libparakeet.so; nil falls back to CppStreamBegin.
|
||||
CppStreamBeginLang func(ctx uintptr, targetLang string) uintptr
|
||||
|
||||
// Streaming JSON variants (ABI v4): feed/finalize returning a malloc'd char*
|
||||
// JSON document {text,eou,frame_sec,words} (uintptr, freed via CppFreeString)
|
||||
// so streaming segments can carry per-word timestamps. Present only in newer
|
||||
// libparakeet.so; nil falls back to the text-only CppStreamFeed/Finalize path.
|
||||
CppStreamFeedJSON func(s uintptr, pcm []float32, nSamples int32) uintptr
|
||||
CppStreamFinalizeJSON func(s uintptr) uintptr
|
||||
)
|
||||
|
||||
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
|
||||
@@ -71,9 +91,26 @@ const streamChunkSamples = 16000
|
||||
//
|
||||
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
|
||||
type transcriptJSON struct {
|
||||
Text string `json:"text"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
Text string `json:"text"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
}
|
||||
|
||||
// streamFeedJSON mirrors the document returned by
|
||||
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v4):
|
||||
//
|
||||
// {"text":"...","eou":0,"frame_sec":0.080000,
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
|
||||
// <EOU>/<EOB> fired this feed; "words" are the words finalized this call with
|
||||
// absolute (stream-relative) start/end seconds.
|
||||
type streamFeedJSON struct {
|
||||
Text string `json:"text"`
|
||||
Eou int `json:"eou"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
}
|
||||
|
||||
type transcriptWord struct {
|
||||
@@ -102,6 +139,10 @@ type ParakeetCpp struct {
|
||||
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
|
||||
bat *batcher
|
||||
batStop chan struct{}
|
||||
// segmentGapFrames is NeMo's segment_gap_threshold in ENCODER FRAMES (model
|
||||
// YAML option, default 0=off). When >0 it adds NeMo's silence-gap split on
|
||||
// top of the punctuation split; converted to seconds via the JSON frame_sec.
|
||||
segmentGapFrames int
|
||||
}
|
||||
|
||||
// Load is the LocalAI gRPC entry point for LoadModel: it calls
|
||||
@@ -131,6 +172,11 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
if maxWaitMs < 0 {
|
||||
maxWaitMs = 0
|
||||
}
|
||||
|
||||
// NeMo's segment_gap_threshold (encoder frames, default 0=off). Off by
|
||||
// default matches NeMo's default (punctuation-only segments); when set it
|
||||
// additionally splits segments on inter-word silence (see transcriptResultFromDoc).
|
||||
p.segmentGapFrames = optInt(opts, "segment_gap_threshold", 0)
|
||||
if CppTranscribePcmBatchJSON != nil {
|
||||
p.batStop = make(chan struct{})
|
||||
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
|
||||
@@ -186,8 +232,19 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
if len(reqs) > 0 {
|
||||
dec = reqs[0].decoder
|
||||
}
|
||||
// All requests in a batch share one language (the batcher coalesces only
|
||||
// same-language requests), so any element's language describes the batch.
|
||||
lang := ""
|
||||
if len(reqs) > 0 {
|
||||
lang = reqs[0].language
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
var cstr uintptr
|
||||
if CppTranscribePcmBatchJSONLang != nil {
|
||||
cstr = CppTranscribePcmBatchJSONLang(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec, lang)
|
||||
} else {
|
||||
cstr = CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
}
|
||||
p.engineMu.Unlock()
|
||||
if cstr == 0 {
|
||||
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
|
||||
@@ -225,21 +282,31 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// OpenAI API, whose default is segment-level); token ids always populate
|
||||
// Segment.Tokens.
|
||||
//
|
||||
// translate/diarize/prompt/temperature/language/threads are not applicable to
|
||||
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
|
||||
// translate/diarize/prompt/temperature/threads are not applicable to parakeet
|
||||
// and are ignored; language is honored on the batched + streaming paths (see
|
||||
// opts.GetLanguage() below); streaming is handled by AudioTranscriptionStream
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
|
||||
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
// Fallback when the batched C-API is unavailable: transcribe directly from
|
||||
// the file path (original behavior, no batching).
|
||||
// Fallback when the batched C-API is unavailable: transcribe from a file
|
||||
// path (original behavior, no batching). The C library's audio loader only
|
||||
// understands 16 kHz mono WAV/PCM, so convert the input first - otherwise
|
||||
// any non-WAV upload (MP3, etc.) fails with "failed to load audio". This
|
||||
// mirrors what every other audio backend (whisper, crispasr) does via
|
||||
// utils.AudioToWav before handing the file to the engine.
|
||||
if p.bat == nil {
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
|
||||
converted, cleanup, err := convertToWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, converted, 0)
|
||||
if cstr == 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
|
||||
}
|
||||
@@ -249,7 +316,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
}
|
||||
|
||||
// Batched path: decode to PCM, submit to the batcher, wait for this request's
|
||||
@@ -261,7 +328,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
|
||||
}
|
||||
rep := make(chan batchReply, 1)
|
||||
select {
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, language: opts.GetLanguage(), reply: rep}:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
@@ -278,34 +345,169 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
|
||||
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
}
|
||||
|
||||
// segmentSeparators is NeMo's default segment_seperators (sentence-ending
|
||||
// punctuation). Splitting on these matches NeMo's default segment timestamps.
|
||||
var segmentSeparators = []rune{'.', '?', '!'}
|
||||
|
||||
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
|
||||
// synthesising a single whole-clip segment and attaching word timings only when
|
||||
// the caller requested word granularity. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
|
||||
// grouping words into NeMo-faithful segments (see splitWordsIntoSegments). The
|
||||
// optional gapFrames (NeMo's segment_gap_threshold, in encoder FRAMES; 0=off)
|
||||
// additionally splits on inter-word silence; it is converted to a seconds gap
|
||||
// with the document's frame_sec. Per-segment word timings are attached only when
|
||||
// the caller requested word granularity; token ids populate each segment's
|
||||
// Tokens by time-window membership. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
|
||||
for _, w := range doc.Words {
|
||||
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
|
||||
|
||||
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
|
||||
gapSeconds := 0.0
|
||||
if gapFrames > 0 {
|
||||
if doc.FrameSec > 0 {
|
||||
gapSeconds = float64(gapFrames) * doc.FrameSec
|
||||
} else {
|
||||
xlog.Warn("parakeet-cpp: segment_gap_threshold set but libparakeet.so " +
|
||||
"did not report frame_sec; falling back to punctuation-only segments")
|
||||
}
|
||||
}
|
||||
tokens := make([]int32, 0, len(doc.Tokens))
|
||||
for _, t := range doc.Tokens {
|
||||
tokens = append(tokens, t.ID)
|
||||
|
||||
groups := splitWordsIntoSegments(doc.Words, segmentSeparators, gapSeconds)
|
||||
if len(groups) == 0 {
|
||||
// No words (edge case): single whole-clip text segment.
|
||||
return pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
|
||||
}
|
||||
}
|
||||
var segStart, segEnd int64
|
||||
if len(words) > 0 {
|
||||
segStart = words[0].Start
|
||||
segEnd = words[len(words)-1].End
|
||||
|
||||
wantWords := wordsRequested(opts.TimestampGranularities)
|
||||
segments := make([]*pb.TranscriptSegment, 0, len(groups))
|
||||
for id, group := range groups {
|
||||
parts := make([]string, len(group))
|
||||
for i, gw := range group {
|
||||
parts[i] = gw.W
|
||||
}
|
||||
seg := &pb.TranscriptSegment{
|
||||
Id: int32(id),
|
||||
Start: secondsToNanos(group[0].Start),
|
||||
End: secondsToNanos(group[len(group)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
Tokens: tokensInWindow(doc.Tokens, group[0].Start, group[len(group)-1].End),
|
||||
}
|
||||
if wantWords {
|
||||
ws := make([]*pb.TranscriptWord, len(group))
|
||||
for i, gw := range group {
|
||||
ws[i] = &pb.TranscriptWord{Start: secondsToNanos(gw.Start), End: secondsToNanos(gw.End), Text: gw.W}
|
||||
}
|
||||
seg.Words = ws
|
||||
}
|
||||
segments = append(segments, seg)
|
||||
}
|
||||
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
|
||||
if wordsRequested(opts.TimestampGranularities) {
|
||||
seg.Words = words
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
|
||||
return pb.TranscriptResult{Text: text, Segments: segments}
|
||||
}
|
||||
|
||||
// splitWordsIntoSegments groups words into segments exactly as NeMo's
|
||||
// get_segment_offsets does (nemo/collections/asr/parts/utils/timestamp_utils.py).
|
||||
// Walking the words, it closes a segment when (1) the gap rule is enabled
|
||||
// (gapSeconds > 0) and the segment already has words and the gap from the
|
||||
// previous word's end to this word's start is >= gapSeconds - the current word
|
||||
// then STARTS a new segment - or, checked only when the gap rule did not apply
|
||||
// (NeMo's elif), (2) the word ends with (or is) a separator, which closes the
|
||||
// segment INCLUDING that word. Trailing words flush into a final segment.
|
||||
// gapSeconds <= 0 disables the gap rule, matching NeMo's default
|
||||
// segment_gap_threshold=None (punctuation-only segments).
|
||||
func splitWordsIntoSegments(words []transcriptWord, separators []rune, gapSeconds float64) [][]transcriptWord {
|
||||
var segments [][]transcriptWord
|
||||
var cur []transcriptWord
|
||||
for i, word := range words {
|
||||
gapActive := gapSeconds > 0 && len(cur) > 0
|
||||
if gapActive && (word.Start-words[i-1].End) >= gapSeconds {
|
||||
segments = append(segments, cur)
|
||||
cur = []transcriptWord{word}
|
||||
continue
|
||||
}
|
||||
if !gapActive && endsWithSeparator(word.W, separators) {
|
||||
cur = append(cur, word)
|
||||
segments = append(segments, cur)
|
||||
cur = nil
|
||||
continue
|
||||
}
|
||||
cur = append(cur, word)
|
||||
}
|
||||
if len(cur) > 0 {
|
||||
segments = append(segments, cur)
|
||||
}
|
||||
return segments
|
||||
}
|
||||
|
||||
// endsWithSeparator reports whether w's last rune is in separators (matching
|
||||
// NeMo's `word[-1] in delims or word in delims`).
|
||||
func endsWithSeparator(w string, separators []rune) bool {
|
||||
r := []rune(strings.TrimSpace(w))
|
||||
if len(r) == 0 {
|
||||
return false
|
||||
}
|
||||
last := r[len(r)-1]
|
||||
for _, s := range separators {
|
||||
if last == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// tokensInWindow returns the ids of tokens whose timestamp t falls in
|
||||
// [start, end] (inclusive), assigning each token to the segment that spans its
|
||||
// time. The last segment's end is the last word end, so the final token is
|
||||
// included.
|
||||
func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
|
||||
var ids []int32
|
||||
for _, t := range tokens {
|
||||
if t.T >= start && t.T <= end {
|
||||
ids = append(ids, t.ID)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
|
||||
// is the model's own utterance boundary; each closed segment takes its start/end
|
||||
// from its first/last accumulated word.
|
||||
type streamSegmenter struct {
|
||||
segs []*pb.TranscriptSegment
|
||||
cur []transcriptWord
|
||||
nextID int32
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) add(doc streamFeedJSON) {
|
||||
s.cur = append(s.cur, doc.Words...)
|
||||
if doc.Eou != 0 {
|
||||
s.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) flush() {
|
||||
if len(s.cur) == 0 {
|
||||
return
|
||||
}
|
||||
parts := make([]string, len(s.cur))
|
||||
for i, w := range s.cur {
|
||||
parts[i] = w.W
|
||||
}
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{
|
||||
Id: s.nextID,
|
||||
Start: secondsToNanos(s.cur[0].Start),
|
||||
End: secondsToNanos(s.cur[len(s.cur)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
})
|
||||
s.nextID++
|
||||
s.cur = nil
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
|
||||
|
||||
// wordsRequested reports whether the caller asked for word-level timestamps.
|
||||
// The OpenAI transcription API gates word timings behind
|
||||
// timestamp_granularities[] containing "word" and defaults to segment-level
|
||||
@@ -342,7 +544,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
defer close(results)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return errors.New("parakeet-cpp: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
@@ -351,7 +553,12 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
stream := CppStreamBegin(p.ctxPtr)
|
||||
var stream uintptr
|
||||
if CppStreamBeginLang != nil {
|
||||
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
|
||||
} else {
|
||||
stream = CppStreamBegin(p.ctxPtr)
|
||||
}
|
||||
if stream == 0 {
|
||||
// Not a cache-aware streaming model: run a normal offline
|
||||
// transcription and emit it as one delta + a closing final result.
|
||||
@@ -380,6 +587,14 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return err
|
||||
}
|
||||
|
||||
// ABI v4: when the streaming JSON entry points are present, drive them so the
|
||||
// per-utterance segments carry per-word start/end timestamps. Falls through to
|
||||
// the text-only loop below against an older libparakeet.so. Runs under the
|
||||
// engineMu already held above.
|
||||
if CppStreamFeedJSON != nil {
|
||||
return p.streamJSON(ctx, stream, data, duration, results)
|
||||
}
|
||||
|
||||
var (
|
||||
full strings.Builder
|
||||
segText strings.Builder
|
||||
@@ -456,21 +671,102 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamJSON drives the ABI v4 streaming JSON entry points: each feed/finalize
|
||||
// returns a {text,eou,frame_sec,words} document. The newly-finalized text is
|
||||
// emitted as a delta (unchanged streaming contract) while words are accumulated
|
||||
// into per-utterance segments (closed on EOU) so the closing FinalResult carries
|
||||
// timestamped segments. Runs under engineMu (already held by the caller).
|
||||
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
|
||||
duration float32, results chan *pb.TranscriptStreamResponse) error {
|
||||
var (
|
||||
full strings.Builder
|
||||
seg streamSegmenter
|
||||
)
|
||||
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
|
||||
// emits the delta, and routes words through the segmenter.
|
||||
consume := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
raw := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
var doc streamFeedJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
|
||||
}
|
||||
if doc.Text != "" {
|
||||
full.WriteString(doc.Text)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
|
||||
}
|
||||
seg.add(doc)
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
|
||||
return err
|
||||
}
|
||||
seg.flush() // close any trailing utterance that never saw an EOU
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
segments := seg.segments()
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
|
||||
// float samples plus the clip duration in seconds. Mirrors the whisper
|
||||
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
|
||||
// decodes the PCM.
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
// convertToWavMono16k converts an arbitrary audio file to a 16 kHz mono WAV in
|
||||
// a fresh temp dir and returns the path together with a cleanup func the caller
|
||||
// must defer. WAV inputs already at 16 kHz/mono/16-bit are passed through by
|
||||
// utils.AudioToWav (hardlink/copy), everything else is transcoded via ffmpeg.
|
||||
// Used by the direct (non-batched) transcription path, which hands a file path
|
||||
// to the C library's WAV-only audio loader.
|
||||
func convertToWavMono16k(path string) (string, func(), error) {
|
||||
dir, err := os.MkdirTemp("", "parakeet")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return "", func() {}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
cleanup := func() { _ = os.RemoveAll(dir) }
|
||||
|
||||
converted := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(path, converted); err != nil {
|
||||
cleanup()
|
||||
return "", func() {}, err
|
||||
}
|
||||
return converted, cleanup, nil
|
||||
}
|
||||
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
converted, cleanup, err := convertToWavMono16k(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
fh, err := os.Open(converted)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,11 +3,14 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -50,6 +53,10 @@ func ensureLibLoaded() {
|
||||
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
|
||||
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
|
||||
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
|
||||
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
|
||||
})
|
||||
@@ -70,6 +77,24 @@ func fixturesOrSkip() (string, string) {
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// writeMono16kWav writes `samples` frames of 16 kHz mono 16-bit silence to
|
||||
// path. The result is already in AudioToWav's target format, so the conversion
|
||||
// helper copies it through without invoking ffmpeg.
|
||||
func writeMono16kWav(path string, samples int) {
|
||||
GinkgoHelper()
|
||||
f, err := os.Create(path)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
enc := wav.NewEncoder(f, 16000, 16, 1, 1)
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 16000},
|
||||
SourceBitDepth: 16,
|
||||
Data: make([]int, samples),
|
||||
}
|
||||
Expect(enc.Write(buf)).To(Succeed())
|
||||
Expect(enc.Close()).To(Succeed())
|
||||
Expect(f.Close()).To(Succeed())
|
||||
}
|
||||
|
||||
var _ = Describe("ParakeetCpp", func() {
|
||||
Context("AudioTranscription", func() {
|
||||
It("transcribes a WAV via the parakeet C-API", func() {
|
||||
@@ -86,13 +111,22 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
|
||||
"expected non-empty transcript for %s", audioPath)
|
||||
Expect(res.Segments).To(HaveLen(1),
|
||||
"synthesises a single whole-clip segment")
|
||||
Expect(res.Segments[0].Text).To(Equal(res.Text),
|
||||
"single segment text must equal the top-level text")
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(res.Segments[0].Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
// NeMo-faithful segmentation: one or more punctuation-delimited
|
||||
// segments, each with text and a monotonically-advancing time span.
|
||||
Expect(res.Segments).ToNot(BeEmpty(), "expected at least one segment")
|
||||
var prevEnd int64
|
||||
for i, seg := range res.Segments {
|
||||
Expect(strings.TrimSpace(seg.Text)).ToNot(BeEmpty(),
|
||||
"segment %d must have text", i)
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start),
|
||||
"segment %d end must not precede its start", i)
|
||||
Expect(seg.Start).To(BeNumerically(">=", prevEnd),
|
||||
"segments must be in time order")
|
||||
prevEnd = seg.End
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(seg.Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
}
|
||||
})
|
||||
|
||||
It("emits word-level timestamps when granularity=word", func() {
|
||||
@@ -108,15 +142,61 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
TimestampGranularities: []string{"word"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
seg := res.Segments[0]
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"expected per-word timestamps with granularity=word")
|
||||
// Monotonic, non-negative timings spanning the segment.
|
||||
Expect(seg.Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start))
|
||||
Expect(seg.Words[len(seg.Words)-1].End).To(Equal(seg.End),
|
||||
"segment end tracks the last word")
|
||||
Expect(res.Segments).ToNot(BeEmpty())
|
||||
// With word granularity every segment carries its own words, and each
|
||||
// segment's span tracks its first/last word; word starts advance
|
||||
// monotonically across the whole transcript.
|
||||
totalWords := 0
|
||||
var prevStart int64 = -1
|
||||
for i, seg := range res.Segments {
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"segment %d must carry per-word timestamps with granularity=word", i)
|
||||
Expect(seg.Start).To(Equal(seg.Words[0].Start),
|
||||
"segment %d start tracks its first word", i)
|
||||
Expect(seg.End).To(Equal(seg.Words[len(seg.Words)-1].End),
|
||||
"segment %d end tracks its last word", i)
|
||||
for _, w := range seg.Words {
|
||||
Expect(w.End).To(BeNumerically(">=", w.Start))
|
||||
Expect(w.Start).To(BeNumerically(">=", prevStart))
|
||||
prevStart = w.Start
|
||||
totalWords++
|
||||
}
|
||||
}
|
||||
Expect(totalWords).To(BeNumerically(">", 0))
|
||||
Expect(res.Segments[0].Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("convertToWavMono16k", func() {
|
||||
// The non-batched transcription path hands a file path to the C
|
||||
// library's WAV-only audio loader, so it must convert first.
|
||||
// utils.AudioToWav passes an already-16kHz/mono/16-bit WAV through
|
||||
// without ffmpeg, which lets us exercise the helper (and the
|
||||
// regression: the direct path used to skip conversion entirely)
|
||||
// without a model, the C library, or ffmpeg.
|
||||
It("returns a decodable 16kHz mono WAV copy and cleans it up", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
src := filepath.Join(dir, "input.wav")
|
||||
writeMono16kWav(src, 16000) // 1s of silence at 16 kHz
|
||||
|
||||
converted, cleanup, err := convertToWavMono16k(src)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// It must produce a fresh temp file, not return the original path.
|
||||
Expect(converted).ToNot(Equal(src))
|
||||
Expect(converted).To(BeAnExistingFile())
|
||||
|
||||
pcm, _, err := decodeWavMono16k(converted)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pcm).To(HaveLen(16000), "round-trips the sample count")
|
||||
|
||||
cleanup()
|
||||
Expect(converted).ToNot(BeAnExistingFile(), "cleanup removes the temp dir")
|
||||
})
|
||||
|
||||
It("errors on a non-existent input rather than passing the path through", func() {
|
||||
_, _, err := convertToWavMono16k(filepath.Join(GinkgoT().TempDir(), "missing.mp3"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -65,6 +65,25 @@ func main() {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
|
||||
// Per-request language variants (multilingual nemotron). Same probe pattern:
|
||||
// present only in libparakeet.so built with multilingual support, so the
|
||||
// backend still loads against an older library and falls back to the
|
||||
// non-lang batched + streaming entry points (model default / "auto").
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSONLang, lib, "parakeet_capi_transcribe_pcm_batch_json_lang")
|
||||
}
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_begin_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamBeginLang, lib, "parakeet_capi_stream_begin_lang")
|
||||
}
|
||||
|
||||
// Streaming JSON entry points (ABI v4): surface per-word timestamps on the
|
||||
// streaming path. Same probe pattern; absent in older libparakeet.so, where
|
||||
// the backend falls back to the text-only streaming feed.
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
127
backend/go/parakeet-cpp/segments_test.go
Normal file
127
backend/go/parakeet-cpp/segments_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func tw(text string, start, end float64) transcriptWord {
|
||||
return transcriptWord{W: text, Start: start, End: end}
|
||||
}
|
||||
|
||||
var _ = Describe("splitWordsIntoSegments (NeMo get_segment_offsets parity)", func() {
|
||||
seps := []rune{'.', '?', '!'}
|
||||
|
||||
It("splits on sentence-ending punctuation, including the delimiter word", func() {
|
||||
words := []transcriptWord{tw("hello", 0, 0.4), tw("world.", 0.4, 0.8), tw("bye", 1.0, 1.3)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[0][1].W).To(Equal("world."))
|
||||
Expect(segs[1]).To(HaveLen(1))
|
||||
Expect(segs[1][0].W).To(Equal("bye"))
|
||||
})
|
||||
|
||||
It("keeps a single segment with no terminal punctuation and gap off", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("splits on the gap rule when enabled, the gapped word starting the next segment", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0) // c is 4.6s after b
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2)) // a b
|
||||
Expect(segs[1]).To(HaveLen(1)) // c
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("checks the gap rule before punctuation (NeMo elif order)", func() {
|
||||
// "b." would terminate, but c is far after it -> gap closes [a b.] at b.
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b.", 0.2, 0.4), tw("c", 9.0, 9.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("still splits on punctuation when the gap rule is enabled but does not fire", func() {
|
||||
words := []transcriptWord{tw("hi.", 0, 0.4), tw("bye", 0.4, 0.8)}
|
||||
segs := splitWordsIntoSegments(words, seps, 5.0) // gap never reached
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0][0].W).To(Equal("hi."))
|
||||
})
|
||||
|
||||
It("returns nothing for empty input", func() {
|
||||
Expect(splitWordsIntoSegments(nil, seps, 0)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
|
||||
doc := transcriptJSON{
|
||||
Text: "hello world. bye now",
|
||||
FrameSec: 0.08,
|
||||
Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4},
|
||||
{W: "world.", Start: 0.4, End: 0.8},
|
||||
{W: "bye", Start: 1.0, End: 1.3},
|
||||
{W: "now", Start: 1.3, End: 1.6},
|
||||
},
|
||||
Tokens: []transcriptToken{{ID: 1, T: 0.1}, {ID: 2, T: 0.5}, {ID: 3, T: 1.1}, {ID: 4, T: 1.4}},
|
||||
}
|
||||
|
||||
It("emits one segment per punctuation-delimited group with start/end", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(2))
|
||||
Expect(res.Segments[0].Text).To(Equal("hello world."))
|
||||
Expect(res.Segments[0].Start).To(Equal(int64(0)))
|
||||
Expect(res.Segments[0].End).To(Equal(secondsToNanos(0.8)))
|
||||
Expect(res.Segments[1].Text).To(Equal("bye now"))
|
||||
Expect(res.Segments[1].Start).To(Equal(secondsToNanos(1.0)))
|
||||
Expect(res.Segments[1].Id).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("assigns tokens to the segment whose time window contains them", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments[0].Tokens).To(Equal([]int32{1, 2}))
|
||||
Expect(res.Segments[1].Tokens).To(Equal([]int32{3, 4}))
|
||||
})
|
||||
|
||||
It("attaches per-segment words only when word granularity requested", func() {
|
||||
plain := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(plain.Segments[0].Words).To(BeEmpty())
|
||||
withWords := transcriptResultFromDoc(doc, &pb.TranscriptRequest{TimestampGranularities: []string{"word"}}, 0)
|
||||
Expect(withWords.Segments[0].Words).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("falls back to a single text segment when there are no words", func() {
|
||||
res := transcriptResultFromDoc(transcriptJSON{Text: "hi"}, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
Expect(res.Segments[0].Text).To(Equal("hi"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("streaming segment assembly", func() {
|
||||
It("closes a segment with start/end from its words on EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
Expect(segs).To(HaveLen(1))
|
||||
Expect(segs[0].Text).To(Equal("hello world"))
|
||||
Expect(segs[0].Start).To(Equal(int64(0)))
|
||||
Expect(segs[0].End).To(Equal(secondsToNanos(0.9)))
|
||||
})
|
||||
|
||||
It("buffers words across feeds until EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
|
||||
Expect(acc.segments()).To(BeEmpty())
|
||||
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
|
||||
Expect(acc.segments()).To(HaveLen(1))
|
||||
Expect(acc.segments()[0].Text).To(Equal("hi there"))
|
||||
})
|
||||
})
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# qwen3-tts.cpp version
|
||||
QWEN3TTS_REPO?=https://github.com/predict-woo/qwen3-tts.cpp
|
||||
QWEN3TTS_CPP_VERSION?=7a762e2ad4bacc6fdda81d81bf10a09ffb546f29
|
||||
QWEN3TTS_CPP_VERSION?=136e5d36c17083da0321fd96512dc7b263f94a44
|
||||
SO_TARGET?=libgoqwen3ttscpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -21,6 +22,43 @@ type Qwen3TtsCpp struct {
|
||||
threads int
|
||||
}
|
||||
|
||||
// languageNameAliases maps common full language names to the canonical
|
||||
// two-letter code understood by the C++ language_to_id table.
|
||||
var languageNameAliases = map[string]string{
|
||||
"english": "en",
|
||||
"russian": "ru",
|
||||
"chinese": "zh",
|
||||
"japanese": "ja",
|
||||
"korean": "ko",
|
||||
"german": "de",
|
||||
"french": "fr",
|
||||
"spanish": "es",
|
||||
"italian": "it",
|
||||
"portuguese": "pt",
|
||||
}
|
||||
|
||||
// normalizeLanguage coerces a caller-supplied language into the canonical code
|
||||
// the model expects. It lowercases, trims, strips any region/locale suffix
|
||||
// (en-US, en_US, ja.JP -> en/ja), and resolves common full names (english -> en).
|
||||
// An empty input stays empty so the C++ side applies its English default; an
|
||||
// unrecognized value is returned normalized so C++ can log it and default.
|
||||
func normalizeLanguage(lang string) string {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
if lang == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip region/locale suffix: keep the segment before the first separator.
|
||||
if i := strings.IndexAny(lang, "-_."); i >= 0 {
|
||||
lang = lang[:i]
|
||||
}
|
||||
|
||||
if code, ok := languageNameAliases[lang]; ok {
|
||||
return code
|
||||
}
|
||||
return lang
|
||||
}
|
||||
|
||||
func (q *Qwen3TtsCpp) Load(opts *pb.ModelOptions) error {
|
||||
// ModelFile is the model directory path (containing GGUF files)
|
||||
modelDir := opts.ModelFile
|
||||
@@ -54,7 +92,7 @@ func (q *Qwen3TtsCpp) TTS(req *pb.TTSRequest) error {
|
||||
dst := req.Dst
|
||||
language := ""
|
||||
if req.Language != nil {
|
||||
language = *req.Language
|
||||
language = normalizeLanguage(*req.Language)
|
||||
}
|
||||
|
||||
// Synthesis parameters with sensible defaults
|
||||
|
||||
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLanguageNormalization(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "qwen3-tts-cpp language normalization")
|
||||
}
|
||||
|
||||
var _ = Describe("normalizeLanguage", func() {
|
||||
DescribeTable("maps caller input to the canonical model language code",
|
||||
func(input, expected string) {
|
||||
Expect(normalizeLanguage(input)).To(Equal(expected))
|
||||
},
|
||||
// Canonical codes pass through unchanged
|
||||
Entry("canonical en", "en", "en"),
|
||||
Entry("canonical zh", "zh", "zh"),
|
||||
Entry("canonical pt", "pt", "pt"),
|
||||
|
||||
// Case-insensitive
|
||||
Entry("uppercase", "EN", "en"),
|
||||
Entry("mixed case", "Ja", "ja"),
|
||||
|
||||
// Surrounding whitespace
|
||||
Entry("trims whitespace", " en ", "en"),
|
||||
|
||||
// Region/locale stripping
|
||||
Entry("BCP-47 region", "en-US", "en"),
|
||||
Entry("underscore region", "en_US", "en"),
|
||||
Entry("dotted locale", "ja.JP", "ja"),
|
||||
Entry("region + case", "ZH-CN", "zh"),
|
||||
|
||||
// Full-name aliases
|
||||
Entry("english name", "english", "en"),
|
||||
Entry("chinese name cased", "Chinese", "zh"),
|
||||
Entry("japanese name", "japanese", "ja"),
|
||||
Entry("russian name", "russian", "ru"),
|
||||
Entry("portuguese name", "portuguese", "pt"),
|
||||
|
||||
// Empty stays empty (C++ applies the English default)
|
||||
Entry("empty", "", ""),
|
||||
Entry("whitespace only", " ", ""),
|
||||
|
||||
// Unknown values pass through normalized so C++ can log + default
|
||||
Entry("unknown code", "klingon", "klingon"),
|
||||
Entry("unknown with region", "xx-YY", "xx"),
|
||||
)
|
||||
})
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=7948df8ac1070f5f6881b8d34675821893eb97d6
|
||||
STABLEDIFFUSION_GGML_VERSION?=b9254dda0d10b91ee6f17fb7f4420097dd29824b
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -386,6 +386,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *llm_vision_path = "";
|
||||
const char *diffusion_model_path = stableDiffusionModel;
|
||||
const char *high_noise_diffusion_model_path = "";
|
||||
const char *uncond_diffusion_model_path = "";
|
||||
const char *taesd_path = "";
|
||||
const char *control_net_path = "";
|
||||
const char *embedding_dir = "";
|
||||
@@ -472,6 +473,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval);
|
||||
if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "uncond_diffusion_model_path")) uncond_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
|
||||
if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
|
||||
if (!strcmp(optname, "embedding_dir")) {
|
||||
@@ -571,6 +573,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.llm_vision_path = llm_vision_path;
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.uncond_diffusion_model_path = uncond_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=23ee03506a91ac3d3f0071b40e66a430eebdfa1d
|
||||
WHISPER_CPP_VERSION?=a8ec021f2750a473ff4a8f3883bc9fdf5feafa84
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -37,6 +37,20 @@ def is_int(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a TTSRequest.params value (string on the wire) to the type the
|
||||
Chatterbox generate() kwargs expect (float/int/bool), matching how static
|
||||
YAML options are coerced at load time. Non-string values pass through."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if is_float(value):
|
||||
return float(value)
|
||||
if is_int(value):
|
||||
return int(value)
|
||||
if value.lower() in ["true", "false"]:
|
||||
return value.lower() == "true"
|
||||
return value
|
||||
|
||||
def split_text_at_word_boundary(text, max_length=250):
|
||||
"""
|
||||
Split text at word boundaries without truncating words.
|
||||
@@ -191,6 +205,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Merge per-request params (TTSRequest.params), overriding the static
|
||||
# YAML options. This exposes Chatterbox generation knobs (e.g.
|
||||
# exaggeration, cfg_weight, temperature) per request. Values arrive as
|
||||
# strings on the wire and are coerced to float/int/bool.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Check if text exceeds 250 characters
|
||||
# (chatterbox does not support long text)
|
||||
# https://github.com/resemble-ai/chatterbox/issues/60
|
||||
|
||||
@@ -47,6 +47,26 @@ def is_int(s):
|
||||
return False
|
||||
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a string param value (from the TTSRequest.params map, which is
|
||||
string-typed on the wire) into the most specific Python type the model
|
||||
generation kwargs expect: bool, int, float, else the original string."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
lowered = value.strip().lower()
|
||||
if lowered in ("true", "false"):
|
||||
return lowered == "true"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
@@ -322,6 +342,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def _effective_instruct(self, request):
|
||||
"""Resolve the instruction/style string for this request, preferring the
|
||||
per-request TTSRequest.instructions value and falling back to the static
|
||||
YAML `instruct` option. Empty string means "no instruction"."""
|
||||
req_instruct = (
|
||||
request.instructions
|
||||
if hasattr(request, "instructions") and request.instructions
|
||||
else ""
|
||||
)
|
||||
if req_instruct:
|
||||
return req_instruct
|
||||
return self.options.get("instruct", "") or ""
|
||||
|
||||
def _detect_mode(self, request):
|
||||
"""Detect which mode to use based on request parameters."""
|
||||
# Priority: VoiceClone > VoiceDesign > CustomVoice
|
||||
@@ -338,8 +371,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if self.audio_path or self.voices:
|
||||
return "VoiceClone"
|
||||
|
||||
# VoiceDesign: instruct option is provided
|
||||
if "instruct" in self.options and self.options["instruct"]:
|
||||
# VoiceDesign: instruct provided per-request or via YAML option
|
||||
if self._effective_instruct(request):
|
||||
return "VoiceDesign"
|
||||
|
||||
# Default to CustomVoice
|
||||
@@ -690,10 +723,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if do_sample is not None:
|
||||
generation_kwargs["do_sample"] = do_sample
|
||||
|
||||
instruct = self.options.get("instruct", "")
|
||||
# Prefer the per-request instruction (TTSRequest.instructions) over the
|
||||
# static YAML `instruct` option. This lets clients set a different style
|
||||
# (CustomVoice emotion) or designed voice (VoiceDesign) per request.
|
||||
instruct = self._effective_instruct(request)
|
||||
if instruct is not None and instruct != "":
|
||||
generation_kwargs["instruct"] = instruct
|
||||
|
||||
# Merge any per-request backend-specific params (TTSRequest.params).
|
||||
# Values arrive as strings on the wire; coerce to int/float/bool so the
|
||||
# model receives the types it expects. These override YAML-derived kwargs.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
generation_kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Generate audio based on mode
|
||||
if mode == "VoiceClone":
|
||||
# VoiceClone mode
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf==7.35.0
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -3,5 +3,5 @@
|
||||
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
||||
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
||||
# so uv consults this index alongside PyPI.
|
||||
--extra-index-url https://wheels.vllm.ai/0.22.0/cu130
|
||||
vllm==0.22.0
|
||||
--extra-index-url https://wheels.vllm.ai/0.22.1/cu130
|
||||
vllm==0.22.1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -102,7 +102,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
|
||||
natsAuth := cfg.Distributed.NatsAuthConfig()
|
||||
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
|
||||
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
|
||||
}
|
||||
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
|
||||
@@ -23,9 +23,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -308,10 +308,31 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
application.galleryService.SetNATSClient(distSvc.Nats)
|
||||
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
|
||||
// Clean up stale in-progress operations from previous crashed instances
|
||||
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
if _, err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to clean stale gallery operations", "error", err)
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
|
||||
// Reap stale ops periodically, not just at boot: an op orphaned by
|
||||
// a replica that died mid-install (its foreground handler goroutine
|
||||
// gone) would otherwise linger "processing" in the UI until the next
|
||||
// restart. 30m matches the install/upgrade ceiling so a genuinely
|
||||
// slow op is never reaped out from under itself.
|
||||
gsvc := application.galleryService
|
||||
go func() {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-options.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := gsvc.ReapStaleOperations(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to reap stale gallery operations", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
// Hydrate from the store first so the wildcard subscriber finds an
|
||||
// already-populated statuses map for any operations still in flight
|
||||
|
||||
@@ -214,7 +214,9 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
"from", info.InstalledVersion, "to", info.AvailableVersion)
|
||||
var err error
|
||||
if bm != nil {
|
||||
err = bm.UpgradeBackend(ctx, name, nil)
|
||||
// Background auto-upgrade: no live admin watching a progress bar,
|
||||
// so opID is empty and the distributed path skips progress streaming.
|
||||
err = bm.UpgradeBackend(ctx, "", name, nil)
|
||||
} else {
|
||||
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)
|
||||
|
||||
@@ -123,14 +123,14 @@ var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
|
||||
})
|
||||
|
||||
It("ModelTTS forwards the request context to the SmartRouter", func() {
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
|
||||
@@ -239,13 +239,13 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
|
||||
if c.Backend == "cloud-proxy" {
|
||||
opts.Proxy = &pb.ProxyOptions{
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,6 +323,12 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
metadata["enable_thinking"] = "true"
|
||||
}
|
||||
}
|
||||
// Forward the effective reasoning effort so the backend can pass it to the
|
||||
// jinja chat template (chat_template_kwargs.reasoning_effort) — the lever
|
||||
// models like gpt-oss / LFM2.5 actually read, distinct from enable_thinking.
|
||||
if c.ReasoningEffort != "" {
|
||||
metadata["reasoning_effort"] = c.ReasoningEffort
|
||||
}
|
||||
pbOpts.Metadata = metadata
|
||||
|
||||
// Logprobs and TopLogprobs are set by the caller if provided
|
||||
|
||||
@@ -75,3 +75,25 @@ var _ = Describe("gRPCPredictOpts enable_thinking metadata", func() {
|
||||
Expect(opts.Metadata).ToNot(HaveKey("enable_thinking"))
|
||||
})
|
||||
})
|
||||
|
||||
// Guards forwarding the effective reasoning_effort into PredictOptions.Metadata,
|
||||
// where the backend passes it to the jinja chat template (chat_template_kwargs)
|
||||
// so models like gpt-oss / LFM2.5 honor it.
|
||||
var _ = Describe("gRPCPredictOpts reasoning_effort metadata", func() {
|
||||
withEffort := func(effort string) config.ModelConfig {
|
||||
cfg := config.ModelConfig{}
|
||||
cfg.SetDefaults()
|
||||
cfg.ReasoningEffort = effort
|
||||
return cfg
|
||||
}
|
||||
|
||||
It("forwards reasoning_effort when set", func() {
|
||||
opts := gRPCPredictOpts(withEffort("none"), "/tmp/models")
|
||||
Expect(opts.Metadata).To(HaveKeyWithValue("reasoning_effort", "none"))
|
||||
})
|
||||
|
||||
It("omits reasoning_effort when empty", func() {
|
||||
opts := gRPCPredictOpts(withEffort(""), "/tmp/models")
|
||||
Expect(opts.Metadata).ToNot(HaveKey("reasoning_effort"))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -20,11 +20,32 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
// newTTSRequest assembles the gRPC TTSRequest from the per-request inputs. The
|
||||
// optional instructions string is only attached when non-empty so backends can
|
||||
// distinguish "no per-request instruction" (fall back to YAML) from an explicit
|
||||
// empty one. params is forwarded as-is (nil when unset).
|
||||
func newTTSRequest(text, modelPath, voice, dst, language, instructions string, params map[string]string) *proto.TTSRequest {
|
||||
req := &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: dst,
|
||||
Language: &language,
|
||||
Params: params,
|
||||
}
|
||||
if instructions != "" {
|
||||
req.Instructions = &instructions
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func ModelTTS(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -74,13 +95,9 @@ func ModelTTS(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: filePath,
|
||||
Language: &language,
|
||||
})
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, filePath, language, instructions, params)
|
||||
|
||||
res, err := ttsModel.TTS(ctx, ttsRequest)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
@@ -128,7 +145,9 @@ func ModelTTSStream(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -177,12 +196,10 @@ func ModelTTSStream(
|
||||
var totalPCMBytes int
|
||||
snippetCapped := false
|
||||
|
||||
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Language: &language,
|
||||
}, func(reply *proto.Reply) {
|
||||
// Streaming TTS writes to the HTTP response, not a file, so dst is empty.
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, "", language, instructions, params)
|
||||
|
||||
err = ttsModel.TTSStream(ctx, ttsRequest, func(reply *proto.Reply) {
|
||||
// First message contains sample rate info
|
||||
if !headerSent && len(reply.Message) > 0 {
|
||||
var info map[string]any
|
||||
|
||||
42
core/backend/tts_test.go
Normal file
42
core/backend/tts_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
// Specs for the TTSRequest assembly that carries the per-request
|
||||
// instructions/params from the OpenAI `instructions` field (and the LocalAI
|
||||
// `params` extension) through to the gRPC boundary. Before this plumbing the
|
||||
// instruction value was dropped before reaching the backend; these specs pin
|
||||
// that it now survives, and that the empty case stays backward compatible.
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("newTTSRequest", func() {
|
||||
It("attaches the instructions when a per-request value is set", func() {
|
||||
req := newTTSRequest("hi", "/m", "alloy", "/out.wav", "en", "cheerful narrator", nil)
|
||||
Expect(req.Instructions).ToNot(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal("cheerful narrator"))
|
||||
Expect(req.GetText()).To(Equal("hi"))
|
||||
Expect(req.GetVoice()).To(Equal("alloy"))
|
||||
Expect(req.GetDst()).To(Equal("/out.wav"))
|
||||
Expect(req.GetLanguage()).To(Equal("en"))
|
||||
})
|
||||
|
||||
It("leaves instructions unset when empty so backends fall back to YAML", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.Instructions).To(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal(""))
|
||||
})
|
||||
|
||||
It("forwards per-request params through to the backend", func() {
|
||||
params := map[string]string{"exaggeration": "0.7", "cfg_weight": "0.3"}
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", params)
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("exaggeration", "0.7"))
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("cfg_weight", "0.3"))
|
||||
})
|
||||
|
||||
It("leaves params nil when none are supplied", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.GetParams()).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -52,10 +52,28 @@ type AgentWorkerCMD struct {
|
||||
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
|
||||
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
|
||||
|
||||
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
|
||||
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
|
||||
// DistributedRequireAuth is the umbrella switch; for the agent worker (which
|
||||
// has no file-transfer server) it implies NATS auth is required.
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch implying --nats-require-auth (agent workers have no file-transfer server)" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
|
||||
// Timeouts
|
||||
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
|
||||
}
|
||||
|
||||
// natsAuthRequired reports whether NATS JWT credentials must be present — the
|
||||
// granular flag or the umbrella (LOCALAI_DISTRIBUTED_REQUIRE_AUTH).
|
||||
func (cmd *AgentWorkerCMD) natsAuthRequired() bool {
|
||||
return cmd.NatsRequireAuth || cmd.DistributedRequireAuth
|
||||
}
|
||||
|
||||
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
|
||||
|
||||
@@ -81,15 +99,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
registrationBody["token"] = cmd.RegistrationToken
|
||||
}
|
||||
|
||||
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
// Context cancelled on shutdown — used by registration waits, heartbeat, and
|
||||
// other background goroutines.
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Acquire credentials via (re)registration. When the bus requires auth and no
|
||||
// static fallback is configured, wait through admin approval until the
|
||||
// frontend mints credentials rather than starting unauthenticated.
|
||||
credMgr := workerregistry.NewNATSCredentialManager(
|
||||
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
|
||||
return regClient.RegisterFull(ctx, registrationBody)
|
||||
},
|
||||
cmd.natsAuthRequired() && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
|
||||
)
|
||||
res, err := credMgr.Acquire(shutdownCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
nodeID := res.ID
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||
|
||||
// Use provisioned API token if none was set
|
||||
if cmd.APIToken == "" {
|
||||
cmd.APIToken = apiToken
|
||||
cmd.APIToken = res.APIToken
|
||||
}
|
||||
|
||||
// Start heartbeat
|
||||
@@ -98,14 +131,40 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cmd.NatsURL)
|
||||
// Resolve NATS credentials with precedence: explicit env override, then
|
||||
// frontend-minted (auto-refreshed before expiry), then service fallback.
|
||||
// Each static source must supply JWT and seed together.
|
||||
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
|
||||
var natsOpts []messaging.Option
|
||||
switch {
|
||||
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
|
||||
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
|
||||
case credMgr.HasCredentials():
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
|
||||
go func() {
|
||||
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
|
||||
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
|
||||
shutdownCancel()
|
||||
}
|
||||
}()
|
||||
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
|
||||
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
|
||||
case cmd.natsAuthRequired():
|
||||
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
|
||||
}
|
||||
if natsTLS.Enabled() {
|
||||
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
|
||||
}
|
||||
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
@@ -183,17 +242,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
|
||||
|
||||
// Wait for shutdown
|
||||
// Wait for an OS signal or an internal fatal condition (e.g. NATS
|
||||
// credentials became unrenewable), so the worker restarts and re-acquires
|
||||
// rather than lingering unable to serve.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
var runErr error
|
||||
select {
|
||||
case <-sigCh:
|
||||
case <-shutdownCtx.Done():
|
||||
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
|
||||
xlog.Error("Internal shutdown requested", "error", runErr)
|
||||
}
|
||||
|
||||
xlog.Info("Shutting down agent worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
dispatcher.Stop()
|
||||
mcpTools.CloseAllMCPSessions()
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
return nil
|
||||
return runErr
|
||||
}
|
||||
|
||||
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
|
||||
|
||||
@@ -154,11 +154,21 @@ type RunCMD struct {
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
RegistrationRequireAuth bool `env:"LOCALAI_REGISTRATION_REQUIRE_AUTH" default:"false" help:"Fail startup when distributed mode is enabled but LOCALAI_REGISTRATION_TOKEN is empty (node endpoints and worker file-transfer server would otherwise be unauthenticated)" group:"distributed"`
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch: require BOTH NATS JWT credentials and a registration token when distributed mode is enabled (implies --nats-require-auth and --registration-require-auth)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
NatsAccountSeed string `env:"LOCALAI_NATS_ACCOUNT_SEED" help:"NATS account signing seed (SU...) used to mint per-node worker JWTs at registration" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"NATS user JWT for the frontend (and agent workers) to publish control-plane messages" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"NATS user signing seed (SU...) paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsWorkerJWTTTL string `env:"LOCALAI_NATS_WORKER_JWT_TTL" help:"Lifetime of minted per-node NATS JWTs (e.g. 24h, default 24h)" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT credentials (service JWT + account seed) when distributed mode is enabled" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI); use with tls:// in --nats-url" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
@@ -283,6 +293,40 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
if r.RegistrationRequireAuth {
|
||||
opts = append(opts, config.EnableRegistrationRequireAuth)
|
||||
}
|
||||
if r.DistributedRequireAuth {
|
||||
opts = append(opts, config.EnableDistributedRequireAuth)
|
||||
}
|
||||
if r.NatsAccountSeed != "" {
|
||||
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
|
||||
}
|
||||
if r.NatsServiceJWT != "" {
|
||||
opts = append(opts, config.WithNatsServiceJWT(r.NatsServiceJWT))
|
||||
}
|
||||
if r.NatsServiceSeed != "" {
|
||||
opts = append(opts, config.WithNatsServiceSeed(r.NatsServiceSeed))
|
||||
}
|
||||
if r.NatsWorkerJWTTTL != "" {
|
||||
d, err := time.ParseDuration(r.NatsWorkerJWTTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_WORKER_JWT_TTL %q: %w", r.NatsWorkerJWTTTL, err)
|
||||
}
|
||||
opts = append(opts, config.WithNatsWorkerJWTTTL(d))
|
||||
}
|
||||
if r.NatsRequireAuth {
|
||||
opts = append(opts, config.EnableNatsRequireAuth)
|
||||
}
|
||||
if r.NatsTLSCA != "" {
|
||||
opts = append(opts, config.WithNatsTLSCA(r.NatsTLSCA))
|
||||
}
|
||||
if r.NatsTLSCert != "" {
|
||||
opts = append(opts, config.WithNatsTLSCert(r.NatsTLSCert))
|
||||
}
|
||||
if r.NatsTLSKey != "" {
|
||||
opts = append(opts, config.WithNatsTLSKey(r.NatsTLSKey))
|
||||
}
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||
options.Backend = t.Backend
|
||||
options.Model = t.Model
|
||||
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, "", nil, ml, opts, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
|
||||
FrontendURL: r.RegisterTo,
|
||||
RegistrationToken: r.RegistrationToken,
|
||||
}
|
||||
nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
nodeID, _, _, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("registering with frontend: %w", regErr)
|
||||
}
|
||||
|
||||
@@ -58,65 +58,77 @@ func (c *RegistrationClient) setAuth(req *http.Request) {
|
||||
|
||||
// RegisterResponse is the JSON body returned by /api/node/register.
|
||||
type RegisterResponse struct {
|
||||
ID string `json:"id"`
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
NatsJWT string `json:"nats_jwt,omitempty"`
|
||||
NatsUserSeed string `json:"nats_user_seed,omitempty"`
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// (optionally) an auto-provisioned API token.
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
|
||||
// RegisterFull sends a single registration request and returns the full
|
||||
// response (node ID, approval status, and optional API token / NATS creds).
|
||||
// Re-registration is idempotent: the frontend preserves the node row and mints
|
||||
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
|
||||
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
url := c.baseURL() + "/api/node/register"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("creating request: %w", err)
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("posting to %s: %w", url, err)
|
||||
return nil, fmt.Errorf("posting to %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", "", fmt.Errorf("decoding response: %w", err)
|
||||
return nil, fmt.Errorf("decoding response: %w", err)
|
||||
}
|
||||
return result.ID, result.APIToken, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// optional credentials (API token for agent workers, NATS JWT when configured).
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
res, err := c.RegisterFull(ctx, body)
|
||||
if err != nil {
|
||||
return "", "", "", "", err
|
||||
}
|
||||
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
|
||||
}
|
||||
|
||||
// RegisterWithRetry retries registration with exponential backoff.
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
backoff := 2 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
|
||||
var nodeID, apiToken string
|
||||
var err error
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
nodeID, apiToken, err = c.Register(ctx, body)
|
||||
nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body)
|
||||
if err == nil {
|
||||
return nodeID, apiToken, nil
|
||||
return nodeID, apiToken, natsJWT, natsSeed, nil
|
||||
}
|
||||
if attempt == maxRetries {
|
||||
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", "", ctx.Err()
|
||||
return "", "", "", "", ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, maxBackoff)
|
||||
}
|
||||
return nodeID, apiToken, err
|
||||
return nodeID, apiToken, natsJWT, natsSeed, err
|
||||
}
|
||||
|
||||
// Heartbeat sends a single heartbeat POST with the given body.
|
||||
|
||||
200
core/cli/workerregistry/credentials.go
Normal file
200
core/cli/workerregistry/credentials.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// statusPending mirrors nodes.StatusPending. It is duplicated rather than
|
||||
// imported so the lightweight registration client does not pull in the nodes
|
||||
// package (and its gorm/DB dependencies).
|
||||
const statusPending = "pending"
|
||||
|
||||
// defaultMaxAttempts bounds how many times Acquire registers (and how many
|
||||
// consecutive times RefreshLoop may fail) before giving up. It is high enough
|
||||
// to ride out a slow admin approval or a transient frontend outage, but finite
|
||||
// so an unauthorized/unapprovable worker exits and surfaces the problem (via a
|
||||
// non-zero exit and the resulting restart) rather than waiting forever.
|
||||
const defaultMaxAttempts = 100
|
||||
|
||||
// RegisterFunc performs one idempotent registration round-trip.
|
||||
type RegisterFunc func(ctx context.Context) (*RegisterResponse, error)
|
||||
|
||||
// NATSCredentialManager acquires NATS credentials at startup — waiting through
|
||||
// admin approval when required — and refreshes them before the minted JWT
|
||||
// expires, by re-registering (which mints a fresh JWT). The live NATS
|
||||
// connection adopts a refreshed JWT on its next reconnect via Provider. Safe
|
||||
// for concurrent use.
|
||||
//
|
||||
// It addresses two failure modes: a worker that needs credentials but registers
|
||||
// while still pending approval (it would otherwise give up and never connect),
|
||||
// and a long-running worker whose 24h JWT expires with no way to renew it.
|
||||
type NATSCredentialManager struct {
|
||||
register RegisterFunc
|
||||
requireCreds bool // block until credentials are present (frontend minting in use)
|
||||
|
||||
// Tunables; defaults set by NewNATSCredentialManager, overridable in tests.
|
||||
initialBackoff time.Duration
|
||||
maxBackoff time.Duration
|
||||
maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited)
|
||||
refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed
|
||||
refreshRetry time.Duration
|
||||
expiryOf func(jwt string) (time.Time, bool)
|
||||
|
||||
mu sync.RWMutex
|
||||
jwt string
|
||||
seed string
|
||||
nodeID string
|
||||
}
|
||||
|
||||
// NewNATSCredentialManager builds a manager over register. When requireCreds is
|
||||
// true, Acquire blocks until the node is approved and credentials are minted.
|
||||
func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager {
|
||||
return &NATSCredentialManager{
|
||||
register: register,
|
||||
requireCreds: requireCreds,
|
||||
initialBackoff: 2 * time.Second,
|
||||
maxBackoff: 30 * time.Second,
|
||||
maxAttempts: defaultMaxAttempts,
|
||||
refreshLead: 0.75,
|
||||
refreshRetry: 30 * time.Second,
|
||||
expiryOf: jwtExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token
|
||||
// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT).
|
||||
func jwtExpiry(token string) (time.Time, bool) {
|
||||
if token == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
uc, err := natsauth.DecodeUserClaims(token)
|
||||
if err != nil || uc.Expires == 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Unix(uc.Expires, 0), true
|
||||
}
|
||||
|
||||
func (m *NATSCredentialManager) store(res *RegisterResponse) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.nodeID = res.ID
|
||||
if res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed
|
||||
}
|
||||
}
|
||||
|
||||
// Current returns the latest NATS credentials (both empty until acquired).
|
||||
func (m *NATSCredentialManager) Current() (jwt, seed string) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.jwt, m.seed
|
||||
}
|
||||
|
||||
// NodeID returns the node ID from the most recent registration.
|
||||
func (m *NATSCredentialManager) NodeID() string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.nodeID
|
||||
}
|
||||
|
||||
// Provider returns a callback compatible with messaging.WithUserJWTProvider,
|
||||
// supplying the current credentials on each (re)connect.
|
||||
func (m *NATSCredentialManager) Provider() func() (string, string) {
|
||||
return m.Current
|
||||
}
|
||||
|
||||
// HasCredentials reports whether complete NATS credentials have been obtained.
|
||||
func (m *NATSCredentialManager) HasCredentials() bool {
|
||||
jwt, seed := m.Current()
|
||||
return jwt != "" && seed != ""
|
||||
}
|
||||
|
||||
// Acquire registers and, when requireCreds is set, keeps re-registering with
|
||||
// exponential backoff until the node is approved (status != pending) and
|
||||
// credentials are minted. Without requireCreds it returns the first successful
|
||||
// response (the historical one-shot behavior, preserved for anonymous NATS).
|
||||
func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) {
|
||||
backoff := m.initialBackoff
|
||||
var lastReason error
|
||||
for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ {
|
||||
res, err := m.register(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
lastReason = err
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
case !m.requireCreds:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
case res.Status == statusPending:
|
||||
lastReason = fmt.Errorf("node %s still pending admin approval", res.ID)
|
||||
xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
case res.NatsJWT == "" || res.NatsUserSeed == "":
|
||||
lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID)
|
||||
xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
default:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, m.maxBackoff)
|
||||
}
|
||||
return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason)
|
||||
}
|
||||
|
||||
// RefreshLoop re-registers to mint a fresh JWT before the current one expires,
|
||||
// updating the credentials returned by Current/Provider so the NATS connection
|
||||
// adopts them on its next reconnect. It returns nil when ctx is cancelled or
|
||||
// when the current credential has no expiry (nothing to refresh), and a non-nil
|
||||
// error after maxAttempts consecutive refresh failures — letting the caller
|
||||
// exit the worker so it restarts and re-acquires (or surfaces the outage)
|
||||
// rather than silently drifting toward an expired, unrenewable JWT.
|
||||
func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error {
|
||||
failures := 0
|
||||
for {
|
||||
jwt, _ := m.Current()
|
||||
exp, ok := m.expiryOf(jwt)
|
||||
if !ok {
|
||||
xlog.Debug("NATS credential has no expiry; refresh loop exiting")
|
||||
return nil
|
||||
}
|
||||
wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(wait):
|
||||
}
|
||||
|
||||
res, err := m.register(ctx)
|
||||
if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.store(res)
|
||||
failures = 0
|
||||
xlog.Info("Refreshed NATS credentials", "node", res.ID)
|
||||
continue
|
||||
}
|
||||
failures++
|
||||
if err != nil {
|
||||
xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err)
|
||||
} else {
|
||||
xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures)
|
||||
}
|
||||
if m.maxAttempts > 0 && failures >= m.maxAttempts {
|
||||
return fmt.Errorf("NATS credential refresh failed %d times in a row", failures)
|
||||
}
|
||||
// Back off before retrying so a persistent failure near expiry does not spin.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(m.refreshRetry):
|
||||
}
|
||||
}
|
||||
}
|
||||
198
core/cli/workerregistry/credentials_test.go
Normal file
198
core/cli/workerregistry/credentials_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestWorkerRegistry(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "WorkerRegistry")
|
||||
}
|
||||
|
||||
// fakeRegister returns a sequence of canned responses/errors, one per call, and
|
||||
// records how many times it was invoked. The last entry repeats once exhausted.
|
||||
type fakeRegister struct {
|
||||
mu sync.Mutex
|
||||
steps []step
|
||||
calls int
|
||||
}
|
||||
|
||||
type step struct {
|
||||
res *RegisterResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeRegister) fn() RegisterFunc {
|
||||
return func(context.Context) (*RegisterResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
i := f.calls
|
||||
f.calls++
|
||||
if i >= len(f.steps) {
|
||||
i = len(f.steps) - 1
|
||||
}
|
||||
return f.steps[i].res, f.steps[i].err
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeRegister) count() int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls
|
||||
}
|
||||
|
||||
var _ = Describe("NATSCredentialManager", func() {
|
||||
approved := func(jwt, seed string) *RegisterResponse {
|
||||
return &RegisterResponse{ID: "node-1", Status: "healthy", NatsJWT: jwt, NatsUserSeed: seed}
|
||||
}
|
||||
pending := &RegisterResponse{ID: "node-1", Status: "pending"}
|
||||
|
||||
Describe("Acquire (#4 — wait through admin approval)", func() {
|
||||
It("keeps re-registering until the node is approved and credentials are minted", func() {
|
||||
f := &fakeRegister{steps: []step{
|
||||
{res: pending}, // not approved yet
|
||||
{res: approved("", "")}, // approved but JWT not minted yet
|
||||
{res: approved("jwt-1", "seed-1")}, // finally minted
|
||||
}}
|
||||
m := NewNATSCredentialManager(f.fn(), true /* requireCreds */)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.ID).To(Equal("node-1"))
|
||||
Expect(f.count()).To(Equal(3))
|
||||
|
||||
jwt, seed := m.Current()
|
||||
Expect(jwt).To(Equal("jwt-1"))
|
||||
Expect(seed).To(Equal("seed-1"))
|
||||
Expect(m.HasCredentials()).To(BeTrue())
|
||||
Expect(m.NodeID()).To(Equal("node-1"))
|
||||
})
|
||||
|
||||
It("returns immediately on the first success when credentials are not required (anonymous NATS)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), false /* requireCreds */)
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Status).To(Equal("pending"))
|
||||
Expect(f.count()).To(Equal(1))
|
||||
Expect(m.HasCredentials()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("aborts when the context is cancelled while waiting for approval", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = 10 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := m.Acquire(ctx)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
})
|
||||
|
||||
It("gives up after a bounded number of attempts so the worker exits and alerts", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}} // never approved
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
m.maxAttempts = 5
|
||||
|
||||
_, err := m.Acquire(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("after 5 attempts"))
|
||||
Expect(err.Error()).To(ContainSubstring("pending admin approval"))
|
||||
Expect(f.count()).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("RefreshLoop (#5 — renew before the JWT expires)", func() {
|
||||
It("re-registers before expiry and updates the credentials served to new connections", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("jwt-2", "seed-2")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
// jwt-1 expires soon; jwt-2 is long-lived so the loop then idles.
|
||||
m.expiryOf = func(jwt string) (time.Time, bool) {
|
||||
switch jwt {
|
||||
case "jwt-1":
|
||||
return time.Now().Add(40 * time.Millisecond), true
|
||||
case "jwt-2":
|
||||
return time.Now().Add(time.Hour), true
|
||||
default:
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() { _ = m.RefreshLoop(ctx) }()
|
||||
|
||||
Eventually(func() string {
|
||||
jwt, _ := m.Current()
|
||||
return jwt
|
||||
}, "2s", "10ms").Should(Equal("jwt-2"))
|
||||
})
|
||||
|
||||
It("returns an error after the bounded number of consecutive failures so the caller can exit", func() {
|
||||
f := &fakeRegister{steps: []step{{err: context.DeadlineExceeded}}} // refresh always fails
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
m.maxAttempts = 3
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Now().Add(time.Millisecond), true }
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- m.RefreshLoop(context.Background()) }()
|
||||
Eventually(errCh, "2s").Should(Receive(MatchError(ContainSubstring("3 times in a row"))))
|
||||
})
|
||||
|
||||
It("exits promptly when the current credential has no expiry (nothing to refresh)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("x", "y")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Time{}, false }
|
||||
m.store(approved("static", "seed"))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { _ = m.RefreshLoop(context.Background()); close(done) }()
|
||||
Eventually(done, "1s").Should(BeClosed())
|
||||
Expect(f.count()).To(Equal(0)) // never tried to re-register
|
||||
})
|
||||
})
|
||||
|
||||
Describe("jwtExpiry default", func() {
|
||||
It("decodes the expiry of a real minted worker JWT", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
cfg := natsauth.Config{AccountSeed: string(seed), WorkerJWTTTL: time.Hour}
|
||||
token, _, err := cfg.MintWorkerJWT("node-1", "backend")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
exp, ok := jwtExpiry(token)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(exp).To(BeTemporally("~", time.Now().Add(time.Hour), 2*time.Minute))
|
||||
})
|
||||
|
||||
It("reports no expiry for an empty or undecodable token", func() {
|
||||
_, ok := jwtExpiry("")
|
||||
Expect(ok).To(BeFalse())
|
||||
_, ok = jwtExpiry("not-a-jwt")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -22,9 +22,11 @@ const (
|
||||
UsecaseRerank = "rerank"
|
||||
UsecaseDetection = "detection"
|
||||
UsecaseVAD = "vad"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseFaceRecognition = "face_recognition"
|
||||
UsecaseSpeakerRecognition = "speaker_recognition"
|
||||
)
|
||||
|
||||
// GRPCMethod identifies a Backend service RPC from backend.proto.
|
||||
@@ -47,6 +49,11 @@ const (
|
||||
MethodAudioTransform GRPCMethod = "AudioTransform"
|
||||
MethodDiarize GRPCMethod = "Diarize"
|
||||
MethodAudioToAudioStream GRPCMethod = "AudioToAudioStream"
|
||||
MethodFaceVerify GRPCMethod = "FaceVerify"
|
||||
MethodFaceAnalyze GRPCMethod = "FaceAnalyze"
|
||||
MethodVoiceVerify GRPCMethod = "VoiceVerify"
|
||||
MethodVoiceEmbed GRPCMethod = "VoiceEmbed"
|
||||
MethodVoiceAnalyze GRPCMethod = "VoiceAnalyze"
|
||||
)
|
||||
|
||||
// UsecaseInfo describes a single known_usecase value and how it maps
|
||||
@@ -154,6 +161,16 @@ var UsecaseInfoMap = map[string]UsecaseInfo{
|
||||
GRPCMethod: MethodAudioToAudioStream,
|
||||
Description: "Self-contained any-to-any audio model for the Realtime API — accepts microphone audio and emits speech + transcript (+ optional function calls) from a single backend via the AudioToAudioStream RPC.",
|
||||
},
|
||||
UsecaseFaceRecognition: {
|
||||
Flag: FLAG_FACE_RECOGNITION,
|
||||
GRPCMethod: MethodFaceVerify,
|
||||
Description: "Face recognition — verify identity, analyze attributes (age/gender/emotion) via FaceVerify and FaceAnalyze RPCs.",
|
||||
},
|
||||
UsecaseSpeakerRecognition: {
|
||||
Flag: FLAG_SPEAKER_RECOGNITION,
|
||||
GRPCMethod: MethodVoiceVerify,
|
||||
Description: "Speaker recognition — verify identity, embed and analyze voice via VoiceVerify, VoiceEmbed and VoiceAnalyze RPCs.",
|
||||
},
|
||||
}
|
||||
|
||||
// BackendCapability describes which gRPC methods and usecases a backend supports.
|
||||
@@ -471,6 +488,21 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR C++ object detection",
|
||||
},
|
||||
|
||||
// --- Face and speaker recognition backends ---
|
||||
"insightface": {
|
||||
GRPCMethods: []GRPCMethod{MethodEmbedding, MethodDetect, MethodFaceVerify, MethodFaceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseEmbeddings, UsecaseDetection, UsecaseFaceRecognition},
|
||||
DefaultUsecases: []string{UsecaseFaceRecognition},
|
||||
AcceptsImages: true,
|
||||
Description: "InsightFace — face detection, embedding, verification and attribute analysis",
|
||||
},
|
||||
"speaker-recognition": {
|
||||
GRPCMethods: []GRPCMethod{MethodVoiceVerify, MethodVoiceEmbed, MethodVoiceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseSpeakerRecognition},
|
||||
DefaultUsecases: []string{UsecaseSpeakerRecognition},
|
||||
Description: "Speaker recognition — voice identity verification and analysis",
|
||||
},
|
||||
"silero-vad": {
|
||||
GRPCMethods: []GRPCMethod{MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseVAD},
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -16,7 +18,29 @@ type DistributedConfig struct {
|
||||
NatsURL string // --nats-url / LOCALAI_NATS_URL
|
||||
StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint)
|
||||
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
// RegistrationRequireAuth fails startup when distributed mode is enabled but
|
||||
// RegistrationToken is empty. The default (false) keeps the historical
|
||||
// fail-open behavior with a loud warning; production should set it so the
|
||||
// node-register endpoints and the worker file-transfer server cannot run
|
||||
// unauthenticated. Mirrors NatsRequireAuth for the NATS bus.
|
||||
RegistrationRequireAuth bool // LOCALAI_REGISTRATION_REQUIRE_AUTH
|
||||
// RequireAuth is the umbrella switch (LOCALAI_DISTRIBUTED_REQUIRE_AUTH) for
|
||||
// distributed-mode auth: when true it implies BOTH NatsRequireAuth and
|
||||
// RegistrationRequireAuth, so a single knob locks down the bus and the
|
||||
// registration/file-transfer layer together. The granular flags remain
|
||||
// available to enforce just one layer.
|
||||
RequireAuth bool // LOCALAI_DISTRIBUTED_REQUIRE_AUTH
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
|
||||
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
|
||||
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
|
||||
NatsServiceJWT string // LOCALAI_NATS_SERVICE_JWT — user JWT for frontends / agent workers
|
||||
NatsServiceSeed string // LOCALAI_NATS_SERVICE_SEED — signing seed paired with service JWT
|
||||
NatsWorkerJWTTTL time.Duration // LOCALAI_NATS_WORKER_JWT_TTL — minted worker JWT lifetime (default 24h)
|
||||
NatsRequireAuth bool // LOCALAI_NATS_REQUIRE_AUTH — fail startup if NATS credentials are missing
|
||||
NatsTLSCA string // LOCALAI_NATS_TLS_CA — PEM file for private CA (server verify)
|
||||
NatsTLSCert string // LOCALAI_NATS_TLS_CERT — client cert for NATS mTLS
|
||||
NatsTLSKey string // LOCALAI_NATS_TLS_KEY — client key paired with NatsTLSCert
|
||||
|
||||
// S3 configuration (used when StorageURL is set)
|
||||
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
|
||||
@@ -76,10 +100,23 @@ func (c DistributedConfig) Validate() error {
|
||||
(c.StorageAccessKey == "" && c.StorageSecretKey != "") {
|
||||
return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty")
|
||||
}
|
||||
// Warn about missing registration token (not an error)
|
||||
// The registration token guards both the node HTTP register/heartbeat
|
||||
// endpoints and the worker file-transfer server (which fails open on an
|
||||
// empty token). Enforce it when registration auth is required (the granular
|
||||
// flag or the umbrella); otherwise warn.
|
||||
if c.RegistrationToken == "" {
|
||||
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
|
||||
if c.RegistrationAuthRequired() {
|
||||
return fmt.Errorf("registration auth is required (LOCALAI_REGISTRATION_REQUIRE_AUTH or LOCALAI_DISTRIBUTED_REQUIRE_AUTH) but LOCALAI_REGISTRATION_TOKEN is empty")
|
||||
}
|
||||
xlog.Warn("distributed mode running without registration token — node endpoints and the worker file-transfer server are unprotected; set LOCALAI_REGISTRATION_TOKEN, or LOCALAI_DISTRIBUTED_REQUIRE_AUTH=true to fail closed")
|
||||
}
|
||||
if err := c.NatsAuthConfig().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.NatsTLSFiles().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.NatsAuthConfig().WarnIfInsecure(true)
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
FlagMCPToolTimeout: c.MCPToolTimeout,
|
||||
@@ -123,6 +160,76 @@ func WithRegistrationToken(token string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsAccountSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsAccountSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceJWT(jwt string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceJWT = jwt
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsWorkerJWTTTL(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsWorkerJWTTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
var EnableNatsRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsRequireAuth = true
|
||||
}
|
||||
|
||||
// EnableRegistrationRequireAuth makes an empty registration token a hard error
|
||||
// in distributed mode (see DistributedConfig.RegistrationRequireAuth).
|
||||
var EnableRegistrationRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.RegistrationRequireAuth = true
|
||||
}
|
||||
|
||||
// EnableDistributedRequireAuth is the umbrella switch implying both
|
||||
// NatsRequireAuth and RegistrationRequireAuth (see DistributedConfig.RequireAuth).
|
||||
var EnableDistributedRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.RequireAuth = true
|
||||
}
|
||||
|
||||
// RegistrationAuthRequired reports whether an empty registration token must be
|
||||
// treated as a fatal misconfiguration — the granular flag or the umbrella.
|
||||
func (c DistributedConfig) RegistrationAuthRequired() bool {
|
||||
return c.RegistrationRequireAuth || c.RequireAuth
|
||||
}
|
||||
|
||||
// NatsAuthRequired reports whether NATS JWT credentials must be present — the
|
||||
// granular flag or the umbrella.
|
||||
func (c DistributedConfig) NatsAuthRequired() bool {
|
||||
return c.NatsRequireAuth || c.RequireAuth
|
||||
}
|
||||
|
||||
func WithNatsTLSCA(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCA = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSCert(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCert = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSKey(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSKey = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageURL(url string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageURL = url
|
||||
@@ -217,6 +324,44 @@ const (
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// NatsTLSFiles returns NATS TLS/mTLS PEM paths for the messaging client.
|
||||
func (c DistributedConfig) NatsTLSFiles() messaging.TLSFiles {
|
||||
return messaging.TLSFiles{
|
||||
CA: c.NatsTLSCA,
|
||||
Cert: c.NatsTLSCert,
|
||||
Key: c.NatsTLSKey,
|
||||
}
|
||||
}
|
||||
|
||||
// NatsMessagingOptions builds messaging client options (JWT + TLS) for distributed components.
|
||||
// Pass explicit userJWT/userSeed when set (e.g. worker overrides); empty uses service JWT from config.
|
||||
func (c DistributedConfig) NatsMessagingOptions(userJWT, userSeed string) []messaging.Option {
|
||||
var opts []messaging.Option
|
||||
jwt, seed := userJWT, userSeed
|
||||
if jwt == "" && seed == "" {
|
||||
auth := c.NatsAuthConfig()
|
||||
jwt, seed = auth.ServiceUserJWT, auth.ServiceUserSeed
|
||||
}
|
||||
if jwt != "" && seed != "" {
|
||||
opts = append(opts, messaging.WithUserJWT(jwt, seed))
|
||||
}
|
||||
if tls := c.NatsTLSFiles(); tls.Enabled() {
|
||||
opts = append(opts, messaging.WithTLS(tls))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// NatsAuthConfig builds pkg/natsauth settings from distributed configuration.
|
||||
func (c DistributedConfig) NatsAuthConfig() natsauth.Config {
|
||||
return natsauth.Config{
|
||||
AccountSeed: c.NatsAccountSeed,
|
||||
ServiceUserJWT: c.NatsServiceJWT,
|
||||
ServiceUserSeed: c.NatsServiceSeed,
|
||||
WorkerJWTTTL: c.NatsWorkerJWTTTL,
|
||||
RequireAuth: c.NatsAuthRequired(),
|
||||
}
|
||||
}
|
||||
|
||||
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||
|
||||
@@ -88,3 +88,66 @@ var _ = Describe("DistributedConfig.Validate negative-duration errors", func() {
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("DistributedConfig.Validate registration auth", func() {
|
||||
It("rejects an empty registration token when RequireAuth is set", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RegistrationRequireAuth: true,
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_REQUIRE_AUTH"))
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
|
||||
})
|
||||
|
||||
It("accepts a set registration token when RequireAuth is set", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RegistrationToken: "s3cret",
|
||||
RegistrationRequireAuth: true,
|
||||
}
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
|
||||
It("warns but succeeds with an empty token when RequireAuth is unset", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
}
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects an empty token when the umbrella RequireAuth is set", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RequireAuth: true,
|
||||
// Provide NATS creds so only the registration-token gap remains.
|
||||
NatsServiceJWT: "jwt",
|
||||
NatsServiceSeed: "seed",
|
||||
NatsAccountSeed: "acct",
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_DISTRIBUTED_REQUIRE_AUTH"))
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
|
||||
})
|
||||
|
||||
It("the umbrella implies NATS auth is required", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RegistrationToken: "tok", // registration layer satisfied
|
||||
RequireAuth: true, // umbrella → NATS creds now required
|
||||
}
|
||||
Expect(c.NatsAuthRequired()).To(BeTrue())
|
||||
Expect(c.RegistrationAuthRequired()).To(BeTrue())
|
||||
// Missing NATS service JWT/seed must now be fatal.
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_NATS_REQUIRE_AUTH"))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -39,7 +39,21 @@ func llamaCppDefaults(cfg *ModelConfig, modelPath string) {
|
||||
}
|
||||
}()
|
||||
|
||||
f, err := gguf.ParseGGUFFile(guessPath)
|
||||
// Startup parses every model's GGUF header to guess defaults. We only need
|
||||
// scalar metadata (architecture, head/ff counts, chat_template, token IDs,
|
||||
// MTP head) plus array *lengths* — never the array *contents*. Two options
|
||||
// keep this cheap, which matters when many models live on slow storage such
|
||||
// as a Docker volume (see https://github.com/mudler/LocalAI/issues/9790):
|
||||
//
|
||||
// - SkipLargeMetadata: seek past large array-valued metadata (the tokenizer
|
||||
// vocab: tokenizer.ggml.tokens/scores/merges, often >100k entries) instead
|
||||
// of reading and allocating every element. Lengths stay populated.
|
||||
// - UseMMap: read the header via a memory map so faulting in a few pages
|
||||
// replaces hundreds of thousands of tiny read() syscalls (measured ~524k
|
||||
// -> 8 for a 256k-token vocab), the dominant cost on slow filesystems.
|
||||
//
|
||||
// The mapping is released when ParseGGUFFile returns.
|
||||
f, err := gguf.ParseGGUFFile(guessPath, gguf.UseMMap(), gguf.SkipLargeMetadata())
|
||||
if err == nil {
|
||||
guessGGUFFromFile(cfg, f, 0)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,76 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// GGUF metadata value type tags (see github.com/gpustack/gguf-parser-go).
|
||||
const (
|
||||
ggufTypeUint32 uint32 = 4
|
||||
ggufTypeString uint32 = 8
|
||||
ggufTypeArray uint32 = 9
|
||||
)
|
||||
|
||||
// writeTestGGUF emits a minimal but valid little-endian GGUF v3 header carrying
|
||||
// the scalar metadata the llama-cpp hook guesses from plus a large string vocab
|
||||
// array (tokenizer.ggml.tokens). The big array is exactly what SkipLargeMetadata
|
||||
// + UseMMap are expected to avoid reading element-by-element, so it must survive a
|
||||
// round-trip through the real hook without corrupting the guessed defaults.
|
||||
func writeTestGGUF(path, chatTemplate string, vocab int) error {
|
||||
wStr := func(b *bytes.Buffer, s string) {
|
||||
binary.Write(b, binary.LittleEndian, uint64(len(s)))
|
||||
b.WriteString(s)
|
||||
}
|
||||
kvStr := func(b *bytes.Buffer, k, v string) {
|
||||
wStr(b, k)
|
||||
binary.Write(b, binary.LittleEndian, ggufTypeString)
|
||||
wStr(b, v)
|
||||
}
|
||||
kvU32 := func(b *bytes.Buffer, k string, v uint32) {
|
||||
wStr(b, k)
|
||||
binary.Write(b, binary.LittleEndian, ggufTypeUint32)
|
||||
binary.Write(b, binary.LittleEndian, v)
|
||||
}
|
||||
|
||||
var meta bytes.Buffer
|
||||
kvStr(&meta, "general.architecture", "llama")
|
||||
kvStr(&meta, "general.name", "ReproModel")
|
||||
kvU32(&meta, "llama.context_length", 4096)
|
||||
kvU32(&meta, "llama.attention.head_count", 32)
|
||||
kvU32(&meta, "llama.feed_forward_length", 11008)
|
||||
kvU32(&meta, "llama.block_count", 32)
|
||||
kvU32(&meta, "tokenizer.ggml.bos_token_id", 1)
|
||||
kvStr(&meta, "tokenizer.chat_template", chatTemplate)
|
||||
|
||||
// large array value — the one the optimization skips reading
|
||||
wStr(&meta, "tokenizer.ggml.tokens")
|
||||
binary.Write(&meta, binary.LittleEndian, ggufTypeArray)
|
||||
binary.Write(&meta, binary.LittleEndian, ggufTypeString)
|
||||
binary.Write(&meta, binary.LittleEndian, uint64(vocab))
|
||||
for i := 0; i < vocab; i++ {
|
||||
wStr(&meta, "token")
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
binary.Write(&out, binary.LittleEndian, gguf.GGUFMagicGGUFLe)
|
||||
binary.Write(&out, binary.LittleEndian, uint32(3)) // version
|
||||
binary.Write(&out, binary.LittleEndian, uint64(0)) // tensor count
|
||||
binary.Write(&out, binary.LittleEndian, uint64(9)) // metadata kv count
|
||||
out.Write(meta.Bytes())
|
||||
|
||||
return os.WriteFile(path, out.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
var _ = Describe("Backend hooks and parser defaults", func() {
|
||||
Context("MatchParserDefaults", func() {
|
||||
It("matches Qwen3 family", func() {
|
||||
@@ -137,6 +200,58 @@ var _ = Describe("Backend hooks and parser defaults", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("llamaCppDefaults GGUF guessing", func() {
|
||||
// Regression coverage for https://github.com/mudler/LocalAI/issues/9790:
|
||||
// the hook reads GGUF headers with SkipLargeMetadata + UseMMap to avoid
|
||||
// pulling the whole tokenizer vocab off (slow) disk on every startup. This
|
||||
// verifies that skipping the vocab array still yields the correct guessed
|
||||
// defaults from the remaining scalar metadata.
|
||||
const chatTemplate = "{{ bos_token }}{% for m in messages %}{{ m.content }}{% endfor %}"
|
||||
|
||||
It("guesses defaults from a GGUF whose large vocab is skipped", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
modelFile := "repro.gguf"
|
||||
Expect(writeTestGGUF(filepath.Join(dir, modelFile), chatTemplate, 50000)).To(Succeed())
|
||||
|
||||
// A pre-set context size short-circuits the GGUF run-estimate, which
|
||||
// needs full tensor info this header-only fixture deliberately omits;
|
||||
// the metadata-reading path the optimization touches is unaffected.
|
||||
ctxSize := 4096
|
||||
cfg := &ModelConfig{
|
||||
Backend: "llama-cpp",
|
||||
LLMConfig: LLMConfig{ContextSize: &ctxSize},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: modelFile},
|
||||
},
|
||||
}
|
||||
cfg.SetDefaults(ModelPath(dir))
|
||||
|
||||
// chat_template is a scalar string, not part of the skipped array,
|
||||
// so it must be captured verbatim.
|
||||
Expect(cfg.GetModelTemplate()).To(Equal(chatTemplate))
|
||||
// scalar-derived defaults are still applied
|
||||
Expect(cfg.ContextSize).NotTo(BeNil())
|
||||
Expect(cfg.NGPULayers).NotTo(BeNil())
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
Expect(cfg.KnownUsecaseStrings).To(ContainElement("FLAG_CHAT"))
|
||||
})
|
||||
|
||||
It("falls back to the default context size when the GGUF is unreadable", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
Expect(os.WriteFile(filepath.Join(dir, "bad.gguf"), []byte("not a gguf"), 0o644)).To(Succeed())
|
||||
|
||||
cfg := &ModelConfig{
|
||||
Backend: "llama-cpp",
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: "bad.gguf"},
|
||||
},
|
||||
}
|
||||
cfg.SetDefaults(ModelPath(dir))
|
||||
|
||||
Expect(cfg.ContextSize).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("PromptCacheAll default", func() {
|
||||
It("defaults to true when omitted from YAML", func() {
|
||||
cfg := &ModelConfig{}
|
||||
|
||||
@@ -128,6 +128,22 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Advanced: true,
|
||||
Order: 21,
|
||||
},
|
||||
"reasoning_effort": {
|
||||
Section: "llm",
|
||||
Label: "Reasoning Effort",
|
||||
Description: "Default reasoning effort, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). A per-request reasoning_effort overrides it. 'none' also turns thinking off.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Unset (model default)"},
|
||||
{Value: "none", Label: "none (disable thinking)"},
|
||||
{Value: "minimal", Label: "minimal"},
|
||||
{Value: "low", Label: "low"},
|
||||
{Value: "medium", Label: "medium"},
|
||||
{Value: "high", Label: "high"},
|
||||
},
|
||||
Advanced: true,
|
||||
Order: 22,
|
||||
},
|
||||
"cache_type_k": {
|
||||
Section: "llm",
|
||||
Label: "KV Cache Type (K)",
|
||||
@@ -277,6 +293,21 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
AutocompleteProvider: ProviderModelsVAD,
|
||||
Order: 63,
|
||||
},
|
||||
"pipeline.reasoning_effort": {
|
||||
Section: "pipeline",
|
||||
Label: "Reasoning Effort",
|
||||
Description: "Reasoning effort for the pipeline's LLM, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). Overrides the LLM model's own reasoning_effort. 'none' also turns thinking off.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Default (model config)"},
|
||||
{Value: "none", Label: "none (disable thinking)"},
|
||||
{Value: "minimal", Label: "minimal"},
|
||||
{Value: "low", Label: "low"},
|
||||
{Value: "medium", Label: "medium"},
|
||||
{Value: "high", Label: "high"},
|
||||
},
|
||||
Order: 64,
|
||||
},
|
||||
|
||||
// --- Functions ---
|
||||
"function.grammar.parallel_calls": {
|
||||
|
||||
@@ -63,6 +63,13 @@ type ModelConfig struct {
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
|
||||
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||
|
||||
// ReasoningEffort is the default reasoning effort (none|minimal|low|medium|high)
|
||||
// for this model. A per-request reasoning_effort overrides it. It is forwarded
|
||||
// to the backend as the reasoning_effort chat_template_kwarg (see
|
||||
// gRPCPredictOpts), so jinja-templated models that key on it — e.g. gpt-oss
|
||||
// (Harmony) or LFM2.5 — honor it; "none" also toggles enable_thinking off.
|
||||
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
|
||||
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
LLMConfig `yaml:",inline" json:",inline"`
|
||||
@@ -487,6 +494,40 @@ type Pipeline struct {
|
||||
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
|
||||
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
|
||||
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
|
||||
|
||||
// ReasoningEffort sets the reasoning effort (none|minimal|low|medium|high) for
|
||||
// the pipeline's LLM without editing the LLM model config. Overrides the LLM's
|
||||
// own reasoning_effort. Unset leaves the LLM model config in charge.
|
||||
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
|
||||
}
|
||||
|
||||
// ApplyReasoningEffort resolves the effective reasoning effort — a per-request
|
||||
// value (requestEffort) overrides the config's own ReasoningEffort default —
|
||||
// stores it on the config so gRPCPredictOpts forwards it to the backend as the
|
||||
// reasoning_effort chat_template_kwarg, and maps it onto the enable_thinking
|
||||
// toggle the backend also reads:
|
||||
// - "none" always disables thinking.
|
||||
// - any explicit level enables it, UNLESS the config already disabled reasoning
|
||||
// (an operator's explicit disable wins over a request asking to think).
|
||||
//
|
||||
// An empty requestEffort keeps the config's own default. With no effort set
|
||||
// anywhere it is a no-op, leaving the model's reasoning settings untouched.
|
||||
func (c *ModelConfig) ApplyReasoningEffort(requestEffort string) {
|
||||
effort := requestEffort
|
||||
if effort == "" {
|
||||
effort = c.ReasoningEffort
|
||||
}
|
||||
c.ReasoningEffort = effort
|
||||
switch strings.ToLower(effort) {
|
||||
case "none":
|
||||
disable := true
|
||||
c.ReasoningConfig.DisableReasoning = &disable
|
||||
case "minimal", "low", "medium", "high":
|
||||
if c.ReasoningConfig.DisableReasoning == nil || !*c.ReasoningConfig.DisableReasoning {
|
||||
enable := false
|
||||
c.ReasoningConfig.DisableReasoning = &enable
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// @Description File configuration for model downloads
|
||||
|
||||
@@ -30,11 +30,26 @@ func MTPSpecOptions() []string {
|
||||
return out
|
||||
}
|
||||
|
||||
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a Multi-Token
|
||||
// Prediction head. Detection reads `<arch>.nextn_predict_layers`, which is
|
||||
// what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
|
||||
// isDraftOnlyAssistantArch reports whether an architecture names a standalone
|
||||
// MTP *draft* model rather than a self-speculating trunk. Upstream's Gemma4 MTP
|
||||
// (ggml-org/llama.cpp#23398) registers the head as a separate `gemma4-assistant`
|
||||
// architecture whose GGUF still carries `nextn_predict_layers`, but which cannot
|
||||
// run alone: it requires a paired target context (`ctx_other`). Such archs must
|
||||
// not trigger the embedded-head self-speculation defaults. The `-assistant`
|
||||
// suffix is upstream's naming convention for these draft-only checkpoints.
|
||||
func isDraftOnlyAssistantArch(arch string) bool {
|
||||
return strings.HasSuffix(arch, "-assistant")
|
||||
}
|
||||
|
||||
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a self-speculating
|
||||
// Multi-Token Prediction head. Detection reads `<arch>.nextn_predict_layers`,
|
||||
// which is what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
|
||||
// `conversion/qwen.py` MTP mixin. A positive layer count means the head is
|
||||
// present in the same GGUF as the trunk.
|
||||
//
|
||||
// Draft-only assistant architectures (e.g. Gemma4's `gemma4-assistant`) carry
|
||||
// the same key but are separate draft checkpoints meant to be paired with a
|
||||
// target model, so they are deliberately excluded here.
|
||||
func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
|
||||
if f == nil {
|
||||
return 0, false
|
||||
@@ -43,6 +58,9 @@ func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
|
||||
if arch == "" {
|
||||
return 0, false
|
||||
}
|
||||
if isDraftOnlyAssistantArch(arch) {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := f.Header.MetadataKV.Get(arch + ".nextn_predict_layers")
|
||||
if !ok {
|
||||
return 0, false
|
||||
|
||||
@@ -3,10 +3,33 @@ package config_test
|
||||
import (
|
||||
. "github.com/mudler/LocalAI/core/config"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// ggufWithArch fabricates a minimal in-memory GGUF carrying the given
|
||||
// `general.architecture` and a positive `<arch>.nextn_predict_layers` count,
|
||||
// so HasEmbeddedMTPHead can be exercised without a real model file.
|
||||
func ggufWithArch(arch string, nextn uint32) *gguf.GGUFFile {
|
||||
return &gguf.GGUFFile{
|
||||
Header: gguf.GGUFHeader{
|
||||
MetadataKV: gguf.GGUFMetadataKVs{
|
||||
{
|
||||
Key: "general.architecture",
|
||||
ValueType: gguf.GGUFMetadataValueTypeString,
|
||||
Value: arch,
|
||||
},
|
||||
{
|
||||
Key: arch + ".nextn_predict_layers",
|
||||
ValueType: gguf.GGUFMetadataValueTypeUint32,
|
||||
Value: nextn,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("MTP auto-defaults", func() {
|
||||
Context("MTPSpecOptions", func() {
|
||||
It("returns the upstream-recommended speculative tuple", func() {
|
||||
@@ -82,5 +105,20 @@ var _ = Describe("MTP auto-defaults", func() {
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
|
||||
It("detects a same-GGUF embedded head (DeepSeek/Qwen style)", func() {
|
||||
n, ok := HasEmbeddedMTPHead(ggufWithArch("qwen3moe", 1))
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(n).To(Equal(uint32(1)))
|
||||
})
|
||||
|
||||
It("ignores a gemma4-assistant draft-only model", func() {
|
||||
// The assistant GGUF carries nextn_predict_layers but is a separate
|
||||
// draft model that requires a paired target (ctx_other); it cannot
|
||||
// self-speculate, so it must not trigger the embedded-head defaults.
|
||||
n, ok := HasEmbeddedMTPHead(ggufWithArch("gemma4-assistant", 48))
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
52
core/config/reasoning_effort_test.go
Normal file
52
core/config/reasoning_effort_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// ApplyReasoningEffort resolves the effective reasoning effort (request value
|
||||
// overrides the model config default), stores it on the config so it reaches the
|
||||
// backend, and maps it onto the enable_thinking toggle.
|
||||
var _ = Describe("ModelConfig.ApplyReasoningEffort", func() {
|
||||
It("uses the request value over the config default", func() {
|
||||
c := &config.ModelConfig{ReasoningEffort: "high"}
|
||||
c.ApplyReasoningEffort("none")
|
||||
Expect(c.ReasoningEffort).To(Equal("none"))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("falls back to the config default when the request omits it", func() {
|
||||
c := &config.ModelConfig{ReasoningEffort: "none"}
|
||||
c.ApplyReasoningEffort("")
|
||||
Expect(c.ReasoningEffort).To(Equal("none"))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("enables thinking for an explicit effort level", func() {
|
||||
c := &config.ModelConfig{}
|
||||
c.ApplyReasoningEffort("medium")
|
||||
Expect(c.ReasoningEffort).To(Equal("medium"))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not let a level override an operator's config-level disable", func() {
|
||||
disabled := true
|
||||
c := &config.ModelConfig{}
|
||||
c.ReasoningConfig.DisableReasoning = &disabled
|
||||
c.ApplyReasoningEffort("high")
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is a no-op on the toggle when no effort is set anywhere", func() {
|
||||
c := &config.ModelConfig{}
|
||||
c.ApplyReasoningEffort("")
|
||||
Expect(c.ReasoningEffort).To(Equal(""))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -420,8 +420,9 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
remoteUnloader = d.Router.Unloader()
|
||||
}
|
||||
}
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
natsCfg := distCfg.NatsAuthConfig()
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, natsCfg)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken, natsCfg)
|
||||
|
||||
// Distributed SSE routes (job progress + agent events via NATS)
|
||||
if d := application.Distributed(); d != nil {
|
||||
|
||||
@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
|
||||
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
|
||||
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, "", nil, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
)
|
||||
|
||||
// nodeError builds a schema.ErrorResponse for node endpoints.
|
||||
@@ -89,7 +90,7 @@ type RegisterNodeRequest struct {
|
||||
// RegisterNodeEndpoint registers a new backend node.
|
||||
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
|
||||
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req RegisterNodeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
@@ -217,13 +218,15 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
}
|
||||
|
||||
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
|
||||
// For agent workers, it also provisions an API key so they can call the inference API.
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
id := c.Param("id")
|
||||
@@ -253,10 +256,26 @@ func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecr
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
|
||||
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
|
||||
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
|
||||
return
|
||||
}
|
||||
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
|
||||
return
|
||||
}
|
||||
response["nats_jwt"] = jwt
|
||||
response["nats_user_seed"] = seed
|
||||
}
|
||||
|
||||
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
|
||||
// Returns the plaintext API key on success.
|
||||
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -63,7 +65,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -74,6 +76,29 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
Expect(resp["status"]).To(Equal(nodes.StatusHealthy))
|
||||
})
|
||||
|
||||
It("returns nats_jwt when account seed is configured", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
e := echo.New()
|
||||
body := `{"name":"worker-nats","address":"10.0.0.2:50051"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
natsCfg := natsauth.Config{AccountSeed: string(seed)}
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsCfg)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["nats_jwt"]).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns 400 when name is missing", func() {
|
||||
e := echo.New()
|
||||
body := `{"address":"10.0.0.1:50051"}`
|
||||
@@ -82,7 +107,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -102,7 +127,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -121,7 +146,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -140,7 +165,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -159,7 +184,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
@@ -172,7 +197,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -195,7 +220,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body1))
|
||||
req1.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(e.NewContext(req1, rec1))).To(Succeed())
|
||||
Expect(rec1.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Stream audio chunks as they're generated
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
_, writeErr := c.Response().Write(audioChunk)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
}
|
||||
|
||||
// Non-streaming TTS (existing behavior)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -44,10 +44,10 @@ type wrappedModel struct {
|
||||
// deps in. nil-safe: with classifierRegistry == nil the per-turn
|
||||
// routing block in Predict is skipped, preserving today's "one LLM
|
||||
// for the whole session" behaviour.
|
||||
routerDeps *middleware.ClassifierDeps
|
||||
routerStore router.DecisionStore
|
||||
routerSessionID string
|
||||
routerUserID string
|
||||
routerDeps *middleware.ClassifierDeps
|
||||
routerStore router.DecisionStore
|
||||
routerSessionID string
|
||||
routerUserID string
|
||||
}
|
||||
|
||||
// anyToAnyModel represent a model which supports Any-to-Any operations
|
||||
@@ -119,6 +119,11 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
}
|
||||
}
|
||||
|
||||
// Surface the resolved reasoning effort to the Go-side template path too
|
||||
// (jinja models get it via backend metadata in gRPCPredictOpts; Go-templated
|
||||
// models like gpt-oss read it from the template's .ReasoningEffort).
|
||||
input.ReasoningEffort = turnCfg.ReasoningEffort
|
||||
|
||||
var predInput string
|
||||
var funcs []functions.Function
|
||||
if !turnCfg.TemplateConfig.UseTokenizerTemplate {
|
||||
@@ -313,7 +318,7 @@ func newRealtimeDecisionID() string {
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
|
||||
return backend.ModelTTS(ctx, text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
return backend.ModelTTS(ctx, text, voice, language, "", nil, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
|
||||
@@ -449,6 +454,9 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
// Let the pipeline set the LLM's reasoning effort (cfgLLM is a per-session copy).
|
||||
applyPipelineReasoning(cfgLLM, *pipeline)
|
||||
|
||||
cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
|
||||
16
core/http/endpoints/openai/realtime_reasoning.go
Normal file
16
core/http/endpoints/openai/realtime_reasoning.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package openai
|
||||
|
||||
import "github.com/mudler/LocalAI/core/config"
|
||||
|
||||
// applyPipelineReasoning sets the reasoning effort for a realtime pipeline's LLM
|
||||
// from the pipeline config, without editing the underlying LLM model config. The
|
||||
// pipeline value overrides the LLM's own reasoning_effort; when the pipeline does
|
||||
// not set it, the LLM model config's reasoning_effort (if any) is used. The LLM
|
||||
// config passed in is the per-session copy returned by the config loader, so this
|
||||
// does not affect other users of the same model.
|
||||
func applyPipelineReasoning(llm *config.ModelConfig, pipeline config.Pipeline) {
|
||||
if llm == nil {
|
||||
return
|
||||
}
|
||||
llm.ApplyReasoningEffort(pipeline.ReasoningEffort)
|
||||
}
|
||||
33
core/http/endpoints/openai/realtime_reasoning_test.go
Normal file
33
core/http/endpoints/openai/realtime_reasoning_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// applyPipelineReasoning lets a realtime pipeline set the reasoning effort for
|
||||
// its LLM (forwarded to the backend as reasoning_effort) without editing the LLM
|
||||
// model config. The pipeline value overrides the LLM's own reasoning_effort.
|
||||
var _ = Describe("applyPipelineReasoning", func() {
|
||||
It("applies the pipeline reasoning_effort to the LLM config", func() {
|
||||
llm := &config.ModelConfig{}
|
||||
applyPipelineReasoning(llm, config.Pipeline{ReasoningEffort: "none"})
|
||||
Expect(llm.ReasoningEffort).To(Equal("none"))
|
||||
Expect(llm.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*llm.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("falls back to the LLM's own reasoning_effort when the pipeline is unset", func() {
|
||||
llm := &config.ModelConfig{ReasoningEffort: "high"}
|
||||
applyPipelineReasoning(llm, config.Pipeline{})
|
||||
Expect(llm.ReasoningEffort).To(Equal("high"))
|
||||
Expect(llm.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*llm.ReasoningConfig.DisableReasoning).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is nil-safe", func() {
|
||||
applyPipelineReasoning(nil, config.Pipeline{ReasoningEffort: "low"})
|
||||
})
|
||||
})
|
||||
@@ -310,25 +310,13 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
// Map the per-request reasoning_effort onto the reasoning toggle the
|
||||
// backend reads (enable_thinking metadata, set in gRPCPredictOpts).
|
||||
// "none" disables thinking for this request - the use case from #10072,
|
||||
// running a single Qwen3-style model and turning reasoning off per
|
||||
// request. Any explicit effort level enables thinking, UNLESS the model
|
||||
// config explicitly disabled it (DisableReasoning==true wins): an
|
||||
// operator who deliberately turned reasoning off should not be overridden
|
||||
// by a request. A value of "none" always disables, since that never
|
||||
// conflicts with a config that also disables.
|
||||
switch strings.ToLower(input.ReasoningEffort) {
|
||||
case "none":
|
||||
disable := true
|
||||
config.ReasoningConfig.DisableReasoning = &disable
|
||||
case "minimal", "low", "medium", "high":
|
||||
if config.ReasoningConfig.DisableReasoning == nil || !*config.ReasoningConfig.DisableReasoning {
|
||||
enable := false
|
||||
config.ReasoningConfig.DisableReasoning = &enable
|
||||
}
|
||||
}
|
||||
// Resolve the effective reasoning effort (request overrides the model config
|
||||
// default), store it so gRPCPredictOpts forwards it to the backend as the
|
||||
// reasoning_effort chat_template_kwarg (what gpt-oss / LFM2.5 read), and map
|
||||
// it onto the enable_thinking toggle. "none" disables thinking (the #10072
|
||||
// use case); a level enables it unless the config already disabled reasoning
|
||||
// (an operator's explicit disable wins over a request asking to think).
|
||||
config.ApplyReasoningEffort(input.ReasoningEffort)
|
||||
|
||||
// Collapse the modern max_completion_tokens alias into the
|
||||
// legacy Maxtokens field so downstream code reads exactly one.
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,7 @@ func nodeReadyMiddleware(registry *nodes.NodeRegistry) echo.MiddlewareFunc {
|
||||
// token but do not verify per-node identity. A compromised worker can heartbeat/drain/
|
||||
// deregister other nodes. Future: issue per-node JWT at registration, validate node
|
||||
// identity on subsequent requests (compare :id param with token subject).
|
||||
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) {
|
||||
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) {
|
||||
if registry == nil {
|
||||
return
|
||||
}
|
||||
@@ -44,7 +45,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
|
||||
tokenAuthMw := nodeTokenAuth(registrationToken)
|
||||
|
||||
node := e.Group("/api/node", readyMw, tokenAuthMw)
|
||||
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret))
|
||||
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret, natsCfg))
|
||||
node.POST("/:id/heartbeat", localai.HeartbeatEndpoint(registry))
|
||||
node.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
|
||||
node.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
|
||||
@@ -60,7 +61,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
|
||||
// backend install path (POST /:id/backends/install). That handler enqueues a
|
||||
// ManagementOp on the gallery channel rather than blocking on a NATS reply, so
|
||||
// the browser gets HTTP 202 + jobID immediately instead of waiting up to 3 minutes.
|
||||
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string) {
|
||||
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string, natsCfg natsauth.Config) {
|
||||
if registry == nil {
|
||||
return
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
|
||||
admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry))
|
||||
admin.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
|
||||
admin.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
|
||||
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret))
|
||||
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret, natsCfg))
|
||||
|
||||
// Backend management on workers
|
||||
admin.GET("/:id/backends", localai.ListBackendsOnNodeEndpoint(unloader))
|
||||
|
||||
@@ -60,6 +60,14 @@ type TTSRequest struct {
|
||||
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
|
||||
Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"` // (optional) enable streaming TTS
|
||||
SampleRate int `json:"sample_rate,omitempty" yaml:"sample_rate,omitempty"` // (optional) desired output sample rate
|
||||
// Instructions is a free-form, per-request style/voice description. It maps to
|
||||
// the OpenAI `instructions` field and is forwarded to the backend so expressive
|
||||
// TTS models (e.g. Qwen3-TTS CustomVoice/VoiceDesign) can vary tone or designed
|
||||
// voice per request instead of only via the static YAML option.
|
||||
Instructions string `json:"instructions,omitempty" yaml:"instructions,omitempty"`
|
||||
// Params carries optional, backend-specific per-request generation parameters
|
||||
// (LocalAI extension, e.g. Chatterbox exaggeration/cfg_weight/temperature).
|
||||
Params map[string]string `json:"params,omitempty" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// @Description VAD request body
|
||||
|
||||
@@ -180,18 +180,21 @@ func (s *GalleryStore) Cancel(id string) error {
|
||||
return s.UpdateStatus(id, "cancelled", "")
|
||||
}
|
||||
|
||||
// CleanStale marks abandoned in-progress operations as failed.
|
||||
// Should be called on startup to recover from crashed instances that
|
||||
// left records in pending/downloading/processing state.
|
||||
func (s *GalleryStore) CleanStale(age time.Duration) error {
|
||||
// CleanStale marks abandoned in-progress operations as failed and returns the
|
||||
// number of rows reaped. Called on startup AND periodically to recover from
|
||||
// crashed/restarted instances that left records in pending/downloading/
|
||||
// processing state — an op orphaned after startup would otherwise linger
|
||||
// "processing" until the next restart.
|
||||
func (s *GalleryStore) CleanStale(age time.Duration) (int64, error) {
|
||||
cutoff := time.Now().Add(-age)
|
||||
return s.db.Model(&GalleryOperationRecord{}).
|
||||
res := s.db.Model(&GalleryOperationRecord{}).
|
||||
Where("updated_at < ? AND status IN ?", cutoff, activeStatuses).
|
||||
Updates(map[string]any{
|
||||
"status": "failed",
|
||||
"error": "stale operation cleaned up on startup",
|
||||
"error": "stale operation reaped (abandoned by a crashed or restarted instance)",
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
})
|
||||
return res.RowsAffected, res.Error
|
||||
}
|
||||
|
||||
// CleanOld removes operations older than the given duration.
|
||||
|
||||
@@ -71,7 +71,7 @@ func (g *GalleryService) backendHandler(op *ManagementOp[gallery.GalleryBackend,
|
||||
|
||||
var err error
|
||||
if op.Upgrade {
|
||||
err = g.backendManager.UpgradeBackend(ctx, op.GalleryElementName, progressCallback)
|
||||
err = g.backendManager.UpgradeBackend(ctx, op.ID, op.GalleryElementName, progressCallback)
|
||||
} else if op.Delete {
|
||||
err = g.backendManager.DeleteBackend(op.GalleryElementName)
|
||||
} else {
|
||||
|
||||
106
core/services/galleryop/cancel_persist_test.go
Normal file
106
core/services/galleryop/cancel_persist_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package galleryop_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/distributed"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
)
|
||||
|
||||
// Reproduces "a cancelled/orphaned op resurrects as 'processing' after a pod
|
||||
// restart". CancelOperation flipped the in-memory status to cancelled and
|
||||
// broadcast a NATS event, but never persisted the terminal status to the
|
||||
// gallery store. On the next replica restart the still-"pending" row hydrated
|
||||
// straight back into processingBackends and the UI spun again. CancelOperation
|
||||
// must persist the cancellation so it survives a restart.
|
||||
var _ = Describe("GalleryService.CancelOperation persistence", func() {
|
||||
It("persists the cancelled status to the gallery store", func() {
|
||||
db := testutil.SetupTestDB()
|
||||
store, err := distributed.NewGalleryStore(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Seed an in-flight op as if a replica was mid-install.
|
||||
Expect(store.Create(&distributed.GalleryOperationRecord{
|
||||
ID: "op-cancel",
|
||||
GalleryElementName: "llama-cpp-development",
|
||||
OpType: "backend_install",
|
||||
Status: "pending",
|
||||
Progress: 0,
|
||||
})).To(Succeed())
|
||||
|
||||
svc := galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
|
||||
svc.SetGalleryStore(store)
|
||||
// Make the op locally cancellable so CancelOperation proceeds.
|
||||
svc.StoreCancellation("op-cancel", context.CancelFunc(func() {}))
|
||||
|
||||
Expect(svc.CancelOperation("op-cancel")).To(Succeed())
|
||||
|
||||
// The persisted row must now be terminal — otherwise it re-hydrates as
|
||||
// pending on the next restart.
|
||||
rec, err := store.Get("op-cancel")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rec.Status).To(Equal("cancelled"))
|
||||
|
||||
// And a fresh service hydrating from the store must NOT see it as active.
|
||||
fresh := galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
|
||||
fresh.SetGalleryStore(store)
|
||||
Expect(fresh.Hydrate()).To(Succeed())
|
||||
Expect(fresh.GetStatus("op-cancel")).To(BeNil(),
|
||||
"a cancelled op must not hydrate back as active after a restart")
|
||||
})
|
||||
})
|
||||
|
||||
// Reproduces "an op orphaned by a replica that died mid-flight stays 'pending'
|
||||
// forever". CleanStale (which marks abandoned active ops failed) only ran once
|
||||
// on startup, so an op orphaned AFTER startup was never reaped until the next
|
||||
// restart. The service must reap stale ops on an interval, not just at boot.
|
||||
var _ = Describe("GalleryService.ReapStaleOperations", func() {
|
||||
It("marks abandoned active ops terminal once they pass the age cutoff", func() {
|
||||
db := testutil.SetupTestDB()
|
||||
store, err := distributed.NewGalleryStore(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(store.Create(&distributed.GalleryOperationRecord{
|
||||
ID: "orphan-op",
|
||||
GalleryElementName: "llama-cpp-development",
|
||||
OpType: "backend_install",
|
||||
Status: "pending",
|
||||
Progress: 0,
|
||||
})).To(Succeed())
|
||||
// Force the row's updated_at into the past so it is older than the cutoff.
|
||||
Expect(db.Exec(
|
||||
"UPDATE gallery_operations SET updated_at = ? WHERE id = ?",
|
||||
time.Now().Add(-1*time.Hour), "orphan-op",
|
||||
).Error).To(Succeed())
|
||||
|
||||
// A fresh, still-progressing op must NOT be reaped.
|
||||
Expect(store.Create(&distributed.GalleryOperationRecord{
|
||||
ID: "live-op",
|
||||
GalleryElementName: "vllm-development",
|
||||
OpType: "backend_install",
|
||||
Status: "downloading",
|
||||
Progress: 50,
|
||||
})).To(Succeed())
|
||||
|
||||
svc := galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
|
||||
svc.SetGalleryStore(store)
|
||||
|
||||
reaped, err := svc.ReapStaleOperations(30 * time.Minute)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(reaped).To(Equal(int64(1)))
|
||||
|
||||
orphan, err := store.Get("orphan-op")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(orphan.Status).To(Equal("failed"))
|
||||
|
||||
live, err := store.Get("live-op")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(live.Status).To(Equal("downloading"), "a recently-updated op must not be reaped")
|
||||
})
|
||||
})
|
||||
@@ -20,7 +20,7 @@ type BackendManager interface {
|
||||
InstallBackend(ctx context.Context, op *ManagementOp[gallery.GalleryBackend, any], progressCb ProgressCallback) error
|
||||
DeleteBackend(name string) error
|
||||
ListBackends() (gallery.SystemBackends, error)
|
||||
UpgradeBackend(ctx context.Context, name string, progressCb ProgressCallback) error
|
||||
UpgradeBackend(ctx context.Context, opID, name string, progressCb ProgressCallback) error
|
||||
CheckUpgrades(ctx context.Context) (map[string]gallery.UpgradeInfo, error)
|
||||
// IsDistributed reports whether installs fan out across worker nodes.
|
||||
// The HTTP layer uses this to refuse hardware-specific (non-meta) installs
|
||||
|
||||
@@ -96,7 +96,10 @@ func (b *LocalBackendManager) ListBackends() (gallery.SystemBackends, error) {
|
||||
return gallery.ListSystemBackends(b.systemState)
|
||||
}
|
||||
|
||||
func (b *LocalBackendManager) UpgradeBackend(ctx context.Context, name string, progressCb ProgressCallback) error {
|
||||
// UpgradeBackend ignores opID: a single-node install reports progress through
|
||||
// the local progressCb already; opID only matters for distributed per-node
|
||||
// streaming (see DistributedBackendManager.UpgradeBackend).
|
||||
func (b *LocalBackendManager) UpgradeBackend(ctx context.Context, _ string, name string, progressCb ProgressCallback) error {
|
||||
return gallery.UpgradeBackend(ctx, b.systemState, b.modelLoader, b.backendGalleries, name, progressCb, b.requireBackendIntegrity)
|
||||
}
|
||||
|
||||
|
||||
92
core/services/galleryop/opcache_evict_test.go
Normal file
92
core/services/galleryop/opcache_evict_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package galleryop_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
)
|
||||
|
||||
// These specs reproduce the distributed "Reinstall spins forever" bug:
|
||||
// processingBackends (the UI spinner source) is built from OpCache.GetStatus,
|
||||
// which historically returned every cached op unconditionally. Cleanup only
|
||||
// happened when a client polled /api/backends/job/:uid, but the Manage-page
|
||||
// Reinstall/Upgrade buttons never poll, so a completed install stayed in
|
||||
// processingBackends forever. GetStatus must self-evict terminal ops.
|
||||
var _ = Describe("OpCache.GetStatus eviction", func() {
|
||||
var (
|
||||
svc *galleryop.GalleryService
|
||||
cache *galleryop.OpCache
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
svc = galleryop.NewGalleryService(&config.ApplicationConfig{}, nil)
|
||||
cache = galleryop.NewOpCache(svc)
|
||||
})
|
||||
|
||||
It("keeps an op that is still processing", func() {
|
||||
cache.SetBackend("llama-cpp", "uuid-1")
|
||||
svc.UpdateStatus("uuid-1", &galleryop.OpStatus{Message: "processing backend: llama-cpp", Progress: 0})
|
||||
processing, _ := cache.GetStatus()
|
||||
Expect(processing).To(HaveKeyWithValue("llama-cpp", "uuid-1"))
|
||||
Expect(cache.Exists("llama-cpp")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("evicts a completed op so it no longer shows as processing", func() {
|
||||
cache.SetBackend("llama-cpp", "uuid-1")
|
||||
svc.UpdateStatus("uuid-1", &galleryop.OpStatus{Processed: true, Progress: 100, Message: "completed"})
|
||||
processing, _ := cache.GetStatus()
|
||||
Expect(processing).NotTo(HaveKey("llama-cpp"))
|
||||
Expect(cache.Exists("llama-cpp")).To(BeFalse())
|
||||
})
|
||||
|
||||
It("keeps a failed op so the operations panel can surface the error and offer Dismiss", func() {
|
||||
cache.SetBackend("piper", "uuid-2")
|
||||
svc.UpdateStatus("uuid-2", &galleryop.OpStatus{Processed: true, Error: errors.New("boom")})
|
||||
processing, _ := cache.GetStatus()
|
||||
Expect(processing).To(HaveKeyWithValue("piper", "uuid-2"))
|
||||
Expect(cache.Exists("piper")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("evicts a cancelled op", func() {
|
||||
cache.SetBackend("vllm", "uuid-3")
|
||||
svc.UpdateStatus("uuid-3", &galleryop.OpStatus{Processed: true, Cancelled: true, Message: "cancelled"})
|
||||
processing, _ := cache.GetStatus()
|
||||
Expect(processing).NotTo(HaveKey("vllm"))
|
||||
})
|
||||
|
||||
It("does not evict an op with no status yet (queued)", func() {
|
||||
cache.SetBackend("whisper", "uuid-4")
|
||||
processing, taskTypes := cache.GetStatus()
|
||||
Expect(processing).To(HaveKeyWithValue("whisper", "uuid-4"))
|
||||
Expect(taskTypes).To(HaveKeyWithValue("whisper", "Waiting"))
|
||||
})
|
||||
|
||||
// Regression guard: GetStatus is called concurrently by four HTTP handlers
|
||||
// (~1s poll). An earlier version evicted by deleting from m.Map() — which
|
||||
// returns the live internal map by reference — causing a fatal
|
||||
// "concurrent map writes" crash. Run under -race; must not panic or race.
|
||||
It("is safe under concurrent GetStatus + Set/complete", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for i := 0; i < 2000; i++ {
|
||||
_, _ = cache.GetStatus()
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
for i := 0; i < 2000; i++ {
|
||||
id := "uuid-c"
|
||||
cache.SetBackend("concurrent-backend", id)
|
||||
// Half the time mark it completed so GetStatus evicts it.
|
||||
if i%2 == 0 {
|
||||
svc.UpdateStatus(id, &galleryop.OpStatus{Processed: true, Progress: 100, Message: "completed"})
|
||||
}
|
||||
_, _ = cache.GetStatus()
|
||||
}
|
||||
<-done
|
||||
})
|
||||
})
|
||||
@@ -408,12 +408,34 @@ func (m *OpCache) Exists(key string) bool {
|
||||
}
|
||||
|
||||
func (m *OpCache) GetStatus() (map[string]string, map[string]string) {
|
||||
processingModelsData := m.Map()
|
||||
|
||||
taskTypes := map[string]string{}
|
||||
processingModelsData := map[string]string{}
|
||||
|
||||
for k, v := range processingModelsData {
|
||||
// Iterate a snapshot (Keys() copies) and build a fresh result map. We must
|
||||
// NOT delete from m.Map() during the range: Map() returns the live internal
|
||||
// map by reference, so a bare delete here would be an unsynchronized write
|
||||
// to a map four HTTP handlers read every ~1s — a concurrent-map-write crash.
|
||||
// Collect evictions and apply them via the locked DeleteUUID after the loop.
|
||||
var evict []string
|
||||
for _, k := range m.status.Keys() {
|
||||
v := m.status.Get(k)
|
||||
if v == "" {
|
||||
continue // raced with a concurrent Delete
|
||||
}
|
||||
status := m.galleryService.GetStatus(v)
|
||||
// Terminal ops must not keep showing as "processing". Cleanup was
|
||||
// previously only triggered by a client polling /api/backends/job/:uid,
|
||||
// but the Manage-page Reinstall/Upgrade buttons never poll, so completed
|
||||
// ops leaked into processingBackends forever and the card spun
|
||||
// "reinstalling" indefinitely. Evict here on the list read (the UI always
|
||||
// calls this). We only evict SUCCESS/cancelled terminals (Error == nil):
|
||||
// failed ops are kept so /api/operations can surface the error and offer
|
||||
// Dismiss. DeleteUUID broadcasts the eviction so peer replicas converge.
|
||||
if status != nil && status.Processed && status.Error == nil {
|
||||
evict = append(evict, v)
|
||||
continue
|
||||
}
|
||||
processingModelsData[k] = v
|
||||
taskTypes[k] = "Installation"
|
||||
if status != nil && status.Deletion {
|
||||
taskTypes[k] = "Deletion"
|
||||
@@ -422,6 +444,10 @@ func (m *OpCache) GetStatus() (map[string]string, map[string]string) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range evict {
|
||||
m.DeleteUUID(v)
|
||||
}
|
||||
|
||||
return processingModelsData, taskTypes
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
@@ -31,9 +32,9 @@ type GalleryService struct {
|
||||
// natsClient is the wider MessagingClient (Publisher + subscribe methods)
|
||||
// when wired by the distributed startup path; broadcastSubs holds the
|
||||
// progress + cancel subscriptions opened by SubscribeBroadcasts.
|
||||
natsClient messaging.MessagingClient
|
||||
galleryStore *distributed.GalleryStore
|
||||
broadcastSubs []messaging.Subscription
|
||||
natsClient messaging.MessagingClient
|
||||
galleryStore *distributed.GalleryStore
|
||||
broadcastSubs []messaging.Subscription
|
||||
|
||||
// OnBackendOpCompleted is fired after every successful install/upgrade/delete
|
||||
// on the backend channel. The Application wires this to UpgradeChecker.TriggerCheck
|
||||
@@ -274,6 +275,29 @@ func (g *GalleryService) GetAllStatus() map[string]*OpStatus {
|
||||
return g.statuses
|
||||
}
|
||||
|
||||
// ReapStaleOperations marks abandoned in-progress operations (pending/
|
||||
// downloading/processing) older than `age` as failed, so an op orphaned by a
|
||||
// replica that died mid-flight does not linger as "processing" forever. The
|
||||
// store's CleanStale runs once on startup; this exposes it for periodic
|
||||
// invocation (a post-startup orphan is otherwise not reaped until the next
|
||||
// restart). No-op when no gallery store is wired. Returns rows reaped.
|
||||
func (g *GalleryService) ReapStaleOperations(age time.Duration) (int64, error) {
|
||||
g.Lock()
|
||||
store := g.galleryStore
|
||||
g.Unlock()
|
||||
if store == nil {
|
||||
return 0, nil
|
||||
}
|
||||
n, err := store.CleanStale(age)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n > 0 {
|
||||
xlog.Info("Reaped stale gallery operations", "count", n)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// CancelOperation cancels an in-progress operation by its ID.
|
||||
//
|
||||
// In distributed mode the UI's cancel click may land on a different replica
|
||||
@@ -295,6 +319,7 @@ func (g *GalleryService) CancelOperation(id string) error {
|
||||
}
|
||||
|
||||
nc := g.natsClient
|
||||
store := g.galleryStore
|
||||
|
||||
if !localExists && nc == nil {
|
||||
g.Unlock()
|
||||
@@ -315,6 +340,17 @@ func (g *GalleryService) CancelOperation(id string) error {
|
||||
}
|
||||
g.Unlock()
|
||||
|
||||
// Persist the terminal status so the cancel survives a restart. Without
|
||||
// this the row stays in its active state and re-hydrates straight back into
|
||||
// processingBackends on the next replica boot — the UI spins again on an op
|
||||
// the admin already cancelled. The peer that broadcasts wins the write; a
|
||||
// no-op when standalone (store nil).
|
||||
if store != nil {
|
||||
if err := store.Cancel(id); err != nil {
|
||||
xlog.Warn("Failed to persist gallery operation cancellation", "op_id", id, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// I/O and user-provided callback after Unlock — the cancel-wildcard
|
||||
// subscriber loops back into applyCancel on this same replica, which
|
||||
// would otherwise deadlock on g.Mutex.
|
||||
|
||||
@@ -2,15 +2,22 @@ package messaging
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
// subscribeConfirmTimeout bounds the server round-trip used to detect whether a
|
||||
// subscription was rejected (e.g. by JWT permissions) before returning to the caller.
|
||||
const subscribeConfirmTimeout = 5 * time.Second
|
||||
|
||||
// Client wraps a NATS connection and provides helpers for pub/sub and queue subscriptions.
|
||||
type Client struct {
|
||||
conn *nats.Conn
|
||||
@@ -18,8 +25,13 @@ type Client struct {
|
||||
}
|
||||
|
||||
// New creates a new NATS client with auto-reconnect.
|
||||
func New(url string) (*Client, error) {
|
||||
nc, err := nats.Connect(url,
|
||||
func New(url string, opts ...Option) (*Client, error) {
|
||||
var cfg connectConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
natsOpts := []nats.Option{
|
||||
nats.RetryOnFailedConnect(true),
|
||||
nats.MaxReconnects(-1),
|
||||
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
|
||||
@@ -33,7 +45,60 @@ func New(url string) (*Client, error) {
|
||||
nats.ClosedHandler(func(_ *nats.Conn) {
|
||||
xlog.Info("NATS connection closed")
|
||||
}),
|
||||
)
|
||||
// Surface async errors (notably permission violations) that NATS would
|
||||
// otherwise deliver silently. A subscription the server rejects for a
|
||||
// JWT permission means the worker never receives those messages, so make
|
||||
// it loud rather than letting the feature fail invisibly.
|
||||
nats.ErrorHandler(func(_ *nats.Conn, sub *nats.Subscription, err error) {
|
||||
subject := ""
|
||||
if sub != nil {
|
||||
subject = sub.Subject
|
||||
}
|
||||
if errors.Is(err, nats.ErrPermissionViolation) {
|
||||
xlog.Error("NATS permission violation — check JWT pub/sub allow lists", "subject", subject, "error", err)
|
||||
return
|
||||
}
|
||||
xlog.Warn("NATS async error", "subject", subject, "error", err)
|
||||
}),
|
||||
}
|
||||
switch {
|
||||
case cfg.jwtProvider != nil:
|
||||
// Fetch creds on every (re)connect so a refresh loop can rotate the JWT
|
||||
// before expiry; the server expiring the old JWT triggers a reconnect
|
||||
// that transparently picks up the new one.
|
||||
natsOpts = append(natsOpts, nats.UserJWT(
|
||||
func() (string, error) {
|
||||
jwt, _ := cfg.jwtProvider()
|
||||
if jwt == "" {
|
||||
return "", fmt.Errorf("no NATS user JWT available")
|
||||
}
|
||||
return jwt, nil
|
||||
},
|
||||
func(nonce []byte) ([]byte, error) {
|
||||
_, seed := cfg.jwtProvider()
|
||||
kp, err := nkeys.FromSeed([]byte(seed))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading NATS user seed: %w", err)
|
||||
}
|
||||
defer kp.Wipe()
|
||||
return kp.Sign(nonce)
|
||||
},
|
||||
))
|
||||
case cfg.userJWT != "" && cfg.userSeed != "":
|
||||
natsOpts = append(natsOpts, nats.UserJWTAndSeed(cfg.userJWT, cfg.userSeed))
|
||||
}
|
||||
if cfg.tls.Enabled() {
|
||||
if err := cfg.tls.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsOpts, err := cfg.tls.natsOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
natsOpts = append(natsOpts, tlsOpts...)
|
||||
}
|
||||
|
||||
nc, err := nats.Connect(url, natsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS at %s: %w", sanitize.URL(url), err)
|
||||
}
|
||||
@@ -54,23 +119,67 @@ func (c *Client) Publish(subject string, data any) error {
|
||||
|
||||
// Subscribe creates a subscription on the given subject. All subscribers receive every message.
|
||||
func (c *Client) Subscribe(subject string, handler func([]byte)) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// QueueSubscribe creates a queue subscription. Within the same queue group,
|
||||
// only one subscriber receives each message (load-balanced).
|
||||
func (c *Client) QueueSubscribe(subject, queue string, handler func([]byte)) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// confirmSubscription creates a subscription via mk and forces a server
|
||||
// round-trip so that a permissions violation — which NATS otherwise reports
|
||||
// only asynchronously — is returned to the caller synchronously. The server
|
||||
// emits the "-ERR Permissions Violation" for a rejected SUB before the PONG
|
||||
// that satisfies the flush, so by the time FlushTimeout returns the violation
|
||||
// is recorded as the connection's last error. Without this, a worker whose JWT
|
||||
// lacks a subject gets a non-nil subscription that never receives a message,
|
||||
// turning a permission misconfiguration into a silent failure.
|
||||
func (c *Client) confirmSubscription(subject string, mk func(*nats.Conn) (*nats.Subscription, error)) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
conn := c.conn
|
||||
c.mu.RUnlock()
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("subscribe to %s: nil NATS connection", subject)
|
||||
}
|
||||
|
||||
sub, err := mk(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// A failed flush here means we could not round-trip to the server (not yet
|
||||
// connected, reconnecting, slow link). RetryOnFailedConnect intentionally
|
||||
// buffers subscriptions across that gap, so do NOT fail — keep the
|
||||
// subscription and let it replay on (re)connect; a later permission
|
||||
// violation is still logged by the async error handler in New.
|
||||
if err := conn.FlushTimeout(subscribeConfirmTimeout); err != nil {
|
||||
xlog.Debug("Could not confirm NATS subscription (will replay on connect)", "subject", subject, "error", err)
|
||||
return sub, nil
|
||||
}
|
||||
// Flush succeeded, so any permission violation for this SUB has already been
|
||||
// recorded as the connection's last error (the server emits it before the
|
||||
// PONG). LastError is per-connection; match the exact quoted subject the
|
||||
// server echoes ("Subscription to \"<subject>\"") so a stale violation for
|
||||
// another subject can't be mis-attributed here.
|
||||
if lerr := conn.LastError(); lerr != nil &&
|
||||
errors.Is(lerr, nats.ErrPermissionViolation) &&
|
||||
strings.Contains(lerr.Error(), `Subscription to "`+subject+`"`) {
|
||||
_ = sub.Unsubscribe()
|
||||
return nil, fmt.Errorf("subscription to %s denied by NATS server (check JWT sub allow list): %w", subject, lerr)
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// Request sends a request and waits for a reply (request-reply pattern).
|
||||
// Returns the raw reply data.
|
||||
func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
|
||||
@@ -86,15 +195,15 @@ func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]
|
||||
// SubscribeReply creates a subscription that supports replying to requests.
|
||||
// The handler receives the raw request data and the reply subject.
|
||||
func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -102,15 +211,15 @@ func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply
|
||||
// QueueSubscribeReply creates a queue subscription that supports replying to requests.
|
||||
// Load-balanced across subscribers in the same queue group, with request-reply support.
|
||||
func (c *Client) QueueSubscribeReply(subject, queue string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
34
core/services/messaging/options.go
Normal file
34
core/services/messaging/options.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package messaging
|
||||
|
||||
// Option configures NATS client connection behavior.
|
||||
type Option func(*connectConfig)
|
||||
|
||||
// CredentialProvider returns the NATS user JWT and signing seed to use for the
|
||||
// next (re)connect. It is consulted on every connection attempt, so a refresh
|
||||
// loop can rotate credentials before they expire and the connection picks them
|
||||
// up automatically when the server expires the old JWT and triggers a reconnect.
|
||||
type CredentialProvider func() (jwt, seed string)
|
||||
|
||||
type connectConfig struct {
|
||||
userJWT string
|
||||
userSeed string
|
||||
jwtProvider CredentialProvider
|
||||
tls TLSFiles
|
||||
}
|
||||
|
||||
// WithUserJWT connects using a static NATS user JWT and signing seed (UserJWTAndSeed).
|
||||
func WithUserJWT(jwt, seed string) Option {
|
||||
return func(c *connectConfig) {
|
||||
c.userJWT = jwt
|
||||
c.userSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserJWTProvider connects using credentials fetched from provider on each
|
||||
// (re)connect, enabling JWT rotation without dropping the client. Takes
|
||||
// precedence over WithUserJWT when both are set.
|
||||
func WithUserJWTProvider(provider CredentialProvider) Option {
|
||||
return func(c *connectConfig) {
|
||||
c.jwtProvider = provider
|
||||
}
|
||||
}
|
||||
@@ -194,6 +194,14 @@ type BackendUpgradeRequest struct {
|
||||
// but the field lets future per-replica metadata (e.g. progress reporting
|
||||
// scoped to a slot) ride the same wire without a v3 type.
|
||||
ReplicaIndex int32 `json:"replica_index,omitempty"`
|
||||
// OpID identifies the admin-side operation. When non-empty the worker
|
||||
// publishes BackendInstallProgressEvent values to
|
||||
// SubjectNodeBackendInstallProgress(nodeID, OpID) while the force-reinstall
|
||||
// runs, so the master can stream per-node progress for upgrades exactly as
|
||||
// it already does for installs (an upgrade IS a force-reinstall, so the
|
||||
// install-progress subject is reused rather than minting a new one — no new
|
||||
// NATS permission or rolling-update compat surface). Empty on legacy callers.
|
||||
OpID string `json:"op_id,omitempty"`
|
||||
}
|
||||
|
||||
// BackendUpgradeReply mirrors BackendInstallReply minus Address — upgrade does
|
||||
|
||||
68
core/services/messaging/tls.go
Normal file
68
core/services/messaging/tls.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package messaging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
// TLSFiles holds PEM paths for NATS TLS / mTLS. Cert and key must be set together.
|
||||
// Use tls:// in LOCALAI_NATS_URL; CA and client cert paths are optional extras.
|
||||
type TLSFiles struct {
|
||||
CA string // LOCALAI_NATS_TLS_CA — private CA for server verification
|
||||
Cert string // LOCALAI_NATS_TLS_CERT — client certificate (mTLS)
|
||||
Key string // LOCALAI_NATS_TLS_KEY — client private key
|
||||
}
|
||||
|
||||
// Enabled reports whether any TLS file path is configured.
|
||||
func (f TLSFiles) Enabled() bool {
|
||||
return f.CA != "" || f.Cert != "" || f.Key != ""
|
||||
}
|
||||
|
||||
// Validate checks path pairing and that files exist.
|
||||
func (f TLSFiles) Validate() error {
|
||||
if f.Cert != "" && f.Key == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_TLS_KEY is required when LOCALAI_NATS_TLS_CERT is set")
|
||||
}
|
||||
if f.Key != "" && f.Cert == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_TLS_CERT is required when LOCALAI_NATS_TLS_KEY is set")
|
||||
}
|
||||
for _, path := range []struct {
|
||||
name, path string
|
||||
}{
|
||||
{"LOCALAI_NATS_TLS_CA", f.CA},
|
||||
{"LOCALAI_NATS_TLS_CERT", f.Cert},
|
||||
{"LOCALAI_NATS_TLS_KEY", f.Key},
|
||||
} {
|
||||
if path.path == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(path.path); err != nil {
|
||||
return fmt.Errorf("%s: %w", path.name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// natsOptions builds nats-go TLS options. Call Validate first.
|
||||
func (f TLSFiles) natsOptions() ([]nats.Option, error) {
|
||||
if !f.Enabled() {
|
||||
return nil, nil
|
||||
}
|
||||
opts := []nats.Option{nats.Secure()}
|
||||
if f.CA != "" {
|
||||
opts = append(opts, nats.RootCAs(f.CA))
|
||||
}
|
||||
if f.Cert != "" {
|
||||
opts = append(opts, nats.ClientCert(f.Cert, f.Key))
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// WithTLS configures CA and/or client certificate paths for the NATS connection.
|
||||
func WithTLS(files TLSFiles) Option {
|
||||
return func(c *connectConfig) {
|
||||
c.tls = files
|
||||
}
|
||||
}
|
||||
25
core/services/messaging/tls_test.go
Normal file
25
core/services/messaging/tls_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package messaging_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TLSFiles", func() {
|
||||
It("requires cert and key together", func() {
|
||||
Expect((messaging.TLSFiles{Cert: "/tmp/c.pem"}).Validate()).To(HaveOccurred())
|
||||
Expect((messaging.TLSFiles{Key: "/tmp/k.pem"}).Validate()).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("validates files exist", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
ca := filepath.Join(dir, "ca.pem")
|
||||
Expect(os.WriteFile(ca, []byte("x"), 0600)).To(Succeed())
|
||||
Expect((messaging.TLSFiles{CA: ca}).Validate()).To(Succeed())
|
||||
})
|
||||
})
|
||||
@@ -61,6 +61,17 @@ func StartFileTransferServerWithListener(lis net.Listener, stagingDir, modelsDir
|
||||
return nil, fmt.Errorf("creating staging dir %s: %w", stagingDir, err)
|
||||
}
|
||||
|
||||
// An empty token makes checkBearerToken fail open: every /v1/files,
|
||||
// /v1/files-list and /v1/backend-logs request is served unauthenticated,
|
||||
// granting read/write to the staging/models/data directories to anyone who
|
||||
// can reach this port. Surface that loudly — the worker process does not
|
||||
// run DistributedConfig.Validate(), so this is the only signal an operator
|
||||
// gets. Set LOCALAI_REGISTRATION_TOKEN (and LOCALAI_REGISTRATION_REQUIRE_AUTH
|
||||
// to fail closed) to protect it.
|
||||
if token == "" {
|
||||
xlog.Warn("HTTP file transfer server starting WITHOUT a registration token — read/write to models/staging/data is unauthenticated for anyone who can reach this port; set LOCALAI_REGISTRATION_TOKEN")
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// PUT /v1/files/{key} — upload file
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -893,3 +894,50 @@ func sha256Hex(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
var _ = Describe("StartFileTransferServerWithListener", func() {
|
||||
start := func(token string) (string, func()) {
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
staging := GinkgoT().TempDir()
|
||||
models := GinkgoT().TempDir()
|
||||
data := GinkgoT().TempDir()
|
||||
srv, err := StartFileTransferServerWithListener(lis, staging, models, data, token, 0)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
base := "http://" + lis.Addr().String()
|
||||
return base, func() { ShutdownFileTransferServer(srv) }
|
||||
}
|
||||
|
||||
// Exercises the empty-token fail-open warning branch: the server serves
|
||||
// file requests with no Authorization header at all.
|
||||
It("serves unauthenticated when started without a token", func() {
|
||||
base, stop := start("")
|
||||
defer stop()
|
||||
|
||||
resp, err := http.Get(base + "/v1/files/missing.bin")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
// No 401 — the empty token fails open. The file is absent so we get 404.
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusNotFound))
|
||||
})
|
||||
|
||||
It("rejects requests without the bearer token when a token is set", func() {
|
||||
base, stop := start("s3cret")
|
||||
defer stop()
|
||||
|
||||
resp, err := http.Get(base + "/v1/files/missing.bin")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("serves the unauthenticated health endpoints regardless of token", func() {
|
||||
base, stop := start("s3cret")
|
||||
defer stop()
|
||||
|
||||
resp, err := http.Get(base + "/healthz")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
ggrpc "google.golang.org/grpc"
|
||||
@@ -64,64 +65,95 @@ func (c *InFlightTrackingClient) track(ctx context.Context) func() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconcile self-heals stale routing: when a backend reports that the model is
|
||||
// no longer loaded (the process survived but the model was evicted, while the
|
||||
// registry still lists it as loaded), it drops the replica row so the next
|
||||
// request triggers a fresh load instead of routing back here. Without this the
|
||||
// model stays unreachable until the controller restarts. The original error is
|
||||
// returned unchanged.
|
||||
func (c *InFlightTrackingClient) reconcile(err error) error {
|
||||
if !grpcerrors.IsModelNotLoaded(err) {
|
||||
return err
|
||||
}
|
||||
rmCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if rmErr := c.registry.RemoveNodeModel(rmCtx, c.nodeID, c.modelName, c.replicaIndex); rmErr != nil {
|
||||
xlog.Warn("Failed to drop stale replica after model-not-loaded",
|
||||
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex, "error", rmErr)
|
||||
} else {
|
||||
xlog.Warn("Backend reports model not loaded; dropped stale replica so the next request reloads",
|
||||
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// --- Tracked inference methods ---
|
||||
|
||||
func (c *InFlightTrackingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Predict(ctx, in, opts...)
|
||||
reply, err := c.Backend.Predict(ctx, in, opts...)
|
||||
return reply, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.PredictStream(ctx, in, f, opts...)
|
||||
return c.reconcile(c.Backend.PredictStream(ctx, in, f, opts...))
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Embeddings(ctx, in, opts...)
|
||||
res, err := c.Backend.Embeddings(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.GenerateImage(ctx, in, opts...)
|
||||
res, err := c.Backend.GenerateImage(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.GenerateVideo(ctx, in, opts...)
|
||||
res, err := c.Backend.GenerateVideo(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.TTS(ctx, in, opts...)
|
||||
res, err := c.Backend.TTS(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.TTSStream(ctx, in, f, opts...)
|
||||
return c.reconcile(c.Backend.TTSStream(ctx, in, f, opts...))
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.SoundGeneration(ctx, in, opts...)
|
||||
res, err := c.Backend.SoundGeneration(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.AudioTranscription(ctx, in, opts...)
|
||||
res, err := c.Backend.AudioTranscription(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)
|
||||
return c.reconcile(c.Backend.AudioTranscriptionStream(ctx, in, f, opts...))
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Detect(ctx, in, opts...)
|
||||
res, err := c.Backend.Detect(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Rerank(ctx, in, opts...)
|
||||
res, err := c.Backend.Rerank(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
@@ -20,9 +20,17 @@ type fakeInFlightTracker struct {
|
||||
mu sync.Mutex
|
||||
increments int
|
||||
decrements int
|
||||
removed int
|
||||
incrementErr error
|
||||
}
|
||||
|
||||
func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.removed++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
@@ -295,4 +303,33 @@ var _ = Describe("InFlightTrackingClient", func() {
|
||||
Expect(tracker.decrements).To(Equal(1))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("stale model reload (self-heal)", func() {
|
||||
It("removes the replica when the backend reports the model is not loaded", func() {
|
||||
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
|
||||
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(1))
|
||||
})
|
||||
|
||||
It("keeps the replica on an unrelated error", func() {
|
||||
backend.predictErr = fmt.Errorf("context deadline exceeded")
|
||||
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(0))
|
||||
})
|
||||
|
||||
It("does not remove on success", func() {
|
||||
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(0))
|
||||
})
|
||||
|
||||
It("self-heals on a streamed call too", func() {
|
||||
backend.streamErr = fmt.Errorf("whisper: model not loaded")
|
||||
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -78,6 +78,9 @@ type ModelLookup interface {
|
||||
type InFlightTracker interface {
|
||||
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
// RemoveNodeModel drops a stale replica row so the next request reloads the
|
||||
// model instead of routing back to a node where it is no longer loaded.
|
||||
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
}
|
||||
|
||||
// NodeManager is used by HTTP endpoints for node registration and lifecycle.
|
||||
|
||||
@@ -533,7 +533,7 @@ func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *gall
|
||||
// backend.upgrade, we try the legacy backend.install Force=true path so a
|
||||
// new master + old worker still converges. Drop the fallback once every
|
||||
// worker in the fleet is on 2026-05-08 or newer.
|
||||
func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name string, progressCb galleryop.ProgressCallback) error {
|
||||
func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, opID, name string, progressCb galleryop.ProgressCallback) error {
|
||||
galleriesJSON, _ := json.Marshal(d.backendGalleries)
|
||||
|
||||
installed, err := d.ListBackends()
|
||||
@@ -549,17 +549,39 @@ func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name str
|
||||
targetNodeIDs[n.NodeID] = true
|
||||
}
|
||||
|
||||
// Empty opID: the caller (galleryop) doesn't thread an op ID into
|
||||
// UpgradeBackend today, so we can't tag per-node sink writes with the
|
||||
// right OpStatus key. Until the upgrade path takes a ManagementOp the
|
||||
// way InstallBackend does, the sink stays no-op here.
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, "", OpBackendUpgrade, name, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
|
||||
reply, err := d.adapter.UpgradeBackend(node.ID, name, string(galleriesJSON), "", "", "", 0)
|
||||
result, err := d.enqueueAndDrainBackendOp(ctx, opID, OpBackendUpgrade, name, galleriesJSON, targetNodeIDs, func(node BackendNode) error {
|
||||
// Per-node progress sink: fan each worker download tick into the legacy
|
||||
// single-bar progressCb and the per-node OpStatus.Nodes view, exactly as
|
||||
// InstallBackend does. Defined per-node so each closure captures its own
|
||||
// node.Name. Without this an upgrade blocks opaque at progress 0 for the
|
||||
// whole 15m round-trip (the original "reinstalling but nothing happens").
|
||||
onProgress := func(ev messaging.BackendInstallProgressEvent) {
|
||||
if progressCb != nil {
|
||||
progressCb(ev.FileName, ev.Current, ev.Total, ev.Percentage)
|
||||
}
|
||||
if d.progressSink != nil && opID != "" {
|
||||
d.progressSink.UpdateNodeProgress(opID, ev.NodeID, galleryop.NodeProgress{
|
||||
NodeID: ev.NodeID,
|
||||
NodeName: node.Name,
|
||||
Status: galleryop.NodeStatusDownloading,
|
||||
FileName: ev.FileName,
|
||||
Current: ev.Current,
|
||||
Total: ev.Total,
|
||||
Percentage: ev.Percentage,
|
||||
Phase: ev.Phase,
|
||||
})
|
||||
}
|
||||
}
|
||||
var onProgressArg func(messaging.BackendInstallProgressEvent)
|
||||
if progressCb != nil || d.progressSink != nil {
|
||||
onProgressArg = onProgress
|
||||
}
|
||||
reply, err := d.adapter.UpgradeBackend(node.ID, name, string(galleriesJSON), "", "", "", 0, opID, onProgressArg)
|
||||
if err != nil {
|
||||
// Rolling-update fallback: an older worker doesn't know
|
||||
// backend.upgrade. Try the legacy install-with-force path.
|
||||
if errors.Is(err, nats.ErrNoResponders) {
|
||||
instReply, instErr := d.adapter.installWithForceFallback(node.ID, name, string(galleriesJSON), "", "", "", 0)
|
||||
instReply, instErr := d.adapter.installWithForceFallback(node.ID, name, string(galleriesJSON), "", "", "", 0, opID, onProgressArg)
|
||||
if instErr != nil {
|
||||
return instErr
|
||||
}
|
||||
|
||||
@@ -317,7 +317,7 @@ func (stubLocalBackendManager) DeleteBackend(_ string) error { return gallery.Er
|
||||
func (stubLocalBackendManager) ListBackends() (gallery.SystemBackends, error) {
|
||||
return gallery.SystemBackends{}, nil
|
||||
}
|
||||
func (stubLocalBackendManager) UpgradeBackend(_ context.Context, _ string, _ galleryop.ProgressCallback) error {
|
||||
func (stubLocalBackendManager) UpgradeBackend(_ context.Context, _ string, _ string, _ galleryop.ProgressCallback) error {
|
||||
return nil
|
||||
}
|
||||
func (stubLocalBackendManager) CheckUpgrades(_ context.Context) (map[string]gallery.UpgradeInfo, error) {
|
||||
@@ -782,7 +782,7 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n2.ID),
|
||||
messaging.BackendUpgradeReply{Success: false, Error: "registry unauthorized"})
|
||||
|
||||
err := mgr.UpgradeBackend(ctx, "vllm-development", nil)
|
||||
err := mgr.UpgradeBackend(ctx, "", "vllm-development", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("worker-a"))
|
||||
Expect(err.Error()).To(ContainSubstring("image manifest not found"))
|
||||
@@ -797,7 +797,7 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
scriptInstalled("vllm-development", n1.ID)
|
||||
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n1.ID),
|
||||
messaging.BackendUpgradeReply{Success: true})
|
||||
Expect(mgr.UpgradeBackend(ctx, "vllm-development", nil)).To(Succeed())
|
||||
Expect(mgr.UpgradeBackend(ctx, "", "vllm-development", nil)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -819,7 +819,7 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
// if the manager attempts it, the scripted-client default returns
|
||||
// fakeNoRespondersErr and the assertion below fails loudly.
|
||||
|
||||
Expect(mgr.UpgradeBackend(ctx, "cpu-insightface-development", nil)).To(Succeed())
|
||||
Expect(mgr.UpgradeBackend(ctx, "", "cpu-insightface-development", nil)).To(Succeed())
|
||||
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
@@ -835,7 +835,7 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
n1 := registerHealthyBackend("worker-a", "10.0.0.1:50051")
|
||||
scriptNoBackends(n1.ID)
|
||||
|
||||
err := mgr.UpgradeBackend(ctx, "vllm-development", nil)
|
||||
err := mgr.UpgradeBackend(ctx, "", "vllm-development", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("not installed on any node"))
|
||||
|
||||
@@ -865,7 +865,7 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
func(req messaging.BackendInstallRequest) bool { return req.Force },
|
||||
messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"})
|
||||
|
||||
Expect(mgr.UpgradeBackend(ctx, "vllm-development", nil)).To(Succeed())
|
||||
Expect(mgr.UpgradeBackend(ctx, "", "vllm-development", nil)).To(Succeed())
|
||||
})
|
||||
|
||||
It("returns the upgrade error when it is not ErrNoResponders", func() {
|
||||
@@ -875,7 +875,7 @@ var _ = Describe("DistributedBackendManager", func() {
|
||||
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n.ID),
|
||||
messaging.BackendUpgradeReply{Success: false, Error: "disk full"})
|
||||
|
||||
err := mgr.UpgradeBackend(ctx, "vllm-development", nil)
|
||||
err := mgr.UpgradeBackend(ctx, "", "vllm-development", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("disk full"))
|
||||
})
|
||||
|
||||
135
core/services/nodes/pending_op_cleanup_test.go
Normal file
135
core/services/nodes/pending_op_cleanup_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
)
|
||||
|
||||
// These specs reproduce the distributed "pending ops behind dead nodes leak
|
||||
// forever" bug. ListDuePendingBackendOps only returns rows whose node is
|
||||
// StatusHealthy, so an op queued against a node that goes offline (heartbeat
|
||||
// stale) or draining (admin action) is never retried, never aged out, and
|
||||
// never deleted. On a live cluster these rows sat at attempts=0 indefinitely
|
||||
// and kept the UI operation alive. DeleteStalePendingBackendOps garbage-collects
|
||||
// them: draining nodes immediately (models already purged), offline nodes only
|
||||
// after a grace window so a brief heartbeat blip does not nuke in-flight work.
|
||||
var _ = Describe("DeleteStalePendingBackendOps", func() {
|
||||
var (
|
||||
registry *NodeRegistry
|
||||
ctx context.Context
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
if runtime.GOOS == "darwin" {
|
||||
Skip("testcontainers requires Docker, not available on macOS CI")
|
||||
}
|
||||
db := testutil.SetupTestDB()
|
||||
var err error
|
||||
registry, err = NewNodeRegistry(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ctx = context.Background()
|
||||
})
|
||||
|
||||
// registerBackend registers an auto-approved backend node and returns its ID.
|
||||
registerBackend := func(name, address string) string {
|
||||
node := &BackendNode{Name: name, NodeType: NodeTypeBackend, Address: address}
|
||||
Expect(registry.Register(ctx, node, true)).To(Succeed())
|
||||
fetched, err := registry.GetByName(ctx, name)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return fetched.ID
|
||||
}
|
||||
|
||||
// setHeartbeat forces a node's last_heartbeat (Register/MarkOffline leave it
|
||||
// at "now"; we age it to simulate a node that went silent a while ago).
|
||||
setHeartbeat := func(nodeID string, t time.Time) {
|
||||
Expect(registry.db.WithContext(ctx).Model(&BackendNode{}).
|
||||
Where("id = ?", nodeID).
|
||||
Update("last_heartbeat", t).Error).To(Succeed())
|
||||
}
|
||||
|
||||
pendingCountFor := func(nodeID string) int64 {
|
||||
var n int64
|
||||
Expect(registry.db.WithContext(ctx).Model(&PendingBackendOp{}).
|
||||
Where("node_id = ?", nodeID).Count(&n).Error).To(Succeed())
|
||||
return n
|
||||
}
|
||||
|
||||
It("clears ops behind an offline node whose heartbeat is past the grace window", func() {
|
||||
dead := registerBackend("nvidia-thor", "10.0.0.9:50051")
|
||||
Expect(registry.UpsertPendingBackendOp(ctx, dead, "llama-cpp-development", OpBackendInstall, nil)).To(Succeed())
|
||||
Expect(registry.MarkOffline(ctx, dead)).To(Succeed())
|
||||
setHeartbeat(dead, time.Now().Add(-1*time.Hour))
|
||||
|
||||
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(removed).To(Equal(int64(1)))
|
||||
Expect(pendingCountFor(dead)).To(Equal(int64(0)))
|
||||
})
|
||||
|
||||
It("clears ops behind a draining node immediately, even with a fresh heartbeat", func() {
|
||||
// Mirrors the live mac-mini-m4 case: draining but still heartbeating.
|
||||
drain := registerBackend("mac-mini-m4", "10.0.0.3:50051")
|
||||
Expect(registry.UpsertPendingBackendOp(ctx, drain, "llama-cpp-development", OpBackendInstall, nil)).To(Succeed())
|
||||
Expect(registry.MarkDraining(ctx, drain)).To(Succeed())
|
||||
setHeartbeat(drain, time.Now()) // fresh heartbeat
|
||||
|
||||
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(removed).To(Equal(int64(1)))
|
||||
Expect(pendingCountFor(drain)).To(Equal(int64(0)))
|
||||
})
|
||||
|
||||
It("clears ops behind an unhealthy node with a stale heartbeat (never ages to offline)", func() {
|
||||
// A node marked unhealthy on a NATS ErrNoResponders never transitions to
|
||||
// offline, so its ops must be reaped via the same stale-heartbeat path.
|
||||
sick := registerBackend("agx-orin-sick", "10.0.0.7:50051")
|
||||
Expect(registry.UpsertPendingBackendOp(ctx, sick, "llama-cpp-development", OpBackendUpgrade, nil)).To(Succeed())
|
||||
Expect(registry.MarkUnhealthy(ctx, sick)).To(Succeed())
|
||||
setHeartbeat(sick, time.Now().Add(-1*time.Hour))
|
||||
|
||||
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(removed).To(Equal(int64(1)))
|
||||
Expect(pendingCountFor(sick)).To(Equal(int64(0)))
|
||||
})
|
||||
|
||||
It("keeps ops behind an unhealthy node that is still heartbeating (recovering)", func() {
|
||||
recovering := registerBackend("agx-orin-flap", "10.0.0.8:50051")
|
||||
Expect(registry.UpsertPendingBackendOp(ctx, recovering, "llama-cpp-development", OpBackendUpgrade, nil)).To(Succeed())
|
||||
Expect(registry.MarkUnhealthy(ctx, recovering)).To(Succeed())
|
||||
setHeartbeat(recovering, time.Now()) // fresh heartbeat → recovering
|
||||
|
||||
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(removed).To(Equal(int64(0)))
|
||||
Expect(pendingCountFor(recovering)).To(Equal(int64(1)))
|
||||
})
|
||||
|
||||
It("keeps ops behind a node that only just went offline (within grace)", func() {
|
||||
blip := registerBackend("agx-orin", "10.0.0.4:50051")
|
||||
Expect(registry.UpsertPendingBackendOp(ctx, blip, "parakeet-cpp-development", OpBackendInstall, nil)).To(Succeed())
|
||||
Expect(registry.MarkOffline(ctx, blip)).To(Succeed())
|
||||
setHeartbeat(blip, time.Now().Add(-1*time.Minute)) // gone only 1m, grace 10m
|
||||
|
||||
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(removed).To(Equal(int64(0)))
|
||||
Expect(pendingCountFor(blip)).To(Equal(int64(1)))
|
||||
})
|
||||
|
||||
It("keeps ops behind a healthy node", func() {
|
||||
healthy := registerBackend("dgx-spark", "10.0.0.1:50051")
|
||||
Expect(registry.UpsertPendingBackendOp(ctx, healthy, "llama-cpp-development", OpBackendUpgrade, nil)).To(Succeed())
|
||||
|
||||
removed, err := registry.DeleteStalePendingBackendOps(ctx, 10*time.Minute)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(removed).To(Equal(int64(0)))
|
||||
Expect(pendingCountFor(healthy)).To(Equal(int64(1)))
|
||||
})
|
||||
})
|
||||
@@ -189,6 +189,13 @@ func (rc *ReplicaReconciler) reconcileState(ctx context.Context) {
|
||||
// passed on nodes that are currently healthy. On success the row is deleted;
|
||||
// on failure attempts++ and next_retry_at moves out via exponential backoff.
|
||||
func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) {
|
||||
// Garbage-collect ops behind nodes that went offline/draining. These are
|
||||
// invisible to ListDuePendingBackendOps (which filters status=healthy), so
|
||||
// without this sweep they leak forever and keep the UI operation spinning.
|
||||
if _, err := rc.registry.DeleteStalePendingBackendOps(ctx, stalePendingBackendOpGrace); err != nil {
|
||||
xlog.Warn("Reconciler: failed to clear stale pending backend ops", "error", err)
|
||||
}
|
||||
|
||||
ops, err := rc.registry.ListDuePendingBackendOps(ctx)
|
||||
if err != nil {
|
||||
xlog.Warn("Reconciler: failed to list pending backend ops", "error", err)
|
||||
@@ -223,10 +230,13 @@ func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) {
|
||||
// the same worker. Falls back to the legacy backend.install
|
||||
// Force=true path on nats.ErrNoResponders for old workers that
|
||||
// don't subscribe to backend.upgrade yet (rolling-update window).
|
||||
reply, err := rc.adapter.UpgradeBackend(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0)
|
||||
// Reconciler retries are background reconciliation with no live
|
||||
// admin watching a progress bar, so opID/onProgress are empty —
|
||||
// the adapter skips the progress subscription entirely.
|
||||
reply, err := rc.adapter.UpgradeBackend(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0, "", nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, nats.ErrNoResponders) {
|
||||
instReply, instErr := rc.adapter.installWithForceFallback(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0)
|
||||
instReply, instErr := rc.adapter.installWithForceFallback(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0, "", nil)
|
||||
if instErr != nil {
|
||||
applyErr = instErr
|
||||
} else if !instReply.Success {
|
||||
@@ -293,6 +303,13 @@ func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) {
|
||||
// amount of further retrying will help.
|
||||
const maxPendingBackendOpAttempts = 10
|
||||
|
||||
// stalePendingBackendOpGrace is how long a node may be offline before its
|
||||
// pending backend ops are garbage-collected. Draining nodes are cleared
|
||||
// immediately regardless of this window (see DeleteStalePendingBackendOps).
|
||||
// ListDuePendingBackendOps never surfaces ops behind non-healthy nodes, so
|
||||
// without this sweep they would leak forever and keep the UI op spinning.
|
||||
const stalePendingBackendOpGrace = 15 * time.Minute
|
||||
|
||||
// probeLoadedModels gRPC-health-checks model addresses that the DB says are
|
||||
// loaded. If a model's backend process is gone (OOM, crash, manual restart)
|
||||
// we remove the row so ghosts don't linger. Only probes rows older than
|
||||
|
||||
@@ -1776,6 +1776,38 @@ func (r *NodeRegistry) DeletePendingBackendOp(ctx context.Context, id uint) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteStalePendingBackendOps garbage-collects pending backend ops whose target
|
||||
// node can never drain them. ListDuePendingBackendOps only returns rows behind a
|
||||
// StatusHealthy node, so ops behind a node that went offline or draining are
|
||||
// otherwise never retried, aged out, or deleted — they leak forever and keep the
|
||||
// UI operation spinning. Draining nodes are cleared immediately (an explicit
|
||||
// admin action; their model rows are already purged). Offline nodes are cleared
|
||||
// only once their last heartbeat is older than `grace`, so a brief heartbeat blip
|
||||
// does not nuke an install that is still legitimately in flight. Returns the
|
||||
// number of rows deleted.
|
||||
func (r *NodeRegistry) DeleteStalePendingBackendOps(ctx context.Context, grace time.Duration) (int64, error) {
|
||||
cutoff := time.Now().Add(-grace)
|
||||
// Draining nodes are cleared immediately (admin action; model rows already
|
||||
// purged). Offline AND unhealthy nodes are cleared only once their heartbeat
|
||||
// is older than the grace window: a node marked unhealthy on a NATS
|
||||
// ErrNoResponders never transitions to offline (health.go skips re-marking
|
||||
// it), so without including unhealthy here its ops would leak exactly like
|
||||
// the offline case. A node with a fresh heartbeat (last_heartbeat > cutoff)
|
||||
// is recovering and keeps its op for retry.
|
||||
res := r.db.WithContext(ctx).
|
||||
Where(`node_id IN (SELECT id FROM backend_nodes WHERE status = ?)
|
||||
OR node_id IN (SELECT id FROM backend_nodes WHERE status IN ? AND last_heartbeat <= ?)`,
|
||||
StatusDraining, []string{StatusOffline, StatusUnhealthy}, cutoff).
|
||||
Delete(&PendingBackendOp{})
|
||||
if res.Error != nil {
|
||||
return 0, fmt.Errorf("deleting stale pending backend ops: %w", res.Error)
|
||||
}
|
||||
if res.RowsAffected > 0 {
|
||||
xlog.Info("Cleared pending backend ops behind non-healthy nodes", "deleted", res.RowsAffected)
|
||||
}
|
||||
return res.RowsAffected, nil
|
||||
}
|
||||
|
||||
// RecordPendingBackendOpFailure bumps Attempts, captures the error, and
|
||||
// pushes NextRetryAt out with exponential backoff capped at 15 minutes.
|
||||
func (r *NodeRegistry) RecordPendingBackendOpFailure(ctx context.Context, id uint, errMsg string) error {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -932,13 +933,12 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
||||
{"AudioPath", &opts.AudioPath},
|
||||
}
|
||||
|
||||
// Count stageable files for progress tracking
|
||||
// Count stageable files for progress tracking. Directory models expand to
|
||||
// the number of files they contain, matching what stageDirectory uploads.
|
||||
totalFiles := 0
|
||||
for _, f := range fields {
|
||||
if *f.val != "" {
|
||||
if _, err := os.Stat(*f.val); err == nil {
|
||||
totalFiles++
|
||||
}
|
||||
totalFiles += countStageableFiles(*f.val)
|
||||
}
|
||||
}
|
||||
for _, adapter := range opts.LoraAdapters {
|
||||
@@ -969,8 +969,33 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
||||
*f.val = ""
|
||||
continue
|
||||
}
|
||||
fileIdx++
|
||||
localPath := *f.val
|
||||
|
||||
// Directory models (e.g. qwen3-tts-cpp ships its weights and tokenizer
|
||||
// ggufs under one directory) can't be uploaded as a single file — the
|
||||
// stager would open the directory and read its fd, failing with
|
||||
// "is a directory" (EISDIR). Expand the directory and stage each
|
||||
// contained file, then rewrite the field to the remote directory.
|
||||
if fi, statErr := os.Stat(localPath); statErr == nil && fi.IsDir() {
|
||||
remoteDir, dirErr := r.stageDirectory(ctx, node, trackingKey, localPath, keyMapper, &fileIdx, totalFiles)
|
||||
if dirErr != nil {
|
||||
if f.name == "ModelFile" {
|
||||
xlog.Error("Failed to stage model directory for remote node", "node", node.Name, "field", f.name, "path", localPath, "error", dirErr)
|
||||
return nil, fmt.Errorf("staging model file: %w", dirErr)
|
||||
}
|
||||
xlog.Warn("Failed to stage model directory, clearing field", "field", f.name, "path", localPath, "error", dirErr)
|
||||
*f.val = ""
|
||||
continue
|
||||
}
|
||||
*f.val = remoteDir
|
||||
if f.name == "ModelFile" && opts.Model != "" {
|
||||
opts.ModelPath = DeriveRemoteModelPath(remoteDir, opts.Model)
|
||||
xlog.Debug("Derived remote ModelPath", "modelPath", opts.ModelPath)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
fileIdx++
|
||||
key := keyMapper.Key(localPath)
|
||||
|
||||
// Attach progress callback to context for byte-level tracking
|
||||
@@ -1074,6 +1099,77 @@ func (r *SmartRouter) withStagingCallback(ctx context.Context, trackingKey, file
|
||||
})
|
||||
}
|
||||
|
||||
// countStageableFiles returns the number of regular files a model path expands
|
||||
// to for staging: 1 for a regular file, the contained file count for a
|
||||
// directory, and 0 if the path does not exist.
|
||||
func countStageableFiles(path string) int {
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
if !fi.IsDir() {
|
||||
return 1
|
||||
}
|
||||
n := 0
|
||||
_ = filepath.WalkDir(path, func(_ string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil
|
||||
}
|
||||
if !d.IsDir() {
|
||||
n++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return n
|
||||
}
|
||||
|
||||
// stageDirectory stages every file under a directory-based model (e.g.
|
||||
// qwen3-tts-cpp, whose weights and tokenizer ggufs live in one directory).
|
||||
// Each file is uploaded individually with a structure-preserving key; the
|
||||
// returned path is the remote directory that contained them, suitable for the
|
||||
// backend's ModelFile/ModelPath. fileIdx is advanced per staged file so the
|
||||
// staging progress tracker stays accurate.
|
||||
func (r *SmartRouter) stageDirectory(ctx context.Context, node *BackendNode, trackingKey, dir string, keyMapper *StagingKeyMapper, fileIdx *int, totalFiles int) (string, error) {
|
||||
var remoteDir string
|
||||
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
*fileIdx++
|
||||
fileName := filepath.Base(path)
|
||||
stageCtx := r.withStagingCallback(ctx, trackingKey, fileName, *fileIdx, totalFiles)
|
||||
xlog.Info("Staging file", "model", trackingKey, "node", node.Name, "field", "ModelDir", "file", fileName, "fileIndex", *fileIdx, "totalFiles", totalFiles)
|
||||
|
||||
remoteFile, err := r.fileStager.EnsureRemote(stageCtx, node.ID, path, keyMapper.Key(path))
|
||||
if err != nil {
|
||||
return fmt.Errorf("staging %s: %w", path, err)
|
||||
}
|
||||
r.stagingTracker.FileComplete(trackingKey, *fileIdx, totalFiles)
|
||||
|
||||
// Every file under dir shares the same remote parent directory; derive
|
||||
// it from this file's staged path and its path relative to dir.
|
||||
rel, relErr := filepath.Rel(dir, path)
|
||||
if relErr != nil {
|
||||
return relErr
|
||||
}
|
||||
remoteDir = DeriveRemoteModelPath(remoteFile, rel)
|
||||
|
||||
r.stageCompanionFiles(ctx, node, path, keyMapper.Key)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if remoteDir == "" {
|
||||
return "", fmt.Errorf("model directory %s contains no files", dir)
|
||||
}
|
||||
return remoteDir, nil
|
||||
}
|
||||
|
||||
// stageCompanionFiles stages known companion files that exist alongside
|
||||
// localPath. For example, piper TTS implicitly loads ".onnx.json" next to
|
||||
// the ".onnx" model file. Errors are logged but not propagated.
|
||||
|
||||
64
core/services/nodes/router_dirstage_test.go
Normal file
64
core/services/nodes/router_dirstage_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// These tests cover staging of "directory models" — models whose ModelFile is a
|
||||
// directory containing multiple files (e.g. qwen3-tts-cpp ships weights +
|
||||
// tokenizer ggufs under one directory). The HTTP file stager uploads a single
|
||||
// regular file per path, so a directory ModelFile must be expanded into its
|
||||
// constituent files; otherwise the upload reads a directory fd and fails with
|
||||
// "is a directory" (EISDIR) on remote NATS worker nodes.
|
||||
var _ = Describe("stageModelFiles directory models", func() {
|
||||
var (
|
||||
stager *fakeFileStager
|
||||
router *SmartRouter
|
||||
node *BackendNode
|
||||
tmp string
|
||||
modelID = "qwen3-tts-cpp"
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
stager = &fakeFileStager{}
|
||||
router = &SmartRouter{
|
||||
fileStager: stager,
|
||||
stagingTracker: NewStagingTracker(),
|
||||
}
|
||||
node = &BackendNode{ID: "node-1", Name: "node-1", Address: "10.0.0.1:50051"}
|
||||
tmp = GinkgoT().TempDir()
|
||||
})
|
||||
|
||||
It("stages every file inside a directory ModelFile instead of the directory path", func() {
|
||||
modelDir := filepath.Join(tmp, "models", modelID)
|
||||
Expect(os.MkdirAll(modelDir, 0o755)).To(Succeed())
|
||||
weights := filepath.Join(modelDir, "qwen3-tts-0.6b-f16.gguf")
|
||||
tokenizer := filepath.Join(modelDir, "qwen3-tts-tokenizer-f16.gguf")
|
||||
Expect(os.WriteFile(weights, []byte("weights"), 0o644)).To(Succeed())
|
||||
Expect(os.WriteFile(tokenizer, []byte("tokenizer"), 0o644)).To(Succeed())
|
||||
|
||||
opts := &pb.ModelOptions{
|
||||
Model: modelID,
|
||||
ModelFile: modelDir,
|
||||
}
|
||||
|
||||
_, err := router.stageModelFiles(context.Background(), node, opts, "track-key")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
staged := make([]string, 0, len(stager.ensureCalls))
|
||||
for _, c := range stager.ensureCalls {
|
||||
staged = append(staged, c.localPath)
|
||||
}
|
||||
// Each contained file is staged individually; the directory path itself
|
||||
// is never handed to the stager (which would read a directory fd).
|
||||
Expect(staged).To(ConsistOf(weights, tokenizer))
|
||||
Expect(staged).ToNot(ContainElement(modelDir))
|
||||
})
|
||||
})
|
||||
@@ -365,7 +365,7 @@ func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ strin
|
||||
return f.installReply, f.installErr
|
||||
}
|
||||
|
||||
func (f *fakeUnloader) UpgradeBackend(nodeID, backend, _, _, _, _ string, replica int) (*messaging.BackendUpgradeReply, error) {
|
||||
func (f *fakeUnloader) UpgradeBackend(nodeID, backend, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error) {
|
||||
f.mu.Lock()
|
||||
f.upgradeCalls = append(f.upgradeCalls, upgradeCall{nodeID, backend, replica})
|
||||
f.mu.Unlock()
|
||||
|
||||
@@ -35,7 +35,7 @@ type backendStopRequest struct {
|
||||
// backend.upgrade subject.
|
||||
type NodeCommandSender interface {
|
||||
InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error)
|
||||
UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error)
|
||||
UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error)
|
||||
DeleteBackend(nodeID, backendName string) (*messaging.BackendDeleteReply, error)
|
||||
ListBackends(nodeID string) (*messaging.BackendListReply, error)
|
||||
StopBackend(nodeID, backend string) error
|
||||
@@ -127,38 +127,8 @@ func (a *RemoteUnloaderAdapter) InstallBackend(
|
||||
xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID, "replica", replicaIndex, "opID", opID)
|
||||
|
||||
// Subscribe to the per-op progress subject BEFORE publishing the install
|
||||
// request so we don't miss early events. When onProgress is nil OR opID
|
||||
// is empty (the reconciler-driven retry path), skip subscription entirely:
|
||||
// silent installs cost nothing extra.
|
||||
var sub messaging.Subscription
|
||||
if onProgress != nil && opID != "" {
|
||||
progressSubject := messaging.SubjectNodeBackendInstallProgress(nodeID, opID)
|
||||
s, subErr := a.nats.Subscribe(progressSubject, func(raw []byte) {
|
||||
var ev messaging.BackendInstallProgressEvent
|
||||
if err := json.Unmarshal(raw, &ev); err != nil {
|
||||
xlog.Debug("malformed install progress event", "subject", progressSubject, "error", err)
|
||||
return
|
||||
}
|
||||
// Goroutine guard: a slow onProgress callback must not stall
|
||||
// the NATS reader thread.
|
||||
//
|
||||
// NOTE: events spawn one goroutine each, so ordering at the
|
||||
// consumer is best-effort. In practice the worker debounces to
|
||||
// ~250ms which is far larger than goroutine scheduling jitter,
|
||||
// so reordering is rare. The worker's final Flush() event is
|
||||
// intended to win as the terminal tick. A future hardening pass
|
||||
// could add a Seq uint64 field to BackendInstallProgressEvent
|
||||
// and drop stale-by-seq at the bridge if reordering becomes a
|
||||
// real UX issue.
|
||||
go onProgress(ev)
|
||||
})
|
||||
if subErr != nil {
|
||||
xlog.Warn("Failed to subscribe to install progress subject; proceeding without progress streaming",
|
||||
"subject", progressSubject, "error", subErr)
|
||||
} else {
|
||||
sub = s
|
||||
}
|
||||
}
|
||||
// request so we don't miss early events.
|
||||
sub := a.subscribeProgress(nodeID, opID, onProgress)
|
||||
|
||||
reply, err := messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
Backend: backendType,
|
||||
@@ -182,18 +152,58 @@ func (a *RemoteUnloaderAdapter) InstallBackend(
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// subscribeProgress subscribes to the per-op backend-install progress subject
|
||||
// so the master can stream per-node download ticks while a worker installs or
|
||||
// upgrades. Returns nil (and subscribes to nothing) when onProgress is nil or
|
||||
// opID is empty — the reconciler-driven retry path and legacy callers stay
|
||||
// silent at no cost. Shared by InstallBackend, UpgradeBackend, and the legacy
|
||||
// force-install fallback: an upgrade is a force-reinstall, so it reuses the
|
||||
// install-progress subject rather than minting a new one (no new NATS
|
||||
// permission, no new rolling-update compat surface). Caller must Unsubscribe
|
||||
// the returned subscription after the request completes.
|
||||
func (a *RemoteUnloaderAdapter) subscribeProgress(nodeID, opID string, onProgress func(messaging.BackendInstallProgressEvent)) messaging.Subscription {
|
||||
if onProgress == nil || opID == "" {
|
||||
return nil
|
||||
}
|
||||
progressSubject := messaging.SubjectNodeBackendInstallProgress(nodeID, opID)
|
||||
s, subErr := a.nats.Subscribe(progressSubject, func(raw []byte) {
|
||||
var ev messaging.BackendInstallProgressEvent
|
||||
if err := json.Unmarshal(raw, &ev); err != nil {
|
||||
xlog.Debug("malformed backend progress event", "subject", progressSubject, "error", err)
|
||||
return
|
||||
}
|
||||
// Goroutine guard: a slow onProgress callback must not stall the NATS
|
||||
// reader thread. Events spawn one goroutine each, so ordering at the
|
||||
// consumer is best-effort; the worker debounces to ~250ms which dwarfs
|
||||
// goroutine scheduling jitter, and its final Flush() is the terminal tick.
|
||||
go onProgress(ev)
|
||||
})
|
||||
if subErr != nil {
|
||||
xlog.Warn("Failed to subscribe to backend progress subject; proceeding without progress streaming",
|
||||
"subject", progressSubject, "error", subErr)
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// UpgradeBackend sends a backend.upgrade request-reply to a worker node.
|
||||
// The worker stops every live process for this backend, force-reinstalls
|
||||
// from the gallery (overwriting the on-disk artifact), and replies. The
|
||||
// next routine InstallBackend call spawns a fresh process with the new
|
||||
// binary - upgrade itself does not start a process.
|
||||
//
|
||||
// When opID is non-empty and onProgress is set, the master subscribes to the
|
||||
// per-op progress subject before firing the request so a long force-reinstall
|
||||
// streams per-node download ticks instead of blocking opaque at progress 0.
|
||||
//
|
||||
// Timeout: configured via DistributedConfig.BackendUpgradeTimeoutOrDefault
|
||||
// (default 15m). Real-world worst case observed: 8-10 minutes for large
|
||||
// CUDA-l4t backend images on Jetson over WiFi.
|
||||
func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error) {
|
||||
func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error) {
|
||||
subject := messaging.SubjectNodeBackendUpgrade(nodeID)
|
||||
xlog.Info("Sending NATS backend.upgrade", "nodeID", nodeID, "backend", backendType, "replica", replicaIndex)
|
||||
xlog.Info("Sending NATS backend.upgrade", "nodeID", nodeID, "backend", backendType, "replica", replicaIndex, "opID", opID)
|
||||
|
||||
sub := a.subscribeProgress(nodeID, opID, onProgress)
|
||||
|
||||
reply, err := messaging.RequestJSON[messaging.BackendUpgradeRequest, messaging.BackendUpgradeReply](a.nats, subject, messaging.BackendUpgradeRequest{
|
||||
Backend: backendType,
|
||||
@@ -202,7 +212,13 @@ func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSO
|
||||
Name: name,
|
||||
Alias: alias,
|
||||
ReplicaIndex: int32(replicaIndex),
|
||||
OpID: opID,
|
||||
}, a.upgradeTimeout)
|
||||
|
||||
if sub != nil {
|
||||
_ = sub.Unsubscribe()
|
||||
}
|
||||
|
||||
if err != nil && isNATSTimeout(err) {
|
||||
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
|
||||
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)
|
||||
@@ -216,10 +232,12 @@ func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSO
|
||||
// doesn't subscribe to the new subject). It re-fires the legacy
|
||||
// backend.install with Force=true. Drop this once every worker is on
|
||||
// 2026-05-08 or newer.
|
||||
func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error) {
|
||||
func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int, opID string, onProgress func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error) {
|
||||
subject := messaging.SubjectNodeBackendInstall(nodeID)
|
||||
xlog.Warn("Falling back to legacy backend.install Force=true (old worker)", "nodeID", nodeID, "backend", backendType)
|
||||
|
||||
sub := a.subscribeProgress(nodeID, opID, onProgress)
|
||||
|
||||
reply, err := messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
|
||||
Backend: backendType,
|
||||
BackendGalleries: galleriesJSON,
|
||||
@@ -228,7 +246,13 @@ func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, ga
|
||||
Alias: alias,
|
||||
ReplicaIndex: int32(replicaIndex),
|
||||
Force: true,
|
||||
OpID: opID,
|
||||
}, a.upgradeTimeout)
|
||||
|
||||
if sub != nil {
|
||||
_ = sub.Unsubscribe()
|
||||
}
|
||||
|
||||
if err != nil && isNATSTimeout(err) {
|
||||
return nil, fmt.Errorf("%w (subject=%s nodeID=%s backend=%s): %v",
|
||||
galleryop.ErrWorkerStillInstalling, subject, nodeID, backendType, err)
|
||||
|
||||
@@ -282,7 +282,7 @@ var _ = Describe("RemoteUnloaderAdapter timeout configuration", func() {
|
||||
mc.scriptReply(messaging.SubjectNodeBackendUpgrade("n1"), messaging.BackendUpgradeReply{Success: true})
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 7*time.Minute, 11*time.Minute)
|
||||
|
||||
_, err := adapter.UpgradeBackend("n1", "llama-cpp", "[]", "", "", "", 0)
|
||||
_, err := adapter.UpgradeBackend("n1", "llama-cpp", "[]", "", "", "", 0, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(mc.calls).To(HaveLen(1))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -18,7 +19,7 @@ var _ = Describe("RemoteUnloaderAdapter.UpgradeBackend", func() {
|
||||
messaging.BackendUpgradeReply{Success: true})
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
|
||||
reply, err := adapter.UpgradeBackend(nodeID, "llama-cpp", `[{"name":"x"}]`, "", "", "", 0)
|
||||
reply, err := adapter.UpgradeBackend(nodeID, "llama-cpp", `[{"name":"x"}]`, "", "", "", 0, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(reply.Success).To(BeTrue())
|
||||
})
|
||||
@@ -27,7 +28,55 @@ var _ = Describe("RemoteUnloaderAdapter.UpgradeBackend", func() {
|
||||
mc := newScriptedMessagingClient() // unscripted subject => fakeNoRespondersErr by harness convention
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
|
||||
_, err := adapter.UpgradeBackend("missing-node", "llama-cpp", "", "", "", "", 0)
|
||||
_, err := adapter.UpgradeBackend("missing-node", "llama-cpp", "", "", "", "", 0, "", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
// Reproducer for "upgrade reports progress:0 the whole time" (Bug B). The
|
||||
// install path streamed per-node download ticks; the upgrade path did a bare
|
||||
// request→single-reply with no progress subscription, so a long force-reinstall
|
||||
// blocked opaque. The adapter must subscribe to the per-op progress subject
|
||||
// (reused from install) BEFORE the request and deliver each tick to onProgress.
|
||||
It("streams per-node progress ticks during the upgrade", func() {
|
||||
mc := newScriptedMessagingClient()
|
||||
nodeID := "node-slow"
|
||||
opID := "op-upgrade-1"
|
||||
|
||||
mc.scriptReply(messaging.SubjectNodeBackendUpgrade(nodeID),
|
||||
messaging.BackendUpgradeReply{Success: true})
|
||||
// The worker would publish these while force-reinstalling. The harness
|
||||
// replays them as soon as the adapter subscribes to the per-op subject.
|
||||
mc.scheduleProgressPublish(nodeID, opID, []messaging.BackendInstallProgressEvent{
|
||||
{NodeID: nodeID, FileName: "llama-cpp.tar", Current: "10 MB", Total: "100 MB", Percentage: 10},
|
||||
{NodeID: nodeID, FileName: "llama-cpp.tar", Current: "100 MB", Total: "100 MB", Percentage: 100},
|
||||
})
|
||||
|
||||
var mu sync.Mutex
|
||||
var got []messaging.BackendInstallProgressEvent
|
||||
onProgress := func(ev messaging.BackendInstallProgressEvent) {
|
||||
mu.Lock()
|
||||
got = append(got, ev)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
adapter := NewRemoteUnloaderAdapter(nil, mc, 3*time.Minute, 15*time.Minute)
|
||||
reply, err := adapter.UpgradeBackend(nodeID, "llama-cpp", `[{"name":"x"}]`, "", "", "", 0, opID, onProgress)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(reply.Success).To(BeTrue())
|
||||
|
||||
// Confirm it subscribed to the (reused) install-progress subject for this op.
|
||||
Expect(mc.subscribeCalls()).To(ContainElement(messaging.SubjectNodeBackendInstallProgress(nodeID, opID)))
|
||||
|
||||
// Progress events are delivered asynchronously (goroutine-per-event), so
|
||||
// poll for both and assert on the set — ordering is best-effort by design.
|
||||
Eventually(func() []float64 {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
pcts := make([]float64, 0, len(got))
|
||||
for _, e := range got {
|
||||
pcts = append(pcts, e.Percentage)
|
||||
}
|
||||
return pcts
|
||||
}, 2*time.Second, 20*time.Millisecond).Should(ConsistOf(float64(10), float64(100)))
|
||||
})
|
||||
})
|
||||
|
||||
30
core/services/worker/auth_required_test.go
Normal file
30
core/services/worker/auth_required_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Worker auth-required helpers", func() {
|
||||
DescribeTable("NatsAuthRequired",
|
||||
func(nats, umbrella, want bool) {
|
||||
cfg := &Config{NatsRequireAuth: nats, DistributedRequireAuth: umbrella}
|
||||
Expect(cfg.NatsAuthRequired()).To(Equal(want))
|
||||
},
|
||||
Entry("neither", false, false, false),
|
||||
Entry("granular only", true, false, true),
|
||||
Entry("umbrella only", false, true, true),
|
||||
Entry("both", true, true, true),
|
||||
)
|
||||
|
||||
DescribeTable("RegistrationAuthRequired",
|
||||
func(reg, umbrella, want bool) {
|
||||
cfg := &Config{RegistrationRequireAuth: reg, DistributedRequireAuth: umbrella}
|
||||
Expect(cfg.RegistrationAuthRequired()).To(Equal(want))
|
||||
},
|
||||
Entry("neither", false, false, false),
|
||||
Entry("granular only", true, false, true),
|
||||
Entry("umbrella only", false, true, true),
|
||||
Entry("both", true, true, true),
|
||||
)
|
||||
})
|
||||
@@ -44,12 +44,14 @@ type Config struct {
|
||||
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server" hidden:""`
|
||||
|
||||
// Registration (required)
|
||||
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""`
|
||||
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
|
||||
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
|
||||
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
|
||||
NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"`
|
||||
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""`
|
||||
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
|
||||
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
|
||||
RegistrationRequireAuth bool `env:"LOCALAI_REGISTRATION_REQUIRE_AUTH" default:"false" help:"Refuse to start the HTTP file-transfer server when no registration token is set (otherwise it fails open and serves read/write to models/staging/data unauthenticated)" group:"registration"`
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch implying both --nats-require-auth and --registration-require-auth" group:"distributed"`
|
||||
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
|
||||
NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"`
|
||||
// MaxReplicasPerModel caps how many replicas of any one model can run on
|
||||
// this worker concurrently. Default 1 = historical single-replica
|
||||
// behavior. Set higher when a node has enough VRAM to host multiple
|
||||
@@ -60,7 +62,13 @@ type Config struct {
|
||||
MaxReplicasPerModel int `env:"LOCALAI_MAX_REPLICAS_PER_MODEL" default:"1" help:"Max replicas of any single model on this worker. Default 1 preserves single-replica behavior; set higher to allow stacking replicas on a fat node." group:"registration"`
|
||||
|
||||
// NATS (required)
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (normally from registration nats_jwt)" group:"distributed"`
|
||||
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user signing seed override (normally from registration nats_user_seed)" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed from registration or env" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
|
||||
// S3 storage for distributed file transfer
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"`
|
||||
@@ -69,3 +77,15 @@ type Config struct {
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret key" group:"distributed"`
|
||||
}
|
||||
|
||||
// NatsAuthRequired reports whether NATS JWT credentials must be present — the
|
||||
// granular flag or the umbrella (LOCALAI_DISTRIBUTED_REQUIRE_AUTH).
|
||||
func (c Config) NatsAuthRequired() bool {
|
||||
return c.NatsRequireAuth || c.DistributedRequireAuth
|
||||
}
|
||||
|
||||
// RegistrationAuthRequired reports whether a registration token must be set
|
||||
// before the file-transfer server may start — the granular flag or the umbrella.
|
||||
func (c Config) RegistrationAuthRequired() bool {
|
||||
return c.RegistrationRequireAuth || c.DistributedRequireAuth
|
||||
}
|
||||
|
||||
@@ -186,17 +186,29 @@ func (s *backendSupervisor) upgradeBackend(req messaging.BackendUpgradeRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// When the master tagged this upgrade with an OpID, stream gallery download
|
||||
// progress back on the per-op subject (reused from install — an upgrade is a
|
||||
// force-reinstall). Old masters omit OpID and stay on the silent path. The
|
||||
// deferred Flush guarantees a terminal-percentage event even if the upgrade
|
||||
// errors out, so the master's per-node bar never hangs mid-download.
|
||||
var downloadCb func(file, current, total string, percentage float64)
|
||||
if req.OpID != "" && s.nats != nil {
|
||||
publisher := nodes.NewDebouncedInstallProgressPublisher(s.nats, s.nodeID, req.OpID, req.Backend, installProgressDebounce)
|
||||
downloadCb = publisher.OnDownload
|
||||
defer publisher.Flush()
|
||||
}
|
||||
|
||||
if req.URI != "" {
|
||||
xlog.Info("Upgrading backend from external URI", "backend", req.Backend, "uri", req.URI)
|
||||
if err := galleryop.InstallExternalBackend(
|
||||
context.Background(), galleries, s.systemState, s.ml, nil, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
|
||||
context.Background(), galleries, s.systemState, s.ml, downloadCb, req.URI, req.Name, req.Alias, s.cfg.RequireBackendIntegrity,
|
||||
); err != nil {
|
||||
return fmt.Errorf("upgrading backend from external URI: %w", err)
|
||||
}
|
||||
} else {
|
||||
xlog.Info("Upgrading backend from gallery", "backend", req.Backend)
|
||||
if err := gallery.InstallBackendFromGallery(
|
||||
context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, true, /* force */
|
||||
context.Background(), galleries, s.systemState, s.ml, req.Backend, downloadCb, true, /* force */
|
||||
s.cfg.RequireBackendIntegrity,
|
||||
); err != nil {
|
||||
return fmt.Errorf("upgrading backend from gallery: %w", err)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user