mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-06 07:46:15 -04:00
Compare commits
68 Commits
v4.3.6
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2342c9348e | ||
|
|
352b7ec604 | ||
|
|
ba706422fb | ||
|
|
e837921c2c | ||
|
|
73385713ca | ||
|
|
a4e671779a | ||
|
|
7051b2e0a1 | ||
|
|
469737101a | ||
|
|
858257eaf0 | ||
|
|
ef80a0e825 | ||
|
|
92726f7631 | ||
|
|
994063ba9a | ||
|
|
c1a55cf72d | ||
|
|
96758841d8 | ||
|
|
7a59260621 | ||
|
|
27e63b9a78 | ||
|
|
55c0911c23 | ||
|
|
f6cb6ab6d9 | ||
|
|
9f11b09c6a | ||
|
|
a5c4f822f0 | ||
|
|
fb36c262fe | ||
|
|
0e4e8980e6 | ||
|
|
3a932a9803 | ||
|
|
9d10418593 | ||
|
|
5470051d4d | ||
|
|
68c5eeebc3 | ||
|
|
1531fabe23 | ||
|
|
b7673d5b76 | ||
|
|
b64bdaf406 | ||
|
|
eebf08ff1d | ||
|
|
42e51894c3 | ||
|
|
d9ae6481fb | ||
|
|
f1c495a748 | ||
|
|
415b561947 | ||
|
|
e6a0d4c375 | ||
|
|
7e59a5c7c5 | ||
|
|
aea954a482 | ||
|
|
595e448714 | ||
|
|
860f9d63ad | ||
|
|
a5a0b3dc4e | ||
|
|
94eca04c60 | ||
|
|
35bd485d6a | ||
|
|
1fe96f8d9a | ||
|
|
c508e9d7c6 | ||
|
|
55e754fd05 | ||
|
|
a17753f7d1 | ||
|
|
c61838dba6 | ||
|
|
7013e13f05 | ||
|
|
5a0013defe | ||
|
|
c01ed631d6 | ||
|
|
d47464cb06 | ||
|
|
63f176346e | ||
|
|
af94d08729 | ||
|
|
6795d38f50 | ||
|
|
718223f33b | ||
|
|
39e050d9e2 | ||
|
|
c222161291 | ||
|
|
aa80d4681b | ||
|
|
0d57957ebb | ||
|
|
76fe0bb929 | ||
|
|
baa11133f1 | ||
|
|
1bdd3338a6 | ||
|
|
e08492a2c3 | ||
|
|
d5d8fe909d | ||
|
|
8a82753277 | ||
|
|
51ca109067 | ||
|
|
07f6c15a37 | ||
|
|
a44bdb29d4 |
@@ -38,9 +38,12 @@ The React UI (`core/http/react-ui/`) has **no component/unit tests** — its onl
|
||||
- **Browser:** the flake dev shell ships `chromium` and exports `PLAYWRIGHT_CHROMIUM_PATH`; `playwright.config.js` uses it via `launchOptions.executablePath`, and the Makefile skips `playwright install` when it's set. This avoids Playwright's downloaded browser, which can't resolve system libs (`libglib-2.0`, …) on NixOS. In CI (no `PLAYWRIGHT_CHROMIUM_PATH`) the Makefile falls back to `playwright install --with-deps chromium`.
|
||||
- The app is a React SPA, so coverage accumulates across in-app navigation within a test; a full `page.goto`/reload resets it.
|
||||
- `.nycrc.json` uses `all: true`, so **every `src/**` file is in the report**, including 0%-coverage ones — that's how you spot features with no test at all (sort the HTML report or `coverage-summary.json` by line% ascending).
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` (default **1.0pp**) below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. **Why a tolerance (unlike the strict Go gate):** UI e2e line coverage is *non-deterministic* — async/debounced paths (e.g. the VRAM estimate's 500ms debounce) make identical specs vary ~0.5pp run-to-run, so a zero-tolerance gate would flake. Keep the tolerance just above the observed jitter. Run in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. Runs in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **Why it has a tolerance (unlike the strict Go gate):** UI e2e coverage is *non-deterministic*. Specs that assert on state and end while async/lazy render work is still in flight collect those lines only when the render beats the coverage teardown — so the total drifts with machine speed/load (a fast local box reads higher than a slow CI runner), diffusely across many specs. The tolerance absorbs that drift, so set the baseline *below* the slow-CI floor, never to a fast-local `make test-ui-coverage-baseline` number, or CI flaps.
|
||||
- **Raising coverage is cheap:** a *render-smoke* spec (navigate to a route, assert its header renders) mounts a lazy page and runs its full render + initial effects, capturing most of its lines in a few lines of test — see `e2e/page-render-smoke.spec.js`. Auth is disabled in the test server (`isAdmin=true`), so `RequireAdmin`/`RequireFeature` routes render without a mock. The most *deterministic* win is removing a race: make a spec `await` a rendered element before ending (see `e2e/agents.spec.js` → AgentCreate) so its lines count every run.
|
||||
|
||||
Rules:
|
||||
- The gate is **strict — there is no tolerance**. Any decrease fails, regardless of how many lines a PR adds or deletes. `covermode=atomic` makes line coverage deterministic, so there's no run-to-run jitter to excuse.
|
||||
- When a change legitimately **raises** coverage, run `make test-coverage-baseline` and **commit** the updated `coverage-baseline.txt` so the ratchet moves up. Never lower the baseline by hand.
|
||||
- If you can't get coverage back to baseline, the fix is to **add tests**, not to edit the baseline.
|
||||
Rules (both gates):
|
||||
- **Install the hooks:** `make install-hooks` once per clone so lint + coverage run pre-commit. Don't lean on CI for what the hook catches.
|
||||
- **Don't work around the gate:** never `git commit --no-verify`, and never hand-lower a baseline or widen a tolerance to turn a red gate green. The ratchet only moves up.
|
||||
- If a change drops coverage, **add tests** (sort `coverage-summary.json` by line% ascending to find untested code) rather than editing the baseline. When coverage legitimately rises, commit the regenerated baseline (`make test-coverage-baseline` / `test-ui-coverage-baseline`).
|
||||
- The Go gate is **strict — no tolerance**; `covermode=atomic` keeps it deterministic. The UI gate keeps a small tolerance only because its e2e coverage isn't.
|
||||
|
||||
@@ -68,6 +68,34 @@ go test -count=1 -timeout=30m -v ./tests/e2e-backends/...
|
||||
|
||||
CI does not load the model; the suite is opt-in via env vars.
|
||||
|
||||
## Distributed mode
|
||||
|
||||
ds4 supports **layer-split** distributed inference (a model too big for one host,
|
||||
split by transformer layer; the GGUF must be present on every machine, each loads
|
||||
only its slice). Topology is **inverted** vs llama.cpp: the coordinator listens,
|
||||
workers dial in.
|
||||
|
||||
- **`ds4-worker` binary**: built and packaged next to `grpc-server` (`package.sh`
|
||||
copies it into `package/`). Links the same engine objects plus `ds4_distributed.o`;
|
||||
**no gRPC/protobuf dependency** (speaks ds4's own TCP transport), so it builds
|
||||
even where `grpc-server` can't. Runs the worker serving loop (`ds4_dist_run`).
|
||||
- **Coordinator wiring**: the ds4 `grpc-server` acts as coordinator when `LoadModel`
|
||||
`ModelOptions.Options` (from model-YAML `options:`) carry:
|
||||
- `ds4_role:coordinator` (enables distributed mode; absent → single-node, back-compat)
|
||||
- `ds4_layers:0:19` (coordinator's own slice, inclusive; `N:output` includes the head)
|
||||
- `ds4_listen:0.0.0.0:1234` (address workers dial into)
|
||||
- `ds4_route_timeout:60` (optional; seconds Predict/PredictStream wait for the route
|
||||
to form before returning gRPC `UNAVAILABLE`; default 60)
|
||||
- **Worker CLI**: `local-ai worker ds4-distributed -- <ds4-worker args>` resolves the
|
||||
ds4 backend and execs the packaged `ds4-worker` (raw passthrough), e.g.
|
||||
`--role worker --model /models/ds4flash.gguf --layers 20:output --coordinator <host> 1234`.
|
||||
|
||||
Opt-in e2e in `tests/e2e-backends/backend_test.go`, gated by
|
||||
`BACKEND_TEST_DS4_DISTRIBUTED=1` (plus `BACKEND_TEST_DS4_WORKER_BINARY`,
|
||||
`BACKEND_TEST_DS4_WORKER_LAYERS`, `BACKEND_TEST_DS4_COORDINATOR_LAYERS`,
|
||||
`BACKEND_TEST_DS4_LISTEN`). Design spec:
|
||||
`docs/superpowers/specs/2026-05-30-ds4-distributed-inference-design.md`.
|
||||
|
||||
## Importer
|
||||
|
||||
`core/gallery/importers/ds4.go` (`DS4Importer`) auto-detects ds4 weights by
|
||||
|
||||
151
.github/backend-matrix.yml
vendored
151
.github/backend-matrix.yml
vendored
@@ -716,6 +716,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -1569,6 +1582,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1595,6 +1621,19 @@ include:
|
||||
backend: "whisper"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-crispasr'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2889,6 +2928,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2903,6 +2956,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2916,6 +2983,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2929,6 +3009,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2943,6 +3036,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2957,6 +3064,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2970,6 +3091,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-crispasr'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2983,6 +3117,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-crispasr'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# parakeet-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -4124,6 +4271,10 @@ includeDarwin:
|
||||
tag-suffix: "-metal-darwin-arm64-whisper"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "crispasr"
|
||||
tag-suffix: "-metal-darwin-arm64-crispasr"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "parakeet-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-parakeet-cpp"
|
||||
build-type: "metal"
|
||||
|
||||
13
.github/gallery-agent/main.go
vendored
13
.github/gallery-agent/main.go
vendored
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -113,6 +114,17 @@ func main() {
|
||||
fmt.Println("Searching for trending models on HuggingFace...")
|
||||
rawModels, err := client.GetTrending(searchTerm, limit)
|
||||
if err != nil {
|
||||
if errors.Is(err, hfapi.ErrRateLimited) {
|
||||
fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err)
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: 0,
|
||||
ModelsAdded: 0,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -277,4 +289,3 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -30,6 +30,10 @@ jobs:
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/whisper/Makefile"
|
||||
- repository: "CrispStrobe/CrispASR"
|
||||
variable: "CRISPASR_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/crispasr/Makefile"
|
||||
- repository: "mudler/parakeet.cpp"
|
||||
variable: "PARAKEET_VERSION"
|
||||
branch: "master"
|
||||
|
||||
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: securego/gosec@v2.22.9
|
||||
uses: securego/gosec@v2.27.1
|
||||
with:
|
||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||
|
||||
@@ -35,6 +35,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
|
||||
## Quick Reference
|
||||
|
||||
- **Git hooks & coverage gates**: Run `make install-hooks` once per clone so the pre-commit lint + coverage gates run. **Never bypass them with `git commit --no-verify`, and never lower a coverage baseline or widen a gate's tolerance to turn a red gate green** — the coverage ratchet only moves up. If a change drops coverage, add tests to raise it (e.g. render-smoke specs). See [.agents/building-and-testing.md](.agents/building-and-testing.md).
|
||||
- **Logging**: Use `github.com/mudler/xlog` (same API as slog)
|
||||
- **Go style**: Prefer `any` over `interface{}`
|
||||
- **Comments**: Explain *why*, not *what*
|
||||
|
||||
@@ -266,6 +266,12 @@ The e2e tests run LocalAI in a Docker container and exercise the API:
|
||||
make test-e2e
|
||||
```
|
||||
|
||||
### React UI tests and coverage
|
||||
|
||||
The React UI (`core/http/react-ui/`) is covered by Playwright e2e specs, gated by a **monotonic line-coverage ratchet** (`make test-ui-coverage-check`, run in CI and pre-commit). The metric is non-deterministic — a fast local box reads higher than a slow CI runner for the same code — so a small tolerance is unavoidable.
|
||||
|
||||
**If your change lowers UI coverage, raise it back by adding specs — do not widen the tolerance or hand-lower the baseline.** A *render-smoke* spec (navigate to a page, assert its header is visible) cheaply covers an entire lazy page. See `core/http/react-ui/e2e/page-render-smoke.spec.js` and the full policy in [.agents/building-and-testing.md](.agents/building-and-testing.md#react-ui-coverage).
|
||||
|
||||
### Running E2E container tests
|
||||
|
||||
These tests build a standard LocalAI Docker image and run it with pre-configured model configs to verify that most endpoints work correctly:
|
||||
|
||||
15
Makefile
15
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -309,13 +309,20 @@ run-e2e-aio: protogen-go
|
||||
@echo 'Running e2e AIO tests'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
|
||||
|
||||
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
|
||||
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
|
||||
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
|
||||
test-e2e-distributed: protogen-go
|
||||
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
|
||||
|
||||
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
|
||||
# cpu-vllm backend from the current working tree, then drives a
|
||||
# head + headless follower via testcontainers-go and asserts a chat
|
||||
# completion. BuildKit caches both images, so re-runs only rebuild
|
||||
# what changed. The test lives under tests/e2e/distributed and is
|
||||
# selected by the VLLMMultinode label so it doesn't run alongside
|
||||
# the other distributed-suite tests by default.
|
||||
# test-e2e-distributed.
|
||||
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
|
||||
@echo 'Running e2e vLLM multi-node DP test'
|
||||
LOCALAI_IMAGE=local-ai \
|
||||
@@ -1162,6 +1169,7 @@ BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
||||
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
@@ -1250,6 +1258,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||
@@ -1300,7 +1309,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-crispasr docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
18
README.md
18
README.md
@@ -31,12 +31,18 @@
|
||||
|
||||
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||
|
||||
- **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
|
||||
- **36+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
|
||||
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready** — API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first** — your data never leaves your infrastructure
|
||||
**A small core, not a bundle.** Each backend wraps a best-in-class engine (llama.cpp, vLLM, whisper.cpp, stable-diffusion, MLX...) in its own image, pulled only when a model needs it. You install nothing you don't use.
|
||||
|
||||
- **Composable by design**: backends are separate and pulled on demand, so you install only what your model needs
|
||||
- **Open and extensible**: load any model, or build your own backend in any language against an open interface
|
||||
- **Drop-in API compatibility**: OpenAI, Anthropic, and ElevenLabs APIs across every backend
|
||||
- **Any model, any modality**: LLMs, vision, voice, image, and video behind one API
|
||||
- **Any hardware**: NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready**: API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents**: autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first**: your data never leaves your infrastructure
|
||||
|
||||

|
||||
|
||||
Created by [Ettore Di Giacinto](https://github.com/mudler) and maintained by the [LocalAI team](#team).
|
||||
|
||||
|
||||
@@ -537,6 +537,15 @@ message TTSRequest {
|
||||
string dst = 3;
|
||||
string voice = 4;
|
||||
optional string language = 5;
|
||||
// instructions is a free-form, per-request style/voice description (maps to
|
||||
// the OpenAI `instructions` field). Backends that support expressive synthesis
|
||||
// (e.g. Qwen3-TTS CustomVoice/VoiceDesign) prefer this over the static YAML
|
||||
// option when set; backends that don't simply ignore it.
|
||||
optional string instructions = 6;
|
||||
// params carries optional, backend-specific per-request generation parameters
|
||||
// (e.g. Chatterbox exaggeration/cfg_weight/temperature). Values are strings and
|
||||
// coerced by the backend; unset leaves the backend's configured defaults.
|
||||
map<string, string> params = 7;
|
||||
}
|
||||
|
||||
message VADRequest {
|
||||
|
||||
1
backend/cpp/ds4/.gitignore
vendored
1
backend/cpp/ds4/.gitignore
vendored
@@ -2,6 +2,7 @@ ds4/
|
||||
build/
|
||||
package/
|
||||
grpc-server
|
||||
ds4-worker
|
||||
*.o
|
||||
backend.pb.cc
|
||||
backend.pb.h
|
||||
|
||||
@@ -104,3 +104,36 @@ if(DS4_NATIVE)
|
||||
target_compile_options(${TARGET} PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ds4-worker: standalone distributed worker. Links the same ds4 engine objects
|
||||
# (including ds4_distributed.o) but has NO gRPC/protobuf dependency - it speaks
|
||||
# ds4's own TCP transport via ds4_dist_run(). Buildable wherever the engine
|
||||
# objects build, even on hosts without protobuf/grpc dev headers.
|
||||
add_executable(ds4-worker worker_main.c)
|
||||
target_include_directories(ds4-worker PRIVATE ${DS4_DIR})
|
||||
foreach(obj ${DS4_OBJS})
|
||||
target_sources(ds4-worker PRIVATE ${obj})
|
||||
set_source_files_properties(${obj} PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
|
||||
endforeach()
|
||||
# worker_main.c is C, but the engine objects built by nvcc (ds4_cuda.o) and the
|
||||
# Metal path (ds4_metal.o, Obj-C++) reference the C++ runtime (libstdc++). Force
|
||||
# the C++ linker driver so those symbols resolve; the C driver would not link
|
||||
# libstdc++ and the CUDA/Metal builds fail with undefined std:: references.
|
||||
set_target_properties(ds4-worker PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_link_libraries(ds4-worker PRIVATE Threads::Threads m)
|
||||
|
||||
if(DS4_GPU STREQUAL "cuda")
|
||||
target_link_libraries(ds4-worker PRIVATE CUDA::cudart CUDA::cublas)
|
||||
elseif(DS4_GPU STREQUAL "metal")
|
||||
target_link_libraries(ds4-worker PRIVATE ${FOUNDATION_LIB} ${METAL_LIB})
|
||||
elseif(DS4_GPU STREQUAL "cpu")
|
||||
target_compile_definitions(ds4-worker PRIVATE DS4_NO_GPU)
|
||||
endif()
|
||||
|
||||
if(DS4_NATIVE)
|
||||
if(APPLE)
|
||||
target_compile_options(ds4-worker PRIVATE -mcpu=native)
|
||||
else()
|
||||
target_compile_options(ds4-worker PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=e16ead1e29c81a67bbb64e5b001117679cf9ce6e
|
||||
# Upstream pin lives below as DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=e16ead1e29c81a67bbb64e5b001117679cf9ce6e
|
||||
DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
@@ -66,6 +66,7 @@ grpc-server: ds4/ds4.o
|
||||
mkdir -p $(BUILD_DIR)
|
||||
cd $(BUILD_DIR) && cmake $(CMAKE_ARGS) $(CURRENT_MAKEFILE_DIR) && cmake --build . --config Release -j $(JOBS)
|
||||
cp $(BUILD_DIR)/grpc-server grpc-server
|
||||
cp $(BUILD_DIR)/ds4-worker ds4-worker
|
||||
|
||||
package: grpc-server
|
||||
bash package.sh
|
||||
@@ -74,7 +75,7 @@ test:
|
||||
@echo "ds4 backend: e2e coverage at tests/e2e-backends/ (BACKEND_BINARY mode)"
|
||||
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR) grpc-server package
|
||||
rm -rf $(BUILD_DIR) grpc-server ds4-worker package
|
||||
if [ -d ds4 ]; then $(MAKE) -C ds4 clean; fi
|
||||
|
||||
purge: clean
|
||||
|
||||
@@ -23,8 +23,11 @@ extern "C" {
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
@@ -51,6 +54,12 @@ ds4_session *g_session = nullptr;
|
||||
int g_ctx_size = 32768;
|
||||
std::string g_kv_cache_dir; // empty disables disk cache
|
||||
|
||||
// Distributed coordinator state. g_distributed is set true when LoadModel is
|
||||
// given 'ds4_role:coordinator'; generation then waits for the worker route to
|
||||
// form before running. Single-node behavior is unchanged when unset.
|
||||
bool g_distributed = false;
|
||||
int g_route_timeout_sec = 60;
|
||||
|
||||
std::atomic<Server *> g_server{nullptr};
|
||||
|
||||
// Parse a "key:value" option string. Returns empty when no colon.
|
||||
@@ -60,6 +69,77 @@ static std::pair<std::string, std::string> split_option(const std::string &opt)
|
||||
return {opt.substr(0, colon), opt.substr(colon + 1)};
|
||||
}
|
||||
|
||||
// Parse a positive base-10 integer. Returns false (without throwing) on empty,
|
||||
// trailing garbage, non-positive, or overflow - unlike std::stoi.
|
||||
static bool parse_positive_int(const std::string &s, int *out) {
|
||||
if (s.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long v = std::strtol(s.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || v <= 0 || v > INT_MAX) return false;
|
||||
*out = static_cast<int>(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse a ds4 layer spec "START:END" or "START:output" into the engine's
|
||||
// distributed layer fields. Returns false on malformed input.
|
||||
static bool parse_layers_spec(const std::string &spec, ds4_distributed_layers *out) {
|
||||
auto colon = spec.find(':');
|
||||
if (colon == std::string::npos) return false;
|
||||
std::string lhs = spec.substr(0, colon);
|
||||
std::string rhs = spec.substr(colon + 1);
|
||||
if (lhs.empty() || rhs.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long start = std::strtol(lhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || start < 0) return false;
|
||||
out->start = static_cast<uint32_t>(start);
|
||||
out->has_output = false;
|
||||
if (rhs == "output") {
|
||||
out->has_output = true;
|
||||
out->end = out->start; // engine treats has_output as "through final layer"
|
||||
} else {
|
||||
long e = std::strtol(rhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || e < start) return false;
|
||||
out->end = static_cast<uint32_t>(e);
|
||||
}
|
||||
out->set = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
// When acting as a distributed coordinator, block until the worker route
|
||||
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
||||
// elapses. Returns an empty string on success, or an error message to return
|
||||
// to the client. No-op when not distributed.
|
||||
//
|
||||
// Takes the g_engine_mu lock by reference and RELEASES it during each poll
|
||||
// sleep. The wait can span up to g_route_timeout_sec seconds while workers
|
||||
// connect; holding g_engine_mu the whole time would block the Status/Health
|
||||
// readiness probes (they also lock g_engine_mu), making LocalAI's loader treat
|
||||
// a still-starting worker as hung.
|
||||
static std::string wait_route_ready(std::unique_lock<std::mutex> &lock) {
|
||||
if (!g_distributed) return "";
|
||||
char err[256] = {0};
|
||||
const int deadline_polls = g_route_timeout_sec * 10; // 100ms per poll
|
||||
for (int i = 0; i <= deadline_polls; ++i) {
|
||||
int ready = ds4_session_distributed_route_ready(g_session, err, sizeof(err));
|
||||
if (ready == 1) return "";
|
||||
if (ready < 0) {
|
||||
return std::string("ds4 distributed route error: ") +
|
||||
(err[0] ? err : "unknown");
|
||||
}
|
||||
// Release the lock while sleeping so Status/Health and other RPCs can
|
||||
// interleave during worker startup.
|
||||
lock.unlock();
|
||||
struct timespec ts = {0, 100L * 1000L * 1000L}; // 100ms
|
||||
nanosleep(&ts, nullptr);
|
||||
lock.lock();
|
||||
// A concurrent Free() may have torn down the engine while we slept.
|
||||
if (!g_engine || !g_session) {
|
||||
return "ds4: model unloaded while waiting for distributed route";
|
||||
}
|
||||
}
|
||||
return "ds4 distributed route incomplete: workers not connected (layers uncovered)";
|
||||
}
|
||||
|
||||
static void append_token_text(ds4_engine *engine, int token, std::string &out) {
|
||||
size_t len = 0;
|
||||
const char *text = ds4_token_text(engine, token, &len);
|
||||
@@ -377,6 +457,11 @@ public:
|
||||
backend::Result *result) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
|
||||
// Reset distributed state so a model swap (a second LoadModel without
|
||||
// ds4_role) doesn't inherit a stale coordinator configuration.
|
||||
g_distributed = false;
|
||||
g_route_timeout_sec = 60;
|
||||
|
||||
if (g_engine) {
|
||||
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
||||
ds4_engine_close(g_engine);
|
||||
@@ -394,12 +479,23 @@ public:
|
||||
std::string mtp_path;
|
||||
int mtp_draft = 0;
|
||||
float mtp_margin = 3.0f;
|
||||
std::string ds4_role, ds4_layers, ds4_listen;
|
||||
for (const auto &opt : request->options()) {
|
||||
auto [k, v] = split_option(opt);
|
||||
if (k == "mtp_path") mtp_path = v;
|
||||
else if (k == "mtp_draft") mtp_draft = std::stoi(v);
|
||||
else if (k == "mtp_margin") mtp_margin = std::stof(v);
|
||||
else if (k == "kv_cache_dir") g_kv_cache_dir = v;
|
||||
else if (k == "ds4_role") ds4_role = v;
|
||||
else if (k == "ds4_layers") ds4_layers = v;
|
||||
else if (k == "ds4_listen") ds4_listen = v;
|
||||
else if (k == "ds4_route_timeout") {
|
||||
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
g_kv_cache.SetDir(g_kv_cache_dir);
|
||||
@@ -422,6 +518,49 @@ public:
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
|
||||
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
||||
// distributed inference: this process listens on ds4_listen and owns
|
||||
// the ds4_layers slice; workers dial in (see `local-ai worker
|
||||
// ds4-distributed`). Absent ds4_role => unchanged single-node path.
|
||||
// Must be static: opt.distributed.listen_host is a const char* the
|
||||
// engine retains past this call, so it cannot point at a local that
|
||||
// goes out of scope (otherwise a future "simplify to local" refactor
|
||||
// reintroduces a dangling pointer).
|
||||
static std::string s_listen_host;
|
||||
if (ds4_role == "coordinator") {
|
||||
if (ds4_layers.empty() || ds4_listen.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_role:coordinator requires ds4_layers and ds4_listen");
|
||||
return GStatus::OK;
|
||||
}
|
||||
// host:port for IPv4/hostname; IPv6 literals are unsupported (the
|
||||
// first colon would split inside the address).
|
||||
auto host_port = split_option(ds4_listen); // "host:port" -> {host, port}
|
||||
if (host_port.second.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen must be host:port");
|
||||
return GStatus::OK;
|
||||
}
|
||||
int listen_port = 0;
|
||||
if (!parse_positive_int(host_port.second, &listen_port)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen port must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
ds4_distributed_layers layers = {};
|
||||
if (!parse_layers_spec(ds4_layers, &layers)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: invalid ds4_layers (want START:END or START:output)");
|
||||
return GStatus::OK;
|
||||
}
|
||||
s_listen_host = host_port.first;
|
||||
opt.distributed.role = DS4_DISTRIBUTED_COORDINATOR;
|
||||
opt.distributed.layers = layers;
|
||||
opt.distributed.listen_host = s_listen_host.c_str();
|
||||
opt.distributed.listen_port = listen_port;
|
||||
g_distributed = true;
|
||||
}
|
||||
|
||||
int rc = ds4_engine_open(&g_engine, &opt);
|
||||
if (rc != 0 || !g_engine) {
|
||||
result->set_success(false);
|
||||
@@ -458,10 +597,13 @@ public:
|
||||
|
||||
GStatus Predict(ServerContext *, const backend::PredictOptions *request,
|
||||
backend::Reply *reply) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
@@ -554,10 +696,13 @@ public:
|
||||
|
||||
GStatus PredictStream(ServerContext *, const backend::PredictOptions *request,
|
||||
ServerWriter<backend::Reply> *writer) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
|
||||
@@ -5,7 +5,8 @@ REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
cp -avf "$CURDIR/grpc-server" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/ds4-worker" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
UNAME_S=$(uname -s)
|
||||
if [ "$UNAME_S" = "Darwin" ]; then
|
||||
|
||||
126
backend/cpp/ds4/worker_main.c
Normal file
126
backend/cpp/ds4/worker_main.c
Normal file
@@ -0,0 +1,126 @@
|
||||
// ds4-worker: standalone distributed worker for the LocalAI ds4 backend.
|
||||
//
|
||||
// A ds4 distributed worker owns a slice of the model's transformer layers,
|
||||
// dials the coordinator, and serves activations for its slice. It does NOT
|
||||
// speak backend.proto - it speaks ds4's own TCP transport via ds4_dist_run().
|
||||
// This binary is intentionally minimal (no HTTP/web/kvstore/linenoise): it
|
||||
// only needs the engine objects + ds4_distributed.o, which the backend already
|
||||
// builds. It is launched by `local-ai worker ds4-distributed`.
|
||||
//
|
||||
// Usage:
|
||||
// ds4-worker --role worker --model <gguf> --layers 20:output \
|
||||
// --coordinator <host> <port> [--cpu|--cuda|--metal] [-c CTX] [-t N]
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <signal.h>
|
||||
#include <limits.h>
|
||||
|
||||
#include "ds4.h"
|
||||
#include "ds4_distributed.h"
|
||||
|
||||
static const char *need_arg(int *i, int argc, char **argv, const char *flag) {
|
||||
if (*i + 1 >= argc) {
|
||||
fprintf(stderr, "ds4-worker: missing value for %s\n", flag);
|
||||
exit(2);
|
||||
}
|
||||
return argv[++(*i)];
|
||||
}
|
||||
|
||||
static int parse_int_arg(const char *s, const char *flag) {
|
||||
char *end = NULL;
|
||||
long v = strtol(s, &end, 10);
|
||||
if (!s[0] || *end || v <= 0 || v > INT_MAX) {
|
||||
fprintf(stderr, "ds4-worker: invalid value for %s: %s\n", flag, s);
|
||||
exit(2);
|
||||
}
|
||||
return (int)v;
|
||||
}
|
||||
|
||||
static ds4_backend default_backend(void) {
|
||||
#if defined(DS4_NO_GPU)
|
||||
return DS4_BACKEND_CPU;
|
||||
#elif defined(__APPLE__)
|
||||
return DS4_BACKEND_METAL;
|
||||
#else
|
||||
return DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
signal(SIGPIPE, SIG_IGN);
|
||||
|
||||
ds4_engine_options opt = {0};
|
||||
opt.backend = default_backend();
|
||||
int ctx_size = 32768;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const char *arg = argv[i];
|
||||
if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
|
||||
fprintf(stdout, "ds4-worker: standalone ds4 distributed worker\n");
|
||||
ds4_dist_usage(stdout);
|
||||
fprintf(stdout, " -m, --model PATH model GGUF (the worker loads only its --layers slice)\n");
|
||||
fprintf(stdout, " -c, --ctx N context size (default 32768)\n");
|
||||
fprintf(stdout, " -t, --threads N CPU threads\n");
|
||||
fprintf(stdout, " --cpu|--cuda|--metal backend override\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
char dist_err[256] = {0};
|
||||
ds4_dist_cli_parse_result dist_parse =
|
||||
ds4_dist_parse_cli_arg(arg, &i, argc, argv, &opt.distributed,
|
||||
dist_err, sizeof(dist_err));
|
||||
if (dist_parse == DS4_DIST_CLI_ERROR) {
|
||||
fprintf(stderr, "ds4-worker: %s\n",
|
||||
dist_err[0] ? dist_err : "invalid distributed option");
|
||||
return 2;
|
||||
}
|
||||
if (dist_parse == DS4_DIST_CLI_MATCHED) continue;
|
||||
|
||||
if (!strcmp(arg, "-m") || !strcmp(arg, "--model")) {
|
||||
opt.model_path = need_arg(&i, argc, argv, arg);
|
||||
} else if (!strcmp(arg, "-c") || !strcmp(arg, "--ctx")) {
|
||||
ctx_size = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "-t") || !strcmp(arg, "--threads")) {
|
||||
opt.n_threads = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "--cpu")) {
|
||||
opt.backend = DS4_BACKEND_CPU;
|
||||
} else if (!strcmp(arg, "--cuda")) {
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
} else if (!strcmp(arg, "--metal")) {
|
||||
opt.backend = DS4_BACKEND_METAL;
|
||||
} else {
|
||||
fprintf(stderr, "ds4-worker: unknown option: %s\n", arg);
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (opt.distributed.role != DS4_DISTRIBUTED_WORKER) {
|
||||
fprintf(stderr, "ds4-worker: --role worker is required\n");
|
||||
return 2;
|
||||
}
|
||||
if (!opt.model_path) {
|
||||
fprintf(stderr, "ds4-worker: --model is required\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
char prep_err[256] = {0};
|
||||
if (ds4_dist_prepare_engine_options(&opt.distributed, &opt,
|
||||
prep_err, sizeof(prep_err)) != 0) {
|
||||
fprintf(stderr, "ds4-worker: %s\n", prep_err);
|
||||
return 2;
|
||||
}
|
||||
|
||||
ds4_engine *engine = NULL;
|
||||
if (ds4_engine_open(&engine, &opt) != 0 || !engine) {
|
||||
fprintf(stderr, "ds4-worker: failed to open engine\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
ds4_dist_generation_options gen = {0};
|
||||
gen.ctx_size = ctx_size;
|
||||
int rc = ds4_dist_run(engine, &opt.distributed, &gen);
|
||||
ds4_engine_close(engine);
|
||||
return rc;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=8960c5ba5ee9db30ba838304373aa4dbec9f7cbd
|
||||
IK_LLAMA_VERSION?=1520eda980564241434b791ce2bbbd128c4be9ea
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=22d66b567eef11cf2e9832f04db64ee0323a0fd0
|
||||
LLAMA_VERSION?=7c158fbb4aec1bdc9c81d6ca0e785139f4826fae
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1944,6 +1944,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto re_it = metadata.find("reasoning_effort");
|
||||
if (re_it != metadata.end() && !re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2204,7 +2215,15 @@ public:
|
||||
// content element — attaching to both would duplicate the first
|
||||
// token since oaicompat_msg_diffs is the same for both.
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
// Upstream llama.cpp (ggml-org/llama.cpp#23884) now emits an initial
|
||||
// "begin" partial whose to_json() returns null, used only to signal the
|
||||
// HTTP layer to flush 200 status headers before any token. gRPC has no
|
||||
// such concept, so there is nothing to emit — the real tokens arrive in
|
||||
// the loop below. Feeding this null into build_reply_from_json would
|
||||
// throw (uncaught) and surface as a generic RPC error.
|
||||
if (first_res_json.is_null()) {
|
||||
// skip the begin-of-stream marker
|
||||
} else if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
// Skip chat deltas for role-init elements (have "role" in
|
||||
@@ -2234,7 +2253,10 @@ public:
|
||||
}
|
||||
|
||||
json res_json = result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
if (res_json.is_null()) {
|
||||
// begin-of-stream marker (see note above) — nothing to emit
|
||||
continue;
|
||||
} else if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
@@ -2726,6 +2748,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto predict_re_it = predict_metadata.find("reasoning_effort");
|
||||
if (predict_re_it != predict_metadata.end() && !predict_re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = predict_re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
@@ -145,7 +146,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -175,7 +176,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -269,7 +270,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
|
||||
5
backend/go/crispasr/.gitignore
vendored
Normal file
5
backend/go/crispasr/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
sources
|
||||
build*
|
||||
libgocrispasr*.so
|
||||
crispasr
|
||||
package
|
||||
30
backend/go/crispasr/CMakeLists.txt
Normal file
30
backend/go/crispasr/CMakeLists.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(gocrispasr LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
add_subdirectory(./sources/CrispASR)
|
||||
|
||||
add_library(gocrispasr MODULE cpp/crispasr_shim.cpp)
|
||||
target_include_directories(gocrispasr PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/ggml/include)
|
||||
# Link the same backend set as crispasr-cli (examples/cli/CMakeLists.txt) so
|
||||
# the session API can dispatch to every compiled-in architecture, not just
|
||||
# whisper. crispasr is the referencer; the backend static libs supply the
|
||||
# per-architecture symbols; ggml is the math/runtime base.
|
||||
target_link_libraries(gocrispasr PRIVATE
|
||||
crispasr
|
||||
parakeet canary canary_ctc cohere granite_speech granite_nle
|
||||
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
|
||||
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice
|
||||
silero-lid pyannote-seg funasr paraformer sensevoice
|
||||
crisp_audio
|
||||
ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gocrispasr PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
set_property(TARGET gocrispasr PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gocrispasr PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
132
backend/go/crispasr/Makefile
Normal file
132
backend/go/crispasr/Makefile
Normal file
@@ -0,0 +1,132 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=13d54e110e1538e0f0bc3af0680b9ab246cfb48d
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
# Keep the build lean: no tests/examples/server/SDL2/curl/ffmpeg (the FROM scratch
|
||||
# image cannot satisfy those runtime deps). All ASR/TTS model backends stay enabled.
|
||||
CMAKE_ARGS+=-DCRISPASR_BUILD_TESTS=OFF -DCRISPASR_BUILD_EXAMPLES=OFF -DCRISPASR_BUILD_SERVER=OFF
|
||||
CMAKE_ARGS+=-DCRISPASR_SDL2=OFF -DCRISPASR_CURL=OFF -DCRISPASR_FFMPEG=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/CrispASR:
|
||||
mkdir -p sources/CrispASR
|
||||
cd sources/CrispASR && \
|
||||
git init && \
|
||||
git remote add origin $(CRISPASR_REPO) && \
|
||||
git fetch origin && \
|
||||
git checkout $(CRISPASR_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
# CrispASR's src/CMakeLists.txt locates its vendored llama.cpp
|
||||
# (crispasr-llama-core, used by the chat C-ABI) via ${CMAKE_SOURCE_DIR},
|
||||
# which assumes CrispASR is the top-level CMake project. We add_subdirectory
|
||||
# it, so ${CMAKE_SOURCE_DIR} is THIS backend dir and the talk-llama sources
|
||||
# aren't found. Rewrite to ${PROJECT_SOURCE_DIR} (the crispasr project root),
|
||||
# which is correct both standalone and as a subproject. Idempotent.
|
||||
sed -i 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgocrispasr-avx.so libgocrispasr-avx2.so libgocrispasr-avx512.so libgocrispasr-fallback.so
|
||||
else
|
||||
VARIANT_TARGETS = libgocrispasr-fallback.so
|
||||
endif
|
||||
|
||||
crispasr: main.go gocrispasr.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o crispasr ./
|
||||
|
||||
package: crispasr
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgocrispasr*.so package sources/CrispASR crispasr
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgocrispasr-avx.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx${RESET})
|
||||
SO_TARGET=libgocrispasr-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx2.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx2${RESET})
|
||||
SO_TARGET=libgocrispasr-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx512.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx512${RESET})
|
||||
SO_TARGET=libgocrispasr-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
libgocrispasr-fallback.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:fallback${RESET})
|
||||
SO_TARGET=libgocrispasr-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-custom: CMakeLists.txt cpp/crispasr_shim.cpp cpp/crispasr_shim.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgocrispasr.so ./$(SO_TARGET)
|
||||
|
||||
test: crispasr
|
||||
CGO_ENABLED=0 $(GOCMD) test -v ./...
|
||||
|
||||
all: crispasr package
|
||||
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
#include "crispasr_shim.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "crispasr.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
// Opaque session types. crispasr.h declares `struct crispasr_session;` but not
|
||||
// the result type nor the open/transcribe/result accessors — those are
|
||||
// CA_EXPORT extern "C" symbols in src/crispasr_c_api.cpp, so we forward-declare
|
||||
// exactly the ones we use. Signatures verified against
|
||||
// sources/CrispASR/src/crispasr_c_api.cpp.
|
||||
struct crispasr_session_result;
|
||||
extern "C" {
|
||||
crispasr_session *crispasr_session_open(const char *model_path, int n_threads);
|
||||
crispasr_session *crispasr_session_open_explicit(const char *model_path,
|
||||
const char *backend_name,
|
||||
int n_threads);
|
||||
int crispasr_session_set_codec_path(crispasr_session *s, const char *path);
|
||||
void crispasr_session_close(crispasr_session *s);
|
||||
const char *crispasr_session_backend(crispasr_session *s);
|
||||
int crispasr_session_set_translate(crispasr_session *s, int enable);
|
||||
crispasr_session_result *crispasr_session_transcribe_lang(
|
||||
crispasr_session *s, const float *pcm, int n_samples, const char *language);
|
||||
int crispasr_session_result_n_segments(crispasr_session_result *r);
|
||||
const char *crispasr_session_result_segment_text(crispasr_session_result *r,
|
||||
int i);
|
||||
int64_t crispasr_session_result_segment_t0(crispasr_session_result *r, int i);
|
||||
int64_t crispasr_session_result_segment_t1(crispasr_session_result *r, int i);
|
||||
void crispasr_session_result_free(crispasr_session_result *r);
|
||||
float *crispasr_session_synthesize(crispasr_session *s, const char *text,
|
||||
int *out_n_samples);
|
||||
void crispasr_pcm_free(float *pcm);
|
||||
int crispasr_session_set_speaker_name(crispasr_session *s, const char *name);
|
||||
int crispasr_session_set_voice(crispasr_session *s, const char *path,
|
||||
const char *ref_text_or_null);
|
||||
}
|
||||
|
||||
static crispasr_session *g_session = nullptr;
|
||||
static crispasr_session_result *g_result = nullptr;
|
||||
|
||||
static struct whisper_vad_context *vctx;
|
||||
static std::vector<float> flat_segs;
|
||||
|
||||
static std::atomic<int> g_abort{0};
|
||||
|
||||
extern "C" void set_abort(int v) {
|
||||
g_abort.store(v, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
|
||||
if (!log) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG:
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_INFO:
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_WARN:
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default: /* Potential future-proofing */
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[%-5s] ", level_str);
|
||||
fputs(log, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (backend_name && *backend_name) {
|
||||
g_session =
|
||||
crispasr_session_open_explicit(model_path, backend_name, threads);
|
||||
} else {
|
||||
g_session = crispasr_session_open(model_path, threads);
|
||||
}
|
||||
if (g_session == nullptr) {
|
||||
fprintf(stderr, "error: failed to open CrispASR session for model\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "info: CrispASR backend selected: %s\n",
|
||||
crispasr_session_backend(g_session));
|
||||
return 0;
|
||||
}
|
||||
|
||||
// set_codec_path forwards a companion file (qwen3-tts codec, orpheus SNAC,
|
||||
// chatterbox s3gen, or mimo-asr tokenizer) to the active session. Returns 0 on
|
||||
// success or when the active backend needs no companion, negative on failure,
|
||||
// and -1 when no session is open.
|
||||
int set_codec_path(const char *path) {
|
||||
return g_session ? crispasr_session_set_codec_path(g_session, path) : -1;
|
||||
}
|
||||
|
||||
int load_model_vad(const char *const model_path) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
struct whisper_vad_context_params vcparams =
|
||||
whisper_vad_default_context_params();
|
||||
|
||||
// XXX: Overridden to false in upstream due to performance?
|
||||
// vcparams.use_gpu = true;
|
||||
|
||||
vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
|
||||
if (vctx == nullptr) {
|
||||
fprintf(stderr, "error: Failed to init model as VAD\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
|
||||
size_t *segs_out_len) {
|
||||
if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
|
||||
fprintf(stderr, "error: failed to detect speech\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct whisper_vad_params params = whisper_vad_default_params();
|
||||
struct whisper_vad_segments *segs =
|
||||
whisper_vad_segments_from_probs(vctx, params);
|
||||
size_t segn = whisper_vad_segments_n_segments(segs);
|
||||
|
||||
// fprintf(stderr, "Got segments %zd\n", segn);
|
||||
|
||||
flat_segs.clear();
|
||||
|
||||
for (int i = 0; i < segn; i++) {
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
|
||||
}
|
||||
|
||||
// fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
|
||||
// segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
|
||||
// flat_segs.size());
|
||||
*segs_out = flat_segs.data();
|
||||
*segs_out_len = flat_segs.size();
|
||||
|
||||
// fprintf(stderr, "freeing segs\n");
|
||||
whisper_vad_free_segments(segs);
|
||||
|
||||
// fprintf(stderr, "returning\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// threads, diarize and prompt are accepted for Go-side API parity but unused
|
||||
// in Phase 1: the thread count is fixed at session open, and diarization and
|
||||
// the initial prompt are separate CrispASR features not yet wired through the
|
||||
// session ASR path.
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt) {
|
||||
(void)threads;
|
||||
(void)diarize;
|
||||
(void)prompt;
|
||||
|
||||
if (!g_session) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Reset stale abort flag from any prior cancelled call. set_abort remains
|
||||
// best-effort: the session transcribe call is blocking and exposes no abort
|
||||
// hook, so a mid-decode abort cannot interrupt it.
|
||||
g_abort.store(0, std::memory_order_relaxed);
|
||||
|
||||
crispasr_session_set_translate(g_session, translate ? 1 : 0);
|
||||
|
||||
if (g_result) {
|
||||
crispasr_session_result_free(g_result);
|
||||
g_result = nullptr;
|
||||
}
|
||||
|
||||
const char *language = (lang && *lang) ? lang : nullptr;
|
||||
g_result = crispasr_session_transcribe_lang(g_session, pcmf32, (int)pcmf32_len,
|
||||
language);
|
||||
if (!g_result) {
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
*segs_out_len = crispasr_session_result_n_segments(g_result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char *get_segment_text(int i) {
|
||||
if (!g_result) {
|
||||
return "";
|
||||
}
|
||||
return crispasr_session_result_segment_text(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t0(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t0(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t1(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t1(g_result, i);
|
||||
}
|
||||
|
||||
const char *get_backend(void) {
|
||||
return g_session ? crispasr_session_backend(g_session) : "";
|
||||
}
|
||||
|
||||
// TTS uses the already-open session (crispasr_session_open auto-detects a TTS
|
||||
// model). Output is 24 kHz mono float PCM (upstream CrispASR convention),
|
||||
// malloc'd by the C API; the caller must release it via tts_free.
|
||||
float *tts_synthesize(const char *text, int *out_n_samples) {
|
||||
if (out_n_samples) *out_n_samples = 0;
|
||||
if (!g_session || !text) return nullptr;
|
||||
return crispasr_session_synthesize(g_session, text, out_n_samples);
|
||||
}
|
||||
|
||||
void tts_free(float *pcm) {
|
||||
if (pcm) crispasr_pcm_free(pcm);
|
||||
}
|
||||
|
||||
int tts_set_voice(const char *name) {
|
||||
if (!g_session || !name || !*name) return 0;
|
||||
return crispasr_session_set_speaker_name(g_session, name);
|
||||
}
|
||||
|
||||
// tts_set_voice_file loads a voice from a file: a .gguf path selects a voice
|
||||
// pack, a .wav path with a non-empty ref_text performs zero-shot voice cloning
|
||||
// (the C API returns -2 when ref_text is required but missing). Returns -1 when
|
||||
// no session is open or path is null.
|
||||
int tts_set_voice_file(const char *path, const char *ref_text) {
|
||||
if (!g_session || !path) return -1;
|
||||
const char *ref = (ref_text && *ref_text) ? ref_text : nullptr;
|
||||
return crispasr_session_set_voice(g_session, path, ref);
|
||||
}
|
||||
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
extern "C" {
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name);
|
||||
int set_codec_path(const char *path);
|
||||
int load_model_vad(const char *const model_path);
|
||||
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
||||
size_t *segs_out_len);
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt);
|
||||
const char *get_segment_text(int i);
|
||||
int64_t get_segment_t0(int i);
|
||||
int64_t get_segment_t1(int i);
|
||||
const char *get_backend(void);
|
||||
void set_abort(int v);
|
||||
float *tts_synthesize(const char *text, int *out_n_samples); // 24kHz mono float, malloc'd; NULL on failure
|
||||
void tts_free(float *pcm);
|
||||
int tts_set_voice(const char *name); // best-effort speaker selection; 0 ok
|
||||
int tts_set_voice_file(const char *path, const char *ref_text); // load voice pack (.gguf) or zero-shot clone (.wav + ref_text)
|
||||
}
|
||||
497
backend/go/crispasr/gocrispasr.go
Normal file
497
backend/go/crispasr/gocrispasr.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string, threads int, backendName string) int
|
||||
CppSetCodecPath func(path string) int
|
||||
CppLoadModelVAD func(modelPath string) int
|
||||
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
||||
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int
|
||||
CppGetSegmentText func(i int) string
|
||||
CppGetSegmentStart func(i int) int64
|
||||
CppGetSegmentEnd func(i int) int64
|
||||
CppGetBackend func() string
|
||||
CppSetAbort func(v int)
|
||||
CppTTSSynthesize func(text string, outNSamples unsafe.Pointer) uintptr
|
||||
CppTTSFree func(ptr uintptr)
|
||||
CppTTSSetVoice func(name string) int
|
||||
CppTTSSetVoiceFile func(path string, refText string) int
|
||||
)
|
||||
|
||||
type CrispASR struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
// splitOption splits a "prefix:value" model option into its key and value,
|
||||
// matching the convention used by other backends (see sherpa-onnx). It returns
|
||||
// ok=false when the option carries no ':' separator.
|
||||
func splitOption(oo string) (key, value string, ok bool) {
|
||||
parts := strings.SplitN(oo, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
|
||||
func (w *CrispASR) Load(opts *pb.ModelOptions) error {
|
||||
vadOnly := false
|
||||
backendName := ""
|
||||
codecPath := ""
|
||||
speakerName := ""
|
||||
voicePath := ""
|
||||
voiceRefText := ""
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
if oo == "vad_only" {
|
||||
vadOnly = true
|
||||
continue
|
||||
}
|
||||
switch key, value, ok := splitOption(oo); {
|
||||
case ok && key == "backend":
|
||||
backendName = value
|
||||
case ok && key == "codec":
|
||||
codecPath = value
|
||||
case ok && key == "speaker":
|
||||
speakerName = value
|
||||
case ok && key == "voice":
|
||||
voicePath = value
|
||||
case ok && key == "voice_text":
|
||||
voiceRefText = value
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
}
|
||||
|
||||
if vadOnly {
|
||||
if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR VAD model")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve a relative companion path against the model directory so a config
|
||||
// can reference a sibling codec/tokenizer file by name alone.
|
||||
if codecPath != "" && !filepath.IsAbs(codecPath) {
|
||||
codecPath = filepath.Join(filepath.Dir(opts.ModelFile), codecPath)
|
||||
}
|
||||
|
||||
// A voice file (.gguf pack or .wav prompt) is resolved against the model
|
||||
// directory just like the codec, so a config can reference a sibling file.
|
||||
if voicePath != "" && !filepath.IsAbs(voicePath) {
|
||||
voicePath = filepath.Join(filepath.Dir(opts.ModelFile), voicePath)
|
||||
}
|
||||
|
||||
if ret := CppLoadModel(opts.ModelFile, int(opts.Threads), backendName); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR transcription model")
|
||||
}
|
||||
|
||||
// Load the companion file (codec/tokenizer/s3gen) after the session is open.
|
||||
// rc==0 means success or "not applicable" for the active backend; only a
|
||||
// negative code is fatal.
|
||||
if codecPath != "" {
|
||||
if rc := CppSetCodecPath(codecPath); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load companion file %q (rc=%d)", codecPath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR companion file loaded: %s\n", codecPath)
|
||||
}
|
||||
|
||||
// Apply the Load-time default voice. A baked speaker (speaker:) is selected
|
||||
// by name and is best-effort: a backend that can't honor it is logged, not
|
||||
// fatal. A voice file (voice:) is a hard requirement once configured, so a
|
||||
// negative rc fails Load.
|
||||
if speakerName != "" {
|
||||
if rc := CppTTSSetVoice(speakerName); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: speaker %q not applied (rc=%d)\n", speakerName, rc)
|
||||
}
|
||||
}
|
||||
if voicePath != "" {
|
||||
if rc := CppTTSSetVoiceFile(voicePath, voiceRefText); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load voice %q (rc=%d)", voicePath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR voice loaded: %s\n", voicePath)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "CrispASR backend selected: %s\n", CppGetBackend())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
audio := req.Audio
|
||||
// We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
|
||||
segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
|
||||
segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
|
||||
return pb.VADResponse{}, fmt.Errorf("Failed VAD")
|
||||
}
|
||||
|
||||
// Happens when CPP vector has not had any elements pushed to it
|
||||
if segsPtr == 0 {
|
||||
return pb.VADResponse{
|
||||
Segments: []*pb.VADSegment{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
|
||||
// however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
|
||||
segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen) //nolint:govet // segsPtr addresses C++-owned heap memory passed back through the cgo-free purego boundary; the uintptr->Pointer round-trip is intentional and the buffer outlives this read.
|
||||
|
||||
vadSegments := []*pb.VADSegment{}
|
||||
for i := range len(segs) >> 1 {
|
||||
s := segs[2*i] / 100
|
||||
t := segs[2*i+1] / 100
|
||||
vadSegments = append(vadSegments, &pb.VADSegment{
|
||||
Start: s,
|
||||
End: t,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.VADResponse{
|
||||
Segments: vadSegments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
// Watcher: flips the C-side abort flag when ctx is cancelled. The
|
||||
// goroutine is joined synchronously (close(done) signals it to exit,
|
||||
// wg.Wait() blocks until it has) so a late CppSetAbort(1) cannot fire
|
||||
// after the function returns and corrupt the next transcription call.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
text := ""
|
||||
for i := range int(segsLen) {
|
||||
// segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895
|
||||
s := CppGetSegmentStart(i) * (10000000)
|
||||
t := CppGetSegmentEnd(i) * (10000000)
|
||||
// The session result can emit bytes that aren't valid UTF-8 (e.g. a
|
||||
// multibyte codepoint split across token boundaries); protobuf string
|
||||
// fields reject those at marshal time. Scrub before the value escapes
|
||||
// cgo. The session result is segment+word based and exposes no token
|
||||
// IDs, so Tokens is left empty.
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
|
||||
segment := &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
|
||||
text += " " + strings.TrimSpace(txt)
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: strings.TrimSpace(text),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream runs the session transcribe to completion and then
|
||||
// emits one delta per non-empty segment, followed by a final TranscriptResult.
|
||||
// Progressive/real-time streaming isn't available via the session API (there
|
||||
// is no per-decode callback), so deltas are emitted per-segment after the
|
||||
// blocking decode returns rather than as segments are produced. The offline
|
||||
// AudioTranscription is unchanged; both paths share the session and the
|
||||
// SingleThread concurrency model.
|
||||
func (w *CrispASR) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
|
||||
// Same abort-watcher pattern as AudioTranscription. Joined synchronously
|
||||
// so a late CppSetAbort(1) cannot fire after this function returns.
|
||||
// Best-effort only: the session transcribe is blocking with no abort hook.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
// Walk the segments once: emit a delta per non-empty segment and build the
|
||||
// final TranscriptResult.Segments alongside. The first delta has no leading
|
||||
// space and subsequent ones are prefixed with a single space, so
|
||||
// concat(deltas) == final.Text exactly, matching the e2e contract.
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
var assembled strings.Builder
|
||||
for i := range int(segsLen) {
|
||||
s := CppGetSegmentStart(i) * 10000000
|
||||
t := CppGetSegmentEnd(i) * 10000000
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
segments = append(segments, &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
})
|
||||
|
||||
trimmed := strings.TrimSpace(txt)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
var delta string
|
||||
if assembled.Len() == 0 {
|
||||
delta = trimmed
|
||||
} else {
|
||||
delta = " " + trimmed
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
assembled.WriteString(delta)
|
||||
}
|
||||
|
||||
final := &pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: assembled.String(),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: final}
|
||||
return nil
|
||||
}
|
||||
|
||||
// synthesize returns 24 kHz mono float32 PCM for text via the open session.
|
||||
func (w *CrispASR) synthesize(text string) ([]float32, error) {
|
||||
if text == "" {
|
||||
return nil, fmt.Errorf("crispasr: TTS requires non-empty text")
|
||||
}
|
||||
var n int32
|
||||
ptr := CppTTSSynthesize(text, unsafe.Pointer(&n))
|
||||
if ptr == 0 || n <= 0 {
|
||||
return nil, fmt.Errorf("crispasr: synthesis failed (the loaded model may not be a supported TTS backend, or needs extra config e.g. orpheus SNAC codec)")
|
||||
}
|
||||
defer CppTTSFree(ptr)
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free.
|
||||
out := make([]float32, int(n)) // copy out of C memory before free
|
||||
copy(out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// setVoice applies a per-call speaker/voice override (best effort). CrispASR
|
||||
// returns a negative code when the active backend can't honor the name; we log
|
||||
// it rather than fail, so an unknown voice falls back to the default speaker.
|
||||
func setVoice(voice string) {
|
||||
v := strings.TrimSpace(voice)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if rc := CppTTSSetVoice(v); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: voice %q not applied by the active TTS backend (rc=%d); using default\n", v, rc)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *CrispASR) TTS(req *pb.TTSRequest) error {
|
||||
if req.Dst == "" {
|
||||
return fmt.Errorf("crispasr: TTS requires a destination path")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWAV24k(req.Dst, pcm)
|
||||
}
|
||||
|
||||
// TTSStream is the streaming counterpart to TTS. CrispASR has no progressive
|
||||
// (native streaming) synth, so we synthesize the whole utterance, encode it to
|
||||
// a 24 kHz WAV, and emit the encoded bytes as a single chunk. The gRPC server
|
||||
// wrapper (pkg/grpc/server.go:TTSStream) ranges over the channel until it is
|
||||
// closed, so this method owns the close - mirrors vibevoice-cpp's TTSStream.
|
||||
func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
defer close(results)
|
||||
|
||||
if req.Text == "" {
|
||||
return fmt.Errorf("crispasr: TTSStream requires text")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp("", "crispasr-tts-stream-*.wav")
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: tempfile: %w", err)
|
||||
}
|
||||
dst := tmp.Name()
|
||||
if err := tmp.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close tempfile: %w", err)
|
||||
}
|
||||
defer func() { _ = os.Remove(dst) }()
|
||||
|
||||
if err := writeWAV24k(dst, pcm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoded, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: read tempfile: %w", err)
|
||||
}
|
||||
results <- encoded
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeWAV24k writes pcm as a 24000 Hz, mono, 16-bit PCM WAV at dst.
|
||||
func writeWAV24k(dst string, pcm []float32) error {
|
||||
f, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: create %q: %w", dst, err)
|
||||
}
|
||||
|
||||
enc := wav.NewEncoder(f, 24000, 16, 1, 1)
|
||||
ints := make([]int, len(pcm))
|
||||
for i, s := range pcm {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
ints[i] = int(s * 32767)
|
||||
}
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 24000},
|
||||
Data: ints,
|
||||
SourceBitDepth: 16,
|
||||
}
|
||||
if err := enc.Write(buf); err != nil {
|
||||
_ = enc.Close()
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: encode WAV: %w", err)
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: finalize WAV: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close %q: %w", dst, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
193
backend/go/crispasr/gocrispasr_test.go
Normal file
193
backend/go/crispasr/gocrispasr_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestCrispASR(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "CrispASR Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// bridge without spinning up the gRPC server. Skips the current spec when the
|
||||
// shared library isn't present (e.g. running before `make backends/whisper`).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
if _, err := os.Stat(libName); err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
|
||||
purego.RegisterLibFunc(&CppSetCodecPath, gosd, "set_codec_path")
|
||||
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
|
||||
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
|
||||
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
|
||||
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
|
||||
purego.RegisterLibFunc(&CppGetBackend, gosd, "get_backend")
|
||||
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
|
||||
purego.RegisterLibFunc(&CppTTSSynthesize, gosd, "tts_synthesize")
|
||||
purego.RegisterLibFunc(&CppTTSFree, gosd, "tts_free")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoice, gosd, "tts_set_voice")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoiceFile, gosd, "tts_set_voice_file")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("whisper library not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if either
|
||||
// env var is unset. The test never runs in default CI — it requires a real
|
||||
// whisper model and a long audio file (~3 minutes) on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("CRISPASR_MODEL_PATH")
|
||||
audioPath := os.Getenv("CRISPASR_AUDIO_PATH")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set CRISPASR_MODEL_PATH and CRISPASR_AUDIO_PATH to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// ttsModelOrSkip returns the TTS model path or skips the spec when the env var
|
||||
// is unset. Like the transcription fixtures, this never runs in default CI — it
|
||||
// needs a real TTS model (e.g. a vibevoice GGUF) on disk.
|
||||
func ttsModelOrSkip() string {
|
||||
modelPath := os.Getenv("CRISPASR_TTS_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
Skip("set CRISPASR_TTS_MODEL_PATH to run this spec")
|
||||
}
|
||||
return modelPath
|
||||
}
|
||||
|
||||
var _ = Describe("CrispASR", func() {
|
||||
Context("AudioTranscription cancellation", func() {
|
||||
It("returns codes.Canceled on a pre-cancelled context and still succeeds afterwards", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
// The session transcribe is blocking and exposes no abort hook, so
|
||||
// a mid-decode cancel can't interrupt it. The contract we can rely
|
||||
// on is the pre-call ctx.Err() check: a context cancelled before
|
||||
// the call must yield codes.Canceled without starting a decode.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "expected pre-cancelled context to fail")
|
||||
st, ok := status.FromError(err)
|
||||
Expect(ok).To(BeTrue(), "expected gRPC status error, got %v", err)
|
||||
Expect(st.Code()).To(Equal(codes.Canceled), "expected codes.Canceled, got %v", err)
|
||||
|
||||
// Subsequent transcription must succeed — proves g_abort reset.
|
||||
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "post-cancel transcription failed")
|
||||
Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("emits multiple deltas progressively for a multi-segment clip", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- w.AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
Stream: true,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var assembled strings.Builder
|
||||
var finalText string
|
||||
var finalSegmentCount int
|
||||
for chunk := range results {
|
||||
if d := chunk.GetDelta(); d != "" {
|
||||
deltas = append(deltas, d)
|
||||
assembled.WriteString(d)
|
||||
}
|
||||
if final := chunk.GetFinalResult(); final != nil {
|
||||
finalText = final.GetText()
|
||||
finalSegmentCount = len(final.GetSegments())
|
||||
}
|
||||
}
|
||||
Expect(<-done).ToNot(HaveOccurred())
|
||||
|
||||
// One delta per non-empty segment is emitted after the blocking
|
||||
// decode returns (the session API has no per-decode callback), so a
|
||||
// multi-segment clip MUST produce >=2 delta events, and
|
||||
// concat(deltas) MUST equal final.Text exactly.
|
||||
Expect(len(deltas)).To(BeNumerically(">=", 2),
|
||||
"expected multiple deltas from a multi-segment clip, got %d (assembled=%q)",
|
||||
len(deltas), assembled.String())
|
||||
Expect(finalSegmentCount).To(BeNumerically(">=", 2),
|
||||
"expected final to carry multiple segments")
|
||||
Expect(assembled.String()).To(Equal(finalText),
|
||||
"concat(deltas) must equal final.Text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("TTS", func() {
|
||||
It("synthesizes a non-empty WAV", func() {
|
||||
ttsModel := ttsModelOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: ttsModel})).To(Succeed())
|
||||
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR.", Dst: dst})).To(Succeed())
|
||||
|
||||
info, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred(), "synthesized WAV should exist at %q", dst)
|
||||
// A real 24 kHz mono WAV is a 44-byte header plus samples; anything
|
||||
// this small would mean an empty/failed synth.
|
||||
Expect(info.Size()).To(BeNumerically(">", 1024),
|
||||
"expected a non-trivial WAV, got %d bytes", info.Size())
|
||||
})
|
||||
})
|
||||
})
|
||||
58
backend/go/crispasr/main.go
Normal file
58
backend/go/crispasr/main.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "load_model"},
|
||||
{&CppSetCodecPath, "set_codec_path"},
|
||||
{&CppLoadModelVAD, "load_model_vad"},
|
||||
{&CppVAD, "vad"},
|
||||
{&CppTranscribe, "transcribe"},
|
||||
{&CppGetSegmentText, "get_segment_text"},
|
||||
{&CppGetSegmentStart, "get_segment_t0"},
|
||||
{&CppGetSegmentEnd, "get_segment_t1"},
|
||||
{&CppGetBackend, "get_backend"},
|
||||
{&CppSetAbort, "set_abort"},
|
||||
{&CppTTSSynthesize, "tts_synthesize"},
|
||||
{&CppTTSFree, "tts_free"},
|
||||
{&CppTTSSetVoice, "tts_set_voice"},
|
||||
{&CppTTSSetVoiceFile, "tts_set_voice_file"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &CrispASR{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
65
backend/go/crispasr/package.sh
Executable file
65
backend/go/crispasr/package.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
# This script is used in the final stage of the Dockerfile
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/crispasr $CURDIR/package/
|
||||
cp -fv $CURDIR/libgocrispasr-*.so $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/crispasr/run.sh
Executable file
52
backend/go/crispasr/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgocrispasr-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export CRISPASR_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/crispasr "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/crispasr "$@"
|
||||
@@ -9,7 +9,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
# LocalVQE upstream version pin. Bump to a specific commit when picking up
|
||||
# a new release; `main` works for development but is not reproducible.
|
||||
LOCALVQE_REPO?=https://github.com/localai-org/LocalVQE
|
||||
LOCALVQE_VERSION?=72bfb4c6
|
||||
LOCALVQE_VERSION?=b0f0378a450e87c871b85689554801601ca56d98
|
||||
|
||||
# LocalVQE handles CPU feature selection internally (it ships the multiple
|
||||
# libggml-cpu-*.so variants and its loader picks the best one at runtime
|
||||
@@ -27,7 +27,8 @@ endif
|
||||
|
||||
# LocalVQE upstream supports CPU + Vulkan only. Other BUILD_TYPE values
|
||||
# fall through to the default CPU build — Vulkan is already as fast as the
|
||||
# specialised GPU paths would be on this 1.3 M-parameter model.
|
||||
# specialised GPU paths would be on these small (1.3 M–4.8 M parameter)
|
||||
# models.
|
||||
ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DLOCALVQE_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -46,24 +46,24 @@ const (
|
||||
// through the options builder (CppOptionsNew + setters + CppNewWithOptions)
|
||||
// — the bare localvqe_new path doesn't expose backend / device selection.
|
||||
var (
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
)
|
||||
|
||||
// LocalVQE speaks gRPC against LocalVQE's flat C ABI. The streaming
|
||||
@@ -490,11 +490,14 @@ func (v *LocalVQE) applyStreamConfig(cfg *pb.AudioTransformStreamConfig) error {
|
||||
|
||||
// ---- WAV I/O ----------------------------------------------------------
|
||||
//
|
||||
// Minimal mono PCM WAV reader/writer. Only handles the subset LocalVQE
|
||||
// cares about (mono, 16-bit signed, no extensible chunks). For broader
|
||||
// audio support the HTTP layer's `audio.NormalizeAudioFile` already
|
||||
// converts arbitrary input to a canonical WAV before we see it; this
|
||||
// reader just decodes the canonical shape.
|
||||
// Reader/writer for the mono 16-bit PCM shape LocalVQE works with. Decoding
|
||||
// goes through the shared go-audio/wav decoder (as the whisper and parakeet
|
||||
// backends do) so RIFF chunk walking is handled robustly — an 18/40-byte
|
||||
// extensible `fmt ` chunk, or JUNK/bext/LIST metadata before or after `data`
|
||||
// (e.g. ffmpeg's trailing "Lavf" tag), is skipped rather than spliced into
|
||||
// the PCM stream as an audible click. The HTTP layer normalises arbitrary
|
||||
// input to WAV before we see it, but that WAV is ffmpeg output and is not
|
||||
// guaranteed to be the canonical 44-byte layout.
|
||||
|
||||
func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
f, err := os.Open(path)
|
||||
@@ -502,35 +505,26 @@ func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
header := make([]byte, 44)
|
||||
if _, err := io.ReadFull(f, header); err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
buf, err := wav.NewDecoder(f).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("decode WAV: %w", err)
|
||||
}
|
||||
if string(header[0:4]) != "RIFF" || string(header[8:12]) != "WAVE" {
|
||||
if buf == nil || buf.Format == nil {
|
||||
return nil, 0, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
channels := binary.LittleEndian.Uint16(header[22:24])
|
||||
sampleRate := binary.LittleEndian.Uint32(header[24:28])
|
||||
bitsPerSample := binary.LittleEndian.Uint16(header[34:36])
|
||||
|
||||
if channels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", channels)
|
||||
if buf.Format.NumChannels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", buf.Format.NumChannels)
|
||||
}
|
||||
if bitsPerSample != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", bitsPerSample)
|
||||
if buf.SourceBitDepth != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", buf.SourceBitDepth)
|
||||
}
|
||||
|
||||
rest, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
if len(buf.Data) == 0 {
|
||||
return nil, 0, fmt.Errorf("WAV has no audio data")
|
||||
}
|
||||
n := len(rest) / 2
|
||||
out := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
s := int16(binary.LittleEndian.Uint16(rest[i*2 : i*2+2]))
|
||||
out[i] = float32(s) / 32768.0
|
||||
}
|
||||
return out, int(sampleRate), nil
|
||||
// AsFloat32Buffer normalises by 2^(bitDepth-1) == /32768 for 16-bit,
|
||||
// matching the model's expected [-1, 1) input range.
|
||||
return buf.AsFloat32Buffer().Data, buf.Format.SampleRate, nil
|
||||
}
|
||||
|
||||
func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
@@ -546,13 +540,13 @@ func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
binary.LittleEndian.PutUint32(header[4:8], 36+dataLen)
|
||||
copy(header[8:12], []byte("WAVE"))
|
||||
copy(header[12:16], []byte("fmt "))
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[24:28], uint32(sampleRate))
|
||||
binary.LittleEndian.PutUint32(header[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
copy(header[36:40], []byte("data"))
|
||||
binary.LittleEndian.PutUint32(header[40:44], dataLen)
|
||||
if _, err := f.Write(header); err != nil {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -92,6 +94,147 @@ var _ = Describe("LocalVQE-cpp", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("readMonoWAVf32 chunk parsing", func() {
|
||||
// chunk builds a word-aligned RIFF sub-chunk (id + size + body + pad).
|
||||
chunk := func(id string, body []byte) []byte {
|
||||
out := append([]byte(id), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
out = append(out, body...)
|
||||
if len(body)&1 == 1 {
|
||||
out = append(out, 0) // pad byte for odd-sized chunks
|
||||
}
|
||||
return out
|
||||
}
|
||||
// fmtBody returns a PCM `fmt ` chunk body. extra bytes simulate the
|
||||
// 18/40-byte extensible form (cbSize + extension).
|
||||
fmtBody := func(channels, bits uint16, rate uint32, extra int) []byte {
|
||||
b := make([]byte, 16+extra)
|
||||
binary.LittleEndian.PutUint16(b[0:2], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(b[2:4], channels)
|
||||
binary.LittleEndian.PutUint32(b[4:8], rate)
|
||||
binary.LittleEndian.PutUint32(b[8:12], rate*uint32(channels)*uint32(bits)/8)
|
||||
binary.LittleEndian.PutUint16(b[12:14], channels*bits/8)
|
||||
binary.LittleEndian.PutUint16(b[14:16], bits)
|
||||
if extra >= 2 {
|
||||
binary.LittleEndian.PutUint16(b[16:18], uint16(extra-2)) // cbSize
|
||||
}
|
||||
return b
|
||||
}
|
||||
// pcm encodes int16 samples little-endian.
|
||||
pcm := func(samples ...int16) []byte {
|
||||
b := make([]byte, len(samples)*2)
|
||||
for i, s := range samples {
|
||||
binary.LittleEndian.PutUint16(b[i*2:i*2+2], uint16(s))
|
||||
}
|
||||
return b
|
||||
}
|
||||
riff := func(chunks ...[]byte) []byte {
|
||||
body := []byte("WAVE")
|
||||
for _, c := range chunks {
|
||||
body = append(body, c...)
|
||||
}
|
||||
out := append([]byte("RIFF"), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
return append(out, body...)
|
||||
}
|
||||
writeWAV := func(b []byte) string {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "in.wav")
|
||||
Expect(os.WriteFile(p, b, 0o600)).To(Succeed())
|
||||
return p
|
||||
}
|
||||
// A canonical sample run with distinct values so any off-by-one /
|
||||
// misalignment shows up as wrong numbers, not just wrong length.
|
||||
samples := []int16{1000, -2000, 3000, -4000, 5000, -6000}
|
||||
expectSamples := func(got []float32) {
|
||||
Expect(got).To(HaveLen(len(samples)))
|
||||
for i, s := range samples {
|
||||
Expect(got[i]).To(BeNumerically("~", float32(s)/32768.0, 1e-6))
|
||||
}
|
||||
}
|
||||
|
||||
It("reads a canonical 44-byte WAV", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("ignores a LIST/JUNK chunk placed before data (no leading-impulse splice)", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("JUNK", []byte("padding-bytes-here!")), // odd length → exercises pad
|
||||
chunk("LIST", []byte("INFOISFTLavf60.0")),
|
||||
chunk("data", pcm(samples...)),
|
||||
))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out) // not corrupted by the preceding chunks
|
||||
})
|
||||
|
||||
It("honours the data chunk size and drops a trailing metadata chunk", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("data", pcm(samples...)),
|
||||
chunk("LIST", []byte("INFOISFTLavf60.16.100")), // ffmpeg trailer tag
|
||||
))
|
||||
out, _, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectSamples(out) // trailing LIST bytes not decoded as PCM
|
||||
})
|
||||
|
||||
It("handles the 18-byte extensible fmt chunk", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 2)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("rejects non-mono input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(2, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("mono"))
|
||||
})
|
||||
|
||||
It("rejects non-16-bit input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 8, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("16-bit"))
|
||||
})
|
||||
|
||||
It("rejects a non-WAV file", func() {
|
||||
p := writeWAV([]byte("not a riff file at all"))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors when the data chunk is missing", func() {
|
||||
// fmt but no data: the decoder must fail rather than return an
|
||||
// empty (or garbage) sample slice. The exact message is the
|
||||
// decoder's, so just assert it errors.
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("round-trips through writeMonoWAVf32", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "rt.wav")
|
||||
in := []float32{0.1, -0.2, 0.3, -0.4}
|
||||
Expect(writeMonoWAVf32(p, in, 16000)).To(Succeed())
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
Expect(out).To(HaveLen(len(in)))
|
||||
for i := range in {
|
||||
Expect(out[i]).To(BeNumerically("~", in[i], 1e-4))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("model-gated integration (LOCALVQE_MODEL_PATH)", func() {
|
||||
It("load + sample rate + hop + fft", func() {
|
||||
path := modelPathOrSkip()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=30a307553f1965ceb38a1a922069a71e7dd67bf3
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
@@ -15,7 +15,7 @@
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=30a307553f1965ceb38a1a922069a71e7dd67bf3
|
||||
PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
@@ -34,14 +34,18 @@ ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# parakeet.cpp gates its GGML backends behind PARAKEET_GGML_* options and does
|
||||
# set(GGML_CUDA ${PARAKEET_GGML_CUDA} CACHE BOOL "" FORCE), so a bare -DGGML_CUDA=ON
|
||||
# is overwritten back to OFF and the build silently falls back to CPU. Forward the
|
||||
# PARAKEET_GGML_* options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DGGML_HIP=ON
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_HIP=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_VULKAN=ON
|
||||
endif
|
||||
|
||||
.PHONY: parakeet-cpp-grpc package build clean purge test all
|
||||
|
||||
79
backend/go/parakeet-cpp/batcher.go
Normal file
79
backend/go/parakeet-cpp/batcher.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package main
|
||||
|
||||
import "time"
|
||||
|
||||
// batchRequest is one in-flight unary transcription waiting to be batched.
|
||||
// In production pcm/decoder are set; tag is an opaque marker used by tests.
|
||||
type batchRequest struct {
|
||||
pcm []float32
|
||||
decoder int32
|
||||
tag string
|
||||
reply chan batchReply
|
||||
}
|
||||
|
||||
// batchReply carries one per-item JSON object string (an element of the C-API's
|
||||
// JSON array) or an error back to the waiting handler goroutine.
|
||||
type batchReply struct {
|
||||
json string
|
||||
err error
|
||||
}
|
||||
|
||||
// batcher coalesces concurrent batchRequests into batched runBatch calls. A
|
||||
// single run() goroutine is the sole caller of runBatch, so runBatch (which in
|
||||
// production calls the thread-unsafe C engine) is never entered concurrently.
|
||||
type batcher struct {
|
||||
submit chan *batchRequest
|
||||
maxSize int
|
||||
maxWait time.Duration
|
||||
runBatch func(reqs []*batchRequest) // must deliver a reply to every req
|
||||
}
|
||||
|
||||
func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchRequest)) *batcher {
|
||||
if maxSize < 1 {
|
||||
maxSize = 1
|
||||
}
|
||||
return &batcher{
|
||||
submit: make(chan *batchRequest),
|
||||
maxSize: maxSize,
|
||||
maxWait: maxWait,
|
||||
runBatch: runBatch,
|
||||
}
|
||||
}
|
||||
|
||||
// run is the dispatcher loop: accumulate submitted requests until either maxSize
|
||||
// is reached or maxWait elapses since the first queued request, then dispatch.
|
||||
// Exits when stop is closed (draining any partially-filled batch first).
|
||||
func (b *batcher) run(stop <-chan struct{}) {
|
||||
for {
|
||||
var first *batchRequest
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
batch := []*batchRequest{first}
|
||||
|
||||
// maxSize==1 disables batching: dispatch immediately (passthrough).
|
||||
if b.maxSize == 1 {
|
||||
b.runBatch(batch)
|
||||
continue
|
||||
}
|
||||
|
||||
timer := time.NewTimer(b.maxWait)
|
||||
fill:
|
||||
for len(batch) < b.maxSize {
|
||||
select {
|
||||
case r := <-b.submit:
|
||||
batch = append(batch, r)
|
||||
case <-timer.C:
|
||||
break fill
|
||||
case <-stop:
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
return
|
||||
}
|
||||
}
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
}
|
||||
}
|
||||
108
backend/go/parakeet-cpp/batcher_test.go
Normal file
108
backend/go/parakeet-cpp/batcher_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("batcher", func() {
|
||||
echoReply := func(reqs []*batchRequest) {
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{json: r.tag}
|
||||
}
|
||||
}
|
||||
|
||||
It("coalesces concurrent submits into batches", func() {
|
||||
var mu sync.Mutex
|
||||
var sizes []int
|
||||
run := func(reqs []*batchRequest) {
|
||||
mu.Lock()
|
||||
sizes = append(sizes, len(reqs))
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(4, 50*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
const N = 4
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
total, maxBatch := 0, 0
|
||||
for _, s := range sizes {
|
||||
total += s
|
||||
if s > maxBatch {
|
||||
maxBatch = s
|
||||
}
|
||||
}
|
||||
Expect(total).To(Equal(N))
|
||||
Expect(maxBatch).To(BeNumerically(">=", 2), "expected at least one batch to coalesce >1 request")
|
||||
})
|
||||
|
||||
It("dispatches when max size is reached", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(2, time.Hour, run) // huge window: only size can trigger
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
for i := 0; i < 2; i++ {
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func(rep chan batchReply) { <-rep }(rep)
|
||||
}
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(2)))
|
||||
})
|
||||
|
||||
It("dispatches when the wait window elapses", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(8, 20*time.Millisecond, run) // size unreachable; window fires
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("bypasses batching when max size is 1", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(1, time.Hour, run) // size 1 => immediate dispatch
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
})
|
||||
@@ -7,13 +7,18 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
@@ -34,6 +39,15 @@ var (
|
||||
CppFreeString func(s uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
|
||||
// Batched JSON transcription: takes a concatenated float buffer of clips
|
||||
// plus their per-clip sample counts (sum(nSamples)==len(samplesConcat))
|
||||
// and returns a malloc'd char* JSON ARRAY of per-clip {"text","words",
|
||||
// "tokens"} objects (uintptr, freed via CppFreeString). purego passes the
|
||||
// Go slices as the base pointer of their backing array (kept alive for the
|
||||
// call), matching the CppStreamFeed pcm []float32 binding pattern; the C
|
||||
// side reads them as const float*/const int*.
|
||||
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
|
||||
|
||||
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
|
||||
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
|
||||
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
|
||||
@@ -77,11 +91,18 @@ type transcriptToken struct {
|
||||
}
|
||||
|
||||
// ParakeetCpp owns a single loaded parakeet_ctx. The C engine is a
|
||||
// thread-unsafe singleton (mirrors whisper.cpp / vibevoice.cpp), so we
|
||||
// serialize calls through base.SingleThread.
|
||||
// thread-unsafe singleton (mirrors whisper.cpp / vibevoice.cpp). Rather than
|
||||
// serialize every call through base.SingleThread, we route unary
|
||||
// transcription through an in-process batcher (its sole dispatcher goroutine
|
||||
// is the only caller of the engine on that path) and guard the shared engine
|
||||
// with engineMu so a streaming session and a batched-unary dispatch never
|
||||
// touch it concurrently.
|
||||
type ParakeetCpp struct {
|
||||
base.SingleThread
|
||||
ctxPtr uintptr
|
||||
base.Base
|
||||
ctxPtr uintptr
|
||||
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
|
||||
bat *batcher
|
||||
batStop chan struct{}
|
||||
}
|
||||
|
||||
// Load is the LocalAI gRPC entry point for LoadModel: it calls
|
||||
@@ -100,13 +121,103 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
return fmt.Errorf("parakeet-cpp: parakeet_capi_load failed for %q", opts.ModelFile)
|
||||
}
|
||||
p.ctxPtr = ctx
|
||||
|
||||
// Dynamic batching knobs (model YAML options:, key:value form). Batching is
|
||||
// OFF by default (batch_max_size:1): each request runs on its own. On GPU,
|
||||
// raising batch_max_size coalesces concurrent requests into one batched
|
||||
// engine call and improves throughput under load; leave it at 1 on CPU and
|
||||
// for low-concurrency setups, where batching only adds latency.
|
||||
maxSize := optInt(opts, "batch_max_size", 1)
|
||||
maxWaitMs := optInt(opts, "batch_max_wait_ms", 15)
|
||||
if maxWaitMs < 0 {
|
||||
maxWaitMs = 0
|
||||
}
|
||||
if CppTranscribePcmBatchJSON != nil {
|
||||
p.batStop = make(chan struct{})
|
||||
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
|
||||
go p.bat.run(p.batStop) // dispatcher runs until Free closes batStop
|
||||
if maxSize > 1 {
|
||||
xlog.Info("parakeet-cpp: dynamic batching enabled",
|
||||
"batch_max_size", maxSize, "batch_max_wait_ms", maxWaitMs)
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: dynamic batching off (batch_max_size=1); " +
|
||||
"set batch_max_size>1 to coalesce concurrent requests on GPU")
|
||||
}
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: batched C-API not present in libparakeet.so; " +
|
||||
"batching disabled, using per-request transcription")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AudioTranscription runs parakeet_capi_transcribe_path_json on the wav at
|
||||
// opts.Dst with the default decoder (decoder=0, which selects the right head
|
||||
// per architecture: transducer for tdt/rnnt/hybrid, CTC for ctc) and shapes
|
||||
// the per-word timestamps into a LocalAI TranscriptResult.
|
||||
// optInt reads an integer model option (key:value form) from ModelOptions,
|
||||
// returning def when absent or unparseable. The options array carries the
|
||||
// model YAML's options: entries (see core/config; siblings such as
|
||||
// acestep-cpp parse the same key:value form via strings.Cut on ":").
|
||||
func optInt(opts *pb.ModelOptions, key string, def int) int {
|
||||
for _, o := range opts.GetOptions() {
|
||||
k, v, ok := strings.Cut(o, ":")
|
||||
if ok && strings.TrimSpace(k) == key {
|
||||
if n, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// runBatch is the dispatcher's batch handler and the ONLY caller of the C
|
||||
// engine on the unary path. It concatenates the batch PCM, calls the batched
|
||||
// JSON C-API under engineMu, splits the JSON array, and replies to each request.
|
||||
func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// Observability: the actual coalesced batch size per engine call. Debug-level
|
||||
// so it stays silent in normal operation but lets operators confirm/tune batching.
|
||||
xlog.Debug("parakeet-cpp: dispatching batch", "size", len(reqs))
|
||||
nSamples := make([]int32, len(reqs))
|
||||
total := 0
|
||||
for i, r := range reqs {
|
||||
nSamples[i] = int32(len(r.pcm))
|
||||
total += len(r.pcm)
|
||||
}
|
||||
concat := make([]float32, 0, total)
|
||||
for _, r := range reqs {
|
||||
concat = append(concat, r.pcm...)
|
||||
}
|
||||
var dec int32
|
||||
if len(reqs) > 0 {
|
||||
dec = reqs[0].decoder
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
p.engineMu.Unlock()
|
||||
if cstr == 0 {
|
||||
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: err}
|
||||
}
|
||||
return
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var docs []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(raw), &docs); err != nil || len(docs) != len(reqs) {
|
||||
e := fmt.Errorf("parakeet-cpp: batch json: got %d results for %d reqs (%v)", len(docs), len(reqs), err)
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: e}
|
||||
}
|
||||
return
|
||||
}
|
||||
for i, r := range reqs {
|
||||
r.reply <- batchReply{json: string(docs[i])}
|
||||
}
|
||||
}
|
||||
|
||||
// AudioTranscription decodes the wav at opts.Dst to 16 kHz mono PCM and
|
||||
// submits it to the in-process batcher, which coalesces concurrent requests
|
||||
// into a single batched engine call (parakeet_capi_transcribe_pcm_batch_json)
|
||||
// with the default decoder (decoder=0, which selects the right head per
|
||||
// architecture: transducer for tdt/rnnt/hybrid, CTC for ctc) and shapes the
|
||||
// per-word timestamps into a LocalAI TranscriptResult.
|
||||
//
|
||||
// Parakeet emits word- and token-level timestamps but no native segment
|
||||
// boundaries, so we synthesise a single whole-clip segment spanning the first
|
||||
@@ -118,69 +229,91 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
// translate/diarize/prompt/temperature/language/threads are not applicable to
|
||||
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
|
||||
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
|
||||
if cstr == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
// Fallback when the batched C-API is unavailable: transcribe from a file
|
||||
// path (original behavior, no batching). The C library's audio loader only
|
||||
// understands 16 kHz mono WAV/PCM, so convert the input first - otherwise
|
||||
// any non-WAV upload (MP3, etc.) fails with "failed to load audio". This
|
||||
// mirrors what every other audio backend (whisper, crispasr) does via
|
||||
// utils.AudioToWav before handing the file to the engine.
|
||||
if p.bat == nil {
|
||||
converted, cleanup, err := convertToWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", msg)
|
||||
defer cleanup()
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, converted, 0)
|
||||
if cstr == 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
|
||||
// Batched path: decode to PCM, submit to the batcher, wait for this request's
|
||||
// JSON element. The dispatcher is the sole engine caller on this path; both
|
||||
// sends honour ctx cancellation.
|
||||
pcm, _, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
rep := make(chan batchReply, 1)
|
||||
select {
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
var res batchReply
|
||||
select {
|
||||
case res = <-rep:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if res.err != nil {
|
||||
return pb.TranscriptResult{}, res.err
|
||||
}
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
|
||||
// synthesising a single whole-clip segment and attaching word timings only when
|
||||
// the caller requested word granularity. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
|
||||
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
|
||||
for _, w := range doc.Words {
|
||||
words = append(words, &pb.TranscriptWord{
|
||||
Start: secondsToNanos(w.Start),
|
||||
End: secondsToNanos(w.End),
|
||||
Text: w.W,
|
||||
})
|
||||
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
|
||||
}
|
||||
|
||||
tokens := make([]int32, 0, len(doc.Tokens))
|
||||
for _, t := range doc.Tokens {
|
||||
tokens = append(tokens, t.ID)
|
||||
}
|
||||
|
||||
// Single whole-clip segment, spanning the first word start to the last
|
||||
// word end (0/0 when the clip produced no words).
|
||||
var segStart, segEnd int64
|
||||
if len(words) > 0 {
|
||||
segStart = words[0].Start
|
||||
segEnd = words[len(words)-1].End
|
||||
}
|
||||
seg := &pb.TranscriptSegment{
|
||||
Id: 0,
|
||||
Start: segStart,
|
||||
End: segEnd,
|
||||
Text: text,
|
||||
Tokens: tokens,
|
||||
}
|
||||
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
|
||||
if wordsRequested(opts.TimestampGranularities) {
|
||||
seg.Words = words
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{seg},
|
||||
}, nil
|
||||
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
|
||||
}
|
||||
|
||||
// wordsRequested reports whether the caller asked for word-level timestamps.
|
||||
@@ -219,7 +352,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
defer close(results)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return errors.New("parakeet-cpp: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
@@ -243,6 +376,14 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return nil
|
||||
}
|
||||
defer CppStreamFree(stream)
|
||||
// The C engine is a single shared context: a streaming session and a batched
|
||||
// unary dispatch must never touch it at once, so hold engineMu for the whole
|
||||
// stream. This lock is intentionally taken AFTER the non-streaming fallback
|
||||
// above returns: that fallback goes through AudioTranscription -> the batcher
|
||||
// -> runBatch, which itself acquires engineMu, so locking here first would
|
||||
// deadlock. Do not hoist this lock above the fallback.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
|
||||
data, duration, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
@@ -329,17 +470,33 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
// float samples plus the clip duration in seconds. Mirrors the whisper
|
||||
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
|
||||
// decodes the PCM.
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
// convertToWavMono16k converts an arbitrary audio file to a 16 kHz mono WAV in
|
||||
// a fresh temp dir and returns the path together with a cleanup func the caller
|
||||
// must defer. WAV inputs already at 16 kHz/mono/16-bit are passed through by
|
||||
// utils.AudioToWav (hardlink/copy), everything else is transcoded via ffmpeg.
|
||||
// Used by the direct (non-batched) transcription path, which hands a file path
|
||||
// to the C library's WAV-only audio loader.
|
||||
func convertToWavMono16k(path string) (string, func(), error) {
|
||||
dir, err := os.MkdirTemp("", "parakeet")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return "", func() {}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
cleanup := func() { _ = os.RemoveAll(dir) }
|
||||
|
||||
converted := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(path, converted); err != nil {
|
||||
cleanup()
|
||||
return "", func() {}, err
|
||||
}
|
||||
return converted, cleanup, nil
|
||||
}
|
||||
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
converted, cleanup, err := convertToWavMono16k(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
fh, err := os.Open(converted)
|
||||
if err != nil {
|
||||
@@ -362,6 +519,12 @@ func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
// Free releases the underlying parakeet_ctx. Called by LocalAI when the
|
||||
// model is unloaded.
|
||||
func (p *ParakeetCpp) Free() error {
|
||||
// Stop the dispatcher before releasing the engine so no in-flight runBatch
|
||||
// can touch a freed ctx (close leak / use-after-free on reload).
|
||||
if p.batStop != nil {
|
||||
close(p.batStop)
|
||||
p.batStop = nil
|
||||
}
|
||||
if p.ctxPtr != 0 {
|
||||
CppFree(p.ctxPtr)
|
||||
p.ctxPtr = 0
|
||||
|
||||
@@ -3,11 +3,14 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -43,6 +46,9 @@ func ensureLibLoaded() {
|
||||
purego.RegisterLibFunc(&CppFree, lib, "parakeet_capi_free")
|
||||
purego.RegisterLibFunc(&CppTranscribePath, lib, "parakeet_capi_transcribe_path")
|
||||
purego.RegisterLibFunc(&CppTranscribePathJSON, lib, "parakeet_capi_transcribe_path_json")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppStreamBegin, lib, "parakeet_capi_stream_begin")
|
||||
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
|
||||
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
|
||||
@@ -67,6 +73,24 @@ func fixturesOrSkip() (string, string) {
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// writeMono16kWav writes `samples` frames of 16 kHz mono 16-bit silence to
|
||||
// path. The result is already in AudioToWav's target format, so the conversion
|
||||
// helper copies it through without invoking ffmpeg.
|
||||
func writeMono16kWav(path string, samples int) {
|
||||
GinkgoHelper()
|
||||
f, err := os.Create(path)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
enc := wav.NewEncoder(f, 16000, 16, 1, 1)
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 16000},
|
||||
SourceBitDepth: 16,
|
||||
Data: make([]int, samples),
|
||||
}
|
||||
Expect(enc.Write(buf)).To(Succeed())
|
||||
Expect(enc.Close()).To(Succeed())
|
||||
Expect(f.Close()).To(Succeed())
|
||||
}
|
||||
|
||||
var _ = Describe("ParakeetCpp", func() {
|
||||
Context("AudioTranscription", func() {
|
||||
It("transcribes a WAV via the parakeet C-API", func() {
|
||||
@@ -117,6 +141,39 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("convertToWavMono16k", func() {
|
||||
// The non-batched transcription path hands a file path to the C
|
||||
// library's WAV-only audio loader, so it must convert first.
|
||||
// utils.AudioToWav passes an already-16kHz/mono/16-bit WAV through
|
||||
// without ffmpeg, which lets us exercise the helper (and the
|
||||
// regression: the direct path used to skip conversion entirely)
|
||||
// without a model, the C library, or ffmpeg.
|
||||
It("returns a decodable 16kHz mono WAV copy and cleans it up", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
src := filepath.Join(dir, "input.wav")
|
||||
writeMono16kWav(src, 16000) // 1s of silence at 16 kHz
|
||||
|
||||
converted, cleanup, err := convertToWavMono16k(src)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// It must produce a fresh temp file, not return the original path.
|
||||
Expect(converted).ToNot(Equal(src))
|
||||
Expect(converted).To(BeAnExistingFile())
|
||||
|
||||
pcm, _, err := decodeWavMono16k(converted)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pcm).To(HaveLen(16000), "round-trips the sample count")
|
||||
|
||||
cleanup()
|
||||
Expect(converted).ToNot(BeAnExistingFile(), "cleanup removes the temp dir")
|
||||
})
|
||||
|
||||
It("errors on a non-existent input rather than passing the path through", func() {
|
||||
_, _, err := convertToWavMono16k(filepath.Join(GinkgoT().TempDir(), "missing.mp3"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("streams deltas and a closing FinalResult from a cache-aware model", func() {
|
||||
// Streaming needs a cache-aware streaming model (e.g.
|
||||
|
||||
@@ -58,6 +58,13 @@ func main() {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
// The batched-JSON entry point exists only in newer libparakeet.so (ABI >= 2).
|
||||
// Probe with Dlsym and register only if present, so the backend still loads
|
||||
// against an older library (it falls back to per-request transcription).
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# qwen3-tts.cpp version
|
||||
QWEN3TTS_REPO?=https://github.com/predict-woo/qwen3-tts.cpp
|
||||
QWEN3TTS_CPP_VERSION?=7a762e2ad4bacc6fdda81d81bf10a09ffb546f29
|
||||
QWEN3TTS_CPP_VERSION?=136e5d36c17083da0321fd96512dc7b263f94a44
|
||||
SO_TARGET?=libgoqwen3ttscpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -21,6 +22,43 @@ type Qwen3TtsCpp struct {
|
||||
threads int
|
||||
}
|
||||
|
||||
// languageNameAliases maps common full language names to the canonical
|
||||
// two-letter code understood by the C++ language_to_id table.
|
||||
var languageNameAliases = map[string]string{
|
||||
"english": "en",
|
||||
"russian": "ru",
|
||||
"chinese": "zh",
|
||||
"japanese": "ja",
|
||||
"korean": "ko",
|
||||
"german": "de",
|
||||
"french": "fr",
|
||||
"spanish": "es",
|
||||
"italian": "it",
|
||||
"portuguese": "pt",
|
||||
}
|
||||
|
||||
// normalizeLanguage coerces a caller-supplied language into the canonical code
|
||||
// the model expects. It lowercases, trims, strips any region/locale suffix
|
||||
// (en-US, en_US, ja.JP -> en/ja), and resolves common full names (english -> en).
|
||||
// An empty input stays empty so the C++ side applies its English default; an
|
||||
// unrecognized value is returned normalized so C++ can log it and default.
|
||||
func normalizeLanguage(lang string) string {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
if lang == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip region/locale suffix: keep the segment before the first separator.
|
||||
if i := strings.IndexAny(lang, "-_."); i >= 0 {
|
||||
lang = lang[:i]
|
||||
}
|
||||
|
||||
if code, ok := languageNameAliases[lang]; ok {
|
||||
return code
|
||||
}
|
||||
return lang
|
||||
}
|
||||
|
||||
func (q *Qwen3TtsCpp) Load(opts *pb.ModelOptions) error {
|
||||
// ModelFile is the model directory path (containing GGUF files)
|
||||
modelDir := opts.ModelFile
|
||||
@@ -54,7 +92,7 @@ func (q *Qwen3TtsCpp) TTS(req *pb.TTSRequest) error {
|
||||
dst := req.Dst
|
||||
language := ""
|
||||
if req.Language != nil {
|
||||
language = *req.Language
|
||||
language = normalizeLanguage(*req.Language)
|
||||
}
|
||||
|
||||
// Synthesis parameters with sensible defaults
|
||||
|
||||
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLanguageNormalization(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "qwen3-tts-cpp language normalization")
|
||||
}
|
||||
|
||||
var _ = Describe("normalizeLanguage", func() {
|
||||
DescribeTable("maps caller input to the canonical model language code",
|
||||
func(input, expected string) {
|
||||
Expect(normalizeLanguage(input)).To(Equal(expected))
|
||||
},
|
||||
// Canonical codes pass through unchanged
|
||||
Entry("canonical en", "en", "en"),
|
||||
Entry("canonical zh", "zh", "zh"),
|
||||
Entry("canonical pt", "pt", "pt"),
|
||||
|
||||
// Case-insensitive
|
||||
Entry("uppercase", "EN", "en"),
|
||||
Entry("mixed case", "Ja", "ja"),
|
||||
|
||||
// Surrounding whitespace
|
||||
Entry("trims whitespace", " en ", "en"),
|
||||
|
||||
// Region/locale stripping
|
||||
Entry("BCP-47 region", "en-US", "en"),
|
||||
Entry("underscore region", "en_US", "en"),
|
||||
Entry("dotted locale", "ja.JP", "ja"),
|
||||
Entry("region + case", "ZH-CN", "zh"),
|
||||
|
||||
// Full-name aliases
|
||||
Entry("english name", "english", "en"),
|
||||
Entry("chinese name cased", "Chinese", "zh"),
|
||||
Entry("japanese name", "japanese", "ja"),
|
||||
Entry("russian name", "russian", "ru"),
|
||||
Entry("portuguese name", "portuguese", "pt"),
|
||||
|
||||
// Empty stays empty (C++ applies the English default)
|
||||
Entry("empty", "", ""),
|
||||
Entry("whitespace only", " ", ""),
|
||||
|
||||
// Unknown values pass through normalized so C++ can log + default
|
||||
Entry("unknown code", "klingon", "klingon"),
|
||||
Entry("unknown with region", "xx-YY", "xx"),
|
||||
)
|
||||
})
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=0e4ee04488159b81d95a9ffcd983a077fd5dcb77
|
||||
STABLEDIFFUSION_GGML_VERSION?=1f9ee88e09c258053fa59d5e05e23dfb10fa0b13
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=f24588a272ae8e23280d9c220536437164e6ed28
|
||||
WHISPER_CPP_VERSION?=99613cb720b65036237d44b52f753b51f75c2797
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -122,6 +122,33 @@
|
||||
nvidia-cuda-12: "cuda12-whisper"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisper"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-whisper"
|
||||
- &crispasr
|
||||
name: "crispasr"
|
||||
alias: "crispasr"
|
||||
license: mit
|
||||
icon: https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg
|
||||
description: |
|
||||
CrispASR unified speech engine (whisper.cpp fork on ggml) supporting many ASR architectures (Parakeet, Canary, Voxtral, Qwen3-ASR, Granite, Wav2Vec2, Moonshine, OmniASR, FireRedASR, and more).
|
||||
urls:
|
||||
- https://github.com/CrispStrobe/CrispASR
|
||||
tags:
|
||||
- audio-transcription
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-crispasr"
|
||||
nvidia: "cuda12-crispasr"
|
||||
intel: "intel-sycl-f16-crispasr"
|
||||
metal: "metal-crispasr"
|
||||
amd: "rocm-crispasr"
|
||||
vulkan: "vulkan-crispasr"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-cuda-13: "cuda13-crispasr"
|
||||
nvidia-cuda-12: "cuda12-crispasr"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
- ¶keetcpp
|
||||
name: "parakeet-cpp"
|
||||
alias: "parakeet-cpp"
|
||||
@@ -1957,6 +1984,131 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-whisper
|
||||
## crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "crispasr-development"
|
||||
capabilities:
|
||||
default: "cpu-crispasr-development"
|
||||
nvidia: "cuda12-crispasr-development"
|
||||
intel: "intel-sycl-f16-crispasr-development"
|
||||
metal: "metal-crispasr-development"
|
||||
amd: "rocm-crispasr-development"
|
||||
vulkan: "vulkan-crispasr-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-cuda-13: "cuda13-crispasr-development"
|
||||
nvidia-cuda-12: "cuda12-crispasr-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-crispasr
|
||||
## parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "parakeet-cpp-development"
|
||||
|
||||
@@ -37,6 +37,20 @@ def is_int(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a TTSRequest.params value (string on the wire) to the type the
|
||||
Chatterbox generate() kwargs expect (float/int/bool), matching how static
|
||||
YAML options are coerced at load time. Non-string values pass through."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if is_float(value):
|
||||
return float(value)
|
||||
if is_int(value):
|
||||
return int(value)
|
||||
if value.lower() in ["true", "false"]:
|
||||
return value.lower() == "true"
|
||||
return value
|
||||
|
||||
def split_text_at_word_boundary(text, max_length=250):
|
||||
"""
|
||||
Split text at word boundaries without truncating words.
|
||||
@@ -191,6 +205,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Merge per-request params (TTSRequest.params), overriding the static
|
||||
# YAML options. This exposes Chatterbox generation knobs (e.g.
|
||||
# exaggeration, cfg_weight, temperature) per request. Values arrive as
|
||||
# strings on the wire and are coerced to float/int/bool.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Check if text exceeds 250 characters
|
||||
# (chatterbox does not support long text)
|
||||
# https://github.com/resemble-ai/chatterbox/issues/60
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
texterrors==1.1.6
|
||||
nemo_toolkit[asr]
|
||||
|
||||
@@ -47,6 +47,26 @@ def is_int(s):
|
||||
return False
|
||||
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a string param value (from the TTSRequest.params map, which is
|
||||
string-typed on the wire) into the most specific Python type the model
|
||||
generation kwargs expect: bool, int, float, else the original string."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
lowered = value.strip().lower()
|
||||
if lowered in ("true", "false"):
|
||||
return lowered == "true"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
@@ -322,6 +342,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def _effective_instruct(self, request):
|
||||
"""Resolve the instruction/style string for this request, preferring the
|
||||
per-request TTSRequest.instructions value and falling back to the static
|
||||
YAML `instruct` option. Empty string means "no instruction"."""
|
||||
req_instruct = (
|
||||
request.instructions
|
||||
if hasattr(request, "instructions") and request.instructions
|
||||
else ""
|
||||
)
|
||||
if req_instruct:
|
||||
return req_instruct
|
||||
return self.options.get("instruct", "") or ""
|
||||
|
||||
def _detect_mode(self, request):
|
||||
"""Detect which mode to use based on request parameters."""
|
||||
# Priority: VoiceClone > VoiceDesign > CustomVoice
|
||||
@@ -338,8 +371,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if self.audio_path or self.voices:
|
||||
return "VoiceClone"
|
||||
|
||||
# VoiceDesign: instruct option is provided
|
||||
if "instruct" in self.options and self.options["instruct"]:
|
||||
# VoiceDesign: instruct provided per-request or via YAML option
|
||||
if self._effective_instruct(request):
|
||||
return "VoiceDesign"
|
||||
|
||||
# Default to CustomVoice
|
||||
@@ -690,10 +723,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if do_sample is not None:
|
||||
generation_kwargs["do_sample"] = do_sample
|
||||
|
||||
instruct = self.options.get("instruct", "")
|
||||
# Prefer the per-request instruction (TTSRequest.instructions) over the
|
||||
# static YAML `instruct` option. This lets clients set a different style
|
||||
# (CustomVoice emotion) or designed voice (VoiceDesign) per request.
|
||||
instruct = self._effective_instruct(request)
|
||||
if instruct is not None and instruct != "":
|
||||
generation_kwargs["instruct"] = instruct
|
||||
|
||||
# Merge any per-request backend-specific params (TTSRequest.params).
|
||||
# Values arrive as strings on the wire; coerce to int/float/bool so the
|
||||
# model receives the types it expects. These override YAML-derived kwargs.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
generation_kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Generate audio based on mode
|
||||
if mode == "VoiceClone":
|
||||
# VoiceClone mode
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
torch==2.7.1
|
||||
torch==2.7.1+xpu
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
torch==2.7.1
|
||||
torch==2.7.1+xpu
|
||||
accelerate
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch==2.9.0
|
||||
torch==2.7.1+xpu
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.9.0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm7.0
|
||||
torch==2.10.0+rocm7.0
|
||||
torch==2.7.1+xpu
|
||||
accelerate
|
||||
transformers>=5.9.0
|
||||
llvmlite==0.43.0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
torch==2.7.1
|
||||
torch==2.7.1+xpu
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf==7.35.0
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -3,5 +3,5 @@
|
||||
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
||||
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
||||
# so uv consults this index alongside PyPI.
|
||||
--extra-index-url https://wheels.vllm.ai/0.22.0/cu130
|
||||
vllm==0.22.0
|
||||
--extra-index-url https://wheels.vllm.ai/0.22.1/cu130
|
||||
vllm==0.22.1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -16,7 +16,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
@@ -100,7 +102,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
|
||||
natsAuth := cfg.Distributed.NatsAuthConfig()
|
||||
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
|
||||
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
|
||||
}
|
||||
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
@@ -240,6 +247,84 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
|
||||
// Prefix-cache-aware routing. Enabled by default; an operator can opt out
|
||||
// with --distributed-prefix-cache=false, which leaves prefixProvider and
|
||||
// pressure nil so the SmartRouter and reconciler behave exactly as the
|
||||
// round-robin floor (true no-op). When enabled we build the local index,
|
||||
// wrap it in a NATS-backed Sync (publishes our observations, applies peers'
|
||||
// via the subscriptions below), install the extraction hook used by
|
||||
// core/backend/llm.go, and run a background eviction ticker on the app ctx.
|
||||
var prefixProvider prefixcache.Provider
|
||||
var pressure *prefixcache.Pressure
|
||||
var prefixCfg prefixcache.Config
|
||||
if !cfg.Distributed.PrefixCacheDisabled {
|
||||
prefixCfg = prefixcache.DefaultConfig()
|
||||
if cfg.Distributed.PrefixCacheTTL > 0 {
|
||||
prefixCfg.TTL = cfg.Distributed.PrefixCacheTTL
|
||||
}
|
||||
if err := prefixCfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid prefix-cache configuration: %w", err)
|
||||
}
|
||||
idx := prefixcache.NewIndex(prefixCfg)
|
||||
prefixSync := prefixcache.NewSync(idx, natsClient)
|
||||
pressure = prefixcache.NewPressure(prefixCfg.PressureWindow)
|
||||
prefixProvider = prefixSync
|
||||
|
||||
// Invalidate the prefix-cache index whenever a replica row is removed.
|
||||
// SetReplicaRemovedHook fires from the single chokepoint all removal paths
|
||||
// funnel through (RemoveNodeModel / RemoveAllNodeModelReplicas), so this
|
||||
// one hook covers every path: reconciler scale-down, probe reaper,
|
||||
// health-monitor reap, RemoteUnloaderAdapter, and the router. Registering
|
||||
// it only inside this enabled block keeps the disabled path a true no-op
|
||||
// (the registry stays hook-less).
|
||||
registry.SetReplicaRemovedHook(func(model, node string, replica int) {
|
||||
if replica < 0 {
|
||||
prefixSync.InvalidateNode(model, node)
|
||||
} else {
|
||||
prefixSync.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: replica})
|
||||
}
|
||||
})
|
||||
|
||||
distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 {
|
||||
return prefixcache.ExtractChain(model, prompt, prefixCfg)
|
||||
}
|
||||
|
||||
// Apply peers' observations/invalidations to the same Sync. ApplyObserve
|
||||
// and ApplyInvalidate update only the local index and do not re-publish,
|
||||
// so there is no broadcast loop.
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheObserve, func(ev messaging.PrefixCacheObserveEvent) {
|
||||
prefixSync.ApplyObserve(ev, time.Now())
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheObserve, err)
|
||||
}
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheInvalidate, func(ev messaging.PrefixCacheInvalidateEvent) {
|
||||
prefixSync.ApplyInvalidate(ev)
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheInvalidate, err)
|
||||
}
|
||||
|
||||
// Background eviction: sweep idle entries on the app context. Stopped
|
||||
// when the app context is cancelled (mirrors the reconciler loop which
|
||||
// also runs on options.Context). TTL/2 keeps stale entries from
|
||||
// outliving their idle window by more than half a TTL.
|
||||
evictInterval := prefixCfg.TTL / 2
|
||||
go func() {
|
||||
ticker := time.NewTicker(evictInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-cfg.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
prefixSync.Evict(time.Now())
|
||||
}
|
||||
}
|
||||
}()
|
||||
xlog.Info("Prefix-cache-aware routing enabled", "ttl", prefixCfg.TTL, "evictInterval", evictInterval)
|
||||
} else {
|
||||
xlog.Info("Prefix-cache-aware routing disabled: using round-robin routing")
|
||||
}
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
if configLoader != nil {
|
||||
@@ -252,6 +337,9 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
ConflictResolver: conflictResolver,
|
||||
PrefixProvider: prefixProvider,
|
||||
PrefixConfig: prefixCfg,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
|
||||
@@ -268,6 +356,8 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
Interval: 30 * time.Second,
|
||||
ScaleDownDelay: 5 * time.Minute,
|
||||
ProbeStaleAfter: 2 * time.Minute,
|
||||
Pressure: pressure,
|
||||
PressureThreshold: prefixCfg.PressureScaleThreshold,
|
||||
})
|
||||
|
||||
// Create ModelRouterAdapter to wire into ModelLoader
|
||||
|
||||
@@ -123,14 +123,14 @@ var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
|
||||
})
|
||||
|
||||
It("ModelTTS forwards the request context to the SmartRouter", func() {
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -94,6 +95,22 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
// Make the rendered prompt's prefix chain available to the distributed router
|
||||
// for prefix-cache-aware node selection. No-op in single-process mode. The
|
||||
// model id MUST match the id ModelOptions feeds to model.WithModelID, so both
|
||||
// use the shared config.ModelConfig.ModelID() helper (Name with a fallback to
|
||||
// Model) or the chain salt and the tracking key would diverge.
|
||||
//
|
||||
// s is empty for UseTokenizerTemplate models (the backend tokenizes the
|
||||
// structured messages itself), so fall back to a prefix-stable serialization
|
||||
// of the messages - otherwise prefix routing would silently degrade to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
chainSource := s
|
||||
if chainSource == "" {
|
||||
chainSource = messagesPrefixSource(messages)
|
||||
}
|
||||
ctx = distributedhdr.MaybeWithPrefixChain(ctx, c.ModelID(), chainSource)
|
||||
|
||||
opts := ModelOptions(*c, o, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -34,16 +34,11 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
|
||||
}
|
||||
|
||||
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
|
||||
name := c.Name
|
||||
if name == "" {
|
||||
name = c.Model
|
||||
}
|
||||
|
||||
defOpts := []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithModel(c.Model),
|
||||
model.WithContext(so.Context),
|
||||
model.WithModelID(name),
|
||||
model.WithModelID(c.ModelID()),
|
||||
}
|
||||
|
||||
threads := 1
|
||||
@@ -244,13 +239,13 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
|
||||
if c.Backend == "cloud-proxy" {
|
||||
opts.Proxy = &pb.ProxyOptions{
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,6 +323,12 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
metadata["enable_thinking"] = "true"
|
||||
}
|
||||
}
|
||||
// Forward the effective reasoning effort so the backend can pass it to the
|
||||
// jinja chat template (chat_template_kwargs.reasoning_effort) — the lever
|
||||
// models like gpt-oss / LFM2.5 actually read, distinct from enable_thinking.
|
||||
if c.ReasoningEffort != "" {
|
||||
metadata["reasoning_effort"] = c.ReasoningEffort
|
||||
}
|
||||
pbOpts.Metadata = metadata
|
||||
|
||||
// Logprobs and TopLogprobs are set by the caller if provided
|
||||
|
||||
@@ -75,3 +75,25 @@ var _ = Describe("gRPCPredictOpts enable_thinking metadata", func() {
|
||||
Expect(opts.Metadata).ToNot(HaveKey("enable_thinking"))
|
||||
})
|
||||
})
|
||||
|
||||
// Guards forwarding the effective reasoning_effort into PredictOptions.Metadata,
|
||||
// where the backend passes it to the jinja chat template (chat_template_kwargs)
|
||||
// so models like gpt-oss / LFM2.5 honor it.
|
||||
var _ = Describe("gRPCPredictOpts reasoning_effort metadata", func() {
|
||||
withEffort := func(effort string) config.ModelConfig {
|
||||
cfg := config.ModelConfig{}
|
||||
cfg.SetDefaults()
|
||||
cfg.ReasoningEffort = effort
|
||||
return cfg
|
||||
}
|
||||
|
||||
It("forwards reasoning_effort when set", func() {
|
||||
opts := gRPCPredictOpts(withEffort("none"), "/tmp/models")
|
||||
Expect(opts.Metadata).To(HaveKeyWithValue("reasoning_effort", "none"))
|
||||
})
|
||||
|
||||
It("omits reasoning_effort when empty", func() {
|
||||
opts := gRPCPredictOpts(withEffort(""), "/tmp/models")
|
||||
Expect(opts.Metadata).ToNot(HaveKey("reasoning_effort"))
|
||||
})
|
||||
})
|
||||
|
||||
36
core/backend/prefix_source.go
Normal file
36
core/backend/prefix_source.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// messagesPrefixSource builds a deterministic, prefix-stable serialization of a
|
||||
// chat conversation for prefix-cache-aware routing. It is the fallback used when
|
||||
// the frontend did not render a prompt string: models with
|
||||
// config.TemplateConfig.UseTokenizerTemplate tokenize the structured messages
|
||||
// backend-side, so the frontend's rendered prompt is empty and a chain built
|
||||
// from it would always be empty - silently degrading prefix routing to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
//
|
||||
// Messages are emitted head-first in turn order (role line + content line per
|
||||
// message), so two conversations sharing a leading system prompt and early turns
|
||||
// share a leading byte prefix. That is exactly what ExtractChain hashes into a
|
||||
// shared chain prefix, landing both requests on the same cache-warm replica.
|
||||
func messagesPrefixSource(messages schema.Messages) string {
|
||||
var b strings.Builder
|
||||
for _, m := range messages {
|
||||
b.WriteString(m.Role)
|
||||
b.WriteByte('\n')
|
||||
content := m.StringContent
|
||||
if content == "" {
|
||||
if s, ok := m.Content.(string); ok {
|
||||
content = s
|
||||
}
|
||||
}
|
||||
b.WriteString(content)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
53
core/backend/prefix_source_internal_test.go
Normal file
53
core/backend/prefix_source_internal_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("messagesPrefixSource", func() {
|
||||
mk := func(role, content string) schema.Message {
|
||||
return schema.Message{Role: role, StringContent: content}
|
||||
}
|
||||
|
||||
It("serializes messages head-first in turn order", func() {
|
||||
got := messagesPrefixSource(schema.Messages{
|
||||
mk("system", "You are helpful."),
|
||||
mk("user", "Hi"),
|
||||
})
|
||||
Expect(got).To(Equal("system\nYou are helpful.\nuser\nHi\n"))
|
||||
})
|
||||
|
||||
It("is deterministic across calls for the same conversation", func() {
|
||||
conv := schema.Messages{mk("system", "S"), mk("user", "U")}
|
||||
Expect(messagesPrefixSource(conv)).To(Equal(messagesPrefixSource(conv)))
|
||||
})
|
||||
|
||||
It("shares a leading byte prefix when the system prompt is shared", func() {
|
||||
shared := "system\nShared system prompt.\nuser\n"
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question A")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question B")})
|
||||
Expect(strings.HasPrefix(a, shared)).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, shared)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does NOT share a prefix when the system prompt differs", func() {
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Prompt A"), mk("user", "Q")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Prompt B"), mk("user", "Q")})
|
||||
Expect(strings.HasPrefix(a, "system\nPrompt A")).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, "system\nPrompt B")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns empty for no messages", func() {
|
||||
Expect(messagesPrefixSource(nil)).To(Equal(""))
|
||||
})
|
||||
|
||||
It("falls back to Content when StringContent is empty", func() {
|
||||
got := messagesPrefixSource(schema.Messages{{Role: "user", Content: "plain"}})
|
||||
Expect(got).To(Equal("user\nplain\n"))
|
||||
})
|
||||
})
|
||||
@@ -20,11 +20,32 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
// newTTSRequest assembles the gRPC TTSRequest from the per-request inputs. The
|
||||
// optional instructions string is only attached when non-empty so backends can
|
||||
// distinguish "no per-request instruction" (fall back to YAML) from an explicit
|
||||
// empty one. params is forwarded as-is (nil when unset).
|
||||
func newTTSRequest(text, modelPath, voice, dst, language, instructions string, params map[string]string) *proto.TTSRequest {
|
||||
req := &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: dst,
|
||||
Language: &language,
|
||||
Params: params,
|
||||
}
|
||||
if instructions != "" {
|
||||
req.Instructions = &instructions
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func ModelTTS(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -74,13 +95,9 @@ func ModelTTS(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: filePath,
|
||||
Language: &language,
|
||||
})
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, filePath, language, instructions, params)
|
||||
|
||||
res, err := ttsModel.TTS(ctx, ttsRequest)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
@@ -128,7 +145,9 @@ func ModelTTSStream(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -177,12 +196,10 @@ func ModelTTSStream(
|
||||
var totalPCMBytes int
|
||||
snippetCapped := false
|
||||
|
||||
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Language: &language,
|
||||
}, func(reply *proto.Reply) {
|
||||
// Streaming TTS writes to the HTTP response, not a file, so dst is empty.
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, "", language, instructions, params)
|
||||
|
||||
err = ttsModel.TTSStream(ctx, ttsRequest, func(reply *proto.Reply) {
|
||||
// First message contains sample rate info
|
||||
if !headerSent && len(reply.Message) > 0 {
|
||||
var info map[string]any
|
||||
|
||||
42
core/backend/tts_test.go
Normal file
42
core/backend/tts_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
// Specs for the TTSRequest assembly that carries the per-request
|
||||
// instructions/params from the OpenAI `instructions` field (and the LocalAI
|
||||
// `params` extension) through to the gRPC boundary. Before this plumbing the
|
||||
// instruction value was dropped before reaching the backend; these specs pin
|
||||
// that it now survives, and that the empty case stays backward compatible.
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("newTTSRequest", func() {
|
||||
It("attaches the instructions when a per-request value is set", func() {
|
||||
req := newTTSRequest("hi", "/m", "alloy", "/out.wav", "en", "cheerful narrator", nil)
|
||||
Expect(req.Instructions).ToNot(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal("cheerful narrator"))
|
||||
Expect(req.GetText()).To(Equal("hi"))
|
||||
Expect(req.GetVoice()).To(Equal("alloy"))
|
||||
Expect(req.GetDst()).To(Equal("/out.wav"))
|
||||
Expect(req.GetLanguage()).To(Equal("en"))
|
||||
})
|
||||
|
||||
It("leaves instructions unset when empty so backends fall back to YAML", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.Instructions).To(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal(""))
|
||||
})
|
||||
|
||||
It("forwards per-request params through to the backend", func() {
|
||||
params := map[string]string{"exaggeration": "0.7", "cfg_weight": "0.3"}
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", params)
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("exaggeration", "0.7"))
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("cfg_weight", "0.3"))
|
||||
})
|
||||
|
||||
It("leaves params nil when none are supplied", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.GetParams()).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -52,10 +52,28 @@ type AgentWorkerCMD struct {
|
||||
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
|
||||
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
|
||||
|
||||
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
|
||||
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
|
||||
// DistributedRequireAuth is the umbrella switch; for the agent worker (which
|
||||
// has no file-transfer server) it implies NATS auth is required.
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch implying --nats-require-auth (agent workers have no file-transfer server)" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
|
||||
// Timeouts
|
||||
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
|
||||
}
|
||||
|
||||
// natsAuthRequired reports whether NATS JWT credentials must be present — the
|
||||
// granular flag or the umbrella (LOCALAI_DISTRIBUTED_REQUIRE_AUTH).
|
||||
func (cmd *AgentWorkerCMD) natsAuthRequired() bool {
|
||||
return cmd.NatsRequireAuth || cmd.DistributedRequireAuth
|
||||
}
|
||||
|
||||
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
|
||||
|
||||
@@ -81,15 +99,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
registrationBody["token"] = cmd.RegistrationToken
|
||||
}
|
||||
|
||||
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
// Context cancelled on shutdown — used by registration waits, heartbeat, and
|
||||
// other background goroutines.
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Acquire credentials via (re)registration. When the bus requires auth and no
|
||||
// static fallback is configured, wait through admin approval until the
|
||||
// frontend mints credentials rather than starting unauthenticated.
|
||||
credMgr := workerregistry.NewNATSCredentialManager(
|
||||
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
|
||||
return regClient.RegisterFull(ctx, registrationBody)
|
||||
},
|
||||
cmd.natsAuthRequired() && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
|
||||
)
|
||||
res, err := credMgr.Acquire(shutdownCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
nodeID := res.ID
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||
|
||||
// Use provisioned API token if none was set
|
||||
if cmd.APIToken == "" {
|
||||
cmd.APIToken = apiToken
|
||||
cmd.APIToken = res.APIToken
|
||||
}
|
||||
|
||||
// Start heartbeat
|
||||
@@ -98,14 +131,40 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cmd.NatsURL)
|
||||
// Resolve NATS credentials with precedence: explicit env override, then
|
||||
// frontend-minted (auto-refreshed before expiry), then service fallback.
|
||||
// Each static source must supply JWT and seed together.
|
||||
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
|
||||
var natsOpts []messaging.Option
|
||||
switch {
|
||||
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
|
||||
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
|
||||
case credMgr.HasCredentials():
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
|
||||
go func() {
|
||||
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
|
||||
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
|
||||
shutdownCancel()
|
||||
}
|
||||
}()
|
||||
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
|
||||
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
|
||||
case cmd.natsAuthRequired():
|
||||
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
|
||||
}
|
||||
if natsTLS.Enabled() {
|
||||
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
|
||||
}
|
||||
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
@@ -183,17 +242,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
|
||||
|
||||
// Wait for shutdown
|
||||
// Wait for an OS signal or an internal fatal condition (e.g. NATS
|
||||
// credentials became unrenewable), so the worker restarts and re-acquires
|
||||
// rather than lingering unable to serve.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
var runErr error
|
||||
select {
|
||||
case <-sigCh:
|
||||
case <-shutdownCtx.Done():
|
||||
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
|
||||
xlog.Error("Internal shutdown requested", "error", runErr)
|
||||
}
|
||||
|
||||
xlog.Info("Shutting down agent worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
dispatcher.Stop()
|
||||
mcpTools.CloseAllMCPSessions()
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
return nil
|
||||
return runErr
|
||||
}
|
||||
|
||||
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
|
||||
|
||||
@@ -145,19 +145,31 @@ type RunCMD struct {
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
RegistrationRequireAuth bool `env:"LOCALAI_REGISTRATION_REQUIRE_AUTH" default:"false" help:"Fail startup when distributed mode is enabled but LOCALAI_REGISTRATION_TOKEN is empty (node endpoints and worker file-transfer server would otherwise be unauthenticated)" group:"distributed"`
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch: require BOTH NATS JWT credentials and a registration token when distributed mode is enabled (implies --nats-require-auth and --registration-require-auth)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
NatsAccountSeed string `env:"LOCALAI_NATS_ACCOUNT_SEED" help:"NATS account signing seed (SU...) used to mint per-node worker JWTs at registration" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"NATS user JWT for the frontend (and agent workers) to publish control-plane messages" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"NATS user signing seed (SU...) paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsWorkerJWTTTL string `env:"LOCALAI_NATS_WORKER_JWT_TTL" help:"Lifetime of minted per-node NATS JWTs (e.g. 24h, default 24h)" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT credentials (service JWT + account seed) when distributed mode is enabled" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI); use with tls:// in --nats-url" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
|
||||
@@ -281,9 +293,53 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
if r.RegistrationRequireAuth {
|
||||
opts = append(opts, config.EnableRegistrationRequireAuth)
|
||||
}
|
||||
if r.DistributedRequireAuth {
|
||||
opts = append(opts, config.EnableDistributedRequireAuth)
|
||||
}
|
||||
if r.NatsAccountSeed != "" {
|
||||
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
|
||||
}
|
||||
if r.NatsServiceJWT != "" {
|
||||
opts = append(opts, config.WithNatsServiceJWT(r.NatsServiceJWT))
|
||||
}
|
||||
if r.NatsServiceSeed != "" {
|
||||
opts = append(opts, config.WithNatsServiceSeed(r.NatsServiceSeed))
|
||||
}
|
||||
if r.NatsWorkerJWTTTL != "" {
|
||||
d, err := time.ParseDuration(r.NatsWorkerJWTTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_WORKER_JWT_TTL %q: %w", r.NatsWorkerJWTTTL, err)
|
||||
}
|
||||
opts = append(opts, config.WithNatsWorkerJWTTTL(d))
|
||||
}
|
||||
if r.NatsRequireAuth {
|
||||
opts = append(opts, config.EnableNatsRequireAuth)
|
||||
}
|
||||
if r.NatsTLSCA != "" {
|
||||
opts = append(opts, config.WithNatsTLSCA(r.NatsTLSCA))
|
||||
}
|
||||
if r.NatsTLSCert != "" {
|
||||
opts = append(opts, config.WithNatsTLSCert(r.NatsTLSCert))
|
||||
}
|
||||
if r.NatsTLSKey != "" {
|
||||
opts = append(opts, config.WithNatsTLSKey(r.NatsTLSKey))
|
||||
}
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
if !r.DistributedPrefixCache {
|
||||
opts = append(opts, config.DisablePrefixCache)
|
||||
}
|
||||
if r.DistributedPrefixCacheTTL != "" {
|
||||
d, err := time.ParseDuration(r.DistributedPrefixCacheTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL %q: %w", r.DistributedPrefixCacheTTL, err)
|
||||
}
|
||||
opts = append(opts, config.WithPrefixCacheTTL(d))
|
||||
}
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||
options.Backend = t.Backend
|
||||
options.Model = t.Model
|
||||
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, "", nil, ml, opts, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -14,4 +14,5 @@ type Worker struct {
|
||||
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
|
||||
MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"`
|
||||
VLLMDistributed VLLMDistributed `cmd:"" name:"vllm" help:"Starts a vLLM data-parallel follower process. Multi-node DP for a single model: head runs the existing vllm backend with engine_args.data_parallel_size>1, followers run this command."`
|
||||
DS4Distributed DS4Distributed `cmd:"" name:"ds4-distributed" help:"Starts a ds4 distributed worker in standalone mode: owns a layer slice and dials the coordinator (pass ds4-worker args after --)"`
|
||||
}
|
||||
|
||||
108
core/cli/worker/worker_ds4.go
Normal file
108
core/cli/worker/worker_ds4.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
type DS4Distributed struct {
|
||||
WorkerFlags `embed:""`
|
||||
ExtraDS4Args string `name:"ds4-args" env:"LOCALAI_EXTRA_DS4_ARGS,EXTRA_DS4_ARGS" help:"Arguments passed to ds4-worker (e.g. '--role worker --model m.gguf --layers 20:output --coordinator HOST PORT')"`
|
||||
}
|
||||
|
||||
const (
|
||||
ds4WorkerBinaryName = "ds4-worker"
|
||||
ds4GalleryName = "ds4"
|
||||
)
|
||||
|
||||
// ds4WorkerArgs builds the argv for syscall.Exec when launching ds4-worker
|
||||
// directly: the binary path followed by the space-split extra args. An empty
|
||||
// extra string yields a bare invocation.
|
||||
func ds4WorkerArgs(binary, extra string) []string {
|
||||
args := []string{binary}
|
||||
args = append(args, strings.Fields(extra)...)
|
||||
return args
|
||||
}
|
||||
|
||||
func findDS4Backend(galleries string, systemState *system.SystemState, requireIntegrity bool) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed listing system backends", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
backend, ok := backends.Get(ds4GalleryName)
|
||||
if !ok {
|
||||
ml := model.NewModelLoader(systemState)
|
||||
var gals []config.Gallery
|
||||
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
|
||||
xlog.Error("failed loading galleries", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, ds4GalleryName, nil, true, requireIntegrity); err != nil {
|
||||
xlog.Error("ds4 backend not found, failed to install it", "error", err)
|
||||
return "", err
|
||||
}
|
||||
backends, err = gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
backend, ok = backends.Get(ds4GalleryName)
|
||||
if !ok {
|
||||
return "", errors.New("ds4 backend not found after install")
|
||||
}
|
||||
}
|
||||
|
||||
backendPath := filepath.Dir(backend.RunFile)
|
||||
if backendPath == "" {
|
||||
return "", errors.New("ds4 backend not found, install it first")
|
||||
}
|
||||
return filepath.Join(backendPath, ds4WorkerBinaryName), nil
|
||||
}
|
||||
|
||||
func (r *DS4Distributed) Run(ctx *cliContext.Context) error {
|
||||
if r.ExtraDS4Args == "" && len(os.Args) < 4 {
|
||||
return fmt.Errorf("usage: local-ai worker ds4-distributed -- --role worker --model <gguf> --layers <START:END|START:output> --coordinator <host> <port>")
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(r.BackendsPath),
|
||||
system.WithBackendSystemPath(r.BackendsSystemPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
worker, err := findDS4Backend(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ds4 bundles its own dynamic loader (lib/ld.so) for glibc compatibility,
|
||||
// like backend/cpp/ds4/run.sh does for grpc-server. Launch ds4-worker via
|
||||
// that loader when present; otherwise exec it directly. (This is a
|
||||
// deliberate divergence from worker_llamacpp.go, which has no bundled loader.)
|
||||
backendPath := filepath.Dir(worker)
|
||||
env := os.Environ()
|
||||
loader := filepath.Join(backendPath, "lib", "ld.so")
|
||||
if _, statErr := os.Stat(loader); statErr == nil {
|
||||
env = append(env, "LD_LIBRARY_PATH="+filepath.Join(backendPath, "lib")+":"+os.Getenv("LD_LIBRARY_PATH"))
|
||||
args := append([]string{loader}, ds4WorkerArgs(worker, r.ExtraDS4Args)...)
|
||||
return syscall.Exec(loader, args, env)
|
||||
}
|
||||
|
||||
return syscall.Exec(worker, ds4WorkerArgs(worker, r.ExtraDS4Args), env)
|
||||
}
|
||||
28
core/cli/worker/worker_ds4_test.go
Normal file
28
core/cli/worker/worker_ds4_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ds4 worker CLI", func() {
|
||||
It("uses the ds4 backend gallery name and worker binary name", func() {
|
||||
Expect(ds4GalleryName).To(Equal("ds4"))
|
||||
Expect(ds4WorkerBinaryName).To(Equal("ds4-worker"))
|
||||
})
|
||||
|
||||
It("assembles direct exec args as [binary, extra-split...]", func() {
|
||||
args := ds4WorkerArgs("/b/ds4-worker", "--role worker --model m.gguf --layers 20:output --coordinator 10.0.0.1 1234")
|
||||
Expect(args).To(Equal([]string{
|
||||
"/b/ds4-worker",
|
||||
"--role", "worker",
|
||||
"--model", "m.gguf",
|
||||
"--layers", "20:output",
|
||||
"--coordinator", "10.0.0.1", "1234",
|
||||
}))
|
||||
})
|
||||
|
||||
It("drops empty extra args to a bare binary invocation", func() {
|
||||
Expect(ds4WorkerArgs("/b/ds4-worker", "")).To(Equal([]string{"/b/ds4-worker"}))
|
||||
})
|
||||
})
|
||||
@@ -96,7 +96,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
|
||||
FrontendURL: r.RegisterTo,
|
||||
RegistrationToken: r.RegistrationToken,
|
||||
}
|
||||
nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
nodeID, _, _, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("registering with frontend: %w", regErr)
|
||||
}
|
||||
|
||||
@@ -58,65 +58,77 @@ func (c *RegistrationClient) setAuth(req *http.Request) {
|
||||
|
||||
// RegisterResponse is the JSON body returned by /api/node/register.
|
||||
type RegisterResponse struct {
|
||||
ID string `json:"id"`
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
NatsJWT string `json:"nats_jwt,omitempty"`
|
||||
NatsUserSeed string `json:"nats_user_seed,omitempty"`
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// (optionally) an auto-provisioned API token.
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
|
||||
// RegisterFull sends a single registration request and returns the full
|
||||
// response (node ID, approval status, and optional API token / NATS creds).
|
||||
// Re-registration is idempotent: the frontend preserves the node row and mints
|
||||
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
|
||||
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
url := c.baseURL() + "/api/node/register"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("creating request: %w", err)
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("posting to %s: %w", url, err)
|
||||
return nil, fmt.Errorf("posting to %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", "", fmt.Errorf("decoding response: %w", err)
|
||||
return nil, fmt.Errorf("decoding response: %w", err)
|
||||
}
|
||||
return result.ID, result.APIToken, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// optional credentials (API token for agent workers, NATS JWT when configured).
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
res, err := c.RegisterFull(ctx, body)
|
||||
if err != nil {
|
||||
return "", "", "", "", err
|
||||
}
|
||||
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
|
||||
}
|
||||
|
||||
// RegisterWithRetry retries registration with exponential backoff.
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
backoff := 2 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
|
||||
var nodeID, apiToken string
|
||||
var err error
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
nodeID, apiToken, err = c.Register(ctx, body)
|
||||
nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body)
|
||||
if err == nil {
|
||||
return nodeID, apiToken, nil
|
||||
return nodeID, apiToken, natsJWT, natsSeed, nil
|
||||
}
|
||||
if attempt == maxRetries {
|
||||
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", "", ctx.Err()
|
||||
return "", "", "", "", ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, maxBackoff)
|
||||
}
|
||||
return nodeID, apiToken, err
|
||||
return nodeID, apiToken, natsJWT, natsSeed, err
|
||||
}
|
||||
|
||||
// Heartbeat sends a single heartbeat POST with the given body.
|
||||
|
||||
200
core/cli/workerregistry/credentials.go
Normal file
200
core/cli/workerregistry/credentials.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// statusPending mirrors nodes.StatusPending. It is duplicated rather than
|
||||
// imported so the lightweight registration client does not pull in the nodes
|
||||
// package (and its gorm/DB dependencies).
|
||||
const statusPending = "pending"
|
||||
|
||||
// defaultMaxAttempts bounds how many times Acquire registers (and how many
|
||||
// consecutive times RefreshLoop may fail) before giving up. It is high enough
|
||||
// to ride out a slow admin approval or a transient frontend outage, but finite
|
||||
// so an unauthorized/unapprovable worker exits and surfaces the problem (via a
|
||||
// non-zero exit and the resulting restart) rather than waiting forever.
|
||||
const defaultMaxAttempts = 100
|
||||
|
||||
// RegisterFunc performs one idempotent registration round-trip.
|
||||
type RegisterFunc func(ctx context.Context) (*RegisterResponse, error)
|
||||
|
||||
// NATSCredentialManager acquires NATS credentials at startup — waiting through
|
||||
// admin approval when required — and refreshes them before the minted JWT
|
||||
// expires, by re-registering (which mints a fresh JWT). The live NATS
|
||||
// connection adopts a refreshed JWT on its next reconnect via Provider. Safe
|
||||
// for concurrent use.
|
||||
//
|
||||
// It addresses two failure modes: a worker that needs credentials but registers
|
||||
// while still pending approval (it would otherwise give up and never connect),
|
||||
// and a long-running worker whose 24h JWT expires with no way to renew it.
|
||||
type NATSCredentialManager struct {
|
||||
register RegisterFunc
|
||||
requireCreds bool // block until credentials are present (frontend minting in use)
|
||||
|
||||
// Tunables; defaults set by NewNATSCredentialManager, overridable in tests.
|
||||
initialBackoff time.Duration
|
||||
maxBackoff time.Duration
|
||||
maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited)
|
||||
refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed
|
||||
refreshRetry time.Duration
|
||||
expiryOf func(jwt string) (time.Time, bool)
|
||||
|
||||
mu sync.RWMutex
|
||||
jwt string
|
||||
seed string
|
||||
nodeID string
|
||||
}
|
||||
|
||||
// NewNATSCredentialManager builds a manager over register. When requireCreds is
|
||||
// true, Acquire blocks until the node is approved and credentials are minted.
|
||||
func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager {
|
||||
return &NATSCredentialManager{
|
||||
register: register,
|
||||
requireCreds: requireCreds,
|
||||
initialBackoff: 2 * time.Second,
|
||||
maxBackoff: 30 * time.Second,
|
||||
maxAttempts: defaultMaxAttempts,
|
||||
refreshLead: 0.75,
|
||||
refreshRetry: 30 * time.Second,
|
||||
expiryOf: jwtExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token
|
||||
// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT).
|
||||
func jwtExpiry(token string) (time.Time, bool) {
|
||||
if token == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
uc, err := natsauth.DecodeUserClaims(token)
|
||||
if err != nil || uc.Expires == 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Unix(uc.Expires, 0), true
|
||||
}
|
||||
|
||||
func (m *NATSCredentialManager) store(res *RegisterResponse) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.nodeID = res.ID
|
||||
if res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed
|
||||
}
|
||||
}
|
||||
|
||||
// Current returns the latest NATS credentials (both empty until acquired).
|
||||
func (m *NATSCredentialManager) Current() (jwt, seed string) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.jwt, m.seed
|
||||
}
|
||||
|
||||
// NodeID returns the node ID from the most recent registration.
|
||||
func (m *NATSCredentialManager) NodeID() string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.nodeID
|
||||
}
|
||||
|
||||
// Provider returns a callback compatible with messaging.WithUserJWTProvider,
|
||||
// supplying the current credentials on each (re)connect.
|
||||
func (m *NATSCredentialManager) Provider() func() (string, string) {
|
||||
return m.Current
|
||||
}
|
||||
|
||||
// HasCredentials reports whether complete NATS credentials have been obtained.
|
||||
func (m *NATSCredentialManager) HasCredentials() bool {
|
||||
jwt, seed := m.Current()
|
||||
return jwt != "" && seed != ""
|
||||
}
|
||||
|
||||
// Acquire registers and, when requireCreds is set, keeps re-registering with
|
||||
// exponential backoff until the node is approved (status != pending) and
|
||||
// credentials are minted. Without requireCreds it returns the first successful
|
||||
// response (the historical one-shot behavior, preserved for anonymous NATS).
|
||||
func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) {
|
||||
backoff := m.initialBackoff
|
||||
var lastReason error
|
||||
for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ {
|
||||
res, err := m.register(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
lastReason = err
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
case !m.requireCreds:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
case res.Status == statusPending:
|
||||
lastReason = fmt.Errorf("node %s still pending admin approval", res.ID)
|
||||
xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
case res.NatsJWT == "" || res.NatsUserSeed == "":
|
||||
lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID)
|
||||
xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
default:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, m.maxBackoff)
|
||||
}
|
||||
return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason)
|
||||
}
|
||||
|
||||
// RefreshLoop re-registers to mint a fresh JWT before the current one expires,
|
||||
// updating the credentials returned by Current/Provider so the NATS connection
|
||||
// adopts them on its next reconnect. It returns nil when ctx is cancelled or
|
||||
// when the current credential has no expiry (nothing to refresh), and a non-nil
|
||||
// error after maxAttempts consecutive refresh failures — letting the caller
|
||||
// exit the worker so it restarts and re-acquires (or surfaces the outage)
|
||||
// rather than silently drifting toward an expired, unrenewable JWT.
|
||||
func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error {
|
||||
failures := 0
|
||||
for {
|
||||
jwt, _ := m.Current()
|
||||
exp, ok := m.expiryOf(jwt)
|
||||
if !ok {
|
||||
xlog.Debug("NATS credential has no expiry; refresh loop exiting")
|
||||
return nil
|
||||
}
|
||||
wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(wait):
|
||||
}
|
||||
|
||||
res, err := m.register(ctx)
|
||||
if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.store(res)
|
||||
failures = 0
|
||||
xlog.Info("Refreshed NATS credentials", "node", res.ID)
|
||||
continue
|
||||
}
|
||||
failures++
|
||||
if err != nil {
|
||||
xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err)
|
||||
} else {
|
||||
xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures)
|
||||
}
|
||||
if m.maxAttempts > 0 && failures >= m.maxAttempts {
|
||||
return fmt.Errorf("NATS credential refresh failed %d times in a row", failures)
|
||||
}
|
||||
// Back off before retrying so a persistent failure near expiry does not spin.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(m.refreshRetry):
|
||||
}
|
||||
}
|
||||
}
|
||||
198
core/cli/workerregistry/credentials_test.go
Normal file
198
core/cli/workerregistry/credentials_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestWorkerRegistry(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "WorkerRegistry")
|
||||
}
|
||||
|
||||
// fakeRegister returns a sequence of canned responses/errors, one per call, and
|
||||
// records how many times it was invoked. The last entry repeats once exhausted.
|
||||
type fakeRegister struct {
|
||||
mu sync.Mutex
|
||||
steps []step
|
||||
calls int
|
||||
}
|
||||
|
||||
type step struct {
|
||||
res *RegisterResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeRegister) fn() RegisterFunc {
|
||||
return func(context.Context) (*RegisterResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
i := f.calls
|
||||
f.calls++
|
||||
if i >= len(f.steps) {
|
||||
i = len(f.steps) - 1
|
||||
}
|
||||
return f.steps[i].res, f.steps[i].err
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeRegister) count() int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls
|
||||
}
|
||||
|
||||
var _ = Describe("NATSCredentialManager", func() {
|
||||
approved := func(jwt, seed string) *RegisterResponse {
|
||||
return &RegisterResponse{ID: "node-1", Status: "healthy", NatsJWT: jwt, NatsUserSeed: seed}
|
||||
}
|
||||
pending := &RegisterResponse{ID: "node-1", Status: "pending"}
|
||||
|
||||
Describe("Acquire (#4 — wait through admin approval)", func() {
|
||||
It("keeps re-registering until the node is approved and credentials are minted", func() {
|
||||
f := &fakeRegister{steps: []step{
|
||||
{res: pending}, // not approved yet
|
||||
{res: approved("", "")}, // approved but JWT not minted yet
|
||||
{res: approved("jwt-1", "seed-1")}, // finally minted
|
||||
}}
|
||||
m := NewNATSCredentialManager(f.fn(), true /* requireCreds */)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.ID).To(Equal("node-1"))
|
||||
Expect(f.count()).To(Equal(3))
|
||||
|
||||
jwt, seed := m.Current()
|
||||
Expect(jwt).To(Equal("jwt-1"))
|
||||
Expect(seed).To(Equal("seed-1"))
|
||||
Expect(m.HasCredentials()).To(BeTrue())
|
||||
Expect(m.NodeID()).To(Equal("node-1"))
|
||||
})
|
||||
|
||||
It("returns immediately on the first success when credentials are not required (anonymous NATS)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), false /* requireCreds */)
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Status).To(Equal("pending"))
|
||||
Expect(f.count()).To(Equal(1))
|
||||
Expect(m.HasCredentials()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("aborts when the context is cancelled while waiting for approval", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = 10 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := m.Acquire(ctx)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
})
|
||||
|
||||
It("gives up after a bounded number of attempts so the worker exits and alerts", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}} // never approved
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
m.maxAttempts = 5
|
||||
|
||||
_, err := m.Acquire(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("after 5 attempts"))
|
||||
Expect(err.Error()).To(ContainSubstring("pending admin approval"))
|
||||
Expect(f.count()).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("RefreshLoop (#5 — renew before the JWT expires)", func() {
|
||||
It("re-registers before expiry and updates the credentials served to new connections", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("jwt-2", "seed-2")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
// jwt-1 expires soon; jwt-2 is long-lived so the loop then idles.
|
||||
m.expiryOf = func(jwt string) (time.Time, bool) {
|
||||
switch jwt {
|
||||
case "jwt-1":
|
||||
return time.Now().Add(40 * time.Millisecond), true
|
||||
case "jwt-2":
|
||||
return time.Now().Add(time.Hour), true
|
||||
default:
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() { _ = m.RefreshLoop(ctx) }()
|
||||
|
||||
Eventually(func() string {
|
||||
jwt, _ := m.Current()
|
||||
return jwt
|
||||
}, "2s", "10ms").Should(Equal("jwt-2"))
|
||||
})
|
||||
|
||||
It("returns an error after the bounded number of consecutive failures so the caller can exit", func() {
|
||||
f := &fakeRegister{steps: []step{{err: context.DeadlineExceeded}}} // refresh always fails
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
m.maxAttempts = 3
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Now().Add(time.Millisecond), true }
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- m.RefreshLoop(context.Background()) }()
|
||||
Eventually(errCh, "2s").Should(Receive(MatchError(ContainSubstring("3 times in a row"))))
|
||||
})
|
||||
|
||||
It("exits promptly when the current credential has no expiry (nothing to refresh)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("x", "y")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Time{}, false }
|
||||
m.store(approved("static", "seed"))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { _ = m.RefreshLoop(context.Background()); close(done) }()
|
||||
Eventually(done, "1s").Should(BeClosed())
|
||||
Expect(f.count()).To(Equal(0)) // never tried to re-register
|
||||
})
|
||||
})
|
||||
|
||||
Describe("jwtExpiry default", func() {
|
||||
It("decodes the expiry of a real minted worker JWT", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
cfg := natsauth.Config{AccountSeed: string(seed), WorkerJWTTTL: time.Hour}
|
||||
token, _, err := cfg.MintWorkerJWT("node-1", "backend")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
exp, ok := jwtExpiry(token)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(exp).To(BeTemporally("~", time.Now().Add(time.Hour), 2*time.Minute))
|
||||
})
|
||||
|
||||
It("reports no expiry for an empty or undecodable token", func() {
|
||||
_, ok := jwtExpiry("")
|
||||
Expect(ok).To(BeFalse())
|
||||
_, ok = jwtExpiry("not-a-jwt")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -22,9 +22,11 @@ const (
|
||||
UsecaseRerank = "rerank"
|
||||
UsecaseDetection = "detection"
|
||||
UsecaseVAD = "vad"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseFaceRecognition = "face_recognition"
|
||||
UsecaseSpeakerRecognition = "speaker_recognition"
|
||||
)
|
||||
|
||||
// GRPCMethod identifies a Backend service RPC from backend.proto.
|
||||
@@ -47,6 +49,11 @@ const (
|
||||
MethodAudioTransform GRPCMethod = "AudioTransform"
|
||||
MethodDiarize GRPCMethod = "Diarize"
|
||||
MethodAudioToAudioStream GRPCMethod = "AudioToAudioStream"
|
||||
MethodFaceVerify GRPCMethod = "FaceVerify"
|
||||
MethodFaceAnalyze GRPCMethod = "FaceAnalyze"
|
||||
MethodVoiceVerify GRPCMethod = "VoiceVerify"
|
||||
MethodVoiceEmbed GRPCMethod = "VoiceEmbed"
|
||||
MethodVoiceAnalyze GRPCMethod = "VoiceAnalyze"
|
||||
)
|
||||
|
||||
// UsecaseInfo describes a single known_usecase value and how it maps
|
||||
@@ -154,6 +161,16 @@ var UsecaseInfoMap = map[string]UsecaseInfo{
|
||||
GRPCMethod: MethodAudioToAudioStream,
|
||||
Description: "Self-contained any-to-any audio model for the Realtime API — accepts microphone audio and emits speech + transcript (+ optional function calls) from a single backend via the AudioToAudioStream RPC.",
|
||||
},
|
||||
UsecaseFaceRecognition: {
|
||||
Flag: FLAG_FACE_RECOGNITION,
|
||||
GRPCMethod: MethodFaceVerify,
|
||||
Description: "Face recognition — verify identity, analyze attributes (age/gender/emotion) via FaceVerify and FaceAnalyze RPCs.",
|
||||
},
|
||||
UsecaseSpeakerRecognition: {
|
||||
Flag: FLAG_SPEAKER_RECOGNITION,
|
||||
GRPCMethod: MethodVoiceVerify,
|
||||
Description: "Speaker recognition — verify identity, embed and analyze voice via VoiceVerify, VoiceEmbed and VoiceAnalyze RPCs.",
|
||||
},
|
||||
}
|
||||
|
||||
// BackendCapability describes which gRPC methods and usecases a backend supports.
|
||||
@@ -198,6 +215,13 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
AcceptsVideos: true,
|
||||
Description: "vLLM engine — high-throughput LLM serving with optional multimodal",
|
||||
},
|
||||
"sglang": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodTokenizeString},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseTokenize, UsecaseVision},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
AcceptsImages: true,
|
||||
Description: "SGLang — fast LLM inference with structured generation and optional vision",
|
||||
},
|
||||
"vllm-omni": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodGenerateImage, MethodGenerateVideo, MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseImage, UsecaseVideo, UsecaseTTS, UsecaseVision},
|
||||
@@ -291,6 +315,12 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "NVIDIA NeMo speech recognition",
|
||||
},
|
||||
"parakeet-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "NVIDIA NeMo Parakeet ASR (parakeet.cpp)",
|
||||
},
|
||||
"qwen-asr": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
@@ -309,6 +339,18 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
Description: "VibeVoice — bidirectional speech (transcription and synthesis)",
|
||||
},
|
||||
"vibevoice-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodTTS, MethodTTSStream},
|
||||
PossibleUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
Description: "VibeVoice C++ — bidirectional speech, C++ backend with streaming TTS",
|
||||
},
|
||||
"sherpa-onnx": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodTTS, MethodTTSStream, MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseTranscript, UsecaseTTS, UsecaseVAD},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "Sherpa-ONNX — multi-model speech toolkit (ASR, TTS, VAD)",
|
||||
},
|
||||
|
||||
// --- TTS backends ---
|
||||
"piper": {
|
||||
@@ -353,6 +395,12 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Qwen TTS",
|
||||
},
|
||||
"qwen3-tts-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Qwen3 TTS C++ — text-to-speech, C++ backend",
|
||||
},
|
||||
"faster-qwen3-tts": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
@@ -434,6 +482,27 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR object detection",
|
||||
},
|
||||
"rfdetr-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodDetect},
|
||||
PossibleUsecases: []string{UsecaseDetection},
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR C++ object detection",
|
||||
},
|
||||
|
||||
// --- Face and speaker recognition backends ---
|
||||
"insightface": {
|
||||
GRPCMethods: []GRPCMethod{MethodEmbedding, MethodDetect, MethodFaceVerify, MethodFaceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseEmbeddings, UsecaseDetection, UsecaseFaceRecognition},
|
||||
DefaultUsecases: []string{UsecaseFaceRecognition},
|
||||
AcceptsImages: true,
|
||||
Description: "InsightFace — face detection, embedding, verification and attribute analysis",
|
||||
},
|
||||
"speaker-recognition": {
|
||||
GRPCMethods: []GRPCMethod{MethodVoiceVerify, MethodVoiceEmbed, MethodVoiceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseSpeakerRecognition},
|
||||
DefaultUsecases: []string{UsecaseSpeakerRecognition},
|
||||
Description: "Speaker recognition — voice identity verification and analysis",
|
||||
},
|
||||
"silero-vad": {
|
||||
GRPCMethods: []GRPCMethod{MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseVAD},
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -16,7 +18,29 @@ type DistributedConfig struct {
|
||||
NatsURL string // --nats-url / LOCALAI_NATS_URL
|
||||
StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint)
|
||||
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
// RegistrationRequireAuth fails startup when distributed mode is enabled but
|
||||
// RegistrationToken is empty. The default (false) keeps the historical
|
||||
// fail-open behavior with a loud warning; production should set it so the
|
||||
// node-register endpoints and the worker file-transfer server cannot run
|
||||
// unauthenticated. Mirrors NatsRequireAuth for the NATS bus.
|
||||
RegistrationRequireAuth bool // LOCALAI_REGISTRATION_REQUIRE_AUTH
|
||||
// RequireAuth is the umbrella switch (LOCALAI_DISTRIBUTED_REQUIRE_AUTH) for
|
||||
// distributed-mode auth: when true it implies BOTH NatsRequireAuth and
|
||||
// RegistrationRequireAuth, so a single knob locks down the bus and the
|
||||
// registration/file-transfer layer together. The granular flags remain
|
||||
// available to enforce just one layer.
|
||||
RequireAuth bool // LOCALAI_DISTRIBUTED_REQUIRE_AUTH
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
|
||||
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
|
||||
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
|
||||
NatsServiceJWT string // LOCALAI_NATS_SERVICE_JWT — user JWT for frontends / agent workers
|
||||
NatsServiceSeed string // LOCALAI_NATS_SERVICE_SEED — signing seed paired with service JWT
|
||||
NatsWorkerJWTTTL time.Duration // LOCALAI_NATS_WORKER_JWT_TTL — minted worker JWT lifetime (default 24h)
|
||||
NatsRequireAuth bool // LOCALAI_NATS_REQUIRE_AUTH — fail startup if NATS credentials are missing
|
||||
NatsTLSCA string // LOCALAI_NATS_TLS_CA — PEM file for private CA (server verify)
|
||||
NatsTLSCert string // LOCALAI_NATS_TLS_CERT — client cert for NATS mTLS
|
||||
NatsTLSKey string // LOCALAI_NATS_TLS_KEY — client key paired with NatsTLSCert
|
||||
|
||||
// S3 configuration (used when StorageURL is set)
|
||||
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
|
||||
@@ -49,6 +73,17 @@ type DistributedConfig struct {
|
||||
|
||||
AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"`
|
||||
JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"`
|
||||
|
||||
// PrefixCacheDisabled turns off prefix-cache-aware routing, falling back to
|
||||
// round-robin (the floor). Prefix-cache routing is ON by default in
|
||||
// distributed mode; this flag exists so operators can opt out. The CLI
|
||||
// surfaces a default-true --distributed-prefix-cache enable flag and sets
|
||||
// this when the operator passes --distributed-prefix-cache=false.
|
||||
PrefixCacheDisabled bool
|
||||
// PrefixCacheTTL is the idle-timeout for prefix-cache index entries and
|
||||
// drives the background eviction cadence (eviction runs every TTL/2). Zero
|
||||
// means use the prefixcache package default (5m).
|
||||
PrefixCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// Validate checks that the distributed configuration is internally consistent.
|
||||
@@ -65,10 +100,23 @@ func (c DistributedConfig) Validate() error {
|
||||
(c.StorageAccessKey == "" && c.StorageSecretKey != "") {
|
||||
return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty")
|
||||
}
|
||||
// Warn about missing registration token (not an error)
|
||||
// The registration token guards both the node HTTP register/heartbeat
|
||||
// endpoints and the worker file-transfer server (which fails open on an
|
||||
// empty token). Enforce it when registration auth is required (the granular
|
||||
// flag or the umbrella); otherwise warn.
|
||||
if c.RegistrationToken == "" {
|
||||
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
|
||||
if c.RegistrationAuthRequired() {
|
||||
return fmt.Errorf("registration auth is required (LOCALAI_REGISTRATION_REQUIRE_AUTH or LOCALAI_DISTRIBUTED_REQUIRE_AUTH) but LOCALAI_REGISTRATION_TOKEN is empty")
|
||||
}
|
||||
xlog.Warn("distributed mode running without registration token — node endpoints and the worker file-transfer server are unprotected; set LOCALAI_REGISTRATION_TOKEN, or LOCALAI_DISTRIBUTED_REQUIRE_AUTH=true to fail closed")
|
||||
}
|
||||
if err := c.NatsAuthConfig().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.NatsTLSFiles().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.NatsAuthConfig().WarnIfInsecure(true)
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
FlagMCPToolTimeout: c.MCPToolTimeout,
|
||||
@@ -112,6 +160,76 @@ func WithRegistrationToken(token string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsAccountSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsAccountSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceJWT(jwt string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceJWT = jwt
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsWorkerJWTTTL(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsWorkerJWTTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
var EnableNatsRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsRequireAuth = true
|
||||
}
|
||||
|
||||
// EnableRegistrationRequireAuth makes an empty registration token a hard error
|
||||
// in distributed mode (see DistributedConfig.RegistrationRequireAuth).
|
||||
var EnableRegistrationRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.RegistrationRequireAuth = true
|
||||
}
|
||||
|
||||
// EnableDistributedRequireAuth is the umbrella switch implying both
|
||||
// NatsRequireAuth and RegistrationRequireAuth (see DistributedConfig.RequireAuth).
|
||||
var EnableDistributedRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.RequireAuth = true
|
||||
}
|
||||
|
||||
// RegistrationAuthRequired reports whether an empty registration token must be
|
||||
// treated as a fatal misconfiguration — the granular flag or the umbrella.
|
||||
func (c DistributedConfig) RegistrationAuthRequired() bool {
|
||||
return c.RegistrationRequireAuth || c.RequireAuth
|
||||
}
|
||||
|
||||
// NatsAuthRequired reports whether NATS JWT credentials must be present — the
|
||||
// granular flag or the umbrella.
|
||||
func (c DistributedConfig) NatsAuthRequired() bool {
|
||||
return c.NatsRequireAuth || c.RequireAuth
|
||||
}
|
||||
|
||||
func WithNatsTLSCA(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCA = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSCert(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCert = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSKey(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSKey = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageURL(url string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageURL = url
|
||||
@@ -158,6 +276,20 @@ var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// DisablePrefixCache turns off prefix-cache-aware routing (falls back to
|
||||
// round-robin). Prefix-cache routing is enabled by default in distributed mode.
|
||||
var DisablePrefixCache = func(o *ApplicationConfig) {
|
||||
o.Distributed.PrefixCacheDisabled = true
|
||||
}
|
||||
|
||||
// WithPrefixCacheTTL sets the prefix-cache index idle-timeout (and the
|
||||
// background eviction cadence, which runs every TTL/2).
|
||||
func WithPrefixCacheTTL(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.PrefixCacheTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
// Flag names for distributed timeout / interval configuration. These are
|
||||
// the kebab-case identifiers kong derives from the matching RunCMD struct
|
||||
// fields; they appear in Validate error messages and any other operator-
|
||||
@@ -192,6 +324,44 @@ const (
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// NatsTLSFiles returns NATS TLS/mTLS PEM paths for the messaging client.
|
||||
func (c DistributedConfig) NatsTLSFiles() messaging.TLSFiles {
|
||||
return messaging.TLSFiles{
|
||||
CA: c.NatsTLSCA,
|
||||
Cert: c.NatsTLSCert,
|
||||
Key: c.NatsTLSKey,
|
||||
}
|
||||
}
|
||||
|
||||
// NatsMessagingOptions builds messaging client options (JWT + TLS) for distributed components.
|
||||
// Pass explicit userJWT/userSeed when set (e.g. worker overrides); empty uses service JWT from config.
|
||||
func (c DistributedConfig) NatsMessagingOptions(userJWT, userSeed string) []messaging.Option {
|
||||
var opts []messaging.Option
|
||||
jwt, seed := userJWT, userSeed
|
||||
if jwt == "" && seed == "" {
|
||||
auth := c.NatsAuthConfig()
|
||||
jwt, seed = auth.ServiceUserJWT, auth.ServiceUserSeed
|
||||
}
|
||||
if jwt != "" && seed != "" {
|
||||
opts = append(opts, messaging.WithUserJWT(jwt, seed))
|
||||
}
|
||||
if tls := c.NatsTLSFiles(); tls.Enabled() {
|
||||
opts = append(opts, messaging.WithTLS(tls))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// NatsAuthConfig builds pkg/natsauth settings from distributed configuration.
|
||||
func (c DistributedConfig) NatsAuthConfig() natsauth.Config {
|
||||
return natsauth.Config{
|
||||
AccountSeed: c.NatsAccountSeed,
|
||||
ServiceUserJWT: c.NatsServiceJWT,
|
||||
ServiceUserSeed: c.NatsServiceSeed,
|
||||
WorkerJWTTTL: c.NatsWorkerJWTTTL,
|
||||
RequireAuth: c.NatsAuthRequired(),
|
||||
}
|
||||
}
|
||||
|
||||
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||
|
||||
@@ -88,3 +88,66 @@ var _ = Describe("DistributedConfig.Validate negative-duration errors", func() {
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("DistributedConfig.Validate registration auth", func() {
|
||||
It("rejects an empty registration token when RequireAuth is set", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RegistrationRequireAuth: true,
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_REQUIRE_AUTH"))
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
|
||||
})
|
||||
|
||||
It("accepts a set registration token when RequireAuth is set", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RegistrationToken: "s3cret",
|
||||
RegistrationRequireAuth: true,
|
||||
}
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
|
||||
It("warns but succeeds with an empty token when RequireAuth is unset", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
}
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects an empty token when the umbrella RequireAuth is set", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RequireAuth: true,
|
||||
// Provide NATS creds so only the registration-token gap remains.
|
||||
NatsServiceJWT: "jwt",
|
||||
NatsServiceSeed: "seed",
|
||||
NatsAccountSeed: "acct",
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_DISTRIBUTED_REQUIRE_AUTH"))
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
|
||||
})
|
||||
|
||||
It("the umbrella implies NATS auth is required", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RegistrationToken: "tok", // registration layer satisfied
|
||||
RequireAuth: true, // umbrella → NATS creds now required
|
||||
}
|
||||
Expect(c.NatsAuthRequired()).To(BeTrue())
|
||||
Expect(c.RegistrationAuthRequired()).To(BeTrue())
|
||||
// Missing NATS service JWT/seed must now be fatal.
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_NATS_REQUIRE_AUTH"))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -128,6 +128,22 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Advanced: true,
|
||||
Order: 21,
|
||||
},
|
||||
"reasoning_effort": {
|
||||
Section: "llm",
|
||||
Label: "Reasoning Effort",
|
||||
Description: "Default reasoning effort, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). A per-request reasoning_effort overrides it. 'none' also turns thinking off.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Unset (model default)"},
|
||||
{Value: "none", Label: "none (disable thinking)"},
|
||||
{Value: "minimal", Label: "minimal"},
|
||||
{Value: "low", Label: "low"},
|
||||
{Value: "medium", Label: "medium"},
|
||||
{Value: "high", Label: "high"},
|
||||
},
|
||||
Advanced: true,
|
||||
Order: 22,
|
||||
},
|
||||
"cache_type_k": {
|
||||
Section: "llm",
|
||||
Label: "KV Cache Type (K)",
|
||||
@@ -277,6 +293,21 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
AutocompleteProvider: ProviderModelsVAD,
|
||||
Order: 63,
|
||||
},
|
||||
"pipeline.reasoning_effort": {
|
||||
Section: "pipeline",
|
||||
Label: "Reasoning Effort",
|
||||
Description: "Reasoning effort for the pipeline's LLM, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). Overrides the LLM model's own reasoning_effort. 'none' also turns thinking off.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Default (model config)"},
|
||||
{Value: "none", Label: "none (disable thinking)"},
|
||||
{Value: "minimal", Label: "minimal"},
|
||||
{Value: "low", Label: "low"},
|
||||
{Value: "medium", Label: "medium"},
|
||||
{Value: "high", Label: "high"},
|
||||
},
|
||||
Order: 64,
|
||||
},
|
||||
|
||||
// --- Functions ---
|
||||
"function.grammar.parallel_calls": {
|
||||
|
||||
@@ -63,6 +63,13 @@ type ModelConfig struct {
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
|
||||
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||
|
||||
// ReasoningEffort is the default reasoning effort (none|minimal|low|medium|high)
|
||||
// for this model. A per-request reasoning_effort overrides it. It is forwarded
|
||||
// to the backend as the reasoning_effort chat_template_kwarg (see
|
||||
// gRPCPredictOpts), so jinja-templated models that key on it — e.g. gpt-oss
|
||||
// (Harmony) or LFM2.5 — honor it; "none" also toggles enable_thinking off.
|
||||
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
|
||||
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
LLMConfig `yaml:",inline" json:",inline"`
|
||||
@@ -487,6 +494,40 @@ type Pipeline struct {
|
||||
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
|
||||
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
|
||||
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
|
||||
|
||||
// ReasoningEffort sets the reasoning effort (none|minimal|low|medium|high) for
|
||||
// the pipeline's LLM without editing the LLM model config. Overrides the LLM's
|
||||
// own reasoning_effort. Unset leaves the LLM model config in charge.
|
||||
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
|
||||
}
|
||||
|
||||
// ApplyReasoningEffort resolves the effective reasoning effort — a per-request
|
||||
// value (requestEffort) overrides the config's own ReasoningEffort default —
|
||||
// stores it on the config so gRPCPredictOpts forwards it to the backend as the
|
||||
// reasoning_effort chat_template_kwarg, and maps it onto the enable_thinking
|
||||
// toggle the backend also reads:
|
||||
// - "none" always disables thinking.
|
||||
// - any explicit level enables it, UNLESS the config already disabled reasoning
|
||||
// (an operator's explicit disable wins over a request asking to think).
|
||||
//
|
||||
// An empty requestEffort keeps the config's own default. With no effort set
|
||||
// anywhere it is a no-op, leaving the model's reasoning settings untouched.
|
||||
func (c *ModelConfig) ApplyReasoningEffort(requestEffort string) {
|
||||
effort := requestEffort
|
||||
if effort == "" {
|
||||
effort = c.ReasoningEffort
|
||||
}
|
||||
c.ReasoningEffort = effort
|
||||
switch strings.ToLower(effort) {
|
||||
case "none":
|
||||
disable := true
|
||||
c.ReasoningConfig.DisableReasoning = &disable
|
||||
case "minimal", "low", "medium", "high":
|
||||
if c.ReasoningConfig.DisableReasoning == nil || !*c.ReasoningConfig.DisableReasoning {
|
||||
enable := false
|
||||
c.ReasoningConfig.DisableReasoning = &enable
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// @Description File configuration for model downloads
|
||||
@@ -694,6 +735,18 @@ func (c *ModelConfig) IsModelURL() bool {
|
||||
return uri.LooksLikeURL()
|
||||
}
|
||||
|
||||
// ModelID returns the identifier used to reference this model across the
|
||||
// system: the configured Name, falling back to Model when Name is empty.
|
||||
// This is the single source of truth for the id fed to model.WithModelID and
|
||||
// the prefix-cache chain salt; both MUST agree with the router's tracking key
|
||||
// or the prefix-cache salt diverges silently.
|
||||
func (c ModelConfig) ModelID() string {
|
||||
if c.Name != "" {
|
||||
return c.Name
|
||||
}
|
||||
return c.Model
|
||||
}
|
||||
|
||||
// ModelFileName returns the filename of the model
|
||||
// If the model is a URL, it will return the MD5 of the URL which is the filename
|
||||
func (c *ModelConfig) ModelFileName() string {
|
||||
|
||||
@@ -10,6 +10,23 @@ import (
|
||||
)
|
||||
|
||||
var _ = Describe("Test cases for config related functions", func() {
|
||||
Context("ModelID", func() {
|
||||
It("returns Name when set", func() {
|
||||
c := ModelConfig{Name: "my-name"}
|
||||
c.Model = "my-model"
|
||||
Expect(c.ModelID()).To(Equal("my-name"))
|
||||
})
|
||||
It("falls back to Model when Name is empty", func() {
|
||||
c := ModelConfig{}
|
||||
c.Model = "my-model"
|
||||
Expect(c.ModelID()).To(Equal("my-model"))
|
||||
})
|
||||
It("returns empty string when both are empty", func() {
|
||||
c := ModelConfig{}
|
||||
Expect(c.ModelID()).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Test Read configuration functions", func() {
|
||||
It("Test Validate", func() {
|
||||
tmp, err := os.CreateTemp("", "config.yaml")
|
||||
|
||||
52
core/config/reasoning_effort_test.go
Normal file
52
core/config/reasoning_effort_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// ApplyReasoningEffort resolves the effective reasoning effort (request value
|
||||
// overrides the model config default), stores it on the config so it reaches the
|
||||
// backend, and maps it onto the enable_thinking toggle.
|
||||
var _ = Describe("ModelConfig.ApplyReasoningEffort", func() {
|
||||
It("uses the request value over the config default", func() {
|
||||
c := &config.ModelConfig{ReasoningEffort: "high"}
|
||||
c.ApplyReasoningEffort("none")
|
||||
Expect(c.ReasoningEffort).To(Equal("none"))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("falls back to the config default when the request omits it", func() {
|
||||
c := &config.ModelConfig{ReasoningEffort: "none"}
|
||||
c.ApplyReasoningEffort("")
|
||||
Expect(c.ReasoningEffort).To(Equal("none"))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("enables thinking for an explicit effort level", func() {
|
||||
c := &config.ModelConfig{}
|
||||
c.ApplyReasoningEffort("medium")
|
||||
Expect(c.ReasoningEffort).To(Equal("medium"))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not let a level override an operator's config-level disable", func() {
|
||||
disabled := true
|
||||
c := &config.ModelConfig{}
|
||||
c.ReasoningConfig.DisableReasoning = &disabled
|
||||
c.ApplyReasoningEffort("high")
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is a no-op on the toggle when no effort is set anywhere", func() {
|
||||
c := &config.ModelConfig{}
|
||||
c.ApplyReasoningEffort("")
|
||||
Expect(c.ReasoningEffort).To(Equal(""))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -420,8 +420,9 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
remoteUnloader = d.Router.Unloader()
|
||||
}
|
||||
}
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
natsCfg := distCfg.NatsAuthConfig()
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, natsCfg)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken, natsCfg)
|
||||
|
||||
// Distributed SSE routes (job progress + agent events via NATS)
|
||||
if d := application.Distributed(); d != nil {
|
||||
|
||||
@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
|
||||
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
|
||||
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, "", nil, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ var knownPrefOnlyBackends = []schema.KnownBackend{
|
||||
{Name: "mlx-vlm", Modality: "text", AutoDetect: false, Description: "MLX vision-language models (preference-only)"},
|
||||
// ASR
|
||||
{Name: "whisperx", Modality: "asr", AutoDetect: false, Description: "WhisperX transcription (preference-only)"},
|
||||
{Name: "crispasr", Modality: "asr", AutoDetect: false, Description: "CrispASR multi-architecture transcription (preference-only)"},
|
||||
// TTS
|
||||
{Name: "kokoros", Modality: "tts", AutoDetect: false, Description: "Kokoros TTS (preference-only)"},
|
||||
{Name: "qwen-tts", Modality: "tts", AutoDetect: false, Description: "Qwen TTS (preference-only)"},
|
||||
|
||||
@@ -140,6 +140,7 @@ var _ = Describe("Backend Endpoints", func() {
|
||||
expectPrefOnly("trl", "text")
|
||||
expectPrefOnly("mlx-vlm", "text")
|
||||
expectPrefOnly("whisperx", "asr")
|
||||
expectPrefOnly("crispasr", "asr")
|
||||
expectPrefOnly("kokoros", "tts")
|
||||
expectPrefOnly("qwen-tts", "tts")
|
||||
expectPrefOnly("qwen3-tts-cpp", "tts")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -25,7 +26,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
)
|
||||
|
||||
// nodeError builds a schema.ErrorResponse for node endpoints.
|
||||
@@ -87,7 +90,7 @@ type RegisterNodeRequest struct {
|
||||
// RegisterNodeEndpoint registers a new backend node.
|
||||
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
|
||||
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req RegisterNodeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
@@ -215,13 +218,15 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
}
|
||||
|
||||
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
|
||||
// For agent workers, it also provisions an API key so they can call the inference API.
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
id := c.Param("id")
|
||||
@@ -251,10 +256,26 @@ func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecr
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
|
||||
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
|
||||
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
|
||||
return
|
||||
}
|
||||
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
|
||||
return
|
||||
}
|
||||
response["nats_jwt"] = jwt
|
||||
response["nats_user_seed"] = seed
|
||||
}
|
||||
|
||||
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
|
||||
// Returns the plaintext API key on success.
|
||||
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {
|
||||
@@ -911,14 +932,56 @@ func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
// SetSchedulingRequest is the request body for creating/updating a scheduling config.
|
||||
//
|
||||
// The four prefix-cache fields are POINTERS so an omitted field is
|
||||
// distinguishable from an explicit zero. On update, an omitted prefix-cache
|
||||
// field preserves the model's previously-configured value instead of resetting
|
||||
// it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector,
|
||||
// MinReplicas and MaxReplicas keep their full-replace PUT semantics.
|
||||
type SetSchedulingRequest struct {
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
RoutePolicy *string `json:"route_policy,omitempty"`
|
||||
BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"`
|
||||
BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"`
|
||||
MinPrefixMatch *float64 `json:"min_prefix_match,omitempty"`
|
||||
}
|
||||
|
||||
// validateSchedulingRequest enforces the invariants of a scheduling config.
|
||||
// The prefix-cache bounds are delegated to prefixcache.ValidateThresholds (the
|
||||
// single source of truth), and are checked against the RESOLVED values passed
|
||||
// in (provided-or-preserved), so validation only rejects bad values the caller
|
||||
// actually supplied. It returns nil when valid, or an error with a user-facing
|
||||
// message describing the first violation.
|
||||
func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, absThr int, relThr, minMatch float64) error {
|
||||
if req.ModelName == "" {
|
||||
return errors.New("model_name is required")
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return errors.New("min_replicas must be >= 0")
|
||||
}
|
||||
if req.MaxReplicas < 0 {
|
||||
return errors.New("max_replicas must be >= 0")
|
||||
}
|
||||
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
||||
return errors.New("min_replicas must be <= max_replicas")
|
||||
}
|
||||
if err := prefixcache.ValidateThresholds(routePolicy, absThr, relThr, minMatch); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSchedulingEndpoint creates or updates a model scheduling config.
|
||||
//
|
||||
// The registry upsert full-replaces all columns, so a request that omits the
|
||||
// prefix-cache fields would otherwise wipe a model's previously-configured
|
||||
// routing settings. To avoid that footgun the four prefix-cache fields are
|
||||
// merged PATCH-style: a non-nil request pointer wins; a nil one preserves the
|
||||
// existing config's value (or the zero default when no config exists yet). The
|
||||
// non-prefix fields keep their full-replace PUT behavior.
|
||||
func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
@@ -926,17 +989,45 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
||||
}
|
||||
if req.ModelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required"))
|
||||
|
||||
// Fetch the existing config (may be nil) so omitted prefix-cache fields
|
||||
// can fall back to the stored value rather than resetting to zero.
|
||||
var existing *nodes.ModelSchedulingConfig
|
||||
if req.ModelName != "" {
|
||||
var err error
|
||||
existing, err = registry.GetModelScheduling(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to load existing scheduling config"))
|
||||
}
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0"))
|
||||
|
||||
// Resolve each prefix-cache field: provided pointer wins, otherwise keep
|
||||
// the existing value (zero/default when there is no existing config).
|
||||
routePolicy := ""
|
||||
absThr := 0
|
||||
relThr := 0.0
|
||||
minMatch := 0.0
|
||||
if existing != nil {
|
||||
routePolicy = existing.RoutePolicy
|
||||
absThr = existing.BalanceAbsThreshold
|
||||
relThr = existing.BalanceRelThreshold
|
||||
minMatch = existing.MinPrefixMatch
|
||||
}
|
||||
if req.MaxReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0"))
|
||||
if req.RoutePolicy != nil {
|
||||
routePolicy = *req.RoutePolicy
|
||||
}
|
||||
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas"))
|
||||
if req.BalanceAbsThreshold != nil {
|
||||
absThr = *req.BalanceAbsThreshold
|
||||
}
|
||||
if req.BalanceRelThreshold != nil {
|
||||
relThr = *req.BalanceRelThreshold
|
||||
}
|
||||
if req.MinPrefixMatch != nil {
|
||||
minMatch = *req.MinPrefixMatch
|
||||
}
|
||||
|
||||
if err := validateSchedulingRequest(req, routePolicy, absThr, relThr, minMatch); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, err.Error()))
|
||||
}
|
||||
|
||||
// Serialize node selector to JSON
|
||||
@@ -950,10 +1041,14 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
config := &nodes.ModelSchedulingConfig{
|
||||
ModelName: req.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
ModelName: req.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
RoutePolicy: routePolicy,
|
||||
BalanceAbsThreshold: absThr,
|
||||
BalanceRelThreshold: relThr,
|
||||
MinPrefixMatch: minMatch,
|
||||
}
|
||||
if err := registry.SetModelScheduling(ctx, config); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("validateSchedulingRequest", func() {
|
||||
base := func() SetSchedulingRequest {
|
||||
return SetSchedulingRequest{ModelName: "m"}
|
||||
}
|
||||
|
||||
It("accepts an empty route policy (inherit) with valid thresholds", func() {
|
||||
Expect(validateSchedulingRequest(base(), "", 3, 0, 0.4)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts the prefix_cache policy", func() {
|
||||
Expect(validateSchedulingRequest(base(), "prefix_cache", 0, 0, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts the round_robin policy", func() {
|
||||
Expect(validateSchedulingRequest(base(), "round_robin", 0, 0, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts balance_rel_threshold >= 1", func() {
|
||||
Expect(validateSchedulingRequest(base(), "", 0, 1.5, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects a missing model_name", func() {
|
||||
req := base()
|
||||
req.ModelName = ""
|
||||
err := validateSchedulingRequest(req, "", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("model_name is required"))
|
||||
})
|
||||
|
||||
It("rejects an unknown route_policy (no silent default)", func() {
|
||||
err := validateSchedulingRequest(base(), "bogus", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("rejects min_prefix_match above 1", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0, 2)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative min_prefix_match", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0, -0.1)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative balance_abs_threshold", func() {
|
||||
err := validateSchedulingRequest(base(), "", -1, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_abs_threshold"))
|
||||
})
|
||||
|
||||
It("rejects balance_rel_threshold between 0 and 1 exclusive", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0.5, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
})
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -63,7 +65,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -74,6 +76,29 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
Expect(resp["status"]).To(Equal(nodes.StatusHealthy))
|
||||
})
|
||||
|
||||
It("returns nats_jwt when account seed is configured", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
e := echo.New()
|
||||
body := `{"name":"worker-nats","address":"10.0.0.2:50051"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
natsCfg := natsauth.Config{AccountSeed: string(seed)}
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsCfg)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["nats_jwt"]).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns 400 when name is missing", func() {
|
||||
e := echo.New()
|
||||
body := `{"address":"10.0.0.1:50051"}`
|
||||
@@ -82,7 +107,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -102,7 +127,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -121,7 +146,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -140,7 +165,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -159,7 +184,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
@@ -172,7 +197,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -195,7 +220,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body1))
|
||||
req1.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(e.NewContext(req1, rec1))).To(Succeed())
|
||||
Expect(rec1.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -230,6 +255,114 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("SetSchedulingEndpoint", func() {
|
||||
postScheduling := func(body string) *httptest.ResponseRecorder {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
handler := SetSchedulingEndpoint(registry)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
It("persists prefix-cache fields and round-trips them via GET", func() {
|
||||
ctx := context.Background()
|
||||
rec := postScheduling(`{"model_name":"pc-model","route_policy":"prefix_cache","balance_abs_threshold":3,"min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "pc-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.RoutePolicy).To(Equal("prefix_cache"))
|
||||
Expect(cfg.BalanceAbsThreshold).To(Equal(3))
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
|
||||
|
||||
e := echo.New()
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
getRec := httptest.NewRecorder()
|
||||
gc := e.NewContext(getReq, getRec)
|
||||
gc.SetParamNames("model")
|
||||
gc.SetParamValues("pc-model")
|
||||
Expect(GetSchedulingEndpoint(registry)(gc)).To(Succeed())
|
||||
Expect(getRec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var got nodes.ModelSchedulingConfig
|
||||
Expect(json.Unmarshal(getRec.Body.Bytes(), &got)).To(Succeed())
|
||||
Expect(got.RoutePolicy).To(Equal("prefix_cache"))
|
||||
Expect(got.BalanceAbsThreshold).To(Equal(3))
|
||||
Expect(got.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
|
||||
})
|
||||
|
||||
It("returns 400 for an out-of-range min_prefix_match", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-mpm","min_prefix_match":2}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("returns 400 for an unknown route_policy", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-policy","route_policy":"bogus"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("returns 400 for a balance_rel_threshold between 0 and 1", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-rel","balance_rel_threshold":0.5}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
|
||||
// Regression for the partial-update footgun: a min/max-only POST used to
|
||||
// full-replace every column and silently reset the prefix-cache settings
|
||||
// to empty/zero. The pointer-merge must preserve omitted prefix fields.
|
||||
It("preserves prefix-cache settings across a min_replicas-only update", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
rec := postScheduling(`{"model_name":"merge-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
// Update only min_replicas - omits all prefix-cache fields.
|
||||
rec = postScheduling(`{"model_name":"merge-model","min_replicas":2}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "merge-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.MinReplicas).To(Equal(2), "the provided non-prefix field must update")
|
||||
Expect(cfg.RoutePolicy).To(Equal("prefix_cache"), "omitted route_policy must be preserved")
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must be preserved")
|
||||
})
|
||||
|
||||
It("updates a prefix-cache field when it is explicitly provided", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
rec := postScheduling(`{"model_name":"update-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
rec = postScheduling(`{"model_name":"update-model","route_policy":"round_robin"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "update-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.RoutePolicy).To(Equal("round_robin"), "explicitly provided route_policy must update")
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must still be preserved")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ListNodesEndpoint", func() {
|
||||
It("returns an empty list when no nodes are registered", func() {
|
||||
e := echo.New()
|
||||
|
||||
@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Stream audio chunks as they're generated
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
_, writeErr := c.Response().Write(audioChunk)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
}
|
||||
|
||||
// Non-streaming TTS (existing behavior)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -44,10 +44,10 @@ type wrappedModel struct {
|
||||
// deps in. nil-safe: with classifierRegistry == nil the per-turn
|
||||
// routing block in Predict is skipped, preserving today's "one LLM
|
||||
// for the whole session" behaviour.
|
||||
routerDeps *middleware.ClassifierDeps
|
||||
routerStore router.DecisionStore
|
||||
routerSessionID string
|
||||
routerUserID string
|
||||
routerDeps *middleware.ClassifierDeps
|
||||
routerStore router.DecisionStore
|
||||
routerSessionID string
|
||||
routerUserID string
|
||||
}
|
||||
|
||||
// anyToAnyModel represent a model which supports Any-to-Any operations
|
||||
@@ -119,6 +119,11 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
}
|
||||
}
|
||||
|
||||
// Surface the resolved reasoning effort to the Go-side template path too
|
||||
// (jinja models get it via backend metadata in gRPCPredictOpts; Go-templated
|
||||
// models like gpt-oss read it from the template's .ReasoningEffort).
|
||||
input.ReasoningEffort = turnCfg.ReasoningEffort
|
||||
|
||||
var predInput string
|
||||
var funcs []functions.Function
|
||||
if !turnCfg.TemplateConfig.UseTokenizerTemplate {
|
||||
@@ -313,7 +318,7 @@ func newRealtimeDecisionID() string {
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
|
||||
return backend.ModelTTS(ctx, text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
return backend.ModelTTS(ctx, text, voice, language, "", nil, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
|
||||
@@ -449,6 +454,9 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
// Let the pipeline set the LLM's reasoning effort (cfgLLM is a per-session copy).
|
||||
applyPipelineReasoning(cfgLLM, *pipeline)
|
||||
|
||||
cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
|
||||
16
core/http/endpoints/openai/realtime_reasoning.go
Normal file
16
core/http/endpoints/openai/realtime_reasoning.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package openai
|
||||
|
||||
import "github.com/mudler/LocalAI/core/config"
|
||||
|
||||
// applyPipelineReasoning sets the reasoning effort for a realtime pipeline's LLM
|
||||
// from the pipeline config, without editing the underlying LLM model config. The
|
||||
// pipeline value overrides the LLM's own reasoning_effort; when the pipeline does
|
||||
// not set it, the LLM model config's reasoning_effort (if any) is used. The LLM
|
||||
// config passed in is the per-session copy returned by the config loader, so this
|
||||
// does not affect other users of the same model.
|
||||
func applyPipelineReasoning(llm *config.ModelConfig, pipeline config.Pipeline) {
|
||||
if llm == nil {
|
||||
return
|
||||
}
|
||||
llm.ApplyReasoningEffort(pipeline.ReasoningEffort)
|
||||
}
|
||||
33
core/http/endpoints/openai/realtime_reasoning_test.go
Normal file
33
core/http/endpoints/openai/realtime_reasoning_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// applyPipelineReasoning lets a realtime pipeline set the reasoning effort for
|
||||
// its LLM (forwarded to the backend as reasoning_effort) without editing the LLM
|
||||
// model config. The pipeline value overrides the LLM's own reasoning_effort.
|
||||
var _ = Describe("applyPipelineReasoning", func() {
|
||||
It("applies the pipeline reasoning_effort to the LLM config", func() {
|
||||
llm := &config.ModelConfig{}
|
||||
applyPipelineReasoning(llm, config.Pipeline{ReasoningEffort: "none"})
|
||||
Expect(llm.ReasoningEffort).To(Equal("none"))
|
||||
Expect(llm.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*llm.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("falls back to the LLM's own reasoning_effort when the pipeline is unset", func() {
|
||||
llm := &config.ModelConfig{ReasoningEffort: "high"}
|
||||
applyPipelineReasoning(llm, config.Pipeline{})
|
||||
Expect(llm.ReasoningEffort).To(Equal("high"))
|
||||
Expect(llm.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*llm.ReasoningConfig.DisableReasoning).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is nil-safe", func() {
|
||||
applyPipelineReasoning(nil, config.Pipeline{ReasoningEffort: "low"})
|
||||
})
|
||||
})
|
||||
@@ -310,25 +310,13 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
// Map the per-request reasoning_effort onto the reasoning toggle the
|
||||
// backend reads (enable_thinking metadata, set in gRPCPredictOpts).
|
||||
// "none" disables thinking for this request - the use case from #10072,
|
||||
// running a single Qwen3-style model and turning reasoning off per
|
||||
// request. Any explicit effort level enables thinking, UNLESS the model
|
||||
// config explicitly disabled it (DisableReasoning==true wins): an
|
||||
// operator who deliberately turned reasoning off should not be overridden
|
||||
// by a request. A value of "none" always disables, since that never
|
||||
// conflicts with a config that also disables.
|
||||
switch strings.ToLower(input.ReasoningEffort) {
|
||||
case "none":
|
||||
disable := true
|
||||
config.ReasoningConfig.DisableReasoning = &disable
|
||||
case "minimal", "low", "medium", "high":
|
||||
if config.ReasoningConfig.DisableReasoning == nil || !*config.ReasoningConfig.DisableReasoning {
|
||||
enable := false
|
||||
config.ReasoningConfig.DisableReasoning = &enable
|
||||
}
|
||||
}
|
||||
// Resolve the effective reasoning effort (request overrides the model config
|
||||
// default), store it so gRPCPredictOpts forwards it to the backend as the
|
||||
// reasoning_effort chat_template_kwarg (what gpt-oss / LFM2.5 read), and map
|
||||
// it onto the enable_thinking toggle. "none" disables thinking (the #10072
|
||||
// use case); a level enables it unless the config already disabled reasoning
|
||||
// (an operator's explicit disable wins over a request asking to think).
|
||||
config.ApplyReasoningEffort(input.ReasoningEffort)
|
||||
|
||||
// Collapse the modern max_completion_tokens alias into the
|
||||
// legacy Maxtokens field so downstream code reads exactly one.
|
||||
|
||||
@@ -17,7 +17,10 @@ func SecurityHeaders() echo.MiddlewareFunc {
|
||||
"img-src 'self' data: blob: https:; " +
|
||||
"media-src 'self' data: blob:; " +
|
||||
"font-src 'self' data:; " +
|
||||
"connect-src 'self' ws: wss: https:; " +
|
||||
// blob: lets the waveform renderer XHR/fetch a freshly-created object
|
||||
// URL (e.g. an uploaded clip before it has a server URL). XHR/fetch of
|
||||
// blob: falls under connect-src, not media-src.
|
||||
"connect-src 'self' ws: wss: https: blob:; " +
|
||||
"frame-src 'self' blob:; " +
|
||||
"worker-src 'self' blob:; " +
|
||||
"object-src 'none'; " +
|
||||
|
||||
@@ -32,6 +32,9 @@ var _ = Describe("SecurityHeaders", func() {
|
||||
Expect(csp).To(ContainSubstring("frame-ancestors 'self'"))
|
||||
Expect(csp).To(ContainSubstring("object-src 'none'"))
|
||||
Expect(csp).To(ContainSubstring("base-uri 'self'"))
|
||||
// blob: must be in connect-src so the waveform renderer can XHR/fetch
|
||||
// a freshly-created object URL (uploaded/enhanced clip).
|
||||
Expect(csp).To(ContainSubstring("connect-src 'self' ws: wss: https: blob:"))
|
||||
})
|
||||
|
||||
It("sets X-Content-Type-Options: nosniff", func() {
|
||||
|
||||
@@ -1 +1 @@
|
||||
38.29
|
||||
40.0
|
||||
@@ -20,5 +20,10 @@ test.describe('Agents page', () => {
|
||||
page.waitForURL(/\/app\/agents\/new$/),
|
||||
create.click(),
|
||||
])
|
||||
// Wait for AgentCreate.jsx to actually render, not just for the URL to
|
||||
// change. Ending the test the instant the route matched let the component
|
||||
// mount race the coverage teardown — its ~400 lines were collected only
|
||||
// when the render won, swinging total UI coverage ~1pp run-to-run.
|
||||
await expect(page.getByRole('heading', { name: 'Create Agent' })).toBeVisible()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -66,6 +66,33 @@ function makeFakeWav(name) {
|
||||
return { name, mimeType: 'audio/wav', buffer: buf }
|
||||
}
|
||||
|
||||
// Build a WAV carrying a real sine tone, long enough that the spectrogram
|
||||
// STFT produces several frames (a few thousand samples). Used to exercise the
|
||||
// FFT / heatmap path, which the 4-sample silent fixture can't.
|
||||
function makeToneWav(name, freq = 1000, seconds = 0.4, sampleRate = 16000) {
|
||||
const samples = Math.floor(seconds * sampleRate)
|
||||
const dataLen = samples * 2
|
||||
const buf = Buffer.alloc(44 + dataLen)
|
||||
buf.write('RIFF', 0)
|
||||
buf.writeUInt32LE(36 + dataLen, 4)
|
||||
buf.write('WAVE', 8)
|
||||
buf.write('fmt ', 12)
|
||||
buf.writeUInt32LE(16, 16)
|
||||
buf.writeUInt16LE(1, 20)
|
||||
buf.writeUInt16LE(1, 22)
|
||||
buf.writeUInt32LE(sampleRate, 24)
|
||||
buf.writeUInt32LE(sampleRate * 2, 28)
|
||||
buf.writeUInt16LE(2, 32)
|
||||
buf.writeUInt16LE(16, 34)
|
||||
buf.write('data', 36)
|
||||
buf.writeUInt32LE(dataLen, 40)
|
||||
for (let i = 0; i < samples; i++) {
|
||||
const v = Math.round(Math.sin((2 * Math.PI * freq * i) / sampleRate) * 16000)
|
||||
buf.writeInt16LE(v, 44 + i * 2)
|
||||
}
|
||||
return { name, mimeType: 'audio/wav', buffer: buf }
|
||||
}
|
||||
|
||||
test.describe('Audio Transform', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await mockCapabilities(page, [
|
||||
@@ -169,6 +196,26 @@ test.describe('Audio Transform', () => {
|
||||
await expect(page.getByTestId('media-history-item')).toHaveCount(1)
|
||||
})
|
||||
|
||||
test('renders an input spectrogram on upload and an output one after transform', async ({ page }) => {
|
||||
mockAudioTransform(page, 'enhanced.wav')
|
||||
|
||||
await page.goto('/app/transform')
|
||||
await expect(page.getByRole('button', { name: 'localvqe' })).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
// Choosing a clip should render its input spectrogram immediately — no
|
||||
// backend round-trip needed (it's computed client-side from the bytes).
|
||||
await page.locator('input[type="file"]').first().setInputFiles(makeToneWav('tone.wav'))
|
||||
await expect(page.getByTestId('spectrogram-input')).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
// Until a transform runs the output side shows a "compare" placeholder.
|
||||
await expect(page.getByText(/Transform to compare/)).toBeVisible()
|
||||
|
||||
await page.getByRole('button', { name: /Transform/ }).last().click()
|
||||
|
||||
// After processing, the output spectrum panel appears alongside the input.
|
||||
await expect(page.getByText('Output spectrum')).toBeVisible({ timeout: 10_000 })
|
||||
})
|
||||
|
||||
test('shows an error banner when the backend returns 4xx', async ({ page }) => {
|
||||
await page.route('**/audio/transformations', (route) => {
|
||||
if (route.request().method() !== 'POST') return route.continue()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user