mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-05 23:36:49 -04:00
Compare commits
61 Commits
v4.3.5
...
feat/p2p-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c410cd7253 | ||
|
|
4e993332af | ||
|
|
42e51894c3 | ||
|
|
d9ae6481fb | ||
|
|
f1c495a748 | ||
|
|
415b561947 | ||
|
|
e6a0d4c375 | ||
|
|
7e59a5c7c5 | ||
|
|
aea954a482 | ||
|
|
595e448714 | ||
|
|
860f9d63ad | ||
|
|
a5a0b3dc4e | ||
|
|
94eca04c60 | ||
|
|
35bd485d6a | ||
|
|
1fe96f8d9a | ||
|
|
c508e9d7c6 | ||
|
|
55e754fd05 | ||
|
|
450376d22f | ||
|
|
8180fddc05 | ||
|
|
5033457f57 | ||
|
|
d88758282a | ||
|
|
a0c7cecddd | ||
|
|
bc42374d8a | ||
|
|
ec2a0645dd | ||
|
|
ce8b97edf2 | ||
|
|
91fc26ff75 | ||
|
|
8df0bb683b | ||
|
|
8ec536a34c | ||
|
|
a17753f7d1 | ||
|
|
14b57aa343 | ||
|
|
288d732af7 | ||
|
|
ed38609d51 | ||
|
|
c61838dba6 | ||
|
|
7013e13f05 | ||
|
|
5a0013defe | ||
|
|
7768b35696 | ||
|
|
830f818c58 | ||
|
|
c01ed631d6 | ||
|
|
d47464cb06 | ||
|
|
63f176346e | ||
|
|
af94d08729 | ||
|
|
6795d38f50 | ||
|
|
718223f33b | ||
|
|
39e050d9e2 | ||
|
|
c222161291 | ||
|
|
aa80d4681b | ||
|
|
0d57957ebb | ||
|
|
76fe0bb929 | ||
|
|
baa11133f1 | ||
|
|
1bdd3338a6 | ||
|
|
e08492a2c3 | ||
|
|
d5d8fe909d | ||
|
|
8a82753277 | ||
|
|
51ca109067 | ||
|
|
07f6c15a37 | ||
|
|
a44bdb29d4 | ||
|
|
aee4611ab2 | ||
|
|
486467623c | ||
|
|
4912c9b73a | ||
|
|
12d1f3a697 | ||
|
|
a7cad704b9 |
@@ -38,9 +38,12 @@ The React UI (`core/http/react-ui/`) has **no component/unit tests** — its onl
|
||||
- **Browser:** the flake dev shell ships `chromium` and exports `PLAYWRIGHT_CHROMIUM_PATH`; `playwright.config.js` uses it via `launchOptions.executablePath`, and the Makefile skips `playwright install` when it's set. This avoids Playwright's downloaded browser, which can't resolve system libs (`libglib-2.0`, …) on NixOS. In CI (no `PLAYWRIGHT_CHROMIUM_PATH`) the Makefile falls back to `playwright install --with-deps chromium`.
|
||||
- The app is a React SPA, so coverage accumulates across in-app navigation within a test; a full `page.goto`/reload resets it.
|
||||
- `.nycrc.json` uses `all: true`, so **every `src/**` file is in the report**, including 0%-coverage ones — that's how you spot features with no test at all (sort the HTML report or `coverage-summary.json` by line% ascending).
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` (default **1.0pp**) below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. **Why a tolerance (unlike the strict Go gate):** UI e2e line coverage is *non-deterministic* — async/debounced paths (e.g. the VRAM estimate's 500ms debounce) make identical specs vary ~0.5pp run-to-run, so a zero-tolerance gate would flake. Keep the tolerance just above the observed jitter. Run in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. Runs in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **Why it has a tolerance (unlike the strict Go gate):** UI e2e coverage is *non-deterministic*. Specs that assert on state and end while async/lazy render work is still in flight collect those lines only when the render beats the coverage teardown — so the total drifts with machine speed/load (a fast local box reads higher than a slow CI runner), diffusely across many specs. The tolerance absorbs that drift, so set the baseline *below* the slow-CI floor, never to a fast-local `make test-ui-coverage-baseline` number, or CI flaps.
|
||||
- **Raising coverage is cheap:** a *render-smoke* spec (navigate to a route, assert its header renders) mounts a lazy page and runs its full render + initial effects, capturing most of its lines in a few lines of test — see `e2e/page-render-smoke.spec.js`. Auth is disabled in the test server (`isAdmin=true`), so `RequireAdmin`/`RequireFeature` routes render without a mock. The most *deterministic* win is removing a race: make a spec `await` a rendered element before ending (see `e2e/agents.spec.js` → AgentCreate) so its lines count every run.
|
||||
|
||||
Rules:
|
||||
- The gate is **strict — there is no tolerance**. Any decrease fails, regardless of how many lines a PR adds or deletes. `covermode=atomic` makes line coverage deterministic, so there's no run-to-run jitter to excuse.
|
||||
- When a change legitimately **raises** coverage, run `make test-coverage-baseline` and **commit** the updated `coverage-baseline.txt` so the ratchet moves up. Never lower the baseline by hand.
|
||||
- If you can't get coverage back to baseline, the fix is to **add tests**, not to edit the baseline.
|
||||
Rules (both gates):
|
||||
- **Install the hooks:** `make install-hooks` once per clone so lint + coverage run pre-commit. Don't lean on CI for what the hook catches.
|
||||
- **Don't work around the gate:** never `git commit --no-verify`, and never hand-lower a baseline or widen a tolerance to turn a red gate green. The ratchet only moves up.
|
||||
- If a change drops coverage, **add tests** (sort `coverage-summary.json` by line% ascending to find untested code) rather than editing the baseline. When coverage legitimately rises, commit the regenerated baseline (`make test-coverage-baseline` / `test-ui-coverage-baseline`).
|
||||
- The Go gate is **strict — no tolerance**; `covermode=atomic` keeps it deterministic. The UI gate keeps a small tolerance only because its e2e coverage isn't.
|
||||
|
||||
@@ -50,6 +50,17 @@ Do not mix styles within a package. If you are extending tests in a package that
|
||||
|
||||
This is enforced by `golangci-lint` via the `forbidigo` linter (see `.golangci.yml`); calls like `t.Errorf` / `t.Fatalf` / `t.Run` / `t.Skip` / `t.Logf` are flagged. Run `make lint` locally before submitting; the same check runs in CI (`.github/workflows/lint.yml`).
|
||||
|
||||
## Outbound HTTP
|
||||
|
||||
All outbound HTTP must go through `github.com/mudler/LocalAI/pkg/httpclient` rather than the standard library's default client. Use `httpclient.New(...)` (no body deadline — safe for streaming/SSE) or `httpclient.NewWithTimeout(d, ...)` (simple request/response). Both **refuse redirects by default** and set a TLS 1.2 floor.
|
||||
|
||||
The reason is GHSA-3mj3-57v2-4636: the std default client follows redirects, and on a *cross-host* redirect Go forwards custom credential headers (e.g. Anthropic's `x-api-key`) to the redirect target, leaking the secret. `httpclient` fails closed instead.
|
||||
|
||||
- Need to follow redirects (download CDNs, registry blobs, GitHub asset URLs)? Pass `httpclient.WithFollowRedirects()` — it still strips credential headers on any cross-host hop.
|
||||
- Have a custom transport (IP-pinned dialer, HTTP/2 tuning, a credential-injecting `RoundTripper`)? Pass `httpclient.WithTransport(rt)`, basing the transport on `httpclient.HardenedTransport()` to keep the TLS floor. Handed a `*http.Client` by a library? `httpclient.Harden(c)` applies the policy in place.
|
||||
|
||||
This is enforced by `forbidigo` (see `.golangci.yml`): `http.DefaultClient` and `http.Get`/`Post`/`PostForm`/`Head` are flagged. The `&http.Client{}` composite literal can't be matched precisely by forbidigo without also flagging legitimate `*http.Client` type references, so that form is caught by review — don't construct raw clients.
|
||||
|
||||
## Documentation
|
||||
|
||||
The project documentation is located in `docs/content`. When adding new features or changing existing functionality, it is crucial to update the documentation to reflect these changes. This helps users understand how to use the new capabilities and ensures the documentation stays relevant.
|
||||
|
||||
@@ -68,6 +68,34 @@ go test -count=1 -timeout=30m -v ./tests/e2e-backends/...
|
||||
|
||||
CI does not load the model; the suite is opt-in via env vars.
|
||||
|
||||
## Distributed mode
|
||||
|
||||
ds4 supports **layer-split** distributed inference (a model too big for one host,
|
||||
split by transformer layer; the GGUF must be present on every machine, each loads
|
||||
only its slice). Topology is **inverted** vs llama.cpp: the coordinator listens,
|
||||
workers dial in.
|
||||
|
||||
- **`ds4-worker` binary**: built and packaged next to `grpc-server` (`package.sh`
|
||||
copies it into `package/`). Links the same engine objects plus `ds4_distributed.o`;
|
||||
**no gRPC/protobuf dependency** (speaks ds4's own TCP transport), so it builds
|
||||
even where `grpc-server` can't. Runs the worker serving loop (`ds4_dist_run`).
|
||||
- **Coordinator wiring**: the ds4 `grpc-server` acts as coordinator when `LoadModel`
|
||||
`ModelOptions.Options` (from model-YAML `options:`) carry:
|
||||
- `ds4_role:coordinator` (enables distributed mode; absent → single-node, back-compat)
|
||||
- `ds4_layers:0:19` (coordinator's own slice, inclusive; `N:output` includes the head)
|
||||
- `ds4_listen:0.0.0.0:1234` (address workers dial into)
|
||||
- `ds4_route_timeout:60` (optional; seconds Predict/PredictStream wait for the route
|
||||
to form before returning gRPC `UNAVAILABLE`; default 60)
|
||||
- **Worker CLI**: `local-ai worker ds4-distributed -- <ds4-worker args>` resolves the
|
||||
ds4 backend and execs the packaged `ds4-worker` (raw passthrough), e.g.
|
||||
`--role worker --model /models/ds4flash.gguf --layers 20:output --coordinator <host> 1234`.
|
||||
|
||||
Opt-in e2e in `tests/e2e-backends/backend_test.go`, gated by
|
||||
`BACKEND_TEST_DS4_DISTRIBUTED=1` (plus `BACKEND_TEST_DS4_WORKER_BINARY`,
|
||||
`BACKEND_TEST_DS4_WORKER_LAYERS`, `BACKEND_TEST_DS4_COORDINATOR_LAYERS`,
|
||||
`BACKEND_TEST_DS4_LISTEN`). Design spec:
|
||||
`docs/superpowers/specs/2026-05-30-ds4-distributed-inference-design.md`.
|
||||
|
||||
## Importer
|
||||
|
||||
`core/gallery/importers/ds4.go` (`DS4Importer`) auto-detects ds4 weights by
|
||||
|
||||
303
.github/backend-matrix.yml
vendored
303
.github/backend-matrix.yml
vendored
@@ -716,6 +716,32 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -1556,6 +1582,32 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1569,6 +1621,32 @@ include:
|
||||
backend: "whisper"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-crispasr'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-parakeet-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2850,6 +2928,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2864,6 +2956,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2877,6 +2983,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2890,6 +3009,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2904,6 +3036,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2918,6 +3064,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2931,6 +3091,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-crispasr'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2944,6 +3117,128 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-crispasr'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# parakeet-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-parakeet-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-parakeet-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-parakeet-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-parakeet-cpp'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# acestep-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -3976,6 +4271,14 @@ includeDarwin:
|
||||
tag-suffix: "-metal-darwin-arm64-whisper"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "crispasr"
|
||||
tag-suffix: "-metal-darwin-arm64-crispasr"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "parakeet-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-parakeet-cpp"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "acestep-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-acestep-cpp"
|
||||
build-type: "metal"
|
||||
|
||||
8
.github/workflows/bump_deps.yaml
vendored
8
.github/workflows/bump_deps.yaml
vendored
@@ -30,6 +30,14 @@ jobs:
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/whisper/Makefile"
|
||||
- repository: "CrispStrobe/CrispASR"
|
||||
variable: "CRISPASR_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/crispasr/Makefile"
|
||||
- repository: "mudler/parakeet.cpp"
|
||||
variable: "PARAKEET_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/parakeet-cpp/Makefile"
|
||||
- repository: "leejet/stable-diffusion.cpp"
|
||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||
branch: "master"
|
||||
|
||||
21
.github/workflows/test-extra.yml
vendored
21
.github/workflows/test-extra.yml
vendored
@@ -46,6 +46,7 @@ jobs:
|
||||
speaker-recognition: ${{ steps.detect.outputs.speaker-recognition }}
|
||||
sherpa-onnx: ${{ steps.detect.outputs.sherpa-onnx }}
|
||||
whisper: ${{ steps.detect.outputs.whisper }}
|
||||
parakeet-cpp: ${{ steps.detect.outputs.parakeet-cpp }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -633,6 +634,26 @@ jobs:
|
||||
- name: Build whisper backend image and run transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-whisper-transcription
|
||||
# Parakeet ASR via the parakeet-cpp backend (C++/ggml port of NeMo
|
||||
# Parakeet). Drives AudioTranscription (offline, with word timestamps) on
|
||||
# tdt_ctc-110m + the JFK 11s clip.
|
||||
tests-parakeet-cpp-grpc-transcription:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.parakeet-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
- name: Build parakeet-cpp backend image and run transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-parakeet-cpp-transcription
|
||||
# VITS TTS via the sherpa-onnx backend. Drives both TTS (file write) and
|
||||
# TTSStream (PCM chunks) on the e2e-backends harness.
|
||||
tests-sherpa-onnx-grpc-tts:
|
||||
|
||||
@@ -56,6 +56,20 @@ linters:
|
||||
# are exempt — see linters.exclusions.rules below.
|
||||
- pattern: '^os\.(Getenv|LookupEnv|Environ)$'
|
||||
msg: 'Plumb config through ApplicationConfig (or the relevant CLI struct) instead of reading env directly. CLI entry points (core/cli/) bind env vars via kong''s `env:` tag — that is the only sanctioned env→struct boundary. See .agents/coding-style.md.'
|
||||
# Outbound HTTP must go through pkg/httpclient, which refuses redirects
|
||||
# by default and sets a TLS floor. The std-library default client and
|
||||
# the http.Get/Post/... convenience helpers follow redirects (up to 10)
|
||||
# and, on a cross-host redirect, forward custom credential headers such
|
||||
# as Anthropic's x-api-key to the redirect target — leaking the secret
|
||||
# (GHSA-3mj3-57v2-4636). forbidigo can't precisely match the
|
||||
# `&http.Client{}` composite literal without also flagging legitimate
|
||||
# `*http.Client` type references, so that form is enforced by
|
||||
# convention + review; these two patterns catch the implicit-default
|
||||
# client, which is the common footgun.
|
||||
- pattern: '^http\.DefaultClient$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.DefaultClient — the std client follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
- pattern: '^http\.(Get|Post|PostForm|Head)$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.Get/Post/PostForm/Head — these use http.DefaultClient, which follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
exclusions:
|
||||
paths:
|
||||
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
||||
@@ -95,3 +109,18 @@ linters:
|
||||
- path: _test\.go$
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# pkg/httpclient is the sanctioned home for outbound HTTP clients; it
|
||||
# necessarily references net/http directly.
|
||||
- path: ^pkg/httpclient/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Tests drive local httptest servers where redirect/TLS hardening is
|
||||
# irrelevant; the std client is fine there.
|
||||
- path: _test\.go$
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Vendored upstream whisper.cpp Go bindings are a separate module and
|
||||
# cannot import pkg/httpclient.
|
||||
- path: ^backend/go/whisper/sources/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
|
||||
@@ -35,6 +35,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
|
||||
## Quick Reference
|
||||
|
||||
- **Git hooks & coverage gates**: Run `make install-hooks` once per clone so the pre-commit lint + coverage gates run. **Never bypass them with `git commit --no-verify`, and never lower a coverage baseline or widen a gate's tolerance to turn a red gate green** — the coverage ratchet only moves up. If a change drops coverage, add tests to raise it (e.g. render-smoke specs). See [.agents/building-and-testing.md](.agents/building-and-testing.md).
|
||||
- **Logging**: Use `github.com/mudler/xlog` (same API as slog)
|
||||
- **Go style**: Prefer `any` over `interface{}`
|
||||
- **Comments**: Explain *why*, not *what*
|
||||
|
||||
@@ -266,6 +266,12 @@ The e2e tests run LocalAI in a Docker container and exercise the API:
|
||||
make test-e2e
|
||||
```
|
||||
|
||||
### React UI tests and coverage
|
||||
|
||||
The React UI (`core/http/react-ui/`) is covered by Playwright e2e specs, gated by a **monotonic line-coverage ratchet** (`make test-ui-coverage-check`, run in CI and pre-commit). The metric is non-deterministic — a fast local box reads higher than a slow CI runner for the same code — so a small tolerance is unavoidable.
|
||||
|
||||
**If your change lowers UI coverage, raise it back by adding specs — do not widen the tolerance or hand-lower the baseline.** A *render-smoke* spec (navigate to a page, assert its header is visible) cheaply covers an entire lazy page. See `core/http/react-ui/e2e/page-render-smoke.spec.js` and the full policy in [.agents/building-and-testing.md](.agents/building-and-testing.md#react-ui-coverage).
|
||||
|
||||
### Running E2E container tests
|
||||
|
||||
These tests build a standard LocalAI Docker image and run it with pre-configured model configs to verify that most endpoints work correctly:
|
||||
|
||||
21
Makefile
21
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -991,6 +991,19 @@ test-extra-backend-whisper-transcription: docker-build-whisper
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## Audio transcription wrapper for the parakeet-cpp (parakeet.cpp ggml port)
|
||||
## backend. Mirrors test-extra-backend-whisper-transcription: drives the
|
||||
## AudioTranscription / AudioTranscriptionStream RPCs against a published
|
||||
## Parakeet GGUF using the JFK 11s clip from whisper.cpp's CI samples. Not
|
||||
## part of the default test suite - run explicitly once the pinned model URL
|
||||
## is reachable.
|
||||
test-extra-backend-parakeet-cpp-transcription: docker-build-parakeet-cpp
|
||||
BACKEND_IMAGE=local-ai-backend:parakeet-cpp \
|
||||
BACKEND_TEST_MODEL_URL=https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/tdt_ctc-110m-f16.gguf \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## LocalVQE audio transform (joint AEC + noise suppression + dereverb).
|
||||
## Exercises the audio_transform capability end-to-end: batch transform
|
||||
## of a real WAV fixture and bidi streaming of synthetic silent frames.
|
||||
@@ -1149,6 +1162,8 @@ BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
||||
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
||||
@@ -1236,6 +1251,8 @@ $(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
||||
@@ -1285,7 +1302,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-crispasr docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
18
README.md
18
README.md
@@ -31,12 +31,18 @@
|
||||
|
||||
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||
|
||||
- **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
|
||||
- **36+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
|
||||
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready** — API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first** — your data never leaves your infrastructure
|
||||
**A small core, not a bundle.** Each backend wraps a best-in-class engine (llama.cpp, vLLM, whisper.cpp, stable-diffusion, MLX...) in its own image, pulled only when a model needs it. You install nothing you don't use.
|
||||
|
||||
- **Composable by design**: backends are separate and pulled on demand, so you install only what your model needs
|
||||
- **Open and extensible**: load any model, or build your own backend in any language against an open interface
|
||||
- **Drop-in API compatibility**: OpenAI, Anthropic, and ElevenLabs APIs across every backend
|
||||
- **Any model, any modality**: LLMs, vision, voice, image, and video behind one API
|
||||
- **Any hardware**: NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready**: API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents**: autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first**: your data never leaves your infrastructure
|
||||
|
||||

|
||||
|
||||
Created by [Ettore Di Giacinto](https://github.com/mudler) and maintained by the [LocalAI team](#team).
|
||||
|
||||
|
||||
1
backend/cpp/ds4/.gitignore
vendored
1
backend/cpp/ds4/.gitignore
vendored
@@ -2,6 +2,7 @@ ds4/
|
||||
build/
|
||||
package/
|
||||
grpc-server
|
||||
ds4-worker
|
||||
*.o
|
||||
backend.pb.cc
|
||||
backend.pb.h
|
||||
|
||||
@@ -60,6 +60,11 @@ elseif(DS4_GPU STREQUAL "cpu")
|
||||
set(DS4_OBJS "${DS4_DIR}/ds4_cpu.o")
|
||||
endif()
|
||||
|
||||
# ds4.c now references ds4_distributed.c (distributed inference was split into
|
||||
# its own translation unit upstream). It is a single GPU-agnostic object shared
|
||||
# by every GPU mode, so link it in regardless of DS4_GPU.
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_distributed.o")
|
||||
|
||||
add_executable(${TARGET}
|
||||
grpc-server.cpp
|
||||
dsml_parser.cpp
|
||||
@@ -99,3 +104,36 @@ if(DS4_NATIVE)
|
||||
target_compile_options(${TARGET} PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ds4-worker: standalone distributed worker. Links the same ds4 engine objects
|
||||
# (including ds4_distributed.o) but has NO gRPC/protobuf dependency - it speaks
|
||||
# ds4's own TCP transport via ds4_dist_run(). Buildable wherever the engine
|
||||
# objects build, even on hosts without protobuf/grpc dev headers.
|
||||
add_executable(ds4-worker worker_main.c)
|
||||
target_include_directories(ds4-worker PRIVATE ${DS4_DIR})
|
||||
foreach(obj ${DS4_OBJS})
|
||||
target_sources(ds4-worker PRIVATE ${obj})
|
||||
set_source_files_properties(${obj} PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
|
||||
endforeach()
|
||||
# worker_main.c is C, but the engine objects built by nvcc (ds4_cuda.o) and the
|
||||
# Metal path (ds4_metal.o, Obj-C++) reference the C++ runtime (libstdc++). Force
|
||||
# the C++ linker driver so those symbols resolve; the C driver would not link
|
||||
# libstdc++ and the CUDA/Metal builds fail with undefined std:: references.
|
||||
set_target_properties(ds4-worker PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_link_libraries(ds4-worker PRIVATE Threads::Threads m)
|
||||
|
||||
if(DS4_GPU STREQUAL "cuda")
|
||||
target_link_libraries(ds4-worker PRIVATE CUDA::cudart CUDA::cublas)
|
||||
elseif(DS4_GPU STREQUAL "metal")
|
||||
target_link_libraries(ds4-worker PRIVATE ${FOUNDATION_LIB} ${METAL_LIB})
|
||||
elseif(DS4_GPU STREQUAL "cpu")
|
||||
target_compile_definitions(ds4-worker PRIVATE DS4_NO_GPU)
|
||||
endif()
|
||||
|
||||
if(DS4_NATIVE)
|
||||
if(APPLE)
|
||||
target_compile_options(ds4-worker PRIVATE -mcpu=native)
|
||||
else()
|
||||
target_compile_options(ds4-worker PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=22393e770ea8eb7501d8718d6f66c6374004e03f
|
||||
# Upstream pin lives below as DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=22393e770ea8eb7501d8718d6f66c6374004e03f
|
||||
DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
@@ -18,16 +18,19 @@ UNAME_S := $(shell uname -s)
|
||||
|
||||
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
|
||||
|
||||
# ds4_distributed.o is a GPU-agnostic translation unit that ds4.c/ds4_cpu.o now
|
||||
# reference (upstream split distributed inference into its own .c). The same
|
||||
# object is shared by every GPU mode, so it is appended unconditionally below.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS += -DDS4_GPU=cuda
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
CMAKE_ARGS += -DDS4_GPU=metal
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o
|
||||
else
|
||||
# CPU reference path (Linux only - macOS CPU path is broken by VM bug per ds4 README).
|
||||
CMAKE_ARGS += -DDS4_GPU=cpu
|
||||
DS4_OBJ_TARGET := ds4_cpu.o
|
||||
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o
|
||||
endif
|
||||
|
||||
ifneq ($(NATIVE),true)
|
||||
@@ -52,17 +55,18 @@ ds4:
|
||||
# the right per-platform compile flags (Objective-C/Metal on Darwin, nvcc on Linux+CUDA).
|
||||
ds4/ds4.o: ds4
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o
|
||||
else
|
||||
+$(MAKE) -C ds4 ds4_cpu.o
|
||||
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o
|
||||
endif
|
||||
|
||||
grpc-server: ds4/ds4.o
|
||||
mkdir -p $(BUILD_DIR)
|
||||
cd $(BUILD_DIR) && cmake $(CMAKE_ARGS) $(CURRENT_MAKEFILE_DIR) && cmake --build . --config Release -j $(JOBS)
|
||||
cp $(BUILD_DIR)/grpc-server grpc-server
|
||||
cp $(BUILD_DIR)/ds4-worker ds4-worker
|
||||
|
||||
package: grpc-server
|
||||
bash package.sh
|
||||
@@ -71,7 +75,7 @@ test:
|
||||
@echo "ds4 backend: e2e coverage at tests/e2e-backends/ (BACKEND_BINARY mode)"
|
||||
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR) grpc-server package
|
||||
rm -rf $(BUILD_DIR) grpc-server ds4-worker package
|
||||
if [ -d ds4 ]; then $(MAKE) -C ds4 clean; fi
|
||||
|
||||
purge: clean
|
||||
|
||||
@@ -23,8 +23,11 @@ extern "C" {
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
@@ -51,6 +54,12 @@ ds4_session *g_session = nullptr;
|
||||
int g_ctx_size = 32768;
|
||||
std::string g_kv_cache_dir; // empty disables disk cache
|
||||
|
||||
// Distributed coordinator state. g_distributed is set true when LoadModel is
|
||||
// given 'ds4_role:coordinator'; generation then waits for the worker route to
|
||||
// form before running. Single-node behavior is unchanged when unset.
|
||||
bool g_distributed = false;
|
||||
int g_route_timeout_sec = 60;
|
||||
|
||||
std::atomic<Server *> g_server{nullptr};
|
||||
|
||||
// Parse a "key:value" option string. Returns empty when no colon.
|
||||
@@ -60,6 +69,77 @@ static std::pair<std::string, std::string> split_option(const std::string &opt)
|
||||
return {opt.substr(0, colon), opt.substr(colon + 1)};
|
||||
}
|
||||
|
||||
// Parse a positive base-10 integer. Returns false (without throwing) on empty,
|
||||
// trailing garbage, non-positive, or overflow - unlike std::stoi.
|
||||
static bool parse_positive_int(const std::string &s, int *out) {
|
||||
if (s.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long v = std::strtol(s.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || v <= 0 || v > INT_MAX) return false;
|
||||
*out = static_cast<int>(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse a ds4 layer spec "START:END" or "START:output" into the engine's
|
||||
// distributed layer fields. Returns false on malformed input.
|
||||
static bool parse_layers_spec(const std::string &spec, ds4_distributed_layers *out) {
|
||||
auto colon = spec.find(':');
|
||||
if (colon == std::string::npos) return false;
|
||||
std::string lhs = spec.substr(0, colon);
|
||||
std::string rhs = spec.substr(colon + 1);
|
||||
if (lhs.empty() || rhs.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long start = std::strtol(lhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || start < 0) return false;
|
||||
out->start = static_cast<uint32_t>(start);
|
||||
out->has_output = false;
|
||||
if (rhs == "output") {
|
||||
out->has_output = true;
|
||||
out->end = out->start; // engine treats has_output as "through final layer"
|
||||
} else {
|
||||
long e = std::strtol(rhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || e < start) return false;
|
||||
out->end = static_cast<uint32_t>(e);
|
||||
}
|
||||
out->set = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
// When acting as a distributed coordinator, block until the worker route
|
||||
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
||||
// elapses. Returns an empty string on success, or an error message to return
|
||||
// to the client. No-op when not distributed.
|
||||
//
|
||||
// Takes the g_engine_mu lock by reference and RELEASES it during each poll
|
||||
// sleep. The wait can span up to g_route_timeout_sec seconds while workers
|
||||
// connect; holding g_engine_mu the whole time would block the Status/Health
|
||||
// readiness probes (they also lock g_engine_mu), making LocalAI's loader treat
|
||||
// a still-starting worker as hung.
|
||||
static std::string wait_route_ready(std::unique_lock<std::mutex> &lock) {
|
||||
if (!g_distributed) return "";
|
||||
char err[256] = {0};
|
||||
const int deadline_polls = g_route_timeout_sec * 10; // 100ms per poll
|
||||
for (int i = 0; i <= deadline_polls; ++i) {
|
||||
int ready = ds4_session_distributed_route_ready(g_session, err, sizeof(err));
|
||||
if (ready == 1) return "";
|
||||
if (ready < 0) {
|
||||
return std::string("ds4 distributed route error: ") +
|
||||
(err[0] ? err : "unknown");
|
||||
}
|
||||
// Release the lock while sleeping so Status/Health and other RPCs can
|
||||
// interleave during worker startup.
|
||||
lock.unlock();
|
||||
struct timespec ts = {0, 100L * 1000L * 1000L}; // 100ms
|
||||
nanosleep(&ts, nullptr);
|
||||
lock.lock();
|
||||
// A concurrent Free() may have torn down the engine while we slept.
|
||||
if (!g_engine || !g_session) {
|
||||
return "ds4: model unloaded while waiting for distributed route";
|
||||
}
|
||||
}
|
||||
return "ds4 distributed route incomplete: workers not connected (layers uncovered)";
|
||||
}
|
||||
|
||||
static void append_token_text(ds4_engine *engine, int token, std::string &out) {
|
||||
size_t len = 0;
|
||||
const char *text = ds4_token_text(engine, token, &len);
|
||||
@@ -377,6 +457,11 @@ public:
|
||||
backend::Result *result) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
|
||||
// Reset distributed state so a model swap (a second LoadModel without
|
||||
// ds4_role) doesn't inherit a stale coordinator configuration.
|
||||
g_distributed = false;
|
||||
g_route_timeout_sec = 60;
|
||||
|
||||
if (g_engine) {
|
||||
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
||||
ds4_engine_close(g_engine);
|
||||
@@ -394,12 +479,23 @@ public:
|
||||
std::string mtp_path;
|
||||
int mtp_draft = 0;
|
||||
float mtp_margin = 3.0f;
|
||||
std::string ds4_role, ds4_layers, ds4_listen;
|
||||
for (const auto &opt : request->options()) {
|
||||
auto [k, v] = split_option(opt);
|
||||
if (k == "mtp_path") mtp_path = v;
|
||||
else if (k == "mtp_draft") mtp_draft = std::stoi(v);
|
||||
else if (k == "mtp_margin") mtp_margin = std::stof(v);
|
||||
else if (k == "kv_cache_dir") g_kv_cache_dir = v;
|
||||
else if (k == "ds4_role") ds4_role = v;
|
||||
else if (k == "ds4_layers") ds4_layers = v;
|
||||
else if (k == "ds4_listen") ds4_listen = v;
|
||||
else if (k == "ds4_route_timeout") {
|
||||
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
g_kv_cache.SetDir(g_kv_cache_dir);
|
||||
@@ -422,6 +518,49 @@ public:
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
|
||||
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
||||
// distributed inference: this process listens on ds4_listen and owns
|
||||
// the ds4_layers slice; workers dial in (see `local-ai worker
|
||||
// ds4-distributed`). Absent ds4_role => unchanged single-node path.
|
||||
// Must be static: opt.distributed.listen_host is a const char* the
|
||||
// engine retains past this call, so it cannot point at a local that
|
||||
// goes out of scope (otherwise a future "simplify to local" refactor
|
||||
// reintroduces a dangling pointer).
|
||||
static std::string s_listen_host;
|
||||
if (ds4_role == "coordinator") {
|
||||
if (ds4_layers.empty() || ds4_listen.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_role:coordinator requires ds4_layers and ds4_listen");
|
||||
return GStatus::OK;
|
||||
}
|
||||
// host:port for IPv4/hostname; IPv6 literals are unsupported (the
|
||||
// first colon would split inside the address).
|
||||
auto host_port = split_option(ds4_listen); // "host:port" -> {host, port}
|
||||
if (host_port.second.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen must be host:port");
|
||||
return GStatus::OK;
|
||||
}
|
||||
int listen_port = 0;
|
||||
if (!parse_positive_int(host_port.second, &listen_port)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen port must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
ds4_distributed_layers layers = {};
|
||||
if (!parse_layers_spec(ds4_layers, &layers)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: invalid ds4_layers (want START:END or START:output)");
|
||||
return GStatus::OK;
|
||||
}
|
||||
s_listen_host = host_port.first;
|
||||
opt.distributed.role = DS4_DISTRIBUTED_COORDINATOR;
|
||||
opt.distributed.layers = layers;
|
||||
opt.distributed.listen_host = s_listen_host.c_str();
|
||||
opt.distributed.listen_port = listen_port;
|
||||
g_distributed = true;
|
||||
}
|
||||
|
||||
int rc = ds4_engine_open(&g_engine, &opt);
|
||||
if (rc != 0 || !g_engine) {
|
||||
result->set_success(false);
|
||||
@@ -458,10 +597,13 @@ public:
|
||||
|
||||
GStatus Predict(ServerContext *, const backend::PredictOptions *request,
|
||||
backend::Reply *reply) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
@@ -554,10 +696,13 @@ public:
|
||||
|
||||
GStatus PredictStream(ServerContext *, const backend::PredictOptions *request,
|
||||
ServerWriter<backend::Reply> *writer) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
|
||||
@@ -5,7 +5,8 @@ REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
cp -avf "$CURDIR/grpc-server" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/ds4-worker" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
UNAME_S=$(uname -s)
|
||||
if [ "$UNAME_S" = "Darwin" ]; then
|
||||
|
||||
126
backend/cpp/ds4/worker_main.c
Normal file
126
backend/cpp/ds4/worker_main.c
Normal file
@@ -0,0 +1,126 @@
|
||||
// ds4-worker: standalone distributed worker for the LocalAI ds4 backend.
|
||||
//
|
||||
// A ds4 distributed worker owns a slice of the model's transformer layers,
|
||||
// dials the coordinator, and serves activations for its slice. It does NOT
|
||||
// speak backend.proto - it speaks ds4's own TCP transport via ds4_dist_run().
|
||||
// This binary is intentionally minimal (no HTTP/web/kvstore/linenoise): it
|
||||
// only needs the engine objects + ds4_distributed.o, which the backend already
|
||||
// builds. It is launched by `local-ai worker ds4-distributed`.
|
||||
//
|
||||
// Usage:
|
||||
// ds4-worker --role worker --model <gguf> --layers 20:output \
|
||||
// --coordinator <host> <port> [--cpu|--cuda|--metal] [-c CTX] [-t N]
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <signal.h>
|
||||
#include <limits.h>
|
||||
|
||||
#include "ds4.h"
|
||||
#include "ds4_distributed.h"
|
||||
|
||||
static const char *need_arg(int *i, int argc, char **argv, const char *flag) {
|
||||
if (*i + 1 >= argc) {
|
||||
fprintf(stderr, "ds4-worker: missing value for %s\n", flag);
|
||||
exit(2);
|
||||
}
|
||||
return argv[++(*i)];
|
||||
}
|
||||
|
||||
static int parse_int_arg(const char *s, const char *flag) {
|
||||
char *end = NULL;
|
||||
long v = strtol(s, &end, 10);
|
||||
if (!s[0] || *end || v <= 0 || v > INT_MAX) {
|
||||
fprintf(stderr, "ds4-worker: invalid value for %s: %s\n", flag, s);
|
||||
exit(2);
|
||||
}
|
||||
return (int)v;
|
||||
}
|
||||
|
||||
static ds4_backend default_backend(void) {
|
||||
#if defined(DS4_NO_GPU)
|
||||
return DS4_BACKEND_CPU;
|
||||
#elif defined(__APPLE__)
|
||||
return DS4_BACKEND_METAL;
|
||||
#else
|
||||
return DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
signal(SIGPIPE, SIG_IGN);
|
||||
|
||||
ds4_engine_options opt = {0};
|
||||
opt.backend = default_backend();
|
||||
int ctx_size = 32768;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const char *arg = argv[i];
|
||||
if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
|
||||
fprintf(stdout, "ds4-worker: standalone ds4 distributed worker\n");
|
||||
ds4_dist_usage(stdout);
|
||||
fprintf(stdout, " -m, --model PATH model GGUF (the worker loads only its --layers slice)\n");
|
||||
fprintf(stdout, " -c, --ctx N context size (default 32768)\n");
|
||||
fprintf(stdout, " -t, --threads N CPU threads\n");
|
||||
fprintf(stdout, " --cpu|--cuda|--metal backend override\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
char dist_err[256] = {0};
|
||||
ds4_dist_cli_parse_result dist_parse =
|
||||
ds4_dist_parse_cli_arg(arg, &i, argc, argv, &opt.distributed,
|
||||
dist_err, sizeof(dist_err));
|
||||
if (dist_parse == DS4_DIST_CLI_ERROR) {
|
||||
fprintf(stderr, "ds4-worker: %s\n",
|
||||
dist_err[0] ? dist_err : "invalid distributed option");
|
||||
return 2;
|
||||
}
|
||||
if (dist_parse == DS4_DIST_CLI_MATCHED) continue;
|
||||
|
||||
if (!strcmp(arg, "-m") || !strcmp(arg, "--model")) {
|
||||
opt.model_path = need_arg(&i, argc, argv, arg);
|
||||
} else if (!strcmp(arg, "-c") || !strcmp(arg, "--ctx")) {
|
||||
ctx_size = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "-t") || !strcmp(arg, "--threads")) {
|
||||
opt.n_threads = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "--cpu")) {
|
||||
opt.backend = DS4_BACKEND_CPU;
|
||||
} else if (!strcmp(arg, "--cuda")) {
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
} else if (!strcmp(arg, "--metal")) {
|
||||
opt.backend = DS4_BACKEND_METAL;
|
||||
} else {
|
||||
fprintf(stderr, "ds4-worker: unknown option: %s\n", arg);
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (opt.distributed.role != DS4_DISTRIBUTED_WORKER) {
|
||||
fprintf(stderr, "ds4-worker: --role worker is required\n");
|
||||
return 2;
|
||||
}
|
||||
if (!opt.model_path) {
|
||||
fprintf(stderr, "ds4-worker: --model is required\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
char prep_err[256] = {0};
|
||||
if (ds4_dist_prepare_engine_options(&opt.distributed, &opt,
|
||||
prep_err, sizeof(prep_err)) != 0) {
|
||||
fprintf(stderr, "ds4-worker: %s\n", prep_err);
|
||||
return 2;
|
||||
}
|
||||
|
||||
ds4_engine *engine = NULL;
|
||||
if (ds4_engine_open(&engine, &opt) != 0 || !engine) {
|
||||
fprintf(stderr, "ds4-worker: failed to open engine\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
ds4_dist_generation_options gen = {0};
|
||||
gen.ctx_size = ctx_size;
|
||||
int rc = ds4_dist_run(engine, &opt.distributed, &gen);
|
||||
ds4_engine_close(engine);
|
||||
return rc;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=8960c5ba5ee9db30ba838304373aa4dbec9f7cbd
|
||||
IK_LLAMA_VERSION?=3f40e73c367ad9f0c1b1819f28c7348c26aa340d
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=751ebd17a58a8a513994509214373bb9e6a3d66c
|
||||
LLAMA_VERSION?=5dcb71166686799f0d873eab7386234302d05ecf
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -2204,7 +2204,15 @@ public:
|
||||
// content element — attaching to both would duplicate the first
|
||||
// token since oaicompat_msg_diffs is the same for both.
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
// Upstream llama.cpp (ggml-org/llama.cpp#23884) now emits an initial
|
||||
// "begin" partial whose to_json() returns null, used only to signal the
|
||||
// HTTP layer to flush 200 status headers before any token. gRPC has no
|
||||
// such concept, so there is nothing to emit — the real tokens arrive in
|
||||
// the loop below. Feeding this null into build_reply_from_json would
|
||||
// throw (uncaught) and surface as a generic RPC error.
|
||||
if (first_res_json.is_null()) {
|
||||
// skip the begin-of-stream marker
|
||||
} else if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
// Skip chat deltas for role-init elements (have "role" in
|
||||
@@ -2234,7 +2242,10 @@ public:
|
||||
}
|
||||
|
||||
json res_json = result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
if (res_json.is_null()) {
|
||||
// begin-of-stream marker (see note above) — nothing to emit
|
||||
continue;
|
||||
} else if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
|
||||
@@ -192,6 +192,61 @@ var _ = Describe("Forward", func() {
|
||||
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
|
||||
})
|
||||
|
||||
It("refuses to follow upstream redirects and never leaks the key to the redirect target", func() {
|
||||
// A 3xx from the configured upstream means misconfiguration or a
|
||||
// hijacked/spoofed host. Following it would replay the request —
|
||||
// and the injected API key — to the Location host. Anthropic's
|
||||
// x-api-key is NOT stripped by Go on cross-host redirects, so this
|
||||
// would be a credential leak. The proxy must refuse the redirect.
|
||||
sinkHit := make(chan string, 1)
|
||||
sink := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sinkHit <- r.Header.Get("x-api-key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer sink.Close()
|
||||
|
||||
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, sink.URL, http.StatusFound)
|
||||
}))
|
||||
defer redirector.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_REDIRECT_KEY", "ant-secret")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: redirector.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_REDIRECT_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-no-redirect"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/messages",
|
||||
Method: "POST",
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
|
||||
// Drain the stream; a refused redirect surfaces as a non-EOF error.
|
||||
var streamErr error
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
streamErr = err
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(streamErr).To(HaveOccurred(), "refused redirect must surface as an error")
|
||||
Expect(sinkHit).NotTo(Receive(), "the redirect target must never be contacted")
|
||||
})
|
||||
|
||||
It("handles concurrent calls without interference", func() {
|
||||
// CloudProxy explicitly omits base.SingleThread — independent
|
||||
// Forward streams must not block each other or leak state.
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
|
||||
@@ -48,10 +50,15 @@ type proxyConfig struct {
|
||||
}
|
||||
|
||||
func NewCloudProxy() *CloudProxy {
|
||||
// No Client-level Timeout — that would bound streaming SSE
|
||||
// responses too, which can legitimately last minutes. Per-request
|
||||
// deadlines come from the gRPC stream context.
|
||||
return &CloudProxy{client: &http.Client{}}
|
||||
// httpclient.New refuses redirects outright: the proxy talks to a
|
||||
// single configured upstream API (OpenAI/Anthropic/...) that answers
|
||||
// directly, so a 3xx means misconfiguration, a hijacked upstream, or
|
||||
// DNS trickery — never normal operation. Following it would replay the
|
||||
// request, including the operator's x-api-key (which Go does NOT strip
|
||||
// on cross-host redirects), to an unvetted host and leak the key
|
||||
// (GHSA-3mj3-57v2-4636). It also imposes no body deadline, so streaming
|
||||
// SSE responses that legitimately last minutes are not truncated.
|
||||
return &CloudProxy{client: httpclient.New()}
|
||||
}
|
||||
|
||||
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
|
||||
@@ -426,4 +433,3 @@ func isHopByHopHeader(name string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
5
backend/go/crispasr/.gitignore
vendored
Normal file
5
backend/go/crispasr/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
sources
|
||||
build*
|
||||
libgocrispasr*.so
|
||||
crispasr
|
||||
package
|
||||
30
backend/go/crispasr/CMakeLists.txt
Normal file
30
backend/go/crispasr/CMakeLists.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(gocrispasr LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
add_subdirectory(./sources/CrispASR)
|
||||
|
||||
add_library(gocrispasr MODULE cpp/crispasr_shim.cpp)
|
||||
target_include_directories(gocrispasr PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/ggml/include)
|
||||
# Link the same backend set as crispasr-cli (examples/cli/CMakeLists.txt) so
|
||||
# the session API can dispatch to every compiled-in architecture, not just
|
||||
# whisper. crispasr is the referencer; the backend static libs supply the
|
||||
# per-architecture symbols; ggml is the math/runtime base.
|
||||
target_link_libraries(gocrispasr PRIVATE
|
||||
crispasr
|
||||
parakeet canary canary_ctc cohere granite_speech granite_nle
|
||||
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
|
||||
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice
|
||||
silero-lid pyannote-seg funasr paraformer sensevoice
|
||||
crisp_audio
|
||||
ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gocrispasr PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
set_property(TARGET gocrispasr PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gocrispasr PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
132
backend/go/crispasr/Makefile
Normal file
132
backend/go/crispasr/Makefile
Normal file
@@ -0,0 +1,132 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=05e60432bcb5bc2113f8c395a41e86497c11504a
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
# Keep the build lean: no tests/examples/server/SDL2/curl/ffmpeg (the FROM scratch
|
||||
# image cannot satisfy those runtime deps). All ASR/TTS model backends stay enabled.
|
||||
CMAKE_ARGS+=-DCRISPASR_BUILD_TESTS=OFF -DCRISPASR_BUILD_EXAMPLES=OFF -DCRISPASR_BUILD_SERVER=OFF
|
||||
CMAKE_ARGS+=-DCRISPASR_SDL2=OFF -DCRISPASR_CURL=OFF -DCRISPASR_FFMPEG=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/CrispASR:
|
||||
mkdir -p sources/CrispASR
|
||||
cd sources/CrispASR && \
|
||||
git init && \
|
||||
git remote add origin $(CRISPASR_REPO) && \
|
||||
git fetch origin && \
|
||||
git checkout $(CRISPASR_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
# CrispASR's src/CMakeLists.txt locates its vendored llama.cpp
|
||||
# (crispasr-llama-core, used by the chat C-ABI) via ${CMAKE_SOURCE_DIR},
|
||||
# which assumes CrispASR is the top-level CMake project. We add_subdirectory
|
||||
# it, so ${CMAKE_SOURCE_DIR} is THIS backend dir and the talk-llama sources
|
||||
# aren't found. Rewrite to ${PROJECT_SOURCE_DIR} (the crispasr project root),
|
||||
# which is correct both standalone and as a subproject. Idempotent.
|
||||
sed -i 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgocrispasr-avx.so libgocrispasr-avx2.so libgocrispasr-avx512.so libgocrispasr-fallback.so
|
||||
else
|
||||
VARIANT_TARGETS = libgocrispasr-fallback.so
|
||||
endif
|
||||
|
||||
crispasr: main.go gocrispasr.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o crispasr ./
|
||||
|
||||
package: crispasr
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgocrispasr*.so package sources/CrispASR crispasr
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgocrispasr-avx.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx${RESET})
|
||||
SO_TARGET=libgocrispasr-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx2.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx2${RESET})
|
||||
SO_TARGET=libgocrispasr-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx512.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx512${RESET})
|
||||
SO_TARGET=libgocrispasr-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
libgocrispasr-fallback.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:fallback${RESET})
|
||||
SO_TARGET=libgocrispasr-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-custom: CMakeLists.txt cpp/crispasr_shim.cpp cpp/crispasr_shim.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgocrispasr.so ./$(SO_TARGET)
|
||||
|
||||
test: crispasr
|
||||
CGO_ENABLED=0 $(GOCMD) test -v ./...
|
||||
|
||||
all: crispasr package
|
||||
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
#include "crispasr_shim.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "crispasr.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
// Opaque session types. crispasr.h declares `struct crispasr_session;` but not
|
||||
// the result type nor the open/transcribe/result accessors — those are
|
||||
// CA_EXPORT extern "C" symbols in src/crispasr_c_api.cpp, so we forward-declare
|
||||
// exactly the ones we use. Signatures verified against
|
||||
// sources/CrispASR/src/crispasr_c_api.cpp.
|
||||
struct crispasr_session_result;
|
||||
extern "C" {
|
||||
crispasr_session *crispasr_session_open(const char *model_path, int n_threads);
|
||||
crispasr_session *crispasr_session_open_explicit(const char *model_path,
|
||||
const char *backend_name,
|
||||
int n_threads);
|
||||
int crispasr_session_set_codec_path(crispasr_session *s, const char *path);
|
||||
void crispasr_session_close(crispasr_session *s);
|
||||
const char *crispasr_session_backend(crispasr_session *s);
|
||||
int crispasr_session_set_translate(crispasr_session *s, int enable);
|
||||
crispasr_session_result *crispasr_session_transcribe_lang(
|
||||
crispasr_session *s, const float *pcm, int n_samples, const char *language);
|
||||
int crispasr_session_result_n_segments(crispasr_session_result *r);
|
||||
const char *crispasr_session_result_segment_text(crispasr_session_result *r,
|
||||
int i);
|
||||
int64_t crispasr_session_result_segment_t0(crispasr_session_result *r, int i);
|
||||
int64_t crispasr_session_result_segment_t1(crispasr_session_result *r, int i);
|
||||
void crispasr_session_result_free(crispasr_session_result *r);
|
||||
float *crispasr_session_synthesize(crispasr_session *s, const char *text,
|
||||
int *out_n_samples);
|
||||
void crispasr_pcm_free(float *pcm);
|
||||
int crispasr_session_set_speaker_name(crispasr_session *s, const char *name);
|
||||
int crispasr_session_set_voice(crispasr_session *s, const char *path,
|
||||
const char *ref_text_or_null);
|
||||
}
|
||||
|
||||
static crispasr_session *g_session = nullptr;
|
||||
static crispasr_session_result *g_result = nullptr;
|
||||
|
||||
static struct whisper_vad_context *vctx;
|
||||
static std::vector<float> flat_segs;
|
||||
|
||||
static std::atomic<int> g_abort{0};
|
||||
|
||||
extern "C" void set_abort(int v) {
|
||||
g_abort.store(v, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
|
||||
if (!log) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG:
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_INFO:
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_WARN:
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default: /* Potential future-proofing */
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[%-5s] ", level_str);
|
||||
fputs(log, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (backend_name && *backend_name) {
|
||||
g_session =
|
||||
crispasr_session_open_explicit(model_path, backend_name, threads);
|
||||
} else {
|
||||
g_session = crispasr_session_open(model_path, threads);
|
||||
}
|
||||
if (g_session == nullptr) {
|
||||
fprintf(stderr, "error: failed to open CrispASR session for model\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "info: CrispASR backend selected: %s\n",
|
||||
crispasr_session_backend(g_session));
|
||||
return 0;
|
||||
}
|
||||
|
||||
// set_codec_path forwards a companion file (qwen3-tts codec, orpheus SNAC,
|
||||
// chatterbox s3gen, or mimo-asr tokenizer) to the active session. Returns 0 on
|
||||
// success or when the active backend needs no companion, negative on failure,
|
||||
// and -1 when no session is open.
|
||||
int set_codec_path(const char *path) {
|
||||
return g_session ? crispasr_session_set_codec_path(g_session, path) : -1;
|
||||
}
|
||||
|
||||
int load_model_vad(const char *const model_path) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
struct whisper_vad_context_params vcparams =
|
||||
whisper_vad_default_context_params();
|
||||
|
||||
// XXX: Overridden to false in upstream due to performance?
|
||||
// vcparams.use_gpu = true;
|
||||
|
||||
vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
|
||||
if (vctx == nullptr) {
|
||||
fprintf(stderr, "error: Failed to init model as VAD\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
|
||||
size_t *segs_out_len) {
|
||||
if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
|
||||
fprintf(stderr, "error: failed to detect speech\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct whisper_vad_params params = whisper_vad_default_params();
|
||||
struct whisper_vad_segments *segs =
|
||||
whisper_vad_segments_from_probs(vctx, params);
|
||||
size_t segn = whisper_vad_segments_n_segments(segs);
|
||||
|
||||
// fprintf(stderr, "Got segments %zd\n", segn);
|
||||
|
||||
flat_segs.clear();
|
||||
|
||||
for (int i = 0; i < segn; i++) {
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
|
||||
}
|
||||
|
||||
// fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
|
||||
// segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
|
||||
// flat_segs.size());
|
||||
*segs_out = flat_segs.data();
|
||||
*segs_out_len = flat_segs.size();
|
||||
|
||||
// fprintf(stderr, "freeing segs\n");
|
||||
whisper_vad_free_segments(segs);
|
||||
|
||||
// fprintf(stderr, "returning\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// threads, diarize and prompt are accepted for Go-side API parity but unused
|
||||
// in Phase 1: the thread count is fixed at session open, and diarization and
|
||||
// the initial prompt are separate CrispASR features not yet wired through the
|
||||
// session ASR path.
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt) {
|
||||
(void)threads;
|
||||
(void)diarize;
|
||||
(void)prompt;
|
||||
|
||||
if (!g_session) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Reset stale abort flag from any prior cancelled call. set_abort remains
|
||||
// best-effort: the session transcribe call is blocking and exposes no abort
|
||||
// hook, so a mid-decode abort cannot interrupt it.
|
||||
g_abort.store(0, std::memory_order_relaxed);
|
||||
|
||||
crispasr_session_set_translate(g_session, translate ? 1 : 0);
|
||||
|
||||
if (g_result) {
|
||||
crispasr_session_result_free(g_result);
|
||||
g_result = nullptr;
|
||||
}
|
||||
|
||||
const char *language = (lang && *lang) ? lang : nullptr;
|
||||
g_result = crispasr_session_transcribe_lang(g_session, pcmf32, (int)pcmf32_len,
|
||||
language);
|
||||
if (!g_result) {
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
*segs_out_len = crispasr_session_result_n_segments(g_result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char *get_segment_text(int i) {
|
||||
if (!g_result) {
|
||||
return "";
|
||||
}
|
||||
return crispasr_session_result_segment_text(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t0(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t0(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t1(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t1(g_result, i);
|
||||
}
|
||||
|
||||
const char *get_backend(void) {
|
||||
return g_session ? crispasr_session_backend(g_session) : "";
|
||||
}
|
||||
|
||||
// TTS uses the already-open session (crispasr_session_open auto-detects a TTS
|
||||
// model). Output is 24 kHz mono float PCM (upstream CrispASR convention),
|
||||
// malloc'd by the C API; the caller must release it via tts_free.
|
||||
float *tts_synthesize(const char *text, int *out_n_samples) {
|
||||
if (out_n_samples) *out_n_samples = 0;
|
||||
if (!g_session || !text) return nullptr;
|
||||
return crispasr_session_synthesize(g_session, text, out_n_samples);
|
||||
}
|
||||
|
||||
void tts_free(float *pcm) {
|
||||
if (pcm) crispasr_pcm_free(pcm);
|
||||
}
|
||||
|
||||
int tts_set_voice(const char *name) {
|
||||
if (!g_session || !name || !*name) return 0;
|
||||
return crispasr_session_set_speaker_name(g_session, name);
|
||||
}
|
||||
|
||||
// tts_set_voice_file loads a voice from a file: a .gguf path selects a voice
|
||||
// pack, a .wav path with a non-empty ref_text performs zero-shot voice cloning
|
||||
// (the C API returns -2 when ref_text is required but missing). Returns -1 when
|
||||
// no session is open or path is null.
|
||||
int tts_set_voice_file(const char *path, const char *ref_text) {
|
||||
if (!g_session || !path) return -1;
|
||||
const char *ref = (ref_text && *ref_text) ? ref_text : nullptr;
|
||||
return crispasr_session_set_voice(g_session, path, ref);
|
||||
}
|
||||
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
extern "C" {
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name);
|
||||
int set_codec_path(const char *path);
|
||||
int load_model_vad(const char *const model_path);
|
||||
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
||||
size_t *segs_out_len);
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt);
|
||||
const char *get_segment_text(int i);
|
||||
int64_t get_segment_t0(int i);
|
||||
int64_t get_segment_t1(int i);
|
||||
const char *get_backend(void);
|
||||
void set_abort(int v);
|
||||
float *tts_synthesize(const char *text, int *out_n_samples); // 24kHz mono float, malloc'd; NULL on failure
|
||||
void tts_free(float *pcm);
|
||||
int tts_set_voice(const char *name); // best-effort speaker selection; 0 ok
|
||||
int tts_set_voice_file(const char *path, const char *ref_text); // load voice pack (.gguf) or zero-shot clone (.wav + ref_text)
|
||||
}
|
||||
497
backend/go/crispasr/gocrispasr.go
Normal file
497
backend/go/crispasr/gocrispasr.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string, threads int, backendName string) int
|
||||
CppSetCodecPath func(path string) int
|
||||
CppLoadModelVAD func(modelPath string) int
|
||||
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
||||
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int
|
||||
CppGetSegmentText func(i int) string
|
||||
CppGetSegmentStart func(i int) int64
|
||||
CppGetSegmentEnd func(i int) int64
|
||||
CppGetBackend func() string
|
||||
CppSetAbort func(v int)
|
||||
CppTTSSynthesize func(text string, outNSamples unsafe.Pointer) uintptr
|
||||
CppTTSFree func(ptr uintptr)
|
||||
CppTTSSetVoice func(name string) int
|
||||
CppTTSSetVoiceFile func(path string, refText string) int
|
||||
)
|
||||
|
||||
type CrispASR struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
// splitOption splits a "prefix:value" model option into its key and value,
|
||||
// matching the convention used by other backends (see sherpa-onnx). It returns
|
||||
// ok=false when the option carries no ':' separator.
|
||||
func splitOption(oo string) (key, value string, ok bool) {
|
||||
parts := strings.SplitN(oo, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
|
||||
func (w *CrispASR) Load(opts *pb.ModelOptions) error {
|
||||
vadOnly := false
|
||||
backendName := ""
|
||||
codecPath := ""
|
||||
speakerName := ""
|
||||
voicePath := ""
|
||||
voiceRefText := ""
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
if oo == "vad_only" {
|
||||
vadOnly = true
|
||||
continue
|
||||
}
|
||||
switch key, value, ok := splitOption(oo); {
|
||||
case ok && key == "backend":
|
||||
backendName = value
|
||||
case ok && key == "codec":
|
||||
codecPath = value
|
||||
case ok && key == "speaker":
|
||||
speakerName = value
|
||||
case ok && key == "voice":
|
||||
voicePath = value
|
||||
case ok && key == "voice_text":
|
||||
voiceRefText = value
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
}
|
||||
|
||||
if vadOnly {
|
||||
if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR VAD model")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve a relative companion path against the model directory so a config
|
||||
// can reference a sibling codec/tokenizer file by name alone.
|
||||
if codecPath != "" && !filepath.IsAbs(codecPath) {
|
||||
codecPath = filepath.Join(filepath.Dir(opts.ModelFile), codecPath)
|
||||
}
|
||||
|
||||
// A voice file (.gguf pack or .wav prompt) is resolved against the model
|
||||
// directory just like the codec, so a config can reference a sibling file.
|
||||
if voicePath != "" && !filepath.IsAbs(voicePath) {
|
||||
voicePath = filepath.Join(filepath.Dir(opts.ModelFile), voicePath)
|
||||
}
|
||||
|
||||
if ret := CppLoadModel(opts.ModelFile, int(opts.Threads), backendName); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR transcription model")
|
||||
}
|
||||
|
||||
// Load the companion file (codec/tokenizer/s3gen) after the session is open.
|
||||
// rc==0 means success or "not applicable" for the active backend; only a
|
||||
// negative code is fatal.
|
||||
if codecPath != "" {
|
||||
if rc := CppSetCodecPath(codecPath); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load companion file %q (rc=%d)", codecPath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR companion file loaded: %s\n", codecPath)
|
||||
}
|
||||
|
||||
// Apply the Load-time default voice. A baked speaker (speaker:) is selected
|
||||
// by name and is best-effort: a backend that can't honor it is logged, not
|
||||
// fatal. A voice file (voice:) is a hard requirement once configured, so a
|
||||
// negative rc fails Load.
|
||||
if speakerName != "" {
|
||||
if rc := CppTTSSetVoice(speakerName); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: speaker %q not applied (rc=%d)\n", speakerName, rc)
|
||||
}
|
||||
}
|
||||
if voicePath != "" {
|
||||
if rc := CppTTSSetVoiceFile(voicePath, voiceRefText); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load voice %q (rc=%d)", voicePath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR voice loaded: %s\n", voicePath)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "CrispASR backend selected: %s\n", CppGetBackend())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
audio := req.Audio
|
||||
// We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
|
||||
segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
|
||||
segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
|
||||
return pb.VADResponse{}, fmt.Errorf("Failed VAD")
|
||||
}
|
||||
|
||||
// Happens when CPP vector has not had any elements pushed to it
|
||||
if segsPtr == 0 {
|
||||
return pb.VADResponse{
|
||||
Segments: []*pb.VADSegment{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
|
||||
// however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
|
||||
segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen) //nolint:govet // segsPtr addresses C++-owned heap memory passed back through the cgo-free purego boundary; the uintptr->Pointer round-trip is intentional and the buffer outlives this read.
|
||||
|
||||
vadSegments := []*pb.VADSegment{}
|
||||
for i := range len(segs) >> 1 {
|
||||
s := segs[2*i] / 100
|
||||
t := segs[2*i+1] / 100
|
||||
vadSegments = append(vadSegments, &pb.VADSegment{
|
||||
Start: s,
|
||||
End: t,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.VADResponse{
|
||||
Segments: vadSegments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
// Watcher: flips the C-side abort flag when ctx is cancelled. The
|
||||
// goroutine is joined synchronously (close(done) signals it to exit,
|
||||
// wg.Wait() blocks until it has) so a late CppSetAbort(1) cannot fire
|
||||
// after the function returns and corrupt the next transcription call.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
text := ""
|
||||
for i := range int(segsLen) {
|
||||
// segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895
|
||||
s := CppGetSegmentStart(i) * (10000000)
|
||||
t := CppGetSegmentEnd(i) * (10000000)
|
||||
// The session result can emit bytes that aren't valid UTF-8 (e.g. a
|
||||
// multibyte codepoint split across token boundaries); protobuf string
|
||||
// fields reject those at marshal time. Scrub before the value escapes
|
||||
// cgo. The session result is segment+word based and exposes no token
|
||||
// IDs, so Tokens is left empty.
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
|
||||
segment := &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
|
||||
text += " " + strings.TrimSpace(txt)
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: strings.TrimSpace(text),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream runs the session transcribe to completion and then
|
||||
// emits one delta per non-empty segment, followed by a final TranscriptResult.
|
||||
// Progressive/real-time streaming isn't available via the session API (there
|
||||
// is no per-decode callback), so deltas are emitted per-segment after the
|
||||
// blocking decode returns rather than as segments are produced. The offline
|
||||
// AudioTranscription is unchanged; both paths share the session and the
|
||||
// SingleThread concurrency model.
|
||||
func (w *CrispASR) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
|
||||
// Same abort-watcher pattern as AudioTranscription. Joined synchronously
|
||||
// so a late CppSetAbort(1) cannot fire after this function returns.
|
||||
// Best-effort only: the session transcribe is blocking with no abort hook.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
// Walk the segments once: emit a delta per non-empty segment and build the
|
||||
// final TranscriptResult.Segments alongside. The first delta has no leading
|
||||
// space and subsequent ones are prefixed with a single space, so
|
||||
// concat(deltas) == final.Text exactly, matching the e2e contract.
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
var assembled strings.Builder
|
||||
for i := range int(segsLen) {
|
||||
s := CppGetSegmentStart(i) * 10000000
|
||||
t := CppGetSegmentEnd(i) * 10000000
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
segments = append(segments, &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
})
|
||||
|
||||
trimmed := strings.TrimSpace(txt)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
var delta string
|
||||
if assembled.Len() == 0 {
|
||||
delta = trimmed
|
||||
} else {
|
||||
delta = " " + trimmed
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
assembled.WriteString(delta)
|
||||
}
|
||||
|
||||
final := &pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: assembled.String(),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: final}
|
||||
return nil
|
||||
}
|
||||
|
||||
// synthesize returns 24 kHz mono float32 PCM for text via the open session.
|
||||
func (w *CrispASR) synthesize(text string) ([]float32, error) {
|
||||
if text == "" {
|
||||
return nil, fmt.Errorf("crispasr: TTS requires non-empty text")
|
||||
}
|
||||
var n int32
|
||||
ptr := CppTTSSynthesize(text, unsafe.Pointer(&n))
|
||||
if ptr == 0 || n <= 0 {
|
||||
return nil, fmt.Errorf("crispasr: synthesis failed (the loaded model may not be a supported TTS backend, or needs extra config e.g. orpheus SNAC codec)")
|
||||
}
|
||||
defer CppTTSFree(ptr)
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free.
|
||||
out := make([]float32, int(n)) // copy out of C memory before free
|
||||
copy(out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// setVoice applies a per-call speaker/voice override (best effort). CrispASR
|
||||
// returns a negative code when the active backend can't honor the name; we log
|
||||
// it rather than fail, so an unknown voice falls back to the default speaker.
|
||||
func setVoice(voice string) {
|
||||
v := strings.TrimSpace(voice)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if rc := CppTTSSetVoice(v); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: voice %q not applied by the active TTS backend (rc=%d); using default\n", v, rc)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *CrispASR) TTS(req *pb.TTSRequest) error {
|
||||
if req.Dst == "" {
|
||||
return fmt.Errorf("crispasr: TTS requires a destination path")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWAV24k(req.Dst, pcm)
|
||||
}
|
||||
|
||||
// TTSStream is the streaming counterpart to TTS. CrispASR has no progressive
|
||||
// (native streaming) synth, so we synthesize the whole utterance, encode it to
|
||||
// a 24 kHz WAV, and emit the encoded bytes as a single chunk. The gRPC server
|
||||
// wrapper (pkg/grpc/server.go:TTSStream) ranges over the channel until it is
|
||||
// closed, so this method owns the close - mirrors vibevoice-cpp's TTSStream.
|
||||
func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
defer close(results)
|
||||
|
||||
if req.Text == "" {
|
||||
return fmt.Errorf("crispasr: TTSStream requires text")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp("", "crispasr-tts-stream-*.wav")
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: tempfile: %w", err)
|
||||
}
|
||||
dst := tmp.Name()
|
||||
if err := tmp.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close tempfile: %w", err)
|
||||
}
|
||||
defer func() { _ = os.Remove(dst) }()
|
||||
|
||||
if err := writeWAV24k(dst, pcm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoded, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: read tempfile: %w", err)
|
||||
}
|
||||
results <- encoded
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeWAV24k writes pcm as a 24000 Hz, mono, 16-bit PCM WAV at dst.
|
||||
func writeWAV24k(dst string, pcm []float32) error {
|
||||
f, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: create %q: %w", dst, err)
|
||||
}
|
||||
|
||||
enc := wav.NewEncoder(f, 24000, 16, 1, 1)
|
||||
ints := make([]int, len(pcm))
|
||||
for i, s := range pcm {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
ints[i] = int(s * 32767)
|
||||
}
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 24000},
|
||||
Data: ints,
|
||||
SourceBitDepth: 16,
|
||||
}
|
||||
if err := enc.Write(buf); err != nil {
|
||||
_ = enc.Close()
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: encode WAV: %w", err)
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: finalize WAV: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close %q: %w", dst, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
193
backend/go/crispasr/gocrispasr_test.go
Normal file
193
backend/go/crispasr/gocrispasr_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestCrispASR(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "CrispASR Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// bridge without spinning up the gRPC server. Skips the current spec when the
|
||||
// shared library isn't present (e.g. running before `make backends/whisper`).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
if _, err := os.Stat(libName); err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
|
||||
purego.RegisterLibFunc(&CppSetCodecPath, gosd, "set_codec_path")
|
||||
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
|
||||
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
|
||||
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
|
||||
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
|
||||
purego.RegisterLibFunc(&CppGetBackend, gosd, "get_backend")
|
||||
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
|
||||
purego.RegisterLibFunc(&CppTTSSynthesize, gosd, "tts_synthesize")
|
||||
purego.RegisterLibFunc(&CppTTSFree, gosd, "tts_free")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoice, gosd, "tts_set_voice")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoiceFile, gosd, "tts_set_voice_file")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("whisper library not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if either
|
||||
// env var is unset. The test never runs in default CI — it requires a real
|
||||
// whisper model and a long audio file (~3 minutes) on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("CRISPASR_MODEL_PATH")
|
||||
audioPath := os.Getenv("CRISPASR_AUDIO_PATH")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set CRISPASR_MODEL_PATH and CRISPASR_AUDIO_PATH to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// ttsModelOrSkip returns the TTS model path or skips the spec when the env var
|
||||
// is unset. Like the transcription fixtures, this never runs in default CI — it
|
||||
// needs a real TTS model (e.g. a vibevoice GGUF) on disk.
|
||||
func ttsModelOrSkip() string {
|
||||
modelPath := os.Getenv("CRISPASR_TTS_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
Skip("set CRISPASR_TTS_MODEL_PATH to run this spec")
|
||||
}
|
||||
return modelPath
|
||||
}
|
||||
|
||||
var _ = Describe("CrispASR", func() {
|
||||
Context("AudioTranscription cancellation", func() {
|
||||
It("returns codes.Canceled on a pre-cancelled context and still succeeds afterwards", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
// The session transcribe is blocking and exposes no abort hook, so
|
||||
// a mid-decode cancel can't interrupt it. The contract we can rely
|
||||
// on is the pre-call ctx.Err() check: a context cancelled before
|
||||
// the call must yield codes.Canceled without starting a decode.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "expected pre-cancelled context to fail")
|
||||
st, ok := status.FromError(err)
|
||||
Expect(ok).To(BeTrue(), "expected gRPC status error, got %v", err)
|
||||
Expect(st.Code()).To(Equal(codes.Canceled), "expected codes.Canceled, got %v", err)
|
||||
|
||||
// Subsequent transcription must succeed — proves g_abort reset.
|
||||
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "post-cancel transcription failed")
|
||||
Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("emits multiple deltas progressively for a multi-segment clip", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- w.AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
Stream: true,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var assembled strings.Builder
|
||||
var finalText string
|
||||
var finalSegmentCount int
|
||||
for chunk := range results {
|
||||
if d := chunk.GetDelta(); d != "" {
|
||||
deltas = append(deltas, d)
|
||||
assembled.WriteString(d)
|
||||
}
|
||||
if final := chunk.GetFinalResult(); final != nil {
|
||||
finalText = final.GetText()
|
||||
finalSegmentCount = len(final.GetSegments())
|
||||
}
|
||||
}
|
||||
Expect(<-done).ToNot(HaveOccurred())
|
||||
|
||||
// One delta per non-empty segment is emitted after the blocking
|
||||
// decode returns (the session API has no per-decode callback), so a
|
||||
// multi-segment clip MUST produce >=2 delta events, and
|
||||
// concat(deltas) MUST equal final.Text exactly.
|
||||
Expect(len(deltas)).To(BeNumerically(">=", 2),
|
||||
"expected multiple deltas from a multi-segment clip, got %d (assembled=%q)",
|
||||
len(deltas), assembled.String())
|
||||
Expect(finalSegmentCount).To(BeNumerically(">=", 2),
|
||||
"expected final to carry multiple segments")
|
||||
Expect(assembled.String()).To(Equal(finalText),
|
||||
"concat(deltas) must equal final.Text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("TTS", func() {
|
||||
It("synthesizes a non-empty WAV", func() {
|
||||
ttsModel := ttsModelOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: ttsModel})).To(Succeed())
|
||||
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR.", Dst: dst})).To(Succeed())
|
||||
|
||||
info, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred(), "synthesized WAV should exist at %q", dst)
|
||||
// A real 24 kHz mono WAV is a 44-byte header plus samples; anything
|
||||
// this small would mean an empty/failed synth.
|
||||
Expect(info.Size()).To(BeNumerically(">", 1024),
|
||||
"expected a non-trivial WAV, got %d bytes", info.Size())
|
||||
})
|
||||
})
|
||||
})
|
||||
58
backend/go/crispasr/main.go
Normal file
58
backend/go/crispasr/main.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "load_model"},
|
||||
{&CppSetCodecPath, "set_codec_path"},
|
||||
{&CppLoadModelVAD, "load_model_vad"},
|
||||
{&CppVAD, "vad"},
|
||||
{&CppTranscribe, "transcribe"},
|
||||
{&CppGetSegmentText, "get_segment_text"},
|
||||
{&CppGetSegmentStart, "get_segment_t0"},
|
||||
{&CppGetSegmentEnd, "get_segment_t1"},
|
||||
{&CppGetBackend, "get_backend"},
|
||||
{&CppSetAbort, "set_abort"},
|
||||
{&CppTTSSynthesize, "tts_synthesize"},
|
||||
{&CppTTSFree, "tts_free"},
|
||||
{&CppTTSSetVoice, "tts_set_voice"},
|
||||
{&CppTTSSetVoiceFile, "tts_set_voice_file"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &CrispASR{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
65
backend/go/crispasr/package.sh
Executable file
65
backend/go/crispasr/package.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
# This script is used in the final stage of the Dockerfile
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/crispasr $CURDIR/package/
|
||||
cp -fv $CURDIR/libgocrispasr-*.so $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/crispasr/run.sh
Executable file
52
backend/go/crispasr/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgocrispasr-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export CRISPASR_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/crispasr "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/crispasr "$@"
|
||||
@@ -9,7 +9,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
# LocalVQE upstream version pin. Bump to a specific commit when picking up
|
||||
# a new release; `main` works for development but is not reproducible.
|
||||
LOCALVQE_REPO?=https://github.com/localai-org/LocalVQE
|
||||
LOCALVQE_VERSION?=72bfb4c6
|
||||
LOCALVQE_VERSION?=b0f0378a450e87c871b85689554801601ca56d98
|
||||
|
||||
# LocalVQE handles CPU feature selection internally (it ships the multiple
|
||||
# libggml-cpu-*.so variants and its loader picks the best one at runtime
|
||||
@@ -27,7 +27,8 @@ endif
|
||||
|
||||
# LocalVQE upstream supports CPU + Vulkan only. Other BUILD_TYPE values
|
||||
# fall through to the default CPU build — Vulkan is already as fast as the
|
||||
# specialised GPU paths would be on this 1.3 M-parameter model.
|
||||
# specialised GPU paths would be on these small (1.3 M–4.8 M parameter)
|
||||
# models.
|
||||
ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DLOCALVQE_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -46,24 +46,24 @@ const (
|
||||
// through the options builder (CppOptionsNew + setters + CppNewWithOptions)
|
||||
// — the bare localvqe_new path doesn't expose backend / device selection.
|
||||
var (
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
)
|
||||
|
||||
// LocalVQE speaks gRPC against LocalVQE's flat C ABI. The streaming
|
||||
@@ -490,11 +490,14 @@ func (v *LocalVQE) applyStreamConfig(cfg *pb.AudioTransformStreamConfig) error {
|
||||
|
||||
// ---- WAV I/O ----------------------------------------------------------
|
||||
//
|
||||
// Minimal mono PCM WAV reader/writer. Only handles the subset LocalVQE
|
||||
// cares about (mono, 16-bit signed, no extensible chunks). For broader
|
||||
// audio support the HTTP layer's `audio.NormalizeAudioFile` already
|
||||
// converts arbitrary input to a canonical WAV before we see it; this
|
||||
// reader just decodes the canonical shape.
|
||||
// Reader/writer for the mono 16-bit PCM shape LocalVQE works with. Decoding
|
||||
// goes through the shared go-audio/wav decoder (as the whisper and parakeet
|
||||
// backends do) so RIFF chunk walking is handled robustly — an 18/40-byte
|
||||
// extensible `fmt ` chunk, or JUNK/bext/LIST metadata before or after `data`
|
||||
// (e.g. ffmpeg's trailing "Lavf" tag), is skipped rather than spliced into
|
||||
// the PCM stream as an audible click. The HTTP layer normalises arbitrary
|
||||
// input to WAV before we see it, but that WAV is ffmpeg output and is not
|
||||
// guaranteed to be the canonical 44-byte layout.
|
||||
|
||||
func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
f, err := os.Open(path)
|
||||
@@ -502,35 +505,26 @@ func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
header := make([]byte, 44)
|
||||
if _, err := io.ReadFull(f, header); err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
buf, err := wav.NewDecoder(f).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("decode WAV: %w", err)
|
||||
}
|
||||
if string(header[0:4]) != "RIFF" || string(header[8:12]) != "WAVE" {
|
||||
if buf == nil || buf.Format == nil {
|
||||
return nil, 0, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
channels := binary.LittleEndian.Uint16(header[22:24])
|
||||
sampleRate := binary.LittleEndian.Uint32(header[24:28])
|
||||
bitsPerSample := binary.LittleEndian.Uint16(header[34:36])
|
||||
|
||||
if channels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", channels)
|
||||
if buf.Format.NumChannels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", buf.Format.NumChannels)
|
||||
}
|
||||
if bitsPerSample != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", bitsPerSample)
|
||||
if buf.SourceBitDepth != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", buf.SourceBitDepth)
|
||||
}
|
||||
|
||||
rest, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
if len(buf.Data) == 0 {
|
||||
return nil, 0, fmt.Errorf("WAV has no audio data")
|
||||
}
|
||||
n := len(rest) / 2
|
||||
out := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
s := int16(binary.LittleEndian.Uint16(rest[i*2 : i*2+2]))
|
||||
out[i] = float32(s) / 32768.0
|
||||
}
|
||||
return out, int(sampleRate), nil
|
||||
// AsFloat32Buffer normalises by 2^(bitDepth-1) == /32768 for 16-bit,
|
||||
// matching the model's expected [-1, 1) input range.
|
||||
return buf.AsFloat32Buffer().Data, buf.Format.SampleRate, nil
|
||||
}
|
||||
|
||||
func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
@@ -546,13 +540,13 @@ func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
binary.LittleEndian.PutUint32(header[4:8], 36+dataLen)
|
||||
copy(header[8:12], []byte("WAVE"))
|
||||
copy(header[12:16], []byte("fmt "))
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[24:28], uint32(sampleRate))
|
||||
binary.LittleEndian.PutUint32(header[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
copy(header[36:40], []byte("data"))
|
||||
binary.LittleEndian.PutUint32(header[40:44], dataLen)
|
||||
if _, err := f.Write(header); err != nil {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -92,6 +94,147 @@ var _ = Describe("LocalVQE-cpp", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("readMonoWAVf32 chunk parsing", func() {
|
||||
// chunk builds a word-aligned RIFF sub-chunk (id + size + body + pad).
|
||||
chunk := func(id string, body []byte) []byte {
|
||||
out := append([]byte(id), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
out = append(out, body...)
|
||||
if len(body)&1 == 1 {
|
||||
out = append(out, 0) // pad byte for odd-sized chunks
|
||||
}
|
||||
return out
|
||||
}
|
||||
// fmtBody returns a PCM `fmt ` chunk body. extra bytes simulate the
|
||||
// 18/40-byte extensible form (cbSize + extension).
|
||||
fmtBody := func(channels, bits uint16, rate uint32, extra int) []byte {
|
||||
b := make([]byte, 16+extra)
|
||||
binary.LittleEndian.PutUint16(b[0:2], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(b[2:4], channels)
|
||||
binary.LittleEndian.PutUint32(b[4:8], rate)
|
||||
binary.LittleEndian.PutUint32(b[8:12], rate*uint32(channels)*uint32(bits)/8)
|
||||
binary.LittleEndian.PutUint16(b[12:14], channels*bits/8)
|
||||
binary.LittleEndian.PutUint16(b[14:16], bits)
|
||||
if extra >= 2 {
|
||||
binary.LittleEndian.PutUint16(b[16:18], uint16(extra-2)) // cbSize
|
||||
}
|
||||
return b
|
||||
}
|
||||
// pcm encodes int16 samples little-endian.
|
||||
pcm := func(samples ...int16) []byte {
|
||||
b := make([]byte, len(samples)*2)
|
||||
for i, s := range samples {
|
||||
binary.LittleEndian.PutUint16(b[i*2:i*2+2], uint16(s))
|
||||
}
|
||||
return b
|
||||
}
|
||||
riff := func(chunks ...[]byte) []byte {
|
||||
body := []byte("WAVE")
|
||||
for _, c := range chunks {
|
||||
body = append(body, c...)
|
||||
}
|
||||
out := append([]byte("RIFF"), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
return append(out, body...)
|
||||
}
|
||||
writeWAV := func(b []byte) string {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "in.wav")
|
||||
Expect(os.WriteFile(p, b, 0o600)).To(Succeed())
|
||||
return p
|
||||
}
|
||||
// A canonical sample run with distinct values so any off-by-one /
|
||||
// misalignment shows up as wrong numbers, not just wrong length.
|
||||
samples := []int16{1000, -2000, 3000, -4000, 5000, -6000}
|
||||
expectSamples := func(got []float32) {
|
||||
Expect(got).To(HaveLen(len(samples)))
|
||||
for i, s := range samples {
|
||||
Expect(got[i]).To(BeNumerically("~", float32(s)/32768.0, 1e-6))
|
||||
}
|
||||
}
|
||||
|
||||
It("reads a canonical 44-byte WAV", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("ignores a LIST/JUNK chunk placed before data (no leading-impulse splice)", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("JUNK", []byte("padding-bytes-here!")), // odd length → exercises pad
|
||||
chunk("LIST", []byte("INFOISFTLavf60.0")),
|
||||
chunk("data", pcm(samples...)),
|
||||
))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out) // not corrupted by the preceding chunks
|
||||
})
|
||||
|
||||
It("honours the data chunk size and drops a trailing metadata chunk", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("data", pcm(samples...)),
|
||||
chunk("LIST", []byte("INFOISFTLavf60.16.100")), // ffmpeg trailer tag
|
||||
))
|
||||
out, _, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectSamples(out) // trailing LIST bytes not decoded as PCM
|
||||
})
|
||||
|
||||
It("handles the 18-byte extensible fmt chunk", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 2)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("rejects non-mono input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(2, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("mono"))
|
||||
})
|
||||
|
||||
It("rejects non-16-bit input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 8, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("16-bit"))
|
||||
})
|
||||
|
||||
It("rejects a non-WAV file", func() {
|
||||
p := writeWAV([]byte("not a riff file at all"))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors when the data chunk is missing", func() {
|
||||
// fmt but no data: the decoder must fail rather than return an
|
||||
// empty (or garbage) sample slice. The exact message is the
|
||||
// decoder's, so just assert it errors.
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("round-trips through writeMonoWAVf32", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "rt.wav")
|
||||
in := []float32{0.1, -0.2, 0.3, -0.4}
|
||||
Expect(writeMonoWAVf32(p, in, 16000)).To(Succeed())
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
Expect(out).To(HaveLen(len(in)))
|
||||
for i := range in {
|
||||
Expect(out[i]).To(BeNumerically("~", in[i], 1e-4))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("model-gated integration (LOCALVQE_MODEL_PATH)", func() {
|
||||
It("load + sample rate + hop + fft", func() {
|
||||
path := modelPathOrSkip()
|
||||
|
||||
11
backend/go/parakeet-cpp/.gitignore
vendored
Normal file
11
backend/go/parakeet-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
.cache/
|
||||
sources/
|
||||
build/
|
||||
package/
|
||||
parakeet-cpp-grpc
|
||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||
# symlinked for local dev; the real sources live in parakeet.cpp upstream.
|
||||
*.so
|
||||
*.so.*
|
||||
parakeet_capi.h
|
||||
compile_commands.json
|
||||
93
backend/go/parakeet-cpp/Makefile
Normal file
93
backend/go/parakeet-cpp/Makefile
Normal file
@@ -0,0 +1,93 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree parakeet.cpp
|
||||
# build, you can symlink the .so + header into this directory and skip
|
||||
# the clone/cmake steps entirely, e.g.:
|
||||
#
|
||||
# ln -sf /path/to/parakeet.cpp/build-shared/libparakeet.so .
|
||||
# ln -sf /path/to/parakeet.cpp/include/parakeet_capi.h .
|
||||
# go build -o parakeet-cpp-grpc .
|
||||
#
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
# Build ggml statically into libparakeet.so (PIC) so the shared lib is
|
||||
# self-contained: dlopen needs no libggml*.so alongside it, only system libs
|
||||
# (libstdc++/libgomp/libc) that the runtime image already provides.
|
||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DPARAKEET_SHARED=ON -DPARAKEET_BUILD_CLI=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# parakeet.cpp gates its GGML backends behind PARAKEET_GGML_* options and does
|
||||
# set(GGML_CUDA ${PARAKEET_GGML_CUDA} CACHE BOOL "" FORCE), so a bare -DGGML_CUDA=ON
|
||||
# is overwritten back to OFF and the build silently falls back to CPU. Forward the
|
||||
# PARAKEET_GGML_* options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_HIP=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_VULKAN=ON
|
||||
endif
|
||||
|
||||
.PHONY: parakeet-cpp-grpc package build clean purge test all
|
||||
|
||||
all: parakeet-cpp-grpc
|
||||
|
||||
# Clone the upstream parakeet.cpp source at the pinned commit. Directory
|
||||
# acts as the target so make only re-clones when missing. After a
|
||||
# PARAKEET_VERSION bump, run 'make purge && make' to refetch.
|
||||
sources/parakeet.cpp:
|
||||
mkdir -p sources/parakeet.cpp
|
||||
cd sources/parakeet.cpp && \
|
||||
git init -q && \
|
||||
git remote add origin $(PARAKEET_REPO) && \
|
||||
git fetch --depth 1 origin $(PARAKEET_VERSION) && \
|
||||
git checkout FETCH_HEAD && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Build the shared lib + header out-of-tree, then stage them next to the
|
||||
# Go sources so purego.Dlopen("libparakeet.so") and the cgo-less build
|
||||
# both pick them up.
|
||||
libparakeet.so: sources/parakeet.cpp
|
||||
cmake -B sources/parakeet.cpp/build-shared -S sources/parakeet.cpp $(CMAKE_ARGS)
|
||||
cmake --build sources/parakeet.cpp/build-shared --config Release -j$(JOBS)
|
||||
cp -fv sources/parakeet.cpp/build-shared/libparakeet.so* ./ 2>/dev/null || true
|
||||
cp -fv sources/parakeet.cpp/include/parakeet_capi.h ./
|
||||
|
||||
parakeet-cpp-grpc: libparakeet.so main.go goparakeetcpp.go
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o parakeet-cpp-grpc .
|
||||
|
||||
package: parakeet-cpp-grpc
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
# Test target. Smoke test is gated on PARAKEET_BACKEND_TEST_MODEL +
|
||||
# PARAKEET_BACKEND_TEST_WAV; without them the spec auto-skips.
|
||||
test:
|
||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||
|
||||
clean: purge
|
||||
rm -rf libparakeet.so* parakeet_capi.h package parakeet-cpp-grpc
|
||||
|
||||
purge:
|
||||
rm -rf sources/parakeet.cpp
|
||||
79
backend/go/parakeet-cpp/batcher.go
Normal file
79
backend/go/parakeet-cpp/batcher.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package main
|
||||
|
||||
import "time"
|
||||
|
||||
// batchRequest is one in-flight unary transcription waiting to be batched.
|
||||
// In production pcm/decoder are set; tag is an opaque marker used by tests.
|
||||
type batchRequest struct {
|
||||
pcm []float32
|
||||
decoder int32
|
||||
tag string
|
||||
reply chan batchReply
|
||||
}
|
||||
|
||||
// batchReply carries one per-item JSON object string (an element of the C-API's
|
||||
// JSON array) or an error back to the waiting handler goroutine.
|
||||
type batchReply struct {
|
||||
json string
|
||||
err error
|
||||
}
|
||||
|
||||
// batcher coalesces concurrent batchRequests into batched runBatch calls. A
|
||||
// single run() goroutine is the sole caller of runBatch, so runBatch (which in
|
||||
// production calls the thread-unsafe C engine) is never entered concurrently.
|
||||
type batcher struct {
|
||||
submit chan *batchRequest
|
||||
maxSize int
|
||||
maxWait time.Duration
|
||||
runBatch func(reqs []*batchRequest) // must deliver a reply to every req
|
||||
}
|
||||
|
||||
func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchRequest)) *batcher {
|
||||
if maxSize < 1 {
|
||||
maxSize = 1
|
||||
}
|
||||
return &batcher{
|
||||
submit: make(chan *batchRequest),
|
||||
maxSize: maxSize,
|
||||
maxWait: maxWait,
|
||||
runBatch: runBatch,
|
||||
}
|
||||
}
|
||||
|
||||
// run is the dispatcher loop: accumulate submitted requests until either maxSize
|
||||
// is reached or maxWait elapses since the first queued request, then dispatch.
|
||||
// Exits when stop is closed (draining any partially-filled batch first).
|
||||
func (b *batcher) run(stop <-chan struct{}) {
|
||||
for {
|
||||
var first *batchRequest
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
batch := []*batchRequest{first}
|
||||
|
||||
// maxSize==1 disables batching: dispatch immediately (passthrough).
|
||||
if b.maxSize == 1 {
|
||||
b.runBatch(batch)
|
||||
continue
|
||||
}
|
||||
|
||||
timer := time.NewTimer(b.maxWait)
|
||||
fill:
|
||||
for len(batch) < b.maxSize {
|
||||
select {
|
||||
case r := <-b.submit:
|
||||
batch = append(batch, r)
|
||||
case <-timer.C:
|
||||
break fill
|
||||
case <-stop:
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
return
|
||||
}
|
||||
}
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
}
|
||||
}
|
||||
108
backend/go/parakeet-cpp/batcher_test.go
Normal file
108
backend/go/parakeet-cpp/batcher_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("batcher", func() {
|
||||
echoReply := func(reqs []*batchRequest) {
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{json: r.tag}
|
||||
}
|
||||
}
|
||||
|
||||
It("coalesces concurrent submits into batches", func() {
|
||||
var mu sync.Mutex
|
||||
var sizes []int
|
||||
run := func(reqs []*batchRequest) {
|
||||
mu.Lock()
|
||||
sizes = append(sizes, len(reqs))
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(4, 50*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
const N = 4
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
total, maxBatch := 0, 0
|
||||
for _, s := range sizes {
|
||||
total += s
|
||||
if s > maxBatch {
|
||||
maxBatch = s
|
||||
}
|
||||
}
|
||||
Expect(total).To(Equal(N))
|
||||
Expect(maxBatch).To(BeNumerically(">=", 2), "expected at least one batch to coalesce >1 request")
|
||||
})
|
||||
|
||||
It("dispatches when max size is reached", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(2, time.Hour, run) // huge window: only size can trigger
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
for i := 0; i < 2; i++ {
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func(rep chan batchReply) { <-rep }(rep)
|
||||
}
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(2)))
|
||||
})
|
||||
|
||||
It("dispatches when the wait window elapses", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(8, 20*time.Millisecond, run) // size unreachable; window fires
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("bypasses batching when max size is 1", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(1, time.Hour, run) // size 1 => immediate dispatch
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
})
|
||||
530
backend/go/parakeet-cpp/goparakeetcpp.go
Normal file
530
backend/go/parakeet-cpp/goparakeetcpp.go
Normal file
@@ -0,0 +1,530 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// purego-bound entry points from libparakeet.so. Names match
|
||||
// parakeet_capi.h exactly so a `nm libparakeet.so | grep parakeet_capi`
|
||||
// is enough to spot drift.
|
||||
//
|
||||
// Functions that return char* are declared as uintptr so we can call
|
||||
// parakeet_capi_free_string on the same pointer after copying, the
|
||||
// C-API contract is "caller owns and must free the returned buffer".
|
||||
var (
|
||||
CppAbiVersion func() int32
|
||||
CppLoad func(ggufPath string) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppTranscribePath func(ctx uintptr, wavPath string, decoder int32) uintptr
|
||||
CppTranscribePathJSON func(ctx uintptr, wavPath string, decoder int32) uintptr
|
||||
CppFreeString func(s uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
|
||||
// Batched JSON transcription: takes a concatenated float buffer of clips
|
||||
// plus their per-clip sample counts (sum(nSamples)==len(samplesConcat))
|
||||
// and returns a malloc'd char* JSON ARRAY of per-clip {"text","words",
|
||||
// "tokens"} objects (uintptr, freed via CppFreeString). purego passes the
|
||||
// Go slices as the base pointer of their backing array (kept alive for the
|
||||
// call), matching the CppStreamFeed pcm []float32 binding pattern; the C
|
||||
// side reads them as const float*/const int*.
|
||||
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
|
||||
|
||||
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
|
||||
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
|
||||
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
|
||||
CppStreamBegin func(ctx uintptr) uintptr
|
||||
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
|
||||
CppStreamFinalize func(s uintptr) uintptr
|
||||
CppStreamFree func(s uintptr)
|
||||
)
|
||||
|
||||
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
|
||||
// call (1 s). The session buffers internally and decodes once a full
|
||||
// cache-aware encoder chunk is available, so this only bounds how often we
|
||||
// poll for newly-finalized text, not the model's actual chunk size.
|
||||
const streamChunkSamples = 16000
|
||||
|
||||
// transcriptJSON mirrors the document returned by
|
||||
// parakeet_capi_transcribe_path_json (see parakeet_capi.h):
|
||||
//
|
||||
// {"text":"...",
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...],
|
||||
// "tokens":[{"id":123,"t":0.480,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
|
||||
type transcriptJSON struct {
|
||||
Text string `json:"text"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
}
|
||||
|
||||
type transcriptWord struct {
|
||||
W string `json:"w"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Conf float64 `json:"conf"`
|
||||
}
|
||||
|
||||
type transcriptToken struct {
|
||||
ID int32 `json:"id"`
|
||||
T float64 `json:"t"`
|
||||
Conf float64 `json:"conf"`
|
||||
}
|
||||
|
||||
// ParakeetCpp owns a single loaded parakeet_ctx. The C engine is a
|
||||
// thread-unsafe singleton (mirrors whisper.cpp / vibevoice.cpp). Rather than
|
||||
// serialize every call through base.SingleThread, we route unary
|
||||
// transcription through an in-process batcher (its sole dispatcher goroutine
|
||||
// is the only caller of the engine on that path) and guard the shared engine
|
||||
// with engineMu so a streaming session and a batched-unary dispatch never
|
||||
// touch it concurrently.
|
||||
type ParakeetCpp struct {
|
||||
base.Base
|
||||
ctxPtr uintptr
|
||||
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
|
||||
bat *batcher
|
||||
batStop chan struct{}
|
||||
}
|
||||
|
||||
// Load is the LocalAI gRPC entry point for LoadModel: it calls
|
||||
// parakeet_capi_load with the GGUF path and stashes the resulting
|
||||
// opaque context pointer for AudioTranscription.
|
||||
func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
if opts.ModelFile == "" {
|
||||
return errors.New("parakeet-cpp: ModelFile is required")
|
||||
}
|
||||
|
||||
ctx := CppLoad(opts.ModelFile)
|
||||
if ctx == 0 {
|
||||
// No ctx to ask for last_error (the C-API's last-error buffer
|
||||
// lives on the ctx that was never returned). Surface the path
|
||||
// so the operator at least knows which load failed.
|
||||
return fmt.Errorf("parakeet-cpp: parakeet_capi_load failed for %q", opts.ModelFile)
|
||||
}
|
||||
p.ctxPtr = ctx
|
||||
|
||||
// Dynamic batching knobs (model YAML options:, key:value form). Batching is
|
||||
// OFF by default (batch_max_size:1): each request runs on its own. On GPU,
|
||||
// raising batch_max_size coalesces concurrent requests into one batched
|
||||
// engine call and improves throughput under load; leave it at 1 on CPU and
|
||||
// for low-concurrency setups, where batching only adds latency.
|
||||
maxSize := optInt(opts, "batch_max_size", 1)
|
||||
maxWaitMs := optInt(opts, "batch_max_wait_ms", 15)
|
||||
if maxWaitMs < 0 {
|
||||
maxWaitMs = 0
|
||||
}
|
||||
if CppTranscribePcmBatchJSON != nil {
|
||||
p.batStop = make(chan struct{})
|
||||
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
|
||||
go p.bat.run(p.batStop) // dispatcher runs until Free closes batStop
|
||||
if maxSize > 1 {
|
||||
xlog.Info("parakeet-cpp: dynamic batching enabled",
|
||||
"batch_max_size", maxSize, "batch_max_wait_ms", maxWaitMs)
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: dynamic batching off (batch_max_size=1); " +
|
||||
"set batch_max_size>1 to coalesce concurrent requests on GPU")
|
||||
}
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: batched C-API not present in libparakeet.so; " +
|
||||
"batching disabled, using per-request transcription")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// optInt reads an integer model option (key:value form) from ModelOptions,
|
||||
// returning def when absent or unparseable. The options array carries the
|
||||
// model YAML's options: entries (see core/config; siblings such as
|
||||
// acestep-cpp parse the same key:value form via strings.Cut on ":").
|
||||
func optInt(opts *pb.ModelOptions, key string, def int) int {
|
||||
for _, o := range opts.GetOptions() {
|
||||
k, v, ok := strings.Cut(o, ":")
|
||||
if ok && strings.TrimSpace(k) == key {
|
||||
if n, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// runBatch is the dispatcher's batch handler and the ONLY caller of the C
|
||||
// engine on the unary path. It concatenates the batch PCM, calls the batched
|
||||
// JSON C-API under engineMu, splits the JSON array, and replies to each request.
|
||||
func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// Observability: the actual coalesced batch size per engine call. Debug-level
|
||||
// so it stays silent in normal operation but lets operators confirm/tune batching.
|
||||
xlog.Debug("parakeet-cpp: dispatching batch", "size", len(reqs))
|
||||
nSamples := make([]int32, len(reqs))
|
||||
total := 0
|
||||
for i, r := range reqs {
|
||||
nSamples[i] = int32(len(r.pcm))
|
||||
total += len(r.pcm)
|
||||
}
|
||||
concat := make([]float32, 0, total)
|
||||
for _, r := range reqs {
|
||||
concat = append(concat, r.pcm...)
|
||||
}
|
||||
var dec int32
|
||||
if len(reqs) > 0 {
|
||||
dec = reqs[0].decoder
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
p.engineMu.Unlock()
|
||||
if cstr == 0 {
|
||||
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: err}
|
||||
}
|
||||
return
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var docs []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(raw), &docs); err != nil || len(docs) != len(reqs) {
|
||||
e := fmt.Errorf("parakeet-cpp: batch json: got %d results for %d reqs (%v)", len(docs), len(reqs), err)
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: e}
|
||||
}
|
||||
return
|
||||
}
|
||||
for i, r := range reqs {
|
||||
r.reply <- batchReply{json: string(docs[i])}
|
||||
}
|
||||
}
|
||||
|
||||
// AudioTranscription decodes the wav at opts.Dst to 16 kHz mono PCM and
|
||||
// submits it to the in-process batcher, which coalesces concurrent requests
|
||||
// into a single batched engine call (parakeet_capi_transcribe_pcm_batch_json)
|
||||
// with the default decoder (decoder=0, which selects the right head per
|
||||
// architecture: transducer for tdt/rnnt/hybrid, CTC for ctc) and shapes the
|
||||
// per-word timestamps into a LocalAI TranscriptResult.
|
||||
//
|
||||
// Parakeet emits word- and token-level timestamps but no native segment
|
||||
// boundaries, so we synthesise a single whole-clip segment spanning the first
|
||||
// word start to the last word end. Word-level timings are attached only when
|
||||
// the caller opts in via timestamp_granularities=["word"] (matching the
|
||||
// OpenAI API, whose default is segment-level); token ids always populate
|
||||
// Segment.Tokens.
|
||||
//
|
||||
// translate/diarize/prompt/temperature/language/threads are not applicable to
|
||||
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
// Fallback when the batched C-API is unavailable: transcribe directly from
|
||||
// the file path (original behavior, no batching).
|
||||
if p.bat == nil {
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
|
||||
if cstr == 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
// Batched path: decode to PCM, submit to the batcher, wait for this request's
|
||||
// JSON element. The dispatcher is the sole engine caller on this path; both
|
||||
// sends honour ctx cancellation.
|
||||
pcm, _, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
rep := make(chan batchReply, 1)
|
||||
select {
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
var res batchReply
|
||||
select {
|
||||
case res = <-rep:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if res.err != nil {
|
||||
return pb.TranscriptResult{}, res.err
|
||||
}
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
|
||||
// synthesising a single whole-clip segment and attaching word timings only when
|
||||
// the caller requested word granularity. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
|
||||
for _, w := range doc.Words {
|
||||
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
|
||||
}
|
||||
tokens := make([]int32, 0, len(doc.Tokens))
|
||||
for _, t := range doc.Tokens {
|
||||
tokens = append(tokens, t.ID)
|
||||
}
|
||||
var segStart, segEnd int64
|
||||
if len(words) > 0 {
|
||||
segStart = words[0].Start
|
||||
segEnd = words[len(words)-1].End
|
||||
}
|
||||
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
|
||||
if wordsRequested(opts.TimestampGranularities) {
|
||||
seg.Words = words
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
|
||||
}
|
||||
|
||||
// wordsRequested reports whether the caller asked for word-level timestamps.
|
||||
// The OpenAI transcription API gates word timings behind
|
||||
// timestamp_granularities[] containing "word" and defaults to segment-level
|
||||
// otherwise; we follow that contract.
|
||||
func wordsRequested(granularities []string) bool {
|
||||
for _, g := range granularities {
|
||||
if strings.EqualFold(strings.TrimSpace(g), "word") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// secondsToNanos converts the C-API's fractional-second timestamps into the
|
||||
// int64 nanoseconds LocalAI carries on TranscriptSegment/TranscriptWord, the
|
||||
// same nanosecond convention the whisper backend uses.
|
||||
func secondsToNanos(sec float64) int64 {
|
||||
return int64(sec * 1e9)
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream drives the cache-aware streaming RNN-T over the
|
||||
// audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it in
|
||||
// chunks to parakeet_capi_stream_feed, and emits each newly-finalized text
|
||||
// run as a TranscriptStreamResponse delta. <EOU>/<EOB> events close the
|
||||
// current segment; a closing FinalResult carries the full transcript and the
|
||||
// per-utterance segments.
|
||||
//
|
||||
// stream_begin returns 0 for models that are not cache-aware streaming models
|
||||
// (only e.g. nvidia/parakeet_realtime_eou_120m-v1 qualifies). For those we fall
|
||||
// back to a single offline transcription emitted as one delta plus a closing
|
||||
// FinalResult, matching LocalAI's non-streaming streaming contract (and the
|
||||
// whisper backend), so the streaming endpoint works for every model.
|
||||
func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return errors.New("parakeet-cpp: model not loaded")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
stream := CppStreamBegin(p.ctxPtr)
|
||||
if stream == 0 {
|
||||
// Not a cache-aware streaming model: run a normal offline
|
||||
// transcription and emit it as one delta + a closing final result.
|
||||
res, err := p.AudioTranscription(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t := strings.TrimSpace(res.Text); t != "" {
|
||||
results <- &pb.TranscriptStreamResponse{Delta: t}
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: &res}
|
||||
return nil
|
||||
}
|
||||
defer CppStreamFree(stream)
|
||||
// The C engine is a single shared context: a streaming session and a batched
|
||||
// unary dispatch must never touch it at once, so hold engineMu for the whole
|
||||
// stream. This lock is intentionally taken AFTER the non-streaming fallback
|
||||
// above returns: that fallback goes through AudioTranscription -> the batcher
|
||||
// -> runBatch, which itself acquires engineMu, so locking here first would
|
||||
// deadlock. Do not hoist this lock above the fallback.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
|
||||
data, duration, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
full strings.Builder
|
||||
segText strings.Builder
|
||||
segments []*pb.TranscriptSegment
|
||||
segID int32
|
||||
)
|
||||
|
||||
flushSegment := func() {
|
||||
t := strings.TrimSpace(segText.String())
|
||||
segText.Reset()
|
||||
if t == "" {
|
||||
return
|
||||
}
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: segID, Text: t})
|
||||
segID++
|
||||
}
|
||||
|
||||
// emitDelta consumes the malloc'd char* returned by feed/finalize: frees
|
||||
// it, accumulates the text, and sends a delta when non-empty. A 0 return
|
||||
// is an error (vs the "" empty-but-non-NULL no-new-text case).
|
||||
emitDelta := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
delta := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
full.WriteString(delta)
|
||||
segText.WriteString(delta)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
|
||||
var eou int32
|
||||
ret := CppStreamFeed(stream, chunk, int32(len(chunk)), unsafe.Pointer(&eou))
|
||||
if err := emitDelta(ret); err != nil {
|
||||
return err
|
||||
}
|
||||
if eou != 0 {
|
||||
flushSegment()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush the streaming tail (final encoder chunk).
|
||||
if err := emitDelta(CppStreamFinalize(stream)); err != nil {
|
||||
return err
|
||||
}
|
||||
flushSegment()
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
|
||||
// float samples plus the clip duration in seconds. Mirrors the whisper
|
||||
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
|
||||
// decodes the PCM.
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
dir, err := os.MkdirTemp("", "parakeet")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
converted := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(path, converted); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
fh, err := os.Open(converted)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
buf, err := wav.NewDecoder(fh).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
return data, duration, nil
|
||||
}
|
||||
|
||||
// Free releases the underlying parakeet_ctx. Called by LocalAI when the
|
||||
// model is unloaded.
|
||||
func (p *ParakeetCpp) Free() error {
|
||||
// Stop the dispatcher before releasing the engine so no in-flight runBatch
|
||||
// can touch a freed ctx (close leak / use-after-free on reload).
|
||||
if p.batStop != nil {
|
||||
close(p.batStop)
|
||||
p.batStop = nil
|
||||
}
|
||||
if p.ctxPtr != 0 {
|
||||
CppFree(p.ctxPtr)
|
||||
p.ctxPtr = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory.
|
||||
// cptr is the raw pointer returned by purego from the C-API (a malloc'd
|
||||
// buffer the caller owns); callers must free it via CppFreeString after
|
||||
// the copy lands.
|
||||
//
|
||||
// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr
|
||||
// check, which can't distinguish a C-owned heap pointer from Go-managed
|
||||
// memory. It is safe here: the pointer addresses a malloc'd C buffer the
|
||||
// Go GC neither tracks nor moves, and we dereference it immediately to
|
||||
// copy the bytes out, the same pattern (and the same tolerated warning)
|
||||
// as the whisper backend's unsafe.Slice over segsPtr.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := unsafe.Pointer(cptr) //nolint:govet // C-owned malloc'd buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
}
|
||||
return string(unsafe.Slice((*byte)(p), n))
|
||||
}
|
||||
167
backend/go/parakeet-cpp/goparakeetcpp_test.go
Normal file
167
backend/go/parakeet-cpp/goparakeetcpp_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestParakeetCpp(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "parakeet-cpp Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive
|
||||
// the C-API bridge without spinning up the gRPC server. Skips the
|
||||
// current spec when libparakeet.so isn't loadable from cwd
|
||||
// ($LD_LIBRARY_PATH or a symlink in ./).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("PARAKEET_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libparakeet.so"
|
||||
}
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppAbiVersion, lib, "parakeet_capi_abi_version")
|
||||
purego.RegisterLibFunc(&CppLoad, lib, "parakeet_capi_load")
|
||||
purego.RegisterLibFunc(&CppFree, lib, "parakeet_capi_free")
|
||||
purego.RegisterLibFunc(&CppTranscribePath, lib, "parakeet_capi_transcribe_path")
|
||||
purego.RegisterLibFunc(&CppTranscribePathJSON, lib, "parakeet_capi_transcribe_path_json")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppStreamBegin, lib, "parakeet_capi_stream_begin")
|
||||
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
|
||||
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
|
||||
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
|
||||
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
|
||||
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("libparakeet.so not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if
|
||||
// either env var is unset. The smoke test never runs in default CI; it
|
||||
// needs a real parakeet GGUF and a 16 kHz mono WAV on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_MODEL")
|
||||
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set PARAKEET_BACKEND_TEST_MODEL and PARAKEET_BACKEND_TEST_WAV to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
var _ = Describe("ParakeetCpp", func() {
|
||||
Context("AudioTranscription", func() {
|
||||
It("transcribes a WAV via the parakeet C-API", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
|
||||
"expected non-empty transcript for %s", audioPath)
|
||||
Expect(res.Segments).To(HaveLen(1),
|
||||
"synthesises a single whole-clip segment")
|
||||
Expect(res.Segments[0].Text).To(Equal(res.Text),
|
||||
"single segment text must equal the top-level text")
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(res.Segments[0].Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
})
|
||||
|
||||
It("emits word-level timestamps when granularity=word", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
TimestampGranularities: []string{"word"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
seg := res.Segments[0]
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"expected per-word timestamps with granularity=word")
|
||||
// Monotonic, non-negative timings spanning the segment.
|
||||
Expect(seg.Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start))
|
||||
Expect(seg.Words[len(seg.Words)-1].End).To(Equal(seg.End),
|
||||
"segment end tracks the last word")
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("streams deltas and a closing FinalResult from a cache-aware model", func() {
|
||||
// Streaming needs a cache-aware streaming model (e.g.
|
||||
// realtime_eou); the offline test model would fail stream_begin.
|
||||
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_STREAM_MODEL")
|
||||
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set PARAKEET_BACKEND_TEST_STREAM_MODEL (cache-aware streaming model) and PARAKEET_BACKEND_TEST_WAV")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- p.AudioTranscriptionStream(context.Background(),
|
||||
&pb.TranscriptRequest{Dst: audioPath}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var final *pb.TranscriptResult
|
||||
for r := range results {
|
||||
if r.Delta != "" {
|
||||
deltas = append(deltas, r.Delta)
|
||||
}
|
||||
if r.FinalResult != nil {
|
||||
final = r.FinalResult
|
||||
}
|
||||
}
|
||||
Expect(<-errCh).ToNot(HaveOccurred())
|
||||
|
||||
Expect(final).ToNot(BeNil(), "expected a closing FinalResult")
|
||||
Expect(strings.TrimSpace(final.Text)).ToNot(BeEmpty(),
|
||||
"expected a non-empty streamed transcript")
|
||||
Expect(final.Segments).ToNot(BeEmpty(),
|
||||
"FinalResult always carries at least one segment")
|
||||
// The concatenated deltas reconstruct the final transcript.
|
||||
Expect(strings.TrimSpace(strings.Join(deltas, ""))).To(Equal(strings.TrimSpace(final.Text)),
|
||||
"deltas should reconstruct the final text")
|
||||
})
|
||||
})
|
||||
})
|
||||
75
backend/go/parakeet-cpp/main.go
Normal file
75
backend/go/parakeet-cpp/main.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
//
|
||||
// Loads libparakeet.so via purego and registers the flat C-API entry
|
||||
// points declared in parakeet_capi.h. The library name can be overridden
|
||||
// with PARAKEET_LIBRARY (mirrors the WHISPER_LIBRARY / VIBEVOICECPP_LIBRARY
|
||||
// convention in the sibling backends); the default looks for the .so next
|
||||
// to this binary.
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("PARAKEET_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libparakeet.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("parakeet-cpp: dlopen %q: %w", libName, err))
|
||||
}
|
||||
|
||||
// Bound 1:1 to parakeet_capi.h. The C-API returns malloc'd char*
|
||||
// buffers from transcribe_*; we register those as uintptr so we get
|
||||
// the raw pointer back and can call parakeet_capi_free_string on it
|
||||
// (purego's string return would copy and forget the original pointer,
|
||||
// leaking it on every call).
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppAbiVersion, "parakeet_capi_abi_version"},
|
||||
{&CppLoad, "parakeet_capi_load"},
|
||||
{&CppFree, "parakeet_capi_free"},
|
||||
{&CppTranscribePath, "parakeet_capi_transcribe_path"},
|
||||
{&CppTranscribePathJSON, "parakeet_capi_transcribe_path_json"},
|
||||
{&CppStreamBegin, "parakeet_capi_stream_begin"},
|
||||
{&CppStreamFeed, "parakeet_capi_stream_feed"},
|
||||
{&CppStreamFinalize, "parakeet_capi_stream_finalize"},
|
||||
{&CppStreamFree, "parakeet_capi_stream_free"},
|
||||
{&CppFreeString, "parakeet_capi_free_string"},
|
||||
{&CppLastError, "parakeet_capi_last_error"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
// The batched-JSON entry point exists only in newer libparakeet.so (ABI >= 2).
|
||||
// Probe with Dlsym and register only if present, so the backend still loads
|
||||
// against an older library (it falls back to per-request transcription).
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &ParakeetCpp{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
backend/go/parakeet-cpp/package.sh
Executable file
23
backend/go/parakeet-cpp/package.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# L0 packaging stub: copy the binary, run.sh and libparakeet.so* into
|
||||
# package/. The full ldd walk (libc, libstdc++, libgomp, GPU runtimes,
|
||||
# arch detection) lands in L3, mirroring backend/go/whisper/package.sh.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/parakeet-cpp-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libparakeet.so + any soname symlinks (libparakeet.so.X, libparakeet.so.X.Y).
|
||||
cp -avf "$CURDIR"/libparakeet.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libparakeet.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo "L0 package layout (full ldd walk lands in L3):"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
16
backend/go/parakeet-cpp/run.sh
Executable file
16
backend/go/parakeet-cpp/run.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# If a self-contained ld.so was packaged, route through it so the
|
||||
# packaged libc / libstdc++ are used instead of the host's (matches the
|
||||
# whisper backend's runtime layout).
|
||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec "$CURDIR/lib/ld.so" "$CURDIR/parakeet-cpp-grpc" "$@"
|
||||
fi
|
||||
|
||||
exec "$CURDIR/parakeet-cpp-grpc" "$@"
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=0e4ee04488159b81d95a9ffcd983a077fd5dcb77
|
||||
STABLEDIFFUSION_GGML_VERSION?=7948df8ac1070f5f6881b8d34675821893eb97d6
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=f24588a272ae8e23280d9c220536437164e6ed28
|
||||
WHISPER_CPP_VERSION?=23ee03506a91ac3d3f0071b40e66a430eebdfa1d
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -122,6 +122,62 @@
|
||||
nvidia-cuda-12: "cuda12-whisper"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisper"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-whisper"
|
||||
- &crispasr
|
||||
name: "crispasr"
|
||||
alias: "crispasr"
|
||||
license: mit
|
||||
icon: https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg
|
||||
description: |
|
||||
CrispASR unified speech engine (whisper.cpp fork on ggml) supporting many ASR architectures (Parakeet, Canary, Voxtral, Qwen3-ASR, Granite, Wav2Vec2, Moonshine, OmniASR, FireRedASR, and more).
|
||||
urls:
|
||||
- https://github.com/CrispStrobe/CrispASR
|
||||
tags:
|
||||
- audio-transcription
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-crispasr"
|
||||
nvidia: "cuda12-crispasr"
|
||||
intel: "intel-sycl-f16-crispasr"
|
||||
metal: "metal-crispasr"
|
||||
amd: "rocm-crispasr"
|
||||
vulkan: "vulkan-crispasr"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-cuda-13: "cuda13-crispasr"
|
||||
nvidia-cuda-12: "cuda12-crispasr"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
- ¶keetcpp
|
||||
name: "parakeet-cpp"
|
||||
alias: "parakeet-cpp"
|
||||
license: mit
|
||||
icon: https://avatars.githubusercontent.com/u/95302084
|
||||
description: |
|
||||
parakeet.cpp is a C++/ggml port of NVIDIA NeMo Parakeet automatic speech recognition (ASR) models.
|
||||
It supports the tdt, ctc, rnnt and hybrid decoder families as well as cache-aware streaming transcription,
|
||||
and runs on CPU, NVIDIA CUDA, AMD ROCm/HIP, Intel SYCL and NVIDIA Jetson (L4T) targets.
|
||||
urls:
|
||||
- https://github.com/mudler/parakeet.cpp
|
||||
tags:
|
||||
- audio-transcription
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-parakeet-cpp"
|
||||
nvidia: "cuda12-parakeet-cpp"
|
||||
intel: "intel-sycl-f16-parakeet-cpp"
|
||||
metal: "metal-parakeet-cpp"
|
||||
amd: "rocm-parakeet-cpp"
|
||||
vulkan: "vulkan-parakeet-cpp"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-parakeet-cpp"
|
||||
nvidia-cuda-13: "cuda13-parakeet-cpp"
|
||||
nvidia-cuda-12: "cuda12-parakeet-cpp"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-parakeet-cpp"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-parakeet-cpp"
|
||||
- &voxtral
|
||||
name: "voxtral"
|
||||
alias: "voxtral"
|
||||
@@ -1928,6 +1984,246 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-whisper
|
||||
## crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "crispasr-development"
|
||||
capabilities:
|
||||
default: "cpu-crispasr-development"
|
||||
nvidia: "cuda12-crispasr-development"
|
||||
intel: "intel-sycl-f16-crispasr-development"
|
||||
metal: "metal-crispasr-development"
|
||||
amd: "rocm-crispasr-development"
|
||||
vulkan: "vulkan-crispasr-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-cuda-13: "cuda13-crispasr-development"
|
||||
nvidia-cuda-12: "cuda12-crispasr-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-crispasr
|
||||
## parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "parakeet-cpp-development"
|
||||
capabilities:
|
||||
default: "cpu-parakeet-cpp-development"
|
||||
nvidia: "cuda12-parakeet-cpp-development"
|
||||
intel: "intel-sycl-f16-parakeet-cpp-development"
|
||||
metal: "metal-parakeet-cpp-development"
|
||||
amd: "rocm-parakeet-cpp-development"
|
||||
vulkan: "vulkan-parakeet-cpp-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
nvidia-cuda-13: "cuda13-parakeet-cpp-development"
|
||||
nvidia-cuda-12: "cuda12-parakeet-cpp-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "nvidia-l4t-arm64-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-nvidia-l4t-arm64-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cpu-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cpu-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "metal-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "metal-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda12-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda12-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "rocm-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "rocm-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f32-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f32-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f16-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f16-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "vulkan-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "vulkan-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-parakeet-cpp
|
||||
## stablediffusion-ggml
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "cpu-stablediffusion-ggml"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
texterrors==1.1.6
|
||||
nemo_toolkit[asr]
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Release represents a LocalAI release
|
||||
@@ -67,9 +68,7 @@ func NewReleaseManager() *ReleaseManager {
|
||||
CurrentVersion: internal.PrintableVersion(),
|
||||
ChecksumsPath: checksumsPath,
|
||||
MetadataPath: metadataPath,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
HTTPClient: httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
@@ -240,6 +242,84 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
|
||||
// Prefix-cache-aware routing. Enabled by default; an operator can opt out
|
||||
// with --distributed-prefix-cache=false, which leaves prefixProvider and
|
||||
// pressure nil so the SmartRouter and reconciler behave exactly as the
|
||||
// round-robin floor (true no-op). When enabled we build the local index,
|
||||
// wrap it in a NATS-backed Sync (publishes our observations, applies peers'
|
||||
// via the subscriptions below), install the extraction hook used by
|
||||
// core/backend/llm.go, and run a background eviction ticker on the app ctx.
|
||||
var prefixProvider prefixcache.Provider
|
||||
var pressure *prefixcache.Pressure
|
||||
var prefixCfg prefixcache.Config
|
||||
if !cfg.Distributed.PrefixCacheDisabled {
|
||||
prefixCfg = prefixcache.DefaultConfig()
|
||||
if cfg.Distributed.PrefixCacheTTL > 0 {
|
||||
prefixCfg.TTL = cfg.Distributed.PrefixCacheTTL
|
||||
}
|
||||
if err := prefixCfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid prefix-cache configuration: %w", err)
|
||||
}
|
||||
idx := prefixcache.NewIndex(prefixCfg)
|
||||
prefixSync := prefixcache.NewSync(idx, natsClient)
|
||||
pressure = prefixcache.NewPressure(prefixCfg.PressureWindow)
|
||||
prefixProvider = prefixSync
|
||||
|
||||
// Invalidate the prefix-cache index whenever a replica row is removed.
|
||||
// SetReplicaRemovedHook fires from the single chokepoint all removal paths
|
||||
// funnel through (RemoveNodeModel / RemoveAllNodeModelReplicas), so this
|
||||
// one hook covers every path: reconciler scale-down, probe reaper,
|
||||
// health-monitor reap, RemoteUnloaderAdapter, and the router. Registering
|
||||
// it only inside this enabled block keeps the disabled path a true no-op
|
||||
// (the registry stays hook-less).
|
||||
registry.SetReplicaRemovedHook(func(model, node string, replica int) {
|
||||
if replica < 0 {
|
||||
prefixSync.InvalidateNode(model, node)
|
||||
} else {
|
||||
prefixSync.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: replica})
|
||||
}
|
||||
})
|
||||
|
||||
distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 {
|
||||
return prefixcache.ExtractChain(model, prompt, prefixCfg)
|
||||
}
|
||||
|
||||
// Apply peers' observations/invalidations to the same Sync. ApplyObserve
|
||||
// and ApplyInvalidate update only the local index and do not re-publish,
|
||||
// so there is no broadcast loop.
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheObserve, func(ev messaging.PrefixCacheObserveEvent) {
|
||||
prefixSync.ApplyObserve(ev, time.Now())
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheObserve, err)
|
||||
}
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheInvalidate, func(ev messaging.PrefixCacheInvalidateEvent) {
|
||||
prefixSync.ApplyInvalidate(ev)
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheInvalidate, err)
|
||||
}
|
||||
|
||||
// Background eviction: sweep idle entries on the app context. Stopped
|
||||
// when the app context is cancelled (mirrors the reconciler loop which
|
||||
// also runs on options.Context). TTL/2 keeps stale entries from
|
||||
// outliving their idle window by more than half a TTL.
|
||||
evictInterval := prefixCfg.TTL / 2
|
||||
go func() {
|
||||
ticker := time.NewTicker(evictInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-cfg.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
prefixSync.Evict(time.Now())
|
||||
}
|
||||
}
|
||||
}()
|
||||
xlog.Info("Prefix-cache-aware routing enabled", "ttl", prefixCfg.TTL, "evictInterval", evictInterval)
|
||||
} else {
|
||||
xlog.Info("Prefix-cache-aware routing disabled: using round-robin routing")
|
||||
}
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
if configLoader != nil {
|
||||
@@ -252,6 +332,9 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
ConflictResolver: conflictResolver,
|
||||
PrefixProvider: prefixProvider,
|
||||
PrefixConfig: prefixCfg,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
|
||||
@@ -268,6 +351,8 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
Interval: 30 * time.Second,
|
||||
ScaleDownDelay: 5 * time.Minute,
|
||||
ProbeStaleAfter: 2 * time.Minute,
|
||||
Pressure: pressure,
|
||||
PressureThreshold: prefixCfg.PressureScaleThreshold,
|
||||
})
|
||||
|
||||
// Create ModelRouterAdapter to wire into ModelLoader
|
||||
|
||||
@@ -53,9 +53,21 @@ func (a *Application) StartP2P() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// modelsFn reports the model names this instance currently serves so the
|
||||
// federation proxy can route a request only to peers that have the
|
||||
// requested model. It is re-evaluated on every announce tick.
|
||||
modelsFn := func() []string {
|
||||
cfgs := a.ModelConfigLoader().GetAllModelsConfigs()
|
||||
names := make([]string, 0, len(cfgs))
|
||||
for _, c := range cfgs {
|
||||
names = append(names, c.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// Here a new node is created and started
|
||||
// and a service is exposed by the node
|
||||
node, err := p2p.ExposeService(ctx, "localhost", port, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.FederatedID))
|
||||
node, err := p2p.ExposeService(ctx, "localhost", port, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.FederatedID), modelsFn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -94,6 +95,22 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
// Make the rendered prompt's prefix chain available to the distributed router
|
||||
// for prefix-cache-aware node selection. No-op in single-process mode. The
|
||||
// model id MUST match the id ModelOptions feeds to model.WithModelID, so both
|
||||
// use the shared config.ModelConfig.ModelID() helper (Name with a fallback to
|
||||
// Model) or the chain salt and the tracking key would diverge.
|
||||
//
|
||||
// s is empty for UseTokenizerTemplate models (the backend tokenizes the
|
||||
// structured messages itself), so fall back to a prefix-stable serialization
|
||||
// of the messages - otherwise prefix routing would silently degrade to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
chainSource := s
|
||||
if chainSource == "" {
|
||||
chainSource = messagesPrefixSource(messages)
|
||||
}
|
||||
ctx = distributedhdr.MaybeWithPrefixChain(ctx, c.ModelID(), chainSource)
|
||||
|
||||
opts := ModelOptions(*c, o, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -34,16 +34,11 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
|
||||
}
|
||||
|
||||
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
|
||||
name := c.Name
|
||||
if name == "" {
|
||||
name = c.Model
|
||||
}
|
||||
|
||||
defOpts := []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithModel(c.Model),
|
||||
model.WithContext(so.Context),
|
||||
model.WithModelID(name),
|
||||
model.WithModelID(c.ModelID()),
|
||||
}
|
||||
|
||||
threads := 1
|
||||
|
||||
36
core/backend/prefix_source.go
Normal file
36
core/backend/prefix_source.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// messagesPrefixSource builds a deterministic, prefix-stable serialization of a
|
||||
// chat conversation for prefix-cache-aware routing. It is the fallback used when
|
||||
// the frontend did not render a prompt string: models with
|
||||
// config.TemplateConfig.UseTokenizerTemplate tokenize the structured messages
|
||||
// backend-side, so the frontend's rendered prompt is empty and a chain built
|
||||
// from it would always be empty - silently degrading prefix routing to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
//
|
||||
// Messages are emitted head-first in turn order (role line + content line per
|
||||
// message), so two conversations sharing a leading system prompt and early turns
|
||||
// share a leading byte prefix. That is exactly what ExtractChain hashes into a
|
||||
// shared chain prefix, landing both requests on the same cache-warm replica.
|
||||
func messagesPrefixSource(messages schema.Messages) string {
|
||||
var b strings.Builder
|
||||
for _, m := range messages {
|
||||
b.WriteString(m.Role)
|
||||
b.WriteByte('\n')
|
||||
content := m.StringContent
|
||||
if content == "" {
|
||||
if s, ok := m.Content.(string); ok {
|
||||
content = s
|
||||
}
|
||||
}
|
||||
b.WriteString(content)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
53
core/backend/prefix_source_internal_test.go
Normal file
53
core/backend/prefix_source_internal_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("messagesPrefixSource", func() {
|
||||
mk := func(role, content string) schema.Message {
|
||||
return schema.Message{Role: role, StringContent: content}
|
||||
}
|
||||
|
||||
It("serializes messages head-first in turn order", func() {
|
||||
got := messagesPrefixSource(schema.Messages{
|
||||
mk("system", "You are helpful."),
|
||||
mk("user", "Hi"),
|
||||
})
|
||||
Expect(got).To(Equal("system\nYou are helpful.\nuser\nHi\n"))
|
||||
})
|
||||
|
||||
It("is deterministic across calls for the same conversation", func() {
|
||||
conv := schema.Messages{mk("system", "S"), mk("user", "U")}
|
||||
Expect(messagesPrefixSource(conv)).To(Equal(messagesPrefixSource(conv)))
|
||||
})
|
||||
|
||||
It("shares a leading byte prefix when the system prompt is shared", func() {
|
||||
shared := "system\nShared system prompt.\nuser\n"
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question A")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question B")})
|
||||
Expect(strings.HasPrefix(a, shared)).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, shared)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does NOT share a prefix when the system prompt differs", func() {
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Prompt A"), mk("user", "Q")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Prompt B"), mk("user", "Q")})
|
||||
Expect(strings.HasPrefix(a, "system\nPrompt A")).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, "system\nPrompt B")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns empty for no messages", func() {
|
||||
Expect(messagesPrefixSource(nil)).To(Equal(""))
|
||||
})
|
||||
|
||||
It("falls back to Content when StringContent is empty", func() {
|
||||
got := messagesPrefixSource(schema.Messages{{Role: "user", Content: "plain"}})
|
||||
Expect(got).To(Equal("user\nplain\n"))
|
||||
})
|
||||
})
|
||||
@@ -14,12 +14,14 @@ type FederatedCLI struct {
|
||||
RandomWorker bool `env:"LOCALAI_RANDOM_WORKER,RANDOM_WORKER" default:"false" help:"Select a random worker from the pool" group:"p2p"`
|
||||
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances." group:"p2p"`
|
||||
TargetWorker string `env:"LOCALAI_TARGET_WORKER,TARGET_WORKER" help:"Target worker to run the federated server on" group:"p2p"`
|
||||
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-size limit in megabytes" group:"api"`
|
||||
AffinitySync bool `env:"LOCALAI_FEDERATED_AFFINITY_SYNC,FEDERATED_AFFINITY_SYNC" default:"false" help:"Broadcast prefix-cache affinity observations to other federation servers over the p2p generic channel (enable on every federation server that should cohere)" group:"p2p"`
|
||||
}
|
||||
|
||||
func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
|
||||
warnDeprecatedFlags()
|
||||
|
||||
fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker)
|
||||
fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker, int64(f.UploadLimit)*1024*1024, f.AffinitySync)
|
||||
|
||||
c, cancel := context.WithCancel(context.Background())
|
||||
|
||||
|
||||
@@ -145,19 +145,21 @@ type RunCMD struct {
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
|
||||
@@ -284,6 +286,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
if !r.DistributedPrefixCache {
|
||||
opts = append(opts, config.DisablePrefixCache)
|
||||
}
|
||||
if r.DistributedPrefixCacheTTL != "" {
|
||||
d, err := time.ParseDuration(r.DistributedPrefixCacheTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL %q: %w", r.DistributedPrefixCacheTTL, err)
|
||||
}
|
||||
opts = append(opts, config.WithPrefixCacheTTL(d))
|
||||
}
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
|
||||
@@ -14,4 +14,5 @@ type Worker struct {
|
||||
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
|
||||
MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"`
|
||||
VLLMDistributed VLLMDistributed `cmd:"" name:"vllm" help:"Starts a vLLM data-parallel follower process. Multi-node DP for a single model: head runs the existing vllm backend with engine_args.data_parallel_size>1, followers run this command."`
|
||||
DS4Distributed DS4Distributed `cmd:"" name:"ds4-distributed" help:"Starts a ds4 distributed worker in standalone mode: owns a layer slice and dials the coordinator (pass ds4-worker args after --)"`
|
||||
}
|
||||
|
||||
108
core/cli/worker/worker_ds4.go
Normal file
108
core/cli/worker/worker_ds4.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
type DS4Distributed struct {
|
||||
WorkerFlags `embed:""`
|
||||
ExtraDS4Args string `name:"ds4-args" env:"LOCALAI_EXTRA_DS4_ARGS,EXTRA_DS4_ARGS" help:"Arguments passed to ds4-worker (e.g. '--role worker --model m.gguf --layers 20:output --coordinator HOST PORT')"`
|
||||
}
|
||||
|
||||
const (
|
||||
ds4WorkerBinaryName = "ds4-worker"
|
||||
ds4GalleryName = "ds4"
|
||||
)
|
||||
|
||||
// ds4WorkerArgs builds the argv for syscall.Exec when launching ds4-worker
|
||||
// directly: the binary path followed by the space-split extra args. An empty
|
||||
// extra string yields a bare invocation.
|
||||
func ds4WorkerArgs(binary, extra string) []string {
|
||||
args := []string{binary}
|
||||
args = append(args, strings.Fields(extra)...)
|
||||
return args
|
||||
}
|
||||
|
||||
func findDS4Backend(galleries string, systemState *system.SystemState, requireIntegrity bool) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed listing system backends", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
backend, ok := backends.Get(ds4GalleryName)
|
||||
if !ok {
|
||||
ml := model.NewModelLoader(systemState)
|
||||
var gals []config.Gallery
|
||||
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
|
||||
xlog.Error("failed loading galleries", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, ds4GalleryName, nil, true, requireIntegrity); err != nil {
|
||||
xlog.Error("ds4 backend not found, failed to install it", "error", err)
|
||||
return "", err
|
||||
}
|
||||
backends, err = gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
backend, ok = backends.Get(ds4GalleryName)
|
||||
if !ok {
|
||||
return "", errors.New("ds4 backend not found after install")
|
||||
}
|
||||
}
|
||||
|
||||
backendPath := filepath.Dir(backend.RunFile)
|
||||
if backendPath == "" {
|
||||
return "", errors.New("ds4 backend not found, install it first")
|
||||
}
|
||||
return filepath.Join(backendPath, ds4WorkerBinaryName), nil
|
||||
}
|
||||
|
||||
func (r *DS4Distributed) Run(ctx *cliContext.Context) error {
|
||||
if r.ExtraDS4Args == "" && len(os.Args) < 4 {
|
||||
return fmt.Errorf("usage: local-ai worker ds4-distributed -- --role worker --model <gguf> --layers <START:END|START:output> --coordinator <host> <port>")
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(r.BackendsPath),
|
||||
system.WithBackendSystemPath(r.BackendsSystemPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
worker, err := findDS4Backend(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ds4 bundles its own dynamic loader (lib/ld.so) for glibc compatibility,
|
||||
// like backend/cpp/ds4/run.sh does for grpc-server. Launch ds4-worker via
|
||||
// that loader when present; otherwise exec it directly. (This is a
|
||||
// deliberate divergence from worker_llamacpp.go, which has no bundled loader.)
|
||||
backendPath := filepath.Dir(worker)
|
||||
env := os.Environ()
|
||||
loader := filepath.Join(backendPath, "lib", "ld.so")
|
||||
if _, statErr := os.Stat(loader); statErr == nil {
|
||||
env = append(env, "LD_LIBRARY_PATH="+filepath.Join(backendPath, "lib")+":"+os.Getenv("LD_LIBRARY_PATH"))
|
||||
args := append([]string{loader}, ds4WorkerArgs(worker, r.ExtraDS4Args)...)
|
||||
return syscall.Exec(loader, args, env)
|
||||
}
|
||||
|
||||
return syscall.Exec(worker, ds4WorkerArgs(worker, r.ExtraDS4Args), env)
|
||||
}
|
||||
28
core/cli/worker/worker_ds4_test.go
Normal file
28
core/cli/worker/worker_ds4_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ds4 worker CLI", func() {
|
||||
It("uses the ds4 backend gallery name and worker binary name", func() {
|
||||
Expect(ds4GalleryName).To(Equal("ds4"))
|
||||
Expect(ds4WorkerBinaryName).To(Equal("ds4-worker"))
|
||||
})
|
||||
|
||||
It("assembles direct exec args as [binary, extra-split...]", func() {
|
||||
args := ds4WorkerArgs("/b/ds4-worker", "--role worker --model m.gguf --layers 20:output --coordinator 10.0.0.1 1234")
|
||||
Expect(args).To(Equal([]string{
|
||||
"/b/ds4-worker",
|
||||
"--role", "worker",
|
||||
"--model", "m.gguf",
|
||||
"--layers", "20:output",
|
||||
"--coordinator", "10.0.0.1", "1234",
|
||||
}))
|
||||
})
|
||||
|
||||
It("drops empty extra args to a bare binary invocation", func() {
|
||||
Expect(ds4WorkerArgs("/b/ds4-worker", "")).To(Equal([]string{"/b/ds4-worker"}))
|
||||
})
|
||||
})
|
||||
@@ -62,7 +62,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
p = r.RunnerPort
|
||||
}
|
||||
|
||||
_, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID))
|
||||
_, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -104,7 +104,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID))
|
||||
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func (r *P2PMLX) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.MLXWorkerID))
|
||||
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.MLXWorkerID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// RegistrationClient talks to the frontend's /api/node/* endpoints.
|
||||
@@ -37,7 +39,7 @@ func (c *RegistrationClient) httpTimeout() time.Duration {
|
||||
// httpClient returns the shared HTTP client, initializing it on first use.
|
||||
func (c *RegistrationClient) httpClient() *http.Client {
|
||||
c.clientOnce.Do(func() {
|
||||
c.client = &http.Client{Timeout: c.httpTimeout()}
|
||||
c.client = httpclient.NewWithTimeout(c.httpTimeout())
|
||||
})
|
||||
return c.client
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Define a struct to hold the store API client
|
||||
@@ -47,7 +49,7 @@ type FindResponse struct {
|
||||
func NewStoreClient(baseUrl string) *StoreClient {
|
||||
return &StoreClient{
|
||||
BaseURL: baseUrl,
|
||||
Client: &http.Client{},
|
||||
Client: httpclient.New(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -198,6 +198,13 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
AcceptsVideos: true,
|
||||
Description: "vLLM engine — high-throughput LLM serving with optional multimodal",
|
||||
},
|
||||
"sglang": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodTokenizeString},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseTokenize, UsecaseVision},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
AcceptsImages: true,
|
||||
Description: "SGLang — fast LLM inference with structured generation and optional vision",
|
||||
},
|
||||
"vllm-omni": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodGenerateImage, MethodGenerateVideo, MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseImage, UsecaseVideo, UsecaseTTS, UsecaseVision},
|
||||
@@ -291,6 +298,12 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "NVIDIA NeMo speech recognition",
|
||||
},
|
||||
"parakeet-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "NVIDIA NeMo Parakeet ASR (parakeet.cpp)",
|
||||
},
|
||||
"qwen-asr": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
@@ -309,6 +322,18 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
Description: "VibeVoice — bidirectional speech (transcription and synthesis)",
|
||||
},
|
||||
"vibevoice-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodTTS, MethodTTSStream},
|
||||
PossibleUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
Description: "VibeVoice C++ — bidirectional speech, C++ backend with streaming TTS",
|
||||
},
|
||||
"sherpa-onnx": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodTTS, MethodTTSStream, MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseTranscript, UsecaseTTS, UsecaseVAD},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "Sherpa-ONNX — multi-model speech toolkit (ASR, TTS, VAD)",
|
||||
},
|
||||
|
||||
// --- TTS backends ---
|
||||
"piper": {
|
||||
@@ -353,6 +378,12 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Qwen TTS",
|
||||
},
|
||||
"qwen3-tts-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Qwen3 TTS C++ — text-to-speech, C++ backend",
|
||||
},
|
||||
"faster-qwen3-tts": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
@@ -434,6 +465,12 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR object detection",
|
||||
},
|
||||
"rfdetr-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodDetect},
|
||||
PossibleUsecases: []string{UsecaseDetection},
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR C++ object detection",
|
||||
},
|
||||
"silero-vad": {
|
||||
GRPCMethods: []GRPCMethod{MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseVAD},
|
||||
|
||||
@@ -49,6 +49,17 @@ type DistributedConfig struct {
|
||||
|
||||
AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"`
|
||||
JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"`
|
||||
|
||||
// PrefixCacheDisabled turns off prefix-cache-aware routing, falling back to
|
||||
// round-robin (the floor). Prefix-cache routing is ON by default in
|
||||
// distributed mode; this flag exists so operators can opt out. The CLI
|
||||
// surfaces a default-true --distributed-prefix-cache enable flag and sets
|
||||
// this when the operator passes --distributed-prefix-cache=false.
|
||||
PrefixCacheDisabled bool
|
||||
// PrefixCacheTTL is the idle-timeout for prefix-cache index entries and
|
||||
// drives the background eviction cadence (eviction runs every TTL/2). Zero
|
||||
// means use the prefixcache package default (5m).
|
||||
PrefixCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// Validate checks that the distributed configuration is internally consistent.
|
||||
@@ -158,6 +169,20 @@ var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// DisablePrefixCache turns off prefix-cache-aware routing (falls back to
|
||||
// round-robin). Prefix-cache routing is enabled by default in distributed mode.
|
||||
var DisablePrefixCache = func(o *ApplicationConfig) {
|
||||
o.Distributed.PrefixCacheDisabled = true
|
||||
}
|
||||
|
||||
// WithPrefixCacheTTL sets the prefix-cache index idle-timeout (and the
|
||||
// background eviction cadence, which runs every TTL/2).
|
||||
func WithPrefixCacheTTL(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.PrefixCacheTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
// Flag names for distributed timeout / interval configuration. These are
|
||||
// the kebab-case identifiers kong derives from the matching RunCMD struct
|
||||
// fields; they appear in Validate error messages and any other operator-
|
||||
|
||||
@@ -9,10 +9,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -55,7 +56,7 @@ var allowedFields = map[string]bool{
|
||||
func main() {
|
||||
fmt.Fprintf(os.Stderr, "Fetching %s ...\n", unslothURL)
|
||||
|
||||
resp, err := http.Get(unslothURL)
|
||||
resp, err := httpclient.New(httpclient.WithFollowRedirects()).Get(unslothURL)
|
||||
if err != nil {
|
||||
fatal("fetch failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -694,6 +694,18 @@ func (c *ModelConfig) IsModelURL() bool {
|
||||
return uri.LooksLikeURL()
|
||||
}
|
||||
|
||||
// ModelID returns the identifier used to reference this model across the
|
||||
// system: the configured Name, falling back to Model when Name is empty.
|
||||
// This is the single source of truth for the id fed to model.WithModelID and
|
||||
// the prefix-cache chain salt; both MUST agree with the router's tracking key
|
||||
// or the prefix-cache salt diverges silently.
|
||||
func (c ModelConfig) ModelID() string {
|
||||
if c.Name != "" {
|
||||
return c.Name
|
||||
}
|
||||
return c.Model
|
||||
}
|
||||
|
||||
// ModelFileName returns the filename of the model
|
||||
// If the model is a URL, it will return the MD5 of the URL which is the filename
|
||||
func (c *ModelConfig) ModelFileName() string {
|
||||
|
||||
@@ -10,6 +10,23 @@ import (
|
||||
)
|
||||
|
||||
var _ = Describe("Test cases for config related functions", func() {
|
||||
Context("ModelID", func() {
|
||||
It("returns Name when set", func() {
|
||||
c := ModelConfig{Name: "my-name"}
|
||||
c.Model = "my-model"
|
||||
Expect(c.ModelID()).To(Equal("my-name"))
|
||||
})
|
||||
It("falls back to Model when Name is empty", func() {
|
||||
c := ModelConfig{}
|
||||
c.Model = "my-model"
|
||||
Expect(c.ModelID()).To(Equal("my-model"))
|
||||
})
|
||||
It("returns empty string when both are empty", func() {
|
||||
c := ModelConfig{}
|
||||
Expect(c.ModelID()).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Test Read configuration functions", func() {
|
||||
It("Test Validate", func() {
|
||||
tmp, err := os.CreateTemp("", "config.yaml")
|
||||
|
||||
@@ -115,6 +115,10 @@ var defaultImporters = []Importer{
|
||||
&NemoImporter{},
|
||||
&FasterWhisperImporter{},
|
||||
&QwenASRImporter{},
|
||||
// ParakeetCppImporter matches only parakeet GGUFs (<arch>-<size>-<quant>.gguf);
|
||||
// kept ahead of LlamaCPPImporter so its .gguf bundles aren't claimed by the
|
||||
// generic GGUF importer.
|
||||
&ParakeetCppImporter{},
|
||||
// TTS (Batch 2)
|
||||
&PiperImporter{},
|
||||
&BarkImporter{},
|
||||
|
||||
180
core/gallery/importers/parakeet-cpp.go
Normal file
180
core/gallery/importers/parakeet-cpp.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &ParakeetCppImporter{}
|
||||
|
||||
// ParakeetCppImporter recognises parakeet.cpp GGUF weights, the C++/ggml port
|
||||
// of NVIDIA NeMo Parakeet. The signal is narrow on purpose: parakeet.cpp names
|
||||
// its weights "<arch>-<size>-<quant>.gguf" (e.g. tdt_ctc-110m-f16.gguf,
|
||||
// rnnt-0.6b-q4_k.gguf, realtime_eou_120m-v1-q8_0.gguf), so we only match a
|
||||
// .gguf whose name carries a parakeet architecture token. That keeps us from
|
||||
// claiming arbitrary llama-style GGUFs (the importer is registered before
|
||||
// llama-cpp), and it deliberately does NOT match the upstream nvidia/parakeet-*
|
||||
// NeMo repos (which ship .nemo checkpoints, not runnable GGUFs).
|
||||
// preferences.backend="parakeet-cpp" forces the importer regardless.
|
||||
type ParakeetCppImporter struct{}
|
||||
|
||||
func (i *ParakeetCppImporter) Name() string { return "parakeet-cpp" }
|
||||
func (i *ParakeetCppImporter) Modality() string { return "asr" }
|
||||
func (i *ParakeetCppImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *ParakeetCppImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "parakeet-cpp" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Direct URL or path to a parakeet GGUF.
|
||||
if isParakeetGGUF(filepath.Base(details.URI)) {
|
||||
return true
|
||||
}
|
||||
|
||||
// HF repo shipping at least one parakeet GGUF.
|
||||
if details.HuggingFace != nil {
|
||||
for _, f := range details.HuggingFace.Files {
|
||||
if isParakeetGGUF(filepath.Base(f.Path)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *ParakeetCppImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// parakeet quants are near-lossless even at Q4_K (WER 0.0 vs NeMo on 110m),
|
||||
// so default to the smallest, then fall back up the size ladder; the last
|
||||
// file wins if none match (mirrors whisper / llama-cpp).
|
||||
preferredQuants, _ := preferencesMap["quantizations"].(string)
|
||||
quants := []string{"q4_k", "q5_k", "q6_k", "q8_0", "f16"}
|
||||
if preferredQuants != "" {
|
||||
quants = strings.Split(preferredQuants, ",")
|
||||
}
|
||||
|
||||
cfg := gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "parakeet-cpp",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
}
|
||||
|
||||
uri := downloader.URI(details.URI)
|
||||
directGGUF := isParakeetGGUF(filepath.Base(details.URI))
|
||||
switch {
|
||||
case uri.LooksLikeURL() && directGGUF:
|
||||
// Direct file URL (e.g. .../resolve/main/tdt_ctc-110m-f16.gguf). The
|
||||
// exact file is known, no quant pick.
|
||||
fileName, err := uri.FilenameFromUrl()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
target := filepath.Join("parakeet-cpp", "models", name, fileName)
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: details.URI,
|
||||
Filename: target,
|
||||
})
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: target},
|
||||
}
|
||||
case details.HuggingFace != nil:
|
||||
// HF repo: collect every parakeet GGUF, pick the preferred quant, and
|
||||
// nest under parakeet-cpp/models/<name>/ so a multi-quant repo doesn't
|
||||
// collide on disk.
|
||||
var ggufFiles []hfapi.ModelFile
|
||||
for _, f := range details.HuggingFace.Files {
|
||||
if isParakeetGGUF(filepath.Base(f.Path)) {
|
||||
ggufFiles = append(ggufFiles, f)
|
||||
}
|
||||
}
|
||||
if chosen, ok := pickPreferredGGMLFile(ggufFiles, quants); ok {
|
||||
target := filepath.Join("parakeet-cpp", "models", name, filepath.Base(chosen.Path))
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: chosen.URL,
|
||||
Filename: target,
|
||||
SHA256: chosen.SHA256,
|
||||
})
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: target},
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Bare URI with no HF metadata (pref-only path): point at the basename
|
||||
// so users can tweak the YAML after import.
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: filepath.Base(details.URI)},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
cfg.ConfigFile = string(data)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// isParakeetGGUF reports whether name is a parakeet.cpp GGUF: a .gguf file
|
||||
// whose name carries a parakeet architecture token. The .gguf check is
|
||||
// case-insensitive; the tokens cover the published naming
|
||||
// (<arch>-<size>-<quant>.gguf) plus a generic "parakeet" fallback.
|
||||
func isParakeetGGUF(name string) bool {
|
||||
lower := strings.ToLower(name)
|
||||
if !strings.HasSuffix(lower, ".gguf") {
|
||||
return false
|
||||
}
|
||||
for _, tok := range []string{"tdt_ctc", "tdt-", "tdt_", "rnnt", "ctc-", "ctc_", "realtime_eou", "parakeet"} {
|
||||
if strings.Contains(lower, tok) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
103
core/gallery/importers/parakeet-cpp_test.go
Normal file
103
core/gallery/importers/parakeet-cpp_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// hfWith builds Details carrying a synthetic HF file list so detection can be
|
||||
// exercised without hitting the network.
|
||||
func parakeetDetails(uri string, prefs string, files ...hfapi.ModelFile) importers.Details {
|
||||
return importers.Details{
|
||||
URI: uri,
|
||||
Preferences: json.RawMessage(prefs),
|
||||
HuggingFace: &hfapi.ModelDetails{Files: files},
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("ParakeetCppImporter", func() {
|
||||
imp := &importers.ParakeetCppImporter{}
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
Expect(imp.Name()).To(Equal("parakeet-cpp"))
|
||||
Expect(imp.Modality()).To(Equal("asr"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("detection (Match)", func() {
|
||||
It("matches an HF repo shipping a parakeet GGUF", func() {
|
||||
d := parakeetDetails("huggingface://mudler/parakeet-cpp-gguf", `{}`,
|
||||
hfapi.ModelFile{Path: "tdt_ctc-110m-f16.gguf"},
|
||||
hfapi.ModelFile{Path: "README.md"},
|
||||
)
|
||||
Expect(imp.Match(d)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches a direct URL to a parakeet GGUF", func() {
|
||||
d := parakeetDetails("https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/rnnt-0.6b-q4_k.gguf", `{}`)
|
||||
Expect(imp.Match(d)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("honours preferences.backend=parakeet-cpp for arbitrary URIs", func() {
|
||||
d := parakeetDetails("https://example.com/whatever", `{"backend": "parakeet-cpp"}`)
|
||||
Expect(imp.Match(d)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does NOT claim a generic llama-style GGUF", func() {
|
||||
d := parakeetDetails("huggingface://someorg/some-llm-gguf", `{}`,
|
||||
hfapi.ModelFile{Path: "llama-3-8b-instruct-q4_k_m.gguf"},
|
||||
)
|
||||
Expect(imp.Match(d)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does NOT claim the upstream NeMo repo (.nemo, no GGUF)", func() {
|
||||
d := parakeetDetails("huggingface://nvidia/parakeet-tdt_ctc-110m", `{}`,
|
||||
hfapi.ModelFile{Path: "parakeet-tdt_ctc-110m.nemo"},
|
||||
)
|
||||
Expect(imp.Match(d)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("import (Import)", func() {
|
||||
It("picks the default quant (q4_k) from a multi-quant HF repo", func() {
|
||||
d := parakeetDetails("huggingface://mudler/parakeet-cpp-gguf", `{"name":"parakeet-110m"}`,
|
||||
hfapi.ModelFile{Path: "tdt_ctc-110m-f16.gguf", URL: "https://hf/f16", SHA256: "aaa"},
|
||||
hfapi.ModelFile{Path: "tdt_ctc-110m-q4_k.gguf", URL: "https://hf/q4k", SHA256: "bbb"},
|
||||
hfapi.ModelFile{Path: "tdt_ctc-110m-q8_0.gguf", URL: "https://hf/q8", SHA256: "ccc"},
|
||||
)
|
||||
cfg, err := imp.Import(d)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.ConfigFile).To(ContainSubstring("backend: parakeet-cpp"), fmt.Sprintf("%+v", cfg))
|
||||
Expect(cfg.ConfigFile).To(ContainSubstring("transcript"))
|
||||
Expect(cfg.Files).To(HaveLen(1))
|
||||
Expect(cfg.Files[0].URI).To(Equal("https://hf/q4k"), "default quant should be q4_k")
|
||||
Expect(cfg.Files[0].Filename).To(ContainSubstring("parakeet-cpp/models/parakeet-110m/tdt_ctc-110m-q4_k.gguf"))
|
||||
})
|
||||
|
||||
It("honours a preferred quantization override", func() {
|
||||
d := parakeetDetails("huggingface://mudler/parakeet-cpp-gguf", `{"name":"p","quantizations":"q8_0"}`,
|
||||
hfapi.ModelFile{Path: "tdt_ctc-110m-f16.gguf", URL: "https://hf/f16"},
|
||||
hfapi.ModelFile{Path: "tdt_ctc-110m-q8_0.gguf", URL: "https://hf/q8"},
|
||||
)
|
||||
cfg, err := imp.Import(d)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Files).To(HaveLen(1))
|
||||
Expect(cfg.Files[0].URI).To(Equal("https://hf/q8"))
|
||||
})
|
||||
|
||||
It("uses the exact file for a direct GGUF URL", func() {
|
||||
d := parakeetDetails("https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/ctc-0.6b-q5_k.gguf", `{"name":"ctc"}`)
|
||||
cfg, err := imp.Import(d)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Files).To(HaveLen(1))
|
||||
Expect(cfg.Files[0].Filename).To(ContainSubstring("parakeet-cpp/models/ctc/ctc-0.6b-q5_k.gguf"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
githubOAuth "golang.org/x/oauth2/github"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// providerEntry holds the OAuth2/OIDC config for a single provider.
|
||||
@@ -389,7 +391,7 @@ func fetchGitHubUserInfoAsOAuth(ctx context.Context, accessToken string) (*oauth
|
||||
}
|
||||
|
||||
func fetchGitHubUserInfo(ctx context.Context, accessToken string) (*githubUserInfo, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
client := httpclient.NewWithTimeout(10 * time.Second)
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
@@ -420,7 +422,7 @@ func fetchGitHubUserInfo(ctx context.Context, accessToken string) (*githubUserIn
|
||||
}
|
||||
|
||||
func fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
client := httpclient.NewWithTimeout(10 * time.Second)
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
@@ -458,7 +460,6 @@ func fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, e
|
||||
return "", fmt.Errorf("no verified email found")
|
||||
}
|
||||
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -22,7 +24,9 @@ import (
|
||||
// decoding on the leading `data:` bytes.
|
||||
var audioDataURIPattern = regexp.MustCompile(`^data:[^,]+?;base64,`)
|
||||
|
||||
var audioDownloadClient = http.Client{Timeout: 30 * time.Second}
|
||||
// Downloading user-supplied media URLs legitimately follows redirects (CDNs);
|
||||
// WithFollowRedirects still strips any credential header on a cross-host hop.
|
||||
var audioDownloadClient = httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects())
|
||||
|
||||
// decodeAudioInput materialises a URL / data-URI / raw-base64 audio
|
||||
// payload to a temporary file and returns its path plus a cleanup
|
||||
|
||||
@@ -31,6 +31,7 @@ var knownPrefOnlyBackends = []schema.KnownBackend{
|
||||
{Name: "mlx-vlm", Modality: "text", AutoDetect: false, Description: "MLX vision-language models (preference-only)"},
|
||||
// ASR
|
||||
{Name: "whisperx", Modality: "asr", AutoDetect: false, Description: "WhisperX transcription (preference-only)"},
|
||||
{Name: "crispasr", Modality: "asr", AutoDetect: false, Description: "CrispASR multi-architecture transcription (preference-only)"},
|
||||
// TTS
|
||||
{Name: "kokoros", Modality: "tts", AutoDetect: false, Description: "Kokoros TTS (preference-only)"},
|
||||
{Name: "qwen-tts", Modality: "tts", AutoDetect: false, Description: "Qwen TTS (preference-only)"},
|
||||
|
||||
@@ -140,6 +140,7 @@ var _ = Describe("Backend Endpoints", func() {
|
||||
expectPrefOnly("trl", "text")
|
||||
expectPrefOnly("mlx-vlm", "text")
|
||||
expectPrefOnly("whisperx", "asr")
|
||||
expectPrefOnly("crispasr", "asr")
|
||||
expectPrefOnly("kokoros", "tts")
|
||||
expectPrefOnly("qwen-tts", "tts")
|
||||
expectPrefOnly("qwen3-tts-cpp", "tts")
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
// CORSProxyEndpoint proxies HTTP requests to external MCP servers,
|
||||
@@ -77,7 +79,7 @@ func CORSProxyEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
)
|
||||
},
|
||||
}
|
||||
client := &http.Client{Transport: transport, Timeout: 10 * time.Minute}
|
||||
client := httpclient.New(httpclient.WithTransport(transport), httpclient.WithTimeout(10*time.Minute))
|
||||
|
||||
xlog.Debug("CORS proxy request", "method", c.Request().Method, "target", targetURL)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -16,14 +17,17 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// nodeError builds a schema.ErrorResponse for node endpoints.
|
||||
@@ -65,15 +69,15 @@ func GetNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
|
||||
// RegisterNodeRequest is the request body for registering a new worker node.
|
||||
type RegisterNodeRequest struct {
|
||||
Name string `json:"name"`
|
||||
NodeType string `json:"node_type,omitempty"` // "backend" (default) or "agent"
|
||||
Address string `json:"address"`
|
||||
HTTPAddress string `json:"http_address,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TotalVRAM uint64 `json:"total_vram,omitempty"`
|
||||
AvailableVRAM uint64 `json:"available_vram,omitempty"`
|
||||
TotalRAM uint64 `json:"total_ram,omitempty"`
|
||||
AvailableRAM uint64 `json:"available_ram,omitempty"`
|
||||
Name string `json:"name"`
|
||||
NodeType string `json:"node_type,omitempty"` // "backend" (default) or "agent"
|
||||
Address string `json:"address"`
|
||||
HTTPAddress string `json:"http_address,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TotalVRAM uint64 `json:"total_vram,omitempty"`
|
||||
AvailableVRAM uint64 `json:"available_vram,omitempty"`
|
||||
TotalRAM uint64 `json:"total_ram,omitempty"`
|
||||
AvailableRAM uint64 `json:"available_ram,omitempty"`
|
||||
GPUVendor string `json:"gpu_vendor,omitempty"`
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
// MaxReplicasPerModel is the per-node cap on replicas of any single model.
|
||||
@@ -909,14 +913,56 @@ func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
// SetSchedulingRequest is the request body for creating/updating a scheduling config.
|
||||
//
|
||||
// The four prefix-cache fields are POINTERS so an omitted field is
|
||||
// distinguishable from an explicit zero. On update, an omitted prefix-cache
|
||||
// field preserves the model's previously-configured value instead of resetting
|
||||
// it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector,
|
||||
// MinReplicas and MaxReplicas keep their full-replace PUT semantics.
|
||||
type SetSchedulingRequest struct {
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
RoutePolicy *string `json:"route_policy,omitempty"`
|
||||
BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"`
|
||||
BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"`
|
||||
MinPrefixMatch *float64 `json:"min_prefix_match,omitempty"`
|
||||
}
|
||||
|
||||
// validateSchedulingRequest enforces the invariants of a scheduling config.
|
||||
// The prefix-cache bounds are delegated to prefixcache.ValidateThresholds (the
|
||||
// single source of truth), and are checked against the RESOLVED values passed
|
||||
// in (provided-or-preserved), so validation only rejects bad values the caller
|
||||
// actually supplied. It returns nil when valid, or an error with a user-facing
|
||||
// message describing the first violation.
|
||||
func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, absThr int, relThr, minMatch float64) error {
|
||||
if req.ModelName == "" {
|
||||
return errors.New("model_name is required")
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return errors.New("min_replicas must be >= 0")
|
||||
}
|
||||
if req.MaxReplicas < 0 {
|
||||
return errors.New("max_replicas must be >= 0")
|
||||
}
|
||||
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
||||
return errors.New("min_replicas must be <= max_replicas")
|
||||
}
|
||||
if err := prefixcache.ValidateThresholds(routePolicy, absThr, relThr, minMatch); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSchedulingEndpoint creates or updates a model scheduling config.
|
||||
//
|
||||
// The registry upsert full-replaces all columns, so a request that omits the
|
||||
// prefix-cache fields would otherwise wipe a model's previously-configured
|
||||
// routing settings. To avoid that footgun the four prefix-cache fields are
|
||||
// merged PATCH-style: a non-nil request pointer wins; a nil one preserves the
|
||||
// existing config's value (or the zero default when no config exists yet). The
|
||||
// non-prefix fields keep their full-replace PUT behavior.
|
||||
func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
@@ -924,17 +970,45 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
||||
}
|
||||
if req.ModelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required"))
|
||||
|
||||
// Fetch the existing config (may be nil) so omitted prefix-cache fields
|
||||
// can fall back to the stored value rather than resetting to zero.
|
||||
var existing *nodes.ModelSchedulingConfig
|
||||
if req.ModelName != "" {
|
||||
var err error
|
||||
existing, err = registry.GetModelScheduling(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to load existing scheduling config"))
|
||||
}
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0"))
|
||||
|
||||
// Resolve each prefix-cache field: provided pointer wins, otherwise keep
|
||||
// the existing value (zero/default when there is no existing config).
|
||||
routePolicy := ""
|
||||
absThr := 0
|
||||
relThr := 0.0
|
||||
minMatch := 0.0
|
||||
if existing != nil {
|
||||
routePolicy = existing.RoutePolicy
|
||||
absThr = existing.BalanceAbsThreshold
|
||||
relThr = existing.BalanceRelThreshold
|
||||
minMatch = existing.MinPrefixMatch
|
||||
}
|
||||
if req.MaxReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0"))
|
||||
if req.RoutePolicy != nil {
|
||||
routePolicy = *req.RoutePolicy
|
||||
}
|
||||
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas"))
|
||||
if req.BalanceAbsThreshold != nil {
|
||||
absThr = *req.BalanceAbsThreshold
|
||||
}
|
||||
if req.BalanceRelThreshold != nil {
|
||||
relThr = *req.BalanceRelThreshold
|
||||
}
|
||||
if req.MinPrefixMatch != nil {
|
||||
minMatch = *req.MinPrefixMatch
|
||||
}
|
||||
|
||||
if err := validateSchedulingRequest(req, routePolicy, absThr, relThr, minMatch); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, err.Error()))
|
||||
}
|
||||
|
||||
// Serialize node selector to JSON
|
||||
@@ -948,10 +1022,14 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
config := &nodes.ModelSchedulingConfig{
|
||||
ModelName: req.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
ModelName: req.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
RoutePolicy: routePolicy,
|
||||
BalanceAbsThreshold: absThr,
|
||||
BalanceRelThreshold: relThr,
|
||||
MinPrefixMatch: minMatch,
|
||||
}
|
||||
if err := registry.SetModelScheduling(ctx, config); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
|
||||
@@ -983,6 +1061,6 @@ func proxyHTTPToWorker(httpAddress, path, token string) (*http.Response, error)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
client := httpclient.NewWithTimeout(15 * time.Second)
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("validateSchedulingRequest", func() {
|
||||
base := func() SetSchedulingRequest {
|
||||
return SetSchedulingRequest{ModelName: "m"}
|
||||
}
|
||||
|
||||
It("accepts an empty route policy (inherit) with valid thresholds", func() {
|
||||
Expect(validateSchedulingRequest(base(), "", 3, 0, 0.4)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts the prefix_cache policy", func() {
|
||||
Expect(validateSchedulingRequest(base(), "prefix_cache", 0, 0, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts the round_robin policy", func() {
|
||||
Expect(validateSchedulingRequest(base(), "round_robin", 0, 0, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts balance_rel_threshold >= 1", func() {
|
||||
Expect(validateSchedulingRequest(base(), "", 0, 1.5, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects a missing model_name", func() {
|
||||
req := base()
|
||||
req.ModelName = ""
|
||||
err := validateSchedulingRequest(req, "", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("model_name is required"))
|
||||
})
|
||||
|
||||
It("rejects an unknown route_policy (no silent default)", func() {
|
||||
err := validateSchedulingRequest(base(), "bogus", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("rejects min_prefix_match above 1", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0, 2)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative min_prefix_match", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0, -0.1)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative balance_abs_threshold", func() {
|
||||
err := validateSchedulingRequest(base(), "", -1, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_abs_threshold"))
|
||||
})
|
||||
|
||||
It("rejects balance_rel_threshold between 0 and 1 exclusive", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0.5, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
})
|
||||
@@ -230,6 +230,114 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("SetSchedulingEndpoint", func() {
|
||||
postScheduling := func(body string) *httptest.ResponseRecorder {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
handler := SetSchedulingEndpoint(registry)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
It("persists prefix-cache fields and round-trips them via GET", func() {
|
||||
ctx := context.Background()
|
||||
rec := postScheduling(`{"model_name":"pc-model","route_policy":"prefix_cache","balance_abs_threshold":3,"min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "pc-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.RoutePolicy).To(Equal("prefix_cache"))
|
||||
Expect(cfg.BalanceAbsThreshold).To(Equal(3))
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
|
||||
|
||||
e := echo.New()
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
getRec := httptest.NewRecorder()
|
||||
gc := e.NewContext(getReq, getRec)
|
||||
gc.SetParamNames("model")
|
||||
gc.SetParamValues("pc-model")
|
||||
Expect(GetSchedulingEndpoint(registry)(gc)).To(Succeed())
|
||||
Expect(getRec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var got nodes.ModelSchedulingConfig
|
||||
Expect(json.Unmarshal(getRec.Body.Bytes(), &got)).To(Succeed())
|
||||
Expect(got.RoutePolicy).To(Equal("prefix_cache"))
|
||||
Expect(got.BalanceAbsThreshold).To(Equal(3))
|
||||
Expect(got.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
|
||||
})
|
||||
|
||||
It("returns 400 for an out-of-range min_prefix_match", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-mpm","min_prefix_match":2}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("returns 400 for an unknown route_policy", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-policy","route_policy":"bogus"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("returns 400 for a balance_rel_threshold between 0 and 1", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-rel","balance_rel_threshold":0.5}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
|
||||
// Regression for the partial-update footgun: a min/max-only POST used to
|
||||
// full-replace every column and silently reset the prefix-cache settings
|
||||
// to empty/zero. The pointer-merge must preserve omitted prefix fields.
|
||||
It("preserves prefix-cache settings across a min_replicas-only update", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
rec := postScheduling(`{"model_name":"merge-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
// Update only min_replicas - omits all prefix-cache fields.
|
||||
rec = postScheduling(`{"model_name":"merge-model","min_replicas":2}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "merge-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.MinReplicas).To(Equal(2), "the provided non-prefix field must update")
|
||||
Expect(cfg.RoutePolicy).To(Equal("prefix_cache"), "omitted route_policy must be preserved")
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must be preserved")
|
||||
})
|
||||
|
||||
It("updates a prefix-cache field when it is explicitly provided", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
rec := postScheduling(`{"model_name":"update-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
rec = postScheduling(`{"model_name":"update-model","route_policy":"round_robin"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "update-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.RoutePolicy).To(Equal("round_robin"), "explicitly provided route_policy must update")
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must still be preserved")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ListNodesEndpoint", func() {
|
||||
It("returns an empty list when no nodes are registered", func() {
|
||||
e := echo.New()
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -15,18 +14,23 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
var videoDownloadClient = http.Client{Timeout: 30 * time.Second}
|
||||
// Downloading user-supplied media URLs legitimately follows redirects (CDNs);
|
||||
// WithFollowRedirects still strips any credential header on a cross-host hop.
|
||||
var videoDownloadClient = httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects())
|
||||
|
||||
func downloadFile(url string) (string, error) {
|
||||
if err := utils.ValidateExternalURL(url); err != nil {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
@@ -180,10 +181,10 @@ func SessionsFromMCPConfig(
|
||||
for _, server := range remote.Servers {
|
||||
xlog.Debug("[MCP remote server] Configuration", "server", server)
|
||||
// Create HTTP client with custom roundtripper for bearer token injection
|
||||
httpClient := &http.Client{
|
||||
Timeout: config.DefaultMCPToolTimeout,
|
||||
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
|
||||
}
|
||||
httpClient := httpclient.New(
|
||||
httpclient.WithTimeout(config.DefaultMCPToolTimeout),
|
||||
httpclient.WithTransport(newBearerTokenRoundTripper(server.Token, httpclient.HardenedTransport())),
|
||||
)
|
||||
|
||||
transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient}
|
||||
mcpSession, err := client.Connect(ctx, transport, nil)
|
||||
@@ -262,10 +263,10 @@ func NamedSessionsFromMCPConfig(
|
||||
|
||||
for serverName, server := range remote.Servers {
|
||||
xlog.Debug("[MCP remote server] Configuration", "name", serverName, "server", server)
|
||||
httpClient := &http.Client{
|
||||
Timeout: config.DefaultMCPToolTimeout,
|
||||
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
|
||||
}
|
||||
httpClient := httpclient.New(
|
||||
httpclient.WithTimeout(config.DefaultMCPToolTimeout),
|
||||
httpclient.WithTransport(newBearerTokenRoundTripper(server.Token, httpclient.HardenedTransport())),
|
||||
)
|
||||
|
||||
transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient}
|
||||
mcpSession, err := client.Connect(ctx, transport, nil)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -16,15 +15,18 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func downloadFile(url string) (string, error) {
|
||||
@@ -33,7 +35,7 @@ func downloadFile(url string) (string, error) {
|
||||
}
|
||||
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
resp, err := httpclient.New(httpclient.WithFollowRedirects()).Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -17,7 +17,10 @@ func SecurityHeaders() echo.MiddlewareFunc {
|
||||
"img-src 'self' data: blob: https:; " +
|
||||
"media-src 'self' data: blob:; " +
|
||||
"font-src 'self' data:; " +
|
||||
"connect-src 'self' ws: wss: https:; " +
|
||||
// blob: lets the waveform renderer XHR/fetch a freshly-created object
|
||||
// URL (e.g. an uploaded clip before it has a server URL). XHR/fetch of
|
||||
// blob: falls under connect-src, not media-src.
|
||||
"connect-src 'self' ws: wss: https: blob:; " +
|
||||
"frame-src 'self' blob:; " +
|
||||
"worker-src 'self' blob:; " +
|
||||
"object-src 'none'; " +
|
||||
|
||||
@@ -32,6 +32,9 @@ var _ = Describe("SecurityHeaders", func() {
|
||||
Expect(csp).To(ContainSubstring("frame-ancestors 'self'"))
|
||||
Expect(csp).To(ContainSubstring("object-src 'none'"))
|
||||
Expect(csp).To(ContainSubstring("base-uri 'self'"))
|
||||
// blob: must be in connect-src so the waveform renderer can XHR/fetch
|
||||
// a freshly-created object URL (uploaded/enhanced clip).
|
||||
Expect(csp).To(ContainSubstring("connect-src 'self' ws: wss: https: blob:"))
|
||||
})
|
||||
|
||||
It("sets X-Content-Type-Options: nosniff", func() {
|
||||
|
||||
@@ -1 +1 @@
|
||||
38.29
|
||||
40.0
|
||||
@@ -20,5 +20,10 @@ test.describe('Agents page', () => {
|
||||
page.waitForURL(/\/app\/agents\/new$/),
|
||||
create.click(),
|
||||
])
|
||||
// Wait for AgentCreate.jsx to actually render, not just for the URL to
|
||||
// change. Ending the test the instant the route matched let the component
|
||||
// mount race the coverage teardown — its ~400 lines were collected only
|
||||
// when the render won, swinging total UI coverage ~1pp run-to-run.
|
||||
await expect(page.getByRole('heading', { name: 'Create Agent' })).toBeVisible()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -66,6 +66,33 @@ function makeFakeWav(name) {
|
||||
return { name, mimeType: 'audio/wav', buffer: buf }
|
||||
}
|
||||
|
||||
// Build a WAV carrying a real sine tone, long enough that the spectrogram
|
||||
// STFT produces several frames (a few thousand samples). Used to exercise the
|
||||
// FFT / heatmap path, which the 4-sample silent fixture can't.
|
||||
function makeToneWav(name, freq = 1000, seconds = 0.4, sampleRate = 16000) {
|
||||
const samples = Math.floor(seconds * sampleRate)
|
||||
const dataLen = samples * 2
|
||||
const buf = Buffer.alloc(44 + dataLen)
|
||||
buf.write('RIFF', 0)
|
||||
buf.writeUInt32LE(36 + dataLen, 4)
|
||||
buf.write('WAVE', 8)
|
||||
buf.write('fmt ', 12)
|
||||
buf.writeUInt32LE(16, 16)
|
||||
buf.writeUInt16LE(1, 20)
|
||||
buf.writeUInt16LE(1, 22)
|
||||
buf.writeUInt32LE(sampleRate, 24)
|
||||
buf.writeUInt32LE(sampleRate * 2, 28)
|
||||
buf.writeUInt16LE(2, 32)
|
||||
buf.writeUInt16LE(16, 34)
|
||||
buf.write('data', 36)
|
||||
buf.writeUInt32LE(dataLen, 40)
|
||||
for (let i = 0; i < samples; i++) {
|
||||
const v = Math.round(Math.sin((2 * Math.PI * freq * i) / sampleRate) * 16000)
|
||||
buf.writeInt16LE(v, 44 + i * 2)
|
||||
}
|
||||
return { name, mimeType: 'audio/wav', buffer: buf }
|
||||
}
|
||||
|
||||
test.describe('Audio Transform', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await mockCapabilities(page, [
|
||||
@@ -169,6 +196,26 @@ test.describe('Audio Transform', () => {
|
||||
await expect(page.getByTestId('media-history-item')).toHaveCount(1)
|
||||
})
|
||||
|
||||
test('renders an input spectrogram on upload and an output one after transform', async ({ page }) => {
|
||||
mockAudioTransform(page, 'enhanced.wav')
|
||||
|
||||
await page.goto('/app/transform')
|
||||
await expect(page.getByRole('button', { name: 'localvqe' })).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
// Choosing a clip should render its input spectrogram immediately — no
|
||||
// backend round-trip needed (it's computed client-side from the bytes).
|
||||
await page.locator('input[type="file"]').first().setInputFiles(makeToneWav('tone.wav'))
|
||||
await expect(page.getByTestId('spectrogram-input')).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
// Until a transform runs the output side shows a "compare" placeholder.
|
||||
await expect(page.getByText(/Transform to compare/)).toBeVisible()
|
||||
|
||||
await page.getByRole('button', { name: /Transform/ }).last().click()
|
||||
|
||||
// After processing, the output spectrum panel appears alongside the input.
|
||||
await expect(page.getByText('Output spectrum')).toBeVisible({ timeout: 10_000 })
|
||||
})
|
||||
|
||||
test('shows an error banner when the backend returns 4xx', async ({ page }) => {
|
||||
await page.route('**/audio/transformations', (route) => {
|
||||
if (route.request().method() !== 'POST') return route.continue()
|
||||
|
||||
37
core/http/react-ui/e2e/cluster.spec.js
Normal file
37
core/http/react-ui/e2e/cluster.spec.js
Normal file
@@ -0,0 +1,37 @@
|
||||
import { test, expect } from './coverage-fixtures.js'
|
||||
|
||||
// The Cluster page composes two capability sections: "Distributed (NATS)" (the
|
||||
// former Nodes page) and "Swarm (p2p)" (the former P2P page). Each section only
|
||||
// mounts when its mode is enabled — distributed when /api/nodes answers OK, swarm
|
||||
// when a non-empty p2p network token is present. We mock those probes so the page
|
||||
// renders against the standalone ui-test-server without NATS / p2p running.
|
||||
|
||||
async function mockDistributedOnly(page) {
|
||||
await page.route('**/api/nodes', (route) => {
|
||||
route.fulfill({ status: 200, contentType: 'application/json', body: '[]' })
|
||||
})
|
||||
await page.route('**/api/nodes/scheduling', (route) => {
|
||||
route.fulfill({ status: 200, contentType: 'application/json', body: '[]' })
|
||||
})
|
||||
// Swarm disabled: token probe fails, so the swarm section stays hidden.
|
||||
await page.route('**/api/p2p/token', (route) => {
|
||||
route.fulfill({ status: 503, contentType: 'text/plain', body: '' })
|
||||
})
|
||||
}
|
||||
|
||||
test.describe('Cluster page', () => {
|
||||
test('shows the page title', async ({ page }) => {
|
||||
await mockDistributedOnly(page)
|
||||
await page.goto('/app/cluster')
|
||||
await expect(page).toHaveURL(/\/app\/cluster$/)
|
||||
await expect(page.getByRole('heading', { name: /Cluster/i })).toBeVisible()
|
||||
})
|
||||
|
||||
test('shows the distributed section when /api/nodes responds', async ({ page }) => {
|
||||
await mockDistributedOnly(page)
|
||||
await page.goto('/app/cluster')
|
||||
await expect(page).toHaveURL(/\/app\/cluster$/)
|
||||
// The distributed capability section is titled "Distributed (NATS)".
|
||||
await expect(page.getByText(/Distributed \(NATS\)/i)).toBeVisible()
|
||||
})
|
||||
})
|
||||
@@ -23,4 +23,11 @@ test.describe('Navigation', () => {
|
||||
await expect(page).toHaveURL(/\/app\/traces/)
|
||||
await expect(page.getByRole('heading', { name: 'Traces', exact: true })).toBeVisible()
|
||||
})
|
||||
|
||||
test('old cluster routes redirect to /app/cluster', async ({ page }) => {
|
||||
await page.goto('/app/p2p')
|
||||
await expect(page).toHaveURL(/\/app\/cluster$/)
|
||||
await page.goto('/app/nodes')
|
||||
await expect(page).toHaveURL(/\/app\/cluster$/)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -81,7 +81,7 @@ async function mockDistributedNodes(page, { onDelete } = {}) {
|
||||
}
|
||||
|
||||
async function expandNodeAndWaitForBackends(page) {
|
||||
await page.goto('/app/nodes')
|
||||
await page.goto('/app/cluster')
|
||||
// Click the row to expand it. The chevron toggle and the row both work,
|
||||
// but clicking the name cell is the most user-like.
|
||||
await page.getByText(NODE_NAME).first().click()
|
||||
|
||||
@@ -1,26 +1,65 @@
|
||||
import { test, expect } from './coverage-fixtures.js'
|
||||
|
||||
// P2P (Swarm) admin page — renders in the no-auth test harness (isAdmin).
|
||||
test.describe('P2P page', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
// The standalone P2P (Swarm) page was merged into the Cluster page: /app/p2p now
|
||||
// redirects to /app/cluster, and the p2p content lives under the "Swarm (p2p)"
|
||||
// section. That section only mounts when p2p is enabled (a network token is
|
||||
// present), so we mock /api/p2p/token to return a non-empty token and assert the
|
||||
// swarm content renders under the cluster page.
|
||||
const P2P_TOKEN = 'test-network-token'
|
||||
|
||||
async function mockSwarmEnabled(page) {
|
||||
await page.route('**/api/p2p/token', (route) => {
|
||||
route.fulfill({
|
||||
status: 200,
|
||||
contentType: 'text/plain',
|
||||
body: P2P_TOKEN,
|
||||
})
|
||||
})
|
||||
await page.route('**/api/p2p/workers', (route) => {
|
||||
route.fulfill({ status: 200, contentType: 'application/json', body: '{"nodes":[]}' })
|
||||
})
|
||||
await page.route('**/api/p2p/federation', (route) => {
|
||||
route.fulfill({ status: 200, contentType: 'application/json', body: '{"nodes":[]}' })
|
||||
})
|
||||
await page.route('**/api/p2p/stats', (route) => {
|
||||
route.fulfill({
|
||||
status: 200,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({
|
||||
llama_cpp_workers: { online: 0, total: 0 },
|
||||
federated: { online: 0, total: 0 },
|
||||
mlx_workers: { online: 0, total: 0 },
|
||||
}),
|
||||
})
|
||||
})
|
||||
// The cluster page also probes /api/nodes for the distributed section; keep it
|
||||
// failing (distributed disabled) so only the swarm section renders here.
|
||||
await page.route('**/api/nodes', (route) => {
|
||||
route.fulfill({ status: 503, contentType: 'application/json', body: '{}' })
|
||||
})
|
||||
}
|
||||
|
||||
test.describe('P2P (Swarm) section on the Cluster page', () => {
|
||||
test('the old /app/p2p route lands on the cluster page', async ({ page }) => {
|
||||
await mockSwarmEnabled(page)
|
||||
// /app/p2p redirects to /app/cluster.
|
||||
await page.goto('/app/p2p')
|
||||
await expect(page).toHaveURL(/\/app\/cluster$/)
|
||||
await expect(page.getByRole('heading', { name: /Cluster/i })).toBeVisible()
|
||||
})
|
||||
|
||||
test('renders the P2P distribution overview and capability cards', async ({ page }) => {
|
||||
await expect(page).toHaveURL(/\/app\/p2p$/)
|
||||
await expect(page.getByRole('heading', { name: /P2P Distribution Not Enabled/i })).toBeVisible()
|
||||
await expect(page.getByRole('heading', { name: 'Instance Federation' })).toBeVisible()
|
||||
await expect(page.getByRole('heading', { name: 'Model Sharding' })).toBeVisible()
|
||||
await expect(page.getByRole('heading', { name: 'Resource Sharing' })).toBeVisible()
|
||||
await expect(page.getByRole('heading', { name: /How to Enable P2P/i })).toBeVisible()
|
||||
})
|
||||
test('renders the Swarm (p2p) section when p2p is enabled', async ({ page }) => {
|
||||
await mockSwarmEnabled(page)
|
||||
await page.goto('/app/cluster')
|
||||
await expect(page).toHaveURL(/\/app\/cluster$/)
|
||||
|
||||
test('hardware selector offers build targets and responds to selection', async ({ page }) => {
|
||||
const cpu = page.getByRole('button').filter({ hasText: /^CPU$/ })
|
||||
const cuda = page.getByRole('button').filter({ hasText: /^CUDA 12$/ })
|
||||
await expect(cpu).toBeVisible()
|
||||
await expect(cuda).toBeVisible()
|
||||
await cuda.click() // selecting a build target must not break the page
|
||||
await expect(page.getByRole('heading', { name: /How to Enable P2P/i })).toBeVisible()
|
||||
// The collapsible swarm section is titled "Swarm (p2p)".
|
||||
await expect(page.getByText(/Swarm \(p2p\)/i)).toBeVisible()
|
||||
|
||||
// The enabled p2p content (Network Token panel + the federation / sharding
|
||||
// tabs) is rendered inside the swarm section.
|
||||
await expect(page.getByRole('heading', { name: /Network Token/i })).toBeVisible()
|
||||
await expect(page.getByText('Federation', { exact: true })).toBeVisible()
|
||||
await expect(page.getByText('Model Sharding', { exact: true })).toBeVisible()
|
||||
})
|
||||
})
|
||||
|
||||
40
core/http/react-ui/e2e/page-render-smoke.spec.js
Normal file
40
core/http/react-ui/e2e/page-render-smoke.spec.js
Normal file
@@ -0,0 +1,40 @@
|
||||
import { test, expect } from './coverage-fixtures.js'
|
||||
|
||||
// Render-smoke coverage. Each page is lazy-loaded and runs its full render +
|
||||
// initial effects on mount, so a bare visit captures the bulk of a page's
|
||||
// lines — cheap, real coverage for pages that have no dedicated spec yet.
|
||||
//
|
||||
// This is the project's preferred way to keep the UI coverage gate green:
|
||||
// raise the floor by covering more, rather than loosening the gate's
|
||||
// tolerance (see CONTRIBUTING.md → "React UI coverage"). Auth is disabled in
|
||||
// the test server, so RequireAdmin/RequireFeature resolve to isAdmin=true and
|
||||
// every gated route renders without an auth/capability mock.
|
||||
//
|
||||
// Asserts the page mounted (its .page-title header is visible) and that it did
|
||||
// not bounce to a gate redirect (/login or back to /app home).
|
||||
const PAGES = [
|
||||
['/app/talk', 'Talk'],
|
||||
['/app/usage', 'Usage'],
|
||||
['/app/account', 'Account'],
|
||||
['/app/studio', 'Studio'],
|
||||
['/app/manage', 'Manage'],
|
||||
['/app/backends', 'Backends'],
|
||||
['/app/settings', 'Settings'],
|
||||
['/app/cluster', 'Cluster'],
|
||||
['/app/face', 'Face recognition'],
|
||||
['/app/voice', 'Voice recognition'],
|
||||
['/app/fine-tune', 'Fine-tuning'],
|
||||
['/app/quantize', 'Quantize'],
|
||||
]
|
||||
|
||||
test.describe('Page render smoke', () => {
|
||||
for (const [path, label] of PAGES) {
|
||||
test(`renders ${label} (${path})`, async ({ page }) => {
|
||||
await page.goto(path)
|
||||
// .page-title for the normal header; .empty-state-title for pages that
|
||||
// render a gated/empty state (e.g. Account when auth is disabled).
|
||||
await expect(page.locator('.page-title, .empty-state-title').first()).toBeVisible({ timeout: 15_000 })
|
||||
await expect(page).toHaveURL(new RegExp(path.replace(/\//g, '\\/') + '$'))
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -58,5 +58,21 @@
|
||||
"explorer": {
|
||||
"title": "Explorer",
|
||||
"subtitle": "Dateien und Konfiguration durchsuchen"
|
||||
},
|
||||
"cluster": {
|
||||
"title": "Cluster",
|
||||
"subtitle": "Verteilte und Peer-to-Peer-Knoten, die diese Instanz bedienen",
|
||||
"empty": "Es ist kein verteiltes oder p2p-Clustering aktiviert. Starte LocalAI im verteilten oder föderierten/p2p-Modus, um hier Cluster-Knoten zu verwalten.",
|
||||
"distributed": {
|
||||
"title": "Verteilt (NATS)"
|
||||
},
|
||||
"swarm": {
|
||||
"title": "Swarm (p2p)"
|
||||
},
|
||||
"summary": {
|
||||
"nodes": "Verteilte Knoten",
|
||||
"inFlight": "Laufende Anfragen",
|
||||
"peers": "Swarm-Peers online"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
"traces": "Traces",
|
||||
"nodes": "Knoten",
|
||||
"swarm": "Swarm",
|
||||
"cluster": "Cluster",
|
||||
"system": "System",
|
||||
"settings": "Einstellungen",
|
||||
"api": "API"
|
||||
|
||||
@@ -81,5 +81,21 @@
|
||||
"explorer": {
|
||||
"title": "Explorer",
|
||||
"subtitle": "Browse files and configuration"
|
||||
},
|
||||
"cluster": {
|
||||
"title": "Cluster",
|
||||
"subtitle": "Distributed and peer-to-peer nodes serving this instance",
|
||||
"empty": "No distributed or p2p clustering is enabled. Start LocalAI in distributed or federated/p2p mode to manage cluster nodes here.",
|
||||
"distributed": {
|
||||
"title": "Distributed (NATS)"
|
||||
},
|
||||
"swarm": {
|
||||
"title": "Swarm (p2p)"
|
||||
},
|
||||
"summary": {
|
||||
"nodes": "Distributed nodes",
|
||||
"inFlight": "In-flight requests",
|
||||
"peers": "Swarm peers online"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@
|
||||
"traces": "Traces",
|
||||
"nodes": "Nodes",
|
||||
"swarm": "Swarm",
|
||||
"cluster": "Cluster",
|
||||
"system": "System",
|
||||
"settings": "Settings",
|
||||
"api": "API"
|
||||
|
||||
@@ -58,5 +58,21 @@
|
||||
"explorer": {
|
||||
"title": "Explorador",
|
||||
"subtitle": "Explora archivos y configuración"
|
||||
},
|
||||
"cluster": {
|
||||
"title": "Clúster",
|
||||
"subtitle": "Nodos distribuidos y entre pares que sirven a esta instancia",
|
||||
"empty": "No hay clustering distribuido ni p2p habilitado. Inicia LocalAI en modo distribuido o federado/p2p para gestionar aquí los nodos del clúster.",
|
||||
"distributed": {
|
||||
"title": "Distribuido (NATS)"
|
||||
},
|
||||
"swarm": {
|
||||
"title": "Swarm (p2p)"
|
||||
},
|
||||
"summary": {
|
||||
"nodes": "Nodos distribuidos",
|
||||
"inFlight": "Solicitudes en curso",
|
||||
"peers": "Pares de Swarm en línea"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
"traces": "Trazas",
|
||||
"nodes": "Nodos",
|
||||
"swarm": "Swarm",
|
||||
"cluster": "Clúster",
|
||||
"system": "Sistema",
|
||||
"settings": "Configuración",
|
||||
"api": "API"
|
||||
|
||||
@@ -58,5 +58,21 @@
|
||||
"explorer": {
|
||||
"title": "Esplora risorse",
|
||||
"subtitle": "Sfoglia file e configurazioni"
|
||||
},
|
||||
"cluster": {
|
||||
"title": "Cluster",
|
||||
"subtitle": "Nodi distribuiti e peer-to-peer al servizio di questa istanza",
|
||||
"empty": "Nessun clustering distribuito o p2p è abilitato. Avvia LocalAI in modalità distribuita o federata/p2p per gestire qui i nodi del cluster.",
|
||||
"distributed": {
|
||||
"title": "Distribuito (NATS)"
|
||||
},
|
||||
"swarm": {
|
||||
"title": "Swarm (p2p)"
|
||||
},
|
||||
"summary": {
|
||||
"nodes": "Nodi distribuiti",
|
||||
"inFlight": "Richieste in corso",
|
||||
"peers": "Peer Swarm online"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user