mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-14 11:49:33 -04:00
Compare commits
193 Commits
update/RFD
...
v4.4.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d3d54d61b | ||
|
|
36e3419203 | ||
|
|
4ec6e3221e | ||
|
|
4bb592cf91 | ||
|
|
3e838c0cff | ||
|
|
36b4a81d1e | ||
|
|
0854932a25 | ||
|
|
203410871b | ||
|
|
7637f8cf1b | ||
|
|
f0e001b7f8 | ||
|
|
cf9debf4eb | ||
|
|
e1556aa1dc | ||
|
|
53cbb578a9 | ||
|
|
99c8205740 | ||
|
|
d7162b9f89 | ||
|
|
3351b62c91 | ||
|
|
0eca930b8d | ||
|
|
81ab62e874 | ||
|
|
0413fc03f8 | ||
|
|
7088572f75 | ||
|
|
c1e8440f5b | ||
|
|
8f0059123b | ||
|
|
a906438a69 | ||
|
|
d28a5b6da1 | ||
|
|
edeacf22c4 | ||
|
|
51f4f67c47 | ||
|
|
cf71e291b4 | ||
|
|
a7a7bd646b | ||
|
|
cec93d2e00 | ||
|
|
722bdb87e9 | ||
|
|
50dea8c983 | ||
|
|
46ba70632b | ||
|
|
60facc7252 | ||
|
|
8c8204d3c4 | ||
|
|
4ce0f6102a | ||
|
|
085fc53bbc | ||
|
|
56cc4f63fc | ||
|
|
a53f34e78f | ||
|
|
1cea96f09f | ||
|
|
006a9d38c7 | ||
|
|
892ce951ce | ||
|
|
7cda221d36 | ||
|
|
9a88eb81e7 | ||
|
|
58cdc050e9 | ||
|
|
b962f4a192 | ||
|
|
b6fcb3e1db | ||
|
|
ff09683d84 | ||
|
|
f618636c71 | ||
|
|
892fc49949 | ||
|
|
228a6dfe79 | ||
|
|
51a92b6093 | ||
|
|
b5964d385d | ||
|
|
fba8c9c498 | ||
|
|
6b2badb837 | ||
|
|
8b8506d01a | ||
|
|
6910a0bb48 | ||
|
|
cffd03b522 | ||
|
|
bf448d3794 | ||
|
|
1d4a12f7c0 | ||
|
|
186d62801d | ||
|
|
da4ed05429 | ||
|
|
ec1eea4f45 | ||
|
|
b203b32e57 | ||
|
|
48a8ce98aa | ||
|
|
8344d1c865 | ||
|
|
d2e6b93369 | ||
|
|
e1ec03d33f | ||
|
|
9323f4b5ca | ||
|
|
c20225fc13 | ||
|
|
337acc4c37 | ||
|
|
618e90cd13 | ||
|
|
92dea961c2 | ||
|
|
2e93186043 | ||
|
|
d07037e817 | ||
|
|
f6cc90d258 | ||
|
|
2c804bef5a | ||
|
|
6070402477 | ||
|
|
67f80a152b | ||
|
|
a7cb587d96 | ||
|
|
f7c74ad2da | ||
|
|
7402d1fd20 | ||
|
|
8c42695ef8 | ||
|
|
72e3241431 | ||
|
|
cd2bf95862 | ||
|
|
f64b72dd7d | ||
|
|
03c84cff28 | ||
|
|
9bc69c9e5f | ||
|
|
1e6c9cfd60 | ||
|
|
0e6712f734 | ||
|
|
0e4cee9a97 | ||
|
|
352b7ec604 | ||
|
|
ba706422fb | ||
|
|
e837921c2c | ||
|
|
73385713ca | ||
|
|
a4e671779a | ||
|
|
7051b2e0a1 | ||
|
|
469737101a | ||
|
|
858257eaf0 | ||
|
|
ef80a0e825 | ||
|
|
92726f7631 | ||
|
|
994063ba9a | ||
|
|
c1a55cf72d | ||
|
|
96758841d8 | ||
|
|
7a59260621 | ||
|
|
27e63b9a78 | ||
|
|
55c0911c23 | ||
|
|
f6cb6ab6d9 | ||
|
|
9f11b09c6a | ||
|
|
a5c4f822f0 | ||
|
|
fb36c262fe | ||
|
|
0e4e8980e6 | ||
|
|
3a932a9803 | ||
|
|
9d10418593 | ||
|
|
5470051d4d | ||
|
|
68c5eeebc3 | ||
|
|
1531fabe23 | ||
|
|
b7673d5b76 | ||
|
|
b64bdaf406 | ||
|
|
eebf08ff1d | ||
|
|
42e51894c3 | ||
|
|
d9ae6481fb | ||
|
|
f1c495a748 | ||
|
|
415b561947 | ||
|
|
e6a0d4c375 | ||
|
|
7e59a5c7c5 | ||
|
|
aea954a482 | ||
|
|
595e448714 | ||
|
|
860f9d63ad | ||
|
|
a5a0b3dc4e | ||
|
|
94eca04c60 | ||
|
|
35bd485d6a | ||
|
|
1fe96f8d9a | ||
|
|
c508e9d7c6 | ||
|
|
55e754fd05 | ||
|
|
a17753f7d1 | ||
|
|
c61838dba6 | ||
|
|
7013e13f05 | ||
|
|
5a0013defe | ||
|
|
c01ed631d6 | ||
|
|
d47464cb06 | ||
|
|
63f176346e | ||
|
|
af94d08729 | ||
|
|
6795d38f50 | ||
|
|
718223f33b | ||
|
|
39e050d9e2 | ||
|
|
c222161291 | ||
|
|
aa80d4681b | ||
|
|
0d57957ebb | ||
|
|
76fe0bb929 | ||
|
|
baa11133f1 | ||
|
|
1bdd3338a6 | ||
|
|
e08492a2c3 | ||
|
|
d5d8fe909d | ||
|
|
8a82753277 | ||
|
|
51ca109067 | ||
|
|
07f6c15a37 | ||
|
|
a44bdb29d4 | ||
|
|
aee4611ab2 | ||
|
|
486467623c | ||
|
|
4912c9b73a | ||
|
|
12d1f3a697 | ||
|
|
a7cad704b9 | ||
|
|
7e4df67556 | ||
|
|
5b24b4dacc | ||
|
|
52fdb46892 | ||
|
|
b389f0fe5f | ||
|
|
74281be340 | ||
|
|
cacf2f7a2c | ||
|
|
4a2cc64d07 | ||
|
|
4647770316 | ||
|
|
3c9b9529c0 | ||
|
|
fc2bd0986c | ||
|
|
a473a32678 | ||
|
|
3e220373b0 | ||
|
|
fbcd886a47 | ||
|
|
e1a782b70f | ||
|
|
73cfedc023 | ||
|
|
b982c977d5 | ||
|
|
532ca1b3a2 | ||
|
|
00ad55b590 | ||
|
|
4c58fd302f | ||
|
|
66582e7035 | ||
|
|
1d13949588 | ||
|
|
c8ad67bbca | ||
|
|
1c92b00918 | ||
|
|
b81a6d01b3 | ||
|
|
0fd666ee6e | ||
|
|
7763fb23a3 | ||
|
|
324277ccfd | ||
|
|
10d02e6c59 | ||
|
|
05ae06c17b | ||
|
|
2671e0c6f7 | ||
|
|
81b6b94f0b |
@@ -38,9 +38,12 @@ The React UI (`core/http/react-ui/`) has **no component/unit tests** — its onl
|
||||
- **Browser:** the flake dev shell ships `chromium` and exports `PLAYWRIGHT_CHROMIUM_PATH`; `playwright.config.js` uses it via `launchOptions.executablePath`, and the Makefile skips `playwright install` when it's set. This avoids Playwright's downloaded browser, which can't resolve system libs (`libglib-2.0`, …) on NixOS. In CI (no `PLAYWRIGHT_CHROMIUM_PATH`) the Makefile falls back to `playwright install --with-deps chromium`.
|
||||
- The app is a React SPA, so coverage accumulates across in-app navigation within a test; a full `page.goto`/reload resets it.
|
||||
- `.nycrc.json` uses `all: true`, so **every `src/**` file is in the report**, including 0%-coverage ones — that's how you spot features with no test at all (sort the HTML report or `coverage-summary.json` by line% ascending).
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` (default **1.0pp**) below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. **Why a tolerance (unlike the strict Go gate):** UI e2e line coverage is *non-deterministic* — async/debounced paths (e.g. the VRAM estimate's 500ms debounce) make identical specs vary ~0.5pp run-to-run, so a zero-tolerance gate would flake. Keep the tolerance just above the observed jitter. Run in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. Runs in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **Why it has a tolerance (unlike the strict Go gate):** UI e2e coverage is *non-deterministic*. Specs that assert on state and end while async/lazy render work is still in flight collect those lines only when the render beats the coverage teardown — so the total drifts with machine speed/load (a fast local box reads higher than a slow CI runner), diffusely across many specs. The tolerance absorbs that drift, so set the baseline *below* the slow-CI floor, never to a fast-local `make test-ui-coverage-baseline` number, or CI flaps.
|
||||
- **Raising coverage is cheap:** a *render-smoke* spec (navigate to a route, assert its header renders) mounts a lazy page and runs its full render + initial effects, capturing most of its lines in a few lines of test — see `e2e/page-render-smoke.spec.js`. Auth is disabled in the test server (`isAdmin=true`), so `RequireAdmin`/`RequireFeature` routes render without a mock. The most *deterministic* win is removing a race: make a spec `await` a rendered element before ending (see `e2e/agents.spec.js` → AgentCreate) so its lines count every run.
|
||||
|
||||
Rules:
|
||||
- The gate is **strict — there is no tolerance**. Any decrease fails, regardless of how many lines a PR adds or deletes. `covermode=atomic` makes line coverage deterministic, so there's no run-to-run jitter to excuse.
|
||||
- When a change legitimately **raises** coverage, run `make test-coverage-baseline` and **commit** the updated `coverage-baseline.txt` so the ratchet moves up. Never lower the baseline by hand.
|
||||
- If you can't get coverage back to baseline, the fix is to **add tests**, not to edit the baseline.
|
||||
Rules (both gates):
|
||||
- **Install the hooks:** `make install-hooks` once per clone so lint + coverage run pre-commit. Don't lean on CI for what the hook catches.
|
||||
- **Don't work around the gate:** never `git commit --no-verify`, and never hand-lower a baseline or widen a tolerance to turn a red gate green. The ratchet only moves up.
|
||||
- If a change drops coverage, **add tests** (sort `coverage-summary.json` by line% ascending to find untested code) rather than editing the baseline. When coverage legitimately rises, commit the regenerated baseline (`make test-coverage-baseline` / `test-ui-coverage-baseline`).
|
||||
- The Go gate is **strict — no tolerance**; `covermode=atomic` keeps it deterministic. The UI gate keeps a small tolerance only because its e2e coverage isn't.
|
||||
|
||||
@@ -50,6 +50,17 @@ Do not mix styles within a package. If you are extending tests in a package that
|
||||
|
||||
This is enforced by `golangci-lint` via the `forbidigo` linter (see `.golangci.yml`); calls like `t.Errorf` / `t.Fatalf` / `t.Run` / `t.Skip` / `t.Logf` are flagged. Run `make lint` locally before submitting; the same check runs in CI (`.github/workflows/lint.yml`).
|
||||
|
||||
## Outbound HTTP
|
||||
|
||||
All outbound HTTP must go through `github.com/mudler/LocalAI/pkg/httpclient` rather than the standard library's default client. Use `httpclient.New(...)` (no body deadline — safe for streaming/SSE) or `httpclient.NewWithTimeout(d, ...)` (simple request/response). Both **refuse redirects by default** and set a TLS 1.2 floor.
|
||||
|
||||
The reason is GHSA-3mj3-57v2-4636: the std default client follows redirects, and on a *cross-host* redirect Go forwards custom credential headers (e.g. Anthropic's `x-api-key`) to the redirect target, leaking the secret. `httpclient` fails closed instead.
|
||||
|
||||
- Need to follow redirects (download CDNs, registry blobs, GitHub asset URLs)? Pass `httpclient.WithFollowRedirects()` — it still strips credential headers on any cross-host hop.
|
||||
- Have a custom transport (IP-pinned dialer, HTTP/2 tuning, a credential-injecting `RoundTripper`)? Pass `httpclient.WithTransport(rt)`, basing the transport on `httpclient.HardenedTransport()` to keep the TLS floor. Handed a `*http.Client` by a library? `httpclient.Harden(c)` applies the policy in place.
|
||||
|
||||
This is enforced by `forbidigo` (see `.golangci.yml`): `http.DefaultClient` and `http.Get`/`Post`/`PostForm`/`Head` are flagged. The `&http.Client{}` composite literal can't be matched precisely by forbidigo without also flagging legitimate `*http.Client` type references, so that form is caught by review — don't construct raw clients.
|
||||
|
||||
## Documentation
|
||||
|
||||
The project documentation is located in `docs/content`. When adding new features or changing existing functionality, it is crucial to update the documentation to reflect these changes. This helps users understand how to use the new capabilities and ensures the documentation stays relevant.
|
||||
|
||||
@@ -68,6 +68,34 @@ go test -count=1 -timeout=30m -v ./tests/e2e-backends/...
|
||||
|
||||
CI does not load the model; the suite is opt-in via env vars.
|
||||
|
||||
## Distributed mode
|
||||
|
||||
ds4 supports **layer-split** distributed inference (a model too big for one host,
|
||||
split by transformer layer; the GGUF must be present on every machine, each loads
|
||||
only its slice). Topology is **inverted** vs llama.cpp: the coordinator listens,
|
||||
workers dial in.
|
||||
|
||||
- **`ds4-worker` binary**: built and packaged next to `grpc-server` (`package.sh`
|
||||
copies it into `package/`). Links the same engine objects plus `ds4_distributed.o`;
|
||||
**no gRPC/protobuf dependency** (speaks ds4's own TCP transport), so it builds
|
||||
even where `grpc-server` can't. Runs the worker serving loop (`ds4_dist_run`).
|
||||
- **Coordinator wiring**: the ds4 `grpc-server` acts as coordinator when `LoadModel`
|
||||
`ModelOptions.Options` (from model-YAML `options:`) carry:
|
||||
- `ds4_role:coordinator` (enables distributed mode; absent → single-node, back-compat)
|
||||
- `ds4_layers:0:19` (coordinator's own slice, inclusive; `N:output` includes the head)
|
||||
- `ds4_listen:0.0.0.0:1234` (address workers dial into)
|
||||
- `ds4_route_timeout:60` (optional; seconds Predict/PredictStream wait for the route
|
||||
to form before returning gRPC `UNAVAILABLE`; default 60)
|
||||
- **Worker CLI**: `local-ai worker ds4-distributed -- <ds4-worker args>` resolves the
|
||||
ds4 backend and execs the packaged `ds4-worker` (raw passthrough), e.g.
|
||||
`--role worker --model /models/ds4flash.gguf --layers 20:output --coordinator <host> 1234`.
|
||||
|
||||
Opt-in e2e in `tests/e2e-backends/backend_test.go`, gated by
|
||||
`BACKEND_TEST_DS4_DISTRIBUTED=1` (plus `BACKEND_TEST_DS4_WORKER_BINARY`,
|
||||
`BACKEND_TEST_DS4_WORKER_LAYERS`, `BACKEND_TEST_DS4_COORDINATOR_LAYERS`,
|
||||
`BACKEND_TEST_DS4_LISTEN`). Design spec:
|
||||
`docs/superpowers/specs/2026-05-30-ds4-distributed-inference-design.md`.
|
||||
|
||||
## Importer
|
||||
|
||||
`core/gallery/importers/ds4.go` (`DS4Importer`) auto-detects ds4 weights by
|
||||
|
||||
599
.github/backend-matrix.yml
vendored
599
.github/backend-matrix.yml
vendored
@@ -703,6 +703,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-locate-anything-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "locate-anything-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -716,6 +729,32 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -742,6 +781,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-omnivoice-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -1517,6 +1569,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-locate-anything-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "locate-anything-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1543,6 +1608,19 @@ include:
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-locate-anything-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "locate-anything-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1556,6 +1634,32 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1569,6 +1673,32 @@ include:
|
||||
backend: "whisper"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-crispasr'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-parakeet-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1595,6 +1725,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-omnivoice-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1634,6 +1777,19 @@ include:
|
||||
backend: "qwen3-tts-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-omnivoice-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1688,20 +1844,6 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-turboquant'
|
||||
builder-base-image: 'quay.io/go-skynet/ci-cache:base-grpc-rocm-amd64'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2742,6 +2884,74 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# locate-anything-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-locate-anything-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "locate-anything-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-locate-anything-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "locate-anything-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-locate-anything-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "locate-anything-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-locate-anything-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "locate-anything-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-locate-anything-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "locate-anything-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2835,6 +3045,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-locate-anything-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "locate-anything-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
# whisper
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -2850,6 +3073,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2864,6 +3101,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2877,6 +3128,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2890,6 +3154,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2904,6 +3181,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2918,6 +3209,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2931,6 +3236,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-crispasr'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2944,6 +3262,128 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-crispasr'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# parakeet-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-parakeet-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-parakeet-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-parakeet-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-parakeet-cpp'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# acestep-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -3082,6 +3522,35 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# omnivoice-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-omnivoice-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-omnivoice-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -3095,6 +3564,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-omnivoice-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -3108,6 +3590,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-omnivoice-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -3122,6 +3617,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-omnivoice-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -3136,6 +3645,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-omnivoice-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -3149,6 +3672,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-omnivoice-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -3162,6 +3698,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-omnivoice-cpp'
|
||||
base-image: "rocm/dev-ubuntu-24.04:6.4.4"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "omnivoice-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# vibevoice-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -3976,6 +4525,14 @@ includeDarwin:
|
||||
tag-suffix: "-metal-darwin-arm64-whisper"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "crispasr"
|
||||
tag-suffix: "-metal-darwin-arm64-crispasr"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "parakeet-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-parakeet-cpp"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "acestep-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-acestep-cpp"
|
||||
build-type: "metal"
|
||||
@@ -3984,6 +4541,10 @@ includeDarwin:
|
||||
tag-suffix: "-metal-darwin-arm64-qwen3-tts-cpp"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "omnivoice-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-omnivoice-cpp"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "vibevoice-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-vibevoice-cpp"
|
||||
build-type: "metal"
|
||||
@@ -4052,6 +4613,10 @@ includeDarwin:
|
||||
tag-suffix: "-metal-darwin-arm64-silero-vad"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "sherpa-onnx"
|
||||
tag-suffix: "-metal-darwin-arm64-sherpa-onnx"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "local-store"
|
||||
tag-suffix: "-metal-darwin-arm64-local-store"
|
||||
build-type: "metal"
|
||||
@@ -4059,3 +4624,9 @@ includeDarwin:
|
||||
- backend: "llama-cpp-quantization"
|
||||
tag-suffix: "-metal-darwin-arm64-llama-cpp-quantization"
|
||||
build-type: "mps"
|
||||
- backend: "speaker-recognition"
|
||||
tag-suffix: "-metal-darwin-arm64-speaker-recognition"
|
||||
build-type: "mps"
|
||||
- backend: "ds4"
|
||||
tag-suffix: "-metal-darwin-arm64-ds4"
|
||||
lang: "go"
|
||||
|
||||
13
.github/gallery-agent/main.go
vendored
13
.github/gallery-agent/main.go
vendored
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -113,6 +114,17 @@ func main() {
|
||||
fmt.Println("Searching for trending models on HuggingFace...")
|
||||
rawModels, err := client.GetTrending(searchTerm, limit)
|
||||
if err != nil {
|
||||
if errors.Is(err, hfapi.ErrRateLimited) {
|
||||
fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err)
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: 0,
|
||||
ModelsAdded: 0,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -277,4 +289,3 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
|
||||
16
.github/workflows/bump_deps.yaml
vendored
16
.github/workflows/bump_deps.yaml
vendored
@@ -30,6 +30,14 @@ jobs:
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/whisper/Makefile"
|
||||
- repository: "CrispStrobe/CrispASR"
|
||||
variable: "CRISPASR_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/crispasr/Makefile"
|
||||
- repository: "mudler/parakeet.cpp"
|
||||
variable: "PARAKEET_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/parakeet-cpp/Makefile"
|
||||
- repository: "leejet/stable-diffusion.cpp"
|
||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||
branch: "master"
|
||||
@@ -54,10 +62,18 @@ jobs:
|
||||
variable: "RFDETR_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/rfdetr-cpp/Makefile"
|
||||
- repository: "mudler/locate-anything.cpp"
|
||||
variable: "LOCATEANYTHING_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/locate-anything-cpp/Makefile"
|
||||
- repository: "predict-woo/qwen3-tts.cpp"
|
||||
variable: "QWEN3TTS_CPP_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/qwen3-tts-cpp/Makefile"
|
||||
- repository: "ServeurpersoCom/omnivoice.cpp"
|
||||
variable: "OMNIVOICE_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/omnivoice-cpp/Makefile"
|
||||
- repository: "localai-org/vibevoice.cpp"
|
||||
variable: "VIBEVOICE_CPP_VERSION"
|
||||
branch: "master"
|
||||
|
||||
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: securego/gosec@v2.22.9
|
||||
uses: securego/gosec@v2.27.1
|
||||
with:
|
||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||
|
||||
63
.github/workflows/test-extra.yml
vendored
63
.github/workflows/test-extra.yml
vendored
@@ -38,6 +38,7 @@ jobs:
|
||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||
qwen3-tts-cpp: ${{ steps.detect.outputs.qwen3-tts-cpp }}
|
||||
rfdetr-cpp: ${{ steps.detect.outputs.rfdetr-cpp }}
|
||||
locate-anything-cpp: ${{ steps.detect.outputs.locate-anything-cpp }}
|
||||
vibevoice-cpp: ${{ steps.detect.outputs.vibevoice-cpp }}
|
||||
localvqe: ${{ steps.detect.outputs.localvqe }}
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
@@ -46,6 +47,7 @@ jobs:
|
||||
speaker-recognition: ${{ steps.detect.outputs.speaker-recognition }}
|
||||
sherpa-onnx: ${{ steps.detect.outputs.sherpa-onnx }}
|
||||
whisper: ${{ steps.detect.outputs.whisper }}
|
||||
parakeet-cpp: ${{ steps.detect.outputs.parakeet-cpp }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -562,7 +564,7 @@ jobs:
|
||||
- name: Run e2e-backends smoke
|
||||
env:
|
||||
BACKEND_IMAGE: quay.io/go-skynet/local-ai-backends:master-cpu-llama-cpp
|
||||
BACKEND_TEST_CAPS: health,load,predict,stream,logprobs,logit_bias
|
||||
BACKEND_TEST_CAPS: health,load,predict,stream,logprobs,logit_bias,tokenize
|
||||
run: |
|
||||
make test-extra-backend
|
||||
# Realtime e2e with sherpa-onnx driving VAD + STT + TTS against a mocked LLM.
|
||||
@@ -633,6 +635,26 @@ jobs:
|
||||
- name: Build whisper backend image and run transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-whisper-transcription
|
||||
# Parakeet ASR via the parakeet-cpp backend (C++/ggml port of NeMo
|
||||
# Parakeet). Drives AudioTranscription (offline, with word timestamps) on
|
||||
# tdt_ctc-110m + the JFK 11s clip.
|
||||
tests-parakeet-cpp-grpc-transcription:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.parakeet-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
- name: Build parakeet-cpp backend image and run transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-parakeet-cpp-transcription
|
||||
# VITS TTS via the sherpa-onnx backend. Drives both TTS (file write) and
|
||||
# TTSStream (PCM chunks) on the e2e-backends harness.
|
||||
tests-sherpa-onnx-grpc-tts:
|
||||
@@ -880,6 +902,45 @@ jobs:
|
||||
- name: Test rfdetr-cpp
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/rfdetr-cpp test
|
||||
# Per-backend e2e for locate-anything-cpp: builds the .so + Go binary and
|
||||
# runs `make -C backend/go/locate-anything-cpp test`. test.sh fetches the
|
||||
# locate-anything-q8_0 GGUF (~6.3 GB, NVIDIA LocateAnything-3B) from the
|
||||
# published mudler/locate-anything.cpp-gguf HF repo + a COCO image, then the
|
||||
# Go wire test loads the model and runs an open-vocabulary Detect, asserting
|
||||
# at least one labeled box. Heavier than the other Go backends (it is a 3B),
|
||||
# so it is gated to changes under backend/go/locate-anything-cpp/.
|
||||
tests-locate-anything-cpp:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.locate-anything-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential cmake curl libopenblas-dev
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
- name: Display Go version
|
||||
run: go version
|
||||
- name: Proto Dependencies
|
||||
run: |
|
||||
# Install protoc
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||
PATH="$PATH:$HOME/go/bin" make protogen-go
|
||||
- name: Build locate-anything-cpp
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/locate-anything-cpp
|
||||
- name: Test locate-anything-cpp
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/locate-anything-cpp test
|
||||
# Per-backend smoke for vibevoice-cpp: builds the .so + Go binary and
|
||||
# runs `make -C backend/go/vibevoice-cpp test`. test.sh auto-downloads
|
||||
# the published mudler/vibevoice.cpp-models bundle (TTS Q8_0 + ASR Q4_K
|
||||
|
||||
@@ -56,6 +56,20 @@ linters:
|
||||
# are exempt — see linters.exclusions.rules below.
|
||||
- pattern: '^os\.(Getenv|LookupEnv|Environ)$'
|
||||
msg: 'Plumb config through ApplicationConfig (or the relevant CLI struct) instead of reading env directly. CLI entry points (core/cli/) bind env vars via kong''s `env:` tag — that is the only sanctioned env→struct boundary. See .agents/coding-style.md.'
|
||||
# Outbound HTTP must go through pkg/httpclient, which refuses redirects
|
||||
# by default and sets a TLS floor. The std-library default client and
|
||||
# the http.Get/Post/... convenience helpers follow redirects (up to 10)
|
||||
# and, on a cross-host redirect, forward custom credential headers such
|
||||
# as Anthropic's x-api-key to the redirect target — leaking the secret
|
||||
# (GHSA-3mj3-57v2-4636). forbidigo can't precisely match the
|
||||
# `&http.Client{}` composite literal without also flagging legitimate
|
||||
# `*http.Client` type references, so that form is enforced by
|
||||
# convention + review; these two patterns catch the implicit-default
|
||||
# client, which is the common footgun.
|
||||
- pattern: '^http\.DefaultClient$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.DefaultClient — the std client follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
- pattern: '^http\.(Get|Post|PostForm|Head)$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.Get/Post/PostForm/Head — these use http.DefaultClient, which follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
exclusions:
|
||||
paths:
|
||||
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
||||
@@ -95,3 +109,18 @@ linters:
|
||||
- path: _test\.go$
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# pkg/httpclient is the sanctioned home for outbound HTTP clients; it
|
||||
# necessarily references net/http directly.
|
||||
- path: ^pkg/httpclient/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Tests drive local httptest servers where redirect/TLS hardening is
|
||||
# irrelevant; the std client is fine there.
|
||||
- path: _test\.go$
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Vendored upstream whisper.cpp Go bindings are a separate module and
|
||||
# cannot import pkg/httpclient.
|
||||
- path: ^backend/go/whisper/sources/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
|
||||
@@ -35,6 +35,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
|
||||
## Quick Reference
|
||||
|
||||
- **Git hooks & coverage gates**: Run `make install-hooks` once per clone so the pre-commit lint + coverage gates run. **Never bypass them with `git commit --no-verify`, and never lower a coverage baseline or widen a gate's tolerance to turn a red gate green** — the coverage ratchet only moves up. If a change drops coverage, add tests to raise it (e.g. render-smoke specs). See [.agents/building-and-testing.md](.agents/building-and-testing.md).
|
||||
- **Logging**: Use `github.com/mudler/xlog` (same API as slog)
|
||||
- **Go style**: Prefer `any` over `interface{}`
|
||||
- **Comments**: Explain *why*, not *what*
|
||||
|
||||
@@ -266,6 +266,12 @@ The e2e tests run LocalAI in a Docker container and exercise the API:
|
||||
make test-e2e
|
||||
```
|
||||
|
||||
### React UI tests and coverage
|
||||
|
||||
The React UI (`core/http/react-ui/`) is covered by Playwright e2e specs, gated by a **monotonic line-coverage ratchet** (`make test-ui-coverage-check`, run in CI and pre-commit). The metric is non-deterministic — a fast local box reads higher than a slow CI runner for the same code — so a small tolerance is unavoidable.
|
||||
|
||||
**If your change lowers UI coverage, raise it back by adding specs — do not widen the tolerance or hand-lower the baseline.** A *render-smoke* spec (navigate to a page, assert its header is visible) cheaply covers an entire lazy page. See `core/http/react-ui/e2e/page-render-smoke.spec.js` and the full policy in [.agents/building-and-testing.md](.agents/building-and-testing.md#react-ui-coverage).
|
||||
|
||||
### Running E2E container tests
|
||||
|
||||
These tests build a standard LocalAI Docker image and run it with pre-configured model configs to verify that most endpoints work correctly:
|
||||
|
||||
@@ -108,6 +108,7 @@ RUN <<EOT bash
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
cuda-nvrtc-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
|
||||
65
Makefile
65
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/omnivoice-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -180,7 +180,7 @@ osx-signed: build
|
||||
|
||||
## Run
|
||||
run: ## run local-ai
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./cmd/local-ai
|
||||
|
||||
prepare-test: protogen-go build-mock-backend
|
||||
|
||||
@@ -309,13 +309,20 @@ run-e2e-aio: protogen-go
|
||||
@echo 'Running e2e AIO tests'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
|
||||
|
||||
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
|
||||
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
|
||||
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
|
||||
test-e2e-distributed: protogen-go
|
||||
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
|
||||
|
||||
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
|
||||
# cpu-vllm backend from the current working tree, then drives a
|
||||
# head + headless follower via testcontainers-go and asserts a chat
|
||||
# completion. BuildKit caches both images, so re-runs only rebuild
|
||||
# what changed. The test lives under tests/e2e/distributed and is
|
||||
# selected by the VLLMMultinode label so it doesn't run alongside
|
||||
# the other distributed-suite tests by default.
|
||||
# test-e2e-distributed.
|
||||
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
|
||||
@echo 'Running e2e vLLM multi-node DP test'
|
||||
LOCALAI_IMAGE=local-ai \
|
||||
@@ -559,6 +566,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/speaker-recognition
|
||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||
$(MAKE) -C backend/go/rfdetr-cpp
|
||||
$(MAKE) -C backend/go/locate-anything-cpp
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/transformers test
|
||||
@@ -586,6 +594,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/speaker-recognition test
|
||||
$(MAKE) -C backend/rust/kokoros test
|
||||
$(MAKE) -C backend/go/rfdetr-cpp test
|
||||
$(MAKE) -C backend/go/locate-anything-cpp test
|
||||
|
||||
##
|
||||
## End-to-end gRPC tests that exercise a built backend container image.
|
||||
@@ -991,6 +1000,19 @@ test-extra-backend-whisper-transcription: docker-build-whisper
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## Audio transcription wrapper for the parakeet-cpp (parakeet.cpp ggml port)
|
||||
## backend. Mirrors test-extra-backend-whisper-transcription: drives the
|
||||
## AudioTranscription / AudioTranscriptionStream RPCs against a published
|
||||
## Parakeet GGUF using the JFK 11s clip from whisper.cpp's CI samples. Not
|
||||
## part of the default test suite - run explicitly once the pinned model URL
|
||||
## is reachable.
|
||||
test-extra-backend-parakeet-cpp-transcription: docker-build-parakeet-cpp
|
||||
BACKEND_IMAGE=local-ai-backend:parakeet-cpp \
|
||||
BACKEND_TEST_MODEL_URL=https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/tdt_ctc-110m-f16.gguf \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## LocalVQE audio transform (joint AEC + noise suppression + dereverb).
|
||||
## Exercises the audio_transform capability end-to-end: batch transform
|
||||
## of a real WAV fixture and bidi streaming of synthetic silent frames.
|
||||
@@ -1149,9 +1171,12 @@ BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
||||
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
||||
BACKEND_OMNIVOICE_CPP = omnivoice-cpp|golang|.|false|true
|
||||
BACKEND_VIBEVOICE_CPP = vibevoice-cpp|golang|.|false|true
|
||||
BACKEND_LOCALVQE = localvqe|golang|.|false|true
|
||||
BACKEND_OPUS = opus|golang|.|false|true
|
||||
@@ -1236,6 +1261,8 @@ $(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
||||
@@ -1268,6 +1295,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN3_TTS_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OMNIVOICE_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCALVQE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX)))
|
||||
@@ -1285,7 +1313,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-crispasr docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-omnivoice-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
@@ -1313,6 +1341,13 @@ build-ui-test-server: build-mock-backend react-ui protogen-go
|
||||
test-ui-e2e: build-ui-test-server
|
||||
cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test
|
||||
|
||||
## Optional Playwright worker count for the UI e2e targets below. Pass
|
||||
## UI_TEST_WORKERS=N (e.g. `make test-ui-coverage UI_TEST_WORKERS=20`) to
|
||||
## override Playwright's default (cores/2). Empty by default so Playwright
|
||||
## picks its own worker count.
|
||||
UI_TEST_WORKERS ?=
|
||||
PLAYWRIGHT_WORKERS_FLAG = $(if $(UI_TEST_WORKERS),--workers=$(UI_TEST_WORKERS),)
|
||||
|
||||
## Fast Playwright e2e run used by the pre-commit hook on React UI changes.
|
||||
## Force-rebuilds the (non-instrumented) dist so the suite tests the working
|
||||
## tree — not a stale dist the `react-ui` skip-guard would leave — re-embeds
|
||||
@@ -1322,22 +1357,24 @@ test-ui-e2e: build-ui-test-server
|
||||
test-ui: build-mock-backend protogen-go
|
||||
cd core/http/react-ui && bun install && bun run build
|
||||
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui
|
||||
cd core/http/react-ui && sh $(CURDIR)/scripts/ensure-playwright-browser.sh && bunx playwright test
|
||||
cd core/http/react-ui && sh $(CURDIR)/scripts/ensure-playwright-browser.sh && bunx playwright test $(PLAYWRIGHT_WORKERS_FLAG)
|
||||
|
||||
## React UI code coverage from the Playwright e2e suite. Builds an
|
||||
## istanbul-instrumented bundle (COVERAGE=true), re-embeds it into the
|
||||
## ui-test-server (the dist is //go:embed'ed at compile time), runs the
|
||||
## Playwright specs — which harvest window.__coverage__ via the coverage
|
||||
## fixture — and writes an nyc report to core/http/react-ui/coverage/.
|
||||
## Removes the instrumented dist afterwards so normal builds aren't served
|
||||
## instrumented assets.
|
||||
## React UI code coverage from the Playwright e2e suite. Builds a
|
||||
## NON-instrumented bundle with source maps (COVERAGE_V8=true), re-embeds it
|
||||
## into the ui-test-server (the dist is //go:embed'ed at compile time), runs the
|
||||
## Playwright specs which collect native Chromium V8 coverage (PW_V8_COVERAGE=1)
|
||||
## — far cheaper than istanbul's build-time counters (~40% faster end-to-end) —
|
||||
## convert it to istanbul via v8-to-istanbul in the coverage fixture, and write
|
||||
## an nyc report to core/http/react-ui/coverage/. Removes the dist afterwards so
|
||||
## normal builds aren't served source-mapped assets. (The legacy istanbul path
|
||||
## still exists: `bun run build:coverage` + unset PW_V8_COVERAGE.)
|
||||
test-ui-coverage: build-mock-backend protogen-go
|
||||
trap 'rm -rf "$(CURDIR)/core/http/react-ui/dist"' EXIT; \
|
||||
( cd core/http/react-ui && bun install && bun run build:coverage ) && \
|
||||
( cd core/http/react-ui && bun install && bun run build:coverage-v8 ) && \
|
||||
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui && \
|
||||
( cd core/http/react-ui && rm -rf .nyc_output coverage && \
|
||||
sh $(CURDIR)/scripts/ensure-playwright-browser.sh && \
|
||||
bunx playwright test && bun run coverage:report )
|
||||
PW_V8_COVERAGE=1 bunx playwright test $(PLAYWRIGHT_WORKERS_FLAG) && bun run coverage:report )
|
||||
|
||||
## UI coverage baseline (committed) and the strict gate that compares against
|
||||
## it — the React mirror of test-coverage-baseline / test-coverage-check.
|
||||
|
||||
34
README.md
34
README.md
@@ -31,12 +31,18 @@
|
||||
|
||||
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||
|
||||
- **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
|
||||
- **36+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
|
||||
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready** — API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first** — your data never leaves your infrastructure
|
||||
**A small core, not a bundle.** Each backend wraps a best-in-class engine (llama.cpp, vLLM, whisper.cpp, stable-diffusion, MLX...) in its own image, pulled only when a model needs it. You install nothing you don't use.
|
||||
|
||||
- **Composable by design**: backends are separate and pulled on demand, so you install only what your model needs
|
||||
- **Open and extensible**: load any model, or build your own backend in any language against an open interface
|
||||
- **Drop-in API compatibility**: OpenAI, Anthropic, and ElevenLabs APIs across every backend
|
||||
- **Any model, any modality**: LLMs, vision, voice, image, and video behind one API
|
||||
- **Any hardware**: NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready**: API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents**: autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first**: your data never leaves your infrastructure
|
||||
|
||||

|
||||
|
||||
Created by [Ettore Di Giacinto](https://github.com/mudler) and maintained by the [LocalAI team](#team).
|
||||
|
||||
@@ -143,12 +149,26 @@ local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
||||
local-ai run oci://localai/phi-2:latest
|
||||
```
|
||||
|
||||
To test a running LocalAI server from the terminal, open an interactive chat session from another shell. Inside the prompt, `/models` lists installed models and `/model <name>` switches between them.
|
||||
|
||||
```bash
|
||||
# Terminal 1
|
||||
local-ai run llama-3.2-1b-instruct:q4_k_m
|
||||
|
||||
# Terminal 2
|
||||
local-ai chat --model llama-3.2-1b-instruct:q4_k_m
|
||||
```
|
||||
|
||||
> **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
|
||||
|
||||
For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).
|
||||
|
||||
## Latest News
|
||||
|
||||
- **June 2026**: New [realtime voice assistant demo](https://github.com/localai-org/localai-realtime-demo) (a tiny Go client for the Realtime API with a full talk-back voice loop and tool calling), plus [streaming of the realtime LLM / TTS / transcription pipeline stages](https://github.com/mudler/LocalAI/pull/10176) and [configurable WebRTC ICE candidates](https://github.com/mudler/LocalAI/pull/10231).
|
||||
- **June 2026**: Big speech push: the [parakeet.cpp](https://github.com/mudler/parakeet.cpp) ASR engine gains [NeMo-faithful segment timestamps](https://github.com/mudler/LocalAI/pull/10207), a [multilingual streaming Nemotron-3.5 model](https://github.com/mudler/LocalAI/pull/10199), [dynamic batching for concurrent transcription](https://github.com/mudler/LocalAI/pull/10112) and [CUDA graphs](https://github.com/mudler/LocalAI/pull/10273); the new [CrispASR backend](https://github.com/mudler/LocalAI/pull/10099) adds multi-architecture ASR + TTS, and [60 Piper TTS voices across 42 languages](https://github.com/mudler/LocalAI/pull/10296) land in the gallery (plus [per-request TTS instructions and params](https://github.com/mudler/LocalAI/pull/10172)).
|
||||
- **June 2026**: New backends and models: [locate-anything.cpp](https://github.com/mudler/LocalAI/pull/10264) for open-vocabulary object detection via ggml, [Ideogram4 image generation](https://github.com/mudler/LocalAI/pull/10201) in stablediffusion-ggml, [llama.cpp video input](https://github.com/mudler/LocalAI/pull/10216), and the [Gemma 4 QAT family with MTP speculative-decoding pairs](https://github.com/mudler/LocalAI/pull/10215). Plus an [interactive CLI chat mode](https://github.com/mudler/LocalAI/pull/10226) and [RAG source citations in agent responses](https://github.com/mudler/LocalAI/pull/10228).
|
||||
- **June 2026**: Distributed mode hardening: [prefix-cache-aware routing](https://github.com/mudler/LocalAI/pull/10071), a [production-ready request router with auto-sized embedding/rerank batches](https://github.com/mudler/LocalAI/pull/10104), [ds4 layer-split distributed inference](https://github.com/mudler/LocalAI/pull/10098), [NATS JWT auth + TLS/mTLS](https://github.com/mudler/LocalAI/pull/10159), and [resumable file uploads](https://github.com/mudler/LocalAI/pull/10109).
|
||||
- **May 2026**: **LocalAI 4.3.0** - `llama.cpp` [prompt cache on by default](https://github.com/mudler/LocalAI/pull/9925) (repeated system prompts collapse from minutes to seconds), [keyless cosign signing of backend OCI images](https://github.com/mudler/LocalAI/pull/9823), [per-API-key + per-user usage attribution](https://github.com/mudler/LocalAI/pull/9920), Distributed v3 with [per-request replica routing](https://github.com/mudler/LocalAI/pull/9968). [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.3.0)
|
||||
- **May 2026**: **LocalAI 4.2.0** - LocalAI sees and hears: [voice recognition](https://github.com/mudler/LocalAI/pull/9500), [face recognition + antispoofing liveness](https://github.com/mudler/LocalAI/pull/9480), speaker diarization. Plus [drop-in Ollama API](https://github.com/mudler/LocalAI/pull/9284), [video generation](https://github.com/mudler/LocalAI/pull/9420), redesigned UI with i18n + admin-configurable branding, vLLM at feature parity with llama.cpp, and 11 new backends. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.2.0)
|
||||
- **April 2026**: **LocalAI 4.1.0** - LocalAI becomes a control tower: distributed cluster mode with VRAM-aware smart routing + autoscaling, multi-user platform with OIDC and API keys, per-user quotas with predictive analytics, in-UI fine-tuning with TRL (auto-export to GGUF), on-the-fly quantization backend, visual pipeline editor. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.1.0)
|
||||
@@ -201,7 +221,7 @@ See the full [Backend & Model Compatibility Table](https://localai.io/model-comp
|
||||
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
||||
- [Installation video walkthrough](https://www.youtube.com/watch?v=cMVNnlqwfw4)
|
||||
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
||||
- [Examples](https://github.com/mudler/LocalAI-examples)
|
||||
- [Examples](https://github.com/mudler/LocalAI-examples) — including the [realtime voice assistant demo](https://github.com/localai-org/localai-realtime-demo) (Go client for the Realtime API with tool calling)
|
||||
|
||||
## Team
|
||||
|
||||
|
||||
@@ -206,6 +206,16 @@ RUN if [ "${BACKEND}" = "opus" ]; then \
|
||||
apt-get clean && rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
# CrispASR's piper TTS backend dlopens libespeak-ng at runtime to phonemize
|
||||
# non-English text (the MIT-clean path; English uses a built-in G2P). Install
|
||||
# the espeak-ng runtime + its libpcaudio/libsonic deps + voice data so
|
||||
# package.sh can bundle them into the FROM scratch image.
|
||||
RUN if [ "${BACKEND}" = "crispasr" ]; then \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
espeak-ng-data libespeak-ng1 libpcaudio0 libsonic0 && \
|
||||
apt-get clean && rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
RUN git config --global --add safe.directory /LocalAI
|
||||
|
||||
@@ -126,6 +126,7 @@ RUN <<EOT bash
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
cuda-nvrtc-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
|
||||
@@ -537,6 +537,15 @@ message TTSRequest {
|
||||
string dst = 3;
|
||||
string voice = 4;
|
||||
optional string language = 5;
|
||||
// instructions is a free-form, per-request style/voice description (maps to
|
||||
// the OpenAI `instructions` field). Backends that support expressive synthesis
|
||||
// (e.g. Qwen3-TTS CustomVoice/VoiceDesign) prefer this over the static YAML
|
||||
// option when set; backends that don't simply ignore it.
|
||||
optional string instructions = 6;
|
||||
// params carries optional, backend-specific per-request generation parameters
|
||||
// (e.g. Chatterbox exaggeration/cfg_weight/temperature). Values are strings and
|
||||
// coerced by the backend; unset leaves the backend's configured defaults.
|
||||
map<string, string> params = 7;
|
||||
}
|
||||
|
||||
message VADRequest {
|
||||
|
||||
1
backend/cpp/ds4/.gitignore
vendored
1
backend/cpp/ds4/.gitignore
vendored
@@ -2,6 +2,7 @@ ds4/
|
||||
build/
|
||||
package/
|
||||
grpc-server
|
||||
ds4-worker
|
||||
*.o
|
||||
backend.pb.cc
|
||||
backend.pb.h
|
||||
|
||||
@@ -60,6 +60,13 @@ elseif(DS4_GPU STREQUAL "cpu")
|
||||
set(DS4_OBJS "${DS4_DIR}/ds4_cpu.o")
|
||||
endif()
|
||||
|
||||
# ds4.c now references ds4_distributed.c (distributed inference) and ds4_ssd.c
|
||||
# (SSD expert-cache), each split into its own translation unit upstream. Both
|
||||
# are GPU-agnostic objects shared by every GPU mode, so link them in regardless
|
||||
# of DS4_GPU.
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_distributed.o")
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_ssd.o")
|
||||
|
||||
add_executable(${TARGET}
|
||||
grpc-server.cpp
|
||||
dsml_parser.cpp
|
||||
@@ -99,3 +106,36 @@ if(DS4_NATIVE)
|
||||
target_compile_options(${TARGET} PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ds4-worker: standalone distributed worker. Links the same ds4 engine objects
|
||||
# (including ds4_distributed.o) but has NO gRPC/protobuf dependency - it speaks
|
||||
# ds4's own TCP transport via ds4_dist_run(). Buildable wherever the engine
|
||||
# objects build, even on hosts without protobuf/grpc dev headers.
|
||||
add_executable(ds4-worker worker_main.c)
|
||||
target_include_directories(ds4-worker PRIVATE ${DS4_DIR})
|
||||
foreach(obj ${DS4_OBJS})
|
||||
target_sources(ds4-worker PRIVATE ${obj})
|
||||
set_source_files_properties(${obj} PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
|
||||
endforeach()
|
||||
# worker_main.c is C, but the engine objects built by nvcc (ds4_cuda.o) and the
|
||||
# Metal path (ds4_metal.o, Obj-C++) reference the C++ runtime (libstdc++). Force
|
||||
# the C++ linker driver so those symbols resolve; the C driver would not link
|
||||
# libstdc++ and the CUDA/Metal builds fail with undefined std:: references.
|
||||
set_target_properties(ds4-worker PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_link_libraries(ds4-worker PRIVATE Threads::Threads m)
|
||||
|
||||
if(DS4_GPU STREQUAL "cuda")
|
||||
target_link_libraries(ds4-worker PRIVATE CUDA::cudart CUDA::cublas)
|
||||
elseif(DS4_GPU STREQUAL "metal")
|
||||
target_link_libraries(ds4-worker PRIVATE ${FOUNDATION_LIB} ${METAL_LIB})
|
||||
elseif(DS4_GPU STREQUAL "cpu")
|
||||
target_compile_definitions(ds4-worker PRIVATE DS4_NO_GPU)
|
||||
endif()
|
||||
|
||||
if(DS4_NATIVE)
|
||||
if(APPLE)
|
||||
target_compile_options(ds4-worker PRIVATE -mcpu=native)
|
||||
else()
|
||||
target_compile_options(ds4-worker PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=e8e8779b261c10f36ad6270ba732c8f0be5b62e3
|
||||
# Upstream pin lives below as DS4_VERSION?=d881f2a05e8ff6bec001315a36b794b4aa310173
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=e8e8779b261c10f36ad6270ba732c8f0be5b62e3
|
||||
DS4_VERSION?=d881f2a05e8ff6bec001315a36b794b4aa310173
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
@@ -18,16 +18,20 @@ UNAME_S := $(shell uname -s)
|
||||
|
||||
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
|
||||
|
||||
# ds4_distributed.o and ds4_ssd.o are GPU-agnostic translation units that
|
||||
# ds4.c/ds4_cpu.o now reference (upstream split distributed inference and the
|
||||
# SSD expert-cache into their own .c files). Both objects are shared by every
|
||||
# GPU mode, so they are appended unconditionally below.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS += -DDS4_GPU=cuda
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
CMAKE_ARGS += -DDS4_GPU=metal
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
|
||||
else
|
||||
# CPU reference path (Linux only - macOS CPU path is broken by VM bug per ds4 README).
|
||||
CMAKE_ARGS += -DDS4_GPU=cpu
|
||||
DS4_OBJ_TARGET := ds4_cpu.o
|
||||
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o ds4_ssd.o
|
||||
endif
|
||||
|
||||
ifneq ($(NATIVE),true)
|
||||
@@ -52,17 +56,18 @@ ds4:
|
||||
# the right per-platform compile flags (Objective-C/Metal on Darwin, nvcc on Linux+CUDA).
|
||||
ds4/ds4.o: ds4
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
|
||||
else
|
||||
+$(MAKE) -C ds4 ds4_cpu.o
|
||||
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o ds4_ssd.o
|
||||
endif
|
||||
|
||||
grpc-server: ds4/ds4.o
|
||||
mkdir -p $(BUILD_DIR)
|
||||
cd $(BUILD_DIR) && cmake $(CMAKE_ARGS) $(CURRENT_MAKEFILE_DIR) && cmake --build . --config Release -j $(JOBS)
|
||||
cp $(BUILD_DIR)/grpc-server grpc-server
|
||||
cp $(BUILD_DIR)/ds4-worker ds4-worker
|
||||
|
||||
package: grpc-server
|
||||
bash package.sh
|
||||
@@ -71,7 +76,7 @@ test:
|
||||
@echo "ds4 backend: e2e coverage at tests/e2e-backends/ (BACKEND_BINARY mode)"
|
||||
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR) grpc-server package
|
||||
rm -rf $(BUILD_DIR) grpc-server ds4-worker package
|
||||
if [ -d ds4 ]; then $(MAKE) -C ds4 clean; fi
|
||||
|
||||
purge: clean
|
||||
|
||||
@@ -23,8 +23,11 @@ extern "C" {
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
@@ -51,6 +54,12 @@ ds4_session *g_session = nullptr;
|
||||
int g_ctx_size = 32768;
|
||||
std::string g_kv_cache_dir; // empty disables disk cache
|
||||
|
||||
// Distributed coordinator state. g_distributed is set true when LoadModel is
|
||||
// given 'ds4_role:coordinator'; generation then waits for the worker route to
|
||||
// form before running. Single-node behavior is unchanged when unset.
|
||||
bool g_distributed = false;
|
||||
int g_route_timeout_sec = 60;
|
||||
|
||||
std::atomic<Server *> g_server{nullptr};
|
||||
|
||||
// Parse a "key:value" option string. Returns empty when no colon.
|
||||
@@ -60,6 +69,77 @@ static std::pair<std::string, std::string> split_option(const std::string &opt)
|
||||
return {opt.substr(0, colon), opt.substr(colon + 1)};
|
||||
}
|
||||
|
||||
// Parse a positive base-10 integer. Returns false (without throwing) on empty,
|
||||
// trailing garbage, non-positive, or overflow - unlike std::stoi.
|
||||
static bool parse_positive_int(const std::string &s, int *out) {
|
||||
if (s.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long v = std::strtol(s.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || v <= 0 || v > INT_MAX) return false;
|
||||
*out = static_cast<int>(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse a ds4 layer spec "START:END" or "START:output" into the engine's
|
||||
// distributed layer fields. Returns false on malformed input.
|
||||
static bool parse_layers_spec(const std::string &spec, ds4_distributed_layers *out) {
|
||||
auto colon = spec.find(':');
|
||||
if (colon == std::string::npos) return false;
|
||||
std::string lhs = spec.substr(0, colon);
|
||||
std::string rhs = spec.substr(colon + 1);
|
||||
if (lhs.empty() || rhs.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long start = std::strtol(lhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || start < 0) return false;
|
||||
out->start = static_cast<uint32_t>(start);
|
||||
out->has_output = false;
|
||||
if (rhs == "output") {
|
||||
out->has_output = true;
|
||||
out->end = out->start; // engine treats has_output as "through final layer"
|
||||
} else {
|
||||
long e = std::strtol(rhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || e < start) return false;
|
||||
out->end = static_cast<uint32_t>(e);
|
||||
}
|
||||
out->set = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
// When acting as a distributed coordinator, block until the worker route
|
||||
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
||||
// elapses. Returns an empty string on success, or an error message to return
|
||||
// to the client. No-op when not distributed.
|
||||
//
|
||||
// Takes the g_engine_mu lock by reference and RELEASES it during each poll
|
||||
// sleep. The wait can span up to g_route_timeout_sec seconds while workers
|
||||
// connect; holding g_engine_mu the whole time would block the Status/Health
|
||||
// readiness probes (they also lock g_engine_mu), making LocalAI's loader treat
|
||||
// a still-starting worker as hung.
|
||||
static std::string wait_route_ready(std::unique_lock<std::mutex> &lock) {
|
||||
if (!g_distributed) return "";
|
||||
char err[256] = {0};
|
||||
const int deadline_polls = g_route_timeout_sec * 10; // 100ms per poll
|
||||
for (int i = 0; i <= deadline_polls; ++i) {
|
||||
int ready = ds4_session_distributed_route_ready(g_session, err, sizeof(err));
|
||||
if (ready == 1) return "";
|
||||
if (ready < 0) {
|
||||
return std::string("ds4 distributed route error: ") +
|
||||
(err[0] ? err : "unknown");
|
||||
}
|
||||
// Release the lock while sleeping so Status/Health and other RPCs can
|
||||
// interleave during worker startup.
|
||||
lock.unlock();
|
||||
struct timespec ts = {0, 100L * 1000L * 1000L}; // 100ms
|
||||
nanosleep(&ts, nullptr);
|
||||
lock.lock();
|
||||
// A concurrent Free() may have torn down the engine while we slept.
|
||||
if (!g_engine || !g_session) {
|
||||
return "ds4: model unloaded while waiting for distributed route";
|
||||
}
|
||||
}
|
||||
return "ds4 distributed route incomplete: workers not connected (layers uncovered)";
|
||||
}
|
||||
|
||||
static void append_token_text(ds4_engine *engine, int token, std::string &out) {
|
||||
size_t len = 0;
|
||||
const char *text = ds4_token_text(engine, token, &len);
|
||||
@@ -377,6 +457,11 @@ public:
|
||||
backend::Result *result) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
|
||||
// Reset distributed state so a model swap (a second LoadModel without
|
||||
// ds4_role) doesn't inherit a stale coordinator configuration.
|
||||
g_distributed = false;
|
||||
g_route_timeout_sec = 60;
|
||||
|
||||
if (g_engine) {
|
||||
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
||||
ds4_engine_close(g_engine);
|
||||
@@ -394,12 +479,23 @@ public:
|
||||
std::string mtp_path;
|
||||
int mtp_draft = 0;
|
||||
float mtp_margin = 3.0f;
|
||||
std::string ds4_role, ds4_layers, ds4_listen;
|
||||
for (const auto &opt : request->options()) {
|
||||
auto [k, v] = split_option(opt);
|
||||
if (k == "mtp_path") mtp_path = v;
|
||||
else if (k == "mtp_draft") mtp_draft = std::stoi(v);
|
||||
else if (k == "mtp_margin") mtp_margin = std::stof(v);
|
||||
else if (k == "kv_cache_dir") g_kv_cache_dir = v;
|
||||
else if (k == "ds4_role") ds4_role = v;
|
||||
else if (k == "ds4_layers") ds4_layers = v;
|
||||
else if (k == "ds4_listen") ds4_listen = v;
|
||||
else if (k == "ds4_route_timeout") {
|
||||
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
g_kv_cache.SetDir(g_kv_cache_dir);
|
||||
@@ -422,6 +518,49 @@ public:
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
|
||||
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
||||
// distributed inference: this process listens on ds4_listen and owns
|
||||
// the ds4_layers slice; workers dial in (see `local-ai worker
|
||||
// ds4-distributed`). Absent ds4_role => unchanged single-node path.
|
||||
// Must be static: opt.distributed.listen_host is a const char* the
|
||||
// engine retains past this call, so it cannot point at a local that
|
||||
// goes out of scope (otherwise a future "simplify to local" refactor
|
||||
// reintroduces a dangling pointer).
|
||||
static std::string s_listen_host;
|
||||
if (ds4_role == "coordinator") {
|
||||
if (ds4_layers.empty() || ds4_listen.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_role:coordinator requires ds4_layers and ds4_listen");
|
||||
return GStatus::OK;
|
||||
}
|
||||
// host:port for IPv4/hostname; IPv6 literals are unsupported (the
|
||||
// first colon would split inside the address).
|
||||
auto host_port = split_option(ds4_listen); // "host:port" -> {host, port}
|
||||
if (host_port.second.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen must be host:port");
|
||||
return GStatus::OK;
|
||||
}
|
||||
int listen_port = 0;
|
||||
if (!parse_positive_int(host_port.second, &listen_port)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen port must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
ds4_distributed_layers layers = {};
|
||||
if (!parse_layers_spec(ds4_layers, &layers)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: invalid ds4_layers (want START:END or START:output)");
|
||||
return GStatus::OK;
|
||||
}
|
||||
s_listen_host = host_port.first;
|
||||
opt.distributed.role = DS4_DISTRIBUTED_COORDINATOR;
|
||||
opt.distributed.layers = layers;
|
||||
opt.distributed.listen_host = s_listen_host.c_str();
|
||||
opt.distributed.listen_port = listen_port;
|
||||
g_distributed = true;
|
||||
}
|
||||
|
||||
int rc = ds4_engine_open(&g_engine, &opt);
|
||||
if (rc != 0 || !g_engine) {
|
||||
result->set_success(false);
|
||||
@@ -458,10 +597,13 @@ public:
|
||||
|
||||
GStatus Predict(ServerContext *, const backend::PredictOptions *request,
|
||||
backend::Reply *reply) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
@@ -554,10 +696,13 @@ public:
|
||||
|
||||
GStatus PredictStream(ServerContext *, const backend::PredictOptions *request,
|
||||
ServerWriter<backend::Reply> *writer) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
|
||||
@@ -5,7 +5,8 @@ REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
cp -avf "$CURDIR/grpc-server" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/ds4-worker" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
UNAME_S=$(uname -s)
|
||||
if [ "$UNAME_S" = "Darwin" ]; then
|
||||
|
||||
126
backend/cpp/ds4/worker_main.c
Normal file
126
backend/cpp/ds4/worker_main.c
Normal file
@@ -0,0 +1,126 @@
|
||||
// ds4-worker: standalone distributed worker for the LocalAI ds4 backend.
|
||||
//
|
||||
// A ds4 distributed worker owns a slice of the model's transformer layers,
|
||||
// dials the coordinator, and serves activations for its slice. It does NOT
|
||||
// speak backend.proto - it speaks ds4's own TCP transport via ds4_dist_run().
|
||||
// This binary is intentionally minimal (no HTTP/web/kvstore/linenoise): it
|
||||
// only needs the engine objects + ds4_distributed.o, which the backend already
|
||||
// builds. It is launched by `local-ai worker ds4-distributed`.
|
||||
//
|
||||
// Usage:
|
||||
// ds4-worker --role worker --model <gguf> --layers 20:output \
|
||||
// --coordinator <host> <port> [--cpu|--cuda|--metal] [-c CTX] [-t N]
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <signal.h>
|
||||
#include <limits.h>
|
||||
|
||||
#include "ds4.h"
|
||||
#include "ds4_distributed.h"
|
||||
|
||||
static const char *need_arg(int *i, int argc, char **argv, const char *flag) {
|
||||
if (*i + 1 >= argc) {
|
||||
fprintf(stderr, "ds4-worker: missing value for %s\n", flag);
|
||||
exit(2);
|
||||
}
|
||||
return argv[++(*i)];
|
||||
}
|
||||
|
||||
static int parse_int_arg(const char *s, const char *flag) {
|
||||
char *end = NULL;
|
||||
long v = strtol(s, &end, 10);
|
||||
if (!s[0] || *end || v <= 0 || v > INT_MAX) {
|
||||
fprintf(stderr, "ds4-worker: invalid value for %s: %s\n", flag, s);
|
||||
exit(2);
|
||||
}
|
||||
return (int)v;
|
||||
}
|
||||
|
||||
static ds4_backend default_backend(void) {
|
||||
#if defined(DS4_NO_GPU)
|
||||
return DS4_BACKEND_CPU;
|
||||
#elif defined(__APPLE__)
|
||||
return DS4_BACKEND_METAL;
|
||||
#else
|
||||
return DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
signal(SIGPIPE, SIG_IGN);
|
||||
|
||||
ds4_engine_options opt = {0};
|
||||
opt.backend = default_backend();
|
||||
int ctx_size = 32768;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const char *arg = argv[i];
|
||||
if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
|
||||
fprintf(stdout, "ds4-worker: standalone ds4 distributed worker\n");
|
||||
ds4_dist_usage(stdout);
|
||||
fprintf(stdout, " -m, --model PATH model GGUF (the worker loads only its --layers slice)\n");
|
||||
fprintf(stdout, " -c, --ctx N context size (default 32768)\n");
|
||||
fprintf(stdout, " -t, --threads N CPU threads\n");
|
||||
fprintf(stdout, " --cpu|--cuda|--metal backend override\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
char dist_err[256] = {0};
|
||||
ds4_dist_cli_parse_result dist_parse =
|
||||
ds4_dist_parse_cli_arg(arg, &i, argc, argv, &opt.distributed,
|
||||
dist_err, sizeof(dist_err));
|
||||
if (dist_parse == DS4_DIST_CLI_ERROR) {
|
||||
fprintf(stderr, "ds4-worker: %s\n",
|
||||
dist_err[0] ? dist_err : "invalid distributed option");
|
||||
return 2;
|
||||
}
|
||||
if (dist_parse == DS4_DIST_CLI_MATCHED) continue;
|
||||
|
||||
if (!strcmp(arg, "-m") || !strcmp(arg, "--model")) {
|
||||
opt.model_path = need_arg(&i, argc, argv, arg);
|
||||
} else if (!strcmp(arg, "-c") || !strcmp(arg, "--ctx")) {
|
||||
ctx_size = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "-t") || !strcmp(arg, "--threads")) {
|
||||
opt.n_threads = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "--cpu")) {
|
||||
opt.backend = DS4_BACKEND_CPU;
|
||||
} else if (!strcmp(arg, "--cuda")) {
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
} else if (!strcmp(arg, "--metal")) {
|
||||
opt.backend = DS4_BACKEND_METAL;
|
||||
} else {
|
||||
fprintf(stderr, "ds4-worker: unknown option: %s\n", arg);
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (opt.distributed.role != DS4_DISTRIBUTED_WORKER) {
|
||||
fprintf(stderr, "ds4-worker: --role worker is required\n");
|
||||
return 2;
|
||||
}
|
||||
if (!opt.model_path) {
|
||||
fprintf(stderr, "ds4-worker: --model is required\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
char prep_err[256] = {0};
|
||||
if (ds4_dist_prepare_engine_options(&opt.distributed, &opt,
|
||||
prep_err, sizeof(prep_err)) != 0) {
|
||||
fprintf(stderr, "ds4-worker: %s\n", prep_err);
|
||||
return 2;
|
||||
}
|
||||
|
||||
ds4_engine *engine = NULL;
|
||||
if (ds4_engine_open(&engine, &opt) != 0 || !engine) {
|
||||
fprintf(stderr, "ds4-worker: failed to open engine\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
ds4_dist_generation_options gen = {0};
|
||||
gen.ctx_size = ctx_size;
|
||||
int rc = ds4_dist_run(engine, &opt.distributed, &gen);
|
||||
ds4_engine_close(engine);
|
||||
return rc;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=d2da6da05c73aeb658a3d1751f386c24e6963856
|
||||
IK_LLAMA_VERSION?=e6f8112f3ba126eed3ff5b30cdd08085414a7516
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=0d18aaa9d1a8af3df9abccd828e22eeaac7f840b
|
||||
LLAMA_VERSION?=4c6595503fe45d5a39f88d194e270f64c7424677
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -381,6 +381,15 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
});
|
||||
}
|
||||
|
||||
// for each video in the request, add the video data
|
||||
for (int i = 0; i < predict->videos_size(); i++) {
|
||||
data["video_data"].push_back(json
|
||||
{
|
||||
{"id", i},
|
||||
{"data", predict->videos(i)},
|
||||
});
|
||||
}
|
||||
|
||||
data["stop"] = predict->stopprompts();
|
||||
// data["n_probs"] = predict->nprobs();
|
||||
//TODO: images,
|
||||
@@ -482,23 +491,13 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!request->draftmodel().empty()) {
|
||||
params.speculative.draft.mparams.path = request->draftmodel();
|
||||
// Default to draft type if a draft model is set but no explicit type.
|
||||
// Upstream (post ggml-org/llama.cpp#22838) made the speculative type a
|
||||
// vector; the turboquant fork still uses the legacy scalar. The
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
||||
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
|
||||
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
}
|
||||
#else
|
||||
// Upstream made the speculative type a vector (ggml-org/llama.cpp#22838)
|
||||
// and renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE (#22964).
|
||||
const bool no_spec_type = params.speculative.types.empty() ||
|
||||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
||||
if (no_spec_type) {
|
||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// params.model_alias ??
|
||||
@@ -573,8 +572,13 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// checkpoint_min_step: minimum spacing between context checkpoints in
|
||||
// tokens (0 disables the minimum). Match upstream's default (256). This
|
||||
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
|
||||
// also shifted from a fixed cadence to a minimum spacing.
|
||||
// also shifted from a fixed cadence to a minimum spacing. The turboquant
|
||||
// fork still lacks common_params::checkpoint_min_step, so skip it there
|
||||
// (LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh).
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
params.checkpoint_min_step = 256;
|
||||
#endif
|
||||
|
||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||
for (int i = 0; i < request->options_size(); i++) {
|
||||
@@ -748,11 +752,18 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
|
||||
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
|
||||
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
|
||||
// with existing user configs: upstream renamed the field and shifted its
|
||||
// semantics from a fixed cadence to a minimum spacing.
|
||||
//
|
||||
// Gated out for the turboquant fork, which lacks common_params::
|
||||
// checkpoint_min_step. The leading `}` closing the cache_idle_slots
|
||||
// branch is removed with this block; the next `} else if` (n_ubatch)
|
||||
// then closes cache_idle_slots, so braces stay balanced under both
|
||||
// preprocessor branches.
|
||||
} else if (!strcmp(optname, "checkpoint_min_step") || !strcmp(optname, "checkpoint_min_spacing") ||
|
||||
!strcmp(optname, "checkpoint_every_nt") || !strcmp(optname, "checkpoint_every_n_tokens")) {
|
||||
if (optval != NULL) {
|
||||
@@ -762,6 +773,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// If conversion fails, keep default value (256)
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// --- physical batch size (upstream -ub / --ubatch-size) ---
|
||||
// Note: line ~482 already aliases n_ubatch to n_batch as a default; this
|
||||
@@ -894,17 +906,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// Fork only knows a single scalar `type`. Take the first comma-
|
||||
// separated value and assign it via the singular helper.
|
||||
std::string first = optval_str;
|
||||
const auto comma = first.find(',');
|
||||
if (comma != std::string::npos) first = first.substr(0, comma);
|
||||
auto type = common_speculative_type_from_name(first);
|
||||
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||
params.speculative.type = type;
|
||||
}
|
||||
#else
|
||||
// Upstream switched to a vector of types (comma-separated for multi-type
|
||||
// chaining via common_speculative_types_from_names). We keep accepting a
|
||||
// single value here, but also tolerate comma-separated lists.
|
||||
@@ -933,7 +934,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!parsed.empty()) {
|
||||
params.speculative.types = parsed;
|
||||
}
|
||||
#endif
|
||||
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.draft.n_max = std::stoi(optval_str); } catch (...) {}
|
||||
@@ -971,21 +971,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// shares the target context size. Accept the option for backward
|
||||
// compatibility but silently ignore it.
|
||||
|
||||
// Everything below relies on struct shape introduced in ggml-org/llama.cpp#22838
|
||||
// (parallel drafting): `ngram_mod`, `ngram_map_k`, `ngram_map_k4v`,
|
||||
// `ngram_cache`, and the `draft.{cache_type_*, cpuparams*, tensor_buft_overrides}`
|
||||
// fields. The turboquant fork branched before that, so its build defines
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC via patch-grpc-server.sh and these option
|
||||
// keys become unrecognized (silently dropped, like any unknown opt) for it.
|
||||
//
|
||||
// The `#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC` / `#else` split below sits at the
|
||||
// closing-brace position of the `draft_ctx_size` branch on purpose: in the
|
||||
// legacy build the chain ends here (the brace closes draft_ctx_size), and in
|
||||
// the modern build the chain continues with `} else if (...)` instead, so the
|
||||
// brace count stays balanced under both branches of the preprocessor.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
}
|
||||
#else
|
||||
// --- ngram_mod family (upstream --spec-ngram-mod-*) ---
|
||||
} else if (!strcmp(optname, "spec_ngram_mod_n_min")) {
|
||||
if (optval != NULL) {
|
||||
@@ -1115,7 +1100,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
}
|
||||
if (!cur.empty()) flush(cur);
|
||||
}
|
||||
#endif // LOCALAI_LEGACY_LLAMA_CPP_SPEC — closes the `else`/`#ifdef` opened at draft_ctx_size
|
||||
}
|
||||
|
||||
// Set params.n_parallel from environment variable if not set via options (fallback)
|
||||
@@ -1165,6 +1149,8 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
}
|
||||
// Terminate the draft tensor_buft_overrides list with a sentinel, mirroring
|
||||
// the main-model handling above.
|
||||
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
|
||||
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
@@ -1526,7 +1512,7 @@ public:
|
||||
msg_json["role"] = msg.role();
|
||||
|
||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
||||
|
||||
// Handle content - can be string, null, or array
|
||||
// For multimodal content, we'll embed images/audio from separate fields
|
||||
@@ -1577,6 +1563,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else {
|
||||
// Use content as-is (already array or not last user message)
|
||||
@@ -1611,6 +1607,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else if (msg.role() == "tool") {
|
||||
// Tool role messages must have content field set, even if empty
|
||||
@@ -1926,6 +1932,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto re_it = metadata.find("reasoning_effort");
|
||||
if (re_it != metadata.end() && !re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2051,6 +2068,16 @@ public:
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &video_data = data.find("video_data");
|
||||
if (video_data != data.end() && video_data->is_array())
|
||||
{
|
||||
for (const auto &video : *video_data)
|
||||
{
|
||||
auto decoded_data = base64_decode(video["data"].get<std::string>());
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const bool has_mtmd = ctx_server.impl->mctx != nullptr;
|
||||
@@ -2186,7 +2213,15 @@ public:
|
||||
// content element — attaching to both would duplicate the first
|
||||
// token since oaicompat_msg_diffs is the same for both.
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
// Upstream llama.cpp (ggml-org/llama.cpp#23884) now emits an initial
|
||||
// "begin" partial whose to_json() returns null, used only to signal the
|
||||
// HTTP layer to flush 200 status headers before any token. gRPC has no
|
||||
// such concept, so there is nothing to emit — the real tokens arrive in
|
||||
// the loop below. Feeding this null into build_reply_from_json would
|
||||
// throw (uncaught) and surface as a generic RPC error.
|
||||
if (first_res_json.is_null()) {
|
||||
// skip the begin-of-stream marker
|
||||
} else if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
// Skip chat deltas for role-init elements (have "role" in
|
||||
@@ -2216,7 +2251,10 @@ public:
|
||||
}
|
||||
|
||||
json res_json = result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
if (res_json.is_null()) {
|
||||
// begin-of-stream marker (see note above) — nothing to emit
|
||||
continue;
|
||||
} else if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
@@ -2292,7 +2330,7 @@ public:
|
||||
}
|
||||
|
||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
||||
|
||||
// Handle content - can be string, null, or array
|
||||
// For multimodal content, we'll embed images/audio from separate fields
|
||||
@@ -2345,6 +2383,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else {
|
||||
// Use content as-is (already array or not last user message)
|
||||
@@ -2384,6 +2432,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i);
|
||||
} else if (!msg.tool_calls().empty()) {
|
||||
@@ -2708,6 +2766,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto predict_re_it = predict_metadata.find("reasoning_effort");
|
||||
if (predict_re_it != predict_metadata.end() && !predict_re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = predict_re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2835,6 +2904,16 @@ public:
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &video_data = data.find("video_data");
|
||||
if (video_data != data.end() && video_data->is_array())
|
||||
{
|
||||
for (const auto &video : *video_data)
|
||||
{
|
||||
auto decoded_data = base64_decode(video["data"].get<std::string>());
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process files
|
||||
@@ -3407,7 +3486,7 @@ public:
|
||||
if (body.count("prompt") != 0) {
|
||||
const bool add_special = json_value(body, "add_special", false);
|
||||
|
||||
llama_tokens tokens = tokenize_mixed(ctx_server.impl->vocab, body.at("content"), add_special, true);
|
||||
llama_tokens tokens = tokenize_mixed(ctx_server.impl->vocab, body.at("prompt"), add_special, true);
|
||||
|
||||
|
||||
for (const auto& token : tokens) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
|
||||
TURBOQUANT_VERSION?=7d9715f1f071fa07c7b2ad3dbfd320b314139e65
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -4,21 +4,19 @@
|
||||
#
|
||||
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
|
||||
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
|
||||
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
|
||||
# server-side random per-instance marker) with the legacy "<__media__>"
|
||||
# literal. The fork branched before that PR, so server-common.cpp has no
|
||||
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
|
||||
# "<__media__>", and Go-side tooling falls back to that sentinel when the
|
||||
# backend does not expose media_marker, so substituting the literal keeps
|
||||
# behavior identical on the turboquant path.
|
||||
# 3. Revert the `common_params_speculative` field references to the
|
||||
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
|
||||
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
|
||||
# the turboquant fork branched before that PR and still exposes the flat
|
||||
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
|
||||
# below map the new nested paths back to the legacy flat names so the
|
||||
# shared grpc-server.cpp keeps compiling against the fork's common.h.
|
||||
# Drop this block once the fork rebases past #22397.
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file
|
||||
# so the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default and the option handler).
|
||||
# That field does not exist in the fork yet; drop this once it does.
|
||||
#
|
||||
# The fork used to lag upstream on the whole common_params_speculative refactor
|
||||
# (ggml-org/llama.cpp#22397/#22838/#22964), the model_tgt rename (#22838) and
|
||||
# get_media_marker (#21962), which required a much larger compat shim here
|
||||
# (flat-field sed renames + a coarse LOCALAI_LEGACY_LLAMA_CPP_SPEC define). The
|
||||
# fork has since rebased past all of those, so the only remaining gap is
|
||||
# checkpoint_min_step. If a future bump reintroduces a divergence, add a narrow
|
||||
# guard in grpc-server.cpp keyed on a fork-specific macro and inject it here
|
||||
# rather than resurrecting the coarse one.
|
||||
#
|
||||
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
|
||||
@@ -72,69 +70,20 @@ else
|
||||
echo "==> KV allow-list patch OK"
|
||||
fi
|
||||
|
||||
if grep -q 'get_media_marker()' "$SRC"; then
|
||||
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
|
||||
# Only one call site today (ModelMetadata), but replace all occurrences to
|
||||
# stay robust if upstream adds more. Use a temp file to avoid relying on
|
||||
# sed -i portability (the builder image uses GNU sed, but keeping this
|
||||
# consistent with the awk block above).
|
||||
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> get_media_marker() substitution OK"
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file so
|
||||
# the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default assignment and the option
|
||||
# handler). That field does not exist in the fork yet. Drop this block once
|
||||
# the fork rebases past the bump that added checkpoint_min_step.
|
||||
if grep -q '^#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP, skipping"
|
||||
else
|
||||
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
|
||||
fi
|
||||
|
||||
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
|
||||
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
|
||||
# Each substitution is the exact post-refactor path → legacy flat field.
|
||||
# Order doesn't matter because the source paths are disjoint, but we keep
|
||||
# the most-specific (mparams.path) first for readability.
|
||||
sed -E \
|
||||
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
|
||||
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
|
||||
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
|
||||
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
|
||||
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
|
||||
"$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> speculative field rename OK"
|
||||
else
|
||||
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
|
||||
fi
|
||||
|
||||
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
|
||||
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
|
||||
# exposes the field as `model` on `server_context_impl`. The two call sites
|
||||
# are in the Rerank and ModelMetadata RPC handlers.
|
||||
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
|
||||
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
|
||||
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> model_tgt rename OK"
|
||||
else
|
||||
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
|
||||
fi
|
||||
|
||||
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
|
||||
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
|
||||
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
|
||||
# draft.tensor_buft_overrides) introduced for the post-#22838 layout. Those
|
||||
# blocks reference struct fields that simply do not exist in the fork.
|
||||
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
|
||||
else
|
||||
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
|
||||
# Insert the define before the very first `#include` so it precedes all the
|
||||
# speculative-decoding code paths.
|
||||
echo "==> patching $SRC to define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top"
|
||||
# Insert the define before the very first `#include` so it precedes the
|
||||
# checkpoint_min_step references.
|
||||
awk '
|
||||
!done && /^#include/ {
|
||||
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
|
||||
print "#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP 1"
|
||||
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
|
||||
print ""
|
||||
done = 1
|
||||
@@ -142,13 +91,13 @@ else
|
||||
{ print }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
|
||||
echo "==> LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP define OK"
|
||||
fi
|
||||
|
||||
echo "==> all patches applied"
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
hip: port the turboquant CUDA additions that ggml's HIP shim doesn't cover
|
||||
|
||||
The turboquant fork adds/modifies a few ggml-cuda.cu spots with CUDA APIs
|
||||
that ggml's HIP (and MUSA) compatibility layer does not provide, breaking
|
||||
the -gpu-rocm-hipblas-turboquant build:
|
||||
|
||||
1. ggml_cuda_copy2d_across_devices() (host-staged cross-device copy for
|
||||
split mul_mat output) uses the CUDA 3D-peer copy APIs
|
||||
cudaMemcpy3DPeerParms / make_cudaPitchedPtr / make_cudaExtent /
|
||||
cudaMemcpy3DPeerAsync. HIP genuinely does not support these (see the
|
||||
fork's own comment "HIP does not support cudaMemcpy3DPeerAsync"), so
|
||||
guard the peer fast path with #if !defined(GGML_USE_HIP) &&
|
||||
!defined(GGML_USE_MUSA) -- matching how the fork already guards the
|
||||
same API for the sibling 2D copy -- and fall through to the existing
|
||||
cudaMemcpyAsync staging fallback below (functionally identical,
|
||||
slightly slower on multi-GPU ROCm).
|
||||
|
||||
2. ggml_backend_cuda_device_event_new() creates its event with plain
|
||||
cudaEventCreate, which ggml's HIP shim does not alias (it only aliases
|
||||
cudaEventCreateWithFlags). Use cudaEventCreateWithFlags(...,
|
||||
cudaEventDisableTiming) -- exactly what the rest of this file already
|
||||
does (cf. lines ~1034, ~3461) and HIP-safe.
|
||||
|
||||
CUDA builds are unaffected. Drop the relevant hunk once the fork HIP-ports
|
||||
these; apply-patches.sh fails fast if an anchor goes stale.
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 0427e6b..6352e6a 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -1933,6 +1933,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
size_t width, size_t height, cudaStream_t dst_stream, cudaStream_t src_stream) {
|
||||
|
||||
const auto & info = ggml_cuda_info();
|
||||
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // 3D-peer copy types unmapped by ggml's HIP/MUSA shim; use staging fallback below
|
||||
if (info.peer_access[src_device][dst_device]) {
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = dst_device;
|
||||
@@ -1942,6 +1943,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
p.extent = make_cudaExtent(width, height, 1);
|
||||
return cudaMemcpy3DPeerAsync(&p, dst_stream);
|
||||
}
|
||||
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
// Fallback: stage all rows through a single contiguous pinned buffer
|
||||
int prev_device = ggml_cuda_get_device();
|
||||
@@ -5714,7 +5716,7 @@ static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_
|
||||
ggml_cuda_set_device(dev_ctx->device);
|
||||
|
||||
cudaEvent_t event;
|
||||
- CUDA_CHECK(cudaEventCreate(&event));
|
||||
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
|
||||
|
||||
return new ggml_backend_event {
|
||||
/* .device = */ dev,
|
||||
@@ -192,6 +192,61 @@ var _ = Describe("Forward", func() {
|
||||
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
|
||||
})
|
||||
|
||||
It("refuses to follow upstream redirects and never leaks the key to the redirect target", func() {
|
||||
// A 3xx from the configured upstream means misconfiguration or a
|
||||
// hijacked/spoofed host. Following it would replay the request —
|
||||
// and the injected API key — to the Location host. Anthropic's
|
||||
// x-api-key is NOT stripped by Go on cross-host redirects, so this
|
||||
// would be a credential leak. The proxy must refuse the redirect.
|
||||
sinkHit := make(chan string, 1)
|
||||
sink := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sinkHit <- r.Header.Get("x-api-key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer sink.Close()
|
||||
|
||||
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, sink.URL, http.StatusFound)
|
||||
}))
|
||||
defer redirector.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_REDIRECT_KEY", "ant-secret")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: redirector.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_REDIRECT_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-no-redirect"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/messages",
|
||||
Method: "POST",
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
|
||||
// Drain the stream; a refused redirect surfaces as a non-EOF error.
|
||||
var streamErr error
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
streamErr = err
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(streamErr).To(HaveOccurred(), "refused redirect must surface as an error")
|
||||
Expect(sinkHit).NotTo(Receive(), "the redirect target must never be contacted")
|
||||
})
|
||||
|
||||
It("handles concurrent calls without interference", func() {
|
||||
// CloudProxy explicitly omits base.SingleThread — independent
|
||||
// Forward streams must not block each other or leak state.
|
||||
|
||||
@@ -11,9 +11,12 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
|
||||
@@ -48,10 +51,15 @@ type proxyConfig struct {
|
||||
}
|
||||
|
||||
func NewCloudProxy() *CloudProxy {
|
||||
// No Client-level Timeout — that would bound streaming SSE
|
||||
// responses too, which can legitimately last minutes. Per-request
|
||||
// deadlines come from the gRPC stream context.
|
||||
return &CloudProxy{client: &http.Client{}}
|
||||
// httpclient.New refuses redirects outright: the proxy talks to a
|
||||
// single configured upstream API (OpenAI/Anthropic/...) that answers
|
||||
// directly, so a 3xx means misconfiguration, a hijacked upstream, or
|
||||
// DNS trickery — never normal operation. Following it would replay the
|
||||
// request, including the operator's x-api-key (which Go does NOT strip
|
||||
// on cross-host redirects), to an unvetted host and leak the key
|
||||
// (GHSA-3mj3-57v2-4636). It also imposes no body deadline, so streaming
|
||||
// SSE responses that legitimately last minutes are not truncated.
|
||||
return &CloudProxy{client: httpclient.New()}
|
||||
}
|
||||
|
||||
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
|
||||
@@ -138,7 +146,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -168,7 +176,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -262,7 +270,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
@@ -426,4 +434,3 @@ func isHopByHopHeader(name string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
5
backend/go/crispasr/.gitignore
vendored
Normal file
5
backend/go/crispasr/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
sources
|
||||
build*
|
||||
libgocrispasr*.so
|
||||
crispasr
|
||||
package
|
||||
30
backend/go/crispasr/CMakeLists.txt
Normal file
30
backend/go/crispasr/CMakeLists.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(gocrispasr LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
add_subdirectory(./sources/CrispASR)
|
||||
|
||||
add_library(gocrispasr MODULE cpp/crispasr_shim.cpp)
|
||||
target_include_directories(gocrispasr PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/ggml/include)
|
||||
# Link the same backend set as crispasr-cli (examples/cli/CMakeLists.txt) so
|
||||
# the session API can dispatch to every compiled-in architecture, not just
|
||||
# whisper. crispasr is the referencer; the backend static libs supply the
|
||||
# per-architecture symbols; ggml is the math/runtime base.
|
||||
target_link_libraries(gocrispasr PRIVATE
|
||||
crispasr-lib
|
||||
parakeet canary canary_ctc cohere granite_speech granite_nle
|
||||
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
|
||||
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice
|
||||
silero-lid pyannote-seg funasr paraformer sensevoice
|
||||
crisp_audio
|
||||
ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gocrispasr PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
set_property(TARGET gocrispasr PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gocrispasr PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
132
backend/go/crispasr/Makefile
Normal file
132
backend/go/crispasr/Makefile
Normal file
@@ -0,0 +1,132 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=d745bda4386ae0f9d1d2f23fff8ec95d76428221
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
# Keep the build lean: no tests/examples/server/SDL2/curl/ffmpeg (the FROM scratch
|
||||
# image cannot satisfy those runtime deps). All ASR/TTS model backends stay enabled.
|
||||
CMAKE_ARGS+=-DCRISPASR_BUILD_TESTS=OFF -DCRISPASR_BUILD_EXAMPLES=OFF -DCRISPASR_BUILD_SERVER=OFF
|
||||
CMAKE_ARGS+=-DCRISPASR_SDL2=OFF -DCRISPASR_CURL=OFF -DCRISPASR_FFMPEG=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/CrispASR:
|
||||
mkdir -p sources/CrispASR
|
||||
cd sources/CrispASR && \
|
||||
git init && \
|
||||
git remote add origin $(CRISPASR_REPO) && \
|
||||
git fetch origin && \
|
||||
git checkout $(CRISPASR_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
# CrispASR's src/CMakeLists.txt locates its vendored llama.cpp
|
||||
# (crispasr-llama-core, used by the chat C-ABI) via ${CMAKE_SOURCE_DIR},
|
||||
# which assumes CrispASR is the top-level CMake project. We add_subdirectory
|
||||
# it, so ${CMAKE_SOURCE_DIR} is THIS backend dir and the talk-llama sources
|
||||
# aren't found. Rewrite to ${PROJECT_SOURCE_DIR} (the crispasr project root),
|
||||
# which is correct both standalone and as a subproject. Idempotent.
|
||||
sed -i 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgocrispasr-avx.so libgocrispasr-avx2.so libgocrispasr-avx512.so libgocrispasr-fallback.so
|
||||
else
|
||||
VARIANT_TARGETS = libgocrispasr-fallback.so
|
||||
endif
|
||||
|
||||
crispasr: main.go gocrispasr.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o crispasr ./
|
||||
|
||||
package: crispasr
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgocrispasr*.so package sources/CrispASR crispasr
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgocrispasr-avx.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx${RESET})
|
||||
SO_TARGET=libgocrispasr-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx2.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx2${RESET})
|
||||
SO_TARGET=libgocrispasr-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx512.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx512${RESET})
|
||||
SO_TARGET=libgocrispasr-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
libgocrispasr-fallback.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:fallback${RESET})
|
||||
SO_TARGET=libgocrispasr-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-custom: CMakeLists.txt cpp/crispasr_shim.cpp cpp/crispasr_shim.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgocrispasr.so ./$(SO_TARGET)
|
||||
|
||||
test: crispasr
|
||||
CGO_ENABLED=0 $(GOCMD) test -v ./...
|
||||
|
||||
all: crispasr package
|
||||
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
#include "crispasr_shim.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "crispasr.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
// Opaque session types. crispasr.h declares `struct crispasr_session;` but not
|
||||
// the result type nor the open/transcribe/result accessors — those are
|
||||
// CA_EXPORT extern "C" symbols in src/crispasr_c_api.cpp, so we forward-declare
|
||||
// exactly the ones we use. Signatures verified against
|
||||
// sources/CrispASR/src/crispasr_c_api.cpp.
|
||||
struct crispasr_session_result;
|
||||
extern "C" {
|
||||
crispasr_session *crispasr_session_open(const char *model_path, int n_threads);
|
||||
crispasr_session *crispasr_session_open_explicit(const char *model_path,
|
||||
const char *backend_name,
|
||||
int n_threads);
|
||||
int crispasr_session_set_codec_path(crispasr_session *s, const char *path);
|
||||
void crispasr_session_close(crispasr_session *s);
|
||||
const char *crispasr_session_backend(crispasr_session *s);
|
||||
int crispasr_session_set_translate(crispasr_session *s, int enable);
|
||||
crispasr_session_result *crispasr_session_transcribe_lang(
|
||||
crispasr_session *s, const float *pcm, int n_samples, const char *language);
|
||||
int crispasr_session_result_n_segments(crispasr_session_result *r);
|
||||
const char *crispasr_session_result_segment_text(crispasr_session_result *r,
|
||||
int i);
|
||||
int64_t crispasr_session_result_segment_t0(crispasr_session_result *r, int i);
|
||||
int64_t crispasr_session_result_segment_t1(crispasr_session_result *r, int i);
|
||||
void crispasr_session_result_free(crispasr_session_result *r);
|
||||
float *crispasr_session_synthesize(crispasr_session *s, const char *text,
|
||||
int *out_n_samples);
|
||||
void crispasr_pcm_free(float *pcm);
|
||||
int crispasr_session_set_speaker_name(crispasr_session *s, const char *name);
|
||||
int crispasr_session_set_voice(crispasr_session *s, const char *path,
|
||||
const char *ref_text_or_null);
|
||||
}
|
||||
|
||||
static crispasr_session *g_session = nullptr;
|
||||
static crispasr_session_result *g_result = nullptr;
|
||||
|
||||
static struct whisper_vad_context *vctx;
|
||||
static std::vector<float> flat_segs;
|
||||
|
||||
static std::atomic<int> g_abort{0};
|
||||
|
||||
extern "C" void set_abort(int v) {
|
||||
g_abort.store(v, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
|
||||
if (!log) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG:
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_INFO:
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_WARN:
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default: /* Potential future-proofing */
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[%-5s] ", level_str);
|
||||
fputs(log, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (backend_name && *backend_name) {
|
||||
g_session =
|
||||
crispasr_session_open_explicit(model_path, backend_name, threads);
|
||||
} else {
|
||||
g_session = crispasr_session_open(model_path, threads);
|
||||
}
|
||||
if (g_session == nullptr) {
|
||||
fprintf(stderr, "error: failed to open CrispASR session for model\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "info: CrispASR backend selected: %s\n",
|
||||
crispasr_session_backend(g_session));
|
||||
return 0;
|
||||
}
|
||||
|
||||
// set_codec_path forwards a companion file (qwen3-tts codec, orpheus SNAC,
|
||||
// chatterbox s3gen, or mimo-asr tokenizer) to the active session. Returns 0 on
|
||||
// success or when the active backend needs no companion, negative on failure,
|
||||
// and -1 when no session is open.
|
||||
int set_codec_path(const char *path) {
|
||||
return g_session ? crispasr_session_set_codec_path(g_session, path) : -1;
|
||||
}
|
||||
|
||||
int load_model_vad(const char *const model_path) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
struct whisper_vad_context_params vcparams =
|
||||
whisper_vad_default_context_params();
|
||||
|
||||
// XXX: Overridden to false in upstream due to performance?
|
||||
// vcparams.use_gpu = true;
|
||||
|
||||
vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
|
||||
if (vctx == nullptr) {
|
||||
fprintf(stderr, "error: Failed to init model as VAD\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
|
||||
size_t *segs_out_len) {
|
||||
if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
|
||||
fprintf(stderr, "error: failed to detect speech\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct whisper_vad_params params = whisper_vad_default_params();
|
||||
struct whisper_vad_segments *segs =
|
||||
whisper_vad_segments_from_probs(vctx, params);
|
||||
size_t segn = whisper_vad_segments_n_segments(segs);
|
||||
|
||||
// fprintf(stderr, "Got segments %zd\n", segn);
|
||||
|
||||
flat_segs.clear();
|
||||
|
||||
for (int i = 0; i < segn; i++) {
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
|
||||
}
|
||||
|
||||
// fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
|
||||
// segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
|
||||
// flat_segs.size());
|
||||
*segs_out = flat_segs.data();
|
||||
*segs_out_len = flat_segs.size();
|
||||
|
||||
// fprintf(stderr, "freeing segs\n");
|
||||
whisper_vad_free_segments(segs);
|
||||
|
||||
// fprintf(stderr, "returning\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// threads, diarize and prompt are accepted for Go-side API parity but unused
|
||||
// in Phase 1: the thread count is fixed at session open, and diarization and
|
||||
// the initial prompt are separate CrispASR features not yet wired through the
|
||||
// session ASR path.
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt) {
|
||||
(void)threads;
|
||||
(void)diarize;
|
||||
(void)prompt;
|
||||
|
||||
if (!g_session) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Reset stale abort flag from any prior cancelled call. set_abort remains
|
||||
// best-effort: the session transcribe call is blocking and exposes no abort
|
||||
// hook, so a mid-decode abort cannot interrupt it.
|
||||
g_abort.store(0, std::memory_order_relaxed);
|
||||
|
||||
crispasr_session_set_translate(g_session, translate ? 1 : 0);
|
||||
|
||||
if (g_result) {
|
||||
crispasr_session_result_free(g_result);
|
||||
g_result = nullptr;
|
||||
}
|
||||
|
||||
const char *language = (lang && *lang) ? lang : nullptr;
|
||||
g_result = crispasr_session_transcribe_lang(g_session, pcmf32, (int)pcmf32_len,
|
||||
language);
|
||||
if (!g_result) {
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
*segs_out_len = crispasr_session_result_n_segments(g_result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char *get_segment_text(int i) {
|
||||
if (!g_result) {
|
||||
return "";
|
||||
}
|
||||
return crispasr_session_result_segment_text(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t0(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t0(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t1(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t1(g_result, i);
|
||||
}
|
||||
|
||||
const char *get_backend(void) {
|
||||
return g_session ? crispasr_session_backend(g_session) : "";
|
||||
}
|
||||
|
||||
// TTS uses the already-open session (crispasr_session_open auto-detects a TTS
|
||||
// model). Output is 24 kHz mono float PCM (upstream CrispASR convention),
|
||||
// malloc'd by the C API; the caller must release it via tts_free.
|
||||
float *tts_synthesize(const char *text, int *out_n_samples) {
|
||||
if (out_n_samples) *out_n_samples = 0;
|
||||
if (!g_session || !text) return nullptr;
|
||||
return crispasr_session_synthesize(g_session, text, out_n_samples);
|
||||
}
|
||||
|
||||
void tts_free(float *pcm) {
|
||||
if (pcm) crispasr_pcm_free(pcm);
|
||||
}
|
||||
|
||||
int tts_set_voice(const char *name) {
|
||||
if (!g_session || !name || !*name) return 0;
|
||||
return crispasr_session_set_speaker_name(g_session, name);
|
||||
}
|
||||
|
||||
// tts_set_voice_file loads a voice from a file: a .gguf path selects a voice
|
||||
// pack, a .wav path with a non-empty ref_text performs zero-shot voice cloning
|
||||
// (the C API returns -2 when ref_text is required but missing). Returns -1 when
|
||||
// no session is open or path is null.
|
||||
int tts_set_voice_file(const char *path, const char *ref_text) {
|
||||
if (!g_session || !path) return -1;
|
||||
const char *ref = (ref_text && *ref_text) ? ref_text : nullptr;
|
||||
return crispasr_session_set_voice(g_session, path, ref);
|
||||
}
|
||||
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
extern "C" {
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name);
|
||||
int set_codec_path(const char *path);
|
||||
int load_model_vad(const char *const model_path);
|
||||
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
||||
size_t *segs_out_len);
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt);
|
||||
const char *get_segment_text(int i);
|
||||
int64_t get_segment_t0(int i);
|
||||
int64_t get_segment_t1(int i);
|
||||
const char *get_backend(void);
|
||||
void set_abort(int v);
|
||||
float *tts_synthesize(const char *text, int *out_n_samples); // 24kHz mono float, malloc'd; NULL on failure
|
||||
void tts_free(float *pcm);
|
||||
int tts_set_voice(const char *name); // best-effort speaker selection; 0 ok
|
||||
int tts_set_voice_file(const char *path, const char *ref_text); // load voice pack (.gguf) or zero-shot clone (.wav + ref_text)
|
||||
}
|
||||
539
backend/go/crispasr/gocrispasr.go
Normal file
539
backend/go/crispasr/gocrispasr.go
Normal file
@@ -0,0 +1,539 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string, threads int, backendName string) int
|
||||
CppSetCodecPath func(path string) int
|
||||
CppLoadModelVAD func(modelPath string) int
|
||||
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
||||
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int
|
||||
CppGetSegmentText func(i int) string
|
||||
CppGetSegmentStart func(i int) int64
|
||||
CppGetSegmentEnd func(i int) int64
|
||||
CppGetBackend func() string
|
||||
CppSetAbort func(v int)
|
||||
CppTTSSynthesize func(text string, outNSamples unsafe.Pointer) uintptr
|
||||
CppTTSFree func(ptr uintptr)
|
||||
CppTTSSetVoice func(name string) int
|
||||
CppTTSSetVoiceFile func(path string, refText string) int
|
||||
)
|
||||
|
||||
type CrispASR struct {
|
||||
base.SingleThread
|
||||
// sampleRate is the output rate (Hz) of the loaded TTS engine's PCM, used to
|
||||
// write a correct WAV header. Most CrispASR TTS backends emit 24 kHz, but
|
||||
// piper returns its model's native rate (16 kHz for x_low/low voices,
|
||||
// 22.05 kHz for medium/high), so it is read from the GGUF metadata at Load.
|
||||
sampleRate int
|
||||
}
|
||||
|
||||
// defaultTTSSampleRate is the output rate assumed for CrispASR TTS engines that
|
||||
// don't advertise one in GGUF metadata (vibevoice/orpheus/chatterbox/qwen3-tts
|
||||
// all emit 24 kHz). piper is the exception and carries piper.sample_rate.
|
||||
const defaultTTSSampleRate = 24000
|
||||
|
||||
// piperSampleRate reads the piper.sample_rate metadata key from a GGUF model.
|
||||
// CrispASR's piper backend returns PCM at the model's native rate without
|
||||
// resampling, so the WAV header must match it. Returns ok=false for non-piper
|
||||
// models (key absent) or an unreadable file, letting the caller fall back to
|
||||
// defaultTTSSampleRate.
|
||||
func piperSampleRate(modelPath string) (int, bool) {
|
||||
// Only scalar architecture keys are read, so skip the large array metadata
|
||||
// (phoneme map) and mmap the header - same rationale as pkg/vram's reader.
|
||||
f, err := gguf.ParseGGUFFile(modelPath, gguf.UseMMap(), gguf.SkipLargeMetadata())
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
kv, ok := f.Header.MetadataKV.Get("piper.sample_rate")
|
||||
if !ok || kv.ValueType != gguf.GGUFMetadataValueTypeUint32 {
|
||||
return 0, false
|
||||
}
|
||||
rate := int(kv.ValueUint32())
|
||||
if rate <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return rate, true
|
||||
}
|
||||
|
||||
// splitOption splits a "prefix:value" model option into its key and value,
|
||||
// matching the convention used by other backends (see sherpa-onnx). It returns
|
||||
// ok=false when the option carries no ':' separator.
|
||||
func splitOption(oo string) (key, value string, ok bool) {
|
||||
parts := strings.SplitN(oo, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
|
||||
func (w *CrispASR) Load(opts *pb.ModelOptions) error {
|
||||
vadOnly := false
|
||||
backendName := ""
|
||||
codecPath := ""
|
||||
speakerName := ""
|
||||
voicePath := ""
|
||||
voiceRefText := ""
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
if oo == "vad_only" {
|
||||
vadOnly = true
|
||||
continue
|
||||
}
|
||||
switch key, value, ok := splitOption(oo); {
|
||||
case ok && key == "backend":
|
||||
backendName = value
|
||||
case ok && key == "codec":
|
||||
codecPath = value
|
||||
case ok && key == "speaker":
|
||||
speakerName = value
|
||||
case ok && key == "voice":
|
||||
voicePath = value
|
||||
case ok && key == "voice_text":
|
||||
voiceRefText = value
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
}
|
||||
|
||||
if vadOnly {
|
||||
if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR VAD model")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve a relative companion path against the model directory so a config
|
||||
// can reference a sibling codec/tokenizer file by name alone.
|
||||
if codecPath != "" && !filepath.IsAbs(codecPath) {
|
||||
codecPath = filepath.Join(filepath.Dir(opts.ModelFile), codecPath)
|
||||
}
|
||||
|
||||
// A voice file (.gguf pack or .wav prompt) is resolved against the model
|
||||
// directory just like the codec, so a config can reference a sibling file.
|
||||
if voicePath != "" && !filepath.IsAbs(voicePath) {
|
||||
voicePath = filepath.Join(filepath.Dir(opts.ModelFile), voicePath)
|
||||
}
|
||||
|
||||
if ret := CppLoadModel(opts.ModelFile, int(opts.Threads), backendName); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR transcription model")
|
||||
}
|
||||
|
||||
// Determine the TTS output sample rate for the WAV header. piper voices
|
||||
// carry their native rate in GGUF metadata and CrispASR does not resample;
|
||||
// every other engine emits the 24 kHz default.
|
||||
w.sampleRate = defaultTTSSampleRate
|
||||
if rate, ok := piperSampleRate(opts.ModelFile); ok {
|
||||
w.sampleRate = rate
|
||||
}
|
||||
|
||||
// Load the companion file (codec/tokenizer/s3gen) after the session is open.
|
||||
// rc==0 means success or "not applicable" for the active backend; only a
|
||||
// negative code is fatal.
|
||||
if codecPath != "" {
|
||||
if rc := CppSetCodecPath(codecPath); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load companion file %q (rc=%d)", codecPath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR companion file loaded: %s\n", codecPath)
|
||||
}
|
||||
|
||||
// Apply the Load-time default voice. A baked speaker (speaker:) is selected
|
||||
// by name and is best-effort: a backend that can't honor it is logged, not
|
||||
// fatal. A voice file (voice:) is a hard requirement once configured, so a
|
||||
// negative rc fails Load.
|
||||
if speakerName != "" {
|
||||
if rc := CppTTSSetVoice(speakerName); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: speaker %q not applied (rc=%d)\n", speakerName, rc)
|
||||
}
|
||||
}
|
||||
if voicePath != "" {
|
||||
if rc := CppTTSSetVoiceFile(voicePath, voiceRefText); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load voice %q (rc=%d)", voicePath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR voice loaded: %s\n", voicePath)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "CrispASR backend selected: %s\n", CppGetBackend())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
audio := req.Audio
|
||||
// We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
|
||||
segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
|
||||
segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
|
||||
return pb.VADResponse{}, fmt.Errorf("Failed VAD")
|
||||
}
|
||||
|
||||
// Happens when CPP vector has not had any elements pushed to it
|
||||
if segsPtr == 0 {
|
||||
return pb.VADResponse{
|
||||
Segments: []*pb.VADSegment{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
|
||||
// however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
|
||||
segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen) //nolint:govet // segsPtr addresses C++-owned heap memory passed back through the cgo-free purego boundary; the uintptr->Pointer round-trip is intentional and the buffer outlives this read.
|
||||
|
||||
vadSegments := []*pb.VADSegment{}
|
||||
for i := range len(segs) >> 1 {
|
||||
s := segs[2*i] / 100
|
||||
t := segs[2*i+1] / 100
|
||||
vadSegments = append(vadSegments, &pb.VADSegment{
|
||||
Start: s,
|
||||
End: t,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.VADResponse{
|
||||
Segments: vadSegments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
// Watcher: flips the C-side abort flag when ctx is cancelled. The
|
||||
// goroutine is joined synchronously (close(done) signals it to exit,
|
||||
// wg.Wait() blocks until it has) so a late CppSetAbort(1) cannot fire
|
||||
// after the function returns and corrupt the next transcription call.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
text := ""
|
||||
for i := range int(segsLen) {
|
||||
// segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895
|
||||
s := CppGetSegmentStart(i) * (10000000)
|
||||
t := CppGetSegmentEnd(i) * (10000000)
|
||||
// The session result can emit bytes that aren't valid UTF-8 (e.g. a
|
||||
// multibyte codepoint split across token boundaries); protobuf string
|
||||
// fields reject those at marshal time. Scrub before the value escapes
|
||||
// cgo. The session result is segment+word based and exposes no token
|
||||
// IDs, so Tokens is left empty.
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
|
||||
segment := &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
|
||||
text += " " + strings.TrimSpace(txt)
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: strings.TrimSpace(text),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream runs the session transcribe to completion and then
|
||||
// emits one delta per non-empty segment, followed by a final TranscriptResult.
|
||||
// Progressive/real-time streaming isn't available via the session API (there
|
||||
// is no per-decode callback), so deltas are emitted per-segment after the
|
||||
// blocking decode returns rather than as segments are produced. The offline
|
||||
// AudioTranscription is unchanged; both paths share the session and the
|
||||
// SingleThread concurrency model.
|
||||
func (w *CrispASR) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
|
||||
// Same abort-watcher pattern as AudioTranscription. Joined synchronously
|
||||
// so a late CppSetAbort(1) cannot fire after this function returns.
|
||||
// Best-effort only: the session transcribe is blocking with no abort hook.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
// Walk the segments once: emit a delta per non-empty segment and build the
|
||||
// final TranscriptResult.Segments alongside. The first delta has no leading
|
||||
// space and subsequent ones are prefixed with a single space, so
|
||||
// concat(deltas) == final.Text exactly, matching the e2e contract.
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
var assembled strings.Builder
|
||||
for i := range int(segsLen) {
|
||||
s := CppGetSegmentStart(i) * 10000000
|
||||
t := CppGetSegmentEnd(i) * 10000000
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
segments = append(segments, &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
})
|
||||
|
||||
trimmed := strings.TrimSpace(txt)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
var delta string
|
||||
if assembled.Len() == 0 {
|
||||
delta = trimmed
|
||||
} else {
|
||||
delta = " " + trimmed
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
assembled.WriteString(delta)
|
||||
}
|
||||
|
||||
final := &pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: assembled.String(),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: final}
|
||||
return nil
|
||||
}
|
||||
|
||||
// synthesize returns 24 kHz mono float32 PCM for text via the open session.
|
||||
func (w *CrispASR) synthesize(text string) ([]float32, error) {
|
||||
if text == "" {
|
||||
return nil, fmt.Errorf("crispasr: TTS requires non-empty text")
|
||||
}
|
||||
var n int32
|
||||
ptr := CppTTSSynthesize(text, unsafe.Pointer(&n))
|
||||
if ptr == 0 || n <= 0 {
|
||||
return nil, fmt.Errorf("crispasr: synthesis failed (the loaded model may not be a supported TTS backend, or needs extra config e.g. orpheus SNAC codec)")
|
||||
}
|
||||
defer CppTTSFree(ptr)
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free.
|
||||
out := make([]float32, int(n)) // copy out of C memory before free
|
||||
copy(out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// setVoice applies a per-call speaker/voice override (best effort). CrispASR
|
||||
// returns a negative code when the active backend can't honor the name; we log
|
||||
// it rather than fail, so an unknown voice falls back to the default speaker.
|
||||
func setVoice(voice string) {
|
||||
v := strings.TrimSpace(voice)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if rc := CppTTSSetVoice(v); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: voice %q not applied by the active TTS backend (rc=%d); using default\n", v, rc)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *CrispASR) TTS(req *pb.TTSRequest) error {
|
||||
if req.Dst == "" {
|
||||
return fmt.Errorf("crispasr: TTS requires a destination path")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWAV(req.Dst, pcm, w.sampleRate)
|
||||
}
|
||||
|
||||
// TTSStream is the streaming counterpart to TTS. CrispASR has no progressive
|
||||
// (native streaming) synth, so we synthesize the whole utterance, encode it to
|
||||
// a 24 kHz WAV, and emit the encoded bytes as a single chunk. The gRPC server
|
||||
// wrapper (pkg/grpc/server.go:TTSStream) ranges over the channel until it is
|
||||
// closed, so this method owns the close - mirrors vibevoice-cpp's TTSStream.
|
||||
func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
defer close(results)
|
||||
|
||||
if req.Text == "" {
|
||||
return fmt.Errorf("crispasr: TTSStream requires text")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp("", "crispasr-tts-stream-*.wav")
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: tempfile: %w", err)
|
||||
}
|
||||
dst := tmp.Name()
|
||||
if err := tmp.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close tempfile: %w", err)
|
||||
}
|
||||
defer func() { _ = os.Remove(dst) }()
|
||||
|
||||
if err := writeWAV(dst, pcm, w.sampleRate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoded, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: read tempfile: %w", err)
|
||||
}
|
||||
results <- encoded
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeWAV writes pcm as a sampleRate Hz, mono, 16-bit PCM WAV at dst.
|
||||
func writeWAV(dst string, pcm []float32, sampleRate int) error {
|
||||
f, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: create %q: %w", dst, err)
|
||||
}
|
||||
|
||||
enc := wav.NewEncoder(f, sampleRate, 16, 1, 1)
|
||||
ints := make([]int, len(pcm))
|
||||
for i, s := range pcm {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
ints[i] = int(s * 32767)
|
||||
}
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: sampleRate},
|
||||
Data: ints,
|
||||
SourceBitDepth: 16,
|
||||
}
|
||||
if err := enc.Write(buf); err != nil {
|
||||
_ = enc.Close()
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: encode WAV: %w", err)
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: finalize WAV: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close %q: %w", dst, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
164
backend/go/crispasr/gocrispasr_samplerate_test.go
Normal file
164
backend/go/crispasr/gocrispasr_samplerate_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// GGUF metadata value type tags (subset) from the GGUF spec.
|
||||
const (
|
||||
ggufTypeUint32 uint32 = 4
|
||||
ggufTypeString uint32 = 8
|
||||
)
|
||||
|
||||
type ggufKV struct {
|
||||
key string
|
||||
vtype uint32
|
||||
val any
|
||||
}
|
||||
|
||||
// writeMinimalGGUF emits a valid, tensor-less GGUF file carrying only the given
|
||||
// metadata key-values. Enough for the header-only parse path piperSampleRate
|
||||
// uses; avoids pulling a real multi-MB voice into the test.
|
||||
func writeMinimalGGUF(path string, kvs []ggufKV) error {
|
||||
var b bytes.Buffer
|
||||
b.WriteString("GGUF") // magic
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint32(3)) // version
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint64(0)) // tensor count
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint64(len(kvs)))
|
||||
for _, kv := range kvs {
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint64(len(kv.key)))
|
||||
b.WriteString(kv.key)
|
||||
_ = binary.Write(&b, binary.LittleEndian, kv.vtype)
|
||||
switch v := kv.val.(type) {
|
||||
case uint32:
|
||||
_ = binary.Write(&b, binary.LittleEndian, v)
|
||||
case string:
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint64(len(v)))
|
||||
b.WriteString(v)
|
||||
}
|
||||
}
|
||||
return os.WriteFile(path, b.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
// wavSampleRate decodes the WAV header at path and returns its sample rate.
|
||||
func wavSampleRate(path string) (int, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
dec := wav.NewDecoder(f)
|
||||
dec.ReadInfo()
|
||||
return int(dec.SampleRate), nil
|
||||
}
|
||||
|
||||
var _ = Describe("piper sample rate", func() {
|
||||
Context("piperSampleRate", func() {
|
||||
It("reads piper.sample_rate from a piper GGUF (medium = 22050)", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "voice.gguf")
|
||||
Expect(writeMinimalGGUF(p, []ggufKV{
|
||||
{key: "general.architecture", vtype: ggufTypeString, val: "piper"},
|
||||
{key: "piper.sample_rate", vtype: ggufTypeUint32, val: uint32(22050)},
|
||||
})).To(Succeed())
|
||||
|
||||
rate, ok := piperSampleRate(p)
|
||||
Expect(ok).To(BeTrue(), "piper.sample_rate should be found")
|
||||
Expect(rate).To(Equal(22050))
|
||||
})
|
||||
|
||||
It("reads the low-quality rate (16000)", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "voice.gguf")
|
||||
Expect(writeMinimalGGUF(p, []ggufKV{
|
||||
{key: "piper.sample_rate", vtype: ggufTypeUint32, val: uint32(16000)},
|
||||
})).To(Succeed())
|
||||
|
||||
rate, ok := piperSampleRate(p)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(rate).To(Equal(16000))
|
||||
})
|
||||
|
||||
It("returns ok=false for a non-piper GGUF (no piper.sample_rate key)", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "other.gguf")
|
||||
Expect(writeMinimalGGUF(p, []ggufKV{
|
||||
{key: "general.architecture", vtype: ggufTypeString, val: "vibevoice"},
|
||||
})).To(Succeed())
|
||||
|
||||
_, ok := piperSampleRate(p)
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns ok=false for an unreadable/non-GGUF file", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "garbage.gguf")
|
||||
Expect(os.WriteFile(p, []byte("not a gguf"), 0o644)).To(Succeed())
|
||||
|
||||
_, ok := piperSampleRate(p)
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
// End-to-end through the built .so. Gated on CRISPASR_PIPER_MODEL_PATH (a
|
||||
// real piper voice GGUF) like the other model-backed specs; never runs in
|
||||
// default CI. Proves CrispASR's piper backend output rate flows into the
|
||||
// WAV header instead of the hardcoded 24 kHz default.
|
||||
Context("piper TTS end-to-end", func() {
|
||||
It("writes the WAV at the model's native piper.sample_rate", func() {
|
||||
model := os.Getenv("CRISPASR_PIPER_MODEL_PATH")
|
||||
if model == "" {
|
||||
Skip("set CRISPASR_PIPER_MODEL_PATH to run the piper e2e spec")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
|
||||
expected, ok := piperSampleRate(model)
|
||||
Expect(ok).To(BeTrue(), "model should carry piper.sample_rate metadata")
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{
|
||||
ModelFile: model,
|
||||
Options: []string{"backend:piper"},
|
||||
Threads: 4,
|
||||
})).To(Succeed())
|
||||
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "piper.wav")
|
||||
Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR piper.", Dst: dst})).To(Succeed())
|
||||
|
||||
info, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(info.Size()).To(BeNumerically(">", 1024), "expected a non-trivial WAV")
|
||||
|
||||
rate, err := wavSampleRate(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rate).To(Equal(expected),
|
||||
"WAV header rate must equal the model's native piper.sample_rate, not the 24k default")
|
||||
})
|
||||
})
|
||||
|
||||
Context("writeWAV", func() {
|
||||
It("writes the WAV header at the given sample rate (22050 for piper, not the 24k default)", func() {
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
pcm := make([]float32, 220) // 10 ms of silence is enough for a header
|
||||
Expect(writeWAV(dst, pcm, 22050)).To(Succeed())
|
||||
|
||||
rate, err := wavSampleRate(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rate).To(Equal(22050))
|
||||
})
|
||||
|
||||
It("writes a 16000 Hz header for low-quality piper voices", func() {
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
pcm := make([]float32, 160)
|
||||
Expect(writeWAV(dst, pcm, 16000)).To(Succeed())
|
||||
|
||||
rate, err := wavSampleRate(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rate).To(Equal(16000))
|
||||
})
|
||||
})
|
||||
})
|
||||
193
backend/go/crispasr/gocrispasr_test.go
Normal file
193
backend/go/crispasr/gocrispasr_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestCrispASR(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "CrispASR Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// bridge without spinning up the gRPC server. Skips the current spec when the
|
||||
// shared library isn't present (e.g. running before `make backends/whisper`).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
if _, err := os.Stat(libName); err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
|
||||
purego.RegisterLibFunc(&CppSetCodecPath, gosd, "set_codec_path")
|
||||
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
|
||||
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
|
||||
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
|
||||
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
|
||||
purego.RegisterLibFunc(&CppGetBackend, gosd, "get_backend")
|
||||
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
|
||||
purego.RegisterLibFunc(&CppTTSSynthesize, gosd, "tts_synthesize")
|
||||
purego.RegisterLibFunc(&CppTTSFree, gosd, "tts_free")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoice, gosd, "tts_set_voice")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoiceFile, gosd, "tts_set_voice_file")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("whisper library not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if either
|
||||
// env var is unset. The test never runs in default CI — it requires a real
|
||||
// whisper model and a long audio file (~3 minutes) on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("CRISPASR_MODEL_PATH")
|
||||
audioPath := os.Getenv("CRISPASR_AUDIO_PATH")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set CRISPASR_MODEL_PATH and CRISPASR_AUDIO_PATH to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// ttsModelOrSkip returns the TTS model path or skips the spec when the env var
|
||||
// is unset. Like the transcription fixtures, this never runs in default CI — it
|
||||
// needs a real TTS model (e.g. a vibevoice GGUF) on disk.
|
||||
func ttsModelOrSkip() string {
|
||||
modelPath := os.Getenv("CRISPASR_TTS_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
Skip("set CRISPASR_TTS_MODEL_PATH to run this spec")
|
||||
}
|
||||
return modelPath
|
||||
}
|
||||
|
||||
var _ = Describe("CrispASR", func() {
|
||||
Context("AudioTranscription cancellation", func() {
|
||||
It("returns codes.Canceled on a pre-cancelled context and still succeeds afterwards", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
// The session transcribe is blocking and exposes no abort hook, so
|
||||
// a mid-decode cancel can't interrupt it. The contract we can rely
|
||||
// on is the pre-call ctx.Err() check: a context cancelled before
|
||||
// the call must yield codes.Canceled without starting a decode.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "expected pre-cancelled context to fail")
|
||||
st, ok := status.FromError(err)
|
||||
Expect(ok).To(BeTrue(), "expected gRPC status error, got %v", err)
|
||||
Expect(st.Code()).To(Equal(codes.Canceled), "expected codes.Canceled, got %v", err)
|
||||
|
||||
// Subsequent transcription must succeed — proves g_abort reset.
|
||||
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "post-cancel transcription failed")
|
||||
Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("emits multiple deltas progressively for a multi-segment clip", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- w.AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
Stream: true,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var assembled strings.Builder
|
||||
var finalText string
|
||||
var finalSegmentCount int
|
||||
for chunk := range results {
|
||||
if d := chunk.GetDelta(); d != "" {
|
||||
deltas = append(deltas, d)
|
||||
assembled.WriteString(d)
|
||||
}
|
||||
if final := chunk.GetFinalResult(); final != nil {
|
||||
finalText = final.GetText()
|
||||
finalSegmentCount = len(final.GetSegments())
|
||||
}
|
||||
}
|
||||
Expect(<-done).ToNot(HaveOccurred())
|
||||
|
||||
// One delta per non-empty segment is emitted after the blocking
|
||||
// decode returns (the session API has no per-decode callback), so a
|
||||
// multi-segment clip MUST produce >=2 delta events, and
|
||||
// concat(deltas) MUST equal final.Text exactly.
|
||||
Expect(len(deltas)).To(BeNumerically(">=", 2),
|
||||
"expected multiple deltas from a multi-segment clip, got %d (assembled=%q)",
|
||||
len(deltas), assembled.String())
|
||||
Expect(finalSegmentCount).To(BeNumerically(">=", 2),
|
||||
"expected final to carry multiple segments")
|
||||
Expect(assembled.String()).To(Equal(finalText),
|
||||
"concat(deltas) must equal final.Text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("TTS", func() {
|
||||
It("synthesizes a non-empty WAV", func() {
|
||||
ttsModel := ttsModelOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: ttsModel})).To(Succeed())
|
||||
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR.", Dst: dst})).To(Succeed())
|
||||
|
||||
info, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred(), "synthesized WAV should exist at %q", dst)
|
||||
// A real 24 kHz mono WAV is a 44-byte header plus samples; anything
|
||||
// this small would mean an empty/failed synth.
|
||||
Expect(info.Size()).To(BeNumerically(">", 1024),
|
||||
"expected a non-trivial WAV, got %d bytes", info.Size())
|
||||
})
|
||||
})
|
||||
})
|
||||
58
backend/go/crispasr/main.go
Normal file
58
backend/go/crispasr/main.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "load_model"},
|
||||
{&CppSetCodecPath, "set_codec_path"},
|
||||
{&CppLoadModelVAD, "load_model_vad"},
|
||||
{&CppVAD, "vad"},
|
||||
{&CppTranscribe, "transcribe"},
|
||||
{&CppGetSegmentText, "get_segment_text"},
|
||||
{&CppGetSegmentStart, "get_segment_t0"},
|
||||
{&CppGetSegmentEnd, "get_segment_t1"},
|
||||
{&CppGetBackend, "get_backend"},
|
||||
{&CppSetAbort, "set_abort"},
|
||||
{&CppTTSSynthesize, "tts_synthesize"},
|
||||
{&CppTTSFree, "tts_free"},
|
||||
{&CppTTSSetVoice, "tts_set_voice"},
|
||||
{&CppTTSSetVoiceFile, "tts_set_voice_file"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &CrispASR{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
91
backend/go/crispasr/package.sh
Executable file
91
backend/go/crispasr/package.sh
Executable file
@@ -0,0 +1,91 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
# This script is used in the final stage of the Dockerfile
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/crispasr $CURDIR/package/
|
||||
cp -fv $CURDIR/libgocrispasr-*.so $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Bundle espeak-ng (+ its libpcaudio/libsonic runtime deps) and its voice data so
|
||||
# the piper TTS backend can phonemize non-English text. CrispASR dlopens
|
||||
# libespeak-ng.so.1 at runtime (the MIT-clean path); the dlopen succeeds loading
|
||||
# libespeak-ng but FAILS if libpcaudio/libsonic are absent, so all three .so are
|
||||
# required. run.sh points CRISPASR_ESPEAK_DATA_PATH at the bundled data dir.
|
||||
# Best-effort: only copied when present, so a local dev build without espeak-ng
|
||||
# installed still packages the rest (English voices keep working).
|
||||
ESPEAK_LIBDIR=""
|
||||
for d in /usr/lib/x86_64-linux-gnu /usr/lib/aarch64-linux-gnu; do
|
||||
if [ -f "$d/libespeak-ng.so.1" ]; then
|
||||
ESPEAK_LIBDIR="$d"
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ -n "$ESPEAK_LIBDIR" ]; then
|
||||
echo "Bundling espeak-ng from $ESPEAK_LIBDIR ..."
|
||||
cp -arfLv "$ESPEAK_LIBDIR/libespeak-ng.so.1" $CURDIR/package/lib/
|
||||
cp -arfLv "$ESPEAK_LIBDIR/libpcaudio.so.0" $CURDIR/package/lib/
|
||||
cp -arfLv "$ESPEAK_LIBDIR/libsonic.so.0" $CURDIR/package/lib/
|
||||
if [ -d "$ESPEAK_LIBDIR/espeak-ng-data" ]; then
|
||||
cp -arfLv "$ESPEAK_LIBDIR/espeak-ng-data" $CURDIR/package/
|
||||
fi
|
||||
else
|
||||
echo "espeak-ng not found; non-English piper voices will not phonemize"
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
57
backend/go/crispasr/run.sh
Executable file
57
backend/go/crispasr/run.sh
Executable file
@@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgocrispasr-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export CRISPASR_LIBRARY=$LIBRARY
|
||||
|
||||
# Point piper's espeak-ng phonemizer at the bundled voice data. The variable
|
||||
# names the directory CONTAINING espeak-ng-data (package.sh drops it next to
|
||||
# this script). Harmless when espeak-ng wasn't bundled.
|
||||
export CRISPASR_ESPEAK_DATA_PATH=$CURDIR
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/crispasr "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/crispasr "$@"
|
||||
@@ -9,7 +9,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
# LocalVQE upstream version pin. Bump to a specific commit when picking up
|
||||
# a new release; `main` works for development but is not reproducible.
|
||||
LOCALVQE_REPO?=https://github.com/localai-org/LocalVQE
|
||||
LOCALVQE_VERSION?=72bfb4c6
|
||||
LOCALVQE_VERSION?=b0f0378a450e87c871b85689554801601ca56d98
|
||||
|
||||
# LocalVQE handles CPU feature selection internally (it ships the multiple
|
||||
# libggml-cpu-*.so variants and its loader picks the best one at runtime
|
||||
@@ -27,7 +27,8 @@ endif
|
||||
|
||||
# LocalVQE upstream supports CPU + Vulkan only. Other BUILD_TYPE values
|
||||
# fall through to the default CPU build — Vulkan is already as fast as the
|
||||
# specialised GPU paths would be on this 1.3 M-parameter model.
|
||||
# specialised GPU paths would be on these small (1.3 M–4.8 M parameter)
|
||||
# models.
|
||||
ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DLOCALVQE_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -46,24 +46,24 @@ const (
|
||||
// through the options builder (CppOptionsNew + setters + CppNewWithOptions)
|
||||
// — the bare localvqe_new path doesn't expose backend / device selection.
|
||||
var (
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
)
|
||||
|
||||
// LocalVQE speaks gRPC against LocalVQE's flat C ABI. The streaming
|
||||
@@ -490,11 +490,14 @@ func (v *LocalVQE) applyStreamConfig(cfg *pb.AudioTransformStreamConfig) error {
|
||||
|
||||
// ---- WAV I/O ----------------------------------------------------------
|
||||
//
|
||||
// Minimal mono PCM WAV reader/writer. Only handles the subset LocalVQE
|
||||
// cares about (mono, 16-bit signed, no extensible chunks). For broader
|
||||
// audio support the HTTP layer's `audio.NormalizeAudioFile` already
|
||||
// converts arbitrary input to a canonical WAV before we see it; this
|
||||
// reader just decodes the canonical shape.
|
||||
// Reader/writer for the mono 16-bit PCM shape LocalVQE works with. Decoding
|
||||
// goes through the shared go-audio/wav decoder (as the whisper and parakeet
|
||||
// backends do) so RIFF chunk walking is handled robustly — an 18/40-byte
|
||||
// extensible `fmt ` chunk, or JUNK/bext/LIST metadata before or after `data`
|
||||
// (e.g. ffmpeg's trailing "Lavf" tag), is skipped rather than spliced into
|
||||
// the PCM stream as an audible click. The HTTP layer normalises arbitrary
|
||||
// input to WAV before we see it, but that WAV is ffmpeg output and is not
|
||||
// guaranteed to be the canonical 44-byte layout.
|
||||
|
||||
func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
f, err := os.Open(path)
|
||||
@@ -502,35 +505,26 @@ func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
header := make([]byte, 44)
|
||||
if _, err := io.ReadFull(f, header); err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
buf, err := wav.NewDecoder(f).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("decode WAV: %w", err)
|
||||
}
|
||||
if string(header[0:4]) != "RIFF" || string(header[8:12]) != "WAVE" {
|
||||
if buf == nil || buf.Format == nil {
|
||||
return nil, 0, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
channels := binary.LittleEndian.Uint16(header[22:24])
|
||||
sampleRate := binary.LittleEndian.Uint32(header[24:28])
|
||||
bitsPerSample := binary.LittleEndian.Uint16(header[34:36])
|
||||
|
||||
if channels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", channels)
|
||||
if buf.Format.NumChannels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", buf.Format.NumChannels)
|
||||
}
|
||||
if bitsPerSample != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", bitsPerSample)
|
||||
if buf.SourceBitDepth != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", buf.SourceBitDepth)
|
||||
}
|
||||
|
||||
rest, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
if len(buf.Data) == 0 {
|
||||
return nil, 0, fmt.Errorf("WAV has no audio data")
|
||||
}
|
||||
n := len(rest) / 2
|
||||
out := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
s := int16(binary.LittleEndian.Uint16(rest[i*2 : i*2+2]))
|
||||
out[i] = float32(s) / 32768.0
|
||||
}
|
||||
return out, int(sampleRate), nil
|
||||
// AsFloat32Buffer normalises by 2^(bitDepth-1) == /32768 for 16-bit,
|
||||
// matching the model's expected [-1, 1) input range.
|
||||
return buf.AsFloat32Buffer().Data, buf.Format.SampleRate, nil
|
||||
}
|
||||
|
||||
func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
@@ -546,13 +540,13 @@ func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
binary.LittleEndian.PutUint32(header[4:8], 36+dataLen)
|
||||
copy(header[8:12], []byte("WAVE"))
|
||||
copy(header[12:16], []byte("fmt "))
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[24:28], uint32(sampleRate))
|
||||
binary.LittleEndian.PutUint32(header[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
copy(header[36:40], []byte("data"))
|
||||
binary.LittleEndian.PutUint32(header[40:44], dataLen)
|
||||
if _, err := f.Write(header); err != nil {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -92,6 +94,147 @@ var _ = Describe("LocalVQE-cpp", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("readMonoWAVf32 chunk parsing", func() {
|
||||
// chunk builds a word-aligned RIFF sub-chunk (id + size + body + pad).
|
||||
chunk := func(id string, body []byte) []byte {
|
||||
out := append([]byte(id), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
out = append(out, body...)
|
||||
if len(body)&1 == 1 {
|
||||
out = append(out, 0) // pad byte for odd-sized chunks
|
||||
}
|
||||
return out
|
||||
}
|
||||
// fmtBody returns a PCM `fmt ` chunk body. extra bytes simulate the
|
||||
// 18/40-byte extensible form (cbSize + extension).
|
||||
fmtBody := func(channels, bits uint16, rate uint32, extra int) []byte {
|
||||
b := make([]byte, 16+extra)
|
||||
binary.LittleEndian.PutUint16(b[0:2], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(b[2:4], channels)
|
||||
binary.LittleEndian.PutUint32(b[4:8], rate)
|
||||
binary.LittleEndian.PutUint32(b[8:12], rate*uint32(channels)*uint32(bits)/8)
|
||||
binary.LittleEndian.PutUint16(b[12:14], channels*bits/8)
|
||||
binary.LittleEndian.PutUint16(b[14:16], bits)
|
||||
if extra >= 2 {
|
||||
binary.LittleEndian.PutUint16(b[16:18], uint16(extra-2)) // cbSize
|
||||
}
|
||||
return b
|
||||
}
|
||||
// pcm encodes int16 samples little-endian.
|
||||
pcm := func(samples ...int16) []byte {
|
||||
b := make([]byte, len(samples)*2)
|
||||
for i, s := range samples {
|
||||
binary.LittleEndian.PutUint16(b[i*2:i*2+2], uint16(s))
|
||||
}
|
||||
return b
|
||||
}
|
||||
riff := func(chunks ...[]byte) []byte {
|
||||
body := []byte("WAVE")
|
||||
for _, c := range chunks {
|
||||
body = append(body, c...)
|
||||
}
|
||||
out := append([]byte("RIFF"), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
return append(out, body...)
|
||||
}
|
||||
writeWAV := func(b []byte) string {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "in.wav")
|
||||
Expect(os.WriteFile(p, b, 0o600)).To(Succeed())
|
||||
return p
|
||||
}
|
||||
// A canonical sample run with distinct values so any off-by-one /
|
||||
// misalignment shows up as wrong numbers, not just wrong length.
|
||||
samples := []int16{1000, -2000, 3000, -4000, 5000, -6000}
|
||||
expectSamples := func(got []float32) {
|
||||
Expect(got).To(HaveLen(len(samples)))
|
||||
for i, s := range samples {
|
||||
Expect(got[i]).To(BeNumerically("~", float32(s)/32768.0, 1e-6))
|
||||
}
|
||||
}
|
||||
|
||||
It("reads a canonical 44-byte WAV", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("ignores a LIST/JUNK chunk placed before data (no leading-impulse splice)", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("JUNK", []byte("padding-bytes-here!")), // odd length → exercises pad
|
||||
chunk("LIST", []byte("INFOISFTLavf60.0")),
|
||||
chunk("data", pcm(samples...)),
|
||||
))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out) // not corrupted by the preceding chunks
|
||||
})
|
||||
|
||||
It("honours the data chunk size and drops a trailing metadata chunk", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("data", pcm(samples...)),
|
||||
chunk("LIST", []byte("INFOISFTLavf60.16.100")), // ffmpeg trailer tag
|
||||
))
|
||||
out, _, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectSamples(out) // trailing LIST bytes not decoded as PCM
|
||||
})
|
||||
|
||||
It("handles the 18-byte extensible fmt chunk", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 2)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("rejects non-mono input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(2, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("mono"))
|
||||
})
|
||||
|
||||
It("rejects non-16-bit input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 8, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("16-bit"))
|
||||
})
|
||||
|
||||
It("rejects a non-WAV file", func() {
|
||||
p := writeWAV([]byte("not a riff file at all"))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors when the data chunk is missing", func() {
|
||||
// fmt but no data: the decoder must fail rather than return an
|
||||
// empty (or garbage) sample slice. The exact message is the
|
||||
// decoder's, so just assert it errors.
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("round-trips through writeMonoWAVf32", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "rt.wav")
|
||||
in := []float32{0.1, -0.2, 0.3, -0.4}
|
||||
Expect(writeMonoWAVf32(p, in, 16000)).To(Succeed())
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
Expect(out).To(HaveLen(len(in)))
|
||||
for i := range in {
|
||||
Expect(out[i]).To(BeNumerically("~", in[i], 1e-4))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("model-gated integration (LOCALVQE_MODEL_PATH)", func() {
|
||||
It("load + sample rate + hop + fft", func() {
|
||||
path := modelPathOrSkip()
|
||||
|
||||
7
backend/go/locate-anything-cpp/.gitignore
vendored
Normal file
7
backend/go/locate-anything-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
sources/
|
||||
build*/
|
||||
package/
|
||||
liblocateanythingcpp*.so
|
||||
locate-anything-cpp
|
||||
test-models/
|
||||
test-data/
|
||||
57
backend/go/locate-anything-cpp/CMakeLists.txt
Normal file
57
backend/go/locate-anything-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,57 @@
|
||||
cmake_minimum_required(VERSION 3.18)
|
||||
project(liblocateanythingcpp LANGUAGES C CXX)
|
||||
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
# Static-link ggml + locate_anything so the resulting .so has no runtime
|
||||
# dependency on extra ggml/locate_anything shared libraries — only on
|
||||
# libc/libstdc++/libgomp, which the LocalAI package step bundles into the
|
||||
# docker image.
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build static libraries" FORCE)
|
||||
|
||||
# locate-anything.cpp build switches: skip CLI/tests, keep static lib.
|
||||
set(LA_BUILD_CLI OFF CACHE BOOL "Disable locate-anything CLI" FORCE)
|
||||
set(LA_BUILD_TESTS OFF CACHE BOOL "Disable locate-anything tests" FORCE)
|
||||
set(LA_SHARED OFF CACHE BOOL "Build locate_anything as static lib" FORCE)
|
||||
|
||||
# Unlike rt-detr.cpp, locate-anything.cpp ships no in-tree ggml patches, so
|
||||
# there is no apply_ggml_patches.sh hook to shim here.
|
||||
add_subdirectory(./sources/locate-anything.cpp)
|
||||
|
||||
# locate-anything.cpp's top-level CMakeLists points its own target's include
|
||||
# dirs at ${CMAKE_SOURCE_DIR}/{include,src,third_party,...}. CMAKE_SOURCE_DIR
|
||||
# is the *top-level* source dir of the whole CMake tree, so when we pull it in
|
||||
# via add_subdirectory it resolves to OUR directory, not theirs, and the
|
||||
# locate_anything target fails to find its own headers (la_capi.h, stb_image.h,
|
||||
# la_gguf_keys.h). Re-add the correct, subdir-relative include paths to the
|
||||
# already-defined target so it compiles regardless of where it's nested.
|
||||
set(LA_SRC ${CMAKE_CURRENT_SOURCE_DIR}/sources/locate-anything.cpp)
|
||||
target_include_directories(locate_anything PRIVATE
|
||||
${LA_SRC}/include
|
||||
${LA_SRC}/src
|
||||
${LA_SRC}/third_party
|
||||
${LA_SRC}/third_party/stb)
|
||||
|
||||
# locate-anything.cpp's C-API symbols already live inside liblocate_anything
|
||||
# (src/la_capi.cpp is compiled into the lib). We re-export them via a MODULE
|
||||
# library that links locate_anything so the symbols are visible at dlopen time.
|
||||
add_library(locateanythingcpp MODULE
|
||||
sources/locate-anything.cpp/src/la_capi.cpp)
|
||||
|
||||
target_include_directories(locateanythingcpp PRIVATE
|
||||
sources/locate-anything.cpp/include
|
||||
sources/locate-anything.cpp/src
|
||||
sources/locate-anything.cpp/third_party
|
||||
sources/locate-anything.cpp/third_party/stb
|
||||
)
|
||||
|
||||
target_link_libraries(locateanythingcpp PRIVATE locate_anything ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(locateanythingcpp PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
set_property(TARGET locateanythingcpp PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(locateanythingcpp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
134
backend/go/locate-anything-cpp/Makefile
Normal file
134
backend/go/locate-anything-cpp/Makefile
Normal file
@@ -0,0 +1,134 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# locate-anything.cpp. Pin to a specific commit for a stable build; leaving
|
||||
# this on `master` always picks up the latest C-API surface (incl. the
|
||||
# per-detection accessor functions used by golocateanythingcpp.go).
|
||||
LOCATEANYTHING_REPO?=https://github.com/mudler/locate-anything.cpp.git
|
||||
LOCATEANYTHING_VERSION?=92c1682da792c1e8a5dec91acc2be4b02c742ded
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# Forward LocalAI's BUILD_TYPE to the matching ggml backend switch.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON -DLA_GGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS?=gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DLA_GGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
CMAKE_ARGS+=-DLA_GGML_METAL=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/locate-anything.cpp:
|
||||
mkdir -p sources && \
|
||||
git clone --recursive $(LOCATEANYTHING_REPO) sources/locate-anything.cpp && \
|
||||
cd sources/locate-anything.cpp && \
|
||||
git checkout $(LOCATEANYTHING_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
# Only build CPU variants on Linux
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = liblocateanythingcpp-avx.so liblocateanythingcpp-avx2.so liblocateanythingcpp-avx512.so liblocateanythingcpp-fallback.so
|
||||
else
|
||||
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||
VARIANT_TARGETS = liblocateanythingcpp-fallback.so
|
||||
endif
|
||||
|
||||
locate-anything-cpp: main.go golocateanythingcpp.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o locate-anything-cpp ./
|
||||
|
||||
package: locate-anything-cpp
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf liblocateanythingcpp*.so locate-anything-cpp package sources
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
# Build all variants (Linux only)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
liblocateanythingcpp-avx.so: sources/locate-anything.cpp
|
||||
rm -rfv build-$@
|
||||
$(info ${GREEN}I locate-anything-cpp build info:avx${RESET})
|
||||
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) liblocateanythingcpp-custom
|
||||
rm -rfv build-$@
|
||||
|
||||
liblocateanythingcpp-avx2.so: sources/locate-anything.cpp
|
||||
rm -rfv build-$@
|
||||
$(info ${GREEN}I locate-anything-cpp build info:avx2${RESET})
|
||||
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) liblocateanythingcpp-custom
|
||||
rm -rfv build-$@
|
||||
|
||||
liblocateanythingcpp-avx512.so: sources/locate-anything.cpp
|
||||
rm -rfv build-$@
|
||||
$(info ${GREEN}I locate-anything-cpp build info:avx512${RESET})
|
||||
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) liblocateanythingcpp-custom
|
||||
rm -rfv build-$@
|
||||
endif
|
||||
|
||||
# Build fallback variant (all platforms)
|
||||
liblocateanythingcpp-fallback.so: sources/locate-anything.cpp
|
||||
rm -rfv build-$@
|
||||
$(info ${GREEN}I locate-anything-cpp build info:fallback${RESET})
|
||||
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) liblocateanythingcpp-custom
|
||||
rm -rfv build-$@
|
||||
|
||||
liblocateanythingcpp-custom: CMakeLists.txt
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/liblocateanythingcpp.so ./$(SO_TARGET)
|
||||
|
||||
all: locate-anything-cpp package
|
||||
|
||||
# `test` is invoked by the top-level Makefile's `test-extra` target. It builds
|
||||
# the backend binary + the fallback shared library (needed for dlopen at
|
||||
# runtime), then runs test.sh which downloads the q8_0 GGUF + COCO image and
|
||||
# exercises the gRPC Load/Detect wire path via the Go smoke test in
|
||||
# main_test.go.
|
||||
test: locate-anything-cpp liblocateanythingcpp-fallback.so
|
||||
bash test.sh
|
||||
174
backend/go/locate-anything-cpp/golocateanythingcpp.go
Normal file
174
backend/go/locate-anything-cpp/golocateanythingcpp.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package main
|
||||
|
||||
// golocateanythingcpp.go - gRPC handlers (Load, Detect) for the
|
||||
// locate-anything-cpp backend.
|
||||
//
|
||||
// Embeds base.SingleThread to default unimplemented RPCs to "not supported"
|
||||
// while we only implement open-vocabulary object detection (Detect).
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// la_ctx* is an opaque handle. la_capi_load returns it directly (0 == failure),
|
||||
// unlike rfdetr's out-parameter convention.
|
||||
var (
|
||||
// la_capi_load(const char* gguf_path, int n_threads) -> la_ctx* (0 = fail)
|
||||
CapiLoad func(gguf string, nThreads int32) uintptr
|
||||
// la_capi_free(la_ctx* ctx)
|
||||
CapiFree func(handle uintptr)
|
||||
// la_capi_locate_path(ctx, image_path, prompt, mode) -> char* json (0 = err)
|
||||
CapiLocatePath func(handle uintptr, imagePath string, prompt string, mode int32) uintptr
|
||||
// la_capi_locate_buffer(ctx, bytes, len, prompt, mode) -> char* json (0 = err)
|
||||
CapiLocateBuffer func(handle uintptr, bytes uintptr, length uintptr, prompt string, mode int32) uintptr
|
||||
// la_capi_get_n_detections(ctx) -> int
|
||||
CapiGetNDetections func(handle uintptr) int32
|
||||
// la_capi_get_detection_box(ctx, i, out_xyxy[4]) -> int (0 on success)
|
||||
CapiGetDetectionBox func(handle uintptr, i int32, outXYXY uintptr) int32
|
||||
// la_capi_get_detection_label(ctx, i, buf, buf_size) -> int (required size incl NUL; two-call sizing)
|
||||
CapiGetDetectionLabel func(handle uintptr, i int32, buf uintptr, bufSize int32) int32
|
||||
// la_capi_free_string(char* s)
|
||||
CapiFreeString func(s uintptr)
|
||||
// la_capi_last_error(ctx) -> const char* (owned by ctx, "" if none / null ctx).
|
||||
// purego marshals the returned C string into a Go string (a copy), so we
|
||||
// never free it and avoid raw pointer arithmetic.
|
||||
CapiLastError func(handle uintptr) string
|
||||
)
|
||||
|
||||
type LocateAnythingCpp struct {
|
||||
base.SingleThread
|
||||
handle uintptr
|
||||
}
|
||||
|
||||
// Load loads the GGUF model at opts.ModelFile (joined with opts.ModelPath if
|
||||
// relative) and stores the la_ctx handle for later Detect calls.
|
||||
func (r *LocateAnythingCpp) Load(opts *pb.ModelOptions) error {
|
||||
modelFile := opts.ModelFile
|
||||
if modelFile == "" {
|
||||
modelFile = opts.Model
|
||||
}
|
||||
if modelFile == "" {
|
||||
return fmt.Errorf("locate-anything-cpp: ModelFile is empty")
|
||||
}
|
||||
|
||||
var modelPath string
|
||||
if filepath.IsAbs(modelFile) {
|
||||
modelPath = modelFile
|
||||
} else {
|
||||
modelPath = filepath.Join(opts.ModelPath, modelFile)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(modelPath); err != nil {
|
||||
return fmt.Errorf("locate-anything-cpp: model file not found: %s: %w", modelPath, err)
|
||||
}
|
||||
|
||||
threads := opts.Threads
|
||||
if threads <= 0 {
|
||||
threads = 4
|
||||
}
|
||||
|
||||
// Release previous model if any (re-Load).
|
||||
if r.handle != 0 {
|
||||
CapiFree(r.handle)
|
||||
r.handle = 0
|
||||
}
|
||||
|
||||
h := CapiLoad(modelPath, threads)
|
||||
if h == 0 {
|
||||
// la_capi_last_error needs a ctx; on a failed load we have none (it
|
||||
// returns "" for a null ctx), so the text is best-effort. Surface it
|
||||
// when present.
|
||||
if msg := CapiLastError(0); msg != "" {
|
||||
return fmt.Errorf("locate-anything-cpp: la_capi_load failed for %s: %s", modelPath, msg)
|
||||
}
|
||||
return fmt.Errorf("locate-anything-cpp: la_capi_load failed for %s", modelPath)
|
||||
}
|
||||
r.handle = h
|
||||
return nil
|
||||
}
|
||||
|
||||
// Detect runs open-vocabulary detection on the base64-encoded image in opts.Src
|
||||
// using the required text prompt in opts.Prompt, returning one pb.Detection per
|
||||
// located object with its predicted label as ClassName.
|
||||
func (r *LocateAnythingCpp) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
|
||||
if r.handle == 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: model not loaded")
|
||||
}
|
||||
|
||||
// Open-vocabulary detection is prompt-driven; without a prompt there is
|
||||
// nothing to locate.
|
||||
prompt := opts.Prompt
|
||||
if prompt == "" {
|
||||
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: a text prompt is required (open-vocabulary detection)")
|
||||
}
|
||||
|
||||
// Decode base64 image and write to temp file.
|
||||
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: failed to decode base64 image: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "locate-anything-*.img")
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: failed to create temp file: %w", err)
|
||||
}
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
|
||||
if _, err := tmpFile.Write(imgData); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: failed to write temp file: %w", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
// mode 0 = hybrid (Parallel Box Decoding). The JSON return value is unused:
|
||||
// structured detections are read via the accessor functions. Still must
|
||||
// free the returned string.
|
||||
jsonPtr := CapiLocatePath(r.handle, tmpFile.Name(), prompt, 0)
|
||||
if jsonPtr != 0 {
|
||||
CapiFreeString(jsonPtr)
|
||||
}
|
||||
|
||||
n := CapiGetNDetections(r.handle)
|
||||
if n < 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: invalid n_detections=%d", n)
|
||||
}
|
||||
|
||||
detections := make([]*pb.Detection, 0, n)
|
||||
for i := int32(0); i < n; i++ {
|
||||
var xyxy [4]float32 // x1, y1, x2, y2
|
||||
if CapiGetDetectionBox(r.handle, i, uintptr(unsafe.Pointer(&xyxy[0]))) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Two-call sizing for the label string.
|
||||
label := ""
|
||||
need := CapiGetDetectionLabel(r.handle, i, 0, 0)
|
||||
if need > 0 {
|
||||
buf := make([]byte, need)
|
||||
CapiGetDetectionLabel(r.handle, i, uintptr(unsafe.Pointer(&buf[0])), need)
|
||||
label = string(buf[:need-1])
|
||||
}
|
||||
|
||||
detections = append(detections, &pb.Detection{
|
||||
X: xyxy[0],
|
||||
Y: xyxy[1],
|
||||
Width: xyxy[2] - xyxy[0],
|
||||
Height: xyxy[3] - xyxy[1],
|
||||
Confidence: 1.0,
|
||||
ClassName: label,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.DetectResponse{
|
||||
Detections: detections,
|
||||
}, nil
|
||||
}
|
||||
59
backend/go/locate-anything-cpp/main.go
Normal file
59
backend/go/locate-anything-cpp/main.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package main
|
||||
|
||||
// main.go - entry point for the locate-anything-cpp gRPC backend.
|
||||
//
|
||||
// Dlopens liblocateanythingcpp-<variant>.so via purego at the path in
|
||||
// LOCATEANYTHING_LIBRARY (set by run.sh based on /proc/cpuinfo), registers
|
||||
// the la_capi_* C ABI symbols, then starts the gRPC server.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Get library name from environment variable, default to fallback
|
||||
libName := os.Getenv("LOCATEANYTHING_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./liblocateanythingcpp-fallback.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CapiLoad, "la_capi_load"},
|
||||
{&CapiFree, "la_capi_free"},
|
||||
{&CapiLocatePath, "la_capi_locate_path"},
|
||||
{&CapiLocateBuffer, "la_capi_locate_buffer"},
|
||||
{&CapiGetNDetections, "la_capi_get_n_detections"},
|
||||
{&CapiGetDetectionBox, "la_capi_get_detection_box"},
|
||||
{&CapiGetDetectionLabel, "la_capi_get_detection_label"},
|
||||
{&CapiFreeString, "la_capi_free_string"},
|
||||
{&CapiLastError, "la_capi_last_error"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &LocateAnythingCpp{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
176
backend/go/locate-anything-cpp/main_test.go
Normal file
176
backend/go/locate-anything-cpp/main_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package main
|
||||
|
||||
// main_test.go - end-to-end smoke test for the locate-anything-cpp gRPC backend.
|
||||
//
|
||||
// Spawns the compiled locate-anything-cpp binary on a free local port, dials it
|
||||
// via gRPC, and exercises LoadModel + Detect against the test fixtures
|
||||
// downloaded by test.sh: the q8_0 GGUF of nvidia/LocateAnything-3B and a real
|
||||
// COCO image with people + cars. Asserts that open-vocabulary detection driven
|
||||
// by a text prompt returns at least one detection, each carrying a non-empty
|
||||
// class name and a bounding box of non-zero size.
|
||||
//
|
||||
// The spec Skip()s cleanly if its fixtures (the ~6.3 GB model, the test image,
|
||||
// the built binary, or the fallback .so) are missing, so the test target stays
|
||||
// usable on a fresh checkout / on CI runners where the large model hasn't been
|
||||
// downloaded.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
func TestDetect(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "locate-anything-cpp backend smoke suite")
|
||||
}
|
||||
|
||||
// freePort grabs an ephemeral TCP port and immediately releases it so the
|
||||
// spawned backend can bind to it. There is a tiny TOCTOU window here but in
|
||||
// practice it's adequate for a smoke test on a quiet runner.
|
||||
func freePort() int {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
Expect(err).ToNot(HaveOccurred(), "freePort listen")
|
||||
port := l.Addr().(*net.TCPAddr).Port
|
||||
Expect(l.Close()).To(Succeed())
|
||||
return port
|
||||
}
|
||||
|
||||
// startBackend spawns the locate-anything-cpp binary on the given port and
|
||||
// waits until it accepts TCP connections (up to 10s). It mirrors how main.go
|
||||
// resolves the purego library: the LOCATEANYTHING_LIBRARY env var points the
|
||||
// dlopen at the freshly built fallback .so, and the la_capi_* symbols are
|
||||
// registered there. The returned cleanup func kills the process and reaps it.
|
||||
func startBackend(port int) func() {
|
||||
binary, err := filepath.Abs("./locate-anything-cpp")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if _, err := os.Stat(binary); err != nil {
|
||||
Skip(fmt.Sprintf("backend binary not built: %s (run `make locate-anything-cpp` first)", binary))
|
||||
}
|
||||
|
||||
libPath, err := filepath.Abs("./liblocateanythingcpp-fallback.so")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if _, err := os.Stat(libPath); err != nil {
|
||||
Skip(fmt.Sprintf("fallback library not built: %s (run `make liblocateanythingcpp-fallback.so` first)", libPath))
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
cmd := exec.Command(binary, "--addr", addr)
|
||||
cmd.Env = append(os.Environ(), "LOCATEANYTHING_LIBRARY="+libPath)
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
Expect(cmd.Start()).To(Succeed())
|
||||
|
||||
cleanup := func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(10 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return cleanup
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
cleanup()
|
||||
Fail(fmt.Sprintf("backend did not become ready on %s within 10s", addr))
|
||||
return func() {}
|
||||
}
|
||||
|
||||
// loadTestImage reads the COCO test image downloaded by test.sh and returns its
|
||||
// base64-encoded content (the wire format accepted by the Detect RPC).
|
||||
func loadTestImage() string {
|
||||
imgPath, err := filepath.Abs("test-data/test.jpg")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
imgBytes, err := os.ReadFile(imgPath)
|
||||
if err != nil {
|
||||
Skip(fmt.Sprintf("test image not present: %s (run test.sh first)", imgPath))
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(imgBytes)
|
||||
}
|
||||
|
||||
// dialBackend opens a gRPC client connection to the spawned backend.
|
||||
func dialBackend(port int) (pb.BackendClient, func()) {
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return pb.NewBackendClient(conn), func() { _ = conn.Close() }
|
||||
}
|
||||
|
||||
// modelPathOrSkip resolves the model file under ./test-models/ and Skip()s the
|
||||
// current spec if it's missing (the ~6.3 GB GGUF is not present on a fresh
|
||||
// checkout / on CI runners without the download).
|
||||
func modelPathOrSkip(name string) string {
|
||||
modelDir, err := filepath.Abs("test-models")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
modelPath := filepath.Join(modelDir, name)
|
||||
if _, err := os.Stat(modelPath); err != nil {
|
||||
Skip(fmt.Sprintf("model not present: %s (run test.sh first)", modelPath))
|
||||
}
|
||||
return modelPath
|
||||
}
|
||||
|
||||
var _ = Describe("locate-anything-cpp backend", func() {
|
||||
It("runs open-vocabulary detection against a known-good COCO image", func() {
|
||||
modelPath := modelPathOrSkip("locate-anything-q8_0.gguf")
|
||||
imgB64 := loadTestImage()
|
||||
|
||||
port := freePort()
|
||||
cleanup := startBackend(port)
|
||||
defer cleanup()
|
||||
|
||||
client, closeConn := dialBackend(port)
|
||||
defer closeConn()
|
||||
|
||||
// The q8_0 model is ~6.3 GB and hybrid Parallel Box Decoding on CPU is
|
||||
// not cheap, so give LoadModel + Detect a generous deadline.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
loadResp, err := client.LoadModel(ctx, &pb.ModelOptions{
|
||||
Model: "locate-anything-q8_0.gguf",
|
||||
ModelFile: modelPath,
|
||||
Threads: 4,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "LoadModel")
|
||||
Expect(loadResp.GetSuccess()).To(BeTrue(), "LoadModel reported failure: %s", loadResp.GetMessage())
|
||||
|
||||
// Open-vocabulary detection is prompt-driven; the prompt names the
|
||||
// classes to locate (people + cars), separated by the </c> control token.
|
||||
detResp, err := client.Detect(ctx, &pb.DetectOptions{
|
||||
Src: imgB64,
|
||||
Prompt: "Locate all the instances that matches the following description: person</c>car.",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "Detect")
|
||||
Expect(detResp.GetDetections()).ToNot(BeEmpty(), "no detections returned on a known-good COCO image")
|
||||
|
||||
_, _ = fmt.Fprintf(GinkgoWriter, "detection OK: %d detections\n", len(detResp.GetDetections()))
|
||||
for i, d := range detResp.GetDetections() {
|
||||
Expect(d.GetClassName()).ToNot(BeEmpty(), "detection %d has empty class_name", i)
|
||||
Expect(d.GetWidth()).To(BeNumerically(">", float32(0)),
|
||||
"detection %d has non-positive width", i)
|
||||
Expect(d.GetHeight()).To(BeNumerically(">", float32(0)),
|
||||
"detection %d has non-positive height", i)
|
||||
_, _ = fmt.Fprintf(GinkgoWriter, " [%d] %s box=(%.1f,%.1f,%.1fx%.1f)\n",
|
||||
i, d.GetClassName(), d.GetX(), d.GetY(), d.GetWidth(), d.GetHeight())
|
||||
}
|
||||
})
|
||||
})
|
||||
59
backend/go/locate-anything-cpp/package.sh
Executable file
59
backend/go/locate-anything-cpp/package.sh
Executable file
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/liblocateanythingcpp-*.so $CURDIR/package/
|
||||
cp -avf $CURDIR/locate-anything-cpp $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/locate-anything-cpp/run.sh
Executable file
52
backend/go/locate-anything-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/liblocateanythingcpp-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/liblocateanythingcpp-avx.so ]; then
|
||||
LIBRARY="$CURDIR/liblocateanythingcpp-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/liblocateanythingcpp-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/liblocateanythingcpp-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/liblocateanythingcpp-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/liblocateanythingcpp-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export LOCATEANYTHING_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/locate-anything-cpp "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/locate-anything-cpp "$@"
|
||||
47
backend/go/locate-anything-cpp/test.sh
Executable file
47
backend/go/locate-anything-cpp/test.sh
Executable file
@@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
echo "Running locate-anything-cpp backend tests..."
|
||||
|
||||
# Test model from the mudler/locate-anything.cpp-gguf HuggingFace repo. This is
|
||||
# the q8_0 quantization of nvidia/LocateAnything-3B (~6.3 GB), so the download
|
||||
# is the slow step. It is resumed with `curl -C -` and skipped entirely if the
|
||||
# file is already present.
|
||||
LOCATEANYTHING_MODEL_DIR="${LOCATEANYTHING_MODEL_DIR:-$CURDIR/test-models}"
|
||||
|
||||
LOCATEANYTHING_MODEL_FILE="${LOCATEANYTHING_MODEL_FILE:-locate-anything-q8_0.gguf}"
|
||||
LOCATEANYTHING_MODEL_URL="${LOCATEANYTHING_MODEL_URL:-https://huggingface.co/mudler/locate-anything.cpp-gguf/resolve/main/locate-anything-q8_0.gguf}"
|
||||
|
||||
mkdir -p "$LOCATEANYTHING_MODEL_DIR"
|
||||
|
||||
if [ ! -f "$LOCATEANYTHING_MODEL_DIR/$LOCATEANYTHING_MODEL_FILE" ]; then
|
||||
echo "Downloading locate-anything q8_0 model (~6.3 GB, this is slow)..."
|
||||
# -C - resumes a partial download so an interrupted run doesn't restart from 0.
|
||||
curl -L -C - -o "$LOCATEANYTHING_MODEL_DIR/$LOCATEANYTHING_MODEL_FILE" "$LOCATEANYTHING_MODEL_URL" --progress-bar
|
||||
fi
|
||||
|
||||
# Use a real COCO test image (people + cars) from the upstream rf-detr.cpp repo
|
||||
# (~46 KB). Open-vocabulary detection needs real content to locate, so a
|
||||
# synthetic image would trivially yield zero detections.
|
||||
TEST_IMAGE_DIR="$CURDIR/test-data"
|
||||
TEST_IMAGE_FILE="$TEST_IMAGE_DIR/test.jpg"
|
||||
TEST_IMAGE_URL="${TEST_IMAGE_URL:-https://raw.githubusercontent.com/mudler/rf-detr.cpp/main/tests/fixtures/ci/test_image.jpg}"
|
||||
|
||||
mkdir -p "$TEST_IMAGE_DIR"
|
||||
if [ ! -f "$TEST_IMAGE_FILE" ]; then
|
||||
echo "Downloading COCO test image..."
|
||||
curl -L -o "$TEST_IMAGE_FILE" "$TEST_IMAGE_URL" --progress-bar
|
||||
fi
|
||||
|
||||
echo "locate-anything-cpp test setup complete."
|
||||
echo " model: $LOCATEANYTHING_MODEL_DIR/$LOCATEANYTHING_MODEL_FILE"
|
||||
echo " test image: $TEST_IMAGE_FILE"
|
||||
|
||||
# Run the Go smoke test: spawns the backend binary on a free port, calls
|
||||
# LoadModel + Detect via gRPC against the downloaded GGUF + COCO image.
|
||||
echo ""
|
||||
echo "Running Go smoke test..."
|
||||
cd "$CURDIR"
|
||||
go test -v -timeout 30m ./...
|
||||
17
backend/go/omnivoice-cpp/.gitignore
vendored
Normal file
17
backend/go/omnivoice-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
# Fetched upstream sources
|
||||
sources/
|
||||
|
||||
# CMake build directories
|
||||
build*/
|
||||
|
||||
# Compiled shared libraries
|
||||
*.so
|
||||
|
||||
# Compiled backend binary
|
||||
omnivoice-cpp
|
||||
|
||||
# Packaging output
|
||||
package/
|
||||
|
||||
# Downloaded e2e models
|
||||
omnivoice-models/
|
||||
53
backend/go/omnivoice-cpp/CMakeLists.txt
Normal file
53
backend/go/omnivoice-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,53 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
project(gomnivoicecpp LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(OMNIVOICE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/sources/omnivoice.cpp)
|
||||
|
||||
# Override upstream's CMAKE_CUDA_ARCHITECTURES before add_subdirectory.
|
||||
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "75-virtual;80-virtual;86-real;89-real")
|
||||
endif()
|
||||
|
||||
# Add the upstream project. Its own CMakeLists adds ggml + builds
|
||||
# omnivoice-core (STATIC, contains src/omnivoice.cpp i.e. the ov_* impl).
|
||||
# EXCLUDE_FROM_ALL keeps its CLI tools/tests from building unless referenced.
|
||||
add_subdirectory(${OMNIVOICE_DIR} omnivoice EXCLUDE_FROM_ALL)
|
||||
|
||||
# Upstream generates version.h into its own CMAKE_CURRENT_BINARY_DIR and adds
|
||||
# the top-level ${CMAKE_BINARY_DIR} to omnivoice-core's include path. When the
|
||||
# project is nested under add_subdirectory those two directories differ
|
||||
# (<build>/omnivoice vs <build>), so omnivoice.cpp cannot find version.h. Point
|
||||
# omnivoice-core at the subproject binary dir where version.h is actually
|
||||
# generated. (Fix lives here, never in the fetched upstream checkout.)
|
||||
target_include_directories(omnivoice-core PRIVATE ${CMAKE_BINARY_DIR}/omnivoice)
|
||||
|
||||
add_library(gomnivoicecpp MODULE cpp/gomnivoicecpp.cpp)
|
||||
target_link_libraries(gomnivoicecpp PRIVATE omnivoice-core)
|
||||
|
||||
target_include_directories(gomnivoicecpp PRIVATE ${OMNIVOICE_DIR}/src)
|
||||
target_include_directories(gomnivoicecpp SYSTEM PRIVATE ${OMNIVOICE_DIR}/ggml/include)
|
||||
|
||||
# Link GPU backends if the upstream ggml created them.
|
||||
foreach(backend blas cuda metal vulkan sycl)
|
||||
if(TARGET ggml-${backend})
|
||||
target_link_libraries(gomnivoicecpp PRIVATE ggml-${backend})
|
||||
if(backend STREQUAL "cuda")
|
||||
find_package(CUDAToolkit QUIET)
|
||||
if(CUDAToolkit_FOUND)
|
||||
target_link_libraries(gomnivoicecpp PRIVATE CUDA::cudart)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(MSVC)
|
||||
target_compile_options(gomnivoicecpp PRIVATE /W4 /wd4100 /wd4505)
|
||||
else()
|
||||
target_compile_options(gomnivoicecpp PRIVATE -Wall -Wextra
|
||||
-Wno-unused-parameter -Wno-unused-function)
|
||||
endif()
|
||||
|
||||
set_property(TARGET gomnivoicecpp PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gomnivoicecpp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
122
backend/go/omnivoice-cpp/Makefile
Normal file
122
backend/go/omnivoice-cpp/Makefile
Normal file
@@ -0,0 +1,122 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# omnivoice.cpp version
|
||||
OMNIVOICE_REPO?=https://github.com/ServeurpersoCom/omnivoice.cpp
|
||||
OMNIVOICE_VERSION?=2603355a5dfacae5cfc33531d5d0933221843509
|
||||
SO_TARGET?=libgomnivoicecpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/omnivoice.cpp:
|
||||
mkdir -p sources/omnivoice.cpp
|
||||
cd sources/omnivoice.cpp && \
|
||||
git init && \
|
||||
git remote add origin $(OMNIVOICE_REPO) && \
|
||||
git fetch origin && \
|
||||
git checkout $(OMNIVOICE_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
# Only build CPU variants on Linux
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgomnivoicecpp-avx.so libgomnivoicecpp-avx2.so libgomnivoicecpp-avx512.so libgomnivoicecpp-fallback.so
|
||||
else
|
||||
VARIANT_TARGETS = libgomnivoicecpp-fallback.so
|
||||
endif
|
||||
|
||||
omnivoice-cpp: main.go gomnivoicecpp.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o omnivoice-cpp ./
|
||||
|
||||
package: omnivoice-cpp
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgomnivoicecpp*.so package sources/omnivoice.cpp omnivoice-cpp
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
.NOTPARALLEL:
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgomnivoicecpp-avx.so: sources/omnivoice.cpp
|
||||
$(info ${GREEN}I omnivoice-cpp build info:avx${RESET})
|
||||
SO_TARGET=libgomnivoicecpp-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgomnivoicecpp-custom
|
||||
rm -rf build-libgomnivoicecpp-avx.so
|
||||
|
||||
libgomnivoicecpp-avx2.so: sources/omnivoice.cpp
|
||||
$(info ${GREEN}I omnivoice-cpp build info:avx2${RESET})
|
||||
SO_TARGET=libgomnivoicecpp-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgomnivoicecpp-custom
|
||||
rm -rf build-libgomnivoicecpp-avx2.so
|
||||
|
||||
libgomnivoicecpp-avx512.so: sources/omnivoice.cpp
|
||||
$(info ${GREEN}I omnivoice-cpp build info:avx512${RESET})
|
||||
SO_TARGET=libgomnivoicecpp-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgomnivoicecpp-custom
|
||||
rm -rf build-libgomnivoicecpp-avx512.so
|
||||
endif
|
||||
|
||||
libgomnivoicecpp-fallback.so: sources/omnivoice.cpp
|
||||
$(info ${GREEN}I omnivoice-cpp build info:fallback${RESET})
|
||||
SO_TARGET=libgomnivoicecpp-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgomnivoicecpp-custom
|
||||
rm -rf build-libgomnivoicecpp-fallback.so
|
||||
|
||||
libgomnivoicecpp-custom: CMakeLists.txt cpp/gomnivoicecpp.cpp cpp/gomnivoicecpp.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) --target gomnivoicecpp && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgomnivoicecpp.so ./$(SO_TARGET)
|
||||
|
||||
test: omnivoice-cpp
|
||||
@echo "Running omnivoice-cpp tests..."
|
||||
bash test.sh
|
||||
@echo "omnivoice-cpp tests completed."
|
||||
|
||||
all: omnivoice-cpp package
|
||||
129
backend/go/omnivoice-cpp/audio.go
Normal file
129
backend/go/omnivoice-cpp/audio.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
const omnivoiceSampleRate = 24000
|
||||
|
||||
// wavHeader24k returns a 44-byte WAV header for a streaming 24 kHz mono 16-bit
|
||||
// PCM stream, with placeholder (0xFFFFFFFF) sizes since the total length is
|
||||
// unknown up front. Emitted as the first chunk of TTSStream so the HTTP layer
|
||||
// receives a self-describing WAV (the gRPC TTSStream path never sets Message,
|
||||
// so the backend owns the header - see core/backend/tts.go:ModelTTSStream).
|
||||
func wavHeader24k() []byte {
|
||||
var buf bytes.Buffer
|
||||
w := func(v any) { _ = binary.Write(&buf, binary.LittleEndian, v) }
|
||||
buf.WriteString("RIFF")
|
||||
w(uint32(0xFFFFFFFF))
|
||||
buf.WriteString("WAVE")
|
||||
buf.WriteString("fmt ")
|
||||
w(uint32(16)) // Subchunk1Size
|
||||
w(uint16(1)) // PCM
|
||||
w(uint16(1)) // mono
|
||||
w(uint32(omnivoiceSampleRate)) // sample rate
|
||||
w(uint32(omnivoiceSampleRate * 2)) // byte rate = SR * blockAlign
|
||||
w(uint16(2)) // block align (16-bit mono)
|
||||
w(uint16(16)) // bits per sample
|
||||
buf.WriteString("data")
|
||||
w(uint32(0xFFFFFFFF))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// floatToPCM16LE clamps each sample to [-1,1] and encodes it as little-endian
|
||||
// signed 16-bit PCM.
|
||||
func floatToPCM16LE(samples []float32) []byte {
|
||||
out := make([]byte, len(samples)*2)
|
||||
for i, s := range samples {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
v := int16(s * 32767)
|
||||
out[i*2] = byte(v)
|
||||
out[i*2+1] = byte(v >> 8)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// writeWAV24k writes samples as a finalized 24 kHz mono 16-bit WAV at dst.
|
||||
func writeWAV24k(dst string, samples []float32) error {
|
||||
f, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("omnivoice: create %q: %w", dst, err)
|
||||
}
|
||||
enc := wav.NewEncoder(f, omnivoiceSampleRate, 16, 1, 1)
|
||||
ints := make([]int, len(samples))
|
||||
for i, s := range samples {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
ints[i] = int(s * 32767)
|
||||
}
|
||||
b := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: omnivoiceSampleRate},
|
||||
Data: ints,
|
||||
SourceBitDepth: 16,
|
||||
}
|
||||
if err := enc.Write(b); err != nil {
|
||||
_ = enc.Close()
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("omnivoice: encode WAV: %w", err)
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("omnivoice: finalize WAV: %w", err)
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
// readWAVAsFloat decodes a WAV file (any sample rate/channels) to a mono
|
||||
// float32 slice in [-1,1] for use as reference audio. OmniVoice expects 24 kHz;
|
||||
// callers should supply 24 kHz reference clips.
|
||||
func readWAVAsFloat(path string) ([]float32, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("omnivoice: open ref %q: %w", path, err)
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
dec := wav.NewDecoder(f)
|
||||
buf, err := dec.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("omnivoice: decode ref %q: %w", path, err)
|
||||
}
|
||||
ch := int(buf.Format.NumChannels)
|
||||
if ch < 1 {
|
||||
ch = 1
|
||||
}
|
||||
bitDepth := int(buf.SourceBitDepth)
|
||||
if bitDepth == 0 {
|
||||
bitDepth = 16
|
||||
}
|
||||
scale := float32(int64(1) << uint(bitDepth-1))
|
||||
n := len(buf.Data) / ch
|
||||
out := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
// Downmix to mono by averaging channels.
|
||||
var acc int
|
||||
for c := 0; c < ch; c++ {
|
||||
acc += buf.Data[i*ch+c]
|
||||
}
|
||||
out[i] = float32(acc) / float32(ch) / scale
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// runtimeKeepAlive prevents the GC from reclaiming the reference-audio slice
|
||||
// while its backing pointer is in use across the C call.
|
||||
func runtimeKeepAlive(v any) { runtime.KeepAlive(v) }
|
||||
166
backend/go/omnivoice-cpp/cpp/gomnivoicecpp.cpp
Normal file
166
backend/go/omnivoice-cpp/cpp/gomnivoicecpp.cpp
Normal file
@@ -0,0 +1,166 @@
|
||||
#include "gomnivoicecpp.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "omnivoice.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
static ov_context *g_ctx = nullptr;
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void * /*data*/) {
|
||||
if (!log)
|
||||
return;
|
||||
const char *lvl = "?????";
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG: lvl = "DEBUG"; break;
|
||||
case GGML_LOG_LEVEL_INFO: lvl = "INFO"; break;
|
||||
case GGML_LOG_LEVEL_WARN: lvl = "WARN"; break;
|
||||
case GGML_LOG_LEVEL_ERROR: lvl = "ERROR"; break;
|
||||
default: break;
|
||||
}
|
||||
fprintf(stderr, "[%-5s] %s", lvl, log);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
int omni_load(const char *model_path, const char *codec_path, int use_fa,
|
||||
int clamp_fp16) {
|
||||
ggml_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (!model_path || model_path[0] == '\0') {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: model_path is required\n");
|
||||
return 1;
|
||||
}
|
||||
if (!codec_path || codec_path[0] == '\0') {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: codec_path is required\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
ov_init_params p;
|
||||
ov_init_default_params(&p);
|
||||
p.model_path = model_path;
|
||||
p.codec_path = codec_path;
|
||||
p.use_fa = use_fa != 0;
|
||||
p.clamp_fp16 = clamp_fp16 != 0;
|
||||
|
||||
fprintf(stderr, "[omnivoice-cpp] Loading model=%s codec=%s\n", model_path,
|
||||
codec_path);
|
||||
|
||||
g_ctx = ov_init(&p);
|
||||
if (!g_ctx) {
|
||||
fprintf(stderr, "[omnivoice-cpp] FATAL: ov_init failed: %s\n",
|
||||
ov_last_error());
|
||||
return 3;
|
||||
}
|
||||
fprintf(stderr, "[omnivoice-cpp] Model loaded (%s)\n", ov_version());
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Fill an ov_tts_params from the flat wrapper arguments.
|
||||
static void fill_params(ov_tts_params *tp, const char *text, const char *lang,
|
||||
const char *instruct, const float *ref_samples,
|
||||
int ref_n, const char *ref_text, long long seed,
|
||||
int denoise) {
|
||||
ov_tts_default_params(tp);
|
||||
tp->text = text ? text : "";
|
||||
tp->lang = lang ? lang : "";
|
||||
if (instruct && instruct[0] != '\0')
|
||||
tp->instruct = instruct;
|
||||
if (ref_samples && ref_n > 0) {
|
||||
tp->ref_audio_24k = ref_samples;
|
||||
tp->ref_n_samples = ref_n;
|
||||
if (ref_text && ref_text[0] != '\0')
|
||||
tp->ref_text = ref_text;
|
||||
tp->denoise = denoise != 0;
|
||||
}
|
||||
if (seed >= 0)
|
||||
tp->mg_seed = (uint64_t)seed;
|
||||
}
|
||||
|
||||
float *omni_tts(const char *text, const char *lang, const char *instruct,
|
||||
const float *ref_samples, int ref_n, const char *ref_text,
|
||||
long long seed, int denoise, int *out_n) {
|
||||
if (out_n)
|
||||
*out_n = 0;
|
||||
if (!g_ctx) {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: model not loaded\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (!text || text[0] == '\0') {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: text is required\n");
|
||||
return nullptr; // omni_tts: out_n already 0
|
||||
}
|
||||
ov_tts_params tp;
|
||||
fill_params(&tp, text, lang, instruct, ref_samples, ref_n, ref_text, seed,
|
||||
denoise);
|
||||
|
||||
ov_audio out = {0};
|
||||
enum ov_status rc = ov_synthesize(g_ctx, &tp, &out);
|
||||
if (rc != OV_STATUS_OK || out.n_samples <= 0 || !out.samples) {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: synthesize failed (rc=%d): %s\n",
|
||||
(int)rc, ov_last_error());
|
||||
ov_audio_free(&out);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Copy into a plain malloc buffer the Go side can free symmetrically via
|
||||
// omni_pcm_free; then release the ov_audio-owned buffer.
|
||||
size_t bytes = (size_t)out.n_samples * sizeof(float);
|
||||
float *buf = (float *)malloc(bytes);
|
||||
if (!buf) {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: malloc(%zu) failed\n", bytes);
|
||||
ov_audio_free(&out);
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(buf, out.samples, bytes);
|
||||
if (out_n)
|
||||
*out_n = out.n_samples;
|
||||
ov_audio_free(&out);
|
||||
return buf;
|
||||
}
|
||||
|
||||
int omni_tts_stream(const char *text, const char *lang, const char *instruct,
|
||||
const float *ref_samples, int ref_n, const char *ref_text,
|
||||
long long seed, int denoise, omni_pcm_chunk_cb cb,
|
||||
void *user_data) {
|
||||
if (!g_ctx) {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: model not loaded\n");
|
||||
return 1;
|
||||
}
|
||||
if (!cb) {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: stream callback is null\n");
|
||||
return 2;
|
||||
}
|
||||
if (!text || text[0] == '\0') {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: text is required\n");
|
||||
return 4;
|
||||
}
|
||||
ov_tts_params tp;
|
||||
fill_params(&tp, text, lang, instruct, ref_samples, ref_n, ref_text, seed,
|
||||
denoise);
|
||||
// ov_audio_chunk_cb has the identical signature to omni_pcm_chunk_cb
|
||||
// (bool vs int return are ABI-compatible; non-zero == true).
|
||||
tp.on_chunk = (ov_audio_chunk_cb)cb;
|
||||
tp.on_chunk_user_data = user_data;
|
||||
|
||||
ov_audio out = {0}; // stays empty in streaming mode
|
||||
enum ov_status rc = ov_synthesize(g_ctx, &tp, &out);
|
||||
ov_audio_free(&out);
|
||||
if (rc != OV_STATUS_OK && rc != OV_STATUS_CANCELLED) {
|
||||
fprintf(stderr, "[omnivoice-cpp] ERROR: stream synth failed (rc=%d): %s\n",
|
||||
(int)rc, ov_last_error());
|
||||
return 3;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void omni_pcm_free(float *p) { free(p); }
|
||||
|
||||
void omni_unload(void) {
|
||||
if (g_ctx) {
|
||||
ov_free(g_ctx);
|
||||
g_ctx = nullptr;
|
||||
}
|
||||
}
|
||||
38
backend/go/omnivoice-cpp/cpp/gomnivoicecpp.h
Normal file
38
backend/go/omnivoice-cpp/cpp/gomnivoicecpp.h
Normal file
@@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Streaming PCM chunk callback. samples is mono float PCM at 24 kHz, valid
|
||||
// only for the duration of the call. Return non-zero to continue, 0 to abort.
|
||||
typedef int (*omni_pcm_chunk_cb)(const float *samples, int n_samples,
|
||||
void *user_data);
|
||||
|
||||
// Load the LM (model_path) + codec (codec_path) GGUFs. use_fa / clamp_fp16
|
||||
// map to ov_init_params. Returns 0 on success, non-zero on failure.
|
||||
int omni_load(const char *model_path, const char *codec_path, int use_fa,
|
||||
int clamp_fp16);
|
||||
|
||||
// Synthesize to a malloc'd float PCM buffer (caller frees via omni_pcm_free).
|
||||
// ref_samples != null && ref_n > 0 => voice cloning (ref_text optional).
|
||||
// instruct != null && non-empty => voice design. seed < 0 keeps the default
|
||||
// MaskGIT seed. denoise toggles the <|denoise|> marker (only with a reference).
|
||||
// Writes the sample count to *out_n. Returns NULL on failure (out_n set to 0).
|
||||
float *omni_tts(const char *text, const char *lang, const char *instruct,
|
||||
const float *ref_samples, int ref_n, const char *ref_text,
|
||||
long long seed, int denoise, int *out_n);
|
||||
|
||||
// Streaming synthesis: cb is invoked per PCM chunk as audio is produced.
|
||||
// Same reference/design/seed semantics as omni_tts. Returns 0 on success.
|
||||
int omni_tts_stream(const char *text, const char *lang, const char *instruct,
|
||||
const float *ref_samples, int ref_n, const char *ref_text,
|
||||
long long seed, int denoise, omni_pcm_chunk_cb cb,
|
||||
void *user_data);
|
||||
|
||||
// Free a buffer returned by omni_tts.
|
||||
void omni_pcm_free(float *p);
|
||||
|
||||
// Release the OmniVoice context.
|
||||
void omni_unload(void);
|
||||
}
|
||||
74
backend/go/omnivoice-cpp/e2e_test.go
Normal file
74
backend/go/omnivoice-cpp/e2e_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func ttsReq(text, voice string, lang *string, dst string) *pb.TTSRequest {
|
||||
return &pb.TTSRequest{Text: text, Voice: voice, Language: lang, Dst: dst}
|
||||
}
|
||||
|
||||
var _ = Describe("OmniVoice e2e", Label("e2e"), func() {
|
||||
var loaded bool
|
||||
|
||||
BeforeEach(func() {
|
||||
modelPath := os.Getenv("OMNIVOICE_MODEL")
|
||||
codecPath := os.Getenv("OMNIVOICE_CODEC")
|
||||
if modelPath == "" || codecPath == "" {
|
||||
Skip("OMNIVOICE_MODEL / OMNIVOICE_CODEC not set; skipping e2e")
|
||||
}
|
||||
if !loaded {
|
||||
lib := os.Getenv("OMNIVOICE_LIBRARY")
|
||||
if lib == "" {
|
||||
lib = "./libgomnivoicecpp-fallback.so"
|
||||
}
|
||||
h, err := purego.Dlopen(lib, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
purego.RegisterLibFunc(&CppLoad, h, "omni_load")
|
||||
purego.RegisterLibFunc(&CppTTS, h, "omni_tts")
|
||||
purego.RegisterLibFunc(&CppTTSStream, h, "omni_tts_stream")
|
||||
purego.RegisterLibFunc(&CppPCMFree, h, "omni_pcm_free")
|
||||
purego.RegisterLibFunc(&CppUnload, h, "omni_unload")
|
||||
Expect(CppLoad(modelPath, codecPath, 0, 0)).To(Equal(0))
|
||||
loaded = true
|
||||
}
|
||||
})
|
||||
|
||||
It("synthesizes a WAV file via TTS", func() {
|
||||
b := &OmnivoiceCpp{opts: loadOptions{seed: 42, denoise: true}}
|
||||
dst := GinkgoT().TempDir() + "/out.wav"
|
||||
lang := "en"
|
||||
err := b.TTS(ttsReq("Hello world.", "", &lang, dst))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
fi, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(fi.Size()).To(BeNumerically(">", int64(44)))
|
||||
})
|
||||
|
||||
It("streams audio chunks via TTSStream", func() {
|
||||
b := &OmnivoiceCpp{opts: loadOptions{seed: 42, denoise: true}}
|
||||
results := make(chan []byte, 1024)
|
||||
lang := "en"
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- b.TTSStream(ttsReq("Hello there, streaming test.", "", &lang, ""), results) }()
|
||||
|
||||
var chunks int
|
||||
var first []byte
|
||||
for c := range results {
|
||||
if chunks == 0 {
|
||||
first = c
|
||||
}
|
||||
chunks++
|
||||
}
|
||||
Expect(<-done).ToNot(HaveOccurred())
|
||||
Expect(chunks).To(BeNumerically(">=", 2))
|
||||
Expect(string(first[0:4])).To(Equal("RIFF"))
|
||||
Expect(strings.HasPrefix(string(first[8:12]), "WAVE")).To(BeTrue())
|
||||
})
|
||||
})
|
||||
246
backend/go/omnivoice-cpp/gomnivoicecpp.go
Normal file
246
backend/go/omnivoice-cpp/gomnivoicecpp.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
// omni_load(model_path, codec_path, use_fa, clamp_fp16) int
|
||||
CppLoad func(modelPath, codecPath string, useFA, clampFP16 int) int
|
||||
// omni_tts(text, lang, instruct, ref_samples, ref_n, ref_text, seed, denoise, out_n) -> float* (uintptr)
|
||||
CppTTS func(text, lang, instruct string, refSamples unsafe.Pointer, refN int,
|
||||
refText string, seed int64, denoise int, outN unsafe.Pointer) uintptr
|
||||
// omni_tts_stream(text, lang, instruct, ref_samples, ref_n, ref_text, seed, denoise, cb, user) int
|
||||
CppTTSStream func(text, lang, instruct string, refSamples unsafe.Pointer, refN int,
|
||||
refText string, seed int64, denoise int, cb uintptr, user uintptr) int
|
||||
CppPCMFree func(ptr uintptr)
|
||||
CppUnload func()
|
||||
)
|
||||
|
||||
type OmnivoiceCpp struct {
|
||||
base.SingleThread
|
||||
opts loadOptions
|
||||
// audioPath is the model-config reference voice (tts.audio_path), used as
|
||||
// the default voice-cloning reference when a request does not set Voice.
|
||||
audioPath string
|
||||
}
|
||||
|
||||
func (o *OmnivoiceCpp) Load(opts *pb.ModelOptions) error {
|
||||
model := opts.ModelFile
|
||||
if model == "" {
|
||||
model = opts.ModelPath
|
||||
}
|
||||
if !filepath.IsAbs(model) && opts.ModelPath != "" {
|
||||
model = filepath.Join(opts.ModelPath, model)
|
||||
}
|
||||
|
||||
o.opts = parseOptions(opts.Options)
|
||||
|
||||
// Resolve the codec/tokenizer GGUF: explicit option, else auto-discover a
|
||||
// *tokenizer*.gguf sibling of the base model.
|
||||
codec := o.opts.codecPath
|
||||
if codec != "" && !filepath.IsAbs(codec) {
|
||||
codec = filepath.Join(filepath.Dir(model), codec)
|
||||
}
|
||||
if codec == "" {
|
||||
codec = discoverTokenizer(filepath.Dir(model))
|
||||
}
|
||||
if codec == "" {
|
||||
return fmt.Errorf("omnivoice: no codec/tokenizer GGUF found; set option 'tokenizer:<file>'")
|
||||
}
|
||||
o.opts.codecPath = codec
|
||||
|
||||
// tts.audio_path (ModelOptions.AudioPath) is the config-level voice-cloning
|
||||
// reference: a default reference WAV used when a request omits Voice.
|
||||
// Resolved relative to the model directory like the codec.
|
||||
o.audioPath = opts.AudioPath
|
||||
if o.audioPath != "" && !filepath.IsAbs(o.audioPath) {
|
||||
o.audioPath = filepath.Join(filepath.Dir(model), o.audioPath)
|
||||
}
|
||||
|
||||
useFA := boolToInt(o.opts.useFA)
|
||||
clamp := boolToInt(o.opts.clampFP16)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[omnivoice-cpp] Load model=%s codec=%s use_fa=%d clamp_fp16=%d\n",
|
||||
model, codec, useFA, clamp)
|
||||
|
||||
if rc := CppLoad(model, codec, useFA, clamp); rc != 0 {
|
||||
return fmt.Errorf("omnivoice: failed to load model (rc=%d)", rc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// discoverTokenizer returns the first *tokenizer*.gguf in dir, or "".
|
||||
func discoverTokenizer(dir string) string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
name := strings.ToLower(e.Name())
|
||||
if strings.Contains(name, "tokenizer") && strings.HasSuffix(name, ".gguf") {
|
||||
return filepath.Join(dir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// refAudio loads the reference WAV (voice cloning) if voice points to a file.
|
||||
// Returns nil if no cloning (empty or non-path - voice design uses Instructions).
|
||||
func (o *OmnivoiceCpp) refAudio(voice string) ([]float32, error) {
|
||||
v := strings.TrimSpace(voice)
|
||||
if v == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if _, err := os.Stat(v); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return readWAVAsFloat(v)
|
||||
}
|
||||
|
||||
// refAudioFor resolves the cloning reference for a request: the per-request
|
||||
// Voice takes precedence, falling back to the model-config audio_path. Empty
|
||||
// result means no cloning (voice design via Instructions still applies).
|
||||
func (o *OmnivoiceCpp) refAudioFor(req *pb.TTSRequest) ([]float32, error) {
|
||||
voice := strings.TrimSpace(req.Voice)
|
||||
if voice == "" {
|
||||
voice = o.audioPath
|
||||
}
|
||||
return o.refAudio(voice)
|
||||
}
|
||||
|
||||
func reqParam(req *pb.TTSRequest, key string) string {
|
||||
if req.Params == nil {
|
||||
return ""
|
||||
}
|
||||
return req.Params[key]
|
||||
}
|
||||
|
||||
func (o *OmnivoiceCpp) seedFor(req *pb.TTSRequest) int64 {
|
||||
if s := reqParam(req, "seed"); s != "" {
|
||||
var n int64
|
||||
if _, err := fmt.Sscan(s, &n); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return o.opts.seed
|
||||
}
|
||||
|
||||
func optStr(p *string) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return *p
|
||||
}
|
||||
|
||||
func (o *OmnivoiceCpp) TTS(req *pb.TTSRequest) error {
|
||||
if req.Dst == "" {
|
||||
return fmt.Errorf("omnivoice: TTS requires a destination path")
|
||||
}
|
||||
lang := normalizeLanguage(optStr(req.Language))
|
||||
instruct := optStr(req.Instructions)
|
||||
refText := reqParam(req, "ref_text")
|
||||
seed := o.seedFor(req)
|
||||
|
||||
ref, err := o.refAudioFor(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var refPtr unsafe.Pointer
|
||||
if len(ref) > 0 {
|
||||
refPtr = unsafe.Pointer(&ref[0])
|
||||
}
|
||||
|
||||
var n int32
|
||||
ptr := CppTTS(req.Text, lang, instruct, refPtr, len(ref), refText, seed,
|
||||
boolToInt(o.opts.denoise), unsafe.Pointer(&n))
|
||||
runtimeKeepAlive(ref)
|
||||
if ptr == 0 || n <= 0 {
|
||||
return fmt.Errorf("omnivoice: synthesis failed")
|
||||
}
|
||||
defer CppPCMFree(ptr)
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // C-allocated PCM, copied out before free
|
||||
out := make([]float32, int(n))
|
||||
copy(out, src)
|
||||
return writeWAV24k(req.Dst, out)
|
||||
}
|
||||
|
||||
// streamState carries the active TTSStream channel to the single shared C
|
||||
// callback. base.SingleThread serializes TTS/TTSStream, so one global slot is
|
||||
// safe and avoids leaking a purego callback per request (purego callbacks
|
||||
// cannot be freed and are capped).
|
||||
var (
|
||||
streamMu sync.Mutex
|
||||
streamChan chan []byte
|
||||
streamCbOnce sync.Once
|
||||
streamCbPtr uintptr
|
||||
)
|
||||
|
||||
// streamCallback is registered once and forwards each PCM chunk to streamChan.
|
||||
func streamCallback(samples *float32, nSamples int32, _ uintptr) uintptr {
|
||||
if nSamples <= 0 || samples == nil || streamChan == nil {
|
||||
return 1 // continue
|
||||
}
|
||||
src := unsafe.Slice(samples, int(nSamples))
|
||||
cp := make([]float32, int(nSamples)) // copy out of C memory before returning
|
||||
copy(cp, src)
|
||||
streamChan <- floatToPCM16LE(cp)
|
||||
return 1 // continue
|
||||
}
|
||||
|
||||
func (o *OmnivoiceCpp) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
defer close(results)
|
||||
if req.Text == "" {
|
||||
return fmt.Errorf("omnivoice: TTSStream requires text")
|
||||
}
|
||||
|
||||
streamCbOnce.Do(func() {
|
||||
streamCbPtr = purego.NewCallback(streamCallback)
|
||||
})
|
||||
|
||||
lang := normalizeLanguage(optStr(req.Language))
|
||||
instruct := optStr(req.Instructions)
|
||||
refText := reqParam(req, "ref_text")
|
||||
seed := o.seedFor(req)
|
||||
|
||||
ref, err := o.refAudioFor(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var refPtr unsafe.Pointer
|
||||
if len(ref) > 0 {
|
||||
refPtr = unsafe.Pointer(&ref[0])
|
||||
}
|
||||
|
||||
// Emit the WAV header first so the HTTP layer gets a self-describing stream.
|
||||
results <- wavHeader24k()
|
||||
|
||||
streamMu.Lock()
|
||||
streamChan = results
|
||||
rc := CppTTSStream(req.Text, lang, instruct, refPtr, len(ref), refText, seed,
|
||||
boolToInt(o.opts.denoise), streamCbPtr, 0)
|
||||
streamChan = nil
|
||||
streamMu.Unlock()
|
||||
runtimeKeepAlive(ref)
|
||||
|
||||
if rc != 0 {
|
||||
return fmt.Errorf("omnivoice: streaming synthesis failed (rc=%d)", rc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
90
backend/go/omnivoice-cpp/gomnivoicecpp_test.go
Normal file
90
backend/go/omnivoice-cpp/gomnivoicecpp_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestOmnivoiceCpp(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "omnivoice-cpp suite")
|
||||
}
|
||||
|
||||
var _ = Describe("normalizeLanguage", func() {
|
||||
DescribeTable("maps caller language to OmniVoice codes",
|
||||
func(in, want string) {
|
||||
Expect(normalizeLanguage(in)).To(Equal(want))
|
||||
},
|
||||
Entry("empty stays empty", "", ""),
|
||||
Entry("english full name", "English", "en"),
|
||||
Entry("chinese full name", "Chinese", "zh"),
|
||||
Entry("locale suffix stripped", "en-US", "en"),
|
||||
Entry("underscore locale", "zh_CN", "zh"),
|
||||
Entry("already a code", "en", "en"),
|
||||
Entry("unknown passes through normalized", "xx", "xx"),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("parseOptions", func() {
|
||||
It("extracts codec, use_fa, clamp_fp16, seed, denoise", func() {
|
||||
o := parseOptions([]string{
|
||||
"tokenizer:tok.gguf",
|
||||
"use_fa:true",
|
||||
"clamp_fp16:true",
|
||||
"seed:7",
|
||||
"denoise:false",
|
||||
"unknown:ignored",
|
||||
})
|
||||
Expect(o.codecPath).To(Equal("tok.gguf"))
|
||||
Expect(o.useFA).To(BeTrue())
|
||||
Expect(o.clampFP16).To(BeTrue())
|
||||
Expect(o.seed).To(Equal(int64(7)))
|
||||
Expect(o.denoise).To(BeFalse())
|
||||
})
|
||||
|
||||
It("accepts codec: as an alias for tokenizer:", func() {
|
||||
o := parseOptions([]string{"codec:c.gguf"})
|
||||
Expect(o.codecPath).To(Equal("c.gguf"))
|
||||
})
|
||||
|
||||
It("defaults seed to -1 and denoise to true", func() {
|
||||
o := parseOptions(nil)
|
||||
Expect(o.seed).To(Equal(int64(-1)))
|
||||
Expect(o.denoise).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("wavHeader24k", func() {
|
||||
It("emits a 44-byte streaming WAV header at 24 kHz mono 16-bit", func() {
|
||||
h := wavHeader24k()
|
||||
Expect(h).To(HaveLen(44))
|
||||
Expect(string(h[0:4])).To(Equal("RIFF"))
|
||||
Expect(string(h[8:12])).To(Equal("WAVE"))
|
||||
Expect(string(h[12:16])).To(Equal("fmt "))
|
||||
Expect(string(h[36:40])).To(Equal("data"))
|
||||
var sampleRate uint32
|
||||
Expect(binary.Read(bytes.NewReader(h[24:28]), binary.LittleEndian, &sampleRate)).To(Succeed())
|
||||
Expect(sampleRate).To(Equal(uint32(24000)))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("floatToPCM16LE", func() {
|
||||
It("clamps and converts float PCM to little-endian int16 bytes", func() {
|
||||
b := floatToPCM16LE([]float32{0, 1.0, -1.0, 2.0, -2.0})
|
||||
Expect(b).To(HaveLen(10)) // 5 samples * 2 bytes
|
||||
read := func(off int) int16 {
|
||||
var v int16
|
||||
_ = binary.Read(bytes.NewReader(b[off:off+2]), binary.LittleEndian, &v)
|
||||
return v
|
||||
}
|
||||
Expect(read(0)).To(Equal(int16(0)))
|
||||
Expect(read(2)).To(Equal(int16(32767)))
|
||||
Expect(read(4)).To(Equal(int16(-32767)))
|
||||
Expect(read(6)).To(Equal(int16(32767))) // clamped from 2.0
|
||||
Expect(read(8)).To(Equal(int16(-32767))) // clamped from -2.0
|
||||
})
|
||||
})
|
||||
48
backend/go/omnivoice-cpp/main.go
Normal file
48
backend/go/omnivoice-cpp/main.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("OMNIVOICE_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgomnivoicecpp-fallback.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoad, "omni_load"},
|
||||
{&CppTTS, "omni_tts"},
|
||||
{&CppTTSStream, "omni_tts_stream"},
|
||||
{&CppPCMFree, "omni_pcm_free"},
|
||||
{&CppUnload, "omni_unload"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &OmnivoiceCpp{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
74
backend/go/omnivoice-cpp/options.go
Normal file
74
backend/go/omnivoice-cpp/options.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// loadOptions holds the parsed model-level options for OmniVoice.
|
||||
type loadOptions struct {
|
||||
codecPath string
|
||||
useFA bool
|
||||
clampFP16 bool
|
||||
seed int64
|
||||
denoise bool
|
||||
}
|
||||
|
||||
func splitOption(o string) (key, value string, ok bool) {
|
||||
i := strings.Index(o, ":")
|
||||
if i < 0 {
|
||||
return "", "", false
|
||||
}
|
||||
return strings.TrimSpace(o[:i]), strings.TrimSpace(o[i+1:]), true
|
||||
}
|
||||
|
||||
// parseOptions reads the backend "key:value" option slice. Unknown keys are
|
||||
// ignored. Defaults: seed -1 (engine default), denoise true.
|
||||
func parseOptions(opts []string) loadOptions {
|
||||
o := loadOptions{seed: -1, denoise: true}
|
||||
for _, oo := range opts {
|
||||
key, value, ok := splitOption(oo)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "tokenizer", "codec":
|
||||
o.codecPath = value
|
||||
case "use_fa":
|
||||
o.useFA = value == "true" || value == "1"
|
||||
case "clamp_fp16":
|
||||
o.clampFP16 = value == "true" || value == "1"
|
||||
case "seed":
|
||||
if n, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
o.seed = n
|
||||
}
|
||||
case "denoise":
|
||||
o.denoise = value == "true" || value == "1"
|
||||
}
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
// languageNameAliases maps full language names to OmniVoice codes. OmniVoice's
|
||||
// lang hint accepts "" (auto), "en", "zh" per the upstream convention; other
|
||||
// codes pass through and the engine treats unknown hints as auto.
|
||||
var languageNameAliases = map[string]string{
|
||||
"english": "en",
|
||||
"chinese": "zh",
|
||||
}
|
||||
|
||||
// normalizeLanguage lowercases, trims, strips a region/locale suffix, and
|
||||
// resolves common full names. Empty stays empty so the engine auto-detects.
|
||||
func normalizeLanguage(lang string) string {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
if lang == "" {
|
||||
return ""
|
||||
}
|
||||
if i := strings.IndexAny(lang, "-_."); i >= 0 {
|
||||
lang = lang[:i]
|
||||
}
|
||||
if code, ok := languageNameAliases[lang]; ok {
|
||||
return code
|
||||
}
|
||||
return lang
|
||||
}
|
||||
64
backend/go/omnivoice-cpp/package.sh
Executable file
64
backend/go/omnivoice-cpp/package.sh
Executable file
@@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
# This script is used in the final stage of the Dockerfile
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/omnivoice-cpp $CURDIR/package/
|
||||
cp -fv $CURDIR/libgomnivoicecpp-*.so $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/omnivoice-cpp/run.sh
Executable file
52
backend/go/omnivoice-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgomnivoicecpp-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgomnivoicecpp-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgomnivoicecpp-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgomnivoicecpp-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgomnivoicecpp-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgomnivoicecpp-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgomnivoicecpp-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export OMNIVOICE_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/omnivoice-cpp "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/omnivoice-cpp "$@"
|
||||
30
backend/go/omnivoice-cpp/test.sh
Executable file
30
backend/go/omnivoice-cpp/test.sh
Executable file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
cd "$CURDIR"
|
||||
|
||||
echo "Running omnivoice-cpp backend tests..."
|
||||
|
||||
if [ -z "$OMNIVOICE_MODEL" ]; then
|
||||
MODEL_DIR="./omnivoice-models"
|
||||
mkdir -p "$MODEL_DIR"
|
||||
REPO_ID="Serveurperso/OmniVoice-GGUF"
|
||||
BASE_URL="https://huggingface.co/${REPO_ID}/resolve/main"
|
||||
FILES=( "omnivoice-base-Q4_K_M.gguf" "omnivoice-tokenizer-Q4_K_M.gguf" )
|
||||
for file in "${FILES[@]}"; do
|
||||
dest="${MODEL_DIR}/${file}"
|
||||
if [ -f "${dest}" ]; then
|
||||
echo " [skip] ${file}"
|
||||
else
|
||||
echo " [download] ${file}..."
|
||||
curl -L -o "${dest}" "${BASE_URL}/${file}" --progress-bar
|
||||
fi
|
||||
done
|
||||
export OMNIVOICE_MODEL="${MODEL_DIR}/omnivoice-base-Q4_K_M.gguf"
|
||||
export OMNIVOICE_CODEC="${MODEL_DIR}/omnivoice-tokenizer-Q4_K_M.gguf"
|
||||
fi
|
||||
|
||||
go test -v -timeout 1200s .
|
||||
|
||||
echo "All omnivoice-cpp e2e tests passed."
|
||||
11
backend/go/parakeet-cpp/.gitignore
vendored
Normal file
11
backend/go/parakeet-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
.cache/
|
||||
sources/
|
||||
build/
|
||||
package/
|
||||
parakeet-cpp-grpc
|
||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||
# symlinked for local dev; the real sources live in parakeet.cpp upstream.
|
||||
*.so
|
||||
*.so.*
|
||||
parakeet_capi.h
|
||||
compile_commands.json
|
||||
96
backend/go/parakeet-cpp/Makefile
Normal file
96
backend/go/parakeet-cpp/Makefile
Normal file
@@ -0,0 +1,96 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=b8012f11e5269126eddb7f4fd02f891a2ccc29b0
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree parakeet.cpp
|
||||
# build, you can symlink the .so + header into this directory and skip
|
||||
# the clone/cmake steps entirely, e.g.:
|
||||
#
|
||||
# ln -sf /path/to/parakeet.cpp/build-shared/libparakeet.so .
|
||||
# ln -sf /path/to/parakeet.cpp/include/parakeet_capi.h .
|
||||
# go build -o parakeet-cpp-grpc .
|
||||
#
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=b8012f11e5269126eddb7f4fd02f891a2ccc29b0
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
# Build ggml statically into libparakeet.so (PIC) so the shared lib is
|
||||
# self-contained: dlopen needs no libggml*.so alongside it, only system libs
|
||||
# (libstdc++/libgomp/libc) that the runtime image already provides.
|
||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DPARAKEET_SHARED=ON -DPARAKEET_BUILD_CLI=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# parakeet.cpp gates its GGML backends behind PARAKEET_GGML_* options and does
|
||||
# set(GGML_CUDA ${PARAKEET_GGML_CUDA} CACHE BOOL "" FORCE), so a bare -DGGML_CUDA=ON
|
||||
# is overwritten back to OFF and the build silently falls back to CPU. Forward the
|
||||
# PARAKEET_GGML_* options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
# GGML_CUDA_GRAPHS is OFF by ggml default; enabling it gives a small free
|
||||
# speedup (~1% measured on GB10, never negative) by capturing/replaying the
|
||||
# CUDA graph. Not gated by parakeet.cpp, so it passes straight through to ggml.
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_CUDA=ON -DGGML_CUDA_GRAPHS=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_HIP=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_VULKAN=ON
|
||||
endif
|
||||
|
||||
.PHONY: parakeet-cpp-grpc package build clean purge test all
|
||||
|
||||
all: parakeet-cpp-grpc
|
||||
|
||||
# Clone the upstream parakeet.cpp source at the pinned commit. Directory
|
||||
# acts as the target so make only re-clones when missing. After a
|
||||
# PARAKEET_VERSION bump, run 'make purge && make' to refetch.
|
||||
sources/parakeet.cpp:
|
||||
mkdir -p sources/parakeet.cpp
|
||||
cd sources/parakeet.cpp && \
|
||||
git init -q && \
|
||||
git remote add origin $(PARAKEET_REPO) && \
|
||||
git fetch --depth 1 origin $(PARAKEET_VERSION) && \
|
||||
git checkout FETCH_HEAD && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Build the shared lib + header out-of-tree, then stage them next to the
|
||||
# Go sources so purego.Dlopen("libparakeet.so") and the cgo-less build
|
||||
# both pick them up.
|
||||
libparakeet.so: sources/parakeet.cpp
|
||||
cmake -B sources/parakeet.cpp/build-shared -S sources/parakeet.cpp $(CMAKE_ARGS)
|
||||
cmake --build sources/parakeet.cpp/build-shared --config Release -j$(JOBS)
|
||||
cp -fv sources/parakeet.cpp/build-shared/libparakeet.so* ./ 2>/dev/null || true
|
||||
cp -fv sources/parakeet.cpp/include/parakeet_capi.h ./
|
||||
|
||||
parakeet-cpp-grpc: libparakeet.so main.go goparakeetcpp.go
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o parakeet-cpp-grpc .
|
||||
|
||||
package: parakeet-cpp-grpc
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
# Test target. Smoke test is gated on PARAKEET_BACKEND_TEST_MODEL +
|
||||
# PARAKEET_BACKEND_TEST_WAV; without them the spec auto-skips.
|
||||
test:
|
||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||
|
||||
clean: purge
|
||||
rm -rf libparakeet.so* parakeet_capi.h package parakeet-cpp-grpc
|
||||
|
||||
purge:
|
||||
rm -rf sources/parakeet.cpp
|
||||
105
backend/go/parakeet-cpp/batcher.go
Normal file
105
backend/go/parakeet-cpp/batcher.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package main
|
||||
|
||||
import "time"
|
||||
|
||||
// batchRequest is one in-flight unary transcription waiting to be batched.
|
||||
// In production pcm/decoder are set; tag is an opaque marker used by tests.
|
||||
type batchRequest struct {
|
||||
pcm []float32
|
||||
decoder int32
|
||||
// language is the per-request target locale ("" means the model default).
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang for the whole batch,
|
||||
// so the dispatcher only coalesces requests that share a language.
|
||||
language string
|
||||
tag string
|
||||
reply chan batchReply
|
||||
}
|
||||
|
||||
// batchReply carries one per-item JSON object string (an element of the C-API's
|
||||
// JSON array) or an error back to the waiting handler goroutine.
|
||||
type batchReply struct {
|
||||
json string
|
||||
err error
|
||||
}
|
||||
|
||||
// batcher coalesces concurrent batchRequests into batched runBatch calls. A
|
||||
// single run() goroutine is the sole caller of runBatch, so runBatch (which in
|
||||
// production calls the thread-unsafe C engine) is never entered concurrently.
|
||||
type batcher struct {
|
||||
submit chan *batchRequest
|
||||
maxSize int
|
||||
maxWait time.Duration
|
||||
runBatch func(reqs []*batchRequest) // must deliver a reply to every req
|
||||
}
|
||||
|
||||
func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchRequest)) *batcher {
|
||||
if maxSize < 1 {
|
||||
maxSize = 1
|
||||
}
|
||||
return &batcher{
|
||||
submit: make(chan *batchRequest),
|
||||
maxSize: maxSize,
|
||||
maxWait: maxWait,
|
||||
runBatch: runBatch,
|
||||
}
|
||||
}
|
||||
|
||||
// run is the dispatcher loop: accumulate submitted requests until either maxSize
|
||||
// is reached or maxWait elapses since the first queued request, then dispatch.
|
||||
// Exits when stop is closed (draining any partially-filled batch first).
|
||||
//
|
||||
// A batch carries ONE language (parakeet.cpp's batched C-API takes a single
|
||||
// target_lang), so a request whose language differs from the batch leader is
|
||||
// not coalesced: it is held in carry and becomes the leader of the next batch.
|
||||
// carry is therefore never dropped and its caller never deadlocks: every batch
|
||||
// (including a lone carry on stop) is dispatched, and runBatch replies to all.
|
||||
func (b *batcher) run(stop <-chan struct{}) {
|
||||
var carry *batchRequest
|
||||
for {
|
||||
var first *batchRequest
|
||||
if carry != nil {
|
||||
// A mismatched request from the previous fill leads this batch.
|
||||
first, carry = carry, nil
|
||||
} else {
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
batch := []*batchRequest{first}
|
||||
|
||||
// maxSize==1 disables batching: dispatch immediately (passthrough).
|
||||
if b.maxSize == 1 {
|
||||
b.runBatch(batch)
|
||||
continue
|
||||
}
|
||||
|
||||
timer := time.NewTimer(b.maxWait)
|
||||
fill:
|
||||
for len(batch) < b.maxSize {
|
||||
select {
|
||||
case r := <-b.submit:
|
||||
if r.language != first.language {
|
||||
// Different language: carry it to the next batch so this
|
||||
// batch stays single-language, then dispatch what we have.
|
||||
carry = r
|
||||
break fill
|
||||
}
|
||||
batch = append(batch, r)
|
||||
case <-timer.C:
|
||||
break fill
|
||||
case <-stop:
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
// Don't strand a carried request's caller on shutdown.
|
||||
if carry != nil {
|
||||
b.runBatch([]*batchRequest{carry})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
}
|
||||
}
|
||||
164
backend/go/parakeet-cpp/batcher_test.go
Normal file
164
backend/go/parakeet-cpp/batcher_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("batcher", func() {
|
||||
echoReply := func(reqs []*batchRequest) {
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{json: r.tag}
|
||||
}
|
||||
}
|
||||
|
||||
It("coalesces concurrent submits into batches", func() {
|
||||
var mu sync.Mutex
|
||||
var sizes []int
|
||||
run := func(reqs []*batchRequest) {
|
||||
mu.Lock()
|
||||
sizes = append(sizes, len(reqs))
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(4, 50*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
const N = 4
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
total, maxBatch := 0, 0
|
||||
for _, s := range sizes {
|
||||
total += s
|
||||
if s > maxBatch {
|
||||
maxBatch = s
|
||||
}
|
||||
}
|
||||
Expect(total).To(Equal(N))
|
||||
Expect(maxBatch).To(BeNumerically(">=", 2), "expected at least one batch to coalesce >1 request")
|
||||
})
|
||||
|
||||
It("dispatches when max size is reached", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(2, time.Hour, run) // huge window: only size can trigger
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
for i := 0; i < 2; i++ {
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func(rep chan batchReply) { <-rep }(rep)
|
||||
}
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(2)))
|
||||
})
|
||||
|
||||
It("dispatches when the wait window elapses", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(8, 20*time.Millisecond, run) // size unreachable; window fires
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("bypasses batching when max size is 1", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(1, time.Hour, run) // size 1 => immediate dispatch
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("never coalesces requests with different languages into one batch", func() {
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang per batch, so the
|
||||
// dispatcher must keep every dispatched batch single-language. Submit a
|
||||
// mix of languages and assert (a) no batch ever carries more than one
|
||||
// distinct language and (b) every submitted request still gets a reply
|
||||
// (the mismatched carry-over is never dropped).
|
||||
var mu sync.Mutex
|
||||
var langsPerBatch [][]string
|
||||
run := func(reqs []*batchRequest) {
|
||||
seen := map[string]struct{}{}
|
||||
var distinct []string
|
||||
for _, r := range reqs {
|
||||
if _, ok := seen[r.language]; !ok {
|
||||
seen[r.language] = struct{}{}
|
||||
distinct = append(distinct, r.language)
|
||||
}
|
||||
}
|
||||
mu.Lock()
|
||||
langsPerBatch = append(langsPerBatch, distinct)
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
// Large window + size so the fill loop stays open across submits and the
|
||||
// language constraint (not the timer) is what splits the batches.
|
||||
b := newBatcher(16, 200*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
langs := []string{"en", "en", "de", "de", "en", "fr", "fr"}
|
||||
const N = 7
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), language: langs[i], reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// Invariant: every dispatched batch is single-language.
|
||||
for _, distinct := range langsPerBatch {
|
||||
Expect(len(distinct)).To(Equal(1), "a batch coalesced more than one language: %v", distinct)
|
||||
}
|
||||
// Liveness: every request got a reply (carry-over never stranded).
|
||||
for i := 0; i < N; i++ {
|
||||
Expect(got[i]).To(Equal(string(rune('a' + i))))
|
||||
}
|
||||
})
|
||||
})
|
||||
834
backend/go/parakeet-cpp/goparakeetcpp.go
Normal file
834
backend/go/parakeet-cpp/goparakeetcpp.go
Normal file
@@ -0,0 +1,834 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// purego-bound entry points from libparakeet.so. Names match
|
||||
// parakeet_capi.h exactly so a `nm libparakeet.so | grep parakeet_capi`
|
||||
// is enough to spot drift.
|
||||
//
|
||||
// Functions that return char* are declared as uintptr so we can call
|
||||
// parakeet_capi_free_string on the same pointer after copying, the
|
||||
// C-API contract is "caller owns and must free the returned buffer".
|
||||
var (
|
||||
CppAbiVersion func() int32
|
||||
CppLoad func(ggufPath string) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppTranscribePath func(ctx uintptr, wavPath string, decoder int32) uintptr
|
||||
CppTranscribePathJSON func(ctx uintptr, wavPath string, decoder int32) uintptr
|
||||
CppFreeString func(s uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
|
||||
// Batched JSON transcription: takes a concatenated float buffer of clips
|
||||
// plus their per-clip sample counts (sum(nSamples)==len(samplesConcat))
|
||||
// and returns a malloc'd char* JSON ARRAY of per-clip {"text","words",
|
||||
// "tokens"} objects (uintptr, freed via CppFreeString). purego passes the
|
||||
// Go slices as the base pointer of their backing array (kept alive for the
|
||||
// call), matching the CppStreamFeed pcm []float32 binding pattern; the C
|
||||
// side reads them as const float*/const int*.
|
||||
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
|
||||
|
||||
// CppTranscribePcmBatchJSONLang is the multilingual variant of the batched
|
||||
// JSON entry point: identical, plus a trailing target_lang. "" (the model
|
||||
// default, "auto") is passed for non-prompt models, which ignore it; an
|
||||
// unknown locale on a prompt model returns 0 and sets last_error. Present
|
||||
// only in newer libparakeet.so; nil falls back to CppTranscribePcmBatchJSON.
|
||||
CppTranscribePcmBatchJSONLang func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32, targetLang string) uintptr
|
||||
|
||||
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
|
||||
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
|
||||
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
|
||||
CppStreamBegin func(ctx uintptr) uintptr
|
||||
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
|
||||
CppStreamFinalize func(s uintptr) uintptr
|
||||
CppStreamFree func(s uintptr)
|
||||
|
||||
// CppStreamBeginLang is the multilingual variant of stream_begin: identical,
|
||||
// plus a trailing target_lang ("" means the model default). Present only in
|
||||
// newer libparakeet.so; nil falls back to CppStreamBegin.
|
||||
CppStreamBeginLang func(ctx uintptr, targetLang string) uintptr
|
||||
|
||||
// Streaming JSON variants (ABI v4): feed/finalize returning a malloc'd char*
|
||||
// JSON document {text,eou,frame_sec,words} (uintptr, freed via CppFreeString)
|
||||
// so streaming segments can carry per-word timestamps. Present only in newer
|
||||
// libparakeet.so; nil falls back to the text-only CppStreamFeed/Finalize path.
|
||||
CppStreamFeedJSON func(s uintptr, pcm []float32, nSamples int32) uintptr
|
||||
CppStreamFinalizeJSON func(s uintptr) uintptr
|
||||
)
|
||||
|
||||
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
|
||||
// call (1 s). The session buffers internally and decodes once a full
|
||||
// cache-aware encoder chunk is available, so this only bounds how often we
|
||||
// poll for newly-finalized text, not the model's actual chunk size.
|
||||
const streamChunkSamples = 16000
|
||||
|
||||
// transcriptJSON mirrors the document returned by
|
||||
// parakeet_capi_transcribe_path_json (see parakeet_capi.h):
|
||||
//
|
||||
// {"text":"...",
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...],
|
||||
// "tokens":[{"id":123,"t":0.480,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
|
||||
type transcriptJSON struct {
|
||||
Text string `json:"text"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
}
|
||||
|
||||
// streamFeedJSON mirrors the document returned by
|
||||
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v5):
|
||||
//
|
||||
// {"text":"...","eou":0,"eob":0,"frame_sec":0.080000,
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
|
||||
// <EOU> (end of utterance) fired this feed and "eob" is 1 when an <EOB>
|
||||
// (backchannel) fired. ABI v4 conflated the two into "eou"; v5 split them, so
|
||||
// we read both and treat either as an utterance boundary for segmentation.
|
||||
// "words" are the words finalized this call with absolute (stream-relative)
|
||||
// start/end seconds.
|
||||
type streamFeedJSON struct {
|
||||
Text string `json:"text"`
|
||||
Eou int `json:"eou"`
|
||||
Eob int `json:"eob"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
}
|
||||
|
||||
type transcriptWord struct {
|
||||
W string `json:"w"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Conf float64 `json:"conf"`
|
||||
}
|
||||
|
||||
type transcriptToken struct {
|
||||
ID int32 `json:"id"`
|
||||
T float64 `json:"t"`
|
||||
Conf float64 `json:"conf"`
|
||||
}
|
||||
|
||||
// ParakeetCpp owns a single loaded parakeet_ctx. The C engine is a
|
||||
// thread-unsafe singleton (mirrors whisper.cpp / vibevoice.cpp). Rather than
|
||||
// serialize every call through base.SingleThread, we route unary
|
||||
// transcription through an in-process batcher (its sole dispatcher goroutine
|
||||
// is the only caller of the engine on that path) and guard the shared engine
|
||||
// with engineMu so a streaming session and a batched-unary dispatch never
|
||||
// touch it concurrently.
|
||||
type ParakeetCpp struct {
|
||||
base.Base
|
||||
ctxPtr uintptr
|
||||
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
|
||||
bat *batcher
|
||||
batStop chan struct{}
|
||||
// segmentGapFrames is NeMo's segment_gap_threshold in ENCODER FRAMES (model
|
||||
// YAML option, default 0=off). When >0 it adds NeMo's silence-gap split on
|
||||
// top of the punctuation split; converted to seconds via the JSON frame_sec.
|
||||
segmentGapFrames int
|
||||
}
|
||||
|
||||
// Load is the LocalAI gRPC entry point for LoadModel: it calls
|
||||
// parakeet_capi_load with the GGUF path and stashes the resulting
|
||||
// opaque context pointer for AudioTranscription.
|
||||
func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
if opts.ModelFile == "" {
|
||||
return errors.New("parakeet-cpp: ModelFile is required")
|
||||
}
|
||||
|
||||
ctx := CppLoad(opts.ModelFile)
|
||||
if ctx == 0 {
|
||||
// No ctx to ask for last_error (the C-API's last-error buffer
|
||||
// lives on the ctx that was never returned). Surface the path
|
||||
// so the operator at least knows which load failed.
|
||||
return fmt.Errorf("parakeet-cpp: parakeet_capi_load failed for %q", opts.ModelFile)
|
||||
}
|
||||
p.ctxPtr = ctx
|
||||
|
||||
// Dynamic batching knobs (model YAML options:, key:value form). Batching is
|
||||
// OFF by default (batch_max_size:1): each request runs on its own. On GPU,
|
||||
// raising batch_max_size coalesces concurrent requests into one batched
|
||||
// engine call and improves throughput under load; leave it at 1 on CPU and
|
||||
// for low-concurrency setups, where batching only adds latency.
|
||||
maxSize := optInt(opts, "batch_max_size", 1)
|
||||
maxWaitMs := optInt(opts, "batch_max_wait_ms", 15)
|
||||
if maxWaitMs < 0 {
|
||||
maxWaitMs = 0
|
||||
}
|
||||
|
||||
// NeMo's segment_gap_threshold (encoder frames, default 0=off). Off by
|
||||
// default matches NeMo's default (punctuation-only segments); when set it
|
||||
// additionally splits segments on inter-word silence (see transcriptResultFromDoc).
|
||||
p.segmentGapFrames = optInt(opts, "segment_gap_threshold", 0)
|
||||
if CppTranscribePcmBatchJSON != nil {
|
||||
p.batStop = make(chan struct{})
|
||||
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
|
||||
go p.bat.run(p.batStop) // dispatcher runs until Free closes batStop
|
||||
if maxSize > 1 {
|
||||
xlog.Info("parakeet-cpp: dynamic batching enabled",
|
||||
"batch_max_size", maxSize, "batch_max_wait_ms", maxWaitMs)
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: dynamic batching off (batch_max_size=1); " +
|
||||
"set batch_max_size>1 to coalesce concurrent requests on GPU")
|
||||
}
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: batched C-API not present in libparakeet.so; " +
|
||||
"batching disabled, using per-request transcription")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// optInt reads an integer model option (key:value form) from ModelOptions,
|
||||
// returning def when absent or unparseable. The options array carries the
|
||||
// model YAML's options: entries (see core/config; siblings such as
|
||||
// acestep-cpp parse the same key:value form via strings.Cut on ":").
|
||||
func optInt(opts *pb.ModelOptions, key string, def int) int {
|
||||
for _, o := range opts.GetOptions() {
|
||||
k, v, ok := strings.Cut(o, ":")
|
||||
if ok && strings.TrimSpace(k) == key {
|
||||
if n, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// runBatch is the dispatcher's batch handler and the ONLY caller of the C
|
||||
// engine on the unary path. It concatenates the batch PCM, calls the batched
|
||||
// JSON C-API under engineMu, splits the JSON array, and replies to each request.
|
||||
func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// Observability: the actual coalesced batch size per engine call. Debug-level
|
||||
// so it stays silent in normal operation but lets operators confirm/tune batching.
|
||||
xlog.Debug("parakeet-cpp: dispatching batch", "size", len(reqs))
|
||||
nSamples := make([]int32, len(reqs))
|
||||
total := 0
|
||||
for i, r := range reqs {
|
||||
nSamples[i] = int32(len(r.pcm))
|
||||
total += len(r.pcm)
|
||||
}
|
||||
concat := make([]float32, 0, total)
|
||||
for _, r := range reqs {
|
||||
concat = append(concat, r.pcm...)
|
||||
}
|
||||
var dec int32
|
||||
if len(reqs) > 0 {
|
||||
dec = reqs[0].decoder
|
||||
}
|
||||
// All requests in a batch share one language (the batcher coalesces only
|
||||
// same-language requests), so any element's language describes the batch.
|
||||
lang := ""
|
||||
if len(reqs) > 0 {
|
||||
lang = reqs[0].language
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
var cstr uintptr
|
||||
if CppTranscribePcmBatchJSONLang != nil {
|
||||
cstr = CppTranscribePcmBatchJSONLang(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec, lang)
|
||||
} else {
|
||||
cstr = CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
}
|
||||
p.engineMu.Unlock()
|
||||
if cstr == 0 {
|
||||
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: err}
|
||||
}
|
||||
return
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var docs []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(raw), &docs); err != nil || len(docs) != len(reqs) {
|
||||
e := fmt.Errorf("parakeet-cpp: batch json: got %d results for %d reqs (%v)", len(docs), len(reqs), err)
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: e}
|
||||
}
|
||||
return
|
||||
}
|
||||
for i, r := range reqs {
|
||||
r.reply <- batchReply{json: string(docs[i])}
|
||||
}
|
||||
}
|
||||
|
||||
// AudioTranscription decodes the wav at opts.Dst to 16 kHz mono PCM and
|
||||
// submits it to the in-process batcher, which coalesces concurrent requests
|
||||
// into a single batched engine call (parakeet_capi_transcribe_pcm_batch_json)
|
||||
// with the default decoder (decoder=0, which selects the right head per
|
||||
// architecture: transducer for tdt/rnnt/hybrid, CTC for ctc) and shapes the
|
||||
// per-word timestamps into a LocalAI TranscriptResult.
|
||||
//
|
||||
// Parakeet emits word- and token-level timestamps but no native segment
|
||||
// boundaries, so we synthesise a single whole-clip segment spanning the first
|
||||
// word start to the last word end. Word-level timings are attached only when
|
||||
// the caller opts in via timestamp_granularities=["word"] (matching the
|
||||
// OpenAI API, whose default is segment-level); token ids always populate
|
||||
// Segment.Tokens.
|
||||
//
|
||||
// translate/diarize/prompt/temperature/threads are not applicable to parakeet
|
||||
// and are ignored; language is honored on the batched + streaming paths (see
|
||||
// opts.GetLanguage() below); streaming is handled by AudioTranscriptionStream
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
// Fallback when the batched C-API is unavailable: transcribe from a file
|
||||
// path (original behavior, no batching). The C library's audio loader only
|
||||
// understands 16 kHz mono WAV/PCM, so convert the input first - otherwise
|
||||
// any non-WAV upload (MP3, etc.) fails with "failed to load audio". This
|
||||
// mirrors what every other audio backend (whisper, crispasr) does via
|
||||
// utils.AudioToWav before handing the file to the engine.
|
||||
if p.bat == nil {
|
||||
converted, cleanup, err := convertToWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, converted, 0)
|
||||
if cstr == 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
}
|
||||
|
||||
// Batched path: decode to PCM, submit to the batcher, wait for this request's
|
||||
// JSON element. The dispatcher is the sole engine caller on this path; both
|
||||
// sends honour ctx cancellation.
|
||||
pcm, _, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
rep := make(chan batchReply, 1)
|
||||
select {
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, language: opts.GetLanguage(), reply: rep}:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
var res batchReply
|
||||
select {
|
||||
case res = <-rep:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if res.err != nil {
|
||||
return pb.TranscriptResult{}, res.err
|
||||
}
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
}
|
||||
|
||||
// segmentSeparators is NeMo's default segment_seperators (sentence-ending
|
||||
// punctuation). Splitting on these matches NeMo's default segment timestamps.
|
||||
var segmentSeparators = []rune{'.', '?', '!'}
|
||||
|
||||
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
|
||||
// grouping words into NeMo-faithful segments (see splitWordsIntoSegments). The
|
||||
// optional gapFrames (NeMo's segment_gap_threshold, in encoder FRAMES; 0=off)
|
||||
// additionally splits on inter-word silence; it is converted to a seconds gap
|
||||
// with the document's frame_sec. Per-segment word timings are attached only when
|
||||
// the caller requested word granularity; token ids populate each segment's
|
||||
// Tokens by time-window membership. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
|
||||
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
|
||||
gapSeconds := 0.0
|
||||
if gapFrames > 0 {
|
||||
if doc.FrameSec > 0 {
|
||||
gapSeconds = float64(gapFrames) * doc.FrameSec
|
||||
} else {
|
||||
xlog.Warn("parakeet-cpp: segment_gap_threshold set but libparakeet.so " +
|
||||
"did not report frame_sec; falling back to punctuation-only segments")
|
||||
}
|
||||
}
|
||||
|
||||
groups := splitWordsIntoSegments(doc.Words, segmentSeparators, gapSeconds)
|
||||
if len(groups) == 0 {
|
||||
// No words (edge case): single whole-clip text segment.
|
||||
return pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
|
||||
}
|
||||
}
|
||||
|
||||
wantWords := wordsRequested(opts.TimestampGranularities)
|
||||
segments := make([]*pb.TranscriptSegment, 0, len(groups))
|
||||
for id, group := range groups {
|
||||
parts := make([]string, len(group))
|
||||
for i, gw := range group {
|
||||
parts[i] = gw.W
|
||||
}
|
||||
seg := &pb.TranscriptSegment{
|
||||
Id: int32(id),
|
||||
Start: secondsToNanos(group[0].Start),
|
||||
End: secondsToNanos(group[len(group)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
Tokens: tokensInWindow(doc.Tokens, group[0].Start, group[len(group)-1].End),
|
||||
}
|
||||
if wantWords {
|
||||
ws := make([]*pb.TranscriptWord, len(group))
|
||||
for i, gw := range group {
|
||||
ws[i] = &pb.TranscriptWord{Start: secondsToNanos(gw.Start), End: secondsToNanos(gw.End), Text: gw.W}
|
||||
}
|
||||
seg.Words = ws
|
||||
}
|
||||
segments = append(segments, seg)
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: segments}
|
||||
}
|
||||
|
||||
// splitWordsIntoSegments groups words into segments exactly as NeMo's
|
||||
// get_segment_offsets does (nemo/collections/asr/parts/utils/timestamp_utils.py).
|
||||
// Walking the words, it closes a segment when (1) the gap rule is enabled
|
||||
// (gapSeconds > 0) and the segment already has words and the gap from the
|
||||
// previous word's end to this word's start is >= gapSeconds - the current word
|
||||
// then STARTS a new segment - or, checked only when the gap rule did not apply
|
||||
// (NeMo's elif), (2) the word ends with (or is) a separator, which closes the
|
||||
// segment INCLUDING that word. Trailing words flush into a final segment.
|
||||
// gapSeconds <= 0 disables the gap rule, matching NeMo's default
|
||||
// segment_gap_threshold=None (punctuation-only segments).
|
||||
func splitWordsIntoSegments(words []transcriptWord, separators []rune, gapSeconds float64) [][]transcriptWord {
|
||||
var segments [][]transcriptWord
|
||||
var cur []transcriptWord
|
||||
for i, word := range words {
|
||||
gapActive := gapSeconds > 0 && len(cur) > 0
|
||||
if gapActive && (word.Start-words[i-1].End) >= gapSeconds {
|
||||
segments = append(segments, cur)
|
||||
cur = []transcriptWord{word}
|
||||
continue
|
||||
}
|
||||
if !gapActive && endsWithSeparator(word.W, separators) {
|
||||
cur = append(cur, word)
|
||||
segments = append(segments, cur)
|
||||
cur = nil
|
||||
continue
|
||||
}
|
||||
cur = append(cur, word)
|
||||
}
|
||||
if len(cur) > 0 {
|
||||
segments = append(segments, cur)
|
||||
}
|
||||
return segments
|
||||
}
|
||||
|
||||
// endsWithSeparator reports whether w's last rune is in separators (matching
|
||||
// NeMo's `word[-1] in delims or word in delims`).
|
||||
func endsWithSeparator(w string, separators []rune) bool {
|
||||
r := []rune(strings.TrimSpace(w))
|
||||
if len(r) == 0 {
|
||||
return false
|
||||
}
|
||||
last := r[len(r)-1]
|
||||
for _, s := range separators {
|
||||
if last == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// tokensInWindow returns the ids of tokens whose timestamp t falls in
|
||||
// [start, end] (inclusive), assigning each token to the segment that spans its
|
||||
// time. The last segment's end is the last word end, so the final token is
|
||||
// included.
|
||||
func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
|
||||
var ids []int32
|
||||
for _, t := range tokens {
|
||||
if t.T >= start && t.T <= end {
|
||||
ids = append(ids, t.ID)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
|
||||
// is the model's own utterance boundary; each closed segment takes its start/end
|
||||
// from its first/last accumulated word.
|
||||
type streamSegmenter struct {
|
||||
segs []*pb.TranscriptSegment
|
||||
cur []transcriptWord
|
||||
nextID int32
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) add(doc streamFeedJSON) {
|
||||
s.cur = append(s.cur, doc.Words...)
|
||||
// Close the segment on either turn signal: <EOU> (end of utterance) or
|
||||
// <EOB> (backchannel). ABI v4 reported both via "eou"; v5 split them, so we
|
||||
// OR them here to keep the v4 segmentation boundaries.
|
||||
if doc.Eou != 0 || doc.Eob != 0 {
|
||||
s.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) flush() {
|
||||
if len(s.cur) == 0 {
|
||||
return
|
||||
}
|
||||
parts := make([]string, len(s.cur))
|
||||
for i, w := range s.cur {
|
||||
parts[i] = w.W
|
||||
}
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{
|
||||
Id: s.nextID,
|
||||
Start: secondsToNanos(s.cur[0].Start),
|
||||
End: secondsToNanos(s.cur[len(s.cur)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
})
|
||||
s.nextID++
|
||||
s.cur = nil
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
|
||||
|
||||
// wordsRequested reports whether the caller asked for word-level timestamps.
|
||||
// The OpenAI transcription API gates word timings behind
|
||||
// timestamp_granularities[] containing "word" and defaults to segment-level
|
||||
// otherwise; we follow that contract.
|
||||
func wordsRequested(granularities []string) bool {
|
||||
for _, g := range granularities {
|
||||
if strings.EqualFold(strings.TrimSpace(g), "word") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// secondsToNanos converts the C-API's fractional-second timestamps into the
|
||||
// int64 nanoseconds LocalAI carries on TranscriptSegment/TranscriptWord, the
|
||||
// same nanosecond convention the whisper backend uses.
|
||||
func secondsToNanos(sec float64) int64 {
|
||||
return int64(sec * 1e9)
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream drives the cache-aware streaming RNN-T over the
|
||||
// audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it in
|
||||
// chunks to parakeet_capi_stream_feed, and emits each newly-finalized text
|
||||
// run as a TranscriptStreamResponse delta. <EOU>/<EOB> events close the
|
||||
// current segment; a closing FinalResult carries the full transcript and the
|
||||
// per-utterance segments.
|
||||
//
|
||||
// stream_begin returns 0 for models that are not cache-aware streaming models
|
||||
// (only e.g. nvidia/parakeet_realtime_eou_120m-v1 qualifies). For those we fall
|
||||
// back to a single offline transcription emitted as one delta plus a closing
|
||||
// FinalResult, matching LocalAI's non-streaming streaming contract (and the
|
||||
// whisper backend), so the streaming endpoint works for every model.
|
||||
func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
var stream uintptr
|
||||
if CppStreamBeginLang != nil {
|
||||
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
|
||||
} else {
|
||||
stream = CppStreamBegin(p.ctxPtr)
|
||||
}
|
||||
if stream == 0 {
|
||||
// Not a cache-aware streaming model: run a normal offline
|
||||
// transcription and emit it as one delta + a closing final result.
|
||||
res, err := p.AudioTranscription(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t := strings.TrimSpace(res.Text); t != "" {
|
||||
results <- &pb.TranscriptStreamResponse{Delta: t}
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: &res}
|
||||
return nil
|
||||
}
|
||||
defer CppStreamFree(stream)
|
||||
// The C engine is a single shared context: a streaming session and a batched
|
||||
// unary dispatch must never touch it at once, so hold engineMu for the whole
|
||||
// stream. This lock is intentionally taken AFTER the non-streaming fallback
|
||||
// above returns: that fallback goes through AudioTranscription -> the batcher
|
||||
// -> runBatch, which itself acquires engineMu, so locking here first would
|
||||
// deadlock. Do not hoist this lock above the fallback.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
|
||||
data, duration, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ABI v4: when the streaming JSON entry points are present, drive them so the
|
||||
// per-utterance segments carry per-word start/end timestamps. Falls through to
|
||||
// the text-only loop below against an older libparakeet.so. Runs under the
|
||||
// engineMu already held above.
|
||||
if CppStreamFeedJSON != nil {
|
||||
return p.streamJSON(ctx, stream, data, duration, results)
|
||||
}
|
||||
|
||||
var (
|
||||
full strings.Builder
|
||||
segText strings.Builder
|
||||
segments []*pb.TranscriptSegment
|
||||
segID int32
|
||||
)
|
||||
|
||||
flushSegment := func() {
|
||||
t := strings.TrimSpace(segText.String())
|
||||
segText.Reset()
|
||||
if t == "" {
|
||||
return
|
||||
}
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: segID, Text: t})
|
||||
segID++
|
||||
}
|
||||
|
||||
// emitDelta consumes the malloc'd char* returned by feed/finalize: frees
|
||||
// it, accumulates the text, and sends a delta when non-empty. A 0 return
|
||||
// is an error (vs the "" empty-but-non-NULL no-new-text case).
|
||||
emitDelta := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
delta := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
full.WriteString(delta)
|
||||
segText.WriteString(delta)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
|
||||
var eou int32
|
||||
ret := CppStreamFeed(stream, chunk, int32(len(chunk)), unsafe.Pointer(&eou))
|
||||
if err := emitDelta(ret); err != nil {
|
||||
return err
|
||||
}
|
||||
if eou != 0 {
|
||||
flushSegment()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush the streaming tail (final encoder chunk).
|
||||
if err := emitDelta(CppStreamFinalize(stream)); err != nil {
|
||||
return err
|
||||
}
|
||||
flushSegment()
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamJSON drives the streaming JSON entry points (present since ABI v4): each
|
||||
// feed/finalize returns a {text,eou,eob,frame_sec,words} document. The
|
||||
// newly-finalized text is emitted as a delta (unchanged streaming contract)
|
||||
// while words are accumulated into per-utterance segments (closed on <EOU> or
|
||||
// <EOB>) so the closing FinalResult carries timestamped segments. Runs under
|
||||
// engineMu (already held by the caller).
|
||||
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
|
||||
duration float32, results chan *pb.TranscriptStreamResponse) error {
|
||||
var (
|
||||
full strings.Builder
|
||||
seg streamSegmenter
|
||||
)
|
||||
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
|
||||
// emits the delta, and routes words through the segmenter.
|
||||
consume := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
raw := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
var doc streamFeedJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
|
||||
}
|
||||
if doc.Text != "" {
|
||||
full.WriteString(doc.Text)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
|
||||
}
|
||||
seg.add(doc)
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
|
||||
return err
|
||||
}
|
||||
seg.flush() // close any trailing utterance that never saw an EOU
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
segments := seg.segments()
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
|
||||
// float samples plus the clip duration in seconds. Mirrors the whisper
|
||||
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
|
||||
// decodes the PCM.
|
||||
// convertToWavMono16k converts an arbitrary audio file to a 16 kHz mono WAV in
|
||||
// a fresh temp dir and returns the path together with a cleanup func the caller
|
||||
// must defer. WAV inputs already at 16 kHz/mono/16-bit are passed through by
|
||||
// utils.AudioToWav (hardlink/copy), everything else is transcoded via ffmpeg.
|
||||
// Used by the direct (non-batched) transcription path, which hands a file path
|
||||
// to the C library's WAV-only audio loader.
|
||||
func convertToWavMono16k(path string) (string, func(), error) {
|
||||
dir, err := os.MkdirTemp("", "parakeet")
|
||||
if err != nil {
|
||||
return "", func() {}, err
|
||||
}
|
||||
cleanup := func() { _ = os.RemoveAll(dir) }
|
||||
|
||||
converted := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(path, converted); err != nil {
|
||||
cleanup()
|
||||
return "", func() {}, err
|
||||
}
|
||||
return converted, cleanup, nil
|
||||
}
|
||||
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
converted, cleanup, err := convertToWavMono16k(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
fh, err := os.Open(converted)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
buf, err := wav.NewDecoder(fh).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
return data, duration, nil
|
||||
}
|
||||
|
||||
// Free releases the underlying parakeet_ctx. Called by LocalAI when the
|
||||
// model is unloaded.
|
||||
func (p *ParakeetCpp) Free() error {
|
||||
// Stop the dispatcher before releasing the engine so no in-flight runBatch
|
||||
// can touch a freed ctx (close leak / use-after-free on reload).
|
||||
if p.batStop != nil {
|
||||
close(p.batStop)
|
||||
p.batStop = nil
|
||||
}
|
||||
if p.ctxPtr != 0 {
|
||||
CppFree(p.ctxPtr)
|
||||
p.ctxPtr = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory.
|
||||
// cptr is the raw pointer returned by purego from the C-API (a malloc'd
|
||||
// buffer the caller owns); callers must free it via CppFreeString after
|
||||
// the copy lands.
|
||||
//
|
||||
// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr
|
||||
// check, which can't distinguish a C-owned heap pointer from Go-managed
|
||||
// memory. It is safe here: the pointer addresses a malloc'd C buffer the
|
||||
// Go GC neither tracks nor moves, and we dereference it immediately to
|
||||
// copy the bytes out, the same pattern (and the same tolerated warning)
|
||||
// as the whisper backend's unsafe.Slice over segsPtr.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := unsafe.Pointer(cptr) //nolint:govet // C-owned malloc'd buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
}
|
||||
return string(unsafe.Slice((*byte)(p), n))
|
||||
}
|
||||
247
backend/go/parakeet-cpp/goparakeetcpp_test.go
Normal file
247
backend/go/parakeet-cpp/goparakeetcpp_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestParakeetCpp(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "parakeet-cpp Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive
|
||||
// the C-API bridge without spinning up the gRPC server. Skips the
|
||||
// current spec when libparakeet.so isn't loadable from cwd
|
||||
// ($LD_LIBRARY_PATH or a symlink in ./).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("PARAKEET_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libparakeet.so"
|
||||
}
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppAbiVersion, lib, "parakeet_capi_abi_version")
|
||||
purego.RegisterLibFunc(&CppLoad, lib, "parakeet_capi_load")
|
||||
purego.RegisterLibFunc(&CppFree, lib, "parakeet_capi_free")
|
||||
purego.RegisterLibFunc(&CppTranscribePath, lib, "parakeet_capi_transcribe_path")
|
||||
purego.RegisterLibFunc(&CppTranscribePathJSON, lib, "parakeet_capi_transcribe_path_json")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppStreamBegin, lib, "parakeet_capi_stream_begin")
|
||||
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
|
||||
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
|
||||
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
|
||||
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("libparakeet.so not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if
|
||||
// either env var is unset. The smoke test never runs in default CI; it
|
||||
// needs a real parakeet GGUF and a 16 kHz mono WAV on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_MODEL")
|
||||
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set PARAKEET_BACKEND_TEST_MODEL and PARAKEET_BACKEND_TEST_WAV to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// writeMono16kWav writes `samples` frames of 16 kHz mono 16-bit silence to
|
||||
// path. The result is already in AudioToWav's target format, so the conversion
|
||||
// helper copies it through without invoking ffmpeg.
|
||||
func writeMono16kWav(path string, samples int) {
|
||||
GinkgoHelper()
|
||||
f, err := os.Create(path)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
enc := wav.NewEncoder(f, 16000, 16, 1, 1)
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 16000},
|
||||
SourceBitDepth: 16,
|
||||
Data: make([]int, samples),
|
||||
}
|
||||
Expect(enc.Write(buf)).To(Succeed())
|
||||
Expect(enc.Close()).To(Succeed())
|
||||
Expect(f.Close()).To(Succeed())
|
||||
}
|
||||
|
||||
var _ = Describe("ParakeetCpp", func() {
|
||||
Context("AudioTranscription", func() {
|
||||
It("transcribes a WAV via the parakeet C-API", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
|
||||
"expected non-empty transcript for %s", audioPath)
|
||||
// NeMo-faithful segmentation: one or more punctuation-delimited
|
||||
// segments, each with text and a monotonically-advancing time span.
|
||||
Expect(res.Segments).ToNot(BeEmpty(), "expected at least one segment")
|
||||
var prevEnd int64
|
||||
for i, seg := range res.Segments {
|
||||
Expect(strings.TrimSpace(seg.Text)).ToNot(BeEmpty(),
|
||||
"segment %d must have text", i)
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start),
|
||||
"segment %d end must not precede its start", i)
|
||||
Expect(seg.Start).To(BeNumerically(">=", prevEnd),
|
||||
"segments must be in time order")
|
||||
prevEnd = seg.End
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(seg.Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
}
|
||||
})
|
||||
|
||||
It("emits word-level timestamps when granularity=word", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
TimestampGranularities: []string{"word"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Segments).ToNot(BeEmpty())
|
||||
// With word granularity every segment carries its own words, and each
|
||||
// segment's span tracks its first/last word; word starts advance
|
||||
// monotonically across the whole transcript.
|
||||
totalWords := 0
|
||||
var prevStart int64 = -1
|
||||
for i, seg := range res.Segments {
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"segment %d must carry per-word timestamps with granularity=word", i)
|
||||
Expect(seg.Start).To(Equal(seg.Words[0].Start),
|
||||
"segment %d start tracks its first word", i)
|
||||
Expect(seg.End).To(Equal(seg.Words[len(seg.Words)-1].End),
|
||||
"segment %d end tracks its last word", i)
|
||||
for _, w := range seg.Words {
|
||||
Expect(w.End).To(BeNumerically(">=", w.Start))
|
||||
Expect(w.Start).To(BeNumerically(">=", prevStart))
|
||||
prevStart = w.Start
|
||||
totalWords++
|
||||
}
|
||||
}
|
||||
Expect(totalWords).To(BeNumerically(">", 0))
|
||||
Expect(res.Segments[0].Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("convertToWavMono16k", func() {
|
||||
// The non-batched transcription path hands a file path to the C
|
||||
// library's WAV-only audio loader, so it must convert first.
|
||||
// utils.AudioToWav passes an already-16kHz/mono/16-bit WAV through
|
||||
// without ffmpeg, which lets us exercise the helper (and the
|
||||
// regression: the direct path used to skip conversion entirely)
|
||||
// without a model, the C library, or ffmpeg.
|
||||
It("returns a decodable 16kHz mono WAV copy and cleans it up", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
src := filepath.Join(dir, "input.wav")
|
||||
writeMono16kWav(src, 16000) // 1s of silence at 16 kHz
|
||||
|
||||
converted, cleanup, err := convertToWavMono16k(src)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// It must produce a fresh temp file, not return the original path.
|
||||
Expect(converted).ToNot(Equal(src))
|
||||
Expect(converted).To(BeAnExistingFile())
|
||||
|
||||
pcm, _, err := decodeWavMono16k(converted)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pcm).To(HaveLen(16000), "round-trips the sample count")
|
||||
|
||||
cleanup()
|
||||
Expect(converted).ToNot(BeAnExistingFile(), "cleanup removes the temp dir")
|
||||
})
|
||||
|
||||
It("errors on a non-existent input rather than passing the path through", func() {
|
||||
_, _, err := convertToWavMono16k(filepath.Join(GinkgoT().TempDir(), "missing.mp3"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("streams deltas and a closing FinalResult from a cache-aware model", func() {
|
||||
// Streaming needs a cache-aware streaming model (e.g.
|
||||
// realtime_eou); the offline test model would fail stream_begin.
|
||||
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_STREAM_MODEL")
|
||||
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set PARAKEET_BACKEND_TEST_STREAM_MODEL (cache-aware streaming model) and PARAKEET_BACKEND_TEST_WAV")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- p.AudioTranscriptionStream(context.Background(),
|
||||
&pb.TranscriptRequest{Dst: audioPath}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var final *pb.TranscriptResult
|
||||
for r := range results {
|
||||
if r.Delta != "" {
|
||||
deltas = append(deltas, r.Delta)
|
||||
}
|
||||
if r.FinalResult != nil {
|
||||
final = r.FinalResult
|
||||
}
|
||||
}
|
||||
Expect(<-errCh).ToNot(HaveOccurred())
|
||||
|
||||
Expect(final).ToNot(BeNil(), "expected a closing FinalResult")
|
||||
Expect(strings.TrimSpace(final.Text)).ToNot(BeEmpty(),
|
||||
"expected a non-empty streamed transcript")
|
||||
Expect(final.Segments).ToNot(BeEmpty(),
|
||||
"FinalResult always carries at least one segment")
|
||||
// The concatenated deltas reconstruct the final transcript.
|
||||
Expect(strings.TrimSpace(strings.Join(deltas, ""))).To(Equal(strings.TrimSpace(final.Text)),
|
||||
"deltas should reconstruct the final text")
|
||||
})
|
||||
})
|
||||
})
|
||||
94
backend/go/parakeet-cpp/main.go
Normal file
94
backend/go/parakeet-cpp/main.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
//
|
||||
// Loads libparakeet.so via purego and registers the flat C-API entry
|
||||
// points declared in parakeet_capi.h. The library name can be overridden
|
||||
// with PARAKEET_LIBRARY (mirrors the WHISPER_LIBRARY / VIBEVOICECPP_LIBRARY
|
||||
// convention in the sibling backends); the default looks for the .so next
|
||||
// to this binary.
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("PARAKEET_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libparakeet.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("parakeet-cpp: dlopen %q: %w", libName, err))
|
||||
}
|
||||
|
||||
// Bound 1:1 to parakeet_capi.h. The C-API returns malloc'd char*
|
||||
// buffers from transcribe_*; we register those as uintptr so we get
|
||||
// the raw pointer back and can call parakeet_capi_free_string on it
|
||||
// (purego's string return would copy and forget the original pointer,
|
||||
// leaking it on every call).
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppAbiVersion, "parakeet_capi_abi_version"},
|
||||
{&CppLoad, "parakeet_capi_load"},
|
||||
{&CppFree, "parakeet_capi_free"},
|
||||
{&CppTranscribePath, "parakeet_capi_transcribe_path"},
|
||||
{&CppTranscribePathJSON, "parakeet_capi_transcribe_path_json"},
|
||||
{&CppStreamBegin, "parakeet_capi_stream_begin"},
|
||||
{&CppStreamFeed, "parakeet_capi_stream_feed"},
|
||||
{&CppStreamFinalize, "parakeet_capi_stream_finalize"},
|
||||
{&CppStreamFree, "parakeet_capi_stream_free"},
|
||||
{&CppFreeString, "parakeet_capi_free_string"},
|
||||
{&CppLastError, "parakeet_capi_last_error"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
// The batched-JSON entry point exists only in newer libparakeet.so (ABI >= 2).
|
||||
// Probe with Dlsym and register only if present, so the backend still loads
|
||||
// against an older library (it falls back to per-request transcription).
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
|
||||
// Per-request language variants (multilingual nemotron). Same probe pattern:
|
||||
// present only in libparakeet.so built with multilingual support, so the
|
||||
// backend still loads against an older library and falls back to the
|
||||
// non-lang batched + streaming entry points (model default / "auto").
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSONLang, lib, "parakeet_capi_transcribe_pcm_batch_json_lang")
|
||||
}
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_begin_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamBeginLang, lib, "parakeet_capi_stream_begin_lang")
|
||||
}
|
||||
|
||||
// Streaming JSON entry points (ABI v4): surface per-word timestamps on the
|
||||
// streaming path. Same probe pattern; absent in older libparakeet.so, where
|
||||
// the backend falls back to the text-only streaming feed.
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &ParakeetCpp{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
backend/go/parakeet-cpp/package.sh
Executable file
23
backend/go/parakeet-cpp/package.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# L0 packaging stub: copy the binary, run.sh and libparakeet.so* into
|
||||
# package/. The full ldd walk (libc, libstdc++, libgomp, GPU runtimes,
|
||||
# arch detection) lands in L3, mirroring backend/go/whisper/package.sh.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/parakeet-cpp-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libparakeet.so + any soname symlinks (libparakeet.so.X, libparakeet.so.X.Y).
|
||||
cp -avf "$CURDIR"/libparakeet.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libparakeet.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo "L0 package layout (full ldd walk lands in L3):"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
16
backend/go/parakeet-cpp/run.sh
Executable file
16
backend/go/parakeet-cpp/run.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# If a self-contained ld.so was packaged, route through it so the
|
||||
# packaged libc / libstdc++ are used instead of the host's (matches the
|
||||
# whisper backend's runtime layout).
|
||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec "$CURDIR/lib/ld.so" "$CURDIR/parakeet-cpp-grpc" "$@"
|
||||
fi
|
||||
|
||||
exec "$CURDIR/parakeet-cpp-grpc" "$@"
|
||||
140
backend/go/parakeet-cpp/segments_test.go
Normal file
140
backend/go/parakeet-cpp/segments_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func tw(text string, start, end float64) transcriptWord {
|
||||
return transcriptWord{W: text, Start: start, End: end}
|
||||
}
|
||||
|
||||
var _ = Describe("splitWordsIntoSegments (NeMo get_segment_offsets parity)", func() {
|
||||
seps := []rune{'.', '?', '!'}
|
||||
|
||||
It("splits on sentence-ending punctuation, including the delimiter word", func() {
|
||||
words := []transcriptWord{tw("hello", 0, 0.4), tw("world.", 0.4, 0.8), tw("bye", 1.0, 1.3)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[0][1].W).To(Equal("world."))
|
||||
Expect(segs[1]).To(HaveLen(1))
|
||||
Expect(segs[1][0].W).To(Equal("bye"))
|
||||
})
|
||||
|
||||
It("keeps a single segment with no terminal punctuation and gap off", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("splits on the gap rule when enabled, the gapped word starting the next segment", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0) // c is 4.6s after b
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2)) // a b
|
||||
Expect(segs[1]).To(HaveLen(1)) // c
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("checks the gap rule before punctuation (NeMo elif order)", func() {
|
||||
// "b." would terminate, but c is far after it -> gap closes [a b.] at b.
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b.", 0.2, 0.4), tw("c", 9.0, 9.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("still splits on punctuation when the gap rule is enabled but does not fire", func() {
|
||||
words := []transcriptWord{tw("hi.", 0, 0.4), tw("bye", 0.4, 0.8)}
|
||||
segs := splitWordsIntoSegments(words, seps, 5.0) // gap never reached
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0][0].W).To(Equal("hi."))
|
||||
})
|
||||
|
||||
It("returns nothing for empty input", func() {
|
||||
Expect(splitWordsIntoSegments(nil, seps, 0)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
|
||||
doc := transcriptJSON{
|
||||
Text: "hello world. bye now",
|
||||
FrameSec: 0.08,
|
||||
Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4},
|
||||
{W: "world.", Start: 0.4, End: 0.8},
|
||||
{W: "bye", Start: 1.0, End: 1.3},
|
||||
{W: "now", Start: 1.3, End: 1.6},
|
||||
},
|
||||
Tokens: []transcriptToken{{ID: 1, T: 0.1}, {ID: 2, T: 0.5}, {ID: 3, T: 1.1}, {ID: 4, T: 1.4}},
|
||||
}
|
||||
|
||||
It("emits one segment per punctuation-delimited group with start/end", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(2))
|
||||
Expect(res.Segments[0].Text).To(Equal("hello world."))
|
||||
Expect(res.Segments[0].Start).To(Equal(int64(0)))
|
||||
Expect(res.Segments[0].End).To(Equal(secondsToNanos(0.8)))
|
||||
Expect(res.Segments[1].Text).To(Equal("bye now"))
|
||||
Expect(res.Segments[1].Start).To(Equal(secondsToNanos(1.0)))
|
||||
Expect(res.Segments[1].Id).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("assigns tokens to the segment whose time window contains them", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments[0].Tokens).To(Equal([]int32{1, 2}))
|
||||
Expect(res.Segments[1].Tokens).To(Equal([]int32{3, 4}))
|
||||
})
|
||||
|
||||
It("attaches per-segment words only when word granularity requested", func() {
|
||||
plain := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(plain.Segments[0].Words).To(BeEmpty())
|
||||
withWords := transcriptResultFromDoc(doc, &pb.TranscriptRequest{TimestampGranularities: []string{"word"}}, 0)
|
||||
Expect(withWords.Segments[0].Words).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("falls back to a single text segment when there are no words", func() {
|
||||
res := transcriptResultFromDoc(transcriptJSON{Text: "hi"}, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
Expect(res.Segments[0].Text).To(Equal("hi"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("streaming segment assembly", func() {
|
||||
It("closes a segment with start/end from its words on EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
Expect(segs).To(HaveLen(1))
|
||||
Expect(segs[0].Text).To(Equal("hello world"))
|
||||
Expect(segs[0].Start).To(Equal(int64(0)))
|
||||
Expect(segs[0].End).To(Equal(secondsToNanos(0.9)))
|
||||
})
|
||||
|
||||
It("buffers words across feeds until EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
|
||||
Expect(acc.segments()).To(BeEmpty())
|
||||
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
|
||||
Expect(acc.segments()).To(HaveLen(1))
|
||||
Expect(acc.segments()[0].Text).To(Equal("hi there"))
|
||||
})
|
||||
|
||||
// ABI v5 split <EOB> (backchannel) out of the "eou" flag into its own "eob"
|
||||
// field; a backchannel must still close the segment as it did in v4.
|
||||
It("closes a segment on EOB (backchannel) too", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "uh huh", Eou: 0, Eob: 1, Words: []transcriptWord{
|
||||
{W: "uh", Start: 0.0, End: 0.2}, {W: "huh", Start: 0.2, End: 0.5},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
Expect(segs).To(HaveLen(1))
|
||||
Expect(segs[0].Text).To(Equal("uh huh"))
|
||||
Expect(segs[0].End).To(Equal(secondsToNanos(0.5)))
|
||||
})
|
||||
})
|
||||
@@ -3,35 +3,36 @@ project(goqwen3ttscpp LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(QWEN3TTS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/sources/qwen3-tts.cpp)
|
||||
set(QWENTTS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/sources/qwentts.cpp)
|
||||
|
||||
# Override upstream's CMAKE_CUDA_ARCHITECTURES before add_subdirectory.
|
||||
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "75-virtual;80-virtual;86-real;89-real")
|
||||
endif()
|
||||
|
||||
# Build ggml from the upstream's submodule FIRST, so that ggml/ggml-base/ggml-cpu
|
||||
# CMake targets exist when the upstream project references them by name.
|
||||
# The upstream CMakeLists.txt uses target_link_libraries(... ggml ggml-base ggml-cpu)
|
||||
# with target_link_directories pointing at a pre-built ggml/build/. By adding ggml
|
||||
# as a subdirectory here, CMake resolves those names as targets instead.
|
||||
add_subdirectory(${QWEN3TTS_DIR}/ggml ggml EXCLUDE_FROM_ALL)
|
||||
# Add the upstream project. Its own CMakeLists adds ggml + cpp-httplib + yyjson
|
||||
# and builds qwen-core (STATIC, the qt_* impl). EXCLUDE_FROM_ALL keeps its CLI
|
||||
# tools / tts-server / tests from building unless referenced.
|
||||
add_subdirectory(${QWENTTS_DIR} qwentts EXCLUDE_FROM_ALL)
|
||||
|
||||
# Now add the upstream project
|
||||
add_subdirectory(${QWEN3TTS_DIR} qwen3tts EXCLUDE_FROM_ALL)
|
||||
# Upstream generates version.h into its own CMAKE_CURRENT_BINARY_DIR and adds
|
||||
# the top-level ${CMAKE_BINARY_DIR} to qwen-core's include path. Under
|
||||
# add_subdirectory those two dirs differ (<build>/qwentts vs <build>), so
|
||||
# qwen.cpp cannot find version.h. Point qwen-core at the subproject binary dir
|
||||
# where version.h is actually generated. (Fix lives here, never in the fetched
|
||||
# upstream checkout.)
|
||||
target_include_directories(qwen-core PRIVATE ${CMAKE_BINARY_DIR}/qwentts)
|
||||
|
||||
add_library(goqwen3ttscpp MODULE cpp/goqwen3ttscpp.cpp)
|
||||
target_link_libraries(goqwen3ttscpp PRIVATE qwen3_tts)
|
||||
target_link_libraries(goqwen3ttscpp PRIVATE qwen-core)
|
||||
|
||||
target_include_directories(goqwen3ttscpp PRIVATE ${QWEN3TTS_DIR}/src)
|
||||
target_include_directories(goqwen3ttscpp SYSTEM PRIVATE ${QWEN3TTS_DIR}/ggml/include)
|
||||
target_include_directories(goqwen3ttscpp PRIVATE ${QWENTTS_DIR}/src)
|
||||
target_include_directories(goqwen3ttscpp SYSTEM PRIVATE ${QWENTTS_DIR}/ggml/include)
|
||||
|
||||
# Link GPU backends if available
|
||||
foreach(backend blas cuda metal vulkan)
|
||||
# Link GPU backends if the upstream ggml created them.
|
||||
foreach(backend blas cuda metal vulkan sycl)
|
||||
if(TARGET ggml-${backend})
|
||||
target_link_libraries(goqwen3ttscpp PRIVATE ggml-${backend})
|
||||
string(TOUPPER ${backend} BACKEND_UPPER)
|
||||
target_compile_definitions(goqwen3ttscpp PRIVATE QWEN3TTS_HAVE_${BACKEND_UPPER})
|
||||
if(backend STREQUAL "cuda")
|
||||
find_package(CUDAToolkit QUIET)
|
||||
if(CUDAToolkit_FOUND)
|
||||
@@ -44,12 +45,8 @@ endforeach()
|
||||
if(MSVC)
|
||||
target_compile_options(goqwen3ttscpp PRIVATE /W4 /wd4100 /wd4505)
|
||||
else()
|
||||
target_compile_options(goqwen3ttscpp PRIVATE -Wall -Wextra -Wshadow -Wconversion
|
||||
-Wno-unused-parameter -Wno-unused-function -Wno-sign-conversion)
|
||||
endif()
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(goqwen3ttscpp PRIVATE stdc++fs)
|
||||
target_compile_options(goqwen3ttscpp PRIVATE -Wall -Wextra
|
||||
-Wno-unused-parameter -Wno-unused-function)
|
||||
endif()
|
||||
|
||||
set_property(TARGET goqwen3ttscpp PROPERTY CXX_STANDARD 17)
|
||||
|
||||
@@ -6,9 +6,9 @@ GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# qwen3-tts.cpp version
|
||||
QWEN3TTS_REPO?=https://github.com/predict-woo/qwen3-tts.cpp
|
||||
QWEN3TTS_CPP_VERSION?=7a762e2ad4bacc6fdda81d81bf10a09ffb546f29
|
||||
# qwentts.cpp version
|
||||
QWEN3TTS_REPO?=https://github.com/ServeurpersoCom/qwentts.cpp
|
||||
QWEN3TTS_CPP_VERSION?=0bf4a18b22e8bb8718d95294e9f7f45c0d4270a4
|
||||
SO_TARGET?=libgoqwen3ttscpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
@@ -49,9 +49,9 @@ ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/qwen3-tts.cpp:
|
||||
mkdir -p sources/qwen3-tts.cpp
|
||||
cd sources/qwen3-tts.cpp && \
|
||||
sources/qwentts.cpp:
|
||||
mkdir -p sources/qwentts.cpp
|
||||
cd sources/qwentts.cpp && \
|
||||
git init && \
|
||||
git remote add origin $(QWEN3TTS_REPO) && \
|
||||
git fetch origin && \
|
||||
@@ -78,7 +78,7 @@ package: qwen3-tts-cpp
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgoqwen3ttscpp*.so package sources/qwen3-tts.cpp qwen3-tts-cpp
|
||||
rm -rf libgoqwen3ttscpp*.so package sources/qwentts.cpp qwen3-tts-cpp
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
@@ -88,24 +88,24 @@ purge:
|
||||
|
||||
# Build all variants (Linux only)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgoqwen3ttscpp-avx.so: sources/qwen3-tts.cpp
|
||||
libgoqwen3ttscpp-avx.so: sources/qwentts.cpp
|
||||
$(info ${GREEN}I qwen3-tts-cpp build info:avx${RESET})
|
||||
SO_TARGET=libgoqwen3ttscpp-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgoqwen3ttscpp-custom
|
||||
rm -rf build-libgoqwen3ttscpp-avx.so
|
||||
|
||||
libgoqwen3ttscpp-avx2.so: sources/qwen3-tts.cpp
|
||||
libgoqwen3ttscpp-avx2.so: sources/qwentts.cpp
|
||||
$(info ${GREEN}I qwen3-tts-cpp build info:avx2${RESET})
|
||||
SO_TARGET=libgoqwen3ttscpp-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgoqwen3ttscpp-custom
|
||||
rm -rf build-libgoqwen3ttscpp-avx2.so
|
||||
|
||||
libgoqwen3ttscpp-avx512.so: sources/qwen3-tts.cpp
|
||||
libgoqwen3ttscpp-avx512.so: sources/qwentts.cpp
|
||||
$(info ${GREEN}I qwen3-tts-cpp build info:avx512${RESET})
|
||||
SO_TARGET=libgoqwen3ttscpp-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgoqwen3ttscpp-custom
|
||||
rm -rf build-libgoqwen3ttscpp-avx512.so
|
||||
endif
|
||||
|
||||
# Build fallback variant (all platforms)
|
||||
libgoqwen3ttscpp-fallback.so: sources/qwen3-tts.cpp
|
||||
libgoqwen3ttscpp-fallback.so: sources/qwentts.cpp
|
||||
$(info ${GREEN}I qwen3-tts-cpp build info:fallback${RESET})
|
||||
SO_TARGET=libgoqwen3ttscpp-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgoqwen3ttscpp-custom
|
||||
rm -rf build-libgoqwen3ttscpp-fallback.so
|
||||
|
||||
128
backend/go/qwen3-tts-cpp/audio.go
Normal file
128
backend/go/qwen3-tts-cpp/audio.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
const qwen3ttsSampleRate = 24000
|
||||
|
||||
// wavHeader24k returns a 44-byte WAV header for a streaming 24 kHz mono 16-bit
|
||||
// PCM stream, with placeholder (0xFFFFFFFF) sizes since the total length is
|
||||
// unknown up front. Emitted as the first chunk of TTSStream so the HTTP layer
|
||||
// receives a self-describing WAV (the gRPC TTSStream path never sets Message,
|
||||
// so the backend owns the header - see core/backend/tts.go:ModelTTSStream).
|
||||
func wavHeader24k() []byte {
|
||||
var buf bytes.Buffer
|
||||
w := func(v any) { _ = binary.Write(&buf, binary.LittleEndian, v) }
|
||||
buf.WriteString("RIFF")
|
||||
w(uint32(0xFFFFFFFF))
|
||||
buf.WriteString("WAVE")
|
||||
buf.WriteString("fmt ")
|
||||
w(uint32(16)) // Subchunk1Size
|
||||
w(uint16(1)) // PCM
|
||||
w(uint16(1)) // mono
|
||||
w(uint32(qwen3ttsSampleRate)) // sample rate
|
||||
w(uint32(qwen3ttsSampleRate * 2)) // byte rate = SR * blockAlign
|
||||
w(uint16(2)) // block align (16-bit mono)
|
||||
w(uint16(16)) // bits per sample
|
||||
buf.WriteString("data")
|
||||
w(uint32(0xFFFFFFFF))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// floatToPCM16LE clamps each sample to [-1,1] and encodes it as little-endian
|
||||
// signed 16-bit PCM.
|
||||
func floatToPCM16LE(samples []float32) []byte {
|
||||
out := make([]byte, len(samples)*2)
|
||||
for i, s := range samples {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
v := int16(s * 32767)
|
||||
out[i*2] = byte(v)
|
||||
out[i*2+1] = byte(v >> 8)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// writeWAV24k writes samples as a finalized 24 kHz mono 16-bit WAV at dst.
|
||||
func writeWAV24k(dst string, samples []float32) error {
|
||||
f, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("qwen3-tts: create %q: %w", dst, err)
|
||||
}
|
||||
enc := wav.NewEncoder(f, qwen3ttsSampleRate, 16, 1, 1)
|
||||
ints := make([]int, len(samples))
|
||||
for i, s := range samples {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
ints[i] = int(s * 32767)
|
||||
}
|
||||
b := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: qwen3ttsSampleRate},
|
||||
Data: ints,
|
||||
SourceBitDepth: 16,
|
||||
}
|
||||
if err := enc.Write(b); err != nil {
|
||||
_ = enc.Close()
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("qwen3-tts: encode WAV: %w", err)
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("qwen3-tts: finalize WAV: %w", err)
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
// readWAVAsFloat decodes a WAV file (any sample rate/channels) to a mono
|
||||
// float32 slice in [-1,1] for use as cloning reference audio. qwentts expects
|
||||
// 24 kHz; callers should supply 24 kHz reference clips.
|
||||
func readWAVAsFloat(path string) ([]float32, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen3-tts: open ref %q: %w", path, err)
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
dec := wav.NewDecoder(f)
|
||||
buf, err := dec.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen3-tts: decode ref %q: %w", path, err)
|
||||
}
|
||||
ch := int(buf.Format.NumChannels)
|
||||
if ch < 1 {
|
||||
ch = 1
|
||||
}
|
||||
bitDepth := int(buf.SourceBitDepth)
|
||||
if bitDepth == 0 {
|
||||
bitDepth = 16
|
||||
}
|
||||
scale := float32(int64(1) << uint(bitDepth-1))
|
||||
n := len(buf.Data) / ch
|
||||
out := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
var acc int
|
||||
for c := 0; c < ch; c++ {
|
||||
acc += buf.Data[i*ch+c]
|
||||
}
|
||||
out[i] = float32(acc) / float32(ch) / scale
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// runtimeKeepAlive prevents the GC from reclaiming the reference-audio slice
|
||||
// while its backing pointer is in use across the C call.
|
||||
func runtimeKeepAlive(v any) { runtime.KeepAlive(v) }
|
||||
54
backend/go/qwen3-tts-cpp/audiopath_test.go
Normal file
54
backend/go/qwen3-tts-cpp/audiopath_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// These specs pin the voice-selection logic in resolveRequest, in particular
|
||||
// the config-level audio_path (tts.audio_path -> ModelOptions.AudioPath) being
|
||||
// used as the default voice-cloning reference. No model/C library is needed:
|
||||
// resolveRequest only reads the reference WAV via readWAVAsFloat (pure Go).
|
||||
var _ = Describe("resolveRequest voice/clone selection", func() {
|
||||
var dir, refWav string
|
||||
|
||||
BeforeEach(func() {
|
||||
dir = GinkgoT().TempDir()
|
||||
refWav = filepath.Join(dir, "ref.wav")
|
||||
// 0.5s of non-silent 24kHz mono audio as a clone reference.
|
||||
samples := make([]float32, qwen3ttsSampleRate/2)
|
||||
for i := range samples {
|
||||
samples[i] = 0.1
|
||||
}
|
||||
Expect(writeWAV24k(refWav, samples)).To(Succeed())
|
||||
})
|
||||
|
||||
It("uses the config audio_path as the clone reference when Voice is empty", func() {
|
||||
q := &Qwen3TtsCpp{audioPath: refWav}
|
||||
_, _, speaker, _, ref, _, err := q.resolveRequest(&pb.TTSRequest{Text: "hi"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(speaker).To(BeEmpty())
|
||||
Expect(len(ref)).To(Equal(qwen3ttsSampleRate / 2))
|
||||
})
|
||||
|
||||
It("lets a per-request audio Voice override audio_path", func() {
|
||||
other := filepath.Join(dir, "other.wav")
|
||||
Expect(writeWAV24k(other, make([]float32, 100))).To(Succeed())
|
||||
q := &Qwen3TtsCpp{audioPath: refWav}
|
||||
_, _, speaker, _, ref, _, err := q.resolveRequest(&pb.TTSRequest{Text: "hi", Voice: other})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(speaker).To(BeEmpty())
|
||||
Expect(len(ref)).To(Equal(100))
|
||||
})
|
||||
|
||||
It("does not trigger audio_path cloning for a named-speaker Voice", func() {
|
||||
q := &Qwen3TtsCpp{audioPath: refWav}
|
||||
_, _, speaker, _, ref, _, err := q.resolveRequest(&pb.TTSRequest{Text: "hi", Voice: "serena"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(speaker).To(Equal("serena"))
|
||||
Expect(ref).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -1,161 +1,191 @@
|
||||
#include "goqwen3ttscpp.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "qwen3_tts.h"
|
||||
#include "qwen.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
|
||||
using namespace qwen3_tts;
|
||||
static qt_context *g_ctx = nullptr;
|
||||
|
||||
// Global engine (loaded once, reused across requests)
|
||||
static Qwen3TTS *g_engine = nullptr;
|
||||
static bool g_loaded = false;
|
||||
static int g_threads = 4;
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log, void *data) {
|
||||
const char *level_str;
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void * /*data*/) {
|
||||
if (!log)
|
||||
return;
|
||||
const char *lvl = "?????";
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG:
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_INFO:
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_WARN:
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default:
|
||||
level_str = "?????";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_DEBUG: lvl = "DEBUG"; break;
|
||||
case GGML_LOG_LEVEL_INFO: lvl = "INFO"; break;
|
||||
case GGML_LOG_LEVEL_WARN: lvl = "WARN"; break;
|
||||
case GGML_LOG_LEVEL_ERROR: lvl = "ERROR"; break;
|
||||
default: break;
|
||||
}
|
||||
fprintf(stderr, "[%-5s] ", level_str);
|
||||
fputs(log, stderr);
|
||||
fprintf(stderr, "[%-5s] %s", lvl, log);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
// Map language string to language_id token used by the model
|
||||
static int language_to_id(const char *lang) {
|
||||
if (!lang || lang[0] == '\0')
|
||||
return 2050; // default: English
|
||||
std::string l(lang);
|
||||
if (l == "en")
|
||||
return 2050;
|
||||
if (l == "ru")
|
||||
return 2069;
|
||||
if (l == "zh")
|
||||
return 2055;
|
||||
if (l == "ja")
|
||||
return 2058;
|
||||
if (l == "ko")
|
||||
return 2064;
|
||||
if (l == "de")
|
||||
return 2053;
|
||||
if (l == "fr")
|
||||
return 2061;
|
||||
if (l == "es")
|
||||
return 2054;
|
||||
if (l == "it")
|
||||
return 2056;
|
||||
if (l == "pt")
|
||||
return 2057;
|
||||
fprintf(stderr, "[qwen3-tts-cpp] Unknown language '%s', defaulting to English\n",
|
||||
lang);
|
||||
return 2050;
|
||||
}
|
||||
|
||||
int load_model(const char *model_dir, int n_threads) {
|
||||
int qt3_load(const char *talker_path, const char *codec_path, int use_fa,
|
||||
int clamp_fp16) {
|
||||
ggml_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (n_threads <= 0)
|
||||
n_threads = 4;
|
||||
g_threads = n_threads;
|
||||
|
||||
fprintf(stderr, "[qwen3-tts-cpp] Loading models from %s (threads=%d)\n",
|
||||
model_dir, n_threads);
|
||||
|
||||
g_engine = new Qwen3TTS();
|
||||
if (!g_engine->load_models(model_dir)) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] FATAL: failed to load models from %s\n",
|
||||
model_dir);
|
||||
delete g_engine;
|
||||
g_engine = nullptr;
|
||||
if (!talker_path || talker_path[0] == '\0') {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: talker_path is required\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
g_loaded = true;
|
||||
fprintf(stderr, "[qwen3-tts-cpp] Models loaded successfully\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
int synthesize(const char *text, const char *ref_audio_path, const char *dst,
|
||||
const char *language, float temperature, float top_p,
|
||||
int top_k, float repetition_penalty, int max_audio_tokens,
|
||||
int n_threads) {
|
||||
if (!g_loaded || !g_engine) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: models not loaded\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!text || !dst) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: text and dst are required\n");
|
||||
if (!codec_path || codec_path[0] == '\0') {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: codec_path is required\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
tts_params params;
|
||||
params.max_audio_tokens = max_audio_tokens > 0 ? max_audio_tokens : 4096;
|
||||
params.temperature = temperature;
|
||||
params.top_p = top_p;
|
||||
params.top_k = top_k;
|
||||
params.repetition_penalty = repetition_penalty;
|
||||
params.n_threads = n_threads > 0 ? n_threads : g_threads;
|
||||
params.language_id = language_to_id(language);
|
||||
qt_init_params p;
|
||||
qt_init_default_params(&p);
|
||||
p.talker_path = talker_path;
|
||||
p.codec_path = codec_path;
|
||||
p.use_fa = use_fa != 0;
|
||||
p.clamp_fp16 = clamp_fp16 != 0;
|
||||
|
||||
fprintf(stderr, "[qwen3-tts-cpp] Synthesizing: text='%.50s%s', lang_id=%d, "
|
||||
"temp=%.2f, threads=%d\n",
|
||||
text, (strlen(text) > 50 ? "..." : ""), params.language_id,
|
||||
temperature, params.n_threads);
|
||||
fprintf(stderr, "[qwen3-tts-cpp] Loading talker=%s codec=%s\n", talker_path,
|
||||
codec_path);
|
||||
|
||||
tts_result result;
|
||||
bool has_ref = ref_audio_path && ref_audio_path[0] != '\0';
|
||||
|
||||
if (has_ref) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] Voice cloning with ref: %s\n",
|
||||
ref_audio_path);
|
||||
result = g_engine->synthesize_with_voice(text, ref_audio_path, params);
|
||||
} else {
|
||||
result = g_engine->synthesize(text, params);
|
||||
}
|
||||
|
||||
if (!result.success) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: synthesis failed: %s\n",
|
||||
result.error_msg.c_str());
|
||||
g_ctx = qt_init(&p);
|
||||
if (!g_ctx) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] FATAL: qt_init failed: %s\n",
|
||||
qt_last_error());
|
||||
return 3;
|
||||
}
|
||||
|
||||
int n_samples = (int)result.audio.size();
|
||||
if (n_samples == 0) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: synthesis produced no samples\n");
|
||||
return 4;
|
||||
}
|
||||
|
||||
fprintf(stderr,
|
||||
"[qwen3-tts-cpp] Synthesis done: %d samples (%.2fs @ 24kHz)\n",
|
||||
n_samples, (float)n_samples / 24000.0f);
|
||||
|
||||
if (!save_audio_file(dst, result.audio, result.sample_rate)) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: failed to write %s\n", dst);
|
||||
return 5;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[qwen3-tts-cpp] Wrote %s\n", dst);
|
||||
fprintf(stderr, "[qwen3-tts-cpp] Model loaded (%s)\n", qt_version());
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Fill a qt_tts_params from the flat wrapper arguments. Unset/zero scalars keep
|
||||
// the qt defaults (temperature 0.9, top_k 50, top_p 1.0, rep 1.05, max 2048).
|
||||
static void fill_params(qt_tts_params *tp, const char *text, const char *lang,
|
||||
const char *instruct, const char *speaker,
|
||||
const float *ref_samples, int ref_n,
|
||||
const char *ref_text, long long seed, float temperature,
|
||||
int top_k, float top_p, float repetition_penalty,
|
||||
int max_new_tokens) {
|
||||
qt_tts_default_params(tp);
|
||||
tp->text = text ? text : "";
|
||||
if (lang && lang[0] != '\0')
|
||||
tp->lang = lang; // else keep default NULL -> auto
|
||||
if (instruct && instruct[0] != '\0')
|
||||
tp->instruct = instruct;
|
||||
if (speaker && speaker[0] != '\0')
|
||||
tp->speaker = speaker;
|
||||
if (ref_samples && ref_n > 0) {
|
||||
tp->ref_audio_24k = ref_samples;
|
||||
tp->ref_n_samples = ref_n;
|
||||
if (ref_text && ref_text[0] != '\0')
|
||||
tp->ref_text = ref_text;
|
||||
}
|
||||
if (seed >= 0)
|
||||
tp->seed = (int64_t)seed; // else default -1 (random)
|
||||
if (temperature > 0.0f)
|
||||
tp->temperature = temperature;
|
||||
if (top_k > 0)
|
||||
tp->top_k = top_k;
|
||||
if (top_p > 0.0f)
|
||||
tp->top_p = top_p;
|
||||
if (repetition_penalty > 0.0f)
|
||||
tp->repetition_penalty = repetition_penalty;
|
||||
if (max_new_tokens > 0)
|
||||
tp->max_new_tokens = max_new_tokens;
|
||||
}
|
||||
|
||||
float *qt3_tts(const char *text, const char *lang, const char *instruct,
|
||||
const char *speaker, const float *ref_samples, int ref_n,
|
||||
const char *ref_text, long long seed, float temperature,
|
||||
int top_k, float top_p, float repetition_penalty,
|
||||
int max_new_tokens, int *out_n) {
|
||||
if (out_n)
|
||||
*out_n = 0;
|
||||
if (!g_ctx) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: model not loaded\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (!text || text[0] == '\0') {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: text is required\n");
|
||||
return nullptr;
|
||||
}
|
||||
qt_tts_params tp;
|
||||
fill_params(&tp, text, lang, instruct, speaker, ref_samples, ref_n,
|
||||
ref_text, seed, temperature, top_k, top_p, repetition_penalty,
|
||||
max_new_tokens);
|
||||
|
||||
qt_audio out = {0};
|
||||
enum qt_status rc = qt_synthesize(g_ctx, &tp, &out);
|
||||
if (rc != QT_STATUS_OK || out.n_samples <= 0 || !out.samples) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: synthesize failed (rc=%d): %s\n",
|
||||
(int)rc, qt_last_error());
|
||||
qt_audio_free(&out);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Copy into a plain malloc buffer the Go side frees via qt3_pcm_free.
|
||||
size_t bytes = (size_t)out.n_samples * sizeof(float);
|
||||
float *buf = (float *)malloc(bytes);
|
||||
if (!buf) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: malloc(%zu) failed\n", bytes);
|
||||
qt_audio_free(&out);
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(buf, out.samples, bytes);
|
||||
if (out_n)
|
||||
*out_n = out.n_samples;
|
||||
qt_audio_free(&out);
|
||||
return buf;
|
||||
}
|
||||
|
||||
int qt3_tts_stream(const char *text, const char *lang, const char *instruct,
|
||||
const char *speaker, const float *ref_samples, int ref_n,
|
||||
const char *ref_text, long long seed, float temperature,
|
||||
int top_k, float top_p, float repetition_penalty,
|
||||
int max_new_tokens, qt3_chunk_cb cb, void *user_data) {
|
||||
if (!g_ctx) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: model not loaded\n");
|
||||
return 1;
|
||||
}
|
||||
if (!cb) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: stream callback is null\n");
|
||||
return 2;
|
||||
}
|
||||
if (!text || text[0] == '\0') {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: text is required\n");
|
||||
return 4;
|
||||
}
|
||||
qt_tts_params tp;
|
||||
fill_params(&tp, text, lang, instruct, speaker, ref_samples, ref_n,
|
||||
ref_text, seed, temperature, top_k, top_p, repetition_penalty,
|
||||
max_new_tokens);
|
||||
// qt_audio_chunk_cb has the identical signature to qt3_chunk_cb
|
||||
// (bool vs int return are ABI-compatible; non-zero == true).
|
||||
tp.on_chunk = (qt_audio_chunk_cb)cb;
|
||||
tp.on_chunk_user_data = user_data;
|
||||
|
||||
qt_audio out = {0}; // stays empty in streaming mode
|
||||
enum qt_status rc = qt_synthesize(g_ctx, &tp, &out);
|
||||
qt_audio_free(&out);
|
||||
if (rc != QT_STATUS_OK && rc != QT_STATUS_CANCELLED) {
|
||||
fprintf(stderr, "[qwen3-tts-cpp] ERROR: stream synth failed (rc=%d): %s\n",
|
||||
(int)rc, qt_last_error());
|
||||
return 3;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void qt3_pcm_free(float *p) { free(p); }
|
||||
|
||||
void qt3_unload(void) {
|
||||
if (g_ctx) {
|
||||
qt_free(g_ctx);
|
||||
g_ctx = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int qt3_n_speakers(void) { return g_ctx ? qt_n_speakers(g_ctx) : 0; }
|
||||
|
||||
const char *qt3_speaker_name(int i) {
|
||||
return g_ctx ? qt_speaker_name(g_ctx, i) : nullptr;
|
||||
}
|
||||
|
||||
@@ -1,12 +1,47 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
extern "C" {
|
||||
int load_model(const char *model_dir, int n_threads);
|
||||
int synthesize(const char *text, const char *ref_audio_path, const char *dst,
|
||||
const char *language, float temperature, float top_p,
|
||||
int top_k, float repetition_penalty, int max_audio_tokens,
|
||||
int n_threads);
|
||||
|
||||
// Streaming PCM chunk callback. samples is mono float PCM at 24 kHz, valid
|
||||
// only for the duration of the call. Return non-zero to continue, 0 to abort.
|
||||
typedef int (*qt3_chunk_cb)(const float *samples, int n_samples,
|
||||
void *user_data);
|
||||
|
||||
// Load the talker + codec/tokenizer GGUFs. use_fa / clamp_fp16 map to
|
||||
// qt_init_params (the qt ABI exposes no thread count; ggml uses its own
|
||||
// default). Returns 0 on success, non-zero on failure.
|
||||
int qt3_load(const char *talker_path, const char *codec_path, int use_fa,
|
||||
int clamp_fp16);
|
||||
|
||||
// Synthesize to a malloc'd float PCM buffer (caller frees via qt3_pcm_free).
|
||||
// The synthesis mode (base / custom_voice / voice_design) is auto-detected by
|
||||
// qt from the talker GGUF; speaker is honoured only for custom_voice, instruct
|
||||
// for voice_design / custom_voice, and ref_samples (+ optional ref_text) drive
|
||||
// base-mode cloning. qt enforces the rules and we surface qt_last_error() on
|
||||
// QT_STATUS_MODE_INVALID. Writes the sample count to *out_n. Returns NULL on
|
||||
// failure (out_n set to 0).
|
||||
float *qt3_tts(const char *text, const char *lang, const char *instruct,
|
||||
const char *speaker, const float *ref_samples, int ref_n,
|
||||
const char *ref_text, long long seed, float temperature,
|
||||
int top_k, float top_p, float repetition_penalty,
|
||||
int max_new_tokens, int *out_n);
|
||||
|
||||
// Streaming synthesis: cb is invoked per PCM chunk as audio is produced. Same
|
||||
// param semantics as qt3_tts. Returns 0 on success.
|
||||
int qt3_tts_stream(const char *text, const char *lang, const char *instruct,
|
||||
const char *speaker, const float *ref_samples, int ref_n,
|
||||
const char *ref_text, long long seed, float temperature,
|
||||
int top_k, float top_p, float repetition_penalty,
|
||||
int max_new_tokens, qt3_chunk_cb cb, void *user_data);
|
||||
|
||||
// Free a buffer returned by qt3_tts.
|
||||
void qt3_pcm_free(float *p);
|
||||
|
||||
// Release the qt context.
|
||||
void qt3_unload(void);
|
||||
|
||||
// Named-speaker introspection (custom_voice models). Returns 0 / NULL when no
|
||||
// model is loaded or the index is out of range.
|
||||
int qt3_n_speakers(void);
|
||||
const char *qt3_speaker_name(int i);
|
||||
}
|
||||
|
||||
95
backend/go/qwen3-tts-cpp/e2e_test.go
Normal file
95
backend/go/qwen3-tts-cpp/e2e_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func ttsReq(text, voice string, lang *string, dst string) *pb.TTSRequest {
|
||||
return &pb.TTSRequest{Text: text, Voice: voice, Language: lang, Dst: dst}
|
||||
}
|
||||
|
||||
var _ = Describe("qwen3-tts-cpp e2e", Label("e2e"), func() {
|
||||
var loaded bool
|
||||
|
||||
BeforeEach(func() {
|
||||
modelPath := os.Getenv("QWEN3TTS_MODEL")
|
||||
codecPath := os.Getenv("QWEN3TTS_CODEC")
|
||||
if modelPath == "" || codecPath == "" {
|
||||
Skip("QWEN3TTS_MODEL / QWEN3TTS_CODEC not set; skipping e2e")
|
||||
}
|
||||
if !loaded {
|
||||
lib := os.Getenv("QWEN3TTS_LIBRARY")
|
||||
if lib == "" {
|
||||
lib = "./libgoqwen3ttscpp-fallback.so"
|
||||
}
|
||||
h, err := purego.Dlopen(lib, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
purego.RegisterLibFunc(&CppLoad, h, "qt3_load")
|
||||
purego.RegisterLibFunc(&CppTTS, h, "qt3_tts")
|
||||
purego.RegisterLibFunc(&CppTTSStream, h, "qt3_tts_stream")
|
||||
purego.RegisterLibFunc(&CppPCMFree, h, "qt3_pcm_free")
|
||||
purego.RegisterLibFunc(&CppUnload, h, "qt3_unload")
|
||||
Expect(CppLoad(modelPath, codecPath, 1, 0)).To(Equal(0))
|
||||
loaded = true
|
||||
}
|
||||
})
|
||||
|
||||
It("synthesizes a WAV file via TTS", func() {
|
||||
b := &Qwen3TtsCpp{opts: loadOptions{seed: 42, useFA: true}}
|
||||
dst := GinkgoT().TempDir() + "/out.wav"
|
||||
lang := "english"
|
||||
err := b.TTS(ttsReq("Hello world.", "", &lang, dst))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
fi, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(fi.Size()).To(BeNumerically(">", int64(44)))
|
||||
})
|
||||
|
||||
It("streams audio chunks via TTSStream", func() {
|
||||
b := &Qwen3TtsCpp{opts: loadOptions{seed: 42, useFA: true}}
|
||||
results := make(chan []byte, 1024)
|
||||
lang := "english"
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- b.TTSStream(ttsReq("Hello there, streaming test.", "", &lang, ""), results) }()
|
||||
|
||||
var chunks int
|
||||
var first []byte
|
||||
for c := range results {
|
||||
if chunks == 0 {
|
||||
first = c
|
||||
}
|
||||
chunks++
|
||||
}
|
||||
Expect(<-done).ToNot(HaveOccurred())
|
||||
Expect(chunks).To(BeNumerically(">=", 2))
|
||||
Expect(string(first[0:4])).To(Equal("RIFF"))
|
||||
Expect(strings.HasPrefix(string(first[8:12]), "WAVE")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("clones a voice from the config audio_path reference", func() {
|
||||
// 1s of 24kHz mono audio as a clone reference; the base model carries
|
||||
// a speaker encoder, so audio_path drives x-vector voice cloning.
|
||||
ref := GinkgoT().TempDir() + "/ref.wav"
|
||||
samples := make([]float32, qwen3ttsSampleRate)
|
||||
for i := range samples {
|
||||
samples[i] = float32(0.05 * math.Sin(float64(i)*0.06))
|
||||
}
|
||||
Expect(writeWAV24k(ref, samples)).To(Succeed())
|
||||
|
||||
b := &Qwen3TtsCpp{opts: loadOptions{seed: 42, useFA: true}, audioPath: ref}
|
||||
dst := GinkgoT().TempDir() + "/clone.wav"
|
||||
lang := "english"
|
||||
// Empty Voice -> the config audio_path is used as the clone reference.
|
||||
Expect(b.TTS(ttsReq("Cloned voice test.", "", &lang, dst))).To(Succeed())
|
||||
fi, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(fi.Size()).To(BeNumerically(">", int64(44)))
|
||||
})
|
||||
})
|
||||
@@ -4,71 +4,226 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelDir string, nThreads int) int
|
||||
CppSynthesize func(text, refAudioPath, dst, language string,
|
||||
temperature, topP float32, topK int,
|
||||
repetitionPenalty float32, maxAudioTokens, nThreads int) int
|
||||
// qt3_load(talker_path, codec_path, use_fa, clamp_fp16) int
|
||||
CppLoad func(talkerPath, codecPath string, useFA, clampFP16 int) int
|
||||
// qt3_tts(text, lang, instruct, speaker, ref_samples, ref_n, ref_text,
|
||||
// seed, temperature, top_k, top_p, rep_pen, max_new, out_n) -> float*
|
||||
CppTTS func(text, lang, instruct, speaker string, refSamples unsafe.Pointer,
|
||||
refN int, refText string, seed int64, temperature float32, topK int,
|
||||
topP, repPen float32, maxNew int, outN unsafe.Pointer) uintptr
|
||||
// qt3_tts_stream(..., cb, user) int
|
||||
CppTTSStream func(text, lang, instruct, speaker string, refSamples unsafe.Pointer,
|
||||
refN int, refText string, seed int64, temperature float32, topK int,
|
||||
topP, repPen float32, maxNew int, cb uintptr, user uintptr) int
|
||||
CppPCMFree func(ptr uintptr)
|
||||
CppUnload func()
|
||||
)
|
||||
|
||||
type Qwen3TtsCpp struct {
|
||||
base.SingleThread
|
||||
threads int
|
||||
opts loadOptions
|
||||
// audioPath is the model-config reference voice (tts.audio_path), the
|
||||
// default clone reference when a request omits an audio Voice.
|
||||
audioPath string
|
||||
}
|
||||
|
||||
func (q *Qwen3TtsCpp) Load(opts *pb.ModelOptions) error {
|
||||
// ModelFile is the model directory path (containing GGUF files)
|
||||
modelDir := opts.ModelFile
|
||||
if modelDir == "" {
|
||||
modelDir = opts.ModelPath
|
||||
model := opts.ModelFile
|
||||
if model == "" {
|
||||
model = opts.ModelPath
|
||||
}
|
||||
if !filepath.IsAbs(model) && opts.ModelPath != "" {
|
||||
model = filepath.Join(opts.ModelPath, model)
|
||||
}
|
||||
|
||||
// Resolve relative paths
|
||||
if !filepath.IsAbs(modelDir) && opts.ModelPath != "" {
|
||||
modelDir = filepath.Join(opts.ModelPath, modelDir)
|
||||
q.opts = parseOptions(opts.Options)
|
||||
|
||||
// Resolve the codec/tokenizer GGUF: explicit option, else auto-discover a
|
||||
// *tokenizer*.gguf sibling of the talker model.
|
||||
codec := q.opts.codecPath
|
||||
if codec != "" && !filepath.IsAbs(codec) {
|
||||
codec = filepath.Join(filepath.Dir(model), codec)
|
||||
}
|
||||
if codec == "" {
|
||||
codec = discoverTokenizer(filepath.Dir(model))
|
||||
}
|
||||
if codec == "" {
|
||||
return fmt.Errorf("qwen3-tts: no codec/tokenizer GGUF found; set option 'tokenizer:<file>'")
|
||||
}
|
||||
q.opts.codecPath = codec
|
||||
|
||||
q.audioPath = opts.AudioPath
|
||||
if q.audioPath != "" && !filepath.IsAbs(q.audioPath) {
|
||||
q.audioPath = filepath.Join(filepath.Dir(model), q.audioPath)
|
||||
}
|
||||
|
||||
threads := int(opts.Threads)
|
||||
if threads <= 0 {
|
||||
threads = 4
|
||||
useFA := boolToInt(q.opts.useFA)
|
||||
clamp := boolToInt(q.opts.clampFP16)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[qwen3-tts-cpp] Load talker=%s codec=%s use_fa=%d clamp_fp16=%d\n",
|
||||
model, codec, useFA, clamp)
|
||||
|
||||
if rc := CppLoad(model, codec, useFA, clamp); rc != 0 {
|
||||
return fmt.Errorf("qwen3-tts: failed to load model (rc=%d)", rc)
|
||||
}
|
||||
q.threads = threads
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[qwen3-tts-cpp] Loading models from: %s (threads=%d)\n", modelDir, threads)
|
||||
|
||||
if ret := CppLoadModel(modelDir, threads); ret != 0 {
|
||||
return fmt.Errorf("failed to load qwen3-tts model (error code: %d)", ret)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// discoverTokenizer returns the first *tokenizer*.gguf in dir, or "".
|
||||
func discoverTokenizer(dir string) string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
name := strings.ToLower(e.Name())
|
||||
if strings.Contains(name, "tokenizer") && strings.HasSuffix(name, ".gguf") {
|
||||
return filepath.Join(dir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func optStr(p *string) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return *p
|
||||
}
|
||||
|
||||
// resolveRequest derives the synthesis inputs from a TTSRequest:
|
||||
// language, instruct, speaker, ref-audio samples, ref-text and sampling.
|
||||
func (q *Qwen3TtsCpp) resolveRequest(req *pb.TTSRequest) (lang, instruct, speaker, refText string, ref []float32, s sampling, err error) {
|
||||
lang = normalizeLanguage(optStr(req.Language))
|
||||
instruct = optStr(req.Instructions)
|
||||
|
||||
var refPath string
|
||||
speaker, refPath = resolveVoice(req.Voice)
|
||||
if refPath == "" && speaker == "" && q.audioPath != "" {
|
||||
// No per-request voice: fall back to the config clone reference.
|
||||
refPath = q.audioPath
|
||||
}
|
||||
if refPath != "" {
|
||||
ref, err = readWAVAsFloat(refPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.Params != nil {
|
||||
refText = req.Params["ref_text"]
|
||||
}
|
||||
s = parseSampling(req.Params, q.opts.seed)
|
||||
return
|
||||
}
|
||||
|
||||
func (q *Qwen3TtsCpp) TTS(req *pb.TTSRequest) error {
|
||||
text := req.Text
|
||||
voice := req.Voice // reference audio path for voice cloning (empty = no cloning)
|
||||
dst := req.Dst
|
||||
language := ""
|
||||
if req.Language != nil {
|
||||
language = *req.Language
|
||||
if req.Dst == "" {
|
||||
return fmt.Errorf("qwen3-tts: TTS requires a destination path")
|
||||
}
|
||||
if req.Text == "" {
|
||||
return fmt.Errorf("qwen3-tts: TTS requires text")
|
||||
}
|
||||
lang, instruct, speaker, refText, ref, s, err := q.resolveRequest(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var refPtr unsafe.Pointer
|
||||
if len(ref) > 0 {
|
||||
refPtr = unsafe.Pointer(&ref[0])
|
||||
}
|
||||
|
||||
// Synthesis parameters with sensible defaults
|
||||
temperature := float32(0.9)
|
||||
topP := float32(0.8)
|
||||
topK := 50
|
||||
repetitionPenalty := float32(1.05)
|
||||
maxAudioTokens := 4096
|
||||
var n int32
|
||||
ptr := CppTTS(req.Text, lang, instruct, speaker, refPtr, len(ref), refText,
|
||||
s.seed, s.temperature, s.topK, s.topP, s.repPen, s.maxNew, unsafe.Pointer(&n))
|
||||
runtimeKeepAlive(ref)
|
||||
if ptr == 0 {
|
||||
return fmt.Errorf("qwen3-tts: synthesis failed")
|
||||
}
|
||||
// Register the free as soon as we own a non-null buffer, so the n<=0 guard
|
||||
// below cannot leak it (defensive: the C contract returns NULL on failure).
|
||||
defer CppPCMFree(ptr)
|
||||
if n <= 0 {
|
||||
return fmt.Errorf("qwen3-tts: synthesis produced no samples")
|
||||
}
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // C-allocated PCM, copied out before free
|
||||
out := make([]float32, int(n))
|
||||
copy(out, src)
|
||||
return writeWAV24k(req.Dst, out)
|
||||
}
|
||||
|
||||
if ret := CppSynthesize(text, voice, dst, language,
|
||||
temperature, topP, topK, repetitionPenalty,
|
||||
maxAudioTokens, q.threads); ret != 0 {
|
||||
return fmt.Errorf("failed to synthesize audio (error code: %d)", ret)
|
||||
// streamState carries the active TTSStream channel to the single shared C
|
||||
// callback. base.SingleThread serializes TTS/TTSStream, so one global slot is
|
||||
// safe and avoids leaking a purego callback per request (purego callbacks
|
||||
// cannot be freed and are capped).
|
||||
var (
|
||||
streamMu sync.Mutex
|
||||
streamChan chan []byte
|
||||
streamCbOnce sync.Once
|
||||
streamCbPtr uintptr
|
||||
)
|
||||
|
||||
// streamCallback is registered once and forwards each PCM chunk to streamChan.
|
||||
func streamCallback(samples *float32, nSamples int32, _ uintptr) uintptr {
|
||||
if nSamples <= 0 || samples == nil || streamChan == nil {
|
||||
return 1 // continue
|
||||
}
|
||||
src := unsafe.Slice(samples, int(nSamples))
|
||||
cp := make([]float32, int(nSamples)) // copy out of C memory before returning
|
||||
copy(cp, src)
|
||||
streamChan <- floatToPCM16LE(cp)
|
||||
return 1 // continue
|
||||
}
|
||||
|
||||
func (q *Qwen3TtsCpp) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
defer close(results)
|
||||
if req.Text == "" {
|
||||
return fmt.Errorf("qwen3-tts: TTSStream requires text")
|
||||
}
|
||||
|
||||
streamCbOnce.Do(func() {
|
||||
streamCbPtr = purego.NewCallback(streamCallback)
|
||||
})
|
||||
|
||||
lang, instruct, speaker, refText, ref, s, err := q.resolveRequest(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var refPtr unsafe.Pointer
|
||||
if len(ref) > 0 {
|
||||
refPtr = unsafe.Pointer(&ref[0])
|
||||
}
|
||||
|
||||
// Emit the WAV header first so the HTTP layer gets a self-describing stream.
|
||||
results <- wavHeader24k()
|
||||
|
||||
streamMu.Lock()
|
||||
streamChan = results
|
||||
rc := CppTTSStream(req.Text, lang, instruct, speaker, refPtr, len(ref), refText,
|
||||
s.seed, s.temperature, s.topK, s.topP, s.repPen, s.maxNew, streamCbPtr, 0)
|
||||
streamChan = nil
|
||||
streamMu.Unlock()
|
||||
runtimeKeepAlive(ref)
|
||||
|
||||
if rc != 0 {
|
||||
return fmt.Errorf("qwen3-tts: streaming synthesis failed (rc=%d)", rc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -19,24 +19,25 @@ type LibFuncs struct {
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Get library name from environment variable, default to fallback
|
||||
libName := os.Getenv("QWEN3TTS_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgoqwen3ttscpp-fallback.so"
|
||||
}
|
||||
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "load_model"},
|
||||
{&CppSynthesize, "synthesize"},
|
||||
{&CppLoad, "qt3_load"},
|
||||
{&CppTTS, "qt3_tts"},
|
||||
{&CppTTSStream, "qt3_tts_stream"},
|
||||
{&CppPCMFree, "qt3_pcm_free"},
|
||||
{&CppUnload, "qt3_unload"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
161
backend/go/qwen3-tts-cpp/options.go
Normal file
161
backend/go/qwen3-tts-cpp/options.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// loadOptions holds the parsed model-level options.
|
||||
type loadOptions struct {
|
||||
codecPath string
|
||||
useFA bool
|
||||
clampFP16 bool
|
||||
seed int64
|
||||
}
|
||||
|
||||
// sampling holds per-request generation parameters with qt defaults applied.
|
||||
type sampling struct {
|
||||
temperature float32
|
||||
topK int
|
||||
topP float32
|
||||
repPen float32
|
||||
maxNew int
|
||||
seed int64
|
||||
}
|
||||
|
||||
func splitOption(o string) (key, value string, ok bool) {
|
||||
i := strings.Index(o, ":")
|
||||
if i < 0 {
|
||||
return "", "", false
|
||||
}
|
||||
return strings.TrimSpace(o[:i]), strings.TrimSpace(o[i+1:]), true
|
||||
}
|
||||
|
||||
func parseBool(v string) bool { return v == "true" || v == "1" }
|
||||
|
||||
// parseOptions reads the backend "key:value" option slice. Unknown keys are
|
||||
// ignored. Defaults: use_fa true (qt default; CPU still uses the F32 chain),
|
||||
// seed -1 (engine random).
|
||||
func parseOptions(opts []string) loadOptions {
|
||||
o := loadOptions{useFA: true, seed: -1}
|
||||
for _, oo := range opts {
|
||||
key, value, ok := splitOption(oo)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "tokenizer", "codec":
|
||||
o.codecPath = value
|
||||
case "use_fa":
|
||||
o.useFA = parseBool(value)
|
||||
case "clamp_fp16":
|
||||
o.clampFP16 = parseBool(value)
|
||||
case "seed":
|
||||
if n, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
o.seed = n
|
||||
}
|
||||
}
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
// languageAliases maps codes / locales / full names to the upstream qwentts
|
||||
// language names. "auto" (and empty) map to "" so the engine auto-detects.
|
||||
var languageAliases = map[string]string{
|
||||
"en": "english", "english": "english",
|
||||
"zh": "chinese", "chinese": "chinese", "mandarin": "chinese",
|
||||
"ja": "japanese", "japanese": "japanese",
|
||||
"ko": "korean", "korean": "korean",
|
||||
"de": "german", "german": "german",
|
||||
"fr": "french", "french": "french",
|
||||
"es": "spanish", "spanish": "spanish",
|
||||
"it": "italian", "italian": "italian",
|
||||
"pt": "portuguese", "portuguese": "portuguese",
|
||||
"ru": "russian", "russian": "russian",
|
||||
"auto": "",
|
||||
}
|
||||
|
||||
// normalizeLanguage lowercases, trims, strips a region/locale suffix
|
||||
// (en-US -> en), and resolves to the qwentts language name. Empty stays empty
|
||||
// (engine auto-detects); an unknown value passes through normalized.
|
||||
func normalizeLanguage(lang string) string {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
if lang == "" {
|
||||
return ""
|
||||
}
|
||||
if i := strings.IndexAny(lang, "-_."); i >= 0 {
|
||||
lang = lang[:i]
|
||||
}
|
||||
if v, ok := languageAliases[lang]; ok {
|
||||
return v
|
||||
}
|
||||
return lang
|
||||
}
|
||||
|
||||
var refAudioExts = []string{".wav", ".flac", ".mp3", ".ogg", ".m4a"}
|
||||
|
||||
// resolveVoice interprets the request Voice field: a value ending in a known
|
||||
// audio extension is a clone-reference path; anything else is a named speaker
|
||||
// (custom_voice). Empty input yields no speaker and no reference.
|
||||
func resolveVoice(voice string) (speaker, refPath string) {
|
||||
v := strings.TrimSpace(voice)
|
||||
if v == "" {
|
||||
return "", ""
|
||||
}
|
||||
lower := strings.ToLower(v)
|
||||
for _, ext := range refAudioExts {
|
||||
if strings.HasSuffix(lower, ext) {
|
||||
return "", v
|
||||
}
|
||||
}
|
||||
return v, ""
|
||||
}
|
||||
|
||||
func parseFloat32(v string, def float32) float32 {
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
f, err := strconv.ParseFloat(v, 32)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return float32(f)
|
||||
}
|
||||
|
||||
func parseInt(v string, def int) int {
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func parseInt64(v string, def int64) int64 {
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
n, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// parseSampling reads per-request sampling params from the TTSRequest params
|
||||
// map, applying qt defaults (matching qt_tts_default_params).
|
||||
func parseSampling(params map[string]string, defaultSeed int64) sampling {
|
||||
s := sampling{temperature: 0.9, topK: 50, topP: 1.0, repPen: 1.05, maxNew: 2048, seed: defaultSeed}
|
||||
if params == nil {
|
||||
return s
|
||||
}
|
||||
s.temperature = parseFloat32(params["temperature"], s.temperature)
|
||||
s.topK = parseInt(params["top_k"], s.topK)
|
||||
s.topP = parseFloat32(params["top_p"], s.topP)
|
||||
s.repPen = parseFloat32(params["repetition_penalty"], s.repPen)
|
||||
s.maxNew = parseInt(params["max_new_tokens"], s.maxNew)
|
||||
s.seed = parseInt64(params["seed"], s.seed)
|
||||
return s
|
||||
}
|
||||
@@ -1,173 +1,136 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
testAddr = "localhost:50051"
|
||||
startupWait = 5 * time.Second
|
||||
)
|
||||
|
||||
func skipIfNoModel(t *testing.T) string {
|
||||
t.Helper()
|
||||
modelDir := os.Getenv("QWEN3TTS_MODEL_DIR")
|
||||
if modelDir == "" {
|
||||
t.Skip("QWEN3TTS_MODEL_DIR not set, skipping test (set to directory with GGUF models)")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(modelDir, "qwen3-tts-0.6b-f16.gguf")); os.IsNotExist(err) {
|
||||
t.Skipf("TTS model file not found in %s, skipping", modelDir)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(modelDir, "qwen3-tts-tokenizer-f16.gguf")); os.IsNotExist(err) {
|
||||
t.Skipf("Tokenizer model file not found in %s, skipping", modelDir)
|
||||
}
|
||||
return modelDir
|
||||
func TestQwen3TtsCpp(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "qwen3-tts-cpp suite")
|
||||
}
|
||||
|
||||
func startServer(t *testing.T) *exec.Cmd {
|
||||
t.Helper()
|
||||
binary := os.Getenv("QWEN3TTS_BINARY")
|
||||
if binary == "" {
|
||||
binary = "./qwen3-tts-cpp"
|
||||
}
|
||||
if _, err := os.Stat(binary); os.IsNotExist(err) {
|
||||
t.Skipf("Backend binary not found at %s, skipping", binary)
|
||||
}
|
||||
cmd := exec.Command(binary, "--addr", testAddr)
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
time.Sleep(startupWait)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func stopServer(cmd *exec.Cmd) {
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
cmd.Process.Kill()
|
||||
cmd.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func dialGRPC(t *testing.T) *grpc.ClientConn {
|
||||
t.Helper()
|
||||
conn, err := grpc.Dial(testAddr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024),
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024),
|
||||
),
|
||||
var _ = Describe("normalizeLanguage", func() {
|
||||
DescribeTable("maps caller language to qwentts language names",
|
||||
func(in, want string) {
|
||||
Expect(normalizeLanguage(in)).To(Equal(want))
|
||||
},
|
||||
Entry("empty stays empty", "", ""),
|
||||
Entry("auto maps to empty", "auto", ""),
|
||||
Entry("english full name", "English", "english"),
|
||||
Entry("english code", "en", "english"),
|
||||
Entry("locale suffix stripped", "en-US", "english"),
|
||||
Entry("underscore locale", "zh_CN", "chinese"),
|
||||
Entry("mandarin alias", "mandarin", "chinese"),
|
||||
Entry("japanese already full", "japanese", "japanese"),
|
||||
Entry("unknown passes through normalized", "xx", "xx"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial gRPC: %v", err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
})
|
||||
|
||||
func TestServerHealth(t *testing.T) {
|
||||
cmd := startServer(t)
|
||||
defer stopServer(cmd)
|
||||
|
||||
conn := dialGRPC(t)
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewBackendClient(conn)
|
||||
resp, err := client.Health(context.Background(), &pb.HealthMessage{})
|
||||
if err != nil {
|
||||
t.Fatalf("Health check failed: %v", err)
|
||||
}
|
||||
if string(resp.Message) != "OK" {
|
||||
t.Fatalf("Expected OK, got %s", string(resp.Message))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadModel(t *testing.T) {
|
||||
modelDir := skipIfNoModel(t)
|
||||
cmd := startServer(t)
|
||||
defer stopServer(cmd)
|
||||
|
||||
conn := dialGRPC(t)
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: modelDir,
|
||||
Threads: 4,
|
||||
var _ = Describe("resolveVoice", func() {
|
||||
It("treats a bare token as a named speaker", func() {
|
||||
sp, ref := resolveVoice("serena")
|
||||
Expect(sp).To(Equal("serena"))
|
||||
Expect(ref).To(BeEmpty())
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel failed: %v", err)
|
||||
}
|
||||
if !resp.Success {
|
||||
t.Fatalf("LoadModel returned failure: %s", resp.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTTS(t *testing.T) {
|
||||
modelDir := skipIfNoModel(t)
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "qwen3tts-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
|
||||
outputFile := filepath.Join(tmpDir, "output.wav")
|
||||
|
||||
cmd := startServer(t)
|
||||
defer stopServer(cmd)
|
||||
|
||||
conn := dialGRPC(t)
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
// Load models
|
||||
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: modelDir,
|
||||
Threads: 4,
|
||||
It("treats an audio path as a clone reference (case-insensitive ext)", func() {
|
||||
sp, ref := resolveVoice("/x/ref.WAV")
|
||||
Expect(sp).To(BeEmpty())
|
||||
Expect(ref).To(Equal("/x/ref.WAV"))
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel failed: %v", err)
|
||||
}
|
||||
if !loadResp.Success {
|
||||
t.Fatalf("LoadModel returned failure: %s", loadResp.Message)
|
||||
}
|
||||
|
||||
// Synthesize speech
|
||||
language := "en"
|
||||
_, err = client.TTS(context.Background(), &pb.TTSRequest{
|
||||
Text: "Hello, this is a test of the Qwen3 text to speech system.",
|
||||
Dst: outputFile,
|
||||
Language: &language,
|
||||
It("recognizes mp3/flac/ogg/m4a", func() {
|
||||
for _, p := range []string{"a.mp3", "b.flac", "c.ogg", "d.m4a"} {
|
||||
sp, ref := resolveVoice(p)
|
||||
Expect(sp).To(BeEmpty())
|
||||
Expect(ref).To(Equal(p))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("TTS failed: %v", err)
|
||||
}
|
||||
It("returns empty for empty input", func() {
|
||||
sp, ref := resolveVoice(" ")
|
||||
Expect(sp).To(BeEmpty())
|
||||
Expect(ref).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
// Verify output file exists and has content
|
||||
info, err := os.Stat(outputFile)
|
||||
if os.IsNotExist(err) {
|
||||
t.Fatal("Output audio file was not created")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat output file: %v", err)
|
||||
}
|
||||
var _ = Describe("parseOptions", func() {
|
||||
It("extracts codec, use_fa, clamp_fp16, seed", func() {
|
||||
o := parseOptions([]string{
|
||||
"tokenizer:tok.gguf", "use_fa:false", "clamp_fp16:true",
|
||||
"seed:7", "unknown:ignored",
|
||||
})
|
||||
Expect(o.codecPath).To(Equal("tok.gguf"))
|
||||
Expect(o.useFA).To(BeFalse())
|
||||
Expect(o.clampFP16).To(BeTrue())
|
||||
Expect(o.seed).To(Equal(int64(7)))
|
||||
})
|
||||
It("accepts codec: as an alias for tokenizer:", func() {
|
||||
Expect(parseOptions([]string{"codec:c.gguf"}).codecPath).To(Equal("c.gguf"))
|
||||
})
|
||||
It("defaults use_fa true and seed -1", func() {
|
||||
o := parseOptions(nil)
|
||||
Expect(o.useFA).To(BeTrue())
|
||||
Expect(o.seed).To(Equal(int64(-1)))
|
||||
})
|
||||
})
|
||||
|
||||
t.Logf("Output file size: %d bytes", info.Size())
|
||||
var _ = Describe("parseSampling", func() {
|
||||
It("applies qt defaults when params are absent", func() {
|
||||
s := parseSampling(nil, -1)
|
||||
Expect(s.temperature).To(BeNumerically("~", 0.9, 1e-6))
|
||||
Expect(s.topK).To(Equal(50))
|
||||
Expect(s.topP).To(BeNumerically("~", 1.0, 1e-6))
|
||||
Expect(s.repPen).To(BeNumerically("~", 1.05, 1e-6))
|
||||
Expect(s.maxNew).To(Equal(2048))
|
||||
Expect(s.seed).To(Equal(int64(-1)))
|
||||
})
|
||||
It("reads overrides and falls back to default seed", func() {
|
||||
s := parseSampling(map[string]string{
|
||||
"temperature": "0.5", "top_k": "10", "top_p": "0.8",
|
||||
"repetition_penalty": "1.2", "max_new_tokens": "512",
|
||||
}, 99)
|
||||
Expect(s.temperature).To(BeNumerically("~", 0.5, 1e-6))
|
||||
Expect(s.topK).To(Equal(10))
|
||||
Expect(s.topP).To(BeNumerically("~", 0.8, 1e-6))
|
||||
Expect(s.repPen).To(BeNumerically("~", 1.2, 1e-6))
|
||||
Expect(s.maxNew).To(Equal(512))
|
||||
Expect(s.seed).To(Equal(int64(99)))
|
||||
})
|
||||
It("reads an explicit seed override", func() {
|
||||
Expect(parseSampling(map[string]string{"seed": "123"}, -1).seed).To(Equal(int64(123)))
|
||||
})
|
||||
})
|
||||
|
||||
// WAV header is 44 bytes minimum; any real audio should be much larger
|
||||
if info.Size() < 1000 {
|
||||
t.Errorf("Output file too small (%d bytes), expected real audio data", info.Size())
|
||||
}
|
||||
}
|
||||
var _ = Describe("wavHeader24k", func() {
|
||||
It("emits a 44-byte streaming WAV header at 24 kHz mono 16-bit", func() {
|
||||
h := wavHeader24k()
|
||||
Expect(h).To(HaveLen(44))
|
||||
Expect(string(h[0:4])).To(Equal("RIFF"))
|
||||
Expect(string(h[8:12])).To(Equal("WAVE"))
|
||||
Expect(string(h[12:16])).To(Equal("fmt "))
|
||||
Expect(string(h[36:40])).To(Equal("data"))
|
||||
var sampleRate uint32
|
||||
Expect(binary.Read(bytes.NewReader(h[24:28]), binary.LittleEndian, &sampleRate)).To(Succeed())
|
||||
Expect(sampleRate).To(Equal(uint32(24000)))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("floatToPCM16LE", func() {
|
||||
It("clamps and converts float PCM to little-endian int16 bytes", func() {
|
||||
b := floatToPCM16LE([]float32{0, 1.0, -1.0, 2.0, -2.0})
|
||||
Expect(b).To(HaveLen(10))
|
||||
read := func(off int) int16 {
|
||||
var v int16
|
||||
_ = binary.Read(bytes.NewReader(b[off:off+2]), binary.LittleEndian, &v)
|
||||
return v
|
||||
}
|
||||
Expect(read(0)).To(Equal(int16(0)))
|
||||
Expect(read(2)).To(Equal(int16(32767)))
|
||||
Expect(read(4)).To(Equal(int16(-32767)))
|
||||
Expect(read(6)).To(Equal(int16(32767))) // clamped from 2.0
|
||||
Expect(read(8)).To(Equal(int16(-32767))) // clamped from -2.0
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,51 +2,30 @@
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
cd "$CURDIR"
|
||||
|
||||
echo "Running qwen3-tts-cpp backend tests..."
|
||||
|
||||
# The test requires:
|
||||
# - QWEN3TTS_MODEL_DIR: path to directory containing GGUF model files
|
||||
# - QWEN3TTS_BINARY: path to the qwen3-tts-cpp binary (defaults to ./qwen3-tts-cpp)
|
||||
#
|
||||
# Tests that require the model will be skipped if QWEN3TTS_MODEL_DIR is not set
|
||||
# or the directory does not contain the required model files.
|
||||
|
||||
cd "$CURDIR"
|
||||
|
||||
# Only auto-download models when QWEN3TTS_MODEL_DIR is not explicitly set
|
||||
if [ -z "$QWEN3TTS_MODEL_DIR" ]; then
|
||||
export QWEN3TTS_MODEL_DIR="./qwen3-tts-models"
|
||||
|
||||
if [ ! -d "$QWEN3TTS_MODEL_DIR" ]; then
|
||||
echo "Creating qwen3-tts-models directory for tests..."
|
||||
mkdir -p "$QWEN3TTS_MODEL_DIR"
|
||||
REPO_ID="endo5501/qwen3-tts.cpp"
|
||||
echo "Repository: ${REPO_ID}"
|
||||
echo ""
|
||||
|
||||
# Files to download (smallest model for testing)
|
||||
FILES=(
|
||||
"qwen3-tts-0.6b-f16.gguf"
|
||||
"qwen3-tts-tokenizer-f16.gguf"
|
||||
)
|
||||
|
||||
BASE_URL="https://huggingface.co/${REPO_ID}/resolve/main"
|
||||
|
||||
for file in "${FILES[@]}"; do
|
||||
dest="${QWEN3TTS_MODEL_DIR}/${file}"
|
||||
if [ -f "${dest}" ]; then
|
||||
echo " [skip] ${file} (already exists)"
|
||||
else
|
||||
echo " [download] ${file}..."
|
||||
curl -L -o "${dest}" "${BASE_URL}/${file}" --progress-bar
|
||||
echo " [done] ${file}"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
# Auto-download a small model pair only when QWEN3TTS_MODEL is not set.
|
||||
if [ -z "$QWEN3TTS_MODEL" ]; then
|
||||
MODEL_DIR="./qwen3-tts-models"
|
||||
mkdir -p "$MODEL_DIR"
|
||||
REPO_ID="Serveurperso/Qwen3-TTS-GGUF"
|
||||
BASE_URL="https://huggingface.co/${REPO_ID}/resolve/main"
|
||||
FILES=( "qwen-talker-0.6b-base-Q4_K_M.gguf" "qwen-tokenizer-12hz-Q4_K_M.gguf" )
|
||||
for file in "${FILES[@]}"; do
|
||||
dest="${MODEL_DIR}/${file}"
|
||||
if [ -f "${dest}" ]; then
|
||||
echo " [skip] ${file}"
|
||||
else
|
||||
echo " [download] ${file}..."
|
||||
curl -L -o "${dest}" "${BASE_URL}/${file}" --progress-bar
|
||||
fi
|
||||
done
|
||||
export QWEN3TTS_MODEL="${MODEL_DIR}/qwen-talker-0.6b-base-Q4_K_M.gguf"
|
||||
export QWEN3TTS_CODEC="${MODEL_DIR}/qwen-tokenizer-12hz-Q4_K_M.gguf"
|
||||
fi
|
||||
|
||||
# Run Go tests
|
||||
go test -v -timeout 600s .
|
||||
go test -v -timeout 1200s .
|
||||
|
||||
echo "All qwen3-tts-cpp tests passed."
|
||||
|
||||
@@ -11,7 +11,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
# build; leaving this on `master` always picks up the latest C-API surface
|
||||
# (incl. the per-detection accessor functions used by gorfdetrcpp.go).
|
||||
RFDETR_REPO?=https://github.com/mudler/rf-detr.cpp.git
|
||||
RFDETR_VERSION?=main
|
||||
RFDETR_VERSION?=65c0ffcc9a9bc9dae38252f63d0417c9845a6cf7
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
|
||||
@@ -62,7 +62,7 @@ var (
|
||||
shimVadConfigSetDebug func(uintptr, int32)
|
||||
shimCreateVad func(uintptr, float32) uintptr
|
||||
|
||||
// TTS (offline, VITS) config
|
||||
// TTS (offline, VITS/Piper and Kokoro) config
|
||||
shimTtsConfigNew func() uintptr
|
||||
shimTtsConfigFree func(uintptr)
|
||||
shimTtsConfigSetVitsModel func(uintptr, string)
|
||||
@@ -76,6 +76,14 @@ var (
|
||||
shimTtsConfigSetDebug func(uintptr, int32)
|
||||
shimTtsConfigSetProvider func(uintptr, string)
|
||||
shimTtsConfigSetMaxNumSentences func(uintptr, int32)
|
||||
shimTtsConfigSetKokoroModel func(uintptr, string)
|
||||
shimTtsConfigSetKokoroVoices func(uintptr, string)
|
||||
shimTtsConfigSetKokoroTokens func(uintptr, string)
|
||||
shimTtsConfigSetKokoroDataDir func(uintptr, string)
|
||||
shimTtsConfigSetKokoroDictDir func(uintptr, string)
|
||||
shimTtsConfigSetKokoroLexicon func(uintptr, string)
|
||||
shimTtsConfigSetKokoroLang func(uintptr, string)
|
||||
shimTtsConfigSetKokoroLengthScale func(uintptr, float32)
|
||||
shimCreateOfflineTts func(uintptr) uintptr
|
||||
|
||||
// Offline recognizer config
|
||||
@@ -101,37 +109,37 @@ var (
|
||||
shimCreateOfflineRecognizer func(uintptr) uintptr
|
||||
|
||||
// Online recognizer config
|
||||
shimOnlineRecogConfigNew func() uintptr
|
||||
shimOnlineRecogConfigFree func(uintptr)
|
||||
shimOnlineRecogConfigSetTransducerEncoder func(uintptr, string)
|
||||
shimOnlineRecogConfigSetTransducerDecoder func(uintptr, string)
|
||||
shimOnlineRecogConfigSetTransducerJoiner func(uintptr, string)
|
||||
shimOnlineRecogConfigSetTokens func(uintptr, string)
|
||||
shimOnlineRecogConfigSetNumThreads func(uintptr, int32)
|
||||
shimOnlineRecogConfigSetDebug func(uintptr, int32)
|
||||
shimOnlineRecogConfigSetProvider func(uintptr, string)
|
||||
shimOnlineRecogConfigSetFeatSampleRate func(uintptr, int32)
|
||||
shimOnlineRecogConfigSetFeatFeatureDim func(uintptr, int32)
|
||||
shimOnlineRecogConfigSetDecodingMethod func(uintptr, string)
|
||||
shimOnlineRecogConfigSetEnableEndpoint func(uintptr, int32)
|
||||
shimOnlineRecogConfigNew func() uintptr
|
||||
shimOnlineRecogConfigFree func(uintptr)
|
||||
shimOnlineRecogConfigSetTransducerEncoder func(uintptr, string)
|
||||
shimOnlineRecogConfigSetTransducerDecoder func(uintptr, string)
|
||||
shimOnlineRecogConfigSetTransducerJoiner func(uintptr, string)
|
||||
shimOnlineRecogConfigSetTokens func(uintptr, string)
|
||||
shimOnlineRecogConfigSetNumThreads func(uintptr, int32)
|
||||
shimOnlineRecogConfigSetDebug func(uintptr, int32)
|
||||
shimOnlineRecogConfigSetProvider func(uintptr, string)
|
||||
shimOnlineRecogConfigSetFeatSampleRate func(uintptr, int32)
|
||||
shimOnlineRecogConfigSetFeatFeatureDim func(uintptr, int32)
|
||||
shimOnlineRecogConfigSetDecodingMethod func(uintptr, string)
|
||||
shimOnlineRecogConfigSetEnableEndpoint func(uintptr, int32)
|
||||
shimOnlineRecogConfigSetRule1MinTrailingSilence func(uintptr, float32)
|
||||
shimOnlineRecogConfigSetRule2MinTrailingSilence func(uintptr, float32)
|
||||
shimOnlineRecogConfigSetRule3MinUtteranceLength func(uintptr, float32)
|
||||
shimCreateOnlineRecognizer func(uintptr) uintptr
|
||||
shimCreateOnlineRecognizer func(uintptr) uintptr
|
||||
|
||||
// Result accessors. Pointer returns use unsafe.Pointer so Go's
|
||||
// vet checker doesn't flag them — the returned memory is C-owned,
|
||||
// not subject to Go GC motion.
|
||||
shimWaveSampleRate func(uintptr) int32
|
||||
shimWaveNumSamples func(uintptr) int32
|
||||
shimWaveSamples func(uintptr) unsafe.Pointer
|
||||
shimOfflineResultText func(uintptr) unsafe.Pointer
|
||||
shimOnlineResultText func(uintptr) unsafe.Pointer
|
||||
shimGeneratedAudioSampleRate func(uintptr) int32
|
||||
shimGeneratedAudioN func(uintptr) int32
|
||||
shimGeneratedAudioSamples func(uintptr) unsafe.Pointer
|
||||
shimSpeechSegmentStart func(uintptr) int32
|
||||
shimSpeechSegmentN func(uintptr) int32
|
||||
shimWaveSampleRate func(uintptr) int32
|
||||
shimWaveNumSamples func(uintptr) int32
|
||||
shimWaveSamples func(uintptr) unsafe.Pointer
|
||||
shimOfflineResultText func(uintptr) unsafe.Pointer
|
||||
shimOnlineResultText func(uintptr) unsafe.Pointer
|
||||
shimGeneratedAudioSampleRate func(uintptr) int32
|
||||
shimGeneratedAudioN func(uintptr) int32
|
||||
shimGeneratedAudioSamples func(uintptr) unsafe.Pointer
|
||||
shimSpeechSegmentStart func(uintptr) int32
|
||||
shimSpeechSegmentN func(uintptr) int32
|
||||
|
||||
// TTS streaming callback trampoline
|
||||
shimTtsGenerateWithCallback func(tts uintptr, text string, sid int32, speed float32, cb uintptr, ud uintptr) uintptr
|
||||
@@ -161,13 +169,13 @@ var (
|
||||
// pointer returned by the shim or `unsafe.Pointer(&slice[0])` from Go.
|
||||
var (
|
||||
// VAD
|
||||
sherpaVadAcceptWaveform func(vad uintptr, samples unsafe.Pointer, n int32)
|
||||
sherpaVadReset func(vad uintptr)
|
||||
sherpaVadFlush func(vad uintptr)
|
||||
sherpaVadEmpty func(vad uintptr) int32
|
||||
sherpaVadFront func(vad uintptr) uintptr
|
||||
sherpaVadPop func(vad uintptr)
|
||||
sherpaDestroySpeechSegment func(seg uintptr)
|
||||
sherpaVadAcceptWaveform func(vad uintptr, samples unsafe.Pointer, n int32)
|
||||
sherpaVadReset func(vad uintptr)
|
||||
sherpaVadFlush func(vad uintptr)
|
||||
sherpaVadEmpty func(vad uintptr) int32
|
||||
sherpaVadFront func(vad uintptr) uintptr
|
||||
sherpaVadPop func(vad uintptr)
|
||||
sherpaDestroySpeechSegment func(seg uintptr)
|
||||
|
||||
// Wave IO
|
||||
sherpaReadWave func(filename string) uintptr
|
||||
@@ -175,11 +183,11 @@ var (
|
||||
sherpaWriteWave func(samples unsafe.Pointer, n int32, sampleRate int32, filename string) int32
|
||||
|
||||
// Offline ASR
|
||||
sherpaCreateOfflineStream func(rec uintptr) uintptr
|
||||
sherpaDestroyOfflineStream func(stream uintptr)
|
||||
sherpaAcceptWaveformOffline func(stream uintptr, sr int32, samples unsafe.Pointer, n int32)
|
||||
sherpaDecodeOfflineStream func(rec uintptr, stream uintptr)
|
||||
sherpaGetOfflineStreamResult func(stream uintptr) uintptr
|
||||
sherpaCreateOfflineStream func(rec uintptr) uintptr
|
||||
sherpaDestroyOfflineStream func(stream uintptr)
|
||||
sherpaAcceptWaveformOffline func(stream uintptr, sr int32, samples unsafe.Pointer, n int32)
|
||||
sherpaDecodeOfflineStream func(rec uintptr, stream uintptr)
|
||||
sherpaGetOfflineStreamResult func(stream uintptr) uintptr
|
||||
sherpaDestroyOfflineRecognizerResult func(result uintptr)
|
||||
|
||||
// Online ASR
|
||||
@@ -195,21 +203,21 @@ var (
|
||||
sherpaOnlineStreamInputFinished func(stream uintptr)
|
||||
|
||||
// TTS
|
||||
sherpaOfflineTtsGenerate func(tts uintptr, text string, sid int32, speed float32) uintptr
|
||||
sherpaOfflineTtsGenerate func(tts uintptr, text string, sid int32, speed float32) uintptr
|
||||
sherpaDestroyOfflineTtsGeneratedAudio func(audio uintptr)
|
||||
sherpaOfflineTtsSampleRate func(tts uintptr) int32
|
||||
sherpaOfflineTtsSampleRate func(tts uintptr) int32
|
||||
|
||||
// Offline speaker diarization. Result handle owns the segment-array
|
||||
// pointer returned by ResultSortByStartTime; destroy the segment
|
||||
// array first, then the result, then (at backend Free()) the diarizer.
|
||||
sherpaDestroyOfflineSpeakerDiarization func(sd uintptr)
|
||||
sherpaOfflineSpeakerDiarizationGetSampleRate func(sd uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationProcess func(sd uintptr, samples unsafe.Pointer, n int32) uintptr
|
||||
sherpaOfflineSpeakerDiarizationResultGetNumSegments func(result uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationResultGetNumSpeakers func(result uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationResultSortByStartTime func(result uintptr) uintptr
|
||||
sherpaOfflineSpeakerDiarizationDestroySegment func(segs uintptr)
|
||||
sherpaDestroyOfflineSpeakerDiarizationResult func(result uintptr)
|
||||
sherpaDestroyOfflineSpeakerDiarization func(sd uintptr)
|
||||
sherpaOfflineSpeakerDiarizationGetSampleRate func(sd uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationProcess func(sd uintptr, samples unsafe.Pointer, n int32) uintptr
|
||||
sherpaOfflineSpeakerDiarizationResultGetNumSegments func(result uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationResultGetNumSpeakers func(result uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationResultSortByStartTime func(result uintptr) uintptr
|
||||
sherpaOfflineSpeakerDiarizationDestroySegment func(segs uintptr)
|
||||
sherpaDestroyOfflineSpeakerDiarizationResult func(result uintptr)
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -278,6 +286,14 @@ func loadSherpaLibsOnce() error {
|
||||
{&shimTtsConfigSetDebug, "sherpa_shim_tts_config_set_debug"},
|
||||
{&shimTtsConfigSetProvider, "sherpa_shim_tts_config_set_provider"},
|
||||
{&shimTtsConfigSetMaxNumSentences, "sherpa_shim_tts_config_set_max_num_sentences"},
|
||||
{&shimTtsConfigSetKokoroModel, "sherpa_shim_tts_config_set_kokoro_model"},
|
||||
{&shimTtsConfigSetKokoroVoices, "sherpa_shim_tts_config_set_kokoro_voices"},
|
||||
{&shimTtsConfigSetKokoroTokens, "sherpa_shim_tts_config_set_kokoro_tokens"},
|
||||
{&shimTtsConfigSetKokoroDataDir, "sherpa_shim_tts_config_set_kokoro_data_dir"},
|
||||
{&shimTtsConfigSetKokoroDictDir, "sherpa_shim_tts_config_set_kokoro_dict_dir"},
|
||||
{&shimTtsConfigSetKokoroLexicon, "sherpa_shim_tts_config_set_kokoro_lexicon"},
|
||||
{&shimTtsConfigSetKokoroLang, "sherpa_shim_tts_config_set_kokoro_lang"},
|
||||
{&shimTtsConfigSetKokoroLengthScale, "sherpa_shim_tts_config_set_kokoro_length_scale"},
|
||||
{&shimCreateOfflineTts, "sherpa_shim_create_offline_tts"},
|
||||
|
||||
{&shimOfflineRecogConfigNew, "sherpa_shim_offline_recog_config_new"},
|
||||
@@ -688,21 +704,14 @@ func (s *SherpaBackend) loadTTS(opts *pb.ModelOptions) error {
|
||||
cfg := shimTtsConfigNew()
|
||||
defer shimTtsConfigFree(cfg)
|
||||
|
||||
shimTtsConfigSetVitsModel(cfg, modelFile)
|
||||
|
||||
if tokensPath := filepath.Join(modelDir, "tokens.txt"); fileExists(tokensPath) {
|
||||
shimTtsConfigSetVitsTokens(cfg, tokensPath)
|
||||
// Kokoro models ship a voices style file alongside the ONNX, whereas
|
||||
// VITS/Piper voices do not. That presence is what tells the two model
|
||||
// families apart, since both arrive as a plain *.onnx in modelDir.
|
||||
if isKokoroModel(modelDir) {
|
||||
s.configureKokoroTTS(cfg, opts, modelFile, modelDir)
|
||||
} else {
|
||||
s.configureVitsTTS(cfg, opts, modelFile, modelDir)
|
||||
}
|
||||
if lexiconPath := filepath.Join(modelDir, "lexicon.txt"); fileExists(lexiconPath) {
|
||||
shimTtsConfigSetVitsLexicon(cfg, lexiconPath)
|
||||
}
|
||||
if dataDir := filepath.Join(modelDir, "espeak-ng-data"); dirExists(dataDir) {
|
||||
shimTtsConfigSetVitsDataDir(cfg, dataDir)
|
||||
}
|
||||
|
||||
shimTtsConfigSetVitsNoiseScale(cfg, findOptionFloat(opts, optionTtsNoiseScale, 0.667))
|
||||
shimTtsConfigSetVitsNoiseScaleW(cfg, findOptionFloat(opts, optionTtsNoiseScaleW, 0.8))
|
||||
shimTtsConfigSetVitsLengthScale(cfg, findOptionFloat(opts, optionTtsLengthScale, 1.0))
|
||||
|
||||
threads := int32(1)
|
||||
if opts.Threads != 0 {
|
||||
@@ -723,6 +732,80 @@ func (s *SherpaBackend) loadTTS(opts *pb.ModelOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// kokoroVoicesFile is the speaker-style bank that ships with Kokoro models and
|
||||
// is absent from VITS/Piper voices; its presence is how loadTTS tells them apart.
|
||||
const kokoroVoicesFile = "voices.bin"
|
||||
|
||||
// isKokoroModel reports whether modelDir holds a Kokoro model (a voices file
|
||||
// next to the ONNX) rather than a VITS/Piper single-speaker model.
|
||||
func isKokoroModel(modelDir string) bool {
|
||||
return fileExists(filepath.Join(modelDir, kokoroVoicesFile))
|
||||
}
|
||||
|
||||
// configureVitsTTS wires a VITS/Piper single-speaker model into cfg: the ONNX
|
||||
// plus the optional tokens, lexicon and espeak-ng-data found beside it.
|
||||
func (s *SherpaBackend) configureVitsTTS(cfg uintptr, opts *pb.ModelOptions, modelFile, modelDir string) {
|
||||
shimTtsConfigSetVitsModel(cfg, modelFile)
|
||||
|
||||
if tokensPath := filepath.Join(modelDir, "tokens.txt"); fileExists(tokensPath) {
|
||||
shimTtsConfigSetVitsTokens(cfg, tokensPath)
|
||||
}
|
||||
if lexiconPath := filepath.Join(modelDir, "lexicon.txt"); fileExists(lexiconPath) {
|
||||
shimTtsConfigSetVitsLexicon(cfg, lexiconPath)
|
||||
}
|
||||
if dataDir := filepath.Join(modelDir, "espeak-ng-data"); dirExists(dataDir) {
|
||||
shimTtsConfigSetVitsDataDir(cfg, dataDir)
|
||||
}
|
||||
|
||||
shimTtsConfigSetVitsNoiseScale(cfg, findOptionFloat(opts, optionTtsNoiseScale, 0.667))
|
||||
shimTtsConfigSetVitsNoiseScaleW(cfg, findOptionFloat(opts, optionTtsNoiseScaleW, 0.8))
|
||||
shimTtsConfigSetVitsLengthScale(cfg, findOptionFloat(opts, optionTtsLengthScale, 1.0))
|
||||
}
|
||||
|
||||
// configureKokoroTTS wires a Kokoro model into cfg: the ONNX, its voices bank,
|
||||
// tokens, and the optional espeak-ng-data / jieba dict / lexicon assets the
|
||||
// multi-lingual packs ship. A language hint comes from the `language=` option.
|
||||
func (s *SherpaBackend) configureKokoroTTS(cfg uintptr, opts *pb.ModelOptions, modelFile, modelDir string) {
|
||||
shimTtsConfigSetKokoroModel(cfg, modelFile)
|
||||
shimTtsConfigSetKokoroVoices(cfg, filepath.Join(modelDir, kokoroVoicesFile))
|
||||
|
||||
if tokensPath := filepath.Join(modelDir, "tokens.txt"); fileExists(tokensPath) {
|
||||
shimTtsConfigSetKokoroTokens(cfg, tokensPath)
|
||||
}
|
||||
if dataDir := filepath.Join(modelDir, "espeak-ng-data"); dirExists(dataDir) {
|
||||
shimTtsConfigSetKokoroDataDir(cfg, dataDir)
|
||||
}
|
||||
if dictDir := filepath.Join(modelDir, "dict"); dirExists(dictDir) {
|
||||
shimTtsConfigSetKokoroDictDir(cfg, dictDir)
|
||||
}
|
||||
|
||||
// Multi-lingual Kokoro ships per-language lexicons; the C API takes them as
|
||||
// a single comma-separated list. US and GB English overlap almost entirely,
|
||||
// so pass only one (US preferred) to avoid tens of thousands of "duplicated
|
||||
// word" warnings at load; non-English lexicons (e.g. zh) are additive.
|
||||
var lexicons []string
|
||||
addLexicon := func(name string) {
|
||||
if p := filepath.Join(modelDir, name); fileExists(p) {
|
||||
lexicons = append(lexicons, p)
|
||||
}
|
||||
}
|
||||
if fileExists(filepath.Join(modelDir, "lexicon-us-en.txt")) {
|
||||
addLexicon("lexicon-us-en.txt")
|
||||
} else {
|
||||
addLexicon("lexicon-gb-en.txt")
|
||||
}
|
||||
addLexicon("lexicon-zh.txt")
|
||||
addLexicon("lexicon.txt")
|
||||
if len(lexicons) > 0 {
|
||||
shimTtsConfigSetKokoroLexicon(cfg, strings.Join(lexicons, ","))
|
||||
}
|
||||
|
||||
if lang := findOptionValue(opts, optionLanguage, ""); lang != "" {
|
||||
shimTtsConfigSetKokoroLang(cfg, lang)
|
||||
}
|
||||
shimTtsConfigSetKokoroLengthScale(cfg, findOptionFloat(opts, optionTtsLengthScale, 1.0))
|
||||
}
|
||||
|
||||
func fileExists(p string) bool {
|
||||
info, err := os.Stat(p)
|
||||
return err == nil && !info.IsDir()
|
||||
@@ -1252,7 +1335,7 @@ type ttsStreamState struct {
|
||||
var (
|
||||
ttsStates sync.Map // uint64 → *ttsStreamState
|
||||
ttsNextID atomic.Uint64
|
||||
ttsCallbackPtr uintptr // purego.NewCallback return; registered in loadSherpaLibs
|
||||
ttsCallbackPtr uintptr // purego.NewCallback return; registered in loadSherpaLibs
|
||||
)
|
||||
|
||||
// ttsStreamCallback is invoked by sherpa-onnx for each PCM chunk VITS
|
||||
|
||||
@@ -124,6 +124,20 @@ var _ = Describe("Sherpa-ONNX", func() {
|
||||
Entry("empty", "", false),
|
||||
Entry("other", "other", false),
|
||||
)
|
||||
|
||||
It("isKokoroModel detects a voices file beside the ONNX", func() {
|
||||
dir, err := os.MkdirTemp("", "sherpa-kokoro-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
// A bare VITS/Piper directory (ONNX only) is not Kokoro.
|
||||
Expect(os.WriteFile(filepath.Join(dir, "model.onnx"), []byte("x"), 0o600)).To(Succeed())
|
||||
Expect(isKokoroModel(dir)).To(BeFalse())
|
||||
|
||||
// Adding the Kokoro voices bank flips detection on.
|
||||
Expect(os.WriteFile(filepath.Join(dir, kokoroVoicesFile), []byte("x"), 0o600)).To(Succeed())
|
||||
Expect(isKokoroModel(dir)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("option parsing", func() {
|
||||
|
||||
@@ -79,6 +79,13 @@ void sherpa_shim_tts_config_free(void *h) {
|
||||
free((char *)c->model.vits.tokens);
|
||||
free((char *)c->model.vits.lexicon);
|
||||
free((char *)c->model.vits.data_dir);
|
||||
free((char *)c->model.kokoro.model);
|
||||
free((char *)c->model.kokoro.voices);
|
||||
free((char *)c->model.kokoro.tokens);
|
||||
free((char *)c->model.kokoro.data_dir);
|
||||
free((char *)c->model.kokoro.dict_dir);
|
||||
free((char *)c->model.kokoro.lexicon);
|
||||
free((char *)c->model.kokoro.lang);
|
||||
free((char *)c->model.provider);
|
||||
free(c);
|
||||
}
|
||||
@@ -117,6 +124,34 @@ void sherpa_shim_tts_config_set_max_num_sentences(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineTtsConfig *)h)->max_num_sentences = v;
|
||||
}
|
||||
|
||||
// Kokoro multi-speaker / multi-lingual TTS. Distinct ONNX + a voices style
|
||||
// file (voices.bin) instead of VITS' single-speaker graph; espeak-ng-data,
|
||||
// lexicon and a language hint are optional refinements.
|
||||
void sherpa_shim_tts_config_set_kokoro_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.kokoro.model, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_kokoro_voices(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.kokoro.voices, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_kokoro_tokens(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.kokoro.tokens, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_kokoro_data_dir(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.kokoro.data_dir, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_kokoro_dict_dir(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.kokoro.dict_dir, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_kokoro_lexicon(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.kokoro.lexicon, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_kokoro_lang(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.kokoro.lang, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_kokoro_length_scale(void *h, float v) {
|
||||
((SherpaOnnxOfflineTtsConfig *)h)->model.kokoro.length_scale = v;
|
||||
}
|
||||
|
||||
void *sherpa_shim_create_offline_tts(void *h) {
|
||||
return (void *)SherpaOnnxCreateOfflineTts(
|
||||
(const SherpaOnnxOfflineTtsConfig *)h);
|
||||
|
||||
@@ -37,7 +37,7 @@ void sherpa_shim_vad_config_set_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_vad_config_set_debug(void *cfg, int32_t v);
|
||||
void *sherpa_shim_create_vad(void *cfg, float buffer_size_seconds);
|
||||
|
||||
// --- Offline TTS config (VITS path — the only TTS family the backend uses) ---
|
||||
// --- Offline TTS config (VITS/Piper and Kokoro model families) ---
|
||||
void *sherpa_shim_tts_config_new(void);
|
||||
void sherpa_shim_tts_config_free(void *cfg);
|
||||
void sherpa_shim_tts_config_set_vits_model(void *cfg, const char *v);
|
||||
@@ -51,6 +51,14 @@ void sherpa_shim_tts_config_set_num_threads(void *cfg, int32_t v);
|
||||
void sherpa_shim_tts_config_set_debug(void *cfg, int32_t v);
|
||||
void sherpa_shim_tts_config_set_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_max_num_sentences(void *cfg, int32_t v);
|
||||
void sherpa_shim_tts_config_set_kokoro_model(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_kokoro_voices(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_kokoro_tokens(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_kokoro_data_dir(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_kokoro_dict_dir(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_kokoro_lexicon(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_kokoro_lang(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_kokoro_length_scale(void *cfg, float v);
|
||||
void *sherpa_shim_create_offline_tts(void *cfg);
|
||||
|
||||
// --- Offline recognizer config (Whisper / Paraformer / SenseVoice / Omnilingual) ---
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=92dc7268fc4ffb0c0cc0bd52dfcefea91326e797
|
||||
STABLEDIFFUSION_GGML_VERSION?=19bdfe22d255d5b4dff39d449318b9bc5ea2317f
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -386,6 +386,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *llm_vision_path = "";
|
||||
const char *diffusion_model_path = stableDiffusionModel;
|
||||
const char *high_noise_diffusion_model_path = "";
|
||||
const char *uncond_diffusion_model_path = "";
|
||||
const char *taesd_path = "";
|
||||
const char *control_net_path = "";
|
||||
const char *embedding_dir = "";
|
||||
@@ -472,6 +473,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval);
|
||||
if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "uncond_diffusion_model_path")) uncond_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
|
||||
if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
|
||||
if (!strcmp(optname, "embedding_dir")) {
|
||||
@@ -571,6 +573,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.llm_vision_path = llm_vision_path;
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.uncond_diffusion_model_path = uncond_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
|
||||
@@ -26,8 +26,16 @@ add_library(govibevoicecpp MODULE cpp/govibevoicecpp.cpp)
|
||||
# vv_capi_* symbols (purego dlopens them by name, nothing in our
|
||||
# translation unit references them). Force the static archive's
|
||||
# entire contents into the MODULE so dlsym finds vv_capi_load etc.
|
||||
#
|
||||
# Link the `vibevoice` TARGET (not a bare archive path) so CMake builds
|
||||
# libvibevoice.a first and tracks the dependency: the upstream project is added
|
||||
# with EXCLUDE_FROM_ALL, so without a target-level link there is no rule to
|
||||
# build it. Passing only $<TARGET_FILE:vibevoice> as a path on Apple left the
|
||||
# build with "No rule to make target 'vibevoice/libvibevoice.a'" (issue #10267).
|
||||
# force_load is then applied as a separate link option.
|
||||
if(APPLE)
|
||||
target_link_libraries(govibevoicecpp PRIVATE -Wl,-force_load $<TARGET_FILE:vibevoice>)
|
||||
target_link_libraries(govibevoicecpp PRIVATE vibevoice)
|
||||
target_link_options(govibevoicecpp PRIVATE "-Wl,-force_load,$<TARGET_FILE:vibevoice>")
|
||||
elseif(MSVC)
|
||||
target_link_libraries(govibevoicecpp PRIVATE vibevoice)
|
||||
set_property(TARGET govibevoicecpp APPEND PROPERTY LINK_FLAGS "/WHOLEARCHIVE:vibevoice")
|
||||
|
||||
@@ -94,26 +94,30 @@ purge:
|
||||
# Build all variants (Linux only)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgovibevoicecpp-avx.so: sources/vibevoice.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I vibevoice-cpp build info:avx${RESET})
|
||||
SO_TARGET=libgovibevoicecpp-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgovibevoicecpp-custom
|
||||
rm -rf build-libgovibevoicecpp-avx.so
|
||||
rm -rfv build*
|
||||
|
||||
libgovibevoicecpp-avx2.so: sources/vibevoice.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I vibevoice-cpp build info:avx2${RESET})
|
||||
SO_TARGET=libgovibevoicecpp-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgovibevoicecpp-custom
|
||||
rm -rf build-libgovibevoicecpp-avx2.so
|
||||
rm -rfv build*
|
||||
|
||||
libgovibevoicecpp-avx512.so: sources/vibevoice.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I vibevoice-cpp build info:avx512${RESET})
|
||||
SO_TARGET=libgovibevoicecpp-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgovibevoicecpp-custom
|
||||
rm -rf build-libgovibevoicecpp-avx512.so
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
# Build fallback variant (all platforms)
|
||||
libgovibevoicecpp-fallback.so: sources/vibevoice.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I vibevoice-cpp build info:fallback${RESET})
|
||||
SO_TARGET=libgovibevoicecpp-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgovibevoicecpp-custom
|
||||
rm -rf build-libgovibevoicecpp-fallback.so
|
||||
rm -rfv build*
|
||||
|
||||
libgovibevoicecpp-custom: CMakeLists.txt cpp/govibevoicecpp.cpp cpp/govibevoicecpp.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=27101c01dcac1676e2b6422256233cd0f1f9ae28
|
||||
WHISPER_CPP_VERSION?=df7638d8229a243af8a4b5a8ae557e0d74e0a0ae
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user