mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-06 07:46:15 -04:00
Compare commits
45 Commits
v4.3.6
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85f02497f2 | ||
|
|
5470051d4d | ||
|
|
68c5eeebc3 | ||
|
|
1531fabe23 | ||
|
|
b7673d5b76 | ||
|
|
b64bdaf406 | ||
|
|
eebf08ff1d | ||
|
|
42e51894c3 | ||
|
|
d9ae6481fb | ||
|
|
f1c495a748 | ||
|
|
415b561947 | ||
|
|
e6a0d4c375 | ||
|
|
7e59a5c7c5 | ||
|
|
aea954a482 | ||
|
|
595e448714 | ||
|
|
860f9d63ad | ||
|
|
a5a0b3dc4e | ||
|
|
94eca04c60 | ||
|
|
35bd485d6a | ||
|
|
1fe96f8d9a | ||
|
|
c508e9d7c6 | ||
|
|
55e754fd05 | ||
|
|
a17753f7d1 | ||
|
|
c61838dba6 | ||
|
|
7013e13f05 | ||
|
|
5a0013defe | ||
|
|
c01ed631d6 | ||
|
|
d47464cb06 | ||
|
|
63f176346e | ||
|
|
af94d08729 | ||
|
|
6795d38f50 | ||
|
|
718223f33b | ||
|
|
39e050d9e2 | ||
|
|
c222161291 | ||
|
|
aa80d4681b | ||
|
|
0d57957ebb | ||
|
|
76fe0bb929 | ||
|
|
baa11133f1 | ||
|
|
1bdd3338a6 | ||
|
|
e08492a2c3 | ||
|
|
d5d8fe909d | ||
|
|
8a82753277 | ||
|
|
51ca109067 | ||
|
|
07f6c15a37 | ||
|
|
a44bdb29d4 |
@@ -38,9 +38,12 @@ The React UI (`core/http/react-ui/`) has **no component/unit tests** — its onl
|
||||
- **Browser:** the flake dev shell ships `chromium` and exports `PLAYWRIGHT_CHROMIUM_PATH`; `playwright.config.js` uses it via `launchOptions.executablePath`, and the Makefile skips `playwright install` when it's set. This avoids Playwright's downloaded browser, which can't resolve system libs (`libglib-2.0`, …) on NixOS. In CI (no `PLAYWRIGHT_CHROMIUM_PATH`) the Makefile falls back to `playwright install --with-deps chromium`.
|
||||
- The app is a React SPA, so coverage accumulates across in-app navigation within a test; a full `page.goto`/reload resets it.
|
||||
- `.nycrc.json` uses `all: true`, so **every `src/**` file is in the report**, including 0%-coverage ones — that's how you spot features with no test at all (sort the HTML report or `coverage-summary.json` by line% ascending).
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` (default **1.0pp**) below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. **Why a tolerance (unlike the strict Go gate):** UI e2e line coverage is *non-deterministic* — async/debounced paths (e.g. the VRAM estimate's 500ms debounce) make identical specs vary ~0.5pp run-to-run, so a zero-tolerance gate would flake. Keep the tolerance just above the observed jitter. Run in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. Runs in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **Why it has a tolerance (unlike the strict Go gate):** UI e2e coverage is *non-deterministic*. Specs that assert on state and end while async/lazy render work is still in flight collect those lines only when the render beats the coverage teardown — so the total drifts with machine speed/load (a fast local box reads higher than a slow CI runner), diffusely across many specs. The tolerance absorbs that drift, so set the baseline *below* the slow-CI floor, never to a fast-local `make test-ui-coverage-baseline` number, or CI flaps.
|
||||
- **Raising coverage is cheap:** a *render-smoke* spec (navigate to a route, assert its header renders) mounts a lazy page and runs its full render + initial effects, capturing most of its lines in a few lines of test — see `e2e/page-render-smoke.spec.js`. Auth is disabled in the test server (`isAdmin=true`), so `RequireAdmin`/`RequireFeature` routes render without a mock. The most *deterministic* win is removing a race: make a spec `await` a rendered element before ending (see `e2e/agents.spec.js` → AgentCreate) so its lines count every run.
|
||||
|
||||
Rules:
|
||||
- The gate is **strict — there is no tolerance**. Any decrease fails, regardless of how many lines a PR adds or deletes. `covermode=atomic` makes line coverage deterministic, so there's no run-to-run jitter to excuse.
|
||||
- When a change legitimately **raises** coverage, run `make test-coverage-baseline` and **commit** the updated `coverage-baseline.txt` so the ratchet moves up. Never lower the baseline by hand.
|
||||
- If you can't get coverage back to baseline, the fix is to **add tests**, not to edit the baseline.
|
||||
Rules (both gates):
|
||||
- **Install the hooks:** `make install-hooks` once per clone so lint + coverage run pre-commit. Don't lean on CI for what the hook catches.
|
||||
- **Don't work around the gate:** never `git commit --no-verify`, and never hand-lower a baseline or widen a tolerance to turn a red gate green. The ratchet only moves up.
|
||||
- If a change drops coverage, **add tests** (sort `coverage-summary.json` by line% ascending to find untested code) rather than editing the baseline. When coverage legitimately rises, commit the regenerated baseline (`make test-coverage-baseline` / `test-ui-coverage-baseline`).
|
||||
- The Go gate is **strict — no tolerance**; `covermode=atomic` keeps it deterministic. The UI gate keeps a small tolerance only because its e2e coverage isn't.
|
||||
|
||||
@@ -68,6 +68,34 @@ go test -count=1 -timeout=30m -v ./tests/e2e-backends/...
|
||||
|
||||
CI does not load the model; the suite is opt-in via env vars.
|
||||
|
||||
## Distributed mode
|
||||
|
||||
ds4 supports **layer-split** distributed inference (a model too big for one host,
|
||||
split by transformer layer; the GGUF must be present on every machine, each loads
|
||||
only its slice). Topology is **inverted** vs llama.cpp: the coordinator listens,
|
||||
workers dial in.
|
||||
|
||||
- **`ds4-worker` binary**: built and packaged next to `grpc-server` (`package.sh`
|
||||
copies it into `package/`). Links the same engine objects plus `ds4_distributed.o`;
|
||||
**no gRPC/protobuf dependency** (speaks ds4's own TCP transport), so it builds
|
||||
even where `grpc-server` can't. Runs the worker serving loop (`ds4_dist_run`).
|
||||
- **Coordinator wiring**: the ds4 `grpc-server` acts as coordinator when `LoadModel`
|
||||
`ModelOptions.Options` (from model-YAML `options:`) carry:
|
||||
- `ds4_role:coordinator` (enables distributed mode; absent → single-node, back-compat)
|
||||
- `ds4_layers:0:19` (coordinator's own slice, inclusive; `N:output` includes the head)
|
||||
- `ds4_listen:0.0.0.0:1234` (address workers dial into)
|
||||
- `ds4_route_timeout:60` (optional; seconds Predict/PredictStream wait for the route
|
||||
to form before returning gRPC `UNAVAILABLE`; default 60)
|
||||
- **Worker CLI**: `local-ai worker ds4-distributed -- <ds4-worker args>` resolves the
|
||||
ds4 backend and execs the packaged `ds4-worker` (raw passthrough), e.g.
|
||||
`--role worker --model /models/ds4flash.gguf --layers 20:output --coordinator <host> 1234`.
|
||||
|
||||
Opt-in e2e in `tests/e2e-backends/backend_test.go`, gated by
|
||||
`BACKEND_TEST_DS4_DISTRIBUTED=1` (plus `BACKEND_TEST_DS4_WORKER_BINARY`,
|
||||
`BACKEND_TEST_DS4_WORKER_LAYERS`, `BACKEND_TEST_DS4_COORDINATOR_LAYERS`,
|
||||
`BACKEND_TEST_DS4_LISTEN`). Design spec:
|
||||
`docs/superpowers/specs/2026-05-30-ds4-distributed-inference-design.md`.
|
||||
|
||||
## Importer
|
||||
|
||||
`core/gallery/importers/ds4.go` (`DS4Importer`) auto-detects ds4 weights by
|
||||
|
||||
151
.github/backend-matrix.yml
vendored
151
.github/backend-matrix.yml
vendored
@@ -716,6 +716,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -1569,6 +1582,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1595,6 +1621,19 @@ include:
|
||||
backend: "whisper"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-crispasr'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2889,6 +2928,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2903,6 +2956,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2916,6 +2983,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2929,6 +3009,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2943,6 +3036,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2957,6 +3064,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2970,6 +3091,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-crispasr'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2983,6 +3117,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-crispasr'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# parakeet-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -4124,6 +4271,10 @@ includeDarwin:
|
||||
tag-suffix: "-metal-darwin-arm64-whisper"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "crispasr"
|
||||
tag-suffix: "-metal-darwin-arm64-crispasr"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "parakeet-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-parakeet-cpp"
|
||||
build-type: "metal"
|
||||
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -30,6 +30,10 @@ jobs:
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/whisper/Makefile"
|
||||
- repository: "CrispStrobe/CrispASR"
|
||||
variable: "CRISPASR_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/crispasr/Makefile"
|
||||
- repository: "mudler/parakeet.cpp"
|
||||
variable: "PARAKEET_VERSION"
|
||||
branch: "master"
|
||||
|
||||
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: securego/gosec@v2.22.9
|
||||
uses: securego/gosec@v2.27.1
|
||||
with:
|
||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||
|
||||
@@ -35,6 +35,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
|
||||
## Quick Reference
|
||||
|
||||
- **Git hooks & coverage gates**: Run `make install-hooks` once per clone so the pre-commit lint + coverage gates run. **Never bypass them with `git commit --no-verify`, and never lower a coverage baseline or widen a gate's tolerance to turn a red gate green** — the coverage ratchet only moves up. If a change drops coverage, add tests to raise it (e.g. render-smoke specs). See [.agents/building-and-testing.md](.agents/building-and-testing.md).
|
||||
- **Logging**: Use `github.com/mudler/xlog` (same API as slog)
|
||||
- **Go style**: Prefer `any` over `interface{}`
|
||||
- **Comments**: Explain *why*, not *what*
|
||||
|
||||
@@ -266,6 +266,12 @@ The e2e tests run LocalAI in a Docker container and exercise the API:
|
||||
make test-e2e
|
||||
```
|
||||
|
||||
### React UI tests and coverage
|
||||
|
||||
The React UI (`core/http/react-ui/`) is covered by Playwright e2e specs, gated by a **monotonic line-coverage ratchet** (`make test-ui-coverage-check`, run in CI and pre-commit). The metric is non-deterministic — a fast local box reads higher than a slow CI runner for the same code — so a small tolerance is unavoidable.
|
||||
|
||||
**If your change lowers UI coverage, raise it back by adding specs — do not widen the tolerance or hand-lower the baseline.** A *render-smoke* spec (navigate to a page, assert its header is visible) cheaply covers an entire lazy page. See `core/http/react-ui/e2e/page-render-smoke.spec.js` and the full policy in [.agents/building-and-testing.md](.agents/building-and-testing.md#react-ui-coverage).
|
||||
|
||||
### Running E2E container tests
|
||||
|
||||
These tests build a standard LocalAI Docker image and run it with pre-configured model configs to verify that most endpoints work correctly:
|
||||
|
||||
6
Makefile
6
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -1162,6 +1162,7 @@ BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
||||
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
@@ -1250,6 +1251,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||
@@ -1300,7 +1302,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-crispasr docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
18
README.md
18
README.md
@@ -31,12 +31,18 @@
|
||||
|
||||
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||
|
||||
- **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
|
||||
- **36+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
|
||||
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready** — API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first** — your data never leaves your infrastructure
|
||||
**A small core, not a bundle.** Each backend wraps a best-in-class engine (llama.cpp, vLLM, whisper.cpp, stable-diffusion, MLX...) in its own image, pulled only when a model needs it. You install nothing you don't use.
|
||||
|
||||
- **Composable by design**: backends are separate and pulled on demand, so you install only what your model needs
|
||||
- **Open and extensible**: load any model, or build your own backend in any language against an open interface
|
||||
- **Drop-in API compatibility**: OpenAI, Anthropic, and ElevenLabs APIs across every backend
|
||||
- **Any model, any modality**: LLMs, vision, voice, image, and video behind one API
|
||||
- **Any hardware**: NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready**: API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents**: autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first**: your data never leaves your infrastructure
|
||||
|
||||

|
||||
|
||||
Created by [Ettore Di Giacinto](https://github.com/mudler) and maintained by the [LocalAI team](#team).
|
||||
|
||||
|
||||
1
backend/cpp/ds4/.gitignore
vendored
1
backend/cpp/ds4/.gitignore
vendored
@@ -2,6 +2,7 @@ ds4/
|
||||
build/
|
||||
package/
|
||||
grpc-server
|
||||
ds4-worker
|
||||
*.o
|
||||
backend.pb.cc
|
||||
backend.pb.h
|
||||
|
||||
@@ -104,3 +104,36 @@ if(DS4_NATIVE)
|
||||
target_compile_options(${TARGET} PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ds4-worker: standalone distributed worker. Links the same ds4 engine objects
|
||||
# (including ds4_distributed.o) but has NO gRPC/protobuf dependency - it speaks
|
||||
# ds4's own TCP transport via ds4_dist_run(). Buildable wherever the engine
|
||||
# objects build, even on hosts without protobuf/grpc dev headers.
|
||||
add_executable(ds4-worker worker_main.c)
|
||||
target_include_directories(ds4-worker PRIVATE ${DS4_DIR})
|
||||
foreach(obj ${DS4_OBJS})
|
||||
target_sources(ds4-worker PRIVATE ${obj})
|
||||
set_source_files_properties(${obj} PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
|
||||
endforeach()
|
||||
# worker_main.c is C, but the engine objects built by nvcc (ds4_cuda.o) and the
|
||||
# Metal path (ds4_metal.o, Obj-C++) reference the C++ runtime (libstdc++). Force
|
||||
# the C++ linker driver so those symbols resolve; the C driver would not link
|
||||
# libstdc++ and the CUDA/Metal builds fail with undefined std:: references.
|
||||
set_target_properties(ds4-worker PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_link_libraries(ds4-worker PRIVATE Threads::Threads m)
|
||||
|
||||
if(DS4_GPU STREQUAL "cuda")
|
||||
target_link_libraries(ds4-worker PRIVATE CUDA::cudart CUDA::cublas)
|
||||
elseif(DS4_GPU STREQUAL "metal")
|
||||
target_link_libraries(ds4-worker PRIVATE ${FOUNDATION_LIB} ${METAL_LIB})
|
||||
elseif(DS4_GPU STREQUAL "cpu")
|
||||
target_compile_definitions(ds4-worker PRIVATE DS4_NO_GPU)
|
||||
endif()
|
||||
|
||||
if(DS4_NATIVE)
|
||||
if(APPLE)
|
||||
target_compile_options(ds4-worker PRIVATE -mcpu=native)
|
||||
else()
|
||||
target_compile_options(ds4-worker PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=e16ead1e29c81a67bbb64e5b001117679cf9ce6e
|
||||
# Upstream pin lives below as DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=e16ead1e29c81a67bbb64e5b001117679cf9ce6e
|
||||
DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
@@ -66,6 +66,7 @@ grpc-server: ds4/ds4.o
|
||||
mkdir -p $(BUILD_DIR)
|
||||
cd $(BUILD_DIR) && cmake $(CMAKE_ARGS) $(CURRENT_MAKEFILE_DIR) && cmake --build . --config Release -j $(JOBS)
|
||||
cp $(BUILD_DIR)/grpc-server grpc-server
|
||||
cp $(BUILD_DIR)/ds4-worker ds4-worker
|
||||
|
||||
package: grpc-server
|
||||
bash package.sh
|
||||
@@ -74,7 +75,7 @@ test:
|
||||
@echo "ds4 backend: e2e coverage at tests/e2e-backends/ (BACKEND_BINARY mode)"
|
||||
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR) grpc-server package
|
||||
rm -rf $(BUILD_DIR) grpc-server ds4-worker package
|
||||
if [ -d ds4 ]; then $(MAKE) -C ds4 clean; fi
|
||||
|
||||
purge: clean
|
||||
|
||||
@@ -23,8 +23,11 @@ extern "C" {
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
@@ -51,6 +54,12 @@ ds4_session *g_session = nullptr;
|
||||
int g_ctx_size = 32768;
|
||||
std::string g_kv_cache_dir; // empty disables disk cache
|
||||
|
||||
// Distributed coordinator state. g_distributed is set true when LoadModel is
|
||||
// given 'ds4_role:coordinator'; generation then waits for the worker route to
|
||||
// form before running. Single-node behavior is unchanged when unset.
|
||||
bool g_distributed = false;
|
||||
int g_route_timeout_sec = 60;
|
||||
|
||||
std::atomic<Server *> g_server{nullptr};
|
||||
|
||||
// Parse a "key:value" option string. Returns empty when no colon.
|
||||
@@ -60,6 +69,77 @@ static std::pair<std::string, std::string> split_option(const std::string &opt)
|
||||
return {opt.substr(0, colon), opt.substr(colon + 1)};
|
||||
}
|
||||
|
||||
// Parse a positive base-10 integer. Returns false (without throwing) on empty,
|
||||
// trailing garbage, non-positive, or overflow - unlike std::stoi.
|
||||
static bool parse_positive_int(const std::string &s, int *out) {
|
||||
if (s.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long v = std::strtol(s.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || v <= 0 || v > INT_MAX) return false;
|
||||
*out = static_cast<int>(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse a ds4 layer spec "START:END" or "START:output" into the engine's
|
||||
// distributed layer fields. Returns false on malformed input.
|
||||
static bool parse_layers_spec(const std::string &spec, ds4_distributed_layers *out) {
|
||||
auto colon = spec.find(':');
|
||||
if (colon == std::string::npos) return false;
|
||||
std::string lhs = spec.substr(0, colon);
|
||||
std::string rhs = spec.substr(colon + 1);
|
||||
if (lhs.empty() || rhs.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long start = std::strtol(lhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || start < 0) return false;
|
||||
out->start = static_cast<uint32_t>(start);
|
||||
out->has_output = false;
|
||||
if (rhs == "output") {
|
||||
out->has_output = true;
|
||||
out->end = out->start; // engine treats has_output as "through final layer"
|
||||
} else {
|
||||
long e = std::strtol(rhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || e < start) return false;
|
||||
out->end = static_cast<uint32_t>(e);
|
||||
}
|
||||
out->set = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
// When acting as a distributed coordinator, block until the worker route
|
||||
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
||||
// elapses. Returns an empty string on success, or an error message to return
|
||||
// to the client. No-op when not distributed.
|
||||
//
|
||||
// Takes the g_engine_mu lock by reference and RELEASES it during each poll
|
||||
// sleep. The wait can span up to g_route_timeout_sec seconds while workers
|
||||
// connect; holding g_engine_mu the whole time would block the Status/Health
|
||||
// readiness probes (they also lock g_engine_mu), making LocalAI's loader treat
|
||||
// a still-starting worker as hung.
|
||||
static std::string wait_route_ready(std::unique_lock<std::mutex> &lock) {
|
||||
if (!g_distributed) return "";
|
||||
char err[256] = {0};
|
||||
const int deadline_polls = g_route_timeout_sec * 10; // 100ms per poll
|
||||
for (int i = 0; i <= deadline_polls; ++i) {
|
||||
int ready = ds4_session_distributed_route_ready(g_session, err, sizeof(err));
|
||||
if (ready == 1) return "";
|
||||
if (ready < 0) {
|
||||
return std::string("ds4 distributed route error: ") +
|
||||
(err[0] ? err : "unknown");
|
||||
}
|
||||
// Release the lock while sleeping so Status/Health and other RPCs can
|
||||
// interleave during worker startup.
|
||||
lock.unlock();
|
||||
struct timespec ts = {0, 100L * 1000L * 1000L}; // 100ms
|
||||
nanosleep(&ts, nullptr);
|
||||
lock.lock();
|
||||
// A concurrent Free() may have torn down the engine while we slept.
|
||||
if (!g_engine || !g_session) {
|
||||
return "ds4: model unloaded while waiting for distributed route";
|
||||
}
|
||||
}
|
||||
return "ds4 distributed route incomplete: workers not connected (layers uncovered)";
|
||||
}
|
||||
|
||||
static void append_token_text(ds4_engine *engine, int token, std::string &out) {
|
||||
size_t len = 0;
|
||||
const char *text = ds4_token_text(engine, token, &len);
|
||||
@@ -377,6 +457,11 @@ public:
|
||||
backend::Result *result) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
|
||||
// Reset distributed state so a model swap (a second LoadModel without
|
||||
// ds4_role) doesn't inherit a stale coordinator configuration.
|
||||
g_distributed = false;
|
||||
g_route_timeout_sec = 60;
|
||||
|
||||
if (g_engine) {
|
||||
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
||||
ds4_engine_close(g_engine);
|
||||
@@ -394,12 +479,23 @@ public:
|
||||
std::string mtp_path;
|
||||
int mtp_draft = 0;
|
||||
float mtp_margin = 3.0f;
|
||||
std::string ds4_role, ds4_layers, ds4_listen;
|
||||
for (const auto &opt : request->options()) {
|
||||
auto [k, v] = split_option(opt);
|
||||
if (k == "mtp_path") mtp_path = v;
|
||||
else if (k == "mtp_draft") mtp_draft = std::stoi(v);
|
||||
else if (k == "mtp_margin") mtp_margin = std::stof(v);
|
||||
else if (k == "kv_cache_dir") g_kv_cache_dir = v;
|
||||
else if (k == "ds4_role") ds4_role = v;
|
||||
else if (k == "ds4_layers") ds4_layers = v;
|
||||
else if (k == "ds4_listen") ds4_listen = v;
|
||||
else if (k == "ds4_route_timeout") {
|
||||
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
g_kv_cache.SetDir(g_kv_cache_dir);
|
||||
@@ -422,6 +518,49 @@ public:
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
|
||||
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
||||
// distributed inference: this process listens on ds4_listen and owns
|
||||
// the ds4_layers slice; workers dial in (see `local-ai worker
|
||||
// ds4-distributed`). Absent ds4_role => unchanged single-node path.
|
||||
// Must be static: opt.distributed.listen_host is a const char* the
|
||||
// engine retains past this call, so it cannot point at a local that
|
||||
// goes out of scope (otherwise a future "simplify to local" refactor
|
||||
// reintroduces a dangling pointer).
|
||||
static std::string s_listen_host;
|
||||
if (ds4_role == "coordinator") {
|
||||
if (ds4_layers.empty() || ds4_listen.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_role:coordinator requires ds4_layers and ds4_listen");
|
||||
return GStatus::OK;
|
||||
}
|
||||
// host:port for IPv4/hostname; IPv6 literals are unsupported (the
|
||||
// first colon would split inside the address).
|
||||
auto host_port = split_option(ds4_listen); // "host:port" -> {host, port}
|
||||
if (host_port.second.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen must be host:port");
|
||||
return GStatus::OK;
|
||||
}
|
||||
int listen_port = 0;
|
||||
if (!parse_positive_int(host_port.second, &listen_port)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen port must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
ds4_distributed_layers layers = {};
|
||||
if (!parse_layers_spec(ds4_layers, &layers)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: invalid ds4_layers (want START:END or START:output)");
|
||||
return GStatus::OK;
|
||||
}
|
||||
s_listen_host = host_port.first;
|
||||
opt.distributed.role = DS4_DISTRIBUTED_COORDINATOR;
|
||||
opt.distributed.layers = layers;
|
||||
opt.distributed.listen_host = s_listen_host.c_str();
|
||||
opt.distributed.listen_port = listen_port;
|
||||
g_distributed = true;
|
||||
}
|
||||
|
||||
int rc = ds4_engine_open(&g_engine, &opt);
|
||||
if (rc != 0 || !g_engine) {
|
||||
result->set_success(false);
|
||||
@@ -458,10 +597,13 @@ public:
|
||||
|
||||
GStatus Predict(ServerContext *, const backend::PredictOptions *request,
|
||||
backend::Reply *reply) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
@@ -554,10 +696,13 @@ public:
|
||||
|
||||
GStatus PredictStream(ServerContext *, const backend::PredictOptions *request,
|
||||
ServerWriter<backend::Reply> *writer) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
|
||||
@@ -5,7 +5,8 @@ REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
cp -avf "$CURDIR/grpc-server" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/ds4-worker" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
UNAME_S=$(uname -s)
|
||||
if [ "$UNAME_S" = "Darwin" ]; then
|
||||
|
||||
126
backend/cpp/ds4/worker_main.c
Normal file
126
backend/cpp/ds4/worker_main.c
Normal file
@@ -0,0 +1,126 @@
|
||||
// ds4-worker: standalone distributed worker for the LocalAI ds4 backend.
|
||||
//
|
||||
// A ds4 distributed worker owns a slice of the model's transformer layers,
|
||||
// dials the coordinator, and serves activations for its slice. It does NOT
|
||||
// speak backend.proto - it speaks ds4's own TCP transport via ds4_dist_run().
|
||||
// This binary is intentionally minimal (no HTTP/web/kvstore/linenoise): it
|
||||
// only needs the engine objects + ds4_distributed.o, which the backend already
|
||||
// builds. It is launched by `local-ai worker ds4-distributed`.
|
||||
//
|
||||
// Usage:
|
||||
// ds4-worker --role worker --model <gguf> --layers 20:output \
|
||||
// --coordinator <host> <port> [--cpu|--cuda|--metal] [-c CTX] [-t N]
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <signal.h>
|
||||
#include <limits.h>
|
||||
|
||||
#include "ds4.h"
|
||||
#include "ds4_distributed.h"
|
||||
|
||||
static const char *need_arg(int *i, int argc, char **argv, const char *flag) {
|
||||
if (*i + 1 >= argc) {
|
||||
fprintf(stderr, "ds4-worker: missing value for %s\n", flag);
|
||||
exit(2);
|
||||
}
|
||||
return argv[++(*i)];
|
||||
}
|
||||
|
||||
static int parse_int_arg(const char *s, const char *flag) {
|
||||
char *end = NULL;
|
||||
long v = strtol(s, &end, 10);
|
||||
if (!s[0] || *end || v <= 0 || v > INT_MAX) {
|
||||
fprintf(stderr, "ds4-worker: invalid value for %s: %s\n", flag, s);
|
||||
exit(2);
|
||||
}
|
||||
return (int)v;
|
||||
}
|
||||
|
||||
static ds4_backend default_backend(void) {
|
||||
#if defined(DS4_NO_GPU)
|
||||
return DS4_BACKEND_CPU;
|
||||
#elif defined(__APPLE__)
|
||||
return DS4_BACKEND_METAL;
|
||||
#else
|
||||
return DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
signal(SIGPIPE, SIG_IGN);
|
||||
|
||||
ds4_engine_options opt = {0};
|
||||
opt.backend = default_backend();
|
||||
int ctx_size = 32768;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const char *arg = argv[i];
|
||||
if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
|
||||
fprintf(stdout, "ds4-worker: standalone ds4 distributed worker\n");
|
||||
ds4_dist_usage(stdout);
|
||||
fprintf(stdout, " -m, --model PATH model GGUF (the worker loads only its --layers slice)\n");
|
||||
fprintf(stdout, " -c, --ctx N context size (default 32768)\n");
|
||||
fprintf(stdout, " -t, --threads N CPU threads\n");
|
||||
fprintf(stdout, " --cpu|--cuda|--metal backend override\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
char dist_err[256] = {0};
|
||||
ds4_dist_cli_parse_result dist_parse =
|
||||
ds4_dist_parse_cli_arg(arg, &i, argc, argv, &opt.distributed,
|
||||
dist_err, sizeof(dist_err));
|
||||
if (dist_parse == DS4_DIST_CLI_ERROR) {
|
||||
fprintf(stderr, "ds4-worker: %s\n",
|
||||
dist_err[0] ? dist_err : "invalid distributed option");
|
||||
return 2;
|
||||
}
|
||||
if (dist_parse == DS4_DIST_CLI_MATCHED) continue;
|
||||
|
||||
if (!strcmp(arg, "-m") || !strcmp(arg, "--model")) {
|
||||
opt.model_path = need_arg(&i, argc, argv, arg);
|
||||
} else if (!strcmp(arg, "-c") || !strcmp(arg, "--ctx")) {
|
||||
ctx_size = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "-t") || !strcmp(arg, "--threads")) {
|
||||
opt.n_threads = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "--cpu")) {
|
||||
opt.backend = DS4_BACKEND_CPU;
|
||||
} else if (!strcmp(arg, "--cuda")) {
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
} else if (!strcmp(arg, "--metal")) {
|
||||
opt.backend = DS4_BACKEND_METAL;
|
||||
} else {
|
||||
fprintf(stderr, "ds4-worker: unknown option: %s\n", arg);
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (opt.distributed.role != DS4_DISTRIBUTED_WORKER) {
|
||||
fprintf(stderr, "ds4-worker: --role worker is required\n");
|
||||
return 2;
|
||||
}
|
||||
if (!opt.model_path) {
|
||||
fprintf(stderr, "ds4-worker: --model is required\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
char prep_err[256] = {0};
|
||||
if (ds4_dist_prepare_engine_options(&opt.distributed, &opt,
|
||||
prep_err, sizeof(prep_err)) != 0) {
|
||||
fprintf(stderr, "ds4-worker: %s\n", prep_err);
|
||||
return 2;
|
||||
}
|
||||
|
||||
ds4_engine *engine = NULL;
|
||||
if (ds4_engine_open(&engine, &opt) != 0 || !engine) {
|
||||
fprintf(stderr, "ds4-worker: failed to open engine\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
ds4_dist_generation_options gen = {0};
|
||||
gen.ctx_size = ctx_size;
|
||||
int rc = ds4_dist_run(engine, &opt.distributed, &gen);
|
||||
ds4_engine_close(engine);
|
||||
return rc;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=8960c5ba5ee9db30ba838304373aa4dbec9f7cbd
|
||||
IK_LLAMA_VERSION?=3f40e73c367ad9f0c1b1819f28c7348c26aa340d
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=22d66b567eef11cf2e9832f04db64ee0323a0fd0
|
||||
LLAMA_VERSION?=5dcb71166686799f0d873eab7386234302d05ecf
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -2204,7 +2204,15 @@ public:
|
||||
// content element — attaching to both would duplicate the first
|
||||
// token since oaicompat_msg_diffs is the same for both.
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
// Upstream llama.cpp (ggml-org/llama.cpp#23884) now emits an initial
|
||||
// "begin" partial whose to_json() returns null, used only to signal the
|
||||
// HTTP layer to flush 200 status headers before any token. gRPC has no
|
||||
// such concept, so there is nothing to emit — the real tokens arrive in
|
||||
// the loop below. Feeding this null into build_reply_from_json would
|
||||
// throw (uncaught) and surface as a generic RPC error.
|
||||
if (first_res_json.is_null()) {
|
||||
// skip the begin-of-stream marker
|
||||
} else if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
// Skip chat deltas for role-init elements (have "role" in
|
||||
@@ -2234,7 +2242,10 @@ public:
|
||||
}
|
||||
|
||||
json res_json = result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
if (res_json.is_null()) {
|
||||
// begin-of-stream marker (see note above) — nothing to emit
|
||||
continue;
|
||||
} else if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
|
||||
5
backend/go/crispasr/.gitignore
vendored
Normal file
5
backend/go/crispasr/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
sources
|
||||
build*
|
||||
libgocrispasr*.so
|
||||
crispasr
|
||||
package
|
||||
30
backend/go/crispasr/CMakeLists.txt
Normal file
30
backend/go/crispasr/CMakeLists.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(gocrispasr LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
add_subdirectory(./sources/CrispASR)
|
||||
|
||||
add_library(gocrispasr MODULE cpp/crispasr_shim.cpp)
|
||||
target_include_directories(gocrispasr PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/ggml/include)
|
||||
# Link the same backend set as crispasr-cli (examples/cli/CMakeLists.txt) so
|
||||
# the session API can dispatch to every compiled-in architecture, not just
|
||||
# whisper. crispasr is the referencer; the backend static libs supply the
|
||||
# per-architecture symbols; ggml is the math/runtime base.
|
||||
target_link_libraries(gocrispasr PRIVATE
|
||||
crispasr
|
||||
parakeet canary canary_ctc cohere granite_speech granite_nle
|
||||
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
|
||||
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice
|
||||
silero-lid pyannote-seg funasr paraformer sensevoice
|
||||
crisp_audio
|
||||
ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gocrispasr PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
set_property(TARGET gocrispasr PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gocrispasr PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
132
backend/go/crispasr/Makefile
Normal file
132
backend/go/crispasr/Makefile
Normal file
@@ -0,0 +1,132 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=05e60432bcb5bc2113f8c395a41e86497c11504a
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
# Keep the build lean: no tests/examples/server/SDL2/curl/ffmpeg (the FROM scratch
|
||||
# image cannot satisfy those runtime deps). All ASR/TTS model backends stay enabled.
|
||||
CMAKE_ARGS+=-DCRISPASR_BUILD_TESTS=OFF -DCRISPASR_BUILD_EXAMPLES=OFF -DCRISPASR_BUILD_SERVER=OFF
|
||||
CMAKE_ARGS+=-DCRISPASR_SDL2=OFF -DCRISPASR_CURL=OFF -DCRISPASR_FFMPEG=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/CrispASR:
|
||||
mkdir -p sources/CrispASR
|
||||
cd sources/CrispASR && \
|
||||
git init && \
|
||||
git remote add origin $(CRISPASR_REPO) && \
|
||||
git fetch origin && \
|
||||
git checkout $(CRISPASR_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
# CrispASR's src/CMakeLists.txt locates its vendored llama.cpp
|
||||
# (crispasr-llama-core, used by the chat C-ABI) via ${CMAKE_SOURCE_DIR},
|
||||
# which assumes CrispASR is the top-level CMake project. We add_subdirectory
|
||||
# it, so ${CMAKE_SOURCE_DIR} is THIS backend dir and the talk-llama sources
|
||||
# aren't found. Rewrite to ${PROJECT_SOURCE_DIR} (the crispasr project root),
|
||||
# which is correct both standalone and as a subproject. Idempotent.
|
||||
sed -i 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgocrispasr-avx.so libgocrispasr-avx2.so libgocrispasr-avx512.so libgocrispasr-fallback.so
|
||||
else
|
||||
VARIANT_TARGETS = libgocrispasr-fallback.so
|
||||
endif
|
||||
|
||||
crispasr: main.go gocrispasr.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o crispasr ./
|
||||
|
||||
package: crispasr
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgocrispasr*.so package sources/CrispASR crispasr
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgocrispasr-avx.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx${RESET})
|
||||
SO_TARGET=libgocrispasr-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx2.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx2${RESET})
|
||||
SO_TARGET=libgocrispasr-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx512.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx512${RESET})
|
||||
SO_TARGET=libgocrispasr-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
libgocrispasr-fallback.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:fallback${RESET})
|
||||
SO_TARGET=libgocrispasr-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-custom: CMakeLists.txt cpp/crispasr_shim.cpp cpp/crispasr_shim.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgocrispasr.so ./$(SO_TARGET)
|
||||
|
||||
test: crispasr
|
||||
CGO_ENABLED=0 $(GOCMD) test -v ./...
|
||||
|
||||
all: crispasr package
|
||||
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
#include "crispasr_shim.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "crispasr.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
// Opaque session types. crispasr.h declares `struct crispasr_session;` but not
|
||||
// the result type nor the open/transcribe/result accessors — those are
|
||||
// CA_EXPORT extern "C" symbols in src/crispasr_c_api.cpp, so we forward-declare
|
||||
// exactly the ones we use. Signatures verified against
|
||||
// sources/CrispASR/src/crispasr_c_api.cpp.
|
||||
struct crispasr_session_result;
|
||||
extern "C" {
|
||||
crispasr_session *crispasr_session_open(const char *model_path, int n_threads);
|
||||
crispasr_session *crispasr_session_open_explicit(const char *model_path,
|
||||
const char *backend_name,
|
||||
int n_threads);
|
||||
int crispasr_session_set_codec_path(crispasr_session *s, const char *path);
|
||||
void crispasr_session_close(crispasr_session *s);
|
||||
const char *crispasr_session_backend(crispasr_session *s);
|
||||
int crispasr_session_set_translate(crispasr_session *s, int enable);
|
||||
crispasr_session_result *crispasr_session_transcribe_lang(
|
||||
crispasr_session *s, const float *pcm, int n_samples, const char *language);
|
||||
int crispasr_session_result_n_segments(crispasr_session_result *r);
|
||||
const char *crispasr_session_result_segment_text(crispasr_session_result *r,
|
||||
int i);
|
||||
int64_t crispasr_session_result_segment_t0(crispasr_session_result *r, int i);
|
||||
int64_t crispasr_session_result_segment_t1(crispasr_session_result *r, int i);
|
||||
void crispasr_session_result_free(crispasr_session_result *r);
|
||||
float *crispasr_session_synthesize(crispasr_session *s, const char *text,
|
||||
int *out_n_samples);
|
||||
void crispasr_pcm_free(float *pcm);
|
||||
int crispasr_session_set_speaker_name(crispasr_session *s, const char *name);
|
||||
int crispasr_session_set_voice(crispasr_session *s, const char *path,
|
||||
const char *ref_text_or_null);
|
||||
}
|
||||
|
||||
static crispasr_session *g_session = nullptr;
|
||||
static crispasr_session_result *g_result = nullptr;
|
||||
|
||||
static struct whisper_vad_context *vctx;
|
||||
static std::vector<float> flat_segs;
|
||||
|
||||
static std::atomic<int> g_abort{0};
|
||||
|
||||
extern "C" void set_abort(int v) {
|
||||
g_abort.store(v, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
|
||||
if (!log) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG:
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_INFO:
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_WARN:
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default: /* Potential future-proofing */
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[%-5s] ", level_str);
|
||||
fputs(log, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (backend_name && *backend_name) {
|
||||
g_session =
|
||||
crispasr_session_open_explicit(model_path, backend_name, threads);
|
||||
} else {
|
||||
g_session = crispasr_session_open(model_path, threads);
|
||||
}
|
||||
if (g_session == nullptr) {
|
||||
fprintf(stderr, "error: failed to open CrispASR session for model\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "info: CrispASR backend selected: %s\n",
|
||||
crispasr_session_backend(g_session));
|
||||
return 0;
|
||||
}
|
||||
|
||||
// set_codec_path forwards a companion file (qwen3-tts codec, orpheus SNAC,
|
||||
// chatterbox s3gen, or mimo-asr tokenizer) to the active session. Returns 0 on
|
||||
// success or when the active backend needs no companion, negative on failure,
|
||||
// and -1 when no session is open.
|
||||
int set_codec_path(const char *path) {
|
||||
return g_session ? crispasr_session_set_codec_path(g_session, path) : -1;
|
||||
}
|
||||
|
||||
int load_model_vad(const char *const model_path) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
struct whisper_vad_context_params vcparams =
|
||||
whisper_vad_default_context_params();
|
||||
|
||||
// XXX: Overridden to false in upstream due to performance?
|
||||
// vcparams.use_gpu = true;
|
||||
|
||||
vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
|
||||
if (vctx == nullptr) {
|
||||
fprintf(stderr, "error: Failed to init model as VAD\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
|
||||
size_t *segs_out_len) {
|
||||
if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
|
||||
fprintf(stderr, "error: failed to detect speech\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct whisper_vad_params params = whisper_vad_default_params();
|
||||
struct whisper_vad_segments *segs =
|
||||
whisper_vad_segments_from_probs(vctx, params);
|
||||
size_t segn = whisper_vad_segments_n_segments(segs);
|
||||
|
||||
// fprintf(stderr, "Got segments %zd\n", segn);
|
||||
|
||||
flat_segs.clear();
|
||||
|
||||
for (int i = 0; i < segn; i++) {
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
|
||||
}
|
||||
|
||||
// fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
|
||||
// segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
|
||||
// flat_segs.size());
|
||||
*segs_out = flat_segs.data();
|
||||
*segs_out_len = flat_segs.size();
|
||||
|
||||
// fprintf(stderr, "freeing segs\n");
|
||||
whisper_vad_free_segments(segs);
|
||||
|
||||
// fprintf(stderr, "returning\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// threads, diarize and prompt are accepted for Go-side API parity but unused
|
||||
// in Phase 1: the thread count is fixed at session open, and diarization and
|
||||
// the initial prompt are separate CrispASR features not yet wired through the
|
||||
// session ASR path.
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt) {
|
||||
(void)threads;
|
||||
(void)diarize;
|
||||
(void)prompt;
|
||||
|
||||
if (!g_session) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Reset stale abort flag from any prior cancelled call. set_abort remains
|
||||
// best-effort: the session transcribe call is blocking and exposes no abort
|
||||
// hook, so a mid-decode abort cannot interrupt it.
|
||||
g_abort.store(0, std::memory_order_relaxed);
|
||||
|
||||
crispasr_session_set_translate(g_session, translate ? 1 : 0);
|
||||
|
||||
if (g_result) {
|
||||
crispasr_session_result_free(g_result);
|
||||
g_result = nullptr;
|
||||
}
|
||||
|
||||
const char *language = (lang && *lang) ? lang : nullptr;
|
||||
g_result = crispasr_session_transcribe_lang(g_session, pcmf32, (int)pcmf32_len,
|
||||
language);
|
||||
if (!g_result) {
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
*segs_out_len = crispasr_session_result_n_segments(g_result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char *get_segment_text(int i) {
|
||||
if (!g_result) {
|
||||
return "";
|
||||
}
|
||||
return crispasr_session_result_segment_text(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t0(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t0(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t1(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t1(g_result, i);
|
||||
}
|
||||
|
||||
const char *get_backend(void) {
|
||||
return g_session ? crispasr_session_backend(g_session) : "";
|
||||
}
|
||||
|
||||
// TTS uses the already-open session (crispasr_session_open auto-detects a TTS
|
||||
// model). Output is 24 kHz mono float PCM (upstream CrispASR convention),
|
||||
// malloc'd by the C API; the caller must release it via tts_free.
|
||||
float *tts_synthesize(const char *text, int *out_n_samples) {
|
||||
if (out_n_samples) *out_n_samples = 0;
|
||||
if (!g_session || !text) return nullptr;
|
||||
return crispasr_session_synthesize(g_session, text, out_n_samples);
|
||||
}
|
||||
|
||||
void tts_free(float *pcm) {
|
||||
if (pcm) crispasr_pcm_free(pcm);
|
||||
}
|
||||
|
||||
int tts_set_voice(const char *name) {
|
||||
if (!g_session || !name || !*name) return 0;
|
||||
return crispasr_session_set_speaker_name(g_session, name);
|
||||
}
|
||||
|
||||
// tts_set_voice_file loads a voice from a file: a .gguf path selects a voice
|
||||
// pack, a .wav path with a non-empty ref_text performs zero-shot voice cloning
|
||||
// (the C API returns -2 when ref_text is required but missing). Returns -1 when
|
||||
// no session is open or path is null.
|
||||
int tts_set_voice_file(const char *path, const char *ref_text) {
|
||||
if (!g_session || !path) return -1;
|
||||
const char *ref = (ref_text && *ref_text) ? ref_text : nullptr;
|
||||
return crispasr_session_set_voice(g_session, path, ref);
|
||||
}
|
||||
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
extern "C" {
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name);
|
||||
int set_codec_path(const char *path);
|
||||
int load_model_vad(const char *const model_path);
|
||||
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
||||
size_t *segs_out_len);
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt);
|
||||
const char *get_segment_text(int i);
|
||||
int64_t get_segment_t0(int i);
|
||||
int64_t get_segment_t1(int i);
|
||||
const char *get_backend(void);
|
||||
void set_abort(int v);
|
||||
float *tts_synthesize(const char *text, int *out_n_samples); // 24kHz mono float, malloc'd; NULL on failure
|
||||
void tts_free(float *pcm);
|
||||
int tts_set_voice(const char *name); // best-effort speaker selection; 0 ok
|
||||
int tts_set_voice_file(const char *path, const char *ref_text); // load voice pack (.gguf) or zero-shot clone (.wav + ref_text)
|
||||
}
|
||||
497
backend/go/crispasr/gocrispasr.go
Normal file
497
backend/go/crispasr/gocrispasr.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string, threads int, backendName string) int
|
||||
CppSetCodecPath func(path string) int
|
||||
CppLoadModelVAD func(modelPath string) int
|
||||
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
||||
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int
|
||||
CppGetSegmentText func(i int) string
|
||||
CppGetSegmentStart func(i int) int64
|
||||
CppGetSegmentEnd func(i int) int64
|
||||
CppGetBackend func() string
|
||||
CppSetAbort func(v int)
|
||||
CppTTSSynthesize func(text string, outNSamples unsafe.Pointer) uintptr
|
||||
CppTTSFree func(ptr uintptr)
|
||||
CppTTSSetVoice func(name string) int
|
||||
CppTTSSetVoiceFile func(path string, refText string) int
|
||||
)
|
||||
|
||||
type CrispASR struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
// splitOption splits a "prefix:value" model option into its key and value,
|
||||
// matching the convention used by other backends (see sherpa-onnx). It returns
|
||||
// ok=false when the option carries no ':' separator.
|
||||
func splitOption(oo string) (key, value string, ok bool) {
|
||||
parts := strings.SplitN(oo, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
|
||||
func (w *CrispASR) Load(opts *pb.ModelOptions) error {
|
||||
vadOnly := false
|
||||
backendName := ""
|
||||
codecPath := ""
|
||||
speakerName := ""
|
||||
voicePath := ""
|
||||
voiceRefText := ""
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
if oo == "vad_only" {
|
||||
vadOnly = true
|
||||
continue
|
||||
}
|
||||
switch key, value, ok := splitOption(oo); {
|
||||
case ok && key == "backend":
|
||||
backendName = value
|
||||
case ok && key == "codec":
|
||||
codecPath = value
|
||||
case ok && key == "speaker":
|
||||
speakerName = value
|
||||
case ok && key == "voice":
|
||||
voicePath = value
|
||||
case ok && key == "voice_text":
|
||||
voiceRefText = value
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
}
|
||||
|
||||
if vadOnly {
|
||||
if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR VAD model")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve a relative companion path against the model directory so a config
|
||||
// can reference a sibling codec/tokenizer file by name alone.
|
||||
if codecPath != "" && !filepath.IsAbs(codecPath) {
|
||||
codecPath = filepath.Join(filepath.Dir(opts.ModelFile), codecPath)
|
||||
}
|
||||
|
||||
// A voice file (.gguf pack or .wav prompt) is resolved against the model
|
||||
// directory just like the codec, so a config can reference a sibling file.
|
||||
if voicePath != "" && !filepath.IsAbs(voicePath) {
|
||||
voicePath = filepath.Join(filepath.Dir(opts.ModelFile), voicePath)
|
||||
}
|
||||
|
||||
if ret := CppLoadModel(opts.ModelFile, int(opts.Threads), backendName); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR transcription model")
|
||||
}
|
||||
|
||||
// Load the companion file (codec/tokenizer/s3gen) after the session is open.
|
||||
// rc==0 means success or "not applicable" for the active backend; only a
|
||||
// negative code is fatal.
|
||||
if codecPath != "" {
|
||||
if rc := CppSetCodecPath(codecPath); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load companion file %q (rc=%d)", codecPath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR companion file loaded: %s\n", codecPath)
|
||||
}
|
||||
|
||||
// Apply the Load-time default voice. A baked speaker (speaker:) is selected
|
||||
// by name and is best-effort: a backend that can't honor it is logged, not
|
||||
// fatal. A voice file (voice:) is a hard requirement once configured, so a
|
||||
// negative rc fails Load.
|
||||
if speakerName != "" {
|
||||
if rc := CppTTSSetVoice(speakerName); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: speaker %q not applied (rc=%d)\n", speakerName, rc)
|
||||
}
|
||||
}
|
||||
if voicePath != "" {
|
||||
if rc := CppTTSSetVoiceFile(voicePath, voiceRefText); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load voice %q (rc=%d)", voicePath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR voice loaded: %s\n", voicePath)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "CrispASR backend selected: %s\n", CppGetBackend())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
audio := req.Audio
|
||||
// We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
|
||||
segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
|
||||
segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
|
||||
return pb.VADResponse{}, fmt.Errorf("Failed VAD")
|
||||
}
|
||||
|
||||
// Happens when CPP vector has not had any elements pushed to it
|
||||
if segsPtr == 0 {
|
||||
return pb.VADResponse{
|
||||
Segments: []*pb.VADSegment{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
|
||||
// however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
|
||||
segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen) //nolint:govet // segsPtr addresses C++-owned heap memory passed back through the cgo-free purego boundary; the uintptr->Pointer round-trip is intentional and the buffer outlives this read.
|
||||
|
||||
vadSegments := []*pb.VADSegment{}
|
||||
for i := range len(segs) >> 1 {
|
||||
s := segs[2*i] / 100
|
||||
t := segs[2*i+1] / 100
|
||||
vadSegments = append(vadSegments, &pb.VADSegment{
|
||||
Start: s,
|
||||
End: t,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.VADResponse{
|
||||
Segments: vadSegments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
// Watcher: flips the C-side abort flag when ctx is cancelled. The
|
||||
// goroutine is joined synchronously (close(done) signals it to exit,
|
||||
// wg.Wait() blocks until it has) so a late CppSetAbort(1) cannot fire
|
||||
// after the function returns and corrupt the next transcription call.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
text := ""
|
||||
for i := range int(segsLen) {
|
||||
// segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895
|
||||
s := CppGetSegmentStart(i) * (10000000)
|
||||
t := CppGetSegmentEnd(i) * (10000000)
|
||||
// The session result can emit bytes that aren't valid UTF-8 (e.g. a
|
||||
// multibyte codepoint split across token boundaries); protobuf string
|
||||
// fields reject those at marshal time. Scrub before the value escapes
|
||||
// cgo. The session result is segment+word based and exposes no token
|
||||
// IDs, so Tokens is left empty.
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
|
||||
segment := &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
|
||||
text += " " + strings.TrimSpace(txt)
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: strings.TrimSpace(text),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream runs the session transcribe to completion and then
|
||||
// emits one delta per non-empty segment, followed by a final TranscriptResult.
|
||||
// Progressive/real-time streaming isn't available via the session API (there
|
||||
// is no per-decode callback), so deltas are emitted per-segment after the
|
||||
// blocking decode returns rather than as segments are produced. The offline
|
||||
// AudioTranscription is unchanged; both paths share the session and the
|
||||
// SingleThread concurrency model.
|
||||
func (w *CrispASR) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
|
||||
// Same abort-watcher pattern as AudioTranscription. Joined synchronously
|
||||
// so a late CppSetAbort(1) cannot fire after this function returns.
|
||||
// Best-effort only: the session transcribe is blocking with no abort hook.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
// Walk the segments once: emit a delta per non-empty segment and build the
|
||||
// final TranscriptResult.Segments alongside. The first delta has no leading
|
||||
// space and subsequent ones are prefixed with a single space, so
|
||||
// concat(deltas) == final.Text exactly, matching the e2e contract.
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
var assembled strings.Builder
|
||||
for i := range int(segsLen) {
|
||||
s := CppGetSegmentStart(i) * 10000000
|
||||
t := CppGetSegmentEnd(i) * 10000000
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
segments = append(segments, &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
})
|
||||
|
||||
trimmed := strings.TrimSpace(txt)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
var delta string
|
||||
if assembled.Len() == 0 {
|
||||
delta = trimmed
|
||||
} else {
|
||||
delta = " " + trimmed
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
assembled.WriteString(delta)
|
||||
}
|
||||
|
||||
final := &pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: assembled.String(),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: final}
|
||||
return nil
|
||||
}
|
||||
|
||||
// synthesize returns 24 kHz mono float32 PCM for text via the open session.
|
||||
func (w *CrispASR) synthesize(text string) ([]float32, error) {
|
||||
if text == "" {
|
||||
return nil, fmt.Errorf("crispasr: TTS requires non-empty text")
|
||||
}
|
||||
var n int32
|
||||
ptr := CppTTSSynthesize(text, unsafe.Pointer(&n))
|
||||
if ptr == 0 || n <= 0 {
|
||||
return nil, fmt.Errorf("crispasr: synthesis failed (the loaded model may not be a supported TTS backend, or needs extra config e.g. orpheus SNAC codec)")
|
||||
}
|
||||
defer CppTTSFree(ptr)
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free.
|
||||
out := make([]float32, int(n)) // copy out of C memory before free
|
||||
copy(out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// setVoice applies a per-call speaker/voice override (best effort). CrispASR
|
||||
// returns a negative code when the active backend can't honor the name; we log
|
||||
// it rather than fail, so an unknown voice falls back to the default speaker.
|
||||
func setVoice(voice string) {
|
||||
v := strings.TrimSpace(voice)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if rc := CppTTSSetVoice(v); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: voice %q not applied by the active TTS backend (rc=%d); using default\n", v, rc)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *CrispASR) TTS(req *pb.TTSRequest) error {
|
||||
if req.Dst == "" {
|
||||
return fmt.Errorf("crispasr: TTS requires a destination path")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWAV24k(req.Dst, pcm)
|
||||
}
|
||||
|
||||
// TTSStream is the streaming counterpart to TTS. CrispASR has no progressive
|
||||
// (native streaming) synth, so we synthesize the whole utterance, encode it to
|
||||
// a 24 kHz WAV, and emit the encoded bytes as a single chunk. The gRPC server
|
||||
// wrapper (pkg/grpc/server.go:TTSStream) ranges over the channel until it is
|
||||
// closed, so this method owns the close - mirrors vibevoice-cpp's TTSStream.
|
||||
func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
defer close(results)
|
||||
|
||||
if req.Text == "" {
|
||||
return fmt.Errorf("crispasr: TTSStream requires text")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp("", "crispasr-tts-stream-*.wav")
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: tempfile: %w", err)
|
||||
}
|
||||
dst := tmp.Name()
|
||||
if err := tmp.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close tempfile: %w", err)
|
||||
}
|
||||
defer func() { _ = os.Remove(dst) }()
|
||||
|
||||
if err := writeWAV24k(dst, pcm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoded, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: read tempfile: %w", err)
|
||||
}
|
||||
results <- encoded
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeWAV24k writes pcm as a 24000 Hz, mono, 16-bit PCM WAV at dst.
|
||||
func writeWAV24k(dst string, pcm []float32) error {
|
||||
f, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: create %q: %w", dst, err)
|
||||
}
|
||||
|
||||
enc := wav.NewEncoder(f, 24000, 16, 1, 1)
|
||||
ints := make([]int, len(pcm))
|
||||
for i, s := range pcm {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
ints[i] = int(s * 32767)
|
||||
}
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 24000},
|
||||
Data: ints,
|
||||
SourceBitDepth: 16,
|
||||
}
|
||||
if err := enc.Write(buf); err != nil {
|
||||
_ = enc.Close()
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: encode WAV: %w", err)
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: finalize WAV: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close %q: %w", dst, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
193
backend/go/crispasr/gocrispasr_test.go
Normal file
193
backend/go/crispasr/gocrispasr_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestCrispASR(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "CrispASR Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// bridge without spinning up the gRPC server. Skips the current spec when the
|
||||
// shared library isn't present (e.g. running before `make backends/whisper`).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
if _, err := os.Stat(libName); err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
|
||||
purego.RegisterLibFunc(&CppSetCodecPath, gosd, "set_codec_path")
|
||||
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
|
||||
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
|
||||
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
|
||||
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
|
||||
purego.RegisterLibFunc(&CppGetBackend, gosd, "get_backend")
|
||||
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
|
||||
purego.RegisterLibFunc(&CppTTSSynthesize, gosd, "tts_synthesize")
|
||||
purego.RegisterLibFunc(&CppTTSFree, gosd, "tts_free")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoice, gosd, "tts_set_voice")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoiceFile, gosd, "tts_set_voice_file")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("whisper library not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if either
|
||||
// env var is unset. The test never runs in default CI — it requires a real
|
||||
// whisper model and a long audio file (~3 minutes) on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("CRISPASR_MODEL_PATH")
|
||||
audioPath := os.Getenv("CRISPASR_AUDIO_PATH")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set CRISPASR_MODEL_PATH and CRISPASR_AUDIO_PATH to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// ttsModelOrSkip returns the TTS model path or skips the spec when the env var
|
||||
// is unset. Like the transcription fixtures, this never runs in default CI — it
|
||||
// needs a real TTS model (e.g. a vibevoice GGUF) on disk.
|
||||
func ttsModelOrSkip() string {
|
||||
modelPath := os.Getenv("CRISPASR_TTS_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
Skip("set CRISPASR_TTS_MODEL_PATH to run this spec")
|
||||
}
|
||||
return modelPath
|
||||
}
|
||||
|
||||
var _ = Describe("CrispASR", func() {
|
||||
Context("AudioTranscription cancellation", func() {
|
||||
It("returns codes.Canceled on a pre-cancelled context and still succeeds afterwards", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
// The session transcribe is blocking and exposes no abort hook, so
|
||||
// a mid-decode cancel can't interrupt it. The contract we can rely
|
||||
// on is the pre-call ctx.Err() check: a context cancelled before
|
||||
// the call must yield codes.Canceled without starting a decode.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "expected pre-cancelled context to fail")
|
||||
st, ok := status.FromError(err)
|
||||
Expect(ok).To(BeTrue(), "expected gRPC status error, got %v", err)
|
||||
Expect(st.Code()).To(Equal(codes.Canceled), "expected codes.Canceled, got %v", err)
|
||||
|
||||
// Subsequent transcription must succeed — proves g_abort reset.
|
||||
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "post-cancel transcription failed")
|
||||
Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("emits multiple deltas progressively for a multi-segment clip", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- w.AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
Stream: true,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var assembled strings.Builder
|
||||
var finalText string
|
||||
var finalSegmentCount int
|
||||
for chunk := range results {
|
||||
if d := chunk.GetDelta(); d != "" {
|
||||
deltas = append(deltas, d)
|
||||
assembled.WriteString(d)
|
||||
}
|
||||
if final := chunk.GetFinalResult(); final != nil {
|
||||
finalText = final.GetText()
|
||||
finalSegmentCount = len(final.GetSegments())
|
||||
}
|
||||
}
|
||||
Expect(<-done).ToNot(HaveOccurred())
|
||||
|
||||
// One delta per non-empty segment is emitted after the blocking
|
||||
// decode returns (the session API has no per-decode callback), so a
|
||||
// multi-segment clip MUST produce >=2 delta events, and
|
||||
// concat(deltas) MUST equal final.Text exactly.
|
||||
Expect(len(deltas)).To(BeNumerically(">=", 2),
|
||||
"expected multiple deltas from a multi-segment clip, got %d (assembled=%q)",
|
||||
len(deltas), assembled.String())
|
||||
Expect(finalSegmentCount).To(BeNumerically(">=", 2),
|
||||
"expected final to carry multiple segments")
|
||||
Expect(assembled.String()).To(Equal(finalText),
|
||||
"concat(deltas) must equal final.Text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("TTS", func() {
|
||||
It("synthesizes a non-empty WAV", func() {
|
||||
ttsModel := ttsModelOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: ttsModel})).To(Succeed())
|
||||
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR.", Dst: dst})).To(Succeed())
|
||||
|
||||
info, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred(), "synthesized WAV should exist at %q", dst)
|
||||
// A real 24 kHz mono WAV is a 44-byte header plus samples; anything
|
||||
// this small would mean an empty/failed synth.
|
||||
Expect(info.Size()).To(BeNumerically(">", 1024),
|
||||
"expected a non-trivial WAV, got %d bytes", info.Size())
|
||||
})
|
||||
})
|
||||
})
|
||||
58
backend/go/crispasr/main.go
Normal file
58
backend/go/crispasr/main.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "load_model"},
|
||||
{&CppSetCodecPath, "set_codec_path"},
|
||||
{&CppLoadModelVAD, "load_model_vad"},
|
||||
{&CppVAD, "vad"},
|
||||
{&CppTranscribe, "transcribe"},
|
||||
{&CppGetSegmentText, "get_segment_text"},
|
||||
{&CppGetSegmentStart, "get_segment_t0"},
|
||||
{&CppGetSegmentEnd, "get_segment_t1"},
|
||||
{&CppGetBackend, "get_backend"},
|
||||
{&CppSetAbort, "set_abort"},
|
||||
{&CppTTSSynthesize, "tts_synthesize"},
|
||||
{&CppTTSFree, "tts_free"},
|
||||
{&CppTTSSetVoice, "tts_set_voice"},
|
||||
{&CppTTSSetVoiceFile, "tts_set_voice_file"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &CrispASR{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
65
backend/go/crispasr/package.sh
Executable file
65
backend/go/crispasr/package.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
# This script is used in the final stage of the Dockerfile
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/crispasr $CURDIR/package/
|
||||
cp -fv $CURDIR/libgocrispasr-*.so $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/crispasr/run.sh
Executable file
52
backend/go/crispasr/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgocrispasr-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export CRISPASR_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/crispasr "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/crispasr "$@"
|
||||
@@ -9,7 +9,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
# LocalVQE upstream version pin. Bump to a specific commit when picking up
|
||||
# a new release; `main` works for development but is not reproducible.
|
||||
LOCALVQE_REPO?=https://github.com/localai-org/LocalVQE
|
||||
LOCALVQE_VERSION?=72bfb4c6
|
||||
LOCALVQE_VERSION?=b0f0378a450e87c871b85689554801601ca56d98
|
||||
|
||||
# LocalVQE handles CPU feature selection internally (it ships the multiple
|
||||
# libggml-cpu-*.so variants and its loader picks the best one at runtime
|
||||
@@ -27,7 +27,8 @@ endif
|
||||
|
||||
# LocalVQE upstream supports CPU + Vulkan only. Other BUILD_TYPE values
|
||||
# fall through to the default CPU build — Vulkan is already as fast as the
|
||||
# specialised GPU paths would be on this 1.3 M-parameter model.
|
||||
# specialised GPU paths would be on these small (1.3 M–4.8 M parameter)
|
||||
# models.
|
||||
ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DLOCALVQE_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -46,24 +46,24 @@ const (
|
||||
// through the options builder (CppOptionsNew + setters + CppNewWithOptions)
|
||||
// — the bare localvqe_new path doesn't expose backend / device selection.
|
||||
var (
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
)
|
||||
|
||||
// LocalVQE speaks gRPC against LocalVQE's flat C ABI. The streaming
|
||||
@@ -490,11 +490,14 @@ func (v *LocalVQE) applyStreamConfig(cfg *pb.AudioTransformStreamConfig) error {
|
||||
|
||||
// ---- WAV I/O ----------------------------------------------------------
|
||||
//
|
||||
// Minimal mono PCM WAV reader/writer. Only handles the subset LocalVQE
|
||||
// cares about (mono, 16-bit signed, no extensible chunks). For broader
|
||||
// audio support the HTTP layer's `audio.NormalizeAudioFile` already
|
||||
// converts arbitrary input to a canonical WAV before we see it; this
|
||||
// reader just decodes the canonical shape.
|
||||
// Reader/writer for the mono 16-bit PCM shape LocalVQE works with. Decoding
|
||||
// goes through the shared go-audio/wav decoder (as the whisper and parakeet
|
||||
// backends do) so RIFF chunk walking is handled robustly — an 18/40-byte
|
||||
// extensible `fmt ` chunk, or JUNK/bext/LIST metadata before or after `data`
|
||||
// (e.g. ffmpeg's trailing "Lavf" tag), is skipped rather than spliced into
|
||||
// the PCM stream as an audible click. The HTTP layer normalises arbitrary
|
||||
// input to WAV before we see it, but that WAV is ffmpeg output and is not
|
||||
// guaranteed to be the canonical 44-byte layout.
|
||||
|
||||
func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
f, err := os.Open(path)
|
||||
@@ -502,35 +505,26 @@ func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
header := make([]byte, 44)
|
||||
if _, err := io.ReadFull(f, header); err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
buf, err := wav.NewDecoder(f).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("decode WAV: %w", err)
|
||||
}
|
||||
if string(header[0:4]) != "RIFF" || string(header[8:12]) != "WAVE" {
|
||||
if buf == nil || buf.Format == nil {
|
||||
return nil, 0, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
channels := binary.LittleEndian.Uint16(header[22:24])
|
||||
sampleRate := binary.LittleEndian.Uint32(header[24:28])
|
||||
bitsPerSample := binary.LittleEndian.Uint16(header[34:36])
|
||||
|
||||
if channels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", channels)
|
||||
if buf.Format.NumChannels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", buf.Format.NumChannels)
|
||||
}
|
||||
if bitsPerSample != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", bitsPerSample)
|
||||
if buf.SourceBitDepth != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", buf.SourceBitDepth)
|
||||
}
|
||||
|
||||
rest, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
if len(buf.Data) == 0 {
|
||||
return nil, 0, fmt.Errorf("WAV has no audio data")
|
||||
}
|
||||
n := len(rest) / 2
|
||||
out := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
s := int16(binary.LittleEndian.Uint16(rest[i*2 : i*2+2]))
|
||||
out[i] = float32(s) / 32768.0
|
||||
}
|
||||
return out, int(sampleRate), nil
|
||||
// AsFloat32Buffer normalises by 2^(bitDepth-1) == /32768 for 16-bit,
|
||||
// matching the model's expected [-1, 1) input range.
|
||||
return buf.AsFloat32Buffer().Data, buf.Format.SampleRate, nil
|
||||
}
|
||||
|
||||
func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
@@ -546,13 +540,13 @@ func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
binary.LittleEndian.PutUint32(header[4:8], 36+dataLen)
|
||||
copy(header[8:12], []byte("WAVE"))
|
||||
copy(header[12:16], []byte("fmt "))
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[24:28], uint32(sampleRate))
|
||||
binary.LittleEndian.PutUint32(header[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
copy(header[36:40], []byte("data"))
|
||||
binary.LittleEndian.PutUint32(header[40:44], dataLen)
|
||||
if _, err := f.Write(header); err != nil {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -92,6 +94,147 @@ var _ = Describe("LocalVQE-cpp", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("readMonoWAVf32 chunk parsing", func() {
|
||||
// chunk builds a word-aligned RIFF sub-chunk (id + size + body + pad).
|
||||
chunk := func(id string, body []byte) []byte {
|
||||
out := append([]byte(id), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
out = append(out, body...)
|
||||
if len(body)&1 == 1 {
|
||||
out = append(out, 0) // pad byte for odd-sized chunks
|
||||
}
|
||||
return out
|
||||
}
|
||||
// fmtBody returns a PCM `fmt ` chunk body. extra bytes simulate the
|
||||
// 18/40-byte extensible form (cbSize + extension).
|
||||
fmtBody := func(channels, bits uint16, rate uint32, extra int) []byte {
|
||||
b := make([]byte, 16+extra)
|
||||
binary.LittleEndian.PutUint16(b[0:2], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(b[2:4], channels)
|
||||
binary.LittleEndian.PutUint32(b[4:8], rate)
|
||||
binary.LittleEndian.PutUint32(b[8:12], rate*uint32(channels)*uint32(bits)/8)
|
||||
binary.LittleEndian.PutUint16(b[12:14], channels*bits/8)
|
||||
binary.LittleEndian.PutUint16(b[14:16], bits)
|
||||
if extra >= 2 {
|
||||
binary.LittleEndian.PutUint16(b[16:18], uint16(extra-2)) // cbSize
|
||||
}
|
||||
return b
|
||||
}
|
||||
// pcm encodes int16 samples little-endian.
|
||||
pcm := func(samples ...int16) []byte {
|
||||
b := make([]byte, len(samples)*2)
|
||||
for i, s := range samples {
|
||||
binary.LittleEndian.PutUint16(b[i*2:i*2+2], uint16(s))
|
||||
}
|
||||
return b
|
||||
}
|
||||
riff := func(chunks ...[]byte) []byte {
|
||||
body := []byte("WAVE")
|
||||
for _, c := range chunks {
|
||||
body = append(body, c...)
|
||||
}
|
||||
out := append([]byte("RIFF"), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
return append(out, body...)
|
||||
}
|
||||
writeWAV := func(b []byte) string {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "in.wav")
|
||||
Expect(os.WriteFile(p, b, 0o600)).To(Succeed())
|
||||
return p
|
||||
}
|
||||
// A canonical sample run with distinct values so any off-by-one /
|
||||
// misalignment shows up as wrong numbers, not just wrong length.
|
||||
samples := []int16{1000, -2000, 3000, -4000, 5000, -6000}
|
||||
expectSamples := func(got []float32) {
|
||||
Expect(got).To(HaveLen(len(samples)))
|
||||
for i, s := range samples {
|
||||
Expect(got[i]).To(BeNumerically("~", float32(s)/32768.0, 1e-6))
|
||||
}
|
||||
}
|
||||
|
||||
It("reads a canonical 44-byte WAV", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("ignores a LIST/JUNK chunk placed before data (no leading-impulse splice)", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("JUNK", []byte("padding-bytes-here!")), // odd length → exercises pad
|
||||
chunk("LIST", []byte("INFOISFTLavf60.0")),
|
||||
chunk("data", pcm(samples...)),
|
||||
))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out) // not corrupted by the preceding chunks
|
||||
})
|
||||
|
||||
It("honours the data chunk size and drops a trailing metadata chunk", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("data", pcm(samples...)),
|
||||
chunk("LIST", []byte("INFOISFTLavf60.16.100")), // ffmpeg trailer tag
|
||||
))
|
||||
out, _, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectSamples(out) // trailing LIST bytes not decoded as PCM
|
||||
})
|
||||
|
||||
It("handles the 18-byte extensible fmt chunk", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 2)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("rejects non-mono input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(2, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("mono"))
|
||||
})
|
||||
|
||||
It("rejects non-16-bit input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 8, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("16-bit"))
|
||||
})
|
||||
|
||||
It("rejects a non-WAV file", func() {
|
||||
p := writeWAV([]byte("not a riff file at all"))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors when the data chunk is missing", func() {
|
||||
// fmt but no data: the decoder must fail rather than return an
|
||||
// empty (or garbage) sample slice. The exact message is the
|
||||
// decoder's, so just assert it errors.
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("round-trips through writeMonoWAVf32", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "rt.wav")
|
||||
in := []float32{0.1, -0.2, 0.3, -0.4}
|
||||
Expect(writeMonoWAVf32(p, in, 16000)).To(Succeed())
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
Expect(out).To(HaveLen(len(in)))
|
||||
for i := range in {
|
||||
Expect(out[i]).To(BeNumerically("~", in[i], 1e-4))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("model-gated integration (LOCALVQE_MODEL_PATH)", func() {
|
||||
It("load + sample rate + hop + fft", func() {
|
||||
path := modelPathOrSkip()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=30a307553f1965ceb38a1a922069a71e7dd67bf3
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
@@ -15,7 +15,7 @@
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=30a307553f1965ceb38a1a922069a71e7dd67bf3
|
||||
PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
@@ -34,14 +34,18 @@ ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# parakeet.cpp gates its GGML backends behind PARAKEET_GGML_* options and does
|
||||
# set(GGML_CUDA ${PARAKEET_GGML_CUDA} CACHE BOOL "" FORCE), so a bare -DGGML_CUDA=ON
|
||||
# is overwritten back to OFF and the build silently falls back to CPU. Forward the
|
||||
# PARAKEET_GGML_* options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DGGML_HIP=ON
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_HIP=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_VULKAN=ON
|
||||
endif
|
||||
|
||||
.PHONY: parakeet-cpp-grpc package build clean purge test all
|
||||
|
||||
79
backend/go/parakeet-cpp/batcher.go
Normal file
79
backend/go/parakeet-cpp/batcher.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package main
|
||||
|
||||
import "time"
|
||||
|
||||
// batchRequest is one in-flight unary transcription waiting to be batched.
|
||||
// In production pcm/decoder are set; tag is an opaque marker used by tests.
|
||||
type batchRequest struct {
|
||||
pcm []float32
|
||||
decoder int32
|
||||
tag string
|
||||
reply chan batchReply
|
||||
}
|
||||
|
||||
// batchReply carries one per-item JSON object string (an element of the C-API's
|
||||
// JSON array) or an error back to the waiting handler goroutine.
|
||||
type batchReply struct {
|
||||
json string
|
||||
err error
|
||||
}
|
||||
|
||||
// batcher coalesces concurrent batchRequests into batched runBatch calls. A
|
||||
// single run() goroutine is the sole caller of runBatch, so runBatch (which in
|
||||
// production calls the thread-unsafe C engine) is never entered concurrently.
|
||||
type batcher struct {
|
||||
submit chan *batchRequest
|
||||
maxSize int
|
||||
maxWait time.Duration
|
||||
runBatch func(reqs []*batchRequest) // must deliver a reply to every req
|
||||
}
|
||||
|
||||
func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchRequest)) *batcher {
|
||||
if maxSize < 1 {
|
||||
maxSize = 1
|
||||
}
|
||||
return &batcher{
|
||||
submit: make(chan *batchRequest),
|
||||
maxSize: maxSize,
|
||||
maxWait: maxWait,
|
||||
runBatch: runBatch,
|
||||
}
|
||||
}
|
||||
|
||||
// run is the dispatcher loop: accumulate submitted requests until either maxSize
|
||||
// is reached or maxWait elapses since the first queued request, then dispatch.
|
||||
// Exits when stop is closed (draining any partially-filled batch first).
|
||||
func (b *batcher) run(stop <-chan struct{}) {
|
||||
for {
|
||||
var first *batchRequest
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
batch := []*batchRequest{first}
|
||||
|
||||
// maxSize==1 disables batching: dispatch immediately (passthrough).
|
||||
if b.maxSize == 1 {
|
||||
b.runBatch(batch)
|
||||
continue
|
||||
}
|
||||
|
||||
timer := time.NewTimer(b.maxWait)
|
||||
fill:
|
||||
for len(batch) < b.maxSize {
|
||||
select {
|
||||
case r := <-b.submit:
|
||||
batch = append(batch, r)
|
||||
case <-timer.C:
|
||||
break fill
|
||||
case <-stop:
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
return
|
||||
}
|
||||
}
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
}
|
||||
}
|
||||
108
backend/go/parakeet-cpp/batcher_test.go
Normal file
108
backend/go/parakeet-cpp/batcher_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("batcher", func() {
|
||||
echoReply := func(reqs []*batchRequest) {
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{json: r.tag}
|
||||
}
|
||||
}
|
||||
|
||||
It("coalesces concurrent submits into batches", func() {
|
||||
var mu sync.Mutex
|
||||
var sizes []int
|
||||
run := func(reqs []*batchRequest) {
|
||||
mu.Lock()
|
||||
sizes = append(sizes, len(reqs))
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(4, 50*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
const N = 4
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
total, maxBatch := 0, 0
|
||||
for _, s := range sizes {
|
||||
total += s
|
||||
if s > maxBatch {
|
||||
maxBatch = s
|
||||
}
|
||||
}
|
||||
Expect(total).To(Equal(N))
|
||||
Expect(maxBatch).To(BeNumerically(">=", 2), "expected at least one batch to coalesce >1 request")
|
||||
})
|
||||
|
||||
It("dispatches when max size is reached", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(2, time.Hour, run) // huge window: only size can trigger
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
for i := 0; i < 2; i++ {
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func(rep chan batchReply) { <-rep }(rep)
|
||||
}
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(2)))
|
||||
})
|
||||
|
||||
It("dispatches when the wait window elapses", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(8, 20*time.Millisecond, run) // size unreachable; window fires
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("bypasses batching when max size is 1", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(1, time.Hour, run) // size 1 => immediate dispatch
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
})
|
||||
@@ -7,13 +7,17 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
@@ -34,6 +38,15 @@ var (
|
||||
CppFreeString func(s uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
|
||||
// Batched JSON transcription: takes a concatenated float buffer of clips
|
||||
// plus their per-clip sample counts (sum(nSamples)==len(samplesConcat))
|
||||
// and returns a malloc'd char* JSON ARRAY of per-clip {"text","words",
|
||||
// "tokens"} objects (uintptr, freed via CppFreeString). purego passes the
|
||||
// Go slices as the base pointer of their backing array (kept alive for the
|
||||
// call), matching the CppStreamFeed pcm []float32 binding pattern; the C
|
||||
// side reads them as const float*/const int*.
|
||||
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
|
||||
|
||||
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
|
||||
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
|
||||
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
|
||||
@@ -77,11 +90,18 @@ type transcriptToken struct {
|
||||
}
|
||||
|
||||
// ParakeetCpp owns a single loaded parakeet_ctx. The C engine is a
|
||||
// thread-unsafe singleton (mirrors whisper.cpp / vibevoice.cpp), so we
|
||||
// serialize calls through base.SingleThread.
|
||||
// thread-unsafe singleton (mirrors whisper.cpp / vibevoice.cpp). Rather than
|
||||
// serialize every call through base.SingleThread, we route unary
|
||||
// transcription through an in-process batcher (its sole dispatcher goroutine
|
||||
// is the only caller of the engine on that path) and guard the shared engine
|
||||
// with engineMu so a streaming session and a batched-unary dispatch never
|
||||
// touch it concurrently.
|
||||
type ParakeetCpp struct {
|
||||
base.SingleThread
|
||||
ctxPtr uintptr
|
||||
base.Base
|
||||
ctxPtr uintptr
|
||||
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
|
||||
bat *batcher
|
||||
batStop chan struct{}
|
||||
}
|
||||
|
||||
// Load is the LocalAI gRPC entry point for LoadModel: it calls
|
||||
@@ -100,13 +120,103 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
return fmt.Errorf("parakeet-cpp: parakeet_capi_load failed for %q", opts.ModelFile)
|
||||
}
|
||||
p.ctxPtr = ctx
|
||||
|
||||
// Dynamic batching knobs (model YAML options:, key:value form). Batching is
|
||||
// OFF by default (batch_max_size:1): each request runs on its own. On GPU,
|
||||
// raising batch_max_size coalesces concurrent requests into one batched
|
||||
// engine call and improves throughput under load; leave it at 1 on CPU and
|
||||
// for low-concurrency setups, where batching only adds latency.
|
||||
maxSize := optInt(opts, "batch_max_size", 1)
|
||||
maxWaitMs := optInt(opts, "batch_max_wait_ms", 15)
|
||||
if maxWaitMs < 0 {
|
||||
maxWaitMs = 0
|
||||
}
|
||||
if CppTranscribePcmBatchJSON != nil {
|
||||
p.batStop = make(chan struct{})
|
||||
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
|
||||
go p.bat.run(p.batStop) // dispatcher runs until Free closes batStop
|
||||
if maxSize > 1 {
|
||||
xlog.Info("parakeet-cpp: dynamic batching enabled",
|
||||
"batch_max_size", maxSize, "batch_max_wait_ms", maxWaitMs)
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: dynamic batching off (batch_max_size=1); " +
|
||||
"set batch_max_size>1 to coalesce concurrent requests on GPU")
|
||||
}
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: batched C-API not present in libparakeet.so; " +
|
||||
"batching disabled, using per-request transcription")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AudioTranscription runs parakeet_capi_transcribe_path_json on the wav at
|
||||
// opts.Dst with the default decoder (decoder=0, which selects the right head
|
||||
// per architecture: transducer for tdt/rnnt/hybrid, CTC for ctc) and shapes
|
||||
// the per-word timestamps into a LocalAI TranscriptResult.
|
||||
// optInt reads an integer model option (key:value form) from ModelOptions,
|
||||
// returning def when absent or unparseable. The options array carries the
|
||||
// model YAML's options: entries (see core/config; siblings such as
|
||||
// acestep-cpp parse the same key:value form via strings.Cut on ":").
|
||||
func optInt(opts *pb.ModelOptions, key string, def int) int {
|
||||
for _, o := range opts.GetOptions() {
|
||||
k, v, ok := strings.Cut(o, ":")
|
||||
if ok && strings.TrimSpace(k) == key {
|
||||
if n, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// runBatch is the dispatcher's batch handler and the ONLY caller of the C
|
||||
// engine on the unary path. It concatenates the batch PCM, calls the batched
|
||||
// JSON C-API under engineMu, splits the JSON array, and replies to each request.
|
||||
func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// Observability: the actual coalesced batch size per engine call. Debug-level
|
||||
// so it stays silent in normal operation but lets operators confirm/tune batching.
|
||||
xlog.Debug("parakeet-cpp: dispatching batch", "size", len(reqs))
|
||||
nSamples := make([]int32, len(reqs))
|
||||
total := 0
|
||||
for i, r := range reqs {
|
||||
nSamples[i] = int32(len(r.pcm))
|
||||
total += len(r.pcm)
|
||||
}
|
||||
concat := make([]float32, 0, total)
|
||||
for _, r := range reqs {
|
||||
concat = append(concat, r.pcm...)
|
||||
}
|
||||
var dec int32
|
||||
if len(reqs) > 0 {
|
||||
dec = reqs[0].decoder
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
p.engineMu.Unlock()
|
||||
if cstr == 0 {
|
||||
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: err}
|
||||
}
|
||||
return
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var docs []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(raw), &docs); err != nil || len(docs) != len(reqs) {
|
||||
e := fmt.Errorf("parakeet-cpp: batch json: got %d results for %d reqs (%v)", len(docs), len(reqs), err)
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: e}
|
||||
}
|
||||
return
|
||||
}
|
||||
for i, r := range reqs {
|
||||
r.reply <- batchReply{json: string(docs[i])}
|
||||
}
|
||||
}
|
||||
|
||||
// AudioTranscription decodes the wav at opts.Dst to 16 kHz mono PCM and
|
||||
// submits it to the in-process batcher, which coalesces concurrent requests
|
||||
// into a single batched engine call (parakeet_capi_transcribe_pcm_batch_json)
|
||||
// with the default decoder (decoder=0, which selects the right head per
|
||||
// architecture: transducer for tdt/rnnt/hybrid, CTC for ctc) and shapes the
|
||||
// per-word timestamps into a LocalAI TranscriptResult.
|
||||
//
|
||||
// Parakeet emits word- and token-level timestamps but no native segment
|
||||
// boundaries, so we synthesise a single whole-clip segment spanning the first
|
||||
@@ -118,7 +228,7 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
// translate/diarize/prompt/temperature/language/threads are not applicable to
|
||||
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
|
||||
}
|
||||
@@ -126,61 +236,74 @@ func (p *ParakeetCpp) AudioTranscription(_ context.Context, opts *pb.TranscriptR
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
|
||||
if cstr == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
// Fallback when the batched C-API is unavailable: transcribe directly from
|
||||
// the file path (original behavior, no batching).
|
||||
if p.bat == nil {
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
|
||||
if cstr == 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
|
||||
}
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", msg)
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
|
||||
// Batched path: decode to PCM, submit to the batcher, wait for this request's
|
||||
// JSON element. The dispatcher is the sole engine caller on this path; both
|
||||
// sends honour ctx cancellation.
|
||||
pcm, _, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
rep := make(chan batchReply, 1)
|
||||
select {
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
var res batchReply
|
||||
select {
|
||||
case res = <-rep:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if res.err != nil {
|
||||
return pb.TranscriptResult{}, res.err
|
||||
}
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
|
||||
// synthesising a single whole-clip segment and attaching word timings only when
|
||||
// the caller requested word granularity. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
|
||||
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
|
||||
for _, w := range doc.Words {
|
||||
words = append(words, &pb.TranscriptWord{
|
||||
Start: secondsToNanos(w.Start),
|
||||
End: secondsToNanos(w.End),
|
||||
Text: w.W,
|
||||
})
|
||||
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
|
||||
}
|
||||
|
||||
tokens := make([]int32, 0, len(doc.Tokens))
|
||||
for _, t := range doc.Tokens {
|
||||
tokens = append(tokens, t.ID)
|
||||
}
|
||||
|
||||
// Single whole-clip segment, spanning the first word start to the last
|
||||
// word end (0/0 when the clip produced no words).
|
||||
var segStart, segEnd int64
|
||||
if len(words) > 0 {
|
||||
segStart = words[0].Start
|
||||
segEnd = words[len(words)-1].End
|
||||
}
|
||||
seg := &pb.TranscriptSegment{
|
||||
Id: 0,
|
||||
Start: segStart,
|
||||
End: segEnd,
|
||||
Text: text,
|
||||
Tokens: tokens,
|
||||
}
|
||||
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
|
||||
if wordsRequested(opts.TimestampGranularities) {
|
||||
seg.Words = words
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{seg},
|
||||
}, nil
|
||||
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
|
||||
}
|
||||
|
||||
// wordsRequested reports whether the caller asked for word-level timestamps.
|
||||
@@ -243,6 +366,14 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return nil
|
||||
}
|
||||
defer CppStreamFree(stream)
|
||||
// The C engine is a single shared context: a streaming session and a batched
|
||||
// unary dispatch must never touch it at once, so hold engineMu for the whole
|
||||
// stream. This lock is intentionally taken AFTER the non-streaming fallback
|
||||
// above returns: that fallback goes through AudioTranscription -> the batcher
|
||||
// -> runBatch, which itself acquires engineMu, so locking here first would
|
||||
// deadlock. Do not hoist this lock above the fallback.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
|
||||
data, duration, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
@@ -362,6 +493,12 @@ func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
// Free releases the underlying parakeet_ctx. Called by LocalAI when the
|
||||
// model is unloaded.
|
||||
func (p *ParakeetCpp) Free() error {
|
||||
// Stop the dispatcher before releasing the engine so no in-flight runBatch
|
||||
// can touch a freed ctx (close leak / use-after-free on reload).
|
||||
if p.batStop != nil {
|
||||
close(p.batStop)
|
||||
p.batStop = nil
|
||||
}
|
||||
if p.ctxPtr != 0 {
|
||||
CppFree(p.ctxPtr)
|
||||
p.ctxPtr = 0
|
||||
|
||||
@@ -43,6 +43,9 @@ func ensureLibLoaded() {
|
||||
purego.RegisterLibFunc(&CppFree, lib, "parakeet_capi_free")
|
||||
purego.RegisterLibFunc(&CppTranscribePath, lib, "parakeet_capi_transcribe_path")
|
||||
purego.RegisterLibFunc(&CppTranscribePathJSON, lib, "parakeet_capi_transcribe_path_json")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppStreamBegin, lib, "parakeet_capi_stream_begin")
|
||||
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
|
||||
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
|
||||
|
||||
@@ -58,6 +58,13 @@ func main() {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
// The batched-JSON entry point exists only in newer libparakeet.so (ABI >= 2).
|
||||
// Probe with Dlsym and register only if present, so the backend still loads
|
||||
// against an older library (it falls back to per-request transcription).
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=0e4ee04488159b81d95a9ffcd983a077fd5dcb77
|
||||
STABLEDIFFUSION_GGML_VERSION?=2d40a8b2adcdf8b5b0ca0535f3bb7801b6ba13e5
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=f24588a272ae8e23280d9c220536437164e6ed28
|
||||
WHISPER_CPP_VERSION?=610e664ba7cfe3af46125ed1b5a1184fccb51bcd
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -122,6 +122,33 @@
|
||||
nvidia-cuda-12: "cuda12-whisper"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisper"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-whisper"
|
||||
- &crispasr
|
||||
name: "crispasr"
|
||||
alias: "crispasr"
|
||||
license: mit
|
||||
icon: https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg
|
||||
description: |
|
||||
CrispASR unified speech engine (whisper.cpp fork on ggml) supporting many ASR architectures (Parakeet, Canary, Voxtral, Qwen3-ASR, Granite, Wav2Vec2, Moonshine, OmniASR, FireRedASR, and more).
|
||||
urls:
|
||||
- https://github.com/CrispStrobe/CrispASR
|
||||
tags:
|
||||
- audio-transcription
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-crispasr"
|
||||
nvidia: "cuda12-crispasr"
|
||||
intel: "intel-sycl-f16-crispasr"
|
||||
metal: "metal-crispasr"
|
||||
amd: "rocm-crispasr"
|
||||
vulkan: "vulkan-crispasr"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-cuda-13: "cuda13-crispasr"
|
||||
nvidia-cuda-12: "cuda12-crispasr"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
- ¶keetcpp
|
||||
name: "parakeet-cpp"
|
||||
alias: "parakeet-cpp"
|
||||
@@ -1957,6 +1984,131 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-whisper
|
||||
## crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "crispasr-development"
|
||||
capabilities:
|
||||
default: "cpu-crispasr-development"
|
||||
nvidia: "cuda12-crispasr-development"
|
||||
intel: "intel-sycl-f16-crispasr-development"
|
||||
metal: "metal-crispasr-development"
|
||||
amd: "rocm-crispasr-development"
|
||||
vulkan: "vulkan-crispasr-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-cuda-13: "cuda13-crispasr-development"
|
||||
nvidia-cuda-12: "cuda12-crispasr-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-crispasr
|
||||
## parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "parakeet-cpp-development"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
texterrors==1.1.6
|
||||
nemo_toolkit[asr]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf==7.35.0
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -16,7 +16,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
@@ -240,6 +242,84 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
|
||||
// Prefix-cache-aware routing. Enabled by default; an operator can opt out
|
||||
// with --distributed-prefix-cache=false, which leaves prefixProvider and
|
||||
// pressure nil so the SmartRouter and reconciler behave exactly as the
|
||||
// round-robin floor (true no-op). When enabled we build the local index,
|
||||
// wrap it in a NATS-backed Sync (publishes our observations, applies peers'
|
||||
// via the subscriptions below), install the extraction hook used by
|
||||
// core/backend/llm.go, and run a background eviction ticker on the app ctx.
|
||||
var prefixProvider prefixcache.Provider
|
||||
var pressure *prefixcache.Pressure
|
||||
var prefixCfg prefixcache.Config
|
||||
if !cfg.Distributed.PrefixCacheDisabled {
|
||||
prefixCfg = prefixcache.DefaultConfig()
|
||||
if cfg.Distributed.PrefixCacheTTL > 0 {
|
||||
prefixCfg.TTL = cfg.Distributed.PrefixCacheTTL
|
||||
}
|
||||
if err := prefixCfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid prefix-cache configuration: %w", err)
|
||||
}
|
||||
idx := prefixcache.NewIndex(prefixCfg)
|
||||
prefixSync := prefixcache.NewSync(idx, natsClient)
|
||||
pressure = prefixcache.NewPressure(prefixCfg.PressureWindow)
|
||||
prefixProvider = prefixSync
|
||||
|
||||
// Invalidate the prefix-cache index whenever a replica row is removed.
|
||||
// SetReplicaRemovedHook fires from the single chokepoint all removal paths
|
||||
// funnel through (RemoveNodeModel / RemoveAllNodeModelReplicas), so this
|
||||
// one hook covers every path: reconciler scale-down, probe reaper,
|
||||
// health-monitor reap, RemoteUnloaderAdapter, and the router. Registering
|
||||
// it only inside this enabled block keeps the disabled path a true no-op
|
||||
// (the registry stays hook-less).
|
||||
registry.SetReplicaRemovedHook(func(model, node string, replica int) {
|
||||
if replica < 0 {
|
||||
prefixSync.InvalidateNode(model, node)
|
||||
} else {
|
||||
prefixSync.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: replica})
|
||||
}
|
||||
})
|
||||
|
||||
distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 {
|
||||
return prefixcache.ExtractChain(model, prompt, prefixCfg)
|
||||
}
|
||||
|
||||
// Apply peers' observations/invalidations to the same Sync. ApplyObserve
|
||||
// and ApplyInvalidate update only the local index and do not re-publish,
|
||||
// so there is no broadcast loop.
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheObserve, func(ev messaging.PrefixCacheObserveEvent) {
|
||||
prefixSync.ApplyObserve(ev, time.Now())
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheObserve, err)
|
||||
}
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheInvalidate, func(ev messaging.PrefixCacheInvalidateEvent) {
|
||||
prefixSync.ApplyInvalidate(ev)
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheInvalidate, err)
|
||||
}
|
||||
|
||||
// Background eviction: sweep idle entries on the app context. Stopped
|
||||
// when the app context is cancelled (mirrors the reconciler loop which
|
||||
// also runs on options.Context). TTL/2 keeps stale entries from
|
||||
// outliving their idle window by more than half a TTL.
|
||||
evictInterval := prefixCfg.TTL / 2
|
||||
go func() {
|
||||
ticker := time.NewTicker(evictInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-cfg.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
prefixSync.Evict(time.Now())
|
||||
}
|
||||
}
|
||||
}()
|
||||
xlog.Info("Prefix-cache-aware routing enabled", "ttl", prefixCfg.TTL, "evictInterval", evictInterval)
|
||||
} else {
|
||||
xlog.Info("Prefix-cache-aware routing disabled: using round-robin routing")
|
||||
}
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
if configLoader != nil {
|
||||
@@ -252,6 +332,9 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
ConflictResolver: conflictResolver,
|
||||
PrefixProvider: prefixProvider,
|
||||
PrefixConfig: prefixCfg,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
|
||||
@@ -268,6 +351,8 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
Interval: 30 * time.Second,
|
||||
ScaleDownDelay: 5 * time.Minute,
|
||||
ProbeStaleAfter: 2 * time.Minute,
|
||||
Pressure: pressure,
|
||||
PressureThreshold: prefixCfg.PressureScaleThreshold,
|
||||
})
|
||||
|
||||
// Create ModelRouterAdapter to wire into ModelLoader
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -94,6 +95,22 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
// Make the rendered prompt's prefix chain available to the distributed router
|
||||
// for prefix-cache-aware node selection. No-op in single-process mode. The
|
||||
// model id MUST match the id ModelOptions feeds to model.WithModelID, so both
|
||||
// use the shared config.ModelConfig.ModelID() helper (Name with a fallback to
|
||||
// Model) or the chain salt and the tracking key would diverge.
|
||||
//
|
||||
// s is empty for UseTokenizerTemplate models (the backend tokenizes the
|
||||
// structured messages itself), so fall back to a prefix-stable serialization
|
||||
// of the messages - otherwise prefix routing would silently degrade to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
chainSource := s
|
||||
if chainSource == "" {
|
||||
chainSource = messagesPrefixSource(messages)
|
||||
}
|
||||
ctx = distributedhdr.MaybeWithPrefixChain(ctx, c.ModelID(), chainSource)
|
||||
|
||||
opts := ModelOptions(*c, o, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -34,16 +34,11 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
|
||||
}
|
||||
|
||||
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
|
||||
name := c.Name
|
||||
if name == "" {
|
||||
name = c.Model
|
||||
}
|
||||
|
||||
defOpts := []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithModel(c.Model),
|
||||
model.WithContext(so.Context),
|
||||
model.WithModelID(name),
|
||||
model.WithModelID(c.ModelID()),
|
||||
}
|
||||
|
||||
threads := 1
|
||||
|
||||
36
core/backend/prefix_source.go
Normal file
36
core/backend/prefix_source.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// messagesPrefixSource builds a deterministic, prefix-stable serialization of a
|
||||
// chat conversation for prefix-cache-aware routing. It is the fallback used when
|
||||
// the frontend did not render a prompt string: models with
|
||||
// config.TemplateConfig.UseTokenizerTemplate tokenize the structured messages
|
||||
// backend-side, so the frontend's rendered prompt is empty and a chain built
|
||||
// from it would always be empty - silently degrading prefix routing to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
//
|
||||
// Messages are emitted head-first in turn order (role line + content line per
|
||||
// message), so two conversations sharing a leading system prompt and early turns
|
||||
// share a leading byte prefix. That is exactly what ExtractChain hashes into a
|
||||
// shared chain prefix, landing both requests on the same cache-warm replica.
|
||||
func messagesPrefixSource(messages schema.Messages) string {
|
||||
var b strings.Builder
|
||||
for _, m := range messages {
|
||||
b.WriteString(m.Role)
|
||||
b.WriteByte('\n')
|
||||
content := m.StringContent
|
||||
if content == "" {
|
||||
if s, ok := m.Content.(string); ok {
|
||||
content = s
|
||||
}
|
||||
}
|
||||
b.WriteString(content)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
53
core/backend/prefix_source_internal_test.go
Normal file
53
core/backend/prefix_source_internal_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("messagesPrefixSource", func() {
|
||||
mk := func(role, content string) schema.Message {
|
||||
return schema.Message{Role: role, StringContent: content}
|
||||
}
|
||||
|
||||
It("serializes messages head-first in turn order", func() {
|
||||
got := messagesPrefixSource(schema.Messages{
|
||||
mk("system", "You are helpful."),
|
||||
mk("user", "Hi"),
|
||||
})
|
||||
Expect(got).To(Equal("system\nYou are helpful.\nuser\nHi\n"))
|
||||
})
|
||||
|
||||
It("is deterministic across calls for the same conversation", func() {
|
||||
conv := schema.Messages{mk("system", "S"), mk("user", "U")}
|
||||
Expect(messagesPrefixSource(conv)).To(Equal(messagesPrefixSource(conv)))
|
||||
})
|
||||
|
||||
It("shares a leading byte prefix when the system prompt is shared", func() {
|
||||
shared := "system\nShared system prompt.\nuser\n"
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question A")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question B")})
|
||||
Expect(strings.HasPrefix(a, shared)).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, shared)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does NOT share a prefix when the system prompt differs", func() {
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Prompt A"), mk("user", "Q")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Prompt B"), mk("user", "Q")})
|
||||
Expect(strings.HasPrefix(a, "system\nPrompt A")).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, "system\nPrompt B")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns empty for no messages", func() {
|
||||
Expect(messagesPrefixSource(nil)).To(Equal(""))
|
||||
})
|
||||
|
||||
It("falls back to Content when StringContent is empty", func() {
|
||||
got := messagesPrefixSource(schema.Messages{{Role: "user", Content: "plain"}})
|
||||
Expect(got).To(Equal("user\nplain\n"))
|
||||
})
|
||||
})
|
||||
@@ -145,19 +145,21 @@ type RunCMD struct {
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
|
||||
@@ -284,6 +286,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
if !r.DistributedPrefixCache {
|
||||
opts = append(opts, config.DisablePrefixCache)
|
||||
}
|
||||
if r.DistributedPrefixCacheTTL != "" {
|
||||
d, err := time.ParseDuration(r.DistributedPrefixCacheTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL %q: %w", r.DistributedPrefixCacheTTL, err)
|
||||
}
|
||||
opts = append(opts, config.WithPrefixCacheTTL(d))
|
||||
}
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
|
||||
@@ -14,4 +14,5 @@ type Worker struct {
|
||||
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
|
||||
MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"`
|
||||
VLLMDistributed VLLMDistributed `cmd:"" name:"vllm" help:"Starts a vLLM data-parallel follower process. Multi-node DP for a single model: head runs the existing vllm backend with engine_args.data_parallel_size>1, followers run this command."`
|
||||
DS4Distributed DS4Distributed `cmd:"" name:"ds4-distributed" help:"Starts a ds4 distributed worker in standalone mode: owns a layer slice and dials the coordinator (pass ds4-worker args after --)"`
|
||||
}
|
||||
|
||||
108
core/cli/worker/worker_ds4.go
Normal file
108
core/cli/worker/worker_ds4.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
type DS4Distributed struct {
|
||||
WorkerFlags `embed:""`
|
||||
ExtraDS4Args string `name:"ds4-args" env:"LOCALAI_EXTRA_DS4_ARGS,EXTRA_DS4_ARGS" help:"Arguments passed to ds4-worker (e.g. '--role worker --model m.gguf --layers 20:output --coordinator HOST PORT')"`
|
||||
}
|
||||
|
||||
const (
|
||||
ds4WorkerBinaryName = "ds4-worker"
|
||||
ds4GalleryName = "ds4"
|
||||
)
|
||||
|
||||
// ds4WorkerArgs builds the argv for syscall.Exec when launching ds4-worker
|
||||
// directly: the binary path followed by the space-split extra args. An empty
|
||||
// extra string yields a bare invocation.
|
||||
func ds4WorkerArgs(binary, extra string) []string {
|
||||
args := []string{binary}
|
||||
args = append(args, strings.Fields(extra)...)
|
||||
return args
|
||||
}
|
||||
|
||||
func findDS4Backend(galleries string, systemState *system.SystemState, requireIntegrity bool) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed listing system backends", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
backend, ok := backends.Get(ds4GalleryName)
|
||||
if !ok {
|
||||
ml := model.NewModelLoader(systemState)
|
||||
var gals []config.Gallery
|
||||
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
|
||||
xlog.Error("failed loading galleries", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, ds4GalleryName, nil, true, requireIntegrity); err != nil {
|
||||
xlog.Error("ds4 backend not found, failed to install it", "error", err)
|
||||
return "", err
|
||||
}
|
||||
backends, err = gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
backend, ok = backends.Get(ds4GalleryName)
|
||||
if !ok {
|
||||
return "", errors.New("ds4 backend not found after install")
|
||||
}
|
||||
}
|
||||
|
||||
backendPath := filepath.Dir(backend.RunFile)
|
||||
if backendPath == "" {
|
||||
return "", errors.New("ds4 backend not found, install it first")
|
||||
}
|
||||
return filepath.Join(backendPath, ds4WorkerBinaryName), nil
|
||||
}
|
||||
|
||||
func (r *DS4Distributed) Run(ctx *cliContext.Context) error {
|
||||
if r.ExtraDS4Args == "" && len(os.Args) < 4 {
|
||||
return fmt.Errorf("usage: local-ai worker ds4-distributed -- --role worker --model <gguf> --layers <START:END|START:output> --coordinator <host> <port>")
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(r.BackendsPath),
|
||||
system.WithBackendSystemPath(r.BackendsSystemPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
worker, err := findDS4Backend(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ds4 bundles its own dynamic loader (lib/ld.so) for glibc compatibility,
|
||||
// like backend/cpp/ds4/run.sh does for grpc-server. Launch ds4-worker via
|
||||
// that loader when present; otherwise exec it directly. (This is a
|
||||
// deliberate divergence from worker_llamacpp.go, which has no bundled loader.)
|
||||
backendPath := filepath.Dir(worker)
|
||||
env := os.Environ()
|
||||
loader := filepath.Join(backendPath, "lib", "ld.so")
|
||||
if _, statErr := os.Stat(loader); statErr == nil {
|
||||
env = append(env, "LD_LIBRARY_PATH="+filepath.Join(backendPath, "lib")+":"+os.Getenv("LD_LIBRARY_PATH"))
|
||||
args := append([]string{loader}, ds4WorkerArgs(worker, r.ExtraDS4Args)...)
|
||||
return syscall.Exec(loader, args, env)
|
||||
}
|
||||
|
||||
return syscall.Exec(worker, ds4WorkerArgs(worker, r.ExtraDS4Args), env)
|
||||
}
|
||||
28
core/cli/worker/worker_ds4_test.go
Normal file
28
core/cli/worker/worker_ds4_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ds4 worker CLI", func() {
|
||||
It("uses the ds4 backend gallery name and worker binary name", func() {
|
||||
Expect(ds4GalleryName).To(Equal("ds4"))
|
||||
Expect(ds4WorkerBinaryName).To(Equal("ds4-worker"))
|
||||
})
|
||||
|
||||
It("assembles direct exec args as [binary, extra-split...]", func() {
|
||||
args := ds4WorkerArgs("/b/ds4-worker", "--role worker --model m.gguf --layers 20:output --coordinator 10.0.0.1 1234")
|
||||
Expect(args).To(Equal([]string{
|
||||
"/b/ds4-worker",
|
||||
"--role", "worker",
|
||||
"--model", "m.gguf",
|
||||
"--layers", "20:output",
|
||||
"--coordinator", "10.0.0.1", "1234",
|
||||
}))
|
||||
})
|
||||
|
||||
It("drops empty extra args to a bare binary invocation", func() {
|
||||
Expect(ds4WorkerArgs("/b/ds4-worker", "")).To(Equal([]string{"/b/ds4-worker"}))
|
||||
})
|
||||
})
|
||||
@@ -198,6 +198,13 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
AcceptsVideos: true,
|
||||
Description: "vLLM engine — high-throughput LLM serving with optional multimodal",
|
||||
},
|
||||
"sglang": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodTokenizeString},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseTokenize, UsecaseVision},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
AcceptsImages: true,
|
||||
Description: "SGLang — fast LLM inference with structured generation and optional vision",
|
||||
},
|
||||
"vllm-omni": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodGenerateImage, MethodGenerateVideo, MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseImage, UsecaseVideo, UsecaseTTS, UsecaseVision},
|
||||
@@ -291,6 +298,12 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "NVIDIA NeMo speech recognition",
|
||||
},
|
||||
"parakeet-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "NVIDIA NeMo Parakeet ASR (parakeet.cpp)",
|
||||
},
|
||||
"qwen-asr": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
@@ -309,6 +322,18 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
Description: "VibeVoice — bidirectional speech (transcription and synthesis)",
|
||||
},
|
||||
"vibevoice-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodTTS, MethodTTSStream},
|
||||
PossibleUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
Description: "VibeVoice C++ — bidirectional speech, C++ backend with streaming TTS",
|
||||
},
|
||||
"sherpa-onnx": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodTTS, MethodTTSStream, MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseTranscript, UsecaseTTS, UsecaseVAD},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "Sherpa-ONNX — multi-model speech toolkit (ASR, TTS, VAD)",
|
||||
},
|
||||
|
||||
// --- TTS backends ---
|
||||
"piper": {
|
||||
@@ -353,6 +378,12 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Qwen TTS",
|
||||
},
|
||||
"qwen3-tts-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Qwen3 TTS C++ — text-to-speech, C++ backend",
|
||||
},
|
||||
"faster-qwen3-tts": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
@@ -434,6 +465,12 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR object detection",
|
||||
},
|
||||
"rfdetr-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodDetect},
|
||||
PossibleUsecases: []string{UsecaseDetection},
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR C++ object detection",
|
||||
},
|
||||
"silero-vad": {
|
||||
GRPCMethods: []GRPCMethod{MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseVAD},
|
||||
|
||||
@@ -49,6 +49,17 @@ type DistributedConfig struct {
|
||||
|
||||
AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"`
|
||||
JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"`
|
||||
|
||||
// PrefixCacheDisabled turns off prefix-cache-aware routing, falling back to
|
||||
// round-robin (the floor). Prefix-cache routing is ON by default in
|
||||
// distributed mode; this flag exists so operators can opt out. The CLI
|
||||
// surfaces a default-true --distributed-prefix-cache enable flag and sets
|
||||
// this when the operator passes --distributed-prefix-cache=false.
|
||||
PrefixCacheDisabled bool
|
||||
// PrefixCacheTTL is the idle-timeout for prefix-cache index entries and
|
||||
// drives the background eviction cadence (eviction runs every TTL/2). Zero
|
||||
// means use the prefixcache package default (5m).
|
||||
PrefixCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// Validate checks that the distributed configuration is internally consistent.
|
||||
@@ -158,6 +169,20 @@ var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// DisablePrefixCache turns off prefix-cache-aware routing (falls back to
|
||||
// round-robin). Prefix-cache routing is enabled by default in distributed mode.
|
||||
var DisablePrefixCache = func(o *ApplicationConfig) {
|
||||
o.Distributed.PrefixCacheDisabled = true
|
||||
}
|
||||
|
||||
// WithPrefixCacheTTL sets the prefix-cache index idle-timeout (and the
|
||||
// background eviction cadence, which runs every TTL/2).
|
||||
func WithPrefixCacheTTL(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.PrefixCacheTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
// Flag names for distributed timeout / interval configuration. These are
|
||||
// the kebab-case identifiers kong derives from the matching RunCMD struct
|
||||
// fields; they appear in Validate error messages and any other operator-
|
||||
|
||||
@@ -694,6 +694,18 @@ func (c *ModelConfig) IsModelURL() bool {
|
||||
return uri.LooksLikeURL()
|
||||
}
|
||||
|
||||
// ModelID returns the identifier used to reference this model across the
|
||||
// system: the configured Name, falling back to Model when Name is empty.
|
||||
// This is the single source of truth for the id fed to model.WithModelID and
|
||||
// the prefix-cache chain salt; both MUST agree with the router's tracking key
|
||||
// or the prefix-cache salt diverges silently.
|
||||
func (c ModelConfig) ModelID() string {
|
||||
if c.Name != "" {
|
||||
return c.Name
|
||||
}
|
||||
return c.Model
|
||||
}
|
||||
|
||||
// ModelFileName returns the filename of the model
|
||||
// If the model is a URL, it will return the MD5 of the URL which is the filename
|
||||
func (c *ModelConfig) ModelFileName() string {
|
||||
|
||||
@@ -10,6 +10,23 @@ import (
|
||||
)
|
||||
|
||||
var _ = Describe("Test cases for config related functions", func() {
|
||||
Context("ModelID", func() {
|
||||
It("returns Name when set", func() {
|
||||
c := ModelConfig{Name: "my-name"}
|
||||
c.Model = "my-model"
|
||||
Expect(c.ModelID()).To(Equal("my-name"))
|
||||
})
|
||||
It("falls back to Model when Name is empty", func() {
|
||||
c := ModelConfig{}
|
||||
c.Model = "my-model"
|
||||
Expect(c.ModelID()).To(Equal("my-model"))
|
||||
})
|
||||
It("returns empty string when both are empty", func() {
|
||||
c := ModelConfig{}
|
||||
Expect(c.ModelID()).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Test Read configuration functions", func() {
|
||||
It("Test Validate", func() {
|
||||
tmp, err := os.CreateTemp("", "config.yaml")
|
||||
|
||||
@@ -31,6 +31,7 @@ var knownPrefOnlyBackends = []schema.KnownBackend{
|
||||
{Name: "mlx-vlm", Modality: "text", AutoDetect: false, Description: "MLX vision-language models (preference-only)"},
|
||||
// ASR
|
||||
{Name: "whisperx", Modality: "asr", AutoDetect: false, Description: "WhisperX transcription (preference-only)"},
|
||||
{Name: "crispasr", Modality: "asr", AutoDetect: false, Description: "CrispASR multi-architecture transcription (preference-only)"},
|
||||
// TTS
|
||||
{Name: "kokoros", Modality: "tts", AutoDetect: false, Description: "Kokoros TTS (preference-only)"},
|
||||
{Name: "qwen-tts", Modality: "tts", AutoDetect: false, Description: "Qwen TTS (preference-only)"},
|
||||
|
||||
@@ -140,6 +140,7 @@ var _ = Describe("Backend Endpoints", func() {
|
||||
expectPrefOnly("trl", "text")
|
||||
expectPrefOnly("mlx-vlm", "text")
|
||||
expectPrefOnly("whisperx", "asr")
|
||||
expectPrefOnly("crispasr", "asr")
|
||||
expectPrefOnly("kokoros", "tts")
|
||||
expectPrefOnly("qwen-tts", "tts")
|
||||
expectPrefOnly("qwen3-tts-cpp", "tts")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
@@ -911,14 +913,56 @@ func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
// SetSchedulingRequest is the request body for creating/updating a scheduling config.
|
||||
//
|
||||
// The four prefix-cache fields are POINTERS so an omitted field is
|
||||
// distinguishable from an explicit zero. On update, an omitted prefix-cache
|
||||
// field preserves the model's previously-configured value instead of resetting
|
||||
// it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector,
|
||||
// MinReplicas and MaxReplicas keep their full-replace PUT semantics.
|
||||
type SetSchedulingRequest struct {
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
RoutePolicy *string `json:"route_policy,omitempty"`
|
||||
BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"`
|
||||
BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"`
|
||||
MinPrefixMatch *float64 `json:"min_prefix_match,omitempty"`
|
||||
}
|
||||
|
||||
// validateSchedulingRequest enforces the invariants of a scheduling config.
|
||||
// The prefix-cache bounds are delegated to prefixcache.ValidateThresholds (the
|
||||
// single source of truth), and are checked against the RESOLVED values passed
|
||||
// in (provided-or-preserved), so validation only rejects bad values the caller
|
||||
// actually supplied. It returns nil when valid, or an error with a user-facing
|
||||
// message describing the first violation.
|
||||
func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, absThr int, relThr, minMatch float64) error {
|
||||
if req.ModelName == "" {
|
||||
return errors.New("model_name is required")
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return errors.New("min_replicas must be >= 0")
|
||||
}
|
||||
if req.MaxReplicas < 0 {
|
||||
return errors.New("max_replicas must be >= 0")
|
||||
}
|
||||
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
||||
return errors.New("min_replicas must be <= max_replicas")
|
||||
}
|
||||
if err := prefixcache.ValidateThresholds(routePolicy, absThr, relThr, minMatch); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSchedulingEndpoint creates or updates a model scheduling config.
|
||||
//
|
||||
// The registry upsert full-replaces all columns, so a request that omits the
|
||||
// prefix-cache fields would otherwise wipe a model's previously-configured
|
||||
// routing settings. To avoid that footgun the four prefix-cache fields are
|
||||
// merged PATCH-style: a non-nil request pointer wins; a nil one preserves the
|
||||
// existing config's value (or the zero default when no config exists yet). The
|
||||
// non-prefix fields keep their full-replace PUT behavior.
|
||||
func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
@@ -926,17 +970,45 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
||||
}
|
||||
if req.ModelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required"))
|
||||
|
||||
// Fetch the existing config (may be nil) so omitted prefix-cache fields
|
||||
// can fall back to the stored value rather than resetting to zero.
|
||||
var existing *nodes.ModelSchedulingConfig
|
||||
if req.ModelName != "" {
|
||||
var err error
|
||||
existing, err = registry.GetModelScheduling(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to load existing scheduling config"))
|
||||
}
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0"))
|
||||
|
||||
// Resolve each prefix-cache field: provided pointer wins, otherwise keep
|
||||
// the existing value (zero/default when there is no existing config).
|
||||
routePolicy := ""
|
||||
absThr := 0
|
||||
relThr := 0.0
|
||||
minMatch := 0.0
|
||||
if existing != nil {
|
||||
routePolicy = existing.RoutePolicy
|
||||
absThr = existing.BalanceAbsThreshold
|
||||
relThr = existing.BalanceRelThreshold
|
||||
minMatch = existing.MinPrefixMatch
|
||||
}
|
||||
if req.MaxReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0"))
|
||||
if req.RoutePolicy != nil {
|
||||
routePolicy = *req.RoutePolicy
|
||||
}
|
||||
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas"))
|
||||
if req.BalanceAbsThreshold != nil {
|
||||
absThr = *req.BalanceAbsThreshold
|
||||
}
|
||||
if req.BalanceRelThreshold != nil {
|
||||
relThr = *req.BalanceRelThreshold
|
||||
}
|
||||
if req.MinPrefixMatch != nil {
|
||||
minMatch = *req.MinPrefixMatch
|
||||
}
|
||||
|
||||
if err := validateSchedulingRequest(req, routePolicy, absThr, relThr, minMatch); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, err.Error()))
|
||||
}
|
||||
|
||||
// Serialize node selector to JSON
|
||||
@@ -950,10 +1022,14 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
config := &nodes.ModelSchedulingConfig{
|
||||
ModelName: req.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
ModelName: req.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
RoutePolicy: routePolicy,
|
||||
BalanceAbsThreshold: absThr,
|
||||
BalanceRelThreshold: relThr,
|
||||
MinPrefixMatch: minMatch,
|
||||
}
|
||||
if err := registry.SetModelScheduling(ctx, config); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("validateSchedulingRequest", func() {
|
||||
base := func() SetSchedulingRequest {
|
||||
return SetSchedulingRequest{ModelName: "m"}
|
||||
}
|
||||
|
||||
It("accepts an empty route policy (inherit) with valid thresholds", func() {
|
||||
Expect(validateSchedulingRequest(base(), "", 3, 0, 0.4)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts the prefix_cache policy", func() {
|
||||
Expect(validateSchedulingRequest(base(), "prefix_cache", 0, 0, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts the round_robin policy", func() {
|
||||
Expect(validateSchedulingRequest(base(), "round_robin", 0, 0, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts balance_rel_threshold >= 1", func() {
|
||||
Expect(validateSchedulingRequest(base(), "", 0, 1.5, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects a missing model_name", func() {
|
||||
req := base()
|
||||
req.ModelName = ""
|
||||
err := validateSchedulingRequest(req, "", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("model_name is required"))
|
||||
})
|
||||
|
||||
It("rejects an unknown route_policy (no silent default)", func() {
|
||||
err := validateSchedulingRequest(base(), "bogus", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("rejects min_prefix_match above 1", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0, 2)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative min_prefix_match", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0, -0.1)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative balance_abs_threshold", func() {
|
||||
err := validateSchedulingRequest(base(), "", -1, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_abs_threshold"))
|
||||
})
|
||||
|
||||
It("rejects balance_rel_threshold between 0 and 1 exclusive", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0.5, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
})
|
||||
@@ -230,6 +230,114 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("SetSchedulingEndpoint", func() {
|
||||
postScheduling := func(body string) *httptest.ResponseRecorder {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
handler := SetSchedulingEndpoint(registry)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
It("persists prefix-cache fields and round-trips them via GET", func() {
|
||||
ctx := context.Background()
|
||||
rec := postScheduling(`{"model_name":"pc-model","route_policy":"prefix_cache","balance_abs_threshold":3,"min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "pc-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.RoutePolicy).To(Equal("prefix_cache"))
|
||||
Expect(cfg.BalanceAbsThreshold).To(Equal(3))
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
|
||||
|
||||
e := echo.New()
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
getRec := httptest.NewRecorder()
|
||||
gc := e.NewContext(getReq, getRec)
|
||||
gc.SetParamNames("model")
|
||||
gc.SetParamValues("pc-model")
|
||||
Expect(GetSchedulingEndpoint(registry)(gc)).To(Succeed())
|
||||
Expect(getRec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var got nodes.ModelSchedulingConfig
|
||||
Expect(json.Unmarshal(getRec.Body.Bytes(), &got)).To(Succeed())
|
||||
Expect(got.RoutePolicy).To(Equal("prefix_cache"))
|
||||
Expect(got.BalanceAbsThreshold).To(Equal(3))
|
||||
Expect(got.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
|
||||
})
|
||||
|
||||
It("returns 400 for an out-of-range min_prefix_match", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-mpm","min_prefix_match":2}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("returns 400 for an unknown route_policy", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-policy","route_policy":"bogus"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("returns 400 for a balance_rel_threshold between 0 and 1", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-rel","balance_rel_threshold":0.5}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
|
||||
// Regression for the partial-update footgun: a min/max-only POST used to
|
||||
// full-replace every column and silently reset the prefix-cache settings
|
||||
// to empty/zero. The pointer-merge must preserve omitted prefix fields.
|
||||
It("preserves prefix-cache settings across a min_replicas-only update", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
rec := postScheduling(`{"model_name":"merge-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
// Update only min_replicas - omits all prefix-cache fields.
|
||||
rec = postScheduling(`{"model_name":"merge-model","min_replicas":2}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "merge-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.MinReplicas).To(Equal(2), "the provided non-prefix field must update")
|
||||
Expect(cfg.RoutePolicy).To(Equal("prefix_cache"), "omitted route_policy must be preserved")
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must be preserved")
|
||||
})
|
||||
|
||||
It("updates a prefix-cache field when it is explicitly provided", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
rec := postScheduling(`{"model_name":"update-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
rec = postScheduling(`{"model_name":"update-model","route_policy":"round_robin"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "update-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.RoutePolicy).To(Equal("round_robin"), "explicitly provided route_policy must update")
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must still be preserved")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ListNodesEndpoint", func() {
|
||||
It("returns an empty list when no nodes are registered", func() {
|
||||
e := echo.New()
|
||||
|
||||
@@ -17,7 +17,10 @@ func SecurityHeaders() echo.MiddlewareFunc {
|
||||
"img-src 'self' data: blob: https:; " +
|
||||
"media-src 'self' data: blob:; " +
|
||||
"font-src 'self' data:; " +
|
||||
"connect-src 'self' ws: wss: https:; " +
|
||||
// blob: lets the waveform renderer XHR/fetch a freshly-created object
|
||||
// URL (e.g. an uploaded clip before it has a server URL). XHR/fetch of
|
||||
// blob: falls under connect-src, not media-src.
|
||||
"connect-src 'self' ws: wss: https: blob:; " +
|
||||
"frame-src 'self' blob:; " +
|
||||
"worker-src 'self' blob:; " +
|
||||
"object-src 'none'; " +
|
||||
|
||||
@@ -32,6 +32,9 @@ var _ = Describe("SecurityHeaders", func() {
|
||||
Expect(csp).To(ContainSubstring("frame-ancestors 'self'"))
|
||||
Expect(csp).To(ContainSubstring("object-src 'none'"))
|
||||
Expect(csp).To(ContainSubstring("base-uri 'self'"))
|
||||
// blob: must be in connect-src so the waveform renderer can XHR/fetch
|
||||
// a freshly-created object URL (uploaded/enhanced clip).
|
||||
Expect(csp).To(ContainSubstring("connect-src 'self' ws: wss: https: blob:"))
|
||||
})
|
||||
|
||||
It("sets X-Content-Type-Options: nosniff", func() {
|
||||
|
||||
@@ -1 +1 @@
|
||||
38.29
|
||||
40.0
|
||||
@@ -20,5 +20,10 @@ test.describe('Agents page', () => {
|
||||
page.waitForURL(/\/app\/agents\/new$/),
|
||||
create.click(),
|
||||
])
|
||||
// Wait for AgentCreate.jsx to actually render, not just for the URL to
|
||||
// change. Ending the test the instant the route matched let the component
|
||||
// mount race the coverage teardown — its ~400 lines were collected only
|
||||
// when the render won, swinging total UI coverage ~1pp run-to-run.
|
||||
await expect(page.getByRole('heading', { name: 'Create Agent' })).toBeVisible()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -66,6 +66,33 @@ function makeFakeWav(name) {
|
||||
return { name, mimeType: 'audio/wav', buffer: buf }
|
||||
}
|
||||
|
||||
// Build a WAV carrying a real sine tone, long enough that the spectrogram
|
||||
// STFT produces several frames (a few thousand samples). Used to exercise the
|
||||
// FFT / heatmap path, which the 4-sample silent fixture can't.
|
||||
function makeToneWav(name, freq = 1000, seconds = 0.4, sampleRate = 16000) {
|
||||
const samples = Math.floor(seconds * sampleRate)
|
||||
const dataLen = samples * 2
|
||||
const buf = Buffer.alloc(44 + dataLen)
|
||||
buf.write('RIFF', 0)
|
||||
buf.writeUInt32LE(36 + dataLen, 4)
|
||||
buf.write('WAVE', 8)
|
||||
buf.write('fmt ', 12)
|
||||
buf.writeUInt32LE(16, 16)
|
||||
buf.writeUInt16LE(1, 20)
|
||||
buf.writeUInt16LE(1, 22)
|
||||
buf.writeUInt32LE(sampleRate, 24)
|
||||
buf.writeUInt32LE(sampleRate * 2, 28)
|
||||
buf.writeUInt16LE(2, 32)
|
||||
buf.writeUInt16LE(16, 34)
|
||||
buf.write('data', 36)
|
||||
buf.writeUInt32LE(dataLen, 40)
|
||||
for (let i = 0; i < samples; i++) {
|
||||
const v = Math.round(Math.sin((2 * Math.PI * freq * i) / sampleRate) * 16000)
|
||||
buf.writeInt16LE(v, 44 + i * 2)
|
||||
}
|
||||
return { name, mimeType: 'audio/wav', buffer: buf }
|
||||
}
|
||||
|
||||
test.describe('Audio Transform', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await mockCapabilities(page, [
|
||||
@@ -169,6 +196,26 @@ test.describe('Audio Transform', () => {
|
||||
await expect(page.getByTestId('media-history-item')).toHaveCount(1)
|
||||
})
|
||||
|
||||
test('renders an input spectrogram on upload and an output one after transform', async ({ page }) => {
|
||||
mockAudioTransform(page, 'enhanced.wav')
|
||||
|
||||
await page.goto('/app/transform')
|
||||
await expect(page.getByRole('button', { name: 'localvqe' })).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
// Choosing a clip should render its input spectrogram immediately — no
|
||||
// backend round-trip needed (it's computed client-side from the bytes).
|
||||
await page.locator('input[type="file"]').first().setInputFiles(makeToneWav('tone.wav'))
|
||||
await expect(page.getByTestId('spectrogram-input')).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
// Until a transform runs the output side shows a "compare" placeholder.
|
||||
await expect(page.getByText(/Transform to compare/)).toBeVisible()
|
||||
|
||||
await page.getByRole('button', { name: /Transform/ }).last().click()
|
||||
|
||||
// After processing, the output spectrum panel appears alongside the input.
|
||||
await expect(page.getByText('Output spectrum')).toBeVisible({ timeout: 10_000 })
|
||||
})
|
||||
|
||||
test('shows an error banner when the backend returns 4xx', async ({ page }) => {
|
||||
await page.route('**/audio/transformations', (route) => {
|
||||
if (route.request().method() !== 'POST') return route.continue()
|
||||
|
||||
40
core/http/react-ui/e2e/page-render-smoke.spec.js
Normal file
40
core/http/react-ui/e2e/page-render-smoke.spec.js
Normal file
@@ -0,0 +1,40 @@
|
||||
import { test, expect } from './coverage-fixtures.js'
|
||||
|
||||
// Render-smoke coverage. Each page is lazy-loaded and runs its full render +
|
||||
// initial effects on mount, so a bare visit captures the bulk of a page's
|
||||
// lines — cheap, real coverage for pages that have no dedicated spec yet.
|
||||
//
|
||||
// This is the project's preferred way to keep the UI coverage gate green:
|
||||
// raise the floor by covering more, rather than loosening the gate's
|
||||
// tolerance (see CONTRIBUTING.md → "React UI coverage"). Auth is disabled in
|
||||
// the test server, so RequireAdmin/RequireFeature resolve to isAdmin=true and
|
||||
// every gated route renders without an auth/capability mock.
|
||||
//
|
||||
// Asserts the page mounted (its .page-title header is visible) and that it did
|
||||
// not bounce to a gate redirect (/login or back to /app home).
|
||||
const PAGES = [
|
||||
['/app/talk', 'Talk'],
|
||||
['/app/usage', 'Usage'],
|
||||
['/app/account', 'Account'],
|
||||
['/app/studio', 'Studio'],
|
||||
['/app/manage', 'Manage'],
|
||||
['/app/backends', 'Backends'],
|
||||
['/app/settings', 'Settings'],
|
||||
['/app/nodes', 'Nodes'],
|
||||
['/app/face', 'Face recognition'],
|
||||
['/app/voice', 'Voice recognition'],
|
||||
['/app/fine-tune', 'Fine-tuning'],
|
||||
['/app/quantize', 'Quantize'],
|
||||
]
|
||||
|
||||
test.describe('Page render smoke', () => {
|
||||
for (const [path, label] of PAGES) {
|
||||
test(`renders ${label} (${path})`, async ({ page }) => {
|
||||
await page.goto(path)
|
||||
// .page-title for the normal header; .empty-state-title for pages that
|
||||
// render a gated/empty state (e.g. Account when auth is disabled).
|
||||
await expect(page.locator('.page-title, .empty-state-title').first()).toBeVisible({ timeout: 15_000 })
|
||||
await expect(page).toHaveURL(new RegExp(path.replace(/\//g, '\\/') + '$'))
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -6984,6 +6984,88 @@ select.input {
|
||||
color: var(--color-primary);
|
||||
}
|
||||
|
||||
/* Spectrogram (AudioTransform spectral view) */
|
||||
.audio-spectrogram-pair {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: var(--spacing-md);
|
||||
}
|
||||
@media (max-width: 720px) {
|
||||
.audio-spectrogram-pair {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
.audio-spectrogram {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--spacing-xs);
|
||||
width: 100%;
|
||||
min-width: 0;
|
||||
}
|
||||
.audio-spectrogram__label {
|
||||
font-size: var(--text-sm);
|
||||
color: var(--color-text-secondary);
|
||||
}
|
||||
.audio-spectrogram__canvas-wrap {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
background: var(--color-surface-sunken);
|
||||
border: 1px solid var(--color-border-subtle);
|
||||
border-radius: var(--radius-md);
|
||||
overflow: hidden;
|
||||
}
|
||||
.audio-spectrogram__canvas-wrap--empty {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.audio-spectrogram__hint {
|
||||
color: var(--color-text-muted);
|
||||
font-size: var(--text-sm);
|
||||
}
|
||||
.audio-spectrogram__loading {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: var(--color-text-muted);
|
||||
font-size: var(--text-sm);
|
||||
}
|
||||
.audio-spectrogram__error {
|
||||
padding: var(--spacing-md);
|
||||
color: var(--color-error);
|
||||
font-size: var(--text-sm);
|
||||
}
|
||||
.audio-spectrogram__axis {
|
||||
position: absolute;
|
||||
left: 6px;
|
||||
font-size: 10px;
|
||||
color: var(--color-text-muted);
|
||||
background: var(--color-bg-overlay);
|
||||
padding: 0 4px;
|
||||
border-radius: var(--radius-sm);
|
||||
pointer-events: none;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
.audio-spectrogram__axis--top {
|
||||
top: 4px;
|
||||
}
|
||||
.audio-spectrogram__axis--bottom {
|
||||
bottom: 4px;
|
||||
}
|
||||
.audio-spectrogram__duration {
|
||||
position: absolute;
|
||||
right: 8px;
|
||||
bottom: 6px;
|
||||
font-size: 11px;
|
||||
color: var(--color-text-muted);
|
||||
font-variant-numeric: tabular-nums;
|
||||
background: var(--color-bg-overlay);
|
||||
padding: 1px 6px;
|
||||
border-radius: var(--radius-sm);
|
||||
}
|
||||
|
||||
/* Audio Transform Studio tab */
|
||||
.audio-transform-stack {
|
||||
display: flex;
|
||||
|
||||
105
core/http/react-ui/src/components/audio/Spectrogram.jsx
Normal file
105
core/http/react-ui/src/components/audio/Spectrogram.jsx
Normal file
@@ -0,0 +1,105 @@
|
||||
import { useEffect, useRef } from 'react'
|
||||
import useSpectrogram from '../../hooks/useSpectrogram'
|
||||
|
||||
// Spectrogram — canvas heatmap of a clip's magnitude STFT (time × frequency).
|
||||
// Time runs left→right, frequency low→high bottom→top, brighter = more energy.
|
||||
// Used on the AudioTransform page to show input next to output so the user can
|
||||
// see which bands the model attenuates (dark gaps that were bright in the
|
||||
// input). Mirrors WaveformPlayer's canvas/label/overlay structure.
|
||||
export default function Spectrogram({ src, label, height = 140, testId }) {
|
||||
const canvasRef = useRef(null)
|
||||
const { spectrogram, frames, bins, maxFreq, duration, error, loading } = useSpectrogram(src)
|
||||
|
||||
useEffect(() => {
|
||||
const canvas = canvasRef.current
|
||||
if (!canvas) return
|
||||
const dpr = window.devicePixelRatio || 1
|
||||
const cssW = canvas.clientWidth
|
||||
const cssH = height
|
||||
canvas.width = Math.floor(cssW * dpr)
|
||||
canvas.height = Math.floor(cssH * dpr)
|
||||
const ctx = canvas.getContext('2d')
|
||||
ctx.setTransform(dpr, 0, 0, dpr, 0, 0)
|
||||
ctx.clearRect(0, 0, cssW, cssH)
|
||||
if (!spectrogram || !frames || !bins) return
|
||||
|
||||
// Paint at native (frames × bins) resolution into an offscreen canvas,
|
||||
// then let drawImage smooth-scale it up — far cheaper than filling
|
||||
// cssW×cssH rects, and the GPU handles the interpolation.
|
||||
const img = ctx.createImageData(frames, bins)
|
||||
for (let f = 0; f < frames; f++) {
|
||||
for (let b = 0; b < bins; b++) {
|
||||
const [r, g, bl] = magma(spectrogram[f * bins + b])
|
||||
// Flip the frequency axis: image row 0 is the top = highest freq.
|
||||
const o = ((bins - 1 - b) * frames + f) * 4
|
||||
img.data[o] = r
|
||||
img.data[o + 1] = g
|
||||
img.data[o + 2] = bl
|
||||
img.data[o + 3] = 255
|
||||
}
|
||||
}
|
||||
const off = document.createElement('canvas')
|
||||
off.width = frames
|
||||
off.height = bins
|
||||
off.getContext('2d').putImageData(img, 0, 0)
|
||||
ctx.imageSmoothingEnabled = true
|
||||
ctx.drawImage(off, 0, 0, cssW, cssH)
|
||||
}, [spectrogram, frames, bins, height])
|
||||
|
||||
if (!src) return null
|
||||
|
||||
return (
|
||||
<div className="audio-spectrogram">
|
||||
{label && <div className="audio-spectrogram__label">{label}</div>}
|
||||
<div className="audio-spectrogram__canvas-wrap" style={{ height }}>
|
||||
{error ? (
|
||||
<div className="audio-spectrogram__error">{error}</div>
|
||||
) : (
|
||||
<canvas ref={canvasRef} data-testid={testId} style={{ width: '100%', height: '100%' }} />
|
||||
)}
|
||||
{maxFreq > 0 && !error && (
|
||||
<>
|
||||
<span className="audio-spectrogram__axis audio-spectrogram__axis--top">{fmtHz(maxFreq)}</span>
|
||||
<span className="audio-spectrogram__axis audio-spectrogram__axis--bottom">0 Hz</span>
|
||||
</>
|
||||
)}
|
||||
{duration > 0 && !error && (
|
||||
<span className="audio-spectrogram__duration">{duration.toFixed(1)}s</span>
|
||||
)}
|
||||
{loading && !error && <div className="audio-spectrogram__loading">Analysing…</div>}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function fmtHz(hz) {
|
||||
if (hz >= 1000) return `${(hz / 1000).toFixed(hz % 1000 === 0 ? 0 : 1)} kHz`
|
||||
return `${Math.round(hz)} Hz`
|
||||
}
|
||||
|
||||
// magma — compact perceptual colormap (black→purple→orange→white) sampled at 8
|
||||
// control points and linearly interpolated. Perceptually uniform maps read
|
||||
// far better for spectral magnitude than a raw hue ramp. v is clamped to [0,1].
|
||||
const MAGMA = [
|
||||
[0, 0, 4],
|
||||
[40, 11, 84],
|
||||
[101, 21, 110],
|
||||
[159, 42, 99],
|
||||
[212, 72, 66],
|
||||
[245, 125, 21],
|
||||
[250, 193, 39],
|
||||
[252, 253, 191],
|
||||
]
|
||||
function magma(v) {
|
||||
const t = v <= 0 ? 0 : v >= 1 ? 1 : v
|
||||
const x = t * (MAGMA.length - 1)
|
||||
const i = Math.floor(x)
|
||||
const frac = x - i
|
||||
const a = MAGMA[i]
|
||||
const b = MAGMA[Math.min(i + 1, MAGMA.length - 1)]
|
||||
return [
|
||||
Math.round(a[0] + (b[0] - a[0]) * frac),
|
||||
Math.round(a[1] + (b[1] - a[1]) * frac),
|
||||
Math.round(a[2] + (b[2] - a[2]) * frac),
|
||||
]
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import { useEffect, useState } from 'react'
|
||||
// and most browsers cap concurrent AudioContexts at ~6. Keep one alive for
|
||||
// the lifetime of the tab and reuse it across decodes.
|
||||
let sharedCtx = null
|
||||
function getSharedAudioContext() {
|
||||
export function getSharedAudioContext() {
|
||||
if (sharedCtx) return sharedCtx
|
||||
const Ctx = window.AudioContext || window.webkitAudioContext
|
||||
if (!Ctx) return null
|
||||
|
||||
107
core/http/react-ui/src/hooks/useSpectrogram.js
vendored
Normal file
107
core/http/react-ui/src/hooks/useSpectrogram.js
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
import { getSharedAudioContext } from './useAudioPeaks'
|
||||
import { fftRadix2 } from '../utils/fft'
|
||||
|
||||
// Hann windows are reused across frames and across clips, so cache one per
|
||||
// size. The window tapers each frame to suppress spectral leakage (the
|
||||
// vertical smearing you'd otherwise get from hard frame edges).
|
||||
const windowCache = new Map()
|
||||
function hann(n) {
|
||||
let w = windowCache.get(n)
|
||||
if (w) return w
|
||||
w = new Float32Array(n)
|
||||
for (let i = 0; i < n; i++) w[i] = 0.5 - 0.5 * Math.cos((2 * Math.PI * i) / (n - 1))
|
||||
windowCache.set(n, w)
|
||||
return w
|
||||
}
|
||||
|
||||
const EMPTY = { spectrogram: null, frames: 0, bins: 0, maxFreq: 0, duration: 0, error: null, loading: false }
|
||||
|
||||
// useSpectrogram — decode an audio source (blob/data/http URL) and compute a
|
||||
// magnitude STFT suitable for a spectrogram heatmap. Returns
|
||||
// `{ spectrogram, frames, bins, maxFreq, duration, error, loading }` where
|
||||
// `spectrogram` is a Float32Array of `frames * bins` values, row-major by
|
||||
// frame, normalised so the dB floor maps to 0 and the loudest bin to 1.
|
||||
// `bins` spans 0..Nyquist (`maxFreq`).
|
||||
//
|
||||
// fftSize/hop default to the LocalVQE frame geometry (512/256) so the picture
|
||||
// lines up with how the model itself frames the audio. Long clips are
|
||||
// strided down to at most `maxFrames` columns — the heatmap is only a few
|
||||
// hundred px wide, so computing an FFT per native hop would be wasted work.
|
||||
export default function useSpectrogram(
|
||||
src,
|
||||
{ fftSize = 512, hop = 256, maxFrames = 900, dbFloor = -90 } = {},
|
||||
) {
|
||||
const [state, setState] = useState(EMPTY)
|
||||
|
||||
useEffect(() => {
|
||||
setState(EMPTY)
|
||||
if (!src) return
|
||||
let cancelled = false
|
||||
setState((s) => ({ ...s, loading: true }))
|
||||
|
||||
async function run() {
|
||||
try {
|
||||
const resp = await fetch(src)
|
||||
const raw = await resp.arrayBuffer()
|
||||
const ctx = getSharedAudioContext()
|
||||
if (!ctx) throw new Error('Web Audio API not available')
|
||||
const audio = await ctx.decodeAudioData(raw.slice(0))
|
||||
if (cancelled) return
|
||||
|
||||
const data = audio.getChannelData(0)
|
||||
const bins = fftSize >> 1
|
||||
const win = hann(fftSize)
|
||||
|
||||
// Frame count, then a stride so we never run more than maxFrames FFTs.
|
||||
const rawFrames = data.length >= fftSize ? 1 + Math.floor((data.length - fftSize) / hop) : 1
|
||||
const stride = rawFrames > maxFrames ? Math.ceil(rawFrames / maxFrames) : 1
|
||||
const frames = Math.ceil(rawFrames / stride)
|
||||
|
||||
const spec = new Float32Array(frames * bins)
|
||||
const re = new Float64Array(fftSize)
|
||||
const im = new Float64Array(fftSize)
|
||||
let peakDb = dbFloor
|
||||
|
||||
for (let f = 0; f < frames; f++) {
|
||||
const start = f * stride * hop
|
||||
for (let i = 0; i < fftSize; i++) {
|
||||
const s = start + i
|
||||
re[i] = s < data.length ? data[s] * win[i] : 0
|
||||
im[i] = 0
|
||||
}
|
||||
fftRadix2(re, im)
|
||||
for (let b = 0; b < bins; b++) {
|
||||
const mag = Math.hypot(re[b], im[b]) / fftSize
|
||||
let db = mag > 0 ? 20 * Math.log10(mag) : dbFloor
|
||||
if (db < dbFloor) db = dbFloor
|
||||
spec[f * bins + b] = db
|
||||
if (db > peakDb) peakDb = db
|
||||
}
|
||||
}
|
||||
|
||||
// Normalise dB into [0,1] against [dbFloor, peakDb].
|
||||
const range = peakDb - dbFloor || 1
|
||||
for (let i = 0; i < spec.length; i++) spec[i] = (spec[i] - dbFloor) / range
|
||||
|
||||
if (cancelled) return
|
||||
setState({
|
||||
spectrogram: spec,
|
||||
frames,
|
||||
bins,
|
||||
maxFreq: audio.sampleRate / 2,
|
||||
duration: audio.duration,
|
||||
error: null,
|
||||
loading: false,
|
||||
})
|
||||
} catch (e) {
|
||||
if (!cancelled) setState((s) => ({ ...s, error: e?.message || 'Could not analyse audio', loading: false }))
|
||||
}
|
||||
}
|
||||
|
||||
run()
|
||||
return () => { cancelled = true }
|
||||
}, [src, fftSize, hop, maxFrames, dbFloor])
|
||||
|
||||
return state
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import { CAP_AUDIO_TRANSFORM } from '../utils/capabilities'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import ErrorWithTraceLink from '../components/ErrorWithTraceLink'
|
||||
import WaveformPlayer from '../components/audio/WaveformPlayer'
|
||||
import Spectrogram from '../components/audio/Spectrogram'
|
||||
import { audioTransformApi } from '../utils/api'
|
||||
import { useMediaCapture } from '../hooks/useMediaCapture'
|
||||
import useObjectUrl from '../hooks/useObjectUrl'
|
||||
@@ -261,6 +262,24 @@ export default function AudioTransform() {
|
||||
</div>
|
||||
) : (
|
||||
<div className="audio-transform-stack">
|
||||
{audioUrl && (
|
||||
<div className="audio-spectrogram-pair">
|
||||
<Spectrogram src={audioUrl} label="Input spectrum" testId="spectrogram-input" />
|
||||
{outputUrl ? (
|
||||
<Spectrogram src={outputUrl} label="Output spectrum" testId="spectrogram-output" />
|
||||
) : (
|
||||
<div className="audio-spectrogram">
|
||||
<div className="audio-spectrogram__label">Output spectrum</div>
|
||||
<div
|
||||
className="audio-spectrogram__canvas-wrap audio-spectrogram__canvas-wrap--empty"
|
||||
style={{ height: 140 }}
|
||||
>
|
||||
<span className="audio-spectrogram__hint">Transform to compare attenuation</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<WaveformPlayer src={audioUrl} label="Audio" height={96} />
|
||||
<WaveformPlayer src={referenceUrl} label="Reference" height={96} dimmed={!referenceFile} />
|
||||
{outputUrl && (
|
||||
|
||||
@@ -493,6 +493,13 @@ function SchedulingForm({ onSave, onCancel }) {
|
||||
const [selector, setSelector] = useState({})
|
||||
const [minReplicas, setMinReplicas] = useState(1)
|
||||
const [maxReplicas, setMaxReplicas] = useState(0)
|
||||
// Prefix-cache routing controls. Empty routePolicy means "inherit the
|
||||
// cluster default"; the three thresholds at 0 likewise inherit, so they
|
||||
// stay out of the POST body's effective override only when explicitly set.
|
||||
const [routePolicy, setRoutePolicy] = useState('')
|
||||
const [balanceAbsThreshold, setBalanceAbsThreshold] = useState(0)
|
||||
const [balanceRelThreshold, setBalanceRelThreshold] = useState(0)
|
||||
const [minPrefixMatch, setMinPrefixMatch] = useState(0)
|
||||
|
||||
const hasSelector = Object.keys(selector).length > 0
|
||||
|
||||
@@ -508,6 +515,10 @@ function SchedulingForm({ onSave, onCancel }) {
|
||||
node_selector: hasSelector ? selector : undefined,
|
||||
min_replicas: mode === 'placement' ? 0 : minReplicas,
|
||||
max_replicas: mode === 'placement' ? 0 : maxReplicas,
|
||||
route_policy: routePolicy,
|
||||
balance_abs_threshold: balanceAbsThreshold,
|
||||
balance_rel_threshold: balanceRelThreshold,
|
||||
min_prefix_match: minPrefixMatch,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -593,6 +604,76 @@ function SchedulingForm({ onSave, onCancel }) {
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Per-model routing policy. Left empty/zero these inherit the
|
||||
cluster-wide defaults; set them to override how requests for this
|
||||
model are spread across replicas. */}
|
||||
<div>
|
||||
<label className="form-label" htmlFor="sched-route-policy">Routing policy</label>
|
||||
<select
|
||||
id="sched-route-policy"
|
||||
className="input"
|
||||
value={routePolicy}
|
||||
onChange={e => setRoutePolicy(e.target.value)}
|
||||
>
|
||||
<option value="">Default (cluster setting)</option>
|
||||
<option value="round_robin">Round Robin</option>
|
||||
<option value="prefix_cache">Prefix Cache</option>
|
||||
</select>
|
||||
<span style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', display: 'block', marginTop: 6 }}>
|
||||
Prefix Cache routes shared-prefix requests to the same replica to reuse its KV cache, falling back to round-robin when replicas are imbalanced.
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{routePolicy === 'prefix_cache' && (
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-md)' }}>
|
||||
<div style={{ flex: 1 }}>
|
||||
<label className="form-label" htmlFor="sched-min-prefix-match">Min prefix match</label>
|
||||
<input
|
||||
id="sched-min-prefix-match"
|
||||
className="input"
|
||||
type="number"
|
||||
step="0.05"
|
||||
min="0"
|
||||
max="1"
|
||||
value={minPrefixMatch}
|
||||
onChange={e => setMinPrefixMatch(parseFloat(e.target.value) || 0)}
|
||||
/>
|
||||
<span style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', display: 'block', marginTop: 6 }}>
|
||||
Fraction of the prompt (0..1) that must match a cached prefix before affinity kicks in. 0 inherits the default.
|
||||
</span>
|
||||
</div>
|
||||
<div style={{ flex: 1 }}>
|
||||
<label className="form-label" htmlFor="sched-balance-abs">Balance abs threshold</label>
|
||||
<input
|
||||
id="sched-balance-abs"
|
||||
className="input"
|
||||
type="number"
|
||||
min="0"
|
||||
value={balanceAbsThreshold}
|
||||
onChange={e => setBalanceAbsThreshold(parseInt(e.target.value) || 0)}
|
||||
/>
|
||||
<span style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', display: 'block', marginTop: 6 }}>
|
||||
Max absolute in-flight gap allowed before falling back to round-robin. 0 inherits the default.
|
||||
</span>
|
||||
</div>
|
||||
<div style={{ flex: 1 }}>
|
||||
<label className="form-label" htmlFor="sched-balance-rel">Balance rel threshold</label>
|
||||
<input
|
||||
id="sched-balance-rel"
|
||||
className="input"
|
||||
type="number"
|
||||
step="0.1"
|
||||
min="0"
|
||||
value={balanceRelThreshold}
|
||||
onChange={e => setBalanceRelThreshold(parseFloat(e.target.value) || 0)}
|
||||
/>
|
||||
<span style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', display: 'block', marginTop: 6 }}>
|
||||
Max relative in-flight ratio (>= 1) allowed before falling back to round-robin. 0 inherits the default.
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Hairline divider above the actions, matching the project's form pattern. */}
|
||||
@@ -1475,6 +1556,8 @@ export default function Nodes() {
|
||||
<th>Node Selector</th>
|
||||
<th>Min Replicas</th>
|
||||
<th>Max Replicas</th>
|
||||
<th>Routing</th>
|
||||
<th>Thresholds</th>
|
||||
<th>Status</th>
|
||||
<th style={{ textAlign: 'right' }}>Actions</th>
|
||||
</tr></thead>
|
||||
@@ -1519,6 +1602,18 @@ export default function Nodes() {
|
||||
<td style={{ fontFamily: 'var(--font-mono)' }}>
|
||||
{isAutoScaling ? (cfg.max_replicas || 'no limit') : '-'}
|
||||
</td>
|
||||
<td style={{ fontSize: '0.8125rem' }}>
|
||||
{cfg.route_policy || 'default'}
|
||||
</td>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>
|
||||
{cfg.route_policy === 'prefix_cache' ? (
|
||||
<>
|
||||
<div>match: {cfg.min_prefix_match ? cfg.min_prefix_match : 'inherit'}</div>
|
||||
<div>abs: {cfg.balance_abs_threshold ? cfg.balance_abs_threshold : 'inherit'}</div>
|
||||
<div>rel: {cfg.balance_rel_threshold ? cfg.balance_rel_threshold : 'inherit'}</div>
|
||||
</>
|
||||
) : '-'}
|
||||
</td>
|
||||
<td>
|
||||
{isUnsatisfiable ? (
|
||||
<span
|
||||
|
||||
47
core/http/react-ui/src/utils/fft.js
vendored
Normal file
47
core/http/react-ui/src/utils/fft.js
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
// Minimal in-place iterative radix-2 Cooley–Tukey FFT.
|
||||
//
|
||||
// The AudioTransform spectrogram only needs forward transforms of short real
|
||||
// frames (≤2048 samples), so a compact ~30-line implementation beats pulling
|
||||
// in a dependency and shipping it in the bundle. `re` and `im` are mutated in
|
||||
// place; `n = re.length` must be a power of two (the caller picks fftSize).
|
||||
export function fftRadix2(re, im) {
|
||||
const n = re.length
|
||||
if (n <= 1) return
|
||||
|
||||
// Bit-reversal permutation: reorder samples so the butterfly stage below can
|
||||
// run in place.
|
||||
for (let i = 1, j = 0; i < n; i++) {
|
||||
let bit = n >> 1
|
||||
for (; j & bit; bit >>= 1) j ^= bit
|
||||
j ^= bit
|
||||
if (i < j) {
|
||||
const tr = re[i]; re[i] = re[j]; re[j] = tr
|
||||
const ti = im[i]; im[i] = im[j]; im[j] = ti
|
||||
}
|
||||
}
|
||||
|
||||
// Butterflies, doubling the transform length each pass.
|
||||
for (let len = 2; len <= n; len <<= 1) {
|
||||
const half = len >> 1
|
||||
const ang = (-2 * Math.PI) / len
|
||||
const wpr = Math.cos(ang)
|
||||
const wpi = Math.sin(ang)
|
||||
for (let i = 0; i < n; i += len) {
|
||||
let wr = 1
|
||||
let wi = 0
|
||||
for (let k = 0; k < half; k++) {
|
||||
const a = i + k
|
||||
const b = a + half
|
||||
const tr = wr * re[b] - wi * im[b]
|
||||
const ti = wr * im[b] + wi * re[b]
|
||||
re[b] = re[a] - tr
|
||||
im[b] = im[a] - ti
|
||||
re[a] += tr
|
||||
im[a] += ti
|
||||
const nwr = wr * wpr - wi * wpi
|
||||
wi = wr * wpi + wi * wpr
|
||||
wr = nwr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -355,3 +355,35 @@ type CacheInvalidateEvent struct {
|
||||
func SubjectCacheInvalidateCollection(name string) string {
|
||||
return "cache.invalidate.collections." + sanitizeSubjectToken(name)
|
||||
}
|
||||
|
||||
// Prefix-Cache Routing Sync (Pub/Sub - broadcast to all frontends)
|
||||
//
|
||||
// Frontends share prefix-cache observations so a request routed to any replica
|
||||
// benefits from the prefix-affinity another replica already learned. This
|
||||
// mirrors the OpCache live-sync pattern: plain NATS Core pub/sub, no JetStream.
|
||||
const (
|
||||
SubjectPrefixCacheObserve = "prefixcache.observe"
|
||||
SubjectPrefixCacheInvalidate = "prefixcache.invalidate"
|
||||
)
|
||||
|
||||
// PrefixCacheObserveEvent announces that the replica (NodeID, Replica) served a
|
||||
// request whose prefix chain ends at the given hashes for model. Chain is the
|
||||
// full shallow-to-deep hash chain so peers can insert the same path. Affinity is
|
||||
// per replica (a backend process with its own KV cache), not per node, so the
|
||||
// replica index is carried so peers attribute the observation to the same one.
|
||||
type PrefixCacheObserveEvent struct {
|
||||
Model string `json:"model"`
|
||||
Chain []uint64 `json:"chain"`
|
||||
NodeID string `json:"node_id"`
|
||||
Replica int `json:"replica"`
|
||||
}
|
||||
|
||||
// PrefixCacheInvalidateEvent tells peers to drop entries for a replica. When
|
||||
// Replica >= 0 it targets the single replica (Model, NodeID, Replica). When
|
||||
// Replica < 0 it targets ALL replicas of (Model, NodeID), for example when a
|
||||
// whole node goes offline.
|
||||
type PrefixCacheInvalidateEvent struct {
|
||||
Model string `json:"model"`
|
||||
NodeID string `json:"node_id"`
|
||||
Replica int `json:"replica"`
|
||||
}
|
||||
|
||||
27
core/services/messaging/subjects_prefixcache_test.go
Normal file
27
core/services/messaging/subjects_prefixcache_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package messaging_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
var _ = Describe("PrefixCache subjects", func() {
|
||||
It("exposes stable subject constants", func() {
|
||||
Expect(messaging.SubjectPrefixCacheObserve).To(Equal("prefixcache.observe"))
|
||||
Expect(messaging.SubjectPrefixCacheInvalidate).To(Equal("prefixcache.invalidate"))
|
||||
})
|
||||
|
||||
It("carries a replica index on the observe event", func() {
|
||||
ev := messaging.PrefixCacheObserveEvent{Model: "m", Chain: []uint64{1, 2}, NodeID: "A", Replica: 3}
|
||||
Expect(ev.Replica).To(Equal(3))
|
||||
})
|
||||
|
||||
It("uses a negative replica on the invalidate event to mean all replicas of a node", func() {
|
||||
all := messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: -1}
|
||||
Expect(all.Replica).To(BeNumerically("<", 0))
|
||||
one := messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: 0}
|
||||
Expect(one.Replica).To(Equal(0))
|
||||
})
|
||||
})
|
||||
@@ -101,23 +101,52 @@ func (h *HTTPFileStager) EnsureRemote(ctx context.Context, nodeID, localPath, ke
|
||||
fileSize := fi.Size()
|
||||
|
||||
url := fmt.Sprintf("http://%s/v1/files/%s", addr, key)
|
||||
|
||||
// Compute the SHA-256 of the local file once and bind it to every PUT
|
||||
// attempt — the server uses it to detect mid-flight content drift and
|
||||
// reject (409) if a partial upload claims a new identity, forcing a clean
|
||||
// restart.
|
||||
localHash, err := downloader.CalculateSHA(localPath)
|
||||
if err != nil {
|
||||
// Hash failure isn't fatal — we can still upload; we just lose
|
||||
// resume-safety and end-of-transfer integrity checks.
|
||||
xlog.Warn("Failed to hash local file for upload integrity check", "localPath", localPath, "error", err)
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
xlog.Info("Uploading file to remote node", "node", nodeID, "file", filepath.Base(localPath), "size", humanFileSize(fileSize), "url", url)
|
||||
|
||||
// Outer time budget: bound the total resumable-upload duration so a
|
||||
// permanently-unreachable worker doesn't hold the request forever. Default
|
||||
// matches the existing per-response timeout.
|
||||
outerBudget := h.resumeBudget()
|
||||
|
||||
resumeCtx, cancel := context.WithTimeout(ctx, outerBudget)
|
||||
defer cancel()
|
||||
|
||||
var lastErr error
|
||||
attempts := h.maxRetries + 1 // maxRetries=3 means 4 total attempts (1 initial + 3 retries)
|
||||
for attempt := 1; attempt <= attempts; attempt++ {
|
||||
attempt := 0
|
||||
for {
|
||||
attempt++
|
||||
if attempt > 1 {
|
||||
backoff := time.Duration(5<<(attempt-2)) * time.Second // 5s, 10s, 20s
|
||||
backoff := nextBackoff(attempt)
|
||||
xlog.Warn("Retrying file upload", "node", nodeID, "file", filepath.Base(localPath),
|
||||
"attempt", attempt, "of", attempts, "backoff", backoff, "lastError", lastErr)
|
||||
"attempt", attempt, "backoff", backoff, "lastError", lastErr)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", fmt.Errorf("upload cancelled during retry backoff: %w", ctx.Err())
|
||||
case <-resumeCtx.Done():
|
||||
return "", fmt.Errorf("upload cancelled during retry backoff (after %d attempts): %w (last: %v)", attempt-1, resumeCtx.Err(), lastErr)
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
|
||||
result, err := h.doUpload(ctx, addr, nodeID, localPath, key, url, fileSize)
|
||||
// Determine resume offset from the server before each attempt. A
|
||||
// HEAD response that reports an in-progress upload (X-Target-SHA256)
|
||||
// matching ours unlocks resume from the reported size; any other
|
||||
// outcome (missing file, hash mismatch, partial-of-different-file)
|
||||
// resets to 0 and uploads the entire file.
|
||||
startOffset := h.resumeOffset(resumeCtx, addr, key, localHash, fileSize)
|
||||
|
||||
result, err := h.doUpload(ctx, resumeCtx, addr, nodeID, localPath, key, url, fileSize, startOffset, localHash)
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
xlog.Info("File upload succeeded after retry", "node", nodeID, "file", filepath.Base(localPath), "attempt", attempt)
|
||||
@@ -126,50 +155,190 @@ func (h *HTTPFileStager) EnsureRemote(ctx context.Context, nodeID, localPath, ke
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
// Non-transient failures (4xx other than 416, hard auth, etc.) abort
|
||||
// immediately — retrying won't help.
|
||||
if !isTransientError(err) {
|
||||
xlog.Error("File upload failed with non-transient error", "node", nodeID, "file", filepath.Base(localPath), "error", err)
|
||||
return "", err
|
||||
}
|
||||
xlog.Warn("File upload failed with transient error", "node", nodeID, "file", filepath.Base(localPath),
|
||||
"attempt", attempt, "of", attempts, "error", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("uploading %s to node %s failed after %d attempts: %w", localPath, nodeID, attempts, lastErr)
|
||||
// Caller-cancelled (not deadline) — give up.
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return "", fmt.Errorf("upload cancelled by caller after %d attempts: %w", attempt, lastErr)
|
||||
}
|
||||
|
||||
// Outer budget exhausted.
|
||||
if errors.Is(resumeCtx.Err(), context.DeadlineExceeded) {
|
||||
return "", fmt.Errorf("uploading %s to node %s failed after %d attempts within %s budget: %w",
|
||||
localPath, nodeID, attempt, outerBudget, lastErr)
|
||||
}
|
||||
|
||||
xlog.Warn("File upload failed with transient error", "node", nodeID, "file", filepath.Base(localPath),
|
||||
"attempt", attempt, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// doUpload performs a single upload attempt.
|
||||
func (h *HTTPFileStager) doUpload(ctx context.Context, addr, nodeID, localPath, key, url string, fileSize int64) (string, error) {
|
||||
// resumeBudget returns the maximum total time the resumable upload loop will
|
||||
// spend retrying transient failures end-to-end. Past this budget the upload
|
||||
// fails rather than spinning forever — 1h covers multi-GB transfers on
|
||||
// pathological links without letting a wedged server jam the master.
|
||||
func (h *HTTPFileStager) resumeBudget() time.Duration {
|
||||
return 1 * time.Hour
|
||||
}
|
||||
|
||||
// nextBackoff returns the sleep before retry #attempt: 1s, 2s, 4s, ..., capped
|
||||
// at 30s, with the first sleep (attempt=2) being 1s.
|
||||
func nextBackoff(attempt int) time.Duration {
|
||||
if attempt < 2 {
|
||||
return 0
|
||||
}
|
||||
const (
|
||||
base = 1 * time.Second
|
||||
ceiling = 30 * time.Second
|
||||
)
|
||||
shift := uint(attempt - 2)
|
||||
if shift > 30 {
|
||||
shift = 30 // saturate before time.Duration overflows
|
||||
}
|
||||
b := base << shift
|
||||
if b > ceiling || b < 0 {
|
||||
b = ceiling
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// resumeOffset asks the server (via HEAD) how many bytes of the current upload
|
||||
// are already on disk. It returns 0 if the server has no usable partial state
|
||||
// (no file, finished file with a different hash, or a partial under a
|
||||
// different target hash). It returns the server-reported size when the
|
||||
// server's X-Target-SHA256 matches our expected final hash AND the size is
|
||||
// strictly less than the local file size.
|
||||
func (h *HTTPFileStager) resumeOffset(ctx context.Context, addr, key, localHash string, fileSize int64) int64 {
|
||||
if localHash == "" || fileSize <= 0 {
|
||||
return 0
|
||||
}
|
||||
url := fmt.Sprintf("http://%s/v1/files/%s", addr, key)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
if h.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+h.token)
|
||||
}
|
||||
resp, err := h.client.Do(req)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return 0
|
||||
}
|
||||
|
||||
sizeStr := resp.Header.Get(HeaderFileSize)
|
||||
if sizeStr == "" {
|
||||
return 0
|
||||
}
|
||||
size, err := strconv.ParseInt(sizeStr, 10, 64)
|
||||
if err != nil || size <= 0 || size >= fileSize {
|
||||
return 0
|
||||
}
|
||||
|
||||
target := resp.Header.Get(HeaderTargetSHA256)
|
||||
if target == "" || !strings.EqualFold(target, localHash) {
|
||||
// No partial-upload metadata, or it's for a different target.
|
||||
return 0
|
||||
}
|
||||
|
||||
xlog.Info("Resuming upload from server-reported offset", "key", key, "offset", size, "total", fileSize)
|
||||
return size
|
||||
}
|
||||
|
||||
// doUpload performs a single upload attempt. When startOffset > 0 the request
|
||||
// is sent as a resumable PUT with a Content-Range header, transferring only
|
||||
// the bytes from startOffset to fileSize-1. The outerCtx is the long-lived
|
||||
// resume budget; reqCtx is what's bound to the request (currently the same as
|
||||
// the parent ctx, since http.Client doesn't expose a per-request timeout).
|
||||
func (h *HTTPFileStager) doUpload(ctx, outerCtx context.Context, addr, nodeID, localPath, key, url string, fileSize, startOffset int64, expectedHash string) (string, error) {
|
||||
if startOffset < 0 || startOffset > fileSize {
|
||||
startOffset = 0
|
||||
}
|
||||
|
||||
f, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("opening local file %s: %w", localPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var body io.Reader = f
|
||||
cb := StagingProgressFromContext(ctx)
|
||||
// For files > 100MB or when a progress callback is set, wrap with progress reporting
|
||||
const progressThreshold = 100 << 20
|
||||
if fileSize > progressThreshold || cb != nil {
|
||||
body = newProgressReader(f, fileSize, filepath.Base(localPath), nodeID, cb)
|
||||
if startOffset > 0 {
|
||||
if _, err := f.Seek(startOffset, io.SeekStart); err != nil {
|
||||
return "", fmt.Errorf("seeking to offset %d in %s: %w", startOffset, localPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, body)
|
||||
chunkLen := fileSize - startOffset
|
||||
|
||||
var body io.Reader = f
|
||||
cb := StagingProgressFromContext(ctx)
|
||||
// For files > 100MB or when a progress callback is set, wrap with progress reporting.
|
||||
// We report against the FULL fileSize (not the chunkLen) so a resumed upload's
|
||||
// progress bar starts from the actual completed fraction rather than at 0%.
|
||||
const progressThreshold = 100 << 20
|
||||
if fileSize > progressThreshold || cb != nil {
|
||||
pr := newProgressReader(f, fileSize, filepath.Base(localPath), nodeID, cb)
|
||||
pr.read = startOffset // seed prior progress
|
||||
body = pr
|
||||
}
|
||||
|
||||
// The body length we actually send.
|
||||
limitedBody := io.LimitReader(body, chunkLen)
|
||||
|
||||
req, err := http.NewRequestWithContext(outerCtx, http.MethodPut, url, limitedBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
req.ContentLength = fileSize // explicit Content-Length for progress tracking
|
||||
req.ContentLength = chunkLen
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
if h.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+h.token)
|
||||
}
|
||||
if expectedHash != "" {
|
||||
// Lets the server detect cross-attempt content drift and reject
|
||||
// resume with 409 if the local file changed identity.
|
||||
req.Header.Set(HeaderContentSHA256, expectedHash)
|
||||
}
|
||||
if startOffset > 0 || (expectedHash != "" && fileSize > 0) {
|
||||
// Send Content-Range even on the first chunk (0-...) when we have an
|
||||
// expected hash, so the server's range-aware branch records the
|
||||
// target-hash sidecar for future resume attempts.
|
||||
req.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", startOffset, fileSize-1, fileSize))
|
||||
}
|
||||
|
||||
resp, err := h.client.Do(req)
|
||||
if err != nil {
|
||||
xlog.Error("File upload failed", "node", nodeID, "file", filepath.Base(localPath), "size", humanFileSize(fileSize), "error", err)
|
||||
xlog.Error("File upload failed", "node", nodeID, "file", filepath.Base(localPath),
|
||||
"size", humanFileSize(fileSize), "offset", startOffset, "error", err)
|
||||
return "", fmt.Errorf("uploading %s to node %s: %w", localPath, nodeID, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 308 Permanent Redirect ("Resume Incomplete") means the chunk landed but
|
||||
// the upload as a whole hasn't completed. From our perspective the
|
||||
// connection survived and the server has more bytes than before — but
|
||||
// since we always send the whole remainder, hitting 308 means the server
|
||||
// truncated us. Treat as transient so the retry loop re-HEADs and tries
|
||||
// again from the new offset.
|
||||
if resp.StatusCode == http.StatusPermanentRedirect {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", &transientStatusError{status: resp.StatusCode, msg: fmt.Sprintf("server reports resume-incomplete: %s", string(body))}
|
||||
}
|
||||
|
||||
// 416 Range Not Satisfiable: client/server disagree on offset. Treat as
|
||||
// transient — the next iteration re-HEADs to learn the correct offset.
|
||||
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", &transientStatusError{status: resp.StatusCode, msg: fmt.Sprintf("range not satisfiable: %s", string(body))}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
xlog.Error("File upload rejected by remote node", "node", nodeID, "file", filepath.Base(localPath), "status", resp.StatusCode, "response", string(respBody))
|
||||
@@ -187,11 +356,30 @@ func (h *HTTPFileStager) doUpload(ctx context.Context, addr, nodeID, localPath,
|
||||
return result.LocalPath, nil
|
||||
}
|
||||
|
||||
// transientStatusError wraps an HTTP status that should be treated as
|
||||
// transient by the upload retry loop.
|
||||
type transientStatusError struct {
|
||||
status int
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *transientStatusError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.status, e.msg)
|
||||
}
|
||||
|
||||
func (e *transientStatusError) Transient() bool { return true }
|
||||
|
||||
// isTransientError returns true if the error is likely transient and worth retrying.
|
||||
func isTransientError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// Errors that explicitly opt into transient semantics (e.g. 308/416 from
|
||||
// the resumable-upload protocol).
|
||||
var transient interface{ Transient() bool }
|
||||
if errors.As(err, &transient) && transient.Transient() {
|
||||
return true
|
||||
}
|
||||
// Connection reset by peer
|
||||
if errors.Is(err, syscall.ECONNRESET) {
|
||||
return true
|
||||
|
||||
@@ -30,7 +30,16 @@ const (
|
||||
HeaderContentSHA256 = "X-Content-SHA256"
|
||||
HeaderLocalPath = "X-Local-Path"
|
||||
HeaderFileSize = "X-File-Size"
|
||||
hashSidecarSuffix = ".sha256"
|
||||
// HeaderTargetSHA256 is set on HEAD responses for partial (resumable) uploads
|
||||
// to expose the expected final SHA-256 of the in-progress file. When set,
|
||||
// the file on disk is not yet the full content — the client may resume by
|
||||
// PUT'ing the remainder with a matching X-Content-SHA256 header.
|
||||
HeaderTargetSHA256 = "X-Target-SHA256"
|
||||
hashSidecarSuffix = ".sha256"
|
||||
// targetSidecarSuffix stores the expected final SHA-256 of a partially
|
||||
// uploaded file. Used to detect mid-flight content mismatches across
|
||||
// resumed PUT requests.
|
||||
targetSidecarSuffix = ".sha256.target"
|
||||
)
|
||||
|
||||
// StartFileTransferServer starts a small HTTP server for file transfer in distributed mode.
|
||||
@@ -169,7 +178,25 @@ func handleHead(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, d
|
||||
}
|
||||
|
||||
w.Header().Set(HeaderFileSize, strconv.FormatInt(info.Size(), 10))
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(info.Size(), 10))
|
||||
w.Header().Set(HeaderLocalPath, filePath)
|
||||
// Advertise resumable-upload support so clients know they may send
|
||||
// Content-Range PUTs to continue partial transfers.
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
|
||||
// If a target-hash sidecar is present the file on disk is a partial
|
||||
// upload, not a finished file. Expose the expected final hash via
|
||||
// X-Target-SHA256 and skip emitting X-Content-SHA256 (which would otherwise
|
||||
// be the hash of just the bytes received so far — misleading for clients
|
||||
// trying to decide whether the file is "the right one").
|
||||
if target, err := os.ReadFile(filePath + targetSidecarSuffix); err == nil {
|
||||
t := strings.TrimSpace(string(target))
|
||||
if len(t) == 64 {
|
||||
w.Header().Set(HeaderTargetSHA256, t)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hashHex, err := computeAndCacheHash(filePath)
|
||||
if err != nil {
|
||||
@@ -181,6 +208,55 @@ func handleHead(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, d
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// contentRange describes a parsed Content-Range request header of the form
|
||||
// "bytes <start>-<end>/<total>". An end of -1 means the request is open-ended
|
||||
// (unknown end), which is unusual for uploads but accepted.
|
||||
type contentRange struct {
|
||||
start int64
|
||||
end int64
|
||||
total int64
|
||||
}
|
||||
|
||||
// parseContentRange parses a Content-Range header value of the form
|
||||
// "bytes <start>-<end>/<total>". RFC 9110 §14.4.
|
||||
// Returns (nil, nil) when the header is empty (no range request).
|
||||
func parseContentRange(h string) (*contentRange, error) {
|
||||
h = strings.TrimSpace(h)
|
||||
if h == "" {
|
||||
return nil, nil
|
||||
}
|
||||
const prefix = "bytes "
|
||||
if !strings.HasPrefix(h, prefix) {
|
||||
return nil, fmt.Errorf("invalid Content-Range: missing %q prefix", strings.TrimSpace(prefix))
|
||||
}
|
||||
spec := strings.TrimSpace(h[len(prefix):])
|
||||
slash := strings.IndexByte(spec, '/')
|
||||
if slash < 0 {
|
||||
return nil, fmt.Errorf("invalid Content-Range: missing /total")
|
||||
}
|
||||
rangePart, totalPart := spec[:slash], spec[slash+1:]
|
||||
dash := strings.IndexByte(rangePart, '-')
|
||||
if dash < 0 {
|
||||
return nil, fmt.Errorf("invalid Content-Range: missing - separator")
|
||||
}
|
||||
start, err := strconv.ParseInt(strings.TrimSpace(rangePart[:dash]), 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid Content-Range start: %w", err)
|
||||
}
|
||||
end, err := strconv.ParseInt(strings.TrimSpace(rangePart[dash+1:]), 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid Content-Range end: %w", err)
|
||||
}
|
||||
total, err := strconv.ParseInt(strings.TrimSpace(totalPart), 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid Content-Range total: %w", err)
|
||||
}
|
||||
if start < 0 || end < start || total < end+1 {
|
||||
return nil, fmt.Errorf("invalid Content-Range range: %d-%d/%d", start, end, total)
|
||||
}
|
||||
return &contentRange{start: start, end: end, total: total}, nil
|
||||
}
|
||||
|
||||
func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, dataDir, key string, maxUploadSize int64) {
|
||||
if key == "" {
|
||||
http.Error(w, "key is required", http.StatusBadRequest)
|
||||
@@ -191,7 +267,19 @@ func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir,
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxUploadSize)
|
||||
}
|
||||
|
||||
xlog.Info("Receiving file upload", "key", key, "contentLength", r.ContentLength, "remote", r.RemoteAddr)
|
||||
// Parse optional Content-Range for resumable uploads.
|
||||
cr, err := parseContentRange(r.Header.Get("Content-Range"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Optional expected total-file SHA-256 used to detect cross-attempt
|
||||
// content drift on resume.
|
||||
expectedFinalHash := strings.TrimSpace(r.Header.Get(HeaderContentSHA256))
|
||||
|
||||
xlog.Info("Receiving file upload", "key", key, "contentLength", r.ContentLength,
|
||||
"contentRange", r.Header.Get("Content-Range"), "remote", r.RemoteAddr)
|
||||
|
||||
// Route keyed files to the appropriate directory
|
||||
targetDir, relName := resolveKeyToDir(key, stagingDir, modelsDir, dataDir)
|
||||
@@ -208,6 +296,21 @@ func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir,
|
||||
return
|
||||
}
|
||||
|
||||
if cr == nil {
|
||||
// Non-resumable (legacy) path: truncate-create, single fire-and-forget.
|
||||
handleFullUpload(w, r, dstPath, key, expectedFinalHash)
|
||||
return
|
||||
}
|
||||
|
||||
handleRangeUpload(w, r, dstPath, key, cr, expectedFinalHash)
|
||||
}
|
||||
|
||||
// handleFullUpload writes the entire request body to dstPath, replacing any
|
||||
// existing content. This is the legacy happy-path with no Range header.
|
||||
func handleFullUpload(w http.ResponseWriter, r *http.Request, dstPath, key, expectedFinalHash string) {
|
||||
// Reset any in-progress resumable state.
|
||||
_ = os.Remove(dstPath + targetSidecarSuffix)
|
||||
|
||||
f, err := os.Create(dstPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("creating file: %v", err), http.StatusInternalServerError)
|
||||
@@ -226,6 +329,14 @@ func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir,
|
||||
}
|
||||
|
||||
hashHex := hex.EncodeToString(hasher.Sum(nil))
|
||||
if expectedFinalHash != "" && !strings.EqualFold(expectedFinalHash, hashHex) {
|
||||
_ = os.Remove(dstPath)
|
||||
_ = os.Remove(dstPath + hashSidecarSuffix)
|
||||
xlog.Error("Uploaded file SHA-256 mismatch", "key", key, "expected", expectedFinalHash, "got", hashHex)
|
||||
http.Error(w, fmt.Sprintf("sha256 mismatch: expected %s got %s", expectedFinalHash, hashHex), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile(dstPath+hashSidecarSuffix, []byte(hashHex), 0640); err != nil {
|
||||
xlog.Warn("Failed to write hash sidecar", "path", dstPath+hashSidecarSuffix, "error", err)
|
||||
}
|
||||
@@ -238,6 +349,154 @@ func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir,
|
||||
}
|
||||
}
|
||||
|
||||
// handleRangeUpload appends a Content-Range slice to dstPath, validating that
|
||||
// the request starts at the current file size. When the slice completes the
|
||||
// transfer (end+1 == total), it validates the optional expected final hash and
|
||||
// writes the sidecar.
|
||||
func handleRangeUpload(w http.ResponseWriter, r *http.Request, dstPath, key string, cr *contentRange, expectedFinalHash string) {
|
||||
// Determine the current on-disk size (0 if missing).
|
||||
var currentSize int64
|
||||
if info, err := os.Stat(dstPath); err == nil {
|
||||
if info.IsDir() {
|
||||
http.Error(w, "destination is a directory", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
currentSize = info.Size()
|
||||
} else if !os.IsNotExist(err) {
|
||||
http.Error(w, fmt.Sprintf("stat dst: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
targetSidecar := dstPath + targetSidecarSuffix
|
||||
|
||||
// Decide whether the existing on-disk bytes (if any) belong to the same
|
||||
// logical file the client is uploading now. If they don't, and the client
|
||||
// is starting from byte 0, we transparently truncate the old file and
|
||||
// proceed — this is the natural "re-upload" case.
|
||||
if cr.start == 0 && currentSize > 0 {
|
||||
sameFile := false
|
||||
if expectedFinalHash != "" {
|
||||
// Compare the client's declared target hash against either an
|
||||
// in-progress target sidecar OR the completed-file sidecar.
|
||||
if t, err := os.ReadFile(targetSidecar); err == nil {
|
||||
if strings.EqualFold(strings.TrimSpace(string(t)), expectedFinalHash) {
|
||||
sameFile = true
|
||||
}
|
||||
} else if h, err := os.ReadFile(dstPath + hashSidecarSuffix); err == nil {
|
||||
if strings.EqualFold(strings.TrimSpace(string(h)), expectedFinalHash) {
|
||||
sameFile = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !sameFile {
|
||||
// Different file content claimed under the same key — drop any
|
||||
// existing bytes (completed or partial) so the new upload starts
|
||||
// from a clean slate.
|
||||
_ = os.Remove(dstPath)
|
||||
_ = os.Remove(dstPath + hashSidecarSuffix)
|
||||
_ = os.Remove(targetSidecar)
|
||||
currentSize = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Cross-attempt consistency: if there's an in-progress target sidecar with
|
||||
// a different hash than what's now being claimed, force a restart.
|
||||
if expectedFinalHash != "" && cr.start > 0 {
|
||||
prev, _ := os.ReadFile(targetSidecar)
|
||||
prevHash := strings.TrimSpace(string(prev))
|
||||
if prevHash != "" && !strings.EqualFold(prevHash, expectedFinalHash) {
|
||||
_ = os.Remove(dstPath)
|
||||
_ = os.Remove(dstPath + hashSidecarSuffix)
|
||||
_ = os.Remove(targetSidecar)
|
||||
http.Error(w, fmt.Sprintf("X-Content-SHA256 mismatch with in-progress upload (was %s, now %s); restart from byte 0", prevHash, expectedFinalHash), http.StatusConflict)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// The most important invariant: the client must continue from exactly
|
||||
// where the server left off. If not, return 416 with the current size in
|
||||
// the Range header so the client can re-sync.
|
||||
if cr.start != currentSize {
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", cr.total))
|
||||
w.Header().Set(HeaderFileSize, strconv.FormatInt(currentSize, 10))
|
||||
http.Error(w, fmt.Sprintf("Content-Range start %d does not match current file size %d", cr.start, currentSize), http.StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
|
||||
// Open the file in append mode (create if missing).
|
||||
f, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("opening dst: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
// Persist the declared expected hash so subsequent chunks can be
|
||||
// cross-checked.
|
||||
if expectedFinalHash != "" {
|
||||
if err := os.WriteFile(targetSidecar, []byte(expectedFinalHash), 0640); err != nil {
|
||||
xlog.Warn("Failed to write target hash sidecar", "path", targetSidecar, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
expectedChunkLen := cr.end - cr.start + 1
|
||||
limited := io.LimitReader(r.Body, expectedChunkLen)
|
||||
n, err := io.Copy(f, limited)
|
||||
if err != nil {
|
||||
xlog.Error("Range upload chunk failed", "key", key, "bytesReceived", n, "expected", expectedChunkLen, "remote", r.RemoteAddr, "error", err)
|
||||
http.Error(w, fmt.Sprintf("writing file: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if n != expectedChunkLen {
|
||||
xlog.Error("Range upload chunk short", "key", key, "bytesReceived", n, "expected", expectedChunkLen, "remote", r.RemoteAddr)
|
||||
http.Error(w, fmt.Sprintf("short body: got %d expected %d", n, expectedChunkLen), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newSize := currentSize + n
|
||||
|
||||
// If this chunk does not complete the transfer, return 308 Resume
|
||||
// Incomplete (semantically aligns with the GCS/Tus resumable convention,
|
||||
// which most language ecosystems treat as "keep going") and report the
|
||||
// current size so the client can continue.
|
||||
if newSize < cr.total {
|
||||
w.Header().Set("Range", fmt.Sprintf("bytes=0-%d", newSize-1))
|
||||
w.Header().Set(HeaderFileSize, strconv.FormatInt(newSize, 10))
|
||||
w.WriteHeader(http.StatusPermanentRedirect) // 308 — "Resume Incomplete"
|
||||
xlog.Debug("Range upload chunk accepted", "key", key, "newSize", newSize, "total", cr.total)
|
||||
return
|
||||
}
|
||||
|
||||
// Upload complete — compute the final hash by re-reading the file.
|
||||
finalHash, err := downloader.CalculateSHA(dstPath)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to compute final hash on range upload", "path", dstPath, "error", err)
|
||||
http.Error(w, fmt.Sprintf("computing final hash: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if expectedFinalHash != "" && !strings.EqualFold(expectedFinalHash, finalHash) {
|
||||
_ = os.Remove(dstPath)
|
||||
_ = os.Remove(dstPath + hashSidecarSuffix)
|
||||
_ = os.Remove(targetSidecar)
|
||||
xlog.Error("Resumed upload SHA-256 mismatch", "key", key, "expected", expectedFinalHash, "got", finalHash)
|
||||
http.Error(w, fmt.Sprintf("sha256 mismatch: expected %s got %s", expectedFinalHash, finalHash), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile(dstPath+hashSidecarSuffix, []byte(finalHash), 0640); err != nil {
|
||||
xlog.Warn("Failed to write hash sidecar", "path", dstPath+hashSidecarSuffix, "error", err)
|
||||
}
|
||||
// Clear the in-progress sidecar — upload is committed.
|
||||
_ = os.Remove(targetSidecar)
|
||||
|
||||
xlog.Info("Resumable file upload complete", "key", key, "path", dstPath, "size", newSize, "sha256", finalHash)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{"local_path": dstPath}); err != nil {
|
||||
xlog.Warn("Failed to encode upload response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// computeAndCacheHash returns the SHA-256 hex digest for filePath.
|
||||
// It reads a cached sidecar when available and still fresh (sidecar mtime >=
|
||||
// file mtime), otherwise computes the hash and writes/updates the sidecar.
|
||||
|
||||
@@ -5,12 +5,16 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -559,6 +563,330 @@ var _ = Describe("FileTransferServer", func() {
|
||||
Expect(uploaded).To(Equal(content))
|
||||
})
|
||||
})
|
||||
|
||||
// --- Resumable upload (Content-Range) tests ---
|
||||
|
||||
Describe("Resumable upload (Content-Range)", func() {
|
||||
// doPut sends a PUT to ts with the given body, headers, and key.
|
||||
doPut := func(ts *httptest.Server, token, key string, body []byte, headers map[string]string) (*http.Response, []byte) {
|
||||
req, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/"+key, bytes.NewReader(body))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return resp, respBody
|
||||
}
|
||||
|
||||
It("accepts two consecutive Content-Range chunks and produces the full file", func() {
|
||||
ts, stagingDir, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
full := bytes.Repeat([]byte("abcdefghij"), 20) // 200 bytes
|
||||
fullHash := sha256Hex(full)
|
||||
|
||||
// Chunk 1: bytes 0-99
|
||||
resp1, _ := doPut(ts, "tok", "chunked.bin", full[:100], map[string]string{
|
||||
"Content-Range": fmt.Sprintf("bytes 0-99/%d", len(full)),
|
||||
HeaderContentSHA256: fullHash,
|
||||
})
|
||||
Expect(resp1.StatusCode).To(Equal(http.StatusPermanentRedirect))
|
||||
Expect(resp1.Header.Get(HeaderFileSize)).To(Equal("100"))
|
||||
|
||||
// Chunk 2: bytes 100-199
|
||||
resp2, _ := doPut(ts, "tok", "chunked.bin", full[100:], map[string]string{
|
||||
"Content-Range": fmt.Sprintf("bytes 100-199/%d", len(full)),
|
||||
HeaderContentSHA256: fullHash,
|
||||
})
|
||||
Expect(resp2.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
// File matches the full content
|
||||
got, err := os.ReadFile(filepath.Join(stagingDir, "chunked.bin"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(full))
|
||||
|
||||
// Sidecar holds the final hash
|
||||
sidecar, err := os.ReadFile(filepath.Join(stagingDir, "chunked.bin.sha256"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(sidecar)).To(Equal(fullHash))
|
||||
|
||||
// Target sidecar (in-progress marker) is cleared once complete
|
||||
_, err = os.Stat(filepath.Join(stagingDir, "chunked.bin.sha256.target"))
|
||||
Expect(os.IsNotExist(err)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns 416 when Content-Range start does not match current file size", func() {
|
||||
ts, _, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
full := bytes.Repeat([]byte("x"), 200)
|
||||
fullHash := sha256Hex(full)
|
||||
|
||||
// First chunk: bytes 0-49
|
||||
resp1, _ := doPut(ts, "tok", "mismatch.bin", full[:50], map[string]string{
|
||||
"Content-Range": fmt.Sprintf("bytes 0-49/%d", len(full)),
|
||||
HeaderContentSHA256: fullHash,
|
||||
})
|
||||
Expect(resp1.StatusCode).To(Equal(http.StatusPermanentRedirect))
|
||||
|
||||
// Skip ahead: server has 50 bytes but client tries to send 100-199.
|
||||
resp2, _ := doPut(ts, "tok", "mismatch.bin", full[100:200], map[string]string{
|
||||
"Content-Range": fmt.Sprintf("bytes 100-199/%d", len(full)),
|
||||
HeaderContentSHA256: fullHash,
|
||||
})
|
||||
Expect(resp2.StatusCode).To(Equal(http.StatusRequestedRangeNotSatisfiable))
|
||||
Expect(resp2.Header.Get(HeaderFileSize)).To(Equal("50"))
|
||||
})
|
||||
|
||||
It("returns 409 when X-Content-SHA256 changes between resumed chunks", func() {
|
||||
ts, _, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
a := bytes.Repeat([]byte("a"), 200)
|
||||
b := bytes.Repeat([]byte("b"), 200)
|
||||
|
||||
// Chunk 1 (file A): bytes 0-49
|
||||
resp1, _ := doPut(ts, "tok", "drifted.bin", a[:50], map[string]string{
|
||||
"Content-Range": fmt.Sprintf("bytes 0-49/%d", len(a)),
|
||||
HeaderContentSHA256: sha256Hex(a),
|
||||
})
|
||||
Expect(resp1.StatusCode).To(Equal(http.StatusPermanentRedirect))
|
||||
|
||||
// Chunk 2 claims file B's hash for the *same* key — should be rejected.
|
||||
resp2, _ := doPut(ts, "tok", "drifted.bin", b[50:100], map[string]string{
|
||||
"Content-Range": fmt.Sprintf("bytes 50-99/%d", len(b)),
|
||||
HeaderContentSHA256: sha256Hex(b),
|
||||
})
|
||||
Expect(resp2.StatusCode).To(Equal(http.StatusConflict))
|
||||
})
|
||||
|
||||
It("returns 400 when final SHA-256 does not match the declared target", func() {
|
||||
ts, _, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
full := bytes.Repeat([]byte("z"), 100)
|
||||
wrongHash := sha256Hex([]byte("definitely-not-this"))
|
||||
|
||||
resp, _ := doPut(ts, "tok", "bad-hash.bin", full, map[string]string{
|
||||
"Content-Range": fmt.Sprintf("bytes 0-99/%d", len(full)),
|
||||
HeaderContentSHA256: wrongHash,
|
||||
})
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
It("HEAD on a partial upload exposes X-Target-SHA256 and current size", func() {
|
||||
ts, _, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
full := bytes.Repeat([]byte("q"), 200)
|
||||
fullHash := sha256Hex(full)
|
||||
|
||||
// One chunk uploaded, file is partial.
|
||||
resp1, _ := doPut(ts, "tok", "partial.bin", full[:60], map[string]string{
|
||||
"Content-Range": fmt.Sprintf("bytes 0-59/%d", len(full)),
|
||||
HeaderContentSHA256: fullHash,
|
||||
})
|
||||
Expect(resp1.StatusCode).To(Equal(http.StatusPermanentRedirect))
|
||||
|
||||
req, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/partial.bin", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Authorization", "Bearer tok")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_ = resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
Expect(resp.Header.Get(HeaderFileSize)).To(Equal("60"))
|
||||
Expect(resp.Header.Get("Accept-Ranges")).To(Equal("bytes"))
|
||||
Expect(resp.Header.Get(HeaderTargetSHA256)).To(Equal(fullHash))
|
||||
// While the upload is in progress we must NOT expose a misleading
|
||||
// X-Content-SHA256 of the bytes-so-far — clients use HeaderContentSHA256
|
||||
// only for completed files.
|
||||
Expect(resp.Header.Get(HeaderContentSHA256)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("transparently overwrites an existing finished file when client starts from byte 0 with a new hash", func() {
|
||||
ts, stagingDir, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
// Pre-place a finished file (sidecar present, no target sidecar).
|
||||
oldContent := []byte("ancient version")
|
||||
err := os.WriteFile(filepath.Join(stagingDir, "overwrite.bin"), oldContent, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(stagingDir, "overwrite.bin.sha256"), []byte(sha256Hex(oldContent)), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// New upload with a different target hash, range 0-N/total.
|
||||
newContent := bytes.Repeat([]byte("new"), 50) // 150 bytes
|
||||
newHash := sha256Hex(newContent)
|
||||
|
||||
resp, _ := doPut(ts, "tok", "overwrite.bin", newContent, map[string]string{
|
||||
"Content-Range": fmt.Sprintf("bytes 0-%d/%d", len(newContent)-1, len(newContent)),
|
||||
HeaderContentSHA256: newHash,
|
||||
})
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
got, err := os.ReadFile(filepath.Join(stagingDir, "overwrite.bin"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(newContent))
|
||||
})
|
||||
|
||||
It("HEAD advertises Accept-Ranges: bytes on completed files", func() {
|
||||
ts, _, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
content := "done"
|
||||
doPut(ts, "tok", "ranges-advert.txt", []byte(content), nil)
|
||||
|
||||
req, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/ranges-advert.txt", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Authorization", "Bearer tok")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_ = resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
Expect(resp.Header.Get("Accept-Ranges")).To(Equal("bytes"))
|
||||
Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len(content))))
|
||||
})
|
||||
})
|
||||
|
||||
// --- End-to-end client/server resume tests ---
|
||||
|
||||
Describe("HTTPFileStager resume via EnsureRemote", func() {
|
||||
It("resumes from server's reported offset when a partial upload exists", func() {
|
||||
ts, stagingDir, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
// Create the local file (the master's source-of-truth).
|
||||
localDir := GinkgoT().TempDir()
|
||||
localPath := filepath.Join(localDir, "resume.bin")
|
||||
content := bytes.Repeat([]byte("R"), 500)
|
||||
Expect(os.WriteFile(localPath, content, 0644)).To(Succeed())
|
||||
fullHash := sha256Hex(content)
|
||||
|
||||
// Pre-seed the "worker" with the first 200 bytes as if a prior
|
||||
// attempt had transferred that much, plus a target-hash sidecar
|
||||
// claiming the full file's hash.
|
||||
dst := filepath.Join(stagingDir, "resume.bin")
|
||||
Expect(os.WriteFile(dst, content[:200], 0644)).To(Succeed())
|
||||
Expect(os.WriteFile(dst+".sha256.target", []byte(fullHash), 0644)).To(Succeed())
|
||||
|
||||
addr := strings.TrimPrefix(ts.URL, "http://")
|
||||
stager := NewHTTPFileStager(func(nodeID string) (string, error) {
|
||||
return addr, nil
|
||||
}, "tok")
|
||||
|
||||
remotePath, err := stager.EnsureRemote(context.Background(), "node-1", localPath, "resume.bin")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(remotePath).To(Equal(dst))
|
||||
|
||||
got, err := os.ReadFile(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(content))
|
||||
|
||||
// Final sidecar should hold the full-file hash.
|
||||
sidecar, err := os.ReadFile(dst + ".sha256")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(strings.TrimSpace(string(sidecar))).To(Equal(fullHash))
|
||||
})
|
||||
|
||||
It("survives a mid-stream connection drop and resumes on retry", func() {
|
||||
// Server that drops the connection after writing the first N bytes
|
||||
// on the FIRST PUT attempt, then behaves normally.
|
||||
stagingDir := GinkgoT().TempDir()
|
||||
modelsDir := GinkgoT().TempDir()
|
||||
dataDir := GinkgoT().TempDir()
|
||||
|
||||
var (
|
||||
attemptCount int
|
||||
attemptMu sync.Mutex
|
||||
)
|
||||
const dropAfter = 80 // bytes the server "accepts" before crashing
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/v1/files/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !checkBearerToken(r, "tok") {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
key := strings.TrimPrefix(r.URL.Path, "/v1/files/")
|
||||
if r.Method == http.MethodHead {
|
||||
handleHead(w, r, stagingDir, modelsDir, dataDir, key)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPut {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
attemptMu.Lock()
|
||||
attemptCount++
|
||||
thisAttempt := attemptCount
|
||||
attemptMu.Unlock()
|
||||
|
||||
if thisAttempt == 1 {
|
||||
// Read a bounded prefix into the partial file, then hijack
|
||||
// the connection and close abruptly to simulate the drop.
|
||||
cr, err := parseContentRange(r.Header.Get("Content-Range"))
|
||||
if err != nil || cr == nil {
|
||||
http.Error(w, "expected content-range", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
dst := filepath.Join(stagingDir, key)
|
||||
f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0640)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
target := r.Header.Get(HeaderContentSHA256)
|
||||
if target != "" {
|
||||
_ = os.WriteFile(dst+".sha256.target", []byte(target), 0640)
|
||||
}
|
||||
_, _ = io.CopyN(f, r.Body, dropAfter)
|
||||
_ = f.Close()
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "hijack unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err == nil {
|
||||
_ = conn.Close() // abrupt close — client sees a transport error
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Subsequent attempts: behave normally.
|
||||
handleUpload(w, r, stagingDir, modelsDir, dataDir, key, 0)
|
||||
})
|
||||
ts := httptest.NewServer(mux)
|
||||
DeferCleanup(ts.Close)
|
||||
|
||||
// Build a small "model" file to upload (300 bytes for speed).
|
||||
localDir := GinkgoT().TempDir()
|
||||
localPath := filepath.Join(localDir, "flaky.bin")
|
||||
content := bytes.Repeat([]byte("F"), 300)
|
||||
Expect(os.WriteFile(localPath, content, 0644)).To(Succeed())
|
||||
|
||||
addr := strings.TrimPrefix(ts.URL, "http://")
|
||||
stager := NewHTTPFileStager(func(nodeID string) (string, error) {
|
||||
return addr, nil
|
||||
}, "tok")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
remotePath, err := stager.EnsureRemote(ctx, "node-1", localPath, "flaky.bin")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(remotePath).To(Equal(filepath.Join(stagingDir, "flaky.bin")))
|
||||
|
||||
// Final file is correct
|
||||
got, err := os.ReadFile(filepath.Join(stagingDir, "flaky.bin"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(content))
|
||||
|
||||
// At least one retry happened
|
||||
attemptMu.Lock()
|
||||
Expect(attemptCount).To(BeNumerically(">=", 2))
|
||||
attemptMu.Unlock()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func sha256Hex(data []byte) string {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
// ModelRouter is used by SmartRouter for routing decisions and model lifecycle.
|
||||
type ModelRouter interface {
|
||||
FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, *NodeModel, error)
|
||||
FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string, pref *RoutePreference) (*BackendNode, *NodeModel, error)
|
||||
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
@@ -37,6 +37,7 @@ type ModelRouter interface {
|
||||
FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
|
||||
GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error)
|
||||
FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error)
|
||||
LoadedReplicaStats(ctx context.Context, modelName string, candidateNodeIDs []string) ([]ReplicaCandidate, error)
|
||||
}
|
||||
|
||||
// ConcurrencyConflictResolver returns the names of configured models that
|
||||
|
||||
@@ -27,7 +27,7 @@ func newFakeModelRouterForSmartRouter() *fakeModelRouterForSmartRouter {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeModelRouterForSmartRouter) FindAndLockNodeWithModel(_ context.Context, _ string, _ []string) (*BackendNode, *NodeModel, error) {
|
||||
func (f *fakeModelRouterForSmartRouter) FindAndLockNodeWithModel(_ context.Context, _ string, _ []string, _ *RoutePreference) (*BackendNode, *NodeModel, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.node, f.nodeModel, f.findErr
|
||||
@@ -121,6 +121,9 @@ func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ strin
|
||||
func (f *fakeModelRouterForSmartRouter) FindNodesWithModel(_ context.Context, _ string) ([]BackendNode, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeModelRouterForSmartRouter) LoadedReplicaStats(_ context.Context, _ string, _ []string) ([]ReplicaCandidate, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Compile-time check
|
||||
var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil)
|
||||
|
||||
95
core/services/nodes/prefixcache/config.go
Normal file
95
core/services/nodes/prefixcache/config.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds prefix-cache-aware routing settings. Per-model overrides
|
||||
// (policy, abs/rel thresholds, min-match) live on ModelSchedulingConfig; TTL
|
||||
// and window/depth are global-only.
|
||||
type Config struct {
|
||||
GlobalPolicy RoutePolicy
|
||||
MinPrefixMatch float64 // ratio matched/total, [0,1]
|
||||
BalanceAbsThreshold int // absolute in-flight slack
|
||||
BalanceRelThreshold float64 // relative load ratio, >= 1
|
||||
TTL time.Duration // idle-timeout for entries
|
||||
HalfLife time.Duration // recency decay for cacheWeight
|
||||
WindowBytes int // chunk window size
|
||||
MaxDepth int // max trailing blocks hashed
|
||||
// PressureWindow is the rolling window over which forced-disturb events are
|
||||
// counted for the autoscale signal (see Pressure). Default 1 minute.
|
||||
PressureWindow time.Duration
|
||||
// PressureScaleThreshold is the minimum forced-disturb count within
|
||||
// PressureWindow that makes the reconciler treat the cache-warm replica as
|
||||
// saturated and scale up (subject to MaxReplicas and capacity). Default 1,
|
||||
// i.e. any sustained forced-disturb.
|
||||
PressureScaleThreshold int
|
||||
}
|
||||
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
GlobalPolicy: RoutePolicyPrefixCache,
|
||||
MinPrefixMatch: 0.3,
|
||||
BalanceAbsThreshold: 2,
|
||||
BalanceRelThreshold: 1.5,
|
||||
TTL: 5 * time.Minute,
|
||||
HalfLife: 2 * time.Minute,
|
||||
WindowBytes: 256,
|
||||
MaxDepth: 64,
|
||||
PressureWindow: time.Minute,
|
||||
PressureScaleThreshold: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// validateThresholdBounds enforces the numeric bounds shared between the
|
||||
// per-model override validator (ValidateThresholds) and Config.Validate:
|
||||
// minMatch in [0,1]; absThr >= 0; relThr == 0 (inherit) or >= 1. It is the
|
||||
// single source of truth for those bounds so the endpoint and the global
|
||||
// config cannot drift apart.
|
||||
func validateThresholdBounds(absThr int, relThr, minMatch float64) error {
|
||||
if minMatch < 0 || minMatch > 1 {
|
||||
return fmt.Errorf("prefixcache: min_prefix_match must be in [0,1], got %v", minMatch)
|
||||
}
|
||||
if absThr < 0 {
|
||||
return fmt.Errorf("prefixcache: balance_abs_threshold must be >= 0, got %d", absThr)
|
||||
}
|
||||
if relThr != 0 && relThr < 1 {
|
||||
return fmt.Errorf("prefixcache: balance_rel_threshold must be 0 (inherit) or >= 1, got %v", relThr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateThresholds checks per-model override bounds. routePolicy must be one
|
||||
// of "", "round_robin", "prefix_cache" (explicit allow-list - NOT ParsePolicy,
|
||||
// which maps unknown to Default and would accept typos). minMatch in [0,1];
|
||||
// absThr >= 0; relThr == 0 (inherit) or >= 1.
|
||||
func ValidateThresholds(routePolicy string, absThr int, relThr, minMatch float64) error {
|
||||
switch routePolicy {
|
||||
case "", "round_robin", "prefix_cache":
|
||||
default:
|
||||
return fmt.Errorf(`prefixcache: route_policy must be one of "", "round_robin", "prefix_cache", got %q`, routePolicy)
|
||||
}
|
||||
return validateThresholdBounds(absThr, relThr, minMatch)
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
// Config.BalanceRelThreshold has no "inherit" sentinel - it is a concrete
|
||||
// global value that must be >= 1 - so pass 0 for relThr to the shared
|
||||
// numeric check and assert the >= 1 floor here separately.
|
||||
if err := validateThresholdBounds(c.BalanceAbsThreshold, 0, c.MinPrefixMatch); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.BalanceRelThreshold < 1 {
|
||||
return fmt.Errorf("prefixcache: balance_rel_threshold must be >= 1, got %v", c.BalanceRelThreshold)
|
||||
}
|
||||
if c.WindowBytes <= 0 || c.MaxDepth <= 0 {
|
||||
return fmt.Errorf("prefixcache: window_bytes and max_depth must be > 0")
|
||||
}
|
||||
// TTL must be positive: it is the entry idle-lifetime and the eviction
|
||||
// ticker runs at TTL/2, so time.NewTicker would panic on TTL <= 0.
|
||||
if c.TTL <= 0 {
|
||||
return fmt.Errorf("prefixcache: ttl must be > 0, got %v", c.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
73
core/services/nodes/prefixcache/config_test.go
Normal file
73
core/services/nodes/prefixcache/config_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
var _ = Describe("Config", func() {
|
||||
It("supplies defaults", func() {
|
||||
c := prefixcache.DefaultConfig()
|
||||
Expect(c.GlobalPolicy).To(Equal(prefixcache.RoutePolicyPrefixCache)) // default ON
|
||||
Expect(c.MinPrefixMatch).To(BeNumerically("==", 0.3))
|
||||
Expect(c.BalanceAbsThreshold).To(Equal(2))
|
||||
Expect(c.BalanceRelThreshold).To(BeNumerically("==", 1.5))
|
||||
Expect(c.TTL).To(Equal(5 * time.Minute))
|
||||
Expect(c.WindowBytes).To(Equal(256))
|
||||
Expect(c.MaxDepth).To(Equal(64))
|
||||
})
|
||||
|
||||
It("rejects invalid values", func() {
|
||||
c := prefixcache.DefaultConfig()
|
||||
c.MinPrefixMatch = 1.5
|
||||
Expect(c.Validate()).To(HaveOccurred())
|
||||
c = prefixcache.DefaultConfig()
|
||||
c.BalanceAbsThreshold = -1
|
||||
Expect(c.Validate()).To(HaveOccurred())
|
||||
c = prefixcache.DefaultConfig()
|
||||
c.TTL = 0
|
||||
Expect(c.Validate()).To(HaveOccurred()) // TTL/2 ticker would panic
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("ValidateThresholds", func() {
|
||||
It("accepts valid values across all route policies", func() {
|
||||
Expect(prefixcache.ValidateThresholds("", 3, 0, 0.4)).To(Succeed())
|
||||
Expect(prefixcache.ValidateThresholds("round_robin", 0, 1.5, 0)).To(Succeed())
|
||||
Expect(prefixcache.ValidateThresholds("prefix_cache", 2, 2.0, 1.0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects an unknown route_policy (explicit allow-list, no silent default)", func() {
|
||||
err := prefixcache.ValidateThresholds("bogus", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("rejects min_prefix_match above 1", func() {
|
||||
err := prefixcache.ValidateThresholds("", 0, 0, 1.5)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative min_prefix_match", func() {
|
||||
err := prefixcache.ValidateThresholds("", 0, 0, -0.1)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative balance_abs_threshold", func() {
|
||||
err := prefixcache.ValidateThresholds("", -1, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_abs_threshold"))
|
||||
})
|
||||
|
||||
It("rejects balance_rel_threshold between 0 and 1 exclusive", func() {
|
||||
err := prefixcache.ValidateThresholds("", 0, 0.5, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
})
|
||||
18
core/services/nodes/prefixcache/export_test.go
Normal file
18
core/services/nodes/prefixcache/export_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package prefixcache
|
||||
|
||||
// LenForTest exposes the internal per-model slice length so black-box tests can
|
||||
// assert that Record bounds its backing slice. Test-only.
|
||||
func (p *Pressure) LenForTest(model string) int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.events[model])
|
||||
}
|
||||
|
||||
// TreeCountForTest exposes the number of per-model radix trees the Index
|
||||
// currently retains, so black-box tests can assert that Invalidate does not
|
||||
// intern empty trees for models that never used the prefix cache. Test-only.
|
||||
func (ix *Index) TreeCountForTest() int {
|
||||
ix.mu.RLock()
|
||||
defer ix.mu.RUnlock()
|
||||
return len(ix.trees)
|
||||
}
|
||||
57
core/services/nodes/prefixcache/extractor.go
Normal file
57
core/services/nodes/prefixcache/extractor.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
)
|
||||
|
||||
// ExtractChain renders prompt into a cumulative chain of prefix hashes:
|
||||
// h[0]=H(salt,block0), h[i]=H(h[i-1],block_i). Blocks are fixed
|
||||
// cfg.WindowBytes-byte windows over the prompt bytes, chunked from absolute
|
||||
// offset 0 with fixed boundaries [0,W), [W,2W), ... and the chain is capped to
|
||||
// the FIRST cfg.MaxDepth blocks (the head).
|
||||
//
|
||||
// Head-first chunking is what makes this a true prefix-chain. The reusable
|
||||
// KV/prefix cache is always at the HEAD of the prompt: the system prompt and
|
||||
// early turns are stable, new content is appended at the end, and the KV cache
|
||||
// is valid up to the first differing token scanning from the start. Because the
|
||||
// boundaries are anchored at offset 0 (never length-dependent), a prompt P and
|
||||
// any extension P+suffix share their entire leading overlap, so turn N and turn
|
||||
// N+1 match for longest-prefix routing. Prefixes deeper than
|
||||
// MaxDepth*WindowBytes bytes are treated as equal (two prompts agreeing on the
|
||||
// first MaxDepth head blocks yield identical chains): an accepted routing-hint
|
||||
// limitation, since the cap bounds the chain length for very long prompts.
|
||||
//
|
||||
// xxhash is used (not hash/maphash) because the hash MUST be identical across
|
||||
// frontend processes: peers exchange these hashes over NATS, and maphash uses a
|
||||
// per-process random seed that would make peers disagree.
|
||||
func ExtractChain(model, prompt string, cfg Config) []uint64 {
|
||||
if prompt == "" {
|
||||
return nil
|
||||
}
|
||||
data := []byte(prompt)
|
||||
nBlocks := (len(data) + cfg.WindowBytes - 1) / cfg.WindowBytes
|
||||
depth := min(nBlocks, cfg.MaxDepth)
|
||||
salt := xxhash.Sum64String(model)
|
||||
// One Digest reused across blocks: Reset() restores the seed-0 initial
|
||||
// state, so Reset()+Write produces the byte-identical value to a fresh
|
||||
// New()+Write. xxhash seed 0 is stateless, so output is unchanged while we
|
||||
// avoid allocating a Digest per block. The output determinism across
|
||||
// processes (peers exchange these hashes over NATS) is preserved.
|
||||
h := xxhash.New()
|
||||
chain := make([]uint64, 0, depth)
|
||||
prev := salt
|
||||
var pb [8]byte
|
||||
for i := range depth {
|
||||
off := i * cfg.WindowBytes
|
||||
end := min(off+cfg.WindowBytes, len(data))
|
||||
h.Reset()
|
||||
binary.LittleEndian.PutUint64(pb[:], prev)
|
||||
_, _ = h.Write(pb[:])
|
||||
_, _ = h.Write(data[off:end])
|
||||
prev = h.Sum64()
|
||||
chain = append(chain, prev)
|
||||
}
|
||||
return chain
|
||||
}
|
||||
75
core/services/nodes/prefixcache/extractor_test.go
Normal file
75
core/services/nodes/prefixcache/extractor_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
var _ = Describe("Extractor", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
|
||||
It("produces a deterministic chain for the same prompt and model", func() {
|
||||
a := prefixcache.ExtractChain("modelX", "hello world", cfg)
|
||||
b := prefixcache.ExtractChain("modelX", "hello world", cfg)
|
||||
Expect(a).To(Equal(b))
|
||||
Expect(len(a)).To(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
It("shares the head but diverges on a volatile tail", func() {
|
||||
base := strings.Repeat("system rules ", 100) // > one window
|
||||
x := prefixcache.ExtractChain("m", base+"Current time 12:00:00", cfg)
|
||||
y := prefixcache.ExtractChain("m", base+"Current time 12:00:01", cfg)
|
||||
// leading hashes (the stable head) are identical
|
||||
Expect(x[0]).To(Equal(y[0]))
|
||||
// the final (tail) hash differs
|
||||
Expect(x[len(x)-1]).NotTo(Equal(y[len(y)-1]))
|
||||
})
|
||||
|
||||
It("salts by model so identical text yields different chains per model", func() {
|
||||
Expect(prefixcache.ExtractChain("m1", "abc", cfg)[0]).
|
||||
NotTo(Equal(prefixcache.ExtractChain("m2", "abc", cfg)[0]))
|
||||
})
|
||||
|
||||
It("caps depth", func() {
|
||||
small := cfg
|
||||
small.WindowBytes = 1
|
||||
small.MaxDepth = 4
|
||||
chain := prefixcache.ExtractChain("m", "abcdefghij", small)
|
||||
Expect(len(chain)).To(Equal(4))
|
||||
})
|
||||
|
||||
It("returns nil for empty prompt", func() {
|
||||
Expect(prefixcache.ExtractChain("m", "", cfg)).To(BeNil())
|
||||
})
|
||||
|
||||
It("stays stable across turns once the prompt grows past the depth cap", func() {
|
||||
small := cfg
|
||||
small.WindowBytes = 4
|
||||
small.MaxDepth = 3 // 12-byte head budget
|
||||
|
||||
// base is longer than MaxDepth*WindowBytes so the chain is capped to
|
||||
// the first 3 head blocks.
|
||||
base := "system-rules-stable-prefix-that-exceeds-the-budget"
|
||||
Expect(len(base)).To(BeNumerically(">", small.WindowBytes*small.MaxDepth))
|
||||
|
||||
turnN := prefixcache.ExtractChain("m", base, small)
|
||||
turnN1 := prefixcache.ExtractChain("m", base+"more text appended", small)
|
||||
// Both capped to the same first MaxDepth head blocks -> identical chains.
|
||||
Expect(turnN).To(HaveLen(small.MaxDepth))
|
||||
Expect(turnN1).To(HaveLen(small.MaxDepth))
|
||||
Expect(turnN1).To(Equal(turnN))
|
||||
|
||||
// A prompt diverging WITHIN the budget shares the leading hashes up to
|
||||
// the divergence block and differs after. "system-r" matches base for
|
||||
// the first two 4-byte blocks ("syst","em-r"), then block 2 differs.
|
||||
divergent := prefixcache.ExtractChain("m", "system-rDIFFERENT-tail", small)
|
||||
Expect(divergent).To(HaveLen(small.MaxDepth))
|
||||
Expect(divergent[0]).To(Equal(turnN[0]))
|
||||
Expect(divergent[1]).To(Equal(turnN[1]))
|
||||
Expect(divergent[2]).NotTo(Equal(turnN[2]))
|
||||
})
|
||||
})
|
||||
129
core/services/nodes/prefixcache/index.go
Normal file
129
core/services/nodes/prefixcache/index.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/radixtree"
|
||||
)
|
||||
|
||||
// Index is the guessed (routing-history) Provider backed by per-model radix
|
||||
// trees keyed by ReplicaKey. Affinity is per replica, so the same prefix served
|
||||
// by two replicas of one node resolves back to the exact replica that served it.
|
||||
// Safe for concurrent use.
|
||||
type Index struct {
|
||||
cfg Config
|
||||
mu sync.RWMutex
|
||||
trees map[string]*radixtree.Tree[ReplicaKey]
|
||||
}
|
||||
|
||||
func NewIndex(cfg Config) *Index {
|
||||
return &Index{cfg: cfg, trees: map[string]*radixtree.Tree[ReplicaKey]{}}
|
||||
}
|
||||
|
||||
// existingTree returns the tree for model without creating one. The bool
|
||||
// reports whether a tree already existed.
|
||||
func (ix *Index) existingTree(model string) (*radixtree.Tree[ReplicaKey], bool) {
|
||||
ix.mu.RLock()
|
||||
defer ix.mu.RUnlock()
|
||||
t, ok := ix.trees[model]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
func (ix *Index) tree(model string) *radixtree.Tree[ReplicaKey] {
|
||||
ix.mu.RLock()
|
||||
t, ok := ix.trees[model]
|
||||
ix.mu.RUnlock()
|
||||
if ok {
|
||||
return t
|
||||
}
|
||||
ix.mu.Lock()
|
||||
defer ix.mu.Unlock()
|
||||
if t, ok = ix.trees[model]; ok {
|
||||
return t
|
||||
}
|
||||
t = radixtree.New[ReplicaKey](radixtree.Options{TTL: ix.cfg.TTL, HalfLife: ix.cfg.HalfLife})
|
||||
ix.trees[model] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (ix *Index) Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision {
|
||||
t := ix.tree(model)
|
||||
var d PrefixDecision
|
||||
// WeightsFor computes every candidate weight in a single tree walk and
|
||||
// returns a map pre-populated with an entry (weight 0 by default) for every
|
||||
// requested candidate. Candidacy is therefore exactly "is a key in weights",
|
||||
// so we derive the hot-match membership check from it rather than building a
|
||||
// second set.
|
||||
weights := t.WeightsFor(candidates, now)
|
||||
if len(chain) > 0 {
|
||||
if key, depth, ok := t.LongestMatch(chain, now); ok {
|
||||
// LongestMatch searches the whole tree, so the deepest match can be
|
||||
// a replica that is offline / unloaded / not in the candidate set.
|
||||
// Treating that as a hot match produces a false forced-disturb signal
|
||||
// upstream (the warm replica was absent, not load-saturated). Only honor
|
||||
// the match when the matched replica is an actual candidate; otherwise
|
||||
// fall back to cold placement.
|
||||
if _, ok := weights[key]; ok {
|
||||
d.Hot = key
|
||||
d.HasHot = true
|
||||
d.MatchRatio = float64(depth) / float64(len(chain))
|
||||
}
|
||||
}
|
||||
}
|
||||
// Cold order: candidates ascending by cacheWeight, tie-break by NodeID then
|
||||
// Replica. The sort comparator reads precomputed weights instead of triggering
|
||||
// an O(tree size) Weight call per comparison. With at most one candidate the
|
||||
// input order is already the cold order, so skip the sort.
|
||||
order := make([]ReplicaKey, len(candidates))
|
||||
copy(order, candidates)
|
||||
if len(order) > 1 {
|
||||
sort.Slice(order, func(i, j int) bool {
|
||||
if weights[order[i]] != weights[order[j]] {
|
||||
return weights[order[i]] < weights[order[j]]
|
||||
}
|
||||
return order[i].less(order[j])
|
||||
})
|
||||
}
|
||||
d.ColdOrder = order
|
||||
return d
|
||||
}
|
||||
|
||||
func (ix *Index) Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool {
|
||||
if len(chain) == 0 || key.NodeID == "" {
|
||||
return false
|
||||
}
|
||||
t := ix.tree(model)
|
||||
// New/extended iff the current deepest match for this exact chain is not
|
||||
// already this replica at full depth.
|
||||
cur, depth, ok := t.LongestMatch(chain, now)
|
||||
t.Insert(chain, key, now)
|
||||
return !ok || depth < len(chain) || cur != key
|
||||
}
|
||||
|
||||
// Invalidate drops all entries for ONE replica. It never interns an empty tree
|
||||
// (a registry chokepoint fires Invalidate for every replica removal of every
|
||||
// model, including round-robin models that never used the prefix cache, so
|
||||
// lazily creating a tree here would grow the trees map unboundedly).
|
||||
func (ix *Index) Invalidate(model string, key ReplicaKey) {
|
||||
if t, ok := ix.existingTree(model); ok {
|
||||
t.RemoveFunc(func(k ReplicaKey) bool { return k == key })
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateNode drops entries for ALL replicas of nodeID. Like Invalidate it
|
||||
// does not intern an empty tree.
|
||||
func (ix *Index) InvalidateNode(model, nodeID string) {
|
||||
if t, ok := ix.existingTree(model); ok {
|
||||
t.RemoveFunc(func(k ReplicaKey) bool { return k.NodeID == nodeID })
|
||||
}
|
||||
}
|
||||
|
||||
func (ix *Index) Evict(now time.Time) {
|
||||
ix.mu.RLock()
|
||||
defer ix.mu.RUnlock()
|
||||
for _, t := range ix.trees {
|
||||
t.Evict(now)
|
||||
}
|
||||
}
|
||||
169
core/services/nodes/prefixcache/index_test.go
Normal file
169
core/services/nodes/prefixcache/index_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
var t0 = time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
var _ = Describe("Index provider", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
|
||||
It("returns no hot match before anything is observed", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
d := idx.Decide("m", []uint64{1, 2, 3}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
// cold order present (all weights zero -> deterministic by node id)
|
||||
Expect(d.ColdOrder).To(ConsistOf(rk("A", 0), rk("B", 0)))
|
||||
})
|
||||
|
||||
It("returns the observed replica as hot match with the right ratio", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 0)))
|
||||
Expect(d.MatchRatio).To(BeNumerically("~", 4.0/5.0, 0.001))
|
||||
})
|
||||
|
||||
It("orders cold candidates by ascending cacheWeight", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1}, rk("A", 0), t0)
|
||||
idx.Observe("m", []uint64{2}, rk("A", 0), t0) // A weight 2
|
||||
idx.Observe("m", []uint64{3}, rk("B", 0), t0) // B weight 1
|
||||
d := idx.Decide("m", []uint64{9}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
Expect(d.ColdOrder).To(Equal([]prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)})) // B lower weight first
|
||||
})
|
||||
|
||||
It("drops the hot match when the matched replica is not in the candidate set", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
// A holds the longest match, but A is not a candidate (offline /
|
||||
// unloaded). The matched replica must be ignored so cold placement runs
|
||||
// and no false forced-disturb fires upstream.
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("B", 0), rk("C", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
Expect(d.MatchRatio).To(Equal(0.0))
|
||||
Expect(d.ColdOrder).To(ConsistOf(rk("B", 0), rk("C", 0)))
|
||||
})
|
||||
|
||||
It("returns a hot match for a query that only shares a prefix with an observed chain", func() {
|
||||
// The real-world case: a replica served chain [1,2,3,4]; a new request
|
||||
// shares the leading block [1,2,3] but diverges at the tail ([1,2,3,9]).
|
||||
// With prefix matching (value recorded at every node) Decide must still
|
||||
// route to the warm replica, matching at the depth of the shared prefix.
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 9}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 0)))
|
||||
Expect(d.MatchRatio).To(BeNumerically("~", 3.0/4.0, 0.001)) // shared [1,2,3] of len-4 query
|
||||
})
|
||||
|
||||
It("keeps the hot match when the matched replica is a candidate", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 0)))
|
||||
Expect(d.MatchRatio).To(BeNumerically("~", 4.0/5.0, 0.001))
|
||||
})
|
||||
|
||||
It("tracks affinity per replica, not per node", func() {
|
||||
// Two replicas on the SAME node, each serving a different chain that share
|
||||
// a leading block. The hot match for a query extending chain1 must be the
|
||||
// EXACT replica that served chain1, not the other replica on the same node.
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) // replica 0 owns [1,2,3,4]
|
||||
idx.Observe("m", []uint64{1, 2, 5, 6}, rk("A", 1), t0) // replica 1 owns [1,2,5,6]
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 4, 7}, cands, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 0))) // distinct replicas on one node have distinct affinity
|
||||
d2 := idx.Decide("m", []uint64{1, 2, 5, 6, 7}, cands, t0)
|
||||
Expect(d2.HasHot).To(BeTrue())
|
||||
Expect(d2.Hot).To(Equal(rk("A", 1)))
|
||||
})
|
||||
|
||||
It("Invalidate drops one replica while InvalidateNode drops all replicas of a node", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
idx.Observe("m", []uint64{5, 6, 7, 8}, rk("A", 1), t0)
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
|
||||
|
||||
// Invalidate replica 0 only: replica 1 survives.
|
||||
idx.Invalidate("m", rk("A", 0))
|
||||
Expect(idx.Decide("m", []uint64{1, 2, 3, 4}, cands, t0).HasHot).To(BeFalse())
|
||||
d1 := idx.Decide("m", []uint64{5, 6, 7, 8}, cands, t0)
|
||||
Expect(d1.HasHot).To(BeTrue())
|
||||
Expect(d1.Hot).To(Equal(rk("A", 1)))
|
||||
|
||||
// Re-observe both, then InvalidateNode drops BOTH replicas.
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
idx.InvalidateNode("m", "A")
|
||||
Expect(idx.Decide("m", []uint64{1, 2, 3, 4}, cands, t0).HasHot).To(BeFalse())
|
||||
Expect(idx.Decide("m", []uint64{5, 6, 7, 8}, cands, t0).HasHot).To(BeFalse())
|
||||
})
|
||||
|
||||
It("forgets a replica on Invalidate", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
|
||||
idx.Invalidate("m", rk("A", 0))
|
||||
d := idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not intern an empty tree when invalidating a model that has none", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
Expect(idx.TreeCountForTest()).To(Equal(0))
|
||||
// Round-robin model that never used the prefix cache: invalidating a
|
||||
// replica removal must be a no-op and must not retain a tree.
|
||||
idx.Invalidate("never-cached", rk("A", 0))
|
||||
idx.Invalidate("never-cached", rk("B", 0))
|
||||
idx.InvalidateNode("other", "C")
|
||||
Expect(idx.TreeCountForTest()).To(Equal(0))
|
||||
// And a Decide afterwards still works without a hot match.
|
||||
d := idx.Decide("never-cached", []uint64{1}, []prefixcache.ReplicaKey{rk("A", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is safe for concurrent Decide/Observe/Invalidate (run with -race)", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
models := []string{"m1", "m2"}
|
||||
nodes := []string{"A", "B", "C"}
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0), rk("C", 0)}
|
||||
var wg sync.WaitGroup
|
||||
for g := range 8 {
|
||||
wg.Add(1)
|
||||
go func(g int) {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
model := models[g%len(models)]
|
||||
node := nodes[g%len(nodes)]
|
||||
now := t0
|
||||
for i := range 200 {
|
||||
chain := []uint64{uint64(g), uint64(i % 7), uint64(i)}
|
||||
switch i % 4 {
|
||||
case 0:
|
||||
idx.Observe(model, chain, prefixcache.ReplicaKey{NodeID: node, Replica: i % 2}, now)
|
||||
case 1:
|
||||
idx.Decide(model, chain, cands, now)
|
||||
case 2:
|
||||
idx.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: i % 2})
|
||||
case 3:
|
||||
idx.InvalidateNode(model, node)
|
||||
}
|
||||
now = now.Add(time.Millisecond)
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
})
|
||||
47
core/services/nodes/prefixcache/policy.go
Normal file
47
core/services/nodes/prefixcache/policy.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Package prefixcache implements prefix-cache-aware routing for distributed
|
||||
// mode: it turns a request prompt into a chain of prefix hashes, tracks which
|
||||
// node served which prefix in an in-memory radix tree, and provides a
|
||||
// load-guarded preferred-node decision. See docs/content/features/distributed-mode.md.
|
||||
package prefixcache
|
||||
|
||||
// RoutePolicy selects the routing strategy for a model. The zero value is
|
||||
// RoutePolicyDefault, meaning "inherit the cluster-wide default".
|
||||
type RoutePolicy int
|
||||
|
||||
const (
|
||||
RoutePolicyDefault RoutePolicy = iota // inherit global default
|
||||
RoutePolicyRoundRobin // today's behavior (the floor)
|
||||
RoutePolicyPrefixCache // cache-aware routing
|
||||
)
|
||||
|
||||
// ParsePolicy maps a config string to a RoutePolicy. Unknown or empty strings
|
||||
// map to RoutePolicyDefault.
|
||||
func ParsePolicy(s string) RoutePolicy {
|
||||
switch s {
|
||||
case "round_robin":
|
||||
return RoutePolicyRoundRobin
|
||||
case "prefix_cache":
|
||||
return RoutePolicyPrefixCache
|
||||
default:
|
||||
return RoutePolicyDefault
|
||||
}
|
||||
}
|
||||
|
||||
func (p RoutePolicy) String() string {
|
||||
switch p {
|
||||
case RoutePolicyRoundRobin:
|
||||
return "round_robin"
|
||||
case RoutePolicyPrefixCache:
|
||||
return "prefix_cache"
|
||||
default:
|
||||
return "default"
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve returns p unless it is Default, in which case it returns global.
|
||||
func (p RoutePolicy) Resolve(global RoutePolicy) RoutePolicy {
|
||||
if p == RoutePolicyDefault {
|
||||
return global
|
||||
}
|
||||
return p
|
||||
}
|
||||
29
core/services/nodes/prefixcache/policy_test.go
Normal file
29
core/services/nodes/prefixcache/policy_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
var _ = Describe("RoutePolicy", func() {
|
||||
It("parses known values and defaults unknown to Default (zero)", func() {
|
||||
Expect(prefixcache.ParsePolicy("round_robin")).To(Equal(prefixcache.RoutePolicyRoundRobin))
|
||||
Expect(prefixcache.ParsePolicy("prefix_cache")).To(Equal(prefixcache.RoutePolicyPrefixCache))
|
||||
Expect(prefixcache.ParsePolicy("")).To(Equal(prefixcache.RoutePolicyDefault))
|
||||
Expect(prefixcache.ParsePolicy("bogus")).To(Equal(prefixcache.RoutePolicyDefault))
|
||||
})
|
||||
|
||||
It("stringifies", func() {
|
||||
Expect(prefixcache.RoutePolicyPrefixCache.String()).To(Equal("prefix_cache"))
|
||||
Expect(prefixcache.RoutePolicyRoundRobin.String()).To(Equal("round_robin"))
|
||||
})
|
||||
|
||||
It("resolves per-model against a global default", func() {
|
||||
Expect(prefixcache.RoutePolicyDefault.Resolve(prefixcache.RoutePolicyPrefixCache)).
|
||||
To(Equal(prefixcache.RoutePolicyPrefixCache))
|
||||
Expect(prefixcache.RoutePolicyRoundRobin.Resolve(prefixcache.RoutePolicyPrefixCache)).
|
||||
To(Equal(prefixcache.RoutePolicyRoundRobin))
|
||||
})
|
||||
})
|
||||
13
core/services/nodes/prefixcache/prefixcache_suite_test.go
Normal file
13
core/services/nodes/prefixcache/prefixcache_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestPrefixCache(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "PrefixCache Suite")
|
||||
}
|
||||
82
core/services/nodes/prefixcache/pressure.go
Normal file
82
core/services/nodes/prefixcache/pressure.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Pressure is a concurrency-safe rolling per-model counter of forced-disturb
|
||||
// events. A forced-disturb is recorded by the router when a usable hot prefix
|
||||
// match existed but the load guard forced the request off the warm node (see
|
||||
// SmartRouter.buildPreference). The reconciler reads Count to decide whether
|
||||
// the cache-warm replica is saturated enough to warrant a scale-up.
|
||||
//
|
||||
// Entries older than the window are dropped on both Record and Count, so the
|
||||
// slice never grows unbounded - even for a model that takes records but is
|
||||
// never Counted (e.g. one with zero loaded replicas the reconciler skips). An
|
||||
// idle model's history also decays to zero on the next read.
|
||||
type Pressure struct {
|
||||
mu sync.Mutex
|
||||
window time.Duration
|
||||
events map[string][]time.Time
|
||||
}
|
||||
|
||||
// NewPressure creates a Pressure counter that remembers events for the given
|
||||
// rolling window.
|
||||
func NewPressure(window time.Duration) *Pressure {
|
||||
return &Pressure{
|
||||
window: window,
|
||||
events: make(map[string][]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// pruneLocked drops entries older than cutoff, compacting in place. The cutoff
|
||||
// boundary itself is inclusive so an event exactly window-old still counts.
|
||||
// Callers must hold p.mu.
|
||||
func pruneLocked(ts []time.Time, cutoff time.Time) []time.Time {
|
||||
kept := ts[:0]
|
||||
for _, t := range ts {
|
||||
if !t.Before(cutoff) {
|
||||
kept = append(kept, t)
|
||||
}
|
||||
}
|
||||
return kept
|
||||
}
|
||||
|
||||
// Record appends a forced-disturb timestamp for the model and prunes entries
|
||||
// older than the window, so the per-model slice stays bounded regardless of how
|
||||
// often Count runs.
|
||||
func (p *Pressure) Record(model string, now time.Time) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
cutoff := now.Add(-p.window)
|
||||
kept := append(pruneLocked(p.events[model], cutoff), now)
|
||||
p.events[model] = kept
|
||||
}
|
||||
|
||||
// Count returns the number of records for the model within [now-window, now],
|
||||
// dropping any entries older than the window so the backing slice stays bounded.
|
||||
func (p *Pressure) Count(model string, now time.Time) int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
ts := p.events[model]
|
||||
if len(ts) == 0 {
|
||||
return 0
|
||||
}
|
||||
kept := pruneLocked(ts, now.Add(-p.window))
|
||||
if len(kept) == 0 {
|
||||
delete(p.events, model)
|
||||
return 0
|
||||
}
|
||||
p.events[model] = kept
|
||||
return len(kept)
|
||||
}
|
||||
|
||||
// Reset clears all recorded events for model. Call after acting on the signal
|
||||
// (a pressure-triggered scale-up) so a single burst does not trigger repeated
|
||||
// scale-ups across consecutive ticks.
|
||||
func (p *Pressure) Reset(model string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
delete(p.events, model)
|
||||
}
|
||||
98
core/services/nodes/prefixcache/pressure_test.go
Normal file
98
core/services/nodes/prefixcache/pressure_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Pressure counter", func() {
|
||||
t0 := time.Unix(1700000000, 0)
|
||||
|
||||
It("counts events within the window and forgets older ones", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("m", t0)
|
||||
p.Record("m", t0.Add(30*time.Second))
|
||||
Expect(p.Count("m", t0.Add(40*time.Second))).To(Equal(2))
|
||||
Expect(p.Count("m", t0.Add(90*time.Second))).To(Equal(1)) // first expired
|
||||
})
|
||||
|
||||
It("tracks pressure per model independently", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("a", t0)
|
||||
p.Record("a", t0.Add(10*time.Second))
|
||||
p.Record("b", t0.Add(20*time.Second))
|
||||
Expect(p.Count("a", t0.Add(30*time.Second))).To(Equal(2))
|
||||
Expect(p.Count("b", t0.Add(30*time.Second))).To(Equal(1))
|
||||
Expect(p.Count("c", t0.Add(30*time.Second))).To(Equal(0))
|
||||
})
|
||||
|
||||
It("returns zero for a model that was never recorded", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
Expect(p.Count("never", t0)).To(Equal(0))
|
||||
})
|
||||
|
||||
It("includes the boundary timestamp at exactly now-window", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("m", t0)
|
||||
// now-window == t0 exactly, so the entry is still within [now-window, now].
|
||||
Expect(p.Count("m", t0.Add(time.Minute))).To(Equal(1))
|
||||
// one nanosecond past the window drops it.
|
||||
Expect(p.Count("m", t0.Add(time.Minute+1))).To(Equal(0))
|
||||
})
|
||||
|
||||
It("bounds the backing slice in Record without any Count calls", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
// Record many timestamps, advancing now well past the window between
|
||||
// each, and never call Count. Each Record must prune the entries that
|
||||
// have fallen out of [now-window, now] so the slice cannot accumulate.
|
||||
var last time.Time
|
||||
for i := range 1000 {
|
||||
last = t0.Add(time.Duration(i) * 10 * time.Second)
|
||||
p.Record("m", last)
|
||||
}
|
||||
// With a 1m window and 10s spacing, at most ~7 records (the boundary is
|
||||
// inclusive) can be within [last-window, last]. The slice must stay that
|
||||
// bounded, never growing toward 1000.
|
||||
Expect(p.LenForTest("m")).To(BeNumerically("<=", 7))
|
||||
// And the in-window count must reflect only those bounded entries.
|
||||
Expect(p.Count("m", last)).To(Equal(p.LenForTest("m")))
|
||||
})
|
||||
|
||||
It("clears all recorded events on Reset", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("m", t0)
|
||||
p.Record("m", t0.Add(10*time.Second))
|
||||
p.Record("m", t0.Add(20*time.Second))
|
||||
Expect(p.Count("m", t0.Add(30*time.Second))).To(BeNumerically(">", 0))
|
||||
|
||||
p.Reset("m")
|
||||
|
||||
// After Reset the model has no in-window events even though the
|
||||
// timestamps would otherwise still be within [now-window, now].
|
||||
Expect(p.Count("m", t0.Add(30*time.Second))).To(Equal(0))
|
||||
Expect(p.LenForTest("m")).To(Equal(0))
|
||||
})
|
||||
|
||||
It("Reset only clears the named model", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("a", t0)
|
||||
p.Record("b", t0)
|
||||
p.Reset("a")
|
||||
Expect(p.Count("a", t0.Add(time.Second))).To(Equal(0))
|
||||
Expect(p.Count("b", t0.Add(time.Second))).To(Equal(1))
|
||||
})
|
||||
|
||||
It("does not accumulate repeated out-of-window Records", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
// Each record is more than a window apart, so every Record prunes the
|
||||
// previous one. The slice should never hold more than a single entry.
|
||||
for i := range 100 {
|
||||
p.Record("m", t0.Add(time.Duration(i)*2*time.Minute))
|
||||
}
|
||||
Expect(p.LenForTest("m")).To(Equal(1))
|
||||
Expect(p.Count("m", t0.Add(198*time.Minute))).To(Equal(1))
|
||||
})
|
||||
})
|
||||
24
core/services/nodes/prefixcache/provider.go
Normal file
24
core/services/nodes/prefixcache/provider.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package prefixcache
|
||||
|
||||
import "time"
|
||||
|
||||
// Provider is the seam between SmartRouter and the prefix-cache implementation.
|
||||
// The radix-tree (guessed) implementation is the only one today; a future
|
||||
// KV-event (reported) implementation can satisfy the same interface without
|
||||
// changing SmartRouter (epic #10063 / #10064). Affinity is tracked per replica:
|
||||
// each loaded replica is a separate process with its own KV cache.
|
||||
type Provider interface {
|
||||
// Decide computes the prefix decision for a request given the candidate
|
||||
// replicas (the selector-filtered set). It does not consult load - load
|
||||
// filtering happens in the DB transaction.
|
||||
Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision
|
||||
// Observe records that the replica served the request whose prefix is chain.
|
||||
// Returns true when the assignment was new or extended (caller broadcasts).
|
||||
Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool
|
||||
// Invalidate drops all entries for ONE replica.
|
||||
Invalidate(model string, key ReplicaKey)
|
||||
// InvalidateNode drops entries for ALL replicas of a node.
|
||||
InvalidateNode(model, nodeID string)
|
||||
// Evict sweeps expired entries for all models.
|
||||
Evict(now time.Time)
|
||||
}
|
||||
93
core/services/nodes/prefixcache/select.go
Normal file
93
core/services/nodes/prefixcache/select.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package prefixcache
|
||||
|
||||
// ReplicaKey identifies a specific loaded replica (a backend process). Affinity
|
||||
// is tracked per replica, not per node, because each replica is a separate
|
||||
// process with its own KV cache.
|
||||
type ReplicaKey struct {
|
||||
NodeID string
|
||||
Replica int
|
||||
}
|
||||
|
||||
// less reports whether a sorts before b, ordering by NodeID then Replica. It is
|
||||
// the deterministic tiebreak used wherever two replicas are otherwise equal.
|
||||
func (a ReplicaKey) less(b ReplicaKey) bool {
|
||||
if a.NodeID != b.NodeID {
|
||||
return a.NodeID < b.NodeID
|
||||
}
|
||||
return a.Replica < b.Replica
|
||||
}
|
||||
|
||||
// Candidate is a load-eligible-or-not replica view from the registry. There is
|
||||
// one Candidate per LOADED replica: the router no longer collapses replicas per
|
||||
// node, so two replicas of the same model on the same node are two candidates.
|
||||
type Candidate struct {
|
||||
Key ReplicaKey
|
||||
InFlight int
|
||||
}
|
||||
|
||||
// PrefixDecision is computed from the in-memory tree before the DB transaction.
|
||||
// Hot is the replica holding the longest prefix match and HasHot reports whether
|
||||
// there is one (a ReplicaKey has no "" sentinel). MatchRatio is matched/total
|
||||
// for that match. ColdOrder lists candidate replicas ascending by cacheWeight
|
||||
// (lowest = least valuable warm cache = best cold target).
|
||||
type PrefixDecision struct {
|
||||
Hot ReplicaKey
|
||||
HasHot bool
|
||||
MatchRatio float64
|
||||
ColdOrder []ReplicaKey
|
||||
}
|
||||
|
||||
// Select implements filter-then-score per replica: keep candidates within the
|
||||
// load guard (relative to the min in-flight across ALL candidate replicas), then
|
||||
// prefer the exact hot-match replica, else the lowest-cacheWeight eligible
|
||||
// replica via ColdOrder, else a deterministic eligible fallback (least in-flight,
|
||||
// tiebreak by NodeID then Replica). Returns (ReplicaKey{}, false) when nothing is
|
||||
// selectable.
|
||||
func Select(cands []Candidate, d PrefixDecision, cfg Config) (ReplicaKey, bool) {
|
||||
if len(cands) == 0 {
|
||||
return ReplicaKey{}, false
|
||||
}
|
||||
minIF := cands[0].InFlight
|
||||
for _, c := range cands {
|
||||
minIF = min(minIF, c.InFlight)
|
||||
}
|
||||
eligible := map[ReplicaKey]bool{}
|
||||
for _, c := range cands {
|
||||
withinAbs := c.InFlight <= minIF+cfg.BalanceAbsThreshold
|
||||
// +1 softens the relative guard when minIF==0 so a zero baseline does
|
||||
// not require exact-zero in-flight; the absolute guard governs near 0.
|
||||
withinRel := float64(c.InFlight) <= float64(minIF)*cfg.BalanceRelThreshold+1
|
||||
if withinAbs && withinRel {
|
||||
eligible[c.Key] = true
|
||||
}
|
||||
}
|
||||
// Hot match wins if eligible and strong enough.
|
||||
if d.HasHot && d.MatchRatio >= cfg.MinPrefixMatch && eligible[d.Hot] {
|
||||
return d.Hot, true
|
||||
}
|
||||
// Cold placement: lowest cacheWeight eligible replica.
|
||||
for _, k := range d.ColdOrder {
|
||||
if eligible[k] {
|
||||
return k, true
|
||||
}
|
||||
}
|
||||
// Deterministic eligible fallback: least in-flight, tiebreak NodeID then
|
||||
// Replica. ColdOrder may not cover the eligible set (the caller may pass an
|
||||
// empty ColdOrder), so this guarantees Select still returns the best eligible
|
||||
// replica rather than failing.
|
||||
var best Candidate
|
||||
found := false
|
||||
for _, c := range cands {
|
||||
if !eligible[c.Key] {
|
||||
continue
|
||||
}
|
||||
if !found || c.InFlight < best.InFlight ||
|
||||
(c.InFlight == best.InFlight && c.Key.less(best.Key)) {
|
||||
best, found = c, true
|
||||
}
|
||||
}
|
||||
if found {
|
||||
return best.Key, true
|
||||
}
|
||||
return ReplicaKey{}, false
|
||||
}
|
||||
139
core/services/nodes/prefixcache/select_test.go
Normal file
139
core/services/nodes/prefixcache/select_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
func rk(node string, replica int) prefixcache.ReplicaKey {
|
||||
return prefixcache.ReplicaKey{NodeID: node, Replica: replica}
|
||||
}
|
||||
|
||||
var _ = Describe("Select (filter-then-score)", func() {
|
||||
cfg := prefixcache.DefaultConfig() // abs=2, rel=1.5, minMatch=0.3
|
||||
|
||||
cand := func(node string, replica, inflight int) prefixcache.Candidate {
|
||||
return prefixcache.Candidate{Key: rk(node, replica), InFlight: inflight}
|
||||
}
|
||||
|
||||
It("returns the hot-match replica when it is load-eligible and match >= min", func() {
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 1), cand("B", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 0.5,
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 0))) // A in-flight 1 <= min(0)+2 and <= 0*1.5+1
|
||||
})
|
||||
|
||||
It("rejects the hot match when it violates the absolute load guard", func() {
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 5), cand("B", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 0.9,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("B", 0))) // A 5 > min(0)+2, drop to cold placement
|
||||
})
|
||||
|
||||
It("ignores a match below min_prefix_match", func() {
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 0), cand("B", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 0.2, // < 0.3
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("B", 0))) // cold placement: lowest cacheWeight eligible
|
||||
})
|
||||
|
||||
It("cold-places to lowest-cacheWeight replica within the eligible subset", func() {
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 0), cand("B", 0, 0), cand("C", 0, 9)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("C", 0), rk("B", 0), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("B", 0))) // C filtered out by load; B is next in cold order
|
||||
})
|
||||
|
||||
It("returns false when no candidates", func() {
|
||||
_, ok := prefixcache.Select(nil, prefixcache.PrefixDecision{}, cfg)
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("falls back to the least-in-flight eligible replica when ColdOrder is empty", func() {
|
||||
// Deterministic eligible fallback: ColdOrder does not cover the eligible
|
||||
// set, so Select picks the least-in-flight eligible replica, tiebreaking by
|
||||
// NodeID then Replica.
|
||||
cands := []prefixcache.Candidate{cand("B", 1, 0), cand("B", 0, 0), cand("A", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 0))) // all in-flight 0; A < B; within B, replica 0 < 1
|
||||
})
|
||||
|
||||
It("returns false when no candidate is eligible", func() {
|
||||
// Impossible in practice (min is always eligible) but guards the contract:
|
||||
// an empty eligible set yields no selection. Here every candidate is the
|
||||
// min, so one is always eligible; instead test the documented zero value.
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 0)))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Select replica granularity", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
|
||||
It("distinguishes two replicas of the same node as separate candidates", func() {
|
||||
// Two replicas on NodeA: replica 0 is hot but saturated, replica 1 is cool.
|
||||
// The round-robin floor must drop to replica 1, NOT collapse them per node.
|
||||
cands := []prefixcache.Candidate{
|
||||
{Key: rk("A", 0), InFlight: 50},
|
||||
{Key: rk("A", 1), InFlight: 0},
|
||||
}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 1.0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("A", 1), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 1)))
|
||||
})
|
||||
|
||||
It("pins back to the exact hot replica when it is within slack", func() {
|
||||
cands := []prefixcache.Candidate{
|
||||
{Key: rk("A", 0), InFlight: 1},
|
||||
{Key: rk("A", 1), InFlight: 0},
|
||||
}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 1.0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("A", 1), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 0))) // within slack -> reuse exact replica
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Select round-robin floor invariant", func() {
|
||||
It("never pins to a saturated hot replica (round-robin floor)", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
cands := []prefixcache.Candidate{{Key: rk("hot", 0), InFlight: 50}, {Key: rk("cool", 0), InFlight: 0}}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("hot", 0), HasHot: true, MatchRatio: 1.0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("cool", 0), rk("hot", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("cool", 0)))
|
||||
})
|
||||
|
||||
It("improves reuse when balanced", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
cands := []prefixcache.Candidate{{Key: rk("hot", 0), InFlight: 1}, {Key: rk("cool", 0), InFlight: 0}}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("hot", 0), HasHot: true, MatchRatio: 1.0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("cool", 0), rk("hot", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("hot", 0))) // within slack -> reuse
|
||||
})
|
||||
})
|
||||
91
core/services/nodes/prefixcache/sync.go
Normal file
91
core/services/nodes/prefixcache/sync.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// publisher is the minimal slice of messaging.Client that Sync needs.
|
||||
type publisher interface {
|
||||
Publish(subject string, v any) error
|
||||
}
|
||||
|
||||
// Sync wraps an Index, broadcasting new/extended observations to peers and
|
||||
// applying peers' broadcasts. It is the cross-frontend coherence layer.
|
||||
type Sync struct {
|
||||
idx Provider
|
||||
pub publisher
|
||||
}
|
||||
|
||||
func NewSync(idx Provider, pub publisher) *Sync { return &Sync{idx: idx, pub: pub} }
|
||||
|
||||
// Observe records locally and, if new/extended, broadcasts to peers. It returns
|
||||
// whether the local index treated the assignment as new or extended, so Sync
|
||||
// satisfies prefixcache.Provider.
|
||||
func (s *Sync) Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool {
|
||||
changed := s.idx.Observe(model, chain, key, now)
|
||||
if changed && s.pub != nil {
|
||||
ev := messaging.PrefixCacheObserveEvent{Model: model, Chain: chain, NodeID: key.NodeID, Replica: key.Replica}
|
||||
if err := s.pub.Publish(messaging.SubjectPrefixCacheObserve, ev); err != nil {
|
||||
xlog.Debug("prefixcache: observe publish failed", "error", err)
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// Invalidate drops the local entry for one replica and broadcasts to peers. The
|
||||
// local drop is a no-op for models that were never cached (Index.Invalidate does
|
||||
// not intern a tree). The broadcast is UNCONDITIONAL (when a publisher is
|
||||
// configured): the registry chokepoint fires for every replica removal, and a
|
||||
// peer frontend may hold a stale entry for the model even when THIS frontend
|
||||
// never cached it, so gating the broadcast on local-tree existence would drop
|
||||
// cross-frontend invalidations and leave peers routing to a removed replica
|
||||
// until their TTL.
|
||||
func (s *Sync) Invalidate(model string, key ReplicaKey) {
|
||||
s.idx.Invalidate(model, key)
|
||||
if s.pub != nil {
|
||||
ev := messaging.PrefixCacheInvalidateEvent{Model: model, NodeID: key.NodeID, Replica: key.Replica}
|
||||
if err := s.pub.Publish(messaging.SubjectPrefixCacheInvalidate, ev); err != nil {
|
||||
xlog.Debug("prefixcache: invalidate publish failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateNode drops the local entries for ALL replicas of node and broadcasts
|
||||
// to peers. Like Invalidate the broadcast is unconditional for cross-frontend
|
||||
// coherence. A negative Replica on the wire means "all replicas of the node".
|
||||
func (s *Sync) InvalidateNode(model, node string) {
|
||||
s.idx.InvalidateNode(model, node)
|
||||
if s.pub != nil {
|
||||
ev := messaging.PrefixCacheInvalidateEvent{Model: model, NodeID: node, Replica: -1}
|
||||
if err := s.pub.Publish(messaging.SubjectPrefixCacheInvalidate, ev); err != nil {
|
||||
xlog.Debug("prefixcache: invalidate-node publish failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyObserve applies a peer observe event locally (no re-broadcast).
|
||||
func (s *Sync) ApplyObserve(ev messaging.PrefixCacheObserveEvent, now time.Time) {
|
||||
s.idx.Observe(ev.Model, ev.Chain, ReplicaKey{NodeID: ev.NodeID, Replica: ev.Replica}, now)
|
||||
}
|
||||
|
||||
// ApplyInvalidate applies a peer invalidate event locally (no re-broadcast). A
|
||||
// negative Replica targets all replicas of the node.
|
||||
func (s *Sync) ApplyInvalidate(ev messaging.PrefixCacheInvalidateEvent) {
|
||||
if ev.Replica < 0 {
|
||||
s.idx.InvalidateNode(ev.Model, ev.NodeID)
|
||||
return
|
||||
}
|
||||
s.idx.Invalidate(ev.Model, ReplicaKey{NodeID: ev.NodeID, Replica: ev.Replica})
|
||||
}
|
||||
|
||||
// Decide delegates to the wrapped index.
|
||||
func (s *Sync) Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision {
|
||||
return s.idx.Decide(model, chain, candidates, now)
|
||||
}
|
||||
|
||||
// Evict delegates eviction of expired entries to the wrapped index. It does not
|
||||
// broadcast: each frontend evicts its own copy on its own TTL clock.
|
||||
func (s *Sync) Evict(now time.Time) { s.idx.Evict(now) }
|
||||
118
core/services/nodes/prefixcache/sync_test.go
Normal file
118
core/services/nodes/prefixcache/sync_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
type fakePub struct{ published []any }
|
||||
|
||||
func (f *fakePub) Publish(subject string, v any) error {
|
||||
f.published = append(f.published, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync must satisfy the Provider seam so SmartRouter can hold a single
|
||||
// prefixcache.Provider that broadcasts via NATS.
|
||||
var _ prefixcache.Provider = (*prefixcache.Sync)(nil)
|
||||
|
||||
var _ = Describe("Sync", func() {
|
||||
It("delegates Evict to the wrapped index", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
cfg.TTL = time.Minute
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
s := prefixcache.NewSync(idx, &fakePub{})
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
|
||||
// Before TTL: still hot.
|
||||
Expect(idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0).HasHot).To(BeTrue())
|
||||
// After TTL via Sync.Evict: entry is swept.
|
||||
s.Evict(t0.Add(2 * time.Minute))
|
||||
Expect(idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0.Add(2*time.Minute)).HasHot).To(BeFalse())
|
||||
})
|
||||
|
||||
It("publishes an observe event with the replica when Observe is new", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
pub := &fakePub{}
|
||||
s := prefixcache.NewSync(idx, pub)
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 1), t0) // first time -> publish
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
ev := pub.published[0].(messaging.PrefixCacheObserveEvent)
|
||||
Expect(ev.NodeID).To(Equal("A"))
|
||||
Expect(ev.Replica).To(Equal(1))
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 1), t0) // same -> no publish
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("broadcasts an invalidate even for a model with no local tree, without interning one", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
pub := &fakePub{}
|
||||
s := prefixcache.NewSync(idx, pub)
|
||||
// A peer frontend may hold a stale entry for this model even though THIS
|
||||
// frontend never cached it, so the invalidate MUST be broadcast for
|
||||
// cross-frontend coherence. The local drop must still not intern a tree.
|
||||
s.Invalidate("never-cached", rk("A", 0))
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
ev := pub.published[0].(messaging.PrefixCacheInvalidateEvent)
|
||||
Expect(ev.NodeID).To(Equal("A"))
|
||||
Expect(ev.Replica).To(Equal(0))
|
||||
Expect(idx.TreeCountForTest()).To(Equal(0))
|
||||
})
|
||||
|
||||
It("broadcasts an invalidate for a cached replica too", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
pub := &fakePub{}
|
||||
s := prefixcache.NewSync(idx, pub)
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 0), t0) // creates the tree (also publishes observe)
|
||||
pub.published = nil
|
||||
s.Invalidate("m", rk("A", 0))
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
Expect(pub.published[0]).To(BeAssignableToTypeOf(messaging.PrefixCacheInvalidateEvent{}))
|
||||
})
|
||||
|
||||
It("broadcasts a node-wide invalidate with a negative replica", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
pub := &fakePub{}
|
||||
s := prefixcache.NewSync(idx, pub)
|
||||
s.InvalidateNode("m", "A")
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
ev := pub.published[0].(messaging.PrefixCacheInvalidateEvent)
|
||||
Expect(ev.NodeID).To(Equal("A"))
|
||||
Expect(ev.Replica).To(BeNumerically("<", 0))
|
||||
})
|
||||
|
||||
It("applies a peer observe event into the local index with the replica", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
s := prefixcache.NewSync(idx, &fakePub{})
|
||||
s.ApplyObserve(messaging.PrefixCacheObserveEvent{Model: "m", Chain: []uint64{1, 2}, NodeID: "A", Replica: 2}, t0)
|
||||
d := idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 2)}, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 2)))
|
||||
})
|
||||
|
||||
It("applies a peer single-replica invalidate", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
s := prefixcache.NewSync(idx, &fakePub{})
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
|
||||
s.Observe("m", []uint64{3, 4}, rk("A", 1), t0)
|
||||
s.ApplyInvalidate(messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: 0})
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
|
||||
Expect(idx.Decide("m", []uint64{1, 2}, cands, t0).HasHot).To(BeFalse())
|
||||
Expect(idx.Decide("m", []uint64{3, 4}, cands, t0).HasHot).To(BeTrue())
|
||||
})
|
||||
|
||||
It("applies a peer node-wide invalidate when replica is negative", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
s := prefixcache.NewSync(idx, &fakePub{})
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
|
||||
s.Observe("m", []uint64{3, 4}, rk("A", 1), t0)
|
||||
s.ApplyInvalidate(messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: -1})
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
|
||||
Expect(idx.Decide("m", []uint64{1, 2}, cands, t0).HasHot).To(BeFalse())
|
||||
Expect(idx.Decide("m", []uint64{3, 4}, cands, t0).HasHot).To(BeFalse())
|
||||
})
|
||||
})
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/advisorylock"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
grpcclient "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/nats-io/nats.go"
|
||||
@@ -56,6 +57,13 @@ type ReplicaReconciler struct {
|
||||
// probeStaleAfter: only probe node_models rows older than this so we
|
||||
// don't hammer every worker every tick for models we just heard from.
|
||||
probeStaleAfter time.Duration
|
||||
// pressure is the shared forced-disturb counter written by the router. When
|
||||
// a model's count within the Pressure's rolling window reaches pressureThreshold the
|
||||
// reconciler treats its cache-warm replica as saturated and scales up,
|
||||
// subject to the same MaxReplicas/capacity/UnsatisfiableUntil machinery as
|
||||
// the other scale-up paths. nil disables this signal (a true no-op).
|
||||
pressure *prefixcache.Pressure
|
||||
pressureThreshold int
|
||||
}
|
||||
|
||||
// ModelScheduler abstracts the scheduling logic needed by the reconciler.
|
||||
@@ -83,6 +91,12 @@ type ReplicaReconcilerOptions struct {
|
||||
Interval time.Duration // default 30s
|
||||
ScaleDownDelay time.Duration // default 5m
|
||||
ProbeStaleAfter time.Duration // default 2m
|
||||
// Pressure is the shared forced-disturb counter written by the router. nil
|
||||
// disables the cache-saturation autoscale signal (a true no-op).
|
||||
Pressure *prefixcache.Pressure
|
||||
// PressureThreshold is the forced-disturb count within PressureWindow that
|
||||
// triggers a scale-up. Default prefixcache.DefaultConfig().PressureScaleThreshold (1).
|
||||
PressureThreshold int
|
||||
}
|
||||
|
||||
// NewReplicaReconciler creates a new ReplicaReconciler.
|
||||
@@ -103,16 +117,22 @@ func NewReplicaReconciler(opts ReplicaReconcilerOptions) *ReplicaReconciler {
|
||||
if prober == nil {
|
||||
prober = grpcModelProber{token: opts.RegistrationToken}
|
||||
}
|
||||
pressureThreshold := opts.PressureThreshold
|
||||
if pressureThreshold == 0 {
|
||||
pressureThreshold = prefixcache.DefaultConfig().PressureScaleThreshold
|
||||
}
|
||||
return &ReplicaReconciler{
|
||||
registry: opts.Registry,
|
||||
scheduler: opts.Scheduler,
|
||||
unloader: opts.Unloader,
|
||||
adapter: opts.Adapter,
|
||||
prober: prober,
|
||||
db: opts.DB,
|
||||
interval: interval,
|
||||
scaleDownDelay: scaleDownDelay,
|
||||
probeStaleAfter: probeStaleAfter,
|
||||
registry: opts.Registry,
|
||||
scheduler: opts.Scheduler,
|
||||
unloader: opts.Unloader,
|
||||
adapter: opts.Adapter,
|
||||
prober: prober,
|
||||
db: opts.DB,
|
||||
interval: interval,
|
||||
scaleDownDelay: scaleDownDelay,
|
||||
probeStaleAfter: probeStaleAfter,
|
||||
pressure: opts.Pressure,
|
||||
pressureThreshold: pressureThreshold,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,13 +429,25 @@ func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedu
|
||||
}
|
||||
xlog.Info("Reconciler: scaling up to meet minimum", "model", cfg.ModelName,
|
||||
"current", current, "min", cfg.MinReplicas, "adding", needed)
|
||||
rc.scaleUp(ctx, cfg, needed)
|
||||
// Successful (or partial) scale-up clears the hysteresis so a future
|
||||
// dip starts fresh.
|
||||
_ = rc.registry.ClearUnsatisfiable(ctx, cfg.ModelName)
|
||||
if rc.scaleUp(ctx, cfg, needed) {
|
||||
// A real (or partial) scale-up clears the hysteresis so a future
|
||||
// dip starts fresh. If scaleUp added nothing (scheduler errored or
|
||||
// no node could be loaded) we leave the hysteresis intact so the
|
||||
// next tick retries from where it left off rather than resetting
|
||||
// the unsatisfiable counter on a failed attempt.
|
||||
_ = rc.registry.ClearUnsatisfiable(ctx, cfg.ModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// scaledUp tracks whether a scale-up already fired in this tick. The two
|
||||
// scale-up paths below (busy-burst and pressure) share the single `current`
|
||||
// value read once above; scaleUp does not re-check it. So at most one of
|
||||
// them may fire per tick, otherwise a model that is both busy AND over the
|
||||
// pressure threshold would scale +2 and could overshoot MaxReplicas by one.
|
||||
// Scale-down is also skipped in a tick that scaled up.
|
||||
scaledUp := false
|
||||
|
||||
// 2. Auto-scale up if all replicas are busy
|
||||
if current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
|
||||
if rc.allReplicasBusy(ctx, cfg.ModelName) {
|
||||
@@ -432,17 +464,63 @@ func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedu
|
||||
}
|
||||
xlog.Info("Reconciler: all replicas busy, scaling up", "model", cfg.ModelName,
|
||||
"current", current)
|
||||
rc.scaleUp(ctx, cfg, 1)
|
||||
// Only mark the tick as having scaled up if a replica was actually
|
||||
// added. On a failed scaleUp, leave scaledUp false so the pressure
|
||||
// path below and the scale-down logic still apply as they would
|
||||
// have if the busy-burst path had not run.
|
||||
scaledUp = rc.scaleUp(ctx, cfg, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Scale down idle replicas above minimum
|
||||
floor := cfg.MinReplicas
|
||||
if floor < 1 {
|
||||
floor = 1
|
||||
// 2b. Auto-scale up on prefix-cache forced-disturb pressure. A forced-disturb
|
||||
// is recorded by the router when a request had a usable hot prefix match
|
||||
// but the load guard forced it off the warm node: the cache-warm replica
|
||||
// is saturated. We reuse the same MaxReplicas + capacity guards as the
|
||||
// busy-burst path, and the same UnsatisfiableUntil cooldown gates this
|
||||
// block at the top of reconcileModel, so a no-capacity model will not
|
||||
// spin. Pressure never overrides MaxReplicas or force-evicts.
|
||||
//
|
||||
// Skipped when the busy-burst path already scaled up this tick: at most
|
||||
// one scaleUp(+1) per tick (see scaledUp above).
|
||||
if !scaledUp && rc.pressure != nil && current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
|
||||
if pressureCount := rc.pressure.Count(cfg.ModelName, time.Now()); pressureCount >= rc.pressureThreshold {
|
||||
candidateNodeIDs, selectorMatched := rc.candidateNodeIDsForSelector(ctx, cfg)
|
||||
if selectorMatched {
|
||||
capacity, capErr := rc.registry.ClusterCapacityForModel(ctx, cfg.ModelName, candidateNodeIDs)
|
||||
if capErr == nil && capacity > 0 {
|
||||
xlog.Info("Reconciler: prefix-cache forced-disturb pressure, scaling up",
|
||||
"model", cfg.ModelName, "current", current,
|
||||
"pressure", pressureCount,
|
||||
"threshold", rc.pressureThreshold)
|
||||
if rc.scaleUp(ctx, cfg, 1) {
|
||||
scaledUp = true
|
||||
// Consume the signal only on a real scale-up:
|
||||
// Pressure.Count is non-draining (it prunes only by
|
||||
// age), so a single burst stays in-window for the whole
|
||||
// window and would re-fire scaleUp on every tick. Reset
|
||||
// clears the model's events so a fresh scale-up needs
|
||||
// fresh forced-disturbs to accumulate. If scaleUp added
|
||||
// nothing (scheduler errored or no node could be loaded)
|
||||
// we preserve the signal so the next tick retries off
|
||||
// the same accumulated pressure instead of having to
|
||||
// re-accumulate a full window from scratch.
|
||||
rc.pressure.Reset(cfg.ModelName)
|
||||
}
|
||||
}
|
||||
// No capacity: transient demand, not a misconfig - let the next
|
||||
// tick retry naturally (mirrors the busy-burst path's choice not
|
||||
// to enter cooldown for burst load).
|
||||
}
|
||||
}
|
||||
}
|
||||
if int(current) > floor {
|
||||
rc.scaleDownIdle(ctx, cfg, int(current), floor)
|
||||
|
||||
// 3. Scale down idle replicas above minimum. Skipped in a tick that already
|
||||
// scaled up so we never scale up and down in the same pass.
|
||||
if !scaledUp {
|
||||
floor := max(cfg.MinReplicas, 1)
|
||||
if int(current) > floor {
|
||||
rc.scaleDownIdle(ctx, cfg, int(current), floor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -470,10 +548,17 @@ func (rc *ReplicaReconciler) markCapacityProblem(ctx context.Context, modelName,
|
||||
// scaleUp schedules additional replicas of the model. Callers in
|
||||
// reconcileModel are expected to have already capped `count` against
|
||||
// ClusterCapacityForModel so this function never tries to overshoot.
|
||||
func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) {
|
||||
//
|
||||
// Returns true if at least one replica was actually scheduled. Callers use
|
||||
// this to gate signal-consuming side effects (Pressure.Reset,
|
||||
// ClearUnsatisfiable) on a real scale-up: a failed/no-op scaleUp must not
|
||||
// discard the accumulated forced-disturb pressure or clear the unsatisfiable
|
||||
// hysteresis, otherwise the signal has to re-accumulate from scratch and the
|
||||
// next tick can't simply retry.
|
||||
func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) bool {
|
||||
if rc.scheduler == nil {
|
||||
xlog.Warn("Reconciler: no scheduler available, cannot scale up")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
// Resolve selector → candidate node IDs (nil when no selector → "any
|
||||
@@ -481,18 +566,21 @@ func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingCon
|
||||
// reconcileModel, but defensively short-circuit here too.
|
||||
candidateNodeIDs, ok := rc.candidateNodeIDsForSelector(ctx, cfg)
|
||||
if !ok {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
scheduled := 0
|
||||
for i := 0; i < count; i++ {
|
||||
node, err := rc.scheduler.ScheduleAndLoadModel(ctx, cfg.ModelName, candidateNodeIDs)
|
||||
if err != nil {
|
||||
xlog.Warn("Reconciler: failed to scale up replica", "model", cfg.ModelName,
|
||||
"attempt", i+1, "error", err)
|
||||
return // stop trying on first failure
|
||||
break // stop trying on first failure
|
||||
}
|
||||
scheduled++
|
||||
xlog.Info("Reconciler: scaled up replica", "model", cfg.ModelName, "node", node.Name)
|
||||
}
|
||||
return scheduled > 0
|
||||
}
|
||||
|
||||
// scaleDownIdle removes idle replicas above the floor.
|
||||
|
||||
@@ -2,12 +2,14 @@ package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -245,6 +247,225 @@ var _ = Describe("ReplicaReconciler", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Forced-disturb pressure autoscale (Phase 6)", func() {
|
||||
It("scales up when pressure exceeds threshold, replicas<max, and capacity exists", func() {
|
||||
// One node with spare slots, one loaded idle replica (so the
|
||||
// all-busy path does not fire). Pressure for the model is above the
|
||||
// threshold, which is the only reason to scale here.
|
||||
node := registerNode("pressure-node", "10.0.0.60:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "pressure-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
setSchedulingConfig("pressure-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
pressure.Record("pressure-model", time.Now())
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"forced-disturb pressure above threshold must trigger a scale-up")
|
||||
Expect(scheduler.scheduleCalls[0].modelName).To(Equal("pressure-model"))
|
||||
})
|
||||
|
||||
It("does not scale up on pressure when already at max_replicas", func() {
|
||||
// Two nodes, both loaded (idle), MaxReplicas=2 → at max. Pressure is
|
||||
// high but MaxReplicas must never be overridden.
|
||||
node1 := registerNode("pmax-1", "10.0.0.61:50051")
|
||||
node2 := registerNode("pmax-2", "10.0.0.62:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node1.ID, "pmax-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node2.ID, "pmax-model", 0, "loaded", "addr2", 0)).To(Succeed())
|
||||
setSchedulingConfig("pmax-model", 1, 2, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
pressure.Record("pmax-model", time.Now())
|
||||
pressure.Record("pmax-model", time.Now())
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node1}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(BeEmpty(),
|
||||
"pressure must never override MaxReplicas")
|
||||
})
|
||||
|
||||
It("consumes the pressure signal so a single burst scales up only once", func() {
|
||||
// A single burst of forced-disturbs (well within the window) must
|
||||
// trigger exactly ONE pressure scale-up. A subsequent tick, with the
|
||||
// SAME events still in-window, must NOT scale again: the first
|
||||
// scale-up consumed (Reset) the signal. Without the fix, the
|
||||
// non-draining Count keeps returning >= threshold every tick and
|
||||
// drives the model toward MaxReplicas off a single burst.
|
||||
node := registerNode("consume-node", "10.0.0.64:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "consume-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
setSchedulingConfig("consume-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
now := time.Now()
|
||||
pressure.Record("consume-model", now)
|
||||
pressure.Record("consume-model", now)
|
||||
pressure.Record("consume-model", now)
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
// First tick: pressure above threshold → one scale-up.
|
||||
reconciler.reconcile(context.Background())
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"first tick must scale up once on the burst")
|
||||
|
||||
// Second tick: the burst's events are still inside the window, but
|
||||
// the first scale-up Reset them, so no further scale-up occurs.
|
||||
reconciler.reconcile(context.Background())
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"a single burst must not re-trigger scale-up on the next in-window tick")
|
||||
})
|
||||
|
||||
It("does not consume the pressure signal when scaleUp fails", func() {
|
||||
// Pressure above threshold and capacity exists, but the scheduler
|
||||
// errors so no replica is actually added. The forced-disturb signal
|
||||
// must be preserved (NOT Reset) so the next tick retries the
|
||||
// scale-up off the same accumulated pressure, instead of having to
|
||||
// re-accumulate a full window of forced-disturbs from scratch.
|
||||
node := registerNode("fail-node", "10.0.0.66:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "fail-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
setSchedulingConfig("fail-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
now := time.Now()
|
||||
pressure.Record("fail-model", now)
|
||||
pressure.Record("fail-model", now)
|
||||
pressure.Record("fail-model", now)
|
||||
Expect(pressure.Count("fail-model", time.Now())).To(BeNumerically(">=", 1))
|
||||
|
||||
// Scheduler errors: scaleUp attempts but adds nothing.
|
||||
scheduler := &fakeScheduler{scheduleErr: errors.New("schedule boom")}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"scaleUp must have attempted exactly one schedule call")
|
||||
Expect(pressure.Count("fail-model", time.Now())).To(BeNumerically(">=", 1),
|
||||
"a failed scaleUp must NOT consume (Reset) the pressure signal — next tick should retry")
|
||||
})
|
||||
|
||||
It("consumes the pressure signal only when scaleUp succeeds", func() {
|
||||
// Mirror of the failure case: when the scheduler succeeds and a
|
||||
// replica is actually added, the forced-disturb signal IS consumed
|
||||
// (Reset to 0) so a single burst scales up only once.
|
||||
node := registerNode("ok-node", "10.0.0.67:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "ok-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
setSchedulingConfig("ok-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
now := time.Now()
|
||||
pressure.Record("ok-model", now)
|
||||
pressure.Record("ok-model", now)
|
||||
pressure.Record("ok-model", now)
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"successful scaleUp must have scheduled one replica")
|
||||
Expect(pressure.Count("ok-model", time.Now())).To(Equal(0),
|
||||
"a successful scaleUp must consume (Reset) the pressure signal to 0")
|
||||
})
|
||||
|
||||
It("performs at most one scale-up per tick when both busy and over pressure", func() {
|
||||
// The single loaded replica is busy (all-replicas-busy fires) AND
|
||||
// pressure is above threshold. Both scale-up paths are eligible in
|
||||
// the same tick. The invariant is at-most-one scaleUp(+1) per tick,
|
||||
// so exactly one schedule call must happen, not two.
|
||||
node := registerNode("dual-node", "10.0.0.65:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "dual-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), node.ID, "dual-model", 0)).To(Succeed())
|
||||
setSchedulingConfig("dual-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
pressure.Record("dual-model", time.Now())
|
||||
pressure.Record("dual-model", time.Now())
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"busy + pressure in one tick must still scale up by exactly one, not two")
|
||||
})
|
||||
|
||||
It("does not spin when pressure is high but no capacity exists", func() {
|
||||
// Single node, cap 1, already loaded → capacity 0. Pressure is high
|
||||
// but there is nowhere to place a replica: must not call scheduler.
|
||||
registerCappedNodeFn := func(name, address string, cap int) *BackendNode {
|
||||
node := &BackendNode{
|
||||
Name: name,
|
||||
NodeType: NodeTypeBackend,
|
||||
Address: address,
|
||||
MaxReplicasPerModel: cap,
|
||||
}
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
return node
|
||||
}
|
||||
node := registerCappedNodeFn("pcap-node", "10.0.0.63:50051", 1)
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "pcap-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
// MaxReplicas high enough that replicas<max, so only capacity guards it.
|
||||
setSchedulingConfig("pcap-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
pressure.Record("pcap-model", time.Now())
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(BeEmpty(),
|
||||
"no capacity means no scale-up: must not spin the scheduler")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Capacity gating + circuit breaker (PR4)", func() {
|
||||
// Helper: register a node with an explicit per-model replica cap.
|
||||
// Tests in this Describe block want to exercise both "fits" and
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user