mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-12 02:38:19 -04:00
Compare commits
156 Commits
v4.3.2
...
feat/dllm-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9c6040fe8 | ||
|
|
8134d6db37 | ||
|
|
ad6d1dbc8b | ||
|
|
eb61e1d770 | ||
|
|
aba9c4794a | ||
|
|
04d6f66a9a | ||
|
|
52b3b68cea | ||
|
|
99184809fa | ||
|
|
294c04ae2f | ||
|
|
778f85c2a0 | ||
|
|
af0db1419c | ||
|
|
892fc49949 | ||
|
|
228a6dfe79 | ||
|
|
51a92b6093 | ||
|
|
b5964d385d | ||
|
|
fba8c9c498 | ||
|
|
6b2badb837 | ||
|
|
8b8506d01a | ||
|
|
6910a0bb48 | ||
|
|
cffd03b522 | ||
|
|
bf448d3794 | ||
|
|
1d4a12f7c0 | ||
|
|
186d62801d | ||
|
|
da4ed05429 | ||
|
|
ec1eea4f45 | ||
|
|
b203b32e57 | ||
|
|
48a8ce98aa | ||
|
|
8344d1c865 | ||
|
|
d2e6b93369 | ||
|
|
e1ec03d33f | ||
|
|
9323f4b5ca | ||
|
|
c20225fc13 | ||
|
|
337acc4c37 | ||
|
|
618e90cd13 | ||
|
|
92dea961c2 | ||
|
|
2e93186043 | ||
|
|
d07037e817 | ||
|
|
f6cc90d258 | ||
|
|
2c804bef5a | ||
|
|
6070402477 | ||
|
|
67f80a152b | ||
|
|
a7cb587d96 | ||
|
|
f7c74ad2da | ||
|
|
7402d1fd20 | ||
|
|
8c42695ef8 | ||
|
|
72e3241431 | ||
|
|
cd2bf95862 | ||
|
|
f64b72dd7d | ||
|
|
03c84cff28 | ||
|
|
9bc69c9e5f | ||
|
|
1e6c9cfd60 | ||
|
|
0e6712f734 | ||
|
|
0e4cee9a97 | ||
|
|
352b7ec604 | ||
|
|
ba706422fb | ||
|
|
e837921c2c | ||
|
|
73385713ca | ||
|
|
a4e671779a | ||
|
|
7051b2e0a1 | ||
|
|
469737101a | ||
|
|
858257eaf0 | ||
|
|
ef80a0e825 | ||
|
|
92726f7631 | ||
|
|
994063ba9a | ||
|
|
c1a55cf72d | ||
|
|
96758841d8 | ||
|
|
7a59260621 | ||
|
|
27e63b9a78 | ||
|
|
55c0911c23 | ||
|
|
f6cb6ab6d9 | ||
|
|
9f11b09c6a | ||
|
|
a5c4f822f0 | ||
|
|
fb36c262fe | ||
|
|
0e4e8980e6 | ||
|
|
3a932a9803 | ||
|
|
9d10418593 | ||
|
|
5470051d4d | ||
|
|
68c5eeebc3 | ||
|
|
1531fabe23 | ||
|
|
b7673d5b76 | ||
|
|
b64bdaf406 | ||
|
|
eebf08ff1d | ||
|
|
42e51894c3 | ||
|
|
d9ae6481fb | ||
|
|
f1c495a748 | ||
|
|
415b561947 | ||
|
|
e6a0d4c375 | ||
|
|
7e59a5c7c5 | ||
|
|
aea954a482 | ||
|
|
595e448714 | ||
|
|
860f9d63ad | ||
|
|
a5a0b3dc4e | ||
|
|
94eca04c60 | ||
|
|
35bd485d6a | ||
|
|
1fe96f8d9a | ||
|
|
c508e9d7c6 | ||
|
|
55e754fd05 | ||
|
|
a17753f7d1 | ||
|
|
c61838dba6 | ||
|
|
7013e13f05 | ||
|
|
5a0013defe | ||
|
|
c01ed631d6 | ||
|
|
d47464cb06 | ||
|
|
63f176346e | ||
|
|
af94d08729 | ||
|
|
6795d38f50 | ||
|
|
718223f33b | ||
|
|
39e050d9e2 | ||
|
|
c222161291 | ||
|
|
aa80d4681b | ||
|
|
0d57957ebb | ||
|
|
76fe0bb929 | ||
|
|
baa11133f1 | ||
|
|
1bdd3338a6 | ||
|
|
e08492a2c3 | ||
|
|
d5d8fe909d | ||
|
|
8a82753277 | ||
|
|
51ca109067 | ||
|
|
07f6c15a37 | ||
|
|
a44bdb29d4 | ||
|
|
aee4611ab2 | ||
|
|
486467623c | ||
|
|
4912c9b73a | ||
|
|
12d1f3a697 | ||
|
|
a7cad704b9 | ||
|
|
7e4df67556 | ||
|
|
5b24b4dacc | ||
|
|
52fdb46892 | ||
|
|
b389f0fe5f | ||
|
|
74281be340 | ||
|
|
cacf2f7a2c | ||
|
|
4a2cc64d07 | ||
|
|
4647770316 | ||
|
|
3c9b9529c0 | ||
|
|
fc2bd0986c | ||
|
|
a473a32678 | ||
|
|
3e220373b0 | ||
|
|
fbcd886a47 | ||
|
|
e1a782b70f | ||
|
|
73cfedc023 | ||
|
|
b982c977d5 | ||
|
|
532ca1b3a2 | ||
|
|
00ad55b590 | ||
|
|
4c58fd302f | ||
|
|
66582e7035 | ||
|
|
1d13949588 | ||
|
|
c8ad67bbca | ||
|
|
1c92b00918 | ||
|
|
b81a6d01b3 | ||
|
|
0fd666ee6e | ||
|
|
7763fb23a3 | ||
|
|
324277ccfd | ||
|
|
10d02e6c59 | ||
|
|
05ae06c17b | ||
|
|
2671e0c6f7 | ||
|
|
81b6b94f0b |
@@ -38,9 +38,12 @@ The React UI (`core/http/react-ui/`) has **no component/unit tests** — its onl
|
||||
- **Browser:** the flake dev shell ships `chromium` and exports `PLAYWRIGHT_CHROMIUM_PATH`; `playwright.config.js` uses it via `launchOptions.executablePath`, and the Makefile skips `playwright install` when it's set. This avoids Playwright's downloaded browser, which can't resolve system libs (`libglib-2.0`, …) on NixOS. In CI (no `PLAYWRIGHT_CHROMIUM_PATH`) the Makefile falls back to `playwright install --with-deps chromium`.
|
||||
- The app is a React SPA, so coverage accumulates across in-app navigation within a test; a full `page.goto`/reload resets it.
|
||||
- `.nycrc.json` uses `all: true`, so **every `src/**` file is in the report**, including 0%-coverage ones — that's how you spot features with no test at all (sort the HTML report or `coverage-summary.json` by line% ascending).
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` (default **1.0pp**) below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. **Why a tolerance (unlike the strict Go gate):** UI e2e line coverage is *non-deterministic* — async/debounced paths (e.g. the VRAM estimate's 500ms debounce) make identical specs vary ~0.5pp run-to-run, so a zero-tolerance gate would flake. Keep the tolerance just above the observed jitter. Run in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. Runs in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **Why it has a tolerance (unlike the strict Go gate):** UI e2e coverage is *non-deterministic*. Specs that assert on state and end while async/lazy render work is still in flight collect those lines only when the render beats the coverage teardown — so the total drifts with machine speed/load (a fast local box reads higher than a slow CI runner), diffusely across many specs. The tolerance absorbs that drift, so set the baseline *below* the slow-CI floor, never to a fast-local `make test-ui-coverage-baseline` number, or CI flaps.
|
||||
- **Raising coverage is cheap:** a *render-smoke* spec (navigate to a route, assert its header renders) mounts a lazy page and runs its full render + initial effects, capturing most of its lines in a few lines of test — see `e2e/page-render-smoke.spec.js`. Auth is disabled in the test server (`isAdmin=true`), so `RequireAdmin`/`RequireFeature` routes render without a mock. The most *deterministic* win is removing a race: make a spec `await` a rendered element before ending (see `e2e/agents.spec.js` → AgentCreate) so its lines count every run.
|
||||
|
||||
Rules:
|
||||
- The gate is **strict — there is no tolerance**. Any decrease fails, regardless of how many lines a PR adds or deletes. `covermode=atomic` makes line coverage deterministic, so there's no run-to-run jitter to excuse.
|
||||
- When a change legitimately **raises** coverage, run `make test-coverage-baseline` and **commit** the updated `coverage-baseline.txt` so the ratchet moves up. Never lower the baseline by hand.
|
||||
- If you can't get coverage back to baseline, the fix is to **add tests**, not to edit the baseline.
|
||||
Rules (both gates):
|
||||
- **Install the hooks:** `make install-hooks` once per clone so lint + coverage run pre-commit. Don't lean on CI for what the hook catches.
|
||||
- **Don't work around the gate:** never `git commit --no-verify`, and never hand-lower a baseline or widen a tolerance to turn a red gate green. The ratchet only moves up.
|
||||
- If a change drops coverage, **add tests** (sort `coverage-summary.json` by line% ascending to find untested code) rather than editing the baseline. When coverage legitimately rises, commit the regenerated baseline (`make test-coverage-baseline` / `test-ui-coverage-baseline`).
|
||||
- The Go gate is **strict — no tolerance**; `covermode=atomic` keeps it deterministic. The UI gate keeps a small tolerance only because its e2e coverage isn't.
|
||||
|
||||
@@ -50,6 +50,17 @@ Do not mix styles within a package. If you are extending tests in a package that
|
||||
|
||||
This is enforced by `golangci-lint` via the `forbidigo` linter (see `.golangci.yml`); calls like `t.Errorf` / `t.Fatalf` / `t.Run` / `t.Skip` / `t.Logf` are flagged. Run `make lint` locally before submitting; the same check runs in CI (`.github/workflows/lint.yml`).
|
||||
|
||||
## Outbound HTTP
|
||||
|
||||
All outbound HTTP must go through `github.com/mudler/LocalAI/pkg/httpclient` rather than the standard library's default client. Use `httpclient.New(...)` (no body deadline — safe for streaming/SSE) or `httpclient.NewWithTimeout(d, ...)` (simple request/response). Both **refuse redirects by default** and set a TLS 1.2 floor.
|
||||
|
||||
The reason is GHSA-3mj3-57v2-4636: the std default client follows redirects, and on a *cross-host* redirect Go forwards custom credential headers (e.g. Anthropic's `x-api-key`) to the redirect target, leaking the secret. `httpclient` fails closed instead.
|
||||
|
||||
- Need to follow redirects (download CDNs, registry blobs, GitHub asset URLs)? Pass `httpclient.WithFollowRedirects()` — it still strips credential headers on any cross-host hop.
|
||||
- Have a custom transport (IP-pinned dialer, HTTP/2 tuning, a credential-injecting `RoundTripper`)? Pass `httpclient.WithTransport(rt)`, basing the transport on `httpclient.HardenedTransport()` to keep the TLS floor. Handed a `*http.Client` by a library? `httpclient.Harden(c)` applies the policy in place.
|
||||
|
||||
This is enforced by `forbidigo` (see `.golangci.yml`): `http.DefaultClient` and `http.Get`/`Post`/`PostForm`/`Head` are flagged. The `&http.Client{}` composite literal can't be matched precisely by forbidigo without also flagging legitimate `*http.Client` type references, so that form is caught by review — don't construct raw clients.
|
||||
|
||||
## Documentation
|
||||
|
||||
The project documentation is located in `docs/content`. When adding new features or changing existing functionality, it is crucial to update the documentation to reflect these changes. This helps users understand how to use the new capabilities and ensures the documentation stays relevant.
|
||||
|
||||
138
.agents/dllm-backend.md
Normal file
138
.agents/dllm-backend.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# Working on the dllm Backend
|
||||
|
||||
`mudler/dllm.cpp` is a standalone C++/ggml engine for DiffusionGemma
|
||||
block-diffusion models. LocalAI wraps it with a **pure-Go** backend at
|
||||
`backend/go/dllm/` that dlopens `libdllm.so` via purego (ebitengine/purego) -
|
||||
NOT cgo, and NOT a C++ grpc-server fork. The Go side owns chat templating
|
||||
(gemma4 renderer) and output parsing (gemma4 streaming parser) and implements
|
||||
the rich gRPC interface (`PredictRich`/`PredictStreamRich`, ChatDelta replies).
|
||||
|
||||
> NOTE: github.com/mudler/dllm.cpp is still **private** (publishing is
|
||||
> planned). Until then the Makefile's anonymous clone fails; use the local-dev
|
||||
> symlink shortcut documented at the top of `backend/go/dllm/Makefile`
|
||||
> (symlink an out-of-tree `build/libdllm.so` into the backend dir and skip the
|
||||
> clone), or a git credential helper with repo access.
|
||||
|
||||
## Pin
|
||||
|
||||
`backend/go/dllm/Makefile` pins `DLLM_VERSION?=<sha>` at the top
|
||||
(whisper / parakeet-cpp / ds4 convention). The bump-deps bot
|
||||
(`.github/workflows/bump_deps.yaml`) tracks `mudler/dllm.cpp` `main` and
|
||||
rewrites that variable. After a manual bump: `make -C backend/go/dllm purge &&
|
||||
make -C backend/go/dllm` (the clone is keyed on the directory existing, not
|
||||
the sha).
|
||||
|
||||
## C-ABI and the serialization contract
|
||||
|
||||
The binding covers the 9-symbol flat C-ABI from dllm.cpp's
|
||||
`include/dllm_capi.h` (ABI v1; `main.go` hard-fails on a version mismatch):
|
||||
`abi_version, load, free, last_error, free_string, tokenize_json, generate,
|
||||
generate_stream, cancel`. Contract points the Go wiring encodes (`capi.go`
|
||||
header comment has the full list):
|
||||
|
||||
- **One ctx = one concurrent generate/tokenize.** A per-model worker
|
||||
goroutine (`Dllm.jobs` in `dllm.go`) owns ALL C calls, making the
|
||||
serialization structural instead of lock discipline.
|
||||
- **`dllm_capi_cancel` is the ONE exception**: it only flips an atomic and may
|
||||
be called from any goroutine mid-generate, so `Dllm.Cancel` bypasses the
|
||||
worker queue. The flag resets at the start of each generate, so a watchdog
|
||||
racing a new generate must re-issue cancel.
|
||||
- **`last_error` is a borrowed pointer** and must only be read AFTER the
|
||||
failing call returned (never while a generate is in flight on the same ctx).
|
||||
- **Free vs in-flight requests**: requests hold `genMu.RLock` for their full
|
||||
duration; `Free` takes the write lock, so it only runs when nothing is in
|
||||
flight, then drains and closes the worker. Post-Free requests get a clean
|
||||
"model not loaded" error.
|
||||
- `tokenize_json`/`generate` return malloc'd `char*` (bound as `uintptr`,
|
||||
copied, then `dllm_capi_free_string`d); opts/params JSON must be a FLAT
|
||||
object of scalars (`buildOptsJSON` rejects anything else).
|
||||
|
||||
## Wire shape
|
||||
|
||||
| RPC | Implementation |
|
||||
|---|---|
|
||||
| LoadModel | `dllm_capi_load` (params: `n_gpu_layers`, `n_threads`, `ctx_len`); `Options[]` parsed into per-request gen opts (`eb_*`, `blocks`, `kv_cache`) by `parseModelGenOpts` |
|
||||
| PredictRich | render (if templated) → `dllm_capi_generate` → parse → ONE Reply with aggregated ChatDeltas + legacy `Message` bytes |
|
||||
| PredictStreamRich | `dllm_capi_generate_stream`; per committed diffusion block → UTF-8 holdback → parser.Feed → one Reply per non-empty delta batch (channel closed by the CALLER, per `pkg/grpc/interface.go`) |
|
||||
| Predict / PredictStream | Legacy paths, delegate to the rich pair (legacy stream INVERTS channel ownership: the impl closes) |
|
||||
| TokenizeString | `dllm_capi_tokenize_json` (C side prepends BOS per `vocab.add_bos`) |
|
||||
| Cancel | `dllm_capi_cancel`, exposed as the `grpc.Cancellable` capability (`pkg/grpc/interface.go`): the gRPC server arms it via `context.AfterFunc` on the Predict/PredictStream context, so client disconnects/timeouts abort the in-flight generate - llama.cpp `IsCancelled()` parity for Go backends |
|
||||
|
||||
`n_threads` and `ctx_len` are accepted-but-ignored by the engine at the
|
||||
current pin (the context bound comes from GGUF `n_ctx_train`); they are sent
|
||||
for forward compatibility.
|
||||
|
||||
## Renderer / parser (the templated chat path)
|
||||
|
||||
With `use_tokenizer_template` + raw Messages, the backend owns templating and
|
||||
parsing (the ds4 precedent, but in Go):
|
||||
|
||||
- `gemma4_renderer.go` - `RenderGemma4(msgs, toolsJSON, enableThinking,
|
||||
addGenerationPrompt)`. The file embeds the FULL `tokenizer.chat_template`
|
||||
jinja (17466 bytes, md5 `8c34cf93c7a7815b3fdb300a009c4c17`) extracted
|
||||
verbatim from `diffusiongemma-26B-A4B-it-BF16.gguf` via gguf-py - e.g.
|
||||
`python scripts/dump_gguf.py model.gguf | grep -A400 chat_template` in the
|
||||
dllm.cpp checkout - as a numbered comment block; every Go rule cites its
|
||||
"tpl L<n>" line. Re-verify the md5 before blaming the renderer for a
|
||||
mismatch with a new GGUF. **BOS exception**: the template emits
|
||||
`{{- bos_token -}}` but the renderer deliberately does NOT - dllm.cpp's
|
||||
`run_generate` tokenizes with `prepend_bos = vocab.add_bos` (true for
|
||||
gemma4), so a literal `<bos>` would double it.
|
||||
- `gemma4_parser.go` - streaming state machine turning raw model text
|
||||
(fragments can split anywhere, including mid-marker) into ChatDeltas:
|
||||
thought channels → `reasoning_content`, `<|tool_call>call:name{...}` →
|
||||
ToolCallDelta, `<turn|>` → done. Marker grammar cross-checked against vLLM
|
||||
PR #45163's gemma4 tool/reasoning parsers. Malformed payloads are re-emitted
|
||||
raw as content, never dropped.
|
||||
- Thinking is **opt-in** for this family (`Metadata["enable_thinking"]`,
|
||||
default OFF - the inverse of ds4): the template gates every thinking branch
|
||||
on `enable_thinking`, and the no-thinking render pre-closes an empty thought
|
||||
channel, so the parser always starts in content state.
|
||||
- **UTF-8 boundary holdback** (`splitValidUTF8` in `dllm.go`): per-block
|
||||
detokenization can split a multi-byte character across block boundaries, and
|
||||
grpc-go refuses to marshal invalid UTF-8 in proto3 strings. An incomplete
|
||||
trailing sequence (at most 3 bytes) is carried into the next block; genuinely
|
||||
undecodable bytes become U+FFFD.
|
||||
|
||||
Without `use_tokenizer_template`, the prompt passes through verbatim and the
|
||||
output is NOT gemma4-parsed (plain content, like any non-autoparsing backend).
|
||||
|
||||
## Tests
|
||||
|
||||
| Layer | Gate | What |
|
||||
|---|---|---|
|
||||
| `backend/go/dllm/*_test.go` (renderer/parser/wiring) | none - run in plain `go test ./backend/go/dllm/...` | Ginkgo specs over a fake `generator` seam; canonical renderer fixtures from transformers' `test_modeling_diffusion_gemma.py`, parser tables from the vLLM gemma4 parsers |
|
||||
| `backend/go/dllm/dllm_test.go` C-ABI smoke | `DLLM_TEST_LIBRARY` + `DLLM_TEST_TINY_MODEL` (dllm.cpp's `tests/fixtures/tiny_with_vocab.gguf`); Skips when unset | Drives the real `libdllm.so`: ABI check, load, tokenize `[2,18]`, deterministic generate, cancel (incl. mid-stream `Dllm.Cancel` aborting a deliberately slow `eb_max_steps:256` run in ~10ms) |
|
||||
| `tests/e2e-backends/dllm_test.go` | `BACKEND_TEST_DLLM=1` + `BACKEND_BINARY` (packaged run.sh) + `BACKEND_TEST_MODEL_FILE` (tiny fixture) | Templated chat round trip (Messages + UseTokenizerTemplate) over the real gRPC binary, non-streaming + streaming; plus client-context cancellation mid-stream (proves the `Cancellable` server plumbing end to end) |
|
||||
| Real-model e2e | `BACKEND_TEST_DLLM_REAL_MODEL_FILE` (26B BF16, ~50 GB) + `BACKEND_TEST_DLLM_REAL_GPU_LAYERS` | CUDA-13-class hardware only |
|
||||
|
||||
Tool-call e2e is deliberately absent from the tiny-model spec: the fixture has
|
||||
random weights and cannot be coaxed into emitting tool markup; the unit tables
|
||||
carry that coverage.
|
||||
|
||||
## Build matrix
|
||||
|
||||
`cpu-dllm` (amd64 + arm64), `cuda13-dllm` (amd64), and
|
||||
`cuda13-nvidia-l4t-arm64-dllm` (arm64 CUDA: Jetson / DGX Spark GB10), via
|
||||
`.github/backend-matrix.yml`. No darwin/Metal. CUDA builds forward
|
||||
`-DDLLM_CUDA=ON` (dllm.cpp gates ggml's CUDA behind its own flag - a bare
|
||||
`-DGGML_CUDA=ON` is overridden by the cache FORCE). `libdllm.so` is
|
||||
self-contained (ggml statically absorbed, PIC), so `package.sh` only ships
|
||||
the binary, `run.sh` and that one .so (the parakeet-cpp-style stub layout;
|
||||
no ldd walk yet).
|
||||
|
||||
## Known limitations
|
||||
|
||||
- **Cancel granularity**: the C-ABI cancel flag is per-ctx and resets on
|
||||
every generate entry, so a Cancel racing a NEW generate can be lost, and
|
||||
with requests queued on the worker it aborts whichever generate is
|
||||
currently running (acceptable: the server de-registers the hook on normal
|
||||
completion, one process serves one model).
|
||||
- **Throughput**: ~0.15 tok/s on the 26B at default settings (GB10) - every
|
||||
denoise step recomputes the full prompt+canvas. The upstream prefix-KV
|
||||
cache (dllm.cpp P3) is the fix; `kv_cache:on` errors until it lands
|
||||
(`auto`/`off` are accepted no-ops).
|
||||
- **Repo privacy**: see the note at the top - CI clone of dllm.cpp needs the
|
||||
repo published (or credentials) before the backend images can build.
|
||||
- Engine spec/validation references: dllm.cpp `docs/validation.md` and
|
||||
LocalAI `docs/superpowers/specs/2026-06-10-dllm-cpp-design.md`.
|
||||
@@ -68,6 +68,34 @@ go test -count=1 -timeout=30m -v ./tests/e2e-backends/...
|
||||
|
||||
CI does not load the model; the suite is opt-in via env vars.
|
||||
|
||||
## Distributed mode
|
||||
|
||||
ds4 supports **layer-split** distributed inference (a model too big for one host,
|
||||
split by transformer layer; the GGUF must be present on every machine, each loads
|
||||
only its slice). Topology is **inverted** vs llama.cpp: the coordinator listens,
|
||||
workers dial in.
|
||||
|
||||
- **`ds4-worker` binary**: built and packaged next to `grpc-server` (`package.sh`
|
||||
copies it into `package/`). Links the same engine objects plus `ds4_distributed.o`;
|
||||
**no gRPC/protobuf dependency** (speaks ds4's own TCP transport), so it builds
|
||||
even where `grpc-server` can't. Runs the worker serving loop (`ds4_dist_run`).
|
||||
- **Coordinator wiring**: the ds4 `grpc-server` acts as coordinator when `LoadModel`
|
||||
`ModelOptions.Options` (from model-YAML `options:`) carry:
|
||||
- `ds4_role:coordinator` (enables distributed mode; absent → single-node, back-compat)
|
||||
- `ds4_layers:0:19` (coordinator's own slice, inclusive; `N:output` includes the head)
|
||||
- `ds4_listen:0.0.0.0:1234` (address workers dial into)
|
||||
- `ds4_route_timeout:60` (optional; seconds Predict/PredictStream wait for the route
|
||||
to form before returning gRPC `UNAVAILABLE`; default 60)
|
||||
- **Worker CLI**: `local-ai worker ds4-distributed -- <ds4-worker args>` resolves the
|
||||
ds4 backend and execs the packaged `ds4-worker` (raw passthrough), e.g.
|
||||
`--role worker --model /models/ds4flash.gguf --layers 20:output --coordinator <host> 1234`.
|
||||
|
||||
Opt-in e2e in `tests/e2e-backends/backend_test.go`, gated by
|
||||
`BACKEND_TEST_DS4_DISTRIBUTED=1` (plus `BACKEND_TEST_DS4_WORKER_BINARY`,
|
||||
`BACKEND_TEST_DS4_WORKER_LAYERS`, `BACKEND_TEST_DS4_COORDINATOR_LAYERS`,
|
||||
`BACKEND_TEST_DS4_LISTEN`). Design spec:
|
||||
`docs/superpowers/specs/2026-05-30-ds4-distributed-inference-design.md`.
|
||||
|
||||
## Importer
|
||||
|
||||
`core/gallery/importers/ds4.go` (`DS4Importer`) auto-detects ds4 weights by
|
||||
|
||||
372
.github/backend-matrix.yml
vendored
372
.github/backend-matrix.yml
vendored
@@ -716,6 +716,32 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -1556,6 +1582,45 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-dllm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1569,6 +1634,45 @@ include:
|
||||
backend: "whisper"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-crispasr'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-parakeet-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-dllm'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1688,20 +1792,6 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-turboquant'
|
||||
builder-base-image: 'quay.io/go-skynet/ci-cache:base-grpc-rocm-amd64'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2850,6 +2940,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2864,6 +2968,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2877,6 +2995,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2890,6 +3021,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2904,6 +3048,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2918,6 +3076,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2931,6 +3103,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-crispasr'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2944,6 +3129,157 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-crispasr'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# parakeet-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-parakeet-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# dllm
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-dllm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-dllm'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-parakeet-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-parakeet-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-parakeet-cpp'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# acestep-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -3976,6 +4312,14 @@ includeDarwin:
|
||||
tag-suffix: "-metal-darwin-arm64-whisper"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "crispasr"
|
||||
tag-suffix: "-metal-darwin-arm64-crispasr"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "parakeet-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-parakeet-cpp"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "acestep-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-acestep-cpp"
|
||||
build-type: "metal"
|
||||
|
||||
13
.github/gallery-agent/main.go
vendored
13
.github/gallery-agent/main.go
vendored
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -113,6 +114,17 @@ func main() {
|
||||
fmt.Println("Searching for trending models on HuggingFace...")
|
||||
rawModels, err := client.GetTrending(searchTerm, limit)
|
||||
if err != nil {
|
||||
if errors.Is(err, hfapi.ErrRateLimited) {
|
||||
fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err)
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: 0,
|
||||
ModelsAdded: 0,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -277,4 +289,3 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
|
||||
12
.github/workflows/bump_deps.yaml
vendored
12
.github/workflows/bump_deps.yaml
vendored
@@ -30,6 +30,18 @@ jobs:
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/whisper/Makefile"
|
||||
- repository: "CrispStrobe/CrispASR"
|
||||
variable: "CRISPASR_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/crispasr/Makefile"
|
||||
- repository: "mudler/parakeet.cpp"
|
||||
variable: "PARAKEET_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/parakeet-cpp/Makefile"
|
||||
- repository: "mudler/dllm.cpp"
|
||||
variable: "DLLM_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/dllm/Makefile"
|
||||
- repository: "leejet/stable-diffusion.cpp"
|
||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||
branch: "master"
|
||||
|
||||
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: securego/gosec@v2.22.9
|
||||
uses: securego/gosec@v2.27.1
|
||||
with:
|
||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||
|
||||
21
.github/workflows/test-extra.yml
vendored
21
.github/workflows/test-extra.yml
vendored
@@ -46,6 +46,7 @@ jobs:
|
||||
speaker-recognition: ${{ steps.detect.outputs.speaker-recognition }}
|
||||
sherpa-onnx: ${{ steps.detect.outputs.sherpa-onnx }}
|
||||
whisper: ${{ steps.detect.outputs.whisper }}
|
||||
parakeet-cpp: ${{ steps.detect.outputs.parakeet-cpp }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -633,6 +634,26 @@ jobs:
|
||||
- name: Build whisper backend image and run transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-whisper-transcription
|
||||
# Parakeet ASR via the parakeet-cpp backend (C++/ggml port of NeMo
|
||||
# Parakeet). Drives AudioTranscription (offline, with word timestamps) on
|
||||
# tdt_ctc-110m + the JFK 11s clip.
|
||||
tests-parakeet-cpp-grpc-transcription:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.parakeet-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
- name: Build parakeet-cpp backend image and run transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-parakeet-cpp-transcription
|
||||
# VITS TTS via the sherpa-onnx backend. Drives both TTS (file write) and
|
||||
# TTSStream (PCM chunks) on the e2e-backends harness.
|
||||
tests-sherpa-onnx-grpc-tts:
|
||||
|
||||
@@ -56,6 +56,20 @@ linters:
|
||||
# are exempt — see linters.exclusions.rules below.
|
||||
- pattern: '^os\.(Getenv|LookupEnv|Environ)$'
|
||||
msg: 'Plumb config through ApplicationConfig (or the relevant CLI struct) instead of reading env directly. CLI entry points (core/cli/) bind env vars via kong''s `env:` tag — that is the only sanctioned env→struct boundary. See .agents/coding-style.md.'
|
||||
# Outbound HTTP must go through pkg/httpclient, which refuses redirects
|
||||
# by default and sets a TLS floor. The std-library default client and
|
||||
# the http.Get/Post/... convenience helpers follow redirects (up to 10)
|
||||
# and, on a cross-host redirect, forward custom credential headers such
|
||||
# as Anthropic's x-api-key to the redirect target — leaking the secret
|
||||
# (GHSA-3mj3-57v2-4636). forbidigo can't precisely match the
|
||||
# `&http.Client{}` composite literal without also flagging legitimate
|
||||
# `*http.Client` type references, so that form is enforced by
|
||||
# convention + review; these two patterns catch the implicit-default
|
||||
# client, which is the common footgun.
|
||||
- pattern: '^http\.DefaultClient$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.DefaultClient — the std client follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
- pattern: '^http\.(Get|Post|PostForm|Head)$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.Get/Post/PostForm/Head — these use http.DefaultClient, which follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
exclusions:
|
||||
paths:
|
||||
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
||||
@@ -95,3 +109,18 @@ linters:
|
||||
- path: _test\.go$
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# pkg/httpclient is the sanctioned home for outbound HTTP clients; it
|
||||
# necessarily references net/http directly.
|
||||
- path: ^pkg/httpclient/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Tests drive local httptest servers where redirect/TLS hardening is
|
||||
# irrelevant; the std client is fine there.
|
||||
- path: _test\.go$
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Vendored upstream whisper.cpp Go bindings are a separate module and
|
||||
# cannot import pkg/httpclient.
|
||||
- path: ^backend/go/whisper/sources/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
|
||||
@@ -26,6 +26,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
|
||||
| [.agents/sglang-backend.md](.agents/sglang-backend.md) | Working on the SGLang backend — `engine_args` validation against ServerArgs, speculative-decoding (EAGLE/EAGLE3/DFLASH/MTP) recipes, parser handling |
|
||||
| [.agents/ds4-backend.md](.agents/ds4-backend.md) | Working on the ds4 backend - DSML state machine, thinking modes, KV cache, Metal+CUDA matrix |
|
||||
| [.agents/dllm-backend.md](.agents/dllm-backend.md) | Working on the dllm backend (DiffusionGemma block-diffusion) - purego C-ABI binding, per-ctx serialization contract, gemma4 renderer/parser, gated test layers |
|
||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
@@ -35,6 +36,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
|
||||
## Quick Reference
|
||||
|
||||
- **Git hooks & coverage gates**: Run `make install-hooks` once per clone so the pre-commit lint + coverage gates run. **Never bypass them with `git commit --no-verify`, and never lower a coverage baseline or widen a gate's tolerance to turn a red gate green** — the coverage ratchet only moves up. If a change drops coverage, add tests to raise it (e.g. render-smoke specs). See [.agents/building-and-testing.md](.agents/building-and-testing.md).
|
||||
- **Logging**: Use `github.com/mudler/xlog` (same API as slog)
|
||||
- **Go style**: Prefer `any` over `interface{}`
|
||||
- **Comments**: Explain *why*, not *what*
|
||||
|
||||
@@ -266,6 +266,12 @@ The e2e tests run LocalAI in a Docker container and exercise the API:
|
||||
make test-e2e
|
||||
```
|
||||
|
||||
### React UI tests and coverage
|
||||
|
||||
The React UI (`core/http/react-ui/`) is covered by Playwright e2e specs, gated by a **monotonic line-coverage ratchet** (`make test-ui-coverage-check`, run in CI and pre-commit). The metric is non-deterministic — a fast local box reads higher than a slow CI runner for the same code — so a small tolerance is unavoidable.
|
||||
|
||||
**If your change lowers UI coverage, raise it back by adding specs — do not widen the tolerance or hand-lower the baseline.** A *render-smoke* spec (navigate to a page, assert its header is visible) cheaply covers an entire lazy page. See `core/http/react-ui/e2e/page-render-smoke.spec.js` and the full policy in [.agents/building-and-testing.md](.agents/building-and-testing.md#react-ui-coverage).
|
||||
|
||||
### Running E2E container tests
|
||||
|
||||
These tests build a standard LocalAI Docker image and run it with pre-configured model configs to verify that most endpoints work correctly:
|
||||
|
||||
65
Makefile
65
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/dllm backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -180,7 +180,7 @@ osx-signed: build
|
||||
|
||||
## Run
|
||||
run: ## run local-ai
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./cmd/local-ai
|
||||
|
||||
prepare-test: protogen-go build-mock-backend
|
||||
|
||||
@@ -309,13 +309,20 @@ run-e2e-aio: protogen-go
|
||||
@echo 'Running e2e AIO tests'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
|
||||
|
||||
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
|
||||
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
|
||||
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
|
||||
test-e2e-distributed: protogen-go
|
||||
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
|
||||
|
||||
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
|
||||
# cpu-vllm backend from the current working tree, then drives a
|
||||
# head + headless follower via testcontainers-go and asserts a chat
|
||||
# completion. BuildKit caches both images, so re-runs only rebuild
|
||||
# what changed. The test lives under tests/e2e/distributed and is
|
||||
# selected by the VLLMMultinode label so it doesn't run alongside
|
||||
# the other distributed-suite tests by default.
|
||||
# test-e2e-distributed.
|
||||
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
|
||||
@echo 'Running e2e vLLM multi-node DP test'
|
||||
LOCALAI_IMAGE=local-ai \
|
||||
@@ -991,6 +998,19 @@ test-extra-backend-whisper-transcription: docker-build-whisper
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## Audio transcription wrapper for the parakeet-cpp (parakeet.cpp ggml port)
|
||||
## backend. Mirrors test-extra-backend-whisper-transcription: drives the
|
||||
## AudioTranscription / AudioTranscriptionStream RPCs against a published
|
||||
## Parakeet GGUF using the JFK 11s clip from whisper.cpp's CI samples. Not
|
||||
## part of the default test suite - run explicitly once the pinned model URL
|
||||
## is reachable.
|
||||
test-extra-backend-parakeet-cpp-transcription: docker-build-parakeet-cpp
|
||||
BACKEND_IMAGE=local-ai-backend:parakeet-cpp \
|
||||
BACKEND_TEST_MODEL_URL=https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/tdt_ctc-110m-f16.gguf \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## LocalVQE audio transform (joint AEC + noise suppression + dereverb).
|
||||
## Exercises the audio_transform capability end-to-end: batch transform
|
||||
## of a real WAV fixture and bidi streaming of synthetic silent frames.
|
||||
@@ -1149,6 +1169,11 @@ BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
||||
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
||||
# dllm is mudler/dllm.cpp, the DiffusionGemma block-diffusion engine,
|
||||
# wrapped by the purego backend at backend/go/dllm.
|
||||
BACKEND_DLLM = dllm|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
||||
@@ -1236,6 +1261,9 @@ $(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DLLM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
||||
@@ -1285,7 +1313,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-crispasr docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
@@ -1313,6 +1341,13 @@ build-ui-test-server: build-mock-backend react-ui protogen-go
|
||||
test-ui-e2e: build-ui-test-server
|
||||
cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test
|
||||
|
||||
## Optional Playwright worker count for the UI e2e targets below. Pass
|
||||
## UI_TEST_WORKERS=N (e.g. `make test-ui-coverage UI_TEST_WORKERS=20`) to
|
||||
## override Playwright's default (cores/2). Empty by default so Playwright
|
||||
## picks its own worker count.
|
||||
UI_TEST_WORKERS ?=
|
||||
PLAYWRIGHT_WORKERS_FLAG = $(if $(UI_TEST_WORKERS),--workers=$(UI_TEST_WORKERS),)
|
||||
|
||||
## Fast Playwright e2e run used by the pre-commit hook on React UI changes.
|
||||
## Force-rebuilds the (non-instrumented) dist so the suite tests the working
|
||||
## tree — not a stale dist the `react-ui` skip-guard would leave — re-embeds
|
||||
@@ -1322,22 +1357,24 @@ test-ui-e2e: build-ui-test-server
|
||||
test-ui: build-mock-backend protogen-go
|
||||
cd core/http/react-ui && bun install && bun run build
|
||||
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui
|
||||
cd core/http/react-ui && sh $(CURDIR)/scripts/ensure-playwright-browser.sh && bunx playwright test
|
||||
cd core/http/react-ui && sh $(CURDIR)/scripts/ensure-playwright-browser.sh && bunx playwright test $(PLAYWRIGHT_WORKERS_FLAG)
|
||||
|
||||
## React UI code coverage from the Playwright e2e suite. Builds an
|
||||
## istanbul-instrumented bundle (COVERAGE=true), re-embeds it into the
|
||||
## ui-test-server (the dist is //go:embed'ed at compile time), runs the
|
||||
## Playwright specs — which harvest window.__coverage__ via the coverage
|
||||
## fixture — and writes an nyc report to core/http/react-ui/coverage/.
|
||||
## Removes the instrumented dist afterwards so normal builds aren't served
|
||||
## instrumented assets.
|
||||
## React UI code coverage from the Playwright e2e suite. Builds a
|
||||
## NON-instrumented bundle with source maps (COVERAGE_V8=true), re-embeds it
|
||||
## into the ui-test-server (the dist is //go:embed'ed at compile time), runs the
|
||||
## Playwright specs which collect native Chromium V8 coverage (PW_V8_COVERAGE=1)
|
||||
## — far cheaper than istanbul's build-time counters (~40% faster end-to-end) —
|
||||
## convert it to istanbul via v8-to-istanbul in the coverage fixture, and write
|
||||
## an nyc report to core/http/react-ui/coverage/. Removes the dist afterwards so
|
||||
## normal builds aren't served source-mapped assets. (The legacy istanbul path
|
||||
## still exists: `bun run build:coverage` + unset PW_V8_COVERAGE.)
|
||||
test-ui-coverage: build-mock-backend protogen-go
|
||||
trap 'rm -rf "$(CURDIR)/core/http/react-ui/dist"' EXIT; \
|
||||
( cd core/http/react-ui && bun install && bun run build:coverage ) && \
|
||||
( cd core/http/react-ui && bun install && bun run build:coverage-v8 ) && \
|
||||
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui && \
|
||||
( cd core/http/react-ui && rm -rf .nyc_output coverage && \
|
||||
sh $(CURDIR)/scripts/ensure-playwright-browser.sh && \
|
||||
bunx playwright test && bun run coverage:report )
|
||||
PW_V8_COVERAGE=1 bunx playwright test $(PLAYWRIGHT_WORKERS_FLAG) && bun run coverage:report )
|
||||
|
||||
## UI coverage baseline (committed) and the strict gate that compares against
|
||||
## it — the React mirror of test-coverage-baseline / test-coverage-check.
|
||||
|
||||
28
README.md
28
README.md
@@ -31,12 +31,18 @@
|
||||
|
||||
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||
|
||||
- **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
|
||||
- **36+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
|
||||
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready** — API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first** — your data never leaves your infrastructure
|
||||
**A small core, not a bundle.** Each backend wraps a best-in-class engine (llama.cpp, vLLM, whisper.cpp, stable-diffusion, MLX...) in its own image, pulled only when a model needs it. You install nothing you don't use.
|
||||
|
||||
- **Composable by design**: backends are separate and pulled on demand, so you install only what your model needs
|
||||
- **Open and extensible**: load any model, or build your own backend in any language against an open interface
|
||||
- **Drop-in API compatibility**: OpenAI, Anthropic, and ElevenLabs APIs across every backend
|
||||
- **Any model, any modality**: LLMs, vision, voice, image, and video behind one API
|
||||
- **Any hardware**: NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready**: API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents**: autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first**: your data never leaves your infrastructure
|
||||
|
||||

|
||||
|
||||
Created by [Ettore Di Giacinto](https://github.com/mudler) and maintained by the [LocalAI team](#team).
|
||||
|
||||
@@ -143,6 +149,16 @@ local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
||||
local-ai run oci://localai/phi-2:latest
|
||||
```
|
||||
|
||||
To test a running LocalAI server from the terminal, open an interactive chat session from another shell. Inside the prompt, `/models` lists installed models and `/model <name>` switches between them.
|
||||
|
||||
```bash
|
||||
# Terminal 1
|
||||
local-ai run llama-3.2-1b-instruct:q4_k_m
|
||||
|
||||
# Terminal 2
|
||||
local-ai chat --model llama-3.2-1b-instruct:q4_k_m
|
||||
```
|
||||
|
||||
> **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
|
||||
|
||||
For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).
|
||||
|
||||
@@ -537,6 +537,15 @@ message TTSRequest {
|
||||
string dst = 3;
|
||||
string voice = 4;
|
||||
optional string language = 5;
|
||||
// instructions is a free-form, per-request style/voice description (maps to
|
||||
// the OpenAI `instructions` field). Backends that support expressive synthesis
|
||||
// (e.g. Qwen3-TTS CustomVoice/VoiceDesign) prefer this over the static YAML
|
||||
// option when set; backends that don't simply ignore it.
|
||||
optional string instructions = 6;
|
||||
// params carries optional, backend-specific per-request generation parameters
|
||||
// (e.g. Chatterbox exaggeration/cfg_weight/temperature). Values are strings and
|
||||
// coerced by the backend; unset leaves the backend's configured defaults.
|
||||
map<string, string> params = 7;
|
||||
}
|
||||
|
||||
message VADRequest {
|
||||
|
||||
1
backend/cpp/ds4/.gitignore
vendored
1
backend/cpp/ds4/.gitignore
vendored
@@ -2,6 +2,7 @@ ds4/
|
||||
build/
|
||||
package/
|
||||
grpc-server
|
||||
ds4-worker
|
||||
*.o
|
||||
backend.pb.cc
|
||||
backend.pb.h
|
||||
|
||||
@@ -60,6 +60,13 @@ elseif(DS4_GPU STREQUAL "cpu")
|
||||
set(DS4_OBJS "${DS4_DIR}/ds4_cpu.o")
|
||||
endif()
|
||||
|
||||
# ds4.c now references ds4_distributed.c (distributed inference) and ds4_ssd.c
|
||||
# (SSD expert-cache), each split into its own translation unit upstream. Both
|
||||
# are GPU-agnostic objects shared by every GPU mode, so link them in regardless
|
||||
# of DS4_GPU.
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_distributed.o")
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_ssd.o")
|
||||
|
||||
add_executable(${TARGET}
|
||||
grpc-server.cpp
|
||||
dsml_parser.cpp
|
||||
@@ -99,3 +106,36 @@ if(DS4_NATIVE)
|
||||
target_compile_options(${TARGET} PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ds4-worker: standalone distributed worker. Links the same ds4 engine objects
|
||||
# (including ds4_distributed.o) but has NO gRPC/protobuf dependency - it speaks
|
||||
# ds4's own TCP transport via ds4_dist_run(). Buildable wherever the engine
|
||||
# objects build, even on hosts without protobuf/grpc dev headers.
|
||||
add_executable(ds4-worker worker_main.c)
|
||||
target_include_directories(ds4-worker PRIVATE ${DS4_DIR})
|
||||
foreach(obj ${DS4_OBJS})
|
||||
target_sources(ds4-worker PRIVATE ${obj})
|
||||
set_source_files_properties(${obj} PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
|
||||
endforeach()
|
||||
# worker_main.c is C, but the engine objects built by nvcc (ds4_cuda.o) and the
|
||||
# Metal path (ds4_metal.o, Obj-C++) reference the C++ runtime (libstdc++). Force
|
||||
# the C++ linker driver so those symbols resolve; the C driver would not link
|
||||
# libstdc++ and the CUDA/Metal builds fail with undefined std:: references.
|
||||
set_target_properties(ds4-worker PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_link_libraries(ds4-worker PRIVATE Threads::Threads m)
|
||||
|
||||
if(DS4_GPU STREQUAL "cuda")
|
||||
target_link_libraries(ds4-worker PRIVATE CUDA::cudart CUDA::cublas)
|
||||
elseif(DS4_GPU STREQUAL "metal")
|
||||
target_link_libraries(ds4-worker PRIVATE ${FOUNDATION_LIB} ${METAL_LIB})
|
||||
elseif(DS4_GPU STREQUAL "cpu")
|
||||
target_compile_definitions(ds4-worker PRIVATE DS4_NO_GPU)
|
||||
endif()
|
||||
|
||||
if(DS4_NATIVE)
|
||||
if(APPLE)
|
||||
target_compile_options(ds4-worker PRIVATE -mcpu=native)
|
||||
else()
|
||||
target_compile_options(ds4-worker PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=e8e8779b261c10f36ad6270ba732c8f0be5b62e3
|
||||
# Upstream pin lives below as DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=e8e8779b261c10f36ad6270ba732c8f0be5b62e3
|
||||
DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
@@ -18,16 +18,20 @@ UNAME_S := $(shell uname -s)
|
||||
|
||||
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
|
||||
|
||||
# ds4_distributed.o and ds4_ssd.o are GPU-agnostic translation units that
|
||||
# ds4.c/ds4_cpu.o now reference (upstream split distributed inference and the
|
||||
# SSD expert-cache into their own .c files). Both objects are shared by every
|
||||
# GPU mode, so they are appended unconditionally below.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS += -DDS4_GPU=cuda
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
CMAKE_ARGS += -DDS4_GPU=metal
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
|
||||
else
|
||||
# CPU reference path (Linux only - macOS CPU path is broken by VM bug per ds4 README).
|
||||
CMAKE_ARGS += -DDS4_GPU=cpu
|
||||
DS4_OBJ_TARGET := ds4_cpu.o
|
||||
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o ds4_ssd.o
|
||||
endif
|
||||
|
||||
ifneq ($(NATIVE),true)
|
||||
@@ -52,17 +56,18 @@ ds4:
|
||||
# the right per-platform compile flags (Objective-C/Metal on Darwin, nvcc on Linux+CUDA).
|
||||
ds4/ds4.o: ds4
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
|
||||
else
|
||||
+$(MAKE) -C ds4 ds4_cpu.o
|
||||
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o ds4_ssd.o
|
||||
endif
|
||||
|
||||
grpc-server: ds4/ds4.o
|
||||
mkdir -p $(BUILD_DIR)
|
||||
cd $(BUILD_DIR) && cmake $(CMAKE_ARGS) $(CURRENT_MAKEFILE_DIR) && cmake --build . --config Release -j $(JOBS)
|
||||
cp $(BUILD_DIR)/grpc-server grpc-server
|
||||
cp $(BUILD_DIR)/ds4-worker ds4-worker
|
||||
|
||||
package: grpc-server
|
||||
bash package.sh
|
||||
@@ -71,7 +76,7 @@ test:
|
||||
@echo "ds4 backend: e2e coverage at tests/e2e-backends/ (BACKEND_BINARY mode)"
|
||||
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR) grpc-server package
|
||||
rm -rf $(BUILD_DIR) grpc-server ds4-worker package
|
||||
if [ -d ds4 ]; then $(MAKE) -C ds4 clean; fi
|
||||
|
||||
purge: clean
|
||||
|
||||
@@ -23,8 +23,11 @@ extern "C" {
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
@@ -51,6 +54,12 @@ ds4_session *g_session = nullptr;
|
||||
int g_ctx_size = 32768;
|
||||
std::string g_kv_cache_dir; // empty disables disk cache
|
||||
|
||||
// Distributed coordinator state. g_distributed is set true when LoadModel is
|
||||
// given 'ds4_role:coordinator'; generation then waits for the worker route to
|
||||
// form before running. Single-node behavior is unchanged when unset.
|
||||
bool g_distributed = false;
|
||||
int g_route_timeout_sec = 60;
|
||||
|
||||
std::atomic<Server *> g_server{nullptr};
|
||||
|
||||
// Parse a "key:value" option string. Returns empty when no colon.
|
||||
@@ -60,6 +69,77 @@ static std::pair<std::string, std::string> split_option(const std::string &opt)
|
||||
return {opt.substr(0, colon), opt.substr(colon + 1)};
|
||||
}
|
||||
|
||||
// Parse a positive base-10 integer. Returns false (without throwing) on empty,
|
||||
// trailing garbage, non-positive, or overflow - unlike std::stoi.
|
||||
static bool parse_positive_int(const std::string &s, int *out) {
|
||||
if (s.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long v = std::strtol(s.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || v <= 0 || v > INT_MAX) return false;
|
||||
*out = static_cast<int>(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse a ds4 layer spec "START:END" or "START:output" into the engine's
|
||||
// distributed layer fields. Returns false on malformed input.
|
||||
static bool parse_layers_spec(const std::string &spec, ds4_distributed_layers *out) {
|
||||
auto colon = spec.find(':');
|
||||
if (colon == std::string::npos) return false;
|
||||
std::string lhs = spec.substr(0, colon);
|
||||
std::string rhs = spec.substr(colon + 1);
|
||||
if (lhs.empty() || rhs.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long start = std::strtol(lhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || start < 0) return false;
|
||||
out->start = static_cast<uint32_t>(start);
|
||||
out->has_output = false;
|
||||
if (rhs == "output") {
|
||||
out->has_output = true;
|
||||
out->end = out->start; // engine treats has_output as "through final layer"
|
||||
} else {
|
||||
long e = std::strtol(rhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || e < start) return false;
|
||||
out->end = static_cast<uint32_t>(e);
|
||||
}
|
||||
out->set = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
// When acting as a distributed coordinator, block until the worker route
|
||||
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
||||
// elapses. Returns an empty string on success, or an error message to return
|
||||
// to the client. No-op when not distributed.
|
||||
//
|
||||
// Takes the g_engine_mu lock by reference and RELEASES it during each poll
|
||||
// sleep. The wait can span up to g_route_timeout_sec seconds while workers
|
||||
// connect; holding g_engine_mu the whole time would block the Status/Health
|
||||
// readiness probes (they also lock g_engine_mu), making LocalAI's loader treat
|
||||
// a still-starting worker as hung.
|
||||
static std::string wait_route_ready(std::unique_lock<std::mutex> &lock) {
|
||||
if (!g_distributed) return "";
|
||||
char err[256] = {0};
|
||||
const int deadline_polls = g_route_timeout_sec * 10; // 100ms per poll
|
||||
for (int i = 0; i <= deadline_polls; ++i) {
|
||||
int ready = ds4_session_distributed_route_ready(g_session, err, sizeof(err));
|
||||
if (ready == 1) return "";
|
||||
if (ready < 0) {
|
||||
return std::string("ds4 distributed route error: ") +
|
||||
(err[0] ? err : "unknown");
|
||||
}
|
||||
// Release the lock while sleeping so Status/Health and other RPCs can
|
||||
// interleave during worker startup.
|
||||
lock.unlock();
|
||||
struct timespec ts = {0, 100L * 1000L * 1000L}; // 100ms
|
||||
nanosleep(&ts, nullptr);
|
||||
lock.lock();
|
||||
// A concurrent Free() may have torn down the engine while we slept.
|
||||
if (!g_engine || !g_session) {
|
||||
return "ds4: model unloaded while waiting for distributed route";
|
||||
}
|
||||
}
|
||||
return "ds4 distributed route incomplete: workers not connected (layers uncovered)";
|
||||
}
|
||||
|
||||
static void append_token_text(ds4_engine *engine, int token, std::string &out) {
|
||||
size_t len = 0;
|
||||
const char *text = ds4_token_text(engine, token, &len);
|
||||
@@ -377,6 +457,11 @@ public:
|
||||
backend::Result *result) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
|
||||
// Reset distributed state so a model swap (a second LoadModel without
|
||||
// ds4_role) doesn't inherit a stale coordinator configuration.
|
||||
g_distributed = false;
|
||||
g_route_timeout_sec = 60;
|
||||
|
||||
if (g_engine) {
|
||||
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
||||
ds4_engine_close(g_engine);
|
||||
@@ -394,12 +479,23 @@ public:
|
||||
std::string mtp_path;
|
||||
int mtp_draft = 0;
|
||||
float mtp_margin = 3.0f;
|
||||
std::string ds4_role, ds4_layers, ds4_listen;
|
||||
for (const auto &opt : request->options()) {
|
||||
auto [k, v] = split_option(opt);
|
||||
if (k == "mtp_path") mtp_path = v;
|
||||
else if (k == "mtp_draft") mtp_draft = std::stoi(v);
|
||||
else if (k == "mtp_margin") mtp_margin = std::stof(v);
|
||||
else if (k == "kv_cache_dir") g_kv_cache_dir = v;
|
||||
else if (k == "ds4_role") ds4_role = v;
|
||||
else if (k == "ds4_layers") ds4_layers = v;
|
||||
else if (k == "ds4_listen") ds4_listen = v;
|
||||
else if (k == "ds4_route_timeout") {
|
||||
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
g_kv_cache.SetDir(g_kv_cache_dir);
|
||||
@@ -422,6 +518,49 @@ public:
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
|
||||
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
||||
// distributed inference: this process listens on ds4_listen and owns
|
||||
// the ds4_layers slice; workers dial in (see `local-ai worker
|
||||
// ds4-distributed`). Absent ds4_role => unchanged single-node path.
|
||||
// Must be static: opt.distributed.listen_host is a const char* the
|
||||
// engine retains past this call, so it cannot point at a local that
|
||||
// goes out of scope (otherwise a future "simplify to local" refactor
|
||||
// reintroduces a dangling pointer).
|
||||
static std::string s_listen_host;
|
||||
if (ds4_role == "coordinator") {
|
||||
if (ds4_layers.empty() || ds4_listen.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_role:coordinator requires ds4_layers and ds4_listen");
|
||||
return GStatus::OK;
|
||||
}
|
||||
// host:port for IPv4/hostname; IPv6 literals are unsupported (the
|
||||
// first colon would split inside the address).
|
||||
auto host_port = split_option(ds4_listen); // "host:port" -> {host, port}
|
||||
if (host_port.second.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen must be host:port");
|
||||
return GStatus::OK;
|
||||
}
|
||||
int listen_port = 0;
|
||||
if (!parse_positive_int(host_port.second, &listen_port)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen port must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
ds4_distributed_layers layers = {};
|
||||
if (!parse_layers_spec(ds4_layers, &layers)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: invalid ds4_layers (want START:END or START:output)");
|
||||
return GStatus::OK;
|
||||
}
|
||||
s_listen_host = host_port.first;
|
||||
opt.distributed.role = DS4_DISTRIBUTED_COORDINATOR;
|
||||
opt.distributed.layers = layers;
|
||||
opt.distributed.listen_host = s_listen_host.c_str();
|
||||
opt.distributed.listen_port = listen_port;
|
||||
g_distributed = true;
|
||||
}
|
||||
|
||||
int rc = ds4_engine_open(&g_engine, &opt);
|
||||
if (rc != 0 || !g_engine) {
|
||||
result->set_success(false);
|
||||
@@ -458,10 +597,13 @@ public:
|
||||
|
||||
GStatus Predict(ServerContext *, const backend::PredictOptions *request,
|
||||
backend::Reply *reply) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
@@ -554,10 +696,13 @@ public:
|
||||
|
||||
GStatus PredictStream(ServerContext *, const backend::PredictOptions *request,
|
||||
ServerWriter<backend::Reply> *writer) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
|
||||
@@ -5,7 +5,8 @@ REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
cp -avf "$CURDIR/grpc-server" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/ds4-worker" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
UNAME_S=$(uname -s)
|
||||
if [ "$UNAME_S" = "Darwin" ]; then
|
||||
|
||||
126
backend/cpp/ds4/worker_main.c
Normal file
126
backend/cpp/ds4/worker_main.c
Normal file
@@ -0,0 +1,126 @@
|
||||
// ds4-worker: standalone distributed worker for the LocalAI ds4 backend.
|
||||
//
|
||||
// A ds4 distributed worker owns a slice of the model's transformer layers,
|
||||
// dials the coordinator, and serves activations for its slice. It does NOT
|
||||
// speak backend.proto - it speaks ds4's own TCP transport via ds4_dist_run().
|
||||
// This binary is intentionally minimal (no HTTP/web/kvstore/linenoise): it
|
||||
// only needs the engine objects + ds4_distributed.o, which the backend already
|
||||
// builds. It is launched by `local-ai worker ds4-distributed`.
|
||||
//
|
||||
// Usage:
|
||||
// ds4-worker --role worker --model <gguf> --layers 20:output \
|
||||
// --coordinator <host> <port> [--cpu|--cuda|--metal] [-c CTX] [-t N]
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <signal.h>
|
||||
#include <limits.h>
|
||||
|
||||
#include "ds4.h"
|
||||
#include "ds4_distributed.h"
|
||||
|
||||
static const char *need_arg(int *i, int argc, char **argv, const char *flag) {
|
||||
if (*i + 1 >= argc) {
|
||||
fprintf(stderr, "ds4-worker: missing value for %s\n", flag);
|
||||
exit(2);
|
||||
}
|
||||
return argv[++(*i)];
|
||||
}
|
||||
|
||||
static int parse_int_arg(const char *s, const char *flag) {
|
||||
char *end = NULL;
|
||||
long v = strtol(s, &end, 10);
|
||||
if (!s[0] || *end || v <= 0 || v > INT_MAX) {
|
||||
fprintf(stderr, "ds4-worker: invalid value for %s: %s\n", flag, s);
|
||||
exit(2);
|
||||
}
|
||||
return (int)v;
|
||||
}
|
||||
|
||||
static ds4_backend default_backend(void) {
|
||||
#if defined(DS4_NO_GPU)
|
||||
return DS4_BACKEND_CPU;
|
||||
#elif defined(__APPLE__)
|
||||
return DS4_BACKEND_METAL;
|
||||
#else
|
||||
return DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
signal(SIGPIPE, SIG_IGN);
|
||||
|
||||
ds4_engine_options opt = {0};
|
||||
opt.backend = default_backend();
|
||||
int ctx_size = 32768;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const char *arg = argv[i];
|
||||
if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
|
||||
fprintf(stdout, "ds4-worker: standalone ds4 distributed worker\n");
|
||||
ds4_dist_usage(stdout);
|
||||
fprintf(stdout, " -m, --model PATH model GGUF (the worker loads only its --layers slice)\n");
|
||||
fprintf(stdout, " -c, --ctx N context size (default 32768)\n");
|
||||
fprintf(stdout, " -t, --threads N CPU threads\n");
|
||||
fprintf(stdout, " --cpu|--cuda|--metal backend override\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
char dist_err[256] = {0};
|
||||
ds4_dist_cli_parse_result dist_parse =
|
||||
ds4_dist_parse_cli_arg(arg, &i, argc, argv, &opt.distributed,
|
||||
dist_err, sizeof(dist_err));
|
||||
if (dist_parse == DS4_DIST_CLI_ERROR) {
|
||||
fprintf(stderr, "ds4-worker: %s\n",
|
||||
dist_err[0] ? dist_err : "invalid distributed option");
|
||||
return 2;
|
||||
}
|
||||
if (dist_parse == DS4_DIST_CLI_MATCHED) continue;
|
||||
|
||||
if (!strcmp(arg, "-m") || !strcmp(arg, "--model")) {
|
||||
opt.model_path = need_arg(&i, argc, argv, arg);
|
||||
} else if (!strcmp(arg, "-c") || !strcmp(arg, "--ctx")) {
|
||||
ctx_size = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "-t") || !strcmp(arg, "--threads")) {
|
||||
opt.n_threads = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "--cpu")) {
|
||||
opt.backend = DS4_BACKEND_CPU;
|
||||
} else if (!strcmp(arg, "--cuda")) {
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
} else if (!strcmp(arg, "--metal")) {
|
||||
opt.backend = DS4_BACKEND_METAL;
|
||||
} else {
|
||||
fprintf(stderr, "ds4-worker: unknown option: %s\n", arg);
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (opt.distributed.role != DS4_DISTRIBUTED_WORKER) {
|
||||
fprintf(stderr, "ds4-worker: --role worker is required\n");
|
||||
return 2;
|
||||
}
|
||||
if (!opt.model_path) {
|
||||
fprintf(stderr, "ds4-worker: --model is required\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
char prep_err[256] = {0};
|
||||
if (ds4_dist_prepare_engine_options(&opt.distributed, &opt,
|
||||
prep_err, sizeof(prep_err)) != 0) {
|
||||
fprintf(stderr, "ds4-worker: %s\n", prep_err);
|
||||
return 2;
|
||||
}
|
||||
|
||||
ds4_engine *engine = NULL;
|
||||
if (ds4_engine_open(&engine, &opt) != 0 || !engine) {
|
||||
fprintf(stderr, "ds4-worker: failed to open engine\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
ds4_dist_generation_options gen = {0};
|
||||
gen.ctx_size = ctx_size;
|
||||
int rc = ds4_dist_run(engine, &opt.distributed, &gen);
|
||||
ds4_engine_close(engine);
|
||||
return rc;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=d2da6da05c73aeb658a3d1751f386c24e6963856
|
||||
IK_LLAMA_VERSION?=e6f8112f3ba126eed3ff5b30cdd08085414a7516
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=0d18aaa9d1a8af3df9abccd828e22eeaac7f840b
|
||||
LLAMA_VERSION?=039e20a2db9e87b2477c76cc04905f3e1acad77f
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -381,6 +381,15 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
});
|
||||
}
|
||||
|
||||
// for each video in the request, add the video data
|
||||
for (int i = 0; i < predict->videos_size(); i++) {
|
||||
data["video_data"].push_back(json
|
||||
{
|
||||
{"id", i},
|
||||
{"data", predict->videos(i)},
|
||||
});
|
||||
}
|
||||
|
||||
data["stop"] = predict->stopprompts();
|
||||
// data["n_probs"] = predict->nprobs();
|
||||
//TODO: images,
|
||||
@@ -482,23 +491,13 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!request->draftmodel().empty()) {
|
||||
params.speculative.draft.mparams.path = request->draftmodel();
|
||||
// Default to draft type if a draft model is set but no explicit type.
|
||||
// Upstream (post ggml-org/llama.cpp#22838) made the speculative type a
|
||||
// vector; the turboquant fork still uses the legacy scalar. The
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
||||
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
|
||||
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
}
|
||||
#else
|
||||
// Upstream made the speculative type a vector (ggml-org/llama.cpp#22838)
|
||||
// and renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE (#22964).
|
||||
const bool no_spec_type = params.speculative.types.empty() ||
|
||||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
||||
if (no_spec_type) {
|
||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// params.model_alias ??
|
||||
@@ -573,8 +572,13 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// checkpoint_min_step: minimum spacing between context checkpoints in
|
||||
// tokens (0 disables the minimum). Match upstream's default (256). This
|
||||
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
|
||||
// also shifted from a fixed cadence to a minimum spacing.
|
||||
// also shifted from a fixed cadence to a minimum spacing. The turboquant
|
||||
// fork still lacks common_params::checkpoint_min_step, so skip it there
|
||||
// (LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh).
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
params.checkpoint_min_step = 256;
|
||||
#endif
|
||||
|
||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||
for (int i = 0; i < request->options_size(); i++) {
|
||||
@@ -748,11 +752,18 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
|
||||
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
|
||||
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
|
||||
// with existing user configs: upstream renamed the field and shifted its
|
||||
// semantics from a fixed cadence to a minimum spacing.
|
||||
//
|
||||
// Gated out for the turboquant fork, which lacks common_params::
|
||||
// checkpoint_min_step. The leading `}` closing the cache_idle_slots
|
||||
// branch is removed with this block; the next `} else if` (n_ubatch)
|
||||
// then closes cache_idle_slots, so braces stay balanced under both
|
||||
// preprocessor branches.
|
||||
} else if (!strcmp(optname, "checkpoint_min_step") || !strcmp(optname, "checkpoint_min_spacing") ||
|
||||
!strcmp(optname, "checkpoint_every_nt") || !strcmp(optname, "checkpoint_every_n_tokens")) {
|
||||
if (optval != NULL) {
|
||||
@@ -762,6 +773,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// If conversion fails, keep default value (256)
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// --- physical batch size (upstream -ub / --ubatch-size) ---
|
||||
// Note: line ~482 already aliases n_ubatch to n_batch as a default; this
|
||||
@@ -894,17 +906,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// Fork only knows a single scalar `type`. Take the first comma-
|
||||
// separated value and assign it via the singular helper.
|
||||
std::string first = optval_str;
|
||||
const auto comma = first.find(',');
|
||||
if (comma != std::string::npos) first = first.substr(0, comma);
|
||||
auto type = common_speculative_type_from_name(first);
|
||||
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||
params.speculative.type = type;
|
||||
}
|
||||
#else
|
||||
// Upstream switched to a vector of types (comma-separated for multi-type
|
||||
// chaining via common_speculative_types_from_names). We keep accepting a
|
||||
// single value here, but also tolerate comma-separated lists.
|
||||
@@ -933,7 +934,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!parsed.empty()) {
|
||||
params.speculative.types = parsed;
|
||||
}
|
||||
#endif
|
||||
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.draft.n_max = std::stoi(optval_str); } catch (...) {}
|
||||
@@ -971,21 +971,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// shares the target context size. Accept the option for backward
|
||||
// compatibility but silently ignore it.
|
||||
|
||||
// Everything below relies on struct shape introduced in ggml-org/llama.cpp#22838
|
||||
// (parallel drafting): `ngram_mod`, `ngram_map_k`, `ngram_map_k4v`,
|
||||
// `ngram_cache`, and the `draft.{cache_type_*, cpuparams*, tensor_buft_overrides}`
|
||||
// fields. The turboquant fork branched before that, so its build defines
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC via patch-grpc-server.sh and these option
|
||||
// keys become unrecognized (silently dropped, like any unknown opt) for it.
|
||||
//
|
||||
// The `#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC` / `#else` split below sits at the
|
||||
// closing-brace position of the `draft_ctx_size` branch on purpose: in the
|
||||
// legacy build the chain ends here (the brace closes draft_ctx_size), and in
|
||||
// the modern build the chain continues with `} else if (...)` instead, so the
|
||||
// brace count stays balanced under both branches of the preprocessor.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
}
|
||||
#else
|
||||
// --- ngram_mod family (upstream --spec-ngram-mod-*) ---
|
||||
} else if (!strcmp(optname, "spec_ngram_mod_n_min")) {
|
||||
if (optval != NULL) {
|
||||
@@ -1115,7 +1100,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
}
|
||||
if (!cur.empty()) flush(cur);
|
||||
}
|
||||
#endif // LOCALAI_LEGACY_LLAMA_CPP_SPEC — closes the `else`/`#ifdef` opened at draft_ctx_size
|
||||
}
|
||||
|
||||
// Set params.n_parallel from environment variable if not set via options (fallback)
|
||||
@@ -1165,6 +1149,8 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
}
|
||||
// Terminate the draft tensor_buft_overrides list with a sentinel, mirroring
|
||||
// the main-model handling above.
|
||||
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
|
||||
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
@@ -1526,7 +1512,7 @@ public:
|
||||
msg_json["role"] = msg.role();
|
||||
|
||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
||||
|
||||
// Handle content - can be string, null, or array
|
||||
// For multimodal content, we'll embed images/audio from separate fields
|
||||
@@ -1577,6 +1563,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else {
|
||||
// Use content as-is (already array or not last user message)
|
||||
@@ -1611,6 +1607,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else if (msg.role() == "tool") {
|
||||
// Tool role messages must have content field set, even if empty
|
||||
@@ -1926,6 +1932,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto re_it = metadata.find("reasoning_effort");
|
||||
if (re_it != metadata.end() && !re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2051,6 +2068,16 @@ public:
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &video_data = data.find("video_data");
|
||||
if (video_data != data.end() && video_data->is_array())
|
||||
{
|
||||
for (const auto &video : *video_data)
|
||||
{
|
||||
auto decoded_data = base64_decode(video["data"].get<std::string>());
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const bool has_mtmd = ctx_server.impl->mctx != nullptr;
|
||||
@@ -2186,7 +2213,15 @@ public:
|
||||
// content element — attaching to both would duplicate the first
|
||||
// token since oaicompat_msg_diffs is the same for both.
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
// Upstream llama.cpp (ggml-org/llama.cpp#23884) now emits an initial
|
||||
// "begin" partial whose to_json() returns null, used only to signal the
|
||||
// HTTP layer to flush 200 status headers before any token. gRPC has no
|
||||
// such concept, so there is nothing to emit — the real tokens arrive in
|
||||
// the loop below. Feeding this null into build_reply_from_json would
|
||||
// throw (uncaught) and surface as a generic RPC error.
|
||||
if (first_res_json.is_null()) {
|
||||
// skip the begin-of-stream marker
|
||||
} else if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
// Skip chat deltas for role-init elements (have "role" in
|
||||
@@ -2216,7 +2251,10 @@ public:
|
||||
}
|
||||
|
||||
json res_json = result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
if (res_json.is_null()) {
|
||||
// begin-of-stream marker (see note above) — nothing to emit
|
||||
continue;
|
||||
} else if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
@@ -2292,7 +2330,7 @@ public:
|
||||
}
|
||||
|
||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
||||
|
||||
// Handle content - can be string, null, or array
|
||||
// For multimodal content, we'll embed images/audio from separate fields
|
||||
@@ -2345,6 +2383,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else {
|
||||
// Use content as-is (already array or not last user message)
|
||||
@@ -2384,6 +2432,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i);
|
||||
} else if (!msg.tool_calls().empty()) {
|
||||
@@ -2708,6 +2766,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto predict_re_it = predict_metadata.find("reasoning_effort");
|
||||
if (predict_re_it != predict_metadata.end() && !predict_re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = predict_re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2835,6 +2904,16 @@ public:
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &video_data = data.find("video_data");
|
||||
if (video_data != data.end() && video_data->is_array())
|
||||
{
|
||||
for (const auto &video : *video_data)
|
||||
{
|
||||
auto decoded_data = base64_decode(video["data"].get<std::string>());
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process files
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
|
||||
TURBOQUANT_VERSION?=7d9715f1f071fa07c7b2ad3dbfd320b314139e65
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -4,21 +4,19 @@
|
||||
#
|
||||
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
|
||||
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
|
||||
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
|
||||
# server-side random per-instance marker) with the legacy "<__media__>"
|
||||
# literal. The fork branched before that PR, so server-common.cpp has no
|
||||
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
|
||||
# "<__media__>", and Go-side tooling falls back to that sentinel when the
|
||||
# backend does not expose media_marker, so substituting the literal keeps
|
||||
# behavior identical on the turboquant path.
|
||||
# 3. Revert the `common_params_speculative` field references to the
|
||||
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
|
||||
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
|
||||
# the turboquant fork branched before that PR and still exposes the flat
|
||||
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
|
||||
# below map the new nested paths back to the legacy flat names so the
|
||||
# shared grpc-server.cpp keeps compiling against the fork's common.h.
|
||||
# Drop this block once the fork rebases past #22397.
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file
|
||||
# so the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default and the option handler).
|
||||
# That field does not exist in the fork yet; drop this once it does.
|
||||
#
|
||||
# The fork used to lag upstream on the whole common_params_speculative refactor
|
||||
# (ggml-org/llama.cpp#22397/#22838/#22964), the model_tgt rename (#22838) and
|
||||
# get_media_marker (#21962), which required a much larger compat shim here
|
||||
# (flat-field sed renames + a coarse LOCALAI_LEGACY_LLAMA_CPP_SPEC define). The
|
||||
# fork has since rebased past all of those, so the only remaining gap is
|
||||
# checkpoint_min_step. If a future bump reintroduces a divergence, add a narrow
|
||||
# guard in grpc-server.cpp keyed on a fork-specific macro and inject it here
|
||||
# rather than resurrecting the coarse one.
|
||||
#
|
||||
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
|
||||
@@ -72,69 +70,20 @@ else
|
||||
echo "==> KV allow-list patch OK"
|
||||
fi
|
||||
|
||||
if grep -q 'get_media_marker()' "$SRC"; then
|
||||
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
|
||||
# Only one call site today (ModelMetadata), but replace all occurrences to
|
||||
# stay robust if upstream adds more. Use a temp file to avoid relying on
|
||||
# sed -i portability (the builder image uses GNU sed, but keeping this
|
||||
# consistent with the awk block above).
|
||||
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> get_media_marker() substitution OK"
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file so
|
||||
# the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default assignment and the option
|
||||
# handler). That field does not exist in the fork yet. Drop this block once
|
||||
# the fork rebases past the bump that added checkpoint_min_step.
|
||||
if grep -q '^#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP, skipping"
|
||||
else
|
||||
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
|
||||
fi
|
||||
|
||||
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
|
||||
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
|
||||
# Each substitution is the exact post-refactor path → legacy flat field.
|
||||
# Order doesn't matter because the source paths are disjoint, but we keep
|
||||
# the most-specific (mparams.path) first for readability.
|
||||
sed -E \
|
||||
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
|
||||
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
|
||||
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
|
||||
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
|
||||
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
|
||||
"$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> speculative field rename OK"
|
||||
else
|
||||
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
|
||||
fi
|
||||
|
||||
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
|
||||
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
|
||||
# exposes the field as `model` on `server_context_impl`. The two call sites
|
||||
# are in the Rerank and ModelMetadata RPC handlers.
|
||||
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
|
||||
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
|
||||
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> model_tgt rename OK"
|
||||
else
|
||||
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
|
||||
fi
|
||||
|
||||
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
|
||||
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
|
||||
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
|
||||
# draft.tensor_buft_overrides) introduced for the post-#22838 layout. Those
|
||||
# blocks reference struct fields that simply do not exist in the fork.
|
||||
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
|
||||
else
|
||||
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
|
||||
# Insert the define before the very first `#include` so it precedes all the
|
||||
# speculative-decoding code paths.
|
||||
echo "==> patching $SRC to define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top"
|
||||
# Insert the define before the very first `#include` so it precedes the
|
||||
# checkpoint_min_step references.
|
||||
awk '
|
||||
!done && /^#include/ {
|
||||
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
|
||||
print "#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP 1"
|
||||
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
|
||||
print ""
|
||||
done = 1
|
||||
@@ -142,13 +91,13 @@ else
|
||||
{ print }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
|
||||
echo "==> LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP define OK"
|
||||
fi
|
||||
|
||||
echo "==> all patches applied"
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
hip: port the turboquant CUDA additions that ggml's HIP shim doesn't cover
|
||||
|
||||
The turboquant fork adds/modifies a few ggml-cuda.cu spots with CUDA APIs
|
||||
that ggml's HIP (and MUSA) compatibility layer does not provide, breaking
|
||||
the -gpu-rocm-hipblas-turboquant build:
|
||||
|
||||
1. ggml_cuda_copy2d_across_devices() (host-staged cross-device copy for
|
||||
split mul_mat output) uses the CUDA 3D-peer copy APIs
|
||||
cudaMemcpy3DPeerParms / make_cudaPitchedPtr / make_cudaExtent /
|
||||
cudaMemcpy3DPeerAsync. HIP genuinely does not support these (see the
|
||||
fork's own comment "HIP does not support cudaMemcpy3DPeerAsync"), so
|
||||
guard the peer fast path with #if !defined(GGML_USE_HIP) &&
|
||||
!defined(GGML_USE_MUSA) -- matching how the fork already guards the
|
||||
same API for the sibling 2D copy -- and fall through to the existing
|
||||
cudaMemcpyAsync staging fallback below (functionally identical,
|
||||
slightly slower on multi-GPU ROCm).
|
||||
|
||||
2. ggml_backend_cuda_device_event_new() creates its event with plain
|
||||
cudaEventCreate, which ggml's HIP shim does not alias (it only aliases
|
||||
cudaEventCreateWithFlags). Use cudaEventCreateWithFlags(...,
|
||||
cudaEventDisableTiming) -- exactly what the rest of this file already
|
||||
does (cf. lines ~1034, ~3461) and HIP-safe.
|
||||
|
||||
CUDA builds are unaffected. Drop the relevant hunk once the fork HIP-ports
|
||||
these; apply-patches.sh fails fast if an anchor goes stale.
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 0427e6b..6352e6a 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -1933,6 +1933,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
size_t width, size_t height, cudaStream_t dst_stream, cudaStream_t src_stream) {
|
||||
|
||||
const auto & info = ggml_cuda_info();
|
||||
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // 3D-peer copy types unmapped by ggml's HIP/MUSA shim; use staging fallback below
|
||||
if (info.peer_access[src_device][dst_device]) {
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = dst_device;
|
||||
@@ -1942,6 +1943,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
p.extent = make_cudaExtent(width, height, 1);
|
||||
return cudaMemcpy3DPeerAsync(&p, dst_stream);
|
||||
}
|
||||
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
// Fallback: stage all rows through a single contiguous pinned buffer
|
||||
int prev_device = ggml_cuda_get_device();
|
||||
@@ -5714,7 +5716,7 @@ static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_
|
||||
ggml_cuda_set_device(dev_ctx->device);
|
||||
|
||||
cudaEvent_t event;
|
||||
- CUDA_CHECK(cudaEventCreate(&event));
|
||||
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
|
||||
|
||||
return new ggml_backend_event {
|
||||
/* .device = */ dev,
|
||||
@@ -192,6 +192,61 @@ var _ = Describe("Forward", func() {
|
||||
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
|
||||
})
|
||||
|
||||
It("refuses to follow upstream redirects and never leaks the key to the redirect target", func() {
|
||||
// A 3xx from the configured upstream means misconfiguration or a
|
||||
// hijacked/spoofed host. Following it would replay the request —
|
||||
// and the injected API key — to the Location host. Anthropic's
|
||||
// x-api-key is NOT stripped by Go on cross-host redirects, so this
|
||||
// would be a credential leak. The proxy must refuse the redirect.
|
||||
sinkHit := make(chan string, 1)
|
||||
sink := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sinkHit <- r.Header.Get("x-api-key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer sink.Close()
|
||||
|
||||
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, sink.URL, http.StatusFound)
|
||||
}))
|
||||
defer redirector.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_REDIRECT_KEY", "ant-secret")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: redirector.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_REDIRECT_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-no-redirect"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/messages",
|
||||
Method: "POST",
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
|
||||
// Drain the stream; a refused redirect surfaces as a non-EOF error.
|
||||
var streamErr error
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
streamErr = err
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(streamErr).To(HaveOccurred(), "refused redirect must surface as an error")
|
||||
Expect(sinkHit).NotTo(Receive(), "the redirect target must never be contacted")
|
||||
})
|
||||
|
||||
It("handles concurrent calls without interference", func() {
|
||||
// CloudProxy explicitly omits base.SingleThread — independent
|
||||
// Forward streams must not block each other or leak state.
|
||||
|
||||
@@ -11,9 +11,12 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
|
||||
@@ -48,10 +51,15 @@ type proxyConfig struct {
|
||||
}
|
||||
|
||||
func NewCloudProxy() *CloudProxy {
|
||||
// No Client-level Timeout — that would bound streaming SSE
|
||||
// responses too, which can legitimately last minutes. Per-request
|
||||
// deadlines come from the gRPC stream context.
|
||||
return &CloudProxy{client: &http.Client{}}
|
||||
// httpclient.New refuses redirects outright: the proxy talks to a
|
||||
// single configured upstream API (OpenAI/Anthropic/...) that answers
|
||||
// directly, so a 3xx means misconfiguration, a hijacked upstream, or
|
||||
// DNS trickery — never normal operation. Following it would replay the
|
||||
// request, including the operator's x-api-key (which Go does NOT strip
|
||||
// on cross-host redirects), to an unvetted host and leak the key
|
||||
// (GHSA-3mj3-57v2-4636). It also imposes no body deadline, so streaming
|
||||
// SSE responses that legitimately last minutes are not truncated.
|
||||
return &CloudProxy{client: httpclient.New()}
|
||||
}
|
||||
|
||||
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
|
||||
@@ -138,7 +146,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -168,7 +176,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -262,7 +270,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
@@ -426,4 +434,3 @@ func isHopByHopHeader(name string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
5
backend/go/crispasr/.gitignore
vendored
Normal file
5
backend/go/crispasr/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
sources
|
||||
build*
|
||||
libgocrispasr*.so
|
||||
crispasr
|
||||
package
|
||||
30
backend/go/crispasr/CMakeLists.txt
Normal file
30
backend/go/crispasr/CMakeLists.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(gocrispasr LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
add_subdirectory(./sources/CrispASR)
|
||||
|
||||
add_library(gocrispasr MODULE cpp/crispasr_shim.cpp)
|
||||
target_include_directories(gocrispasr PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/ggml/include)
|
||||
# Link the same backend set as crispasr-cli (examples/cli/CMakeLists.txt) so
|
||||
# the session API can dispatch to every compiled-in architecture, not just
|
||||
# whisper. crispasr is the referencer; the backend static libs supply the
|
||||
# per-architecture symbols; ggml is the math/runtime base.
|
||||
target_link_libraries(gocrispasr PRIVATE
|
||||
crispasr-lib
|
||||
parakeet canary canary_ctc cohere granite_speech granite_nle
|
||||
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
|
||||
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice
|
||||
silero-lid pyannote-seg funasr paraformer sensevoice
|
||||
crisp_audio
|
||||
ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gocrispasr PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
set_property(TARGET gocrispasr PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gocrispasr PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
132
backend/go/crispasr/Makefile
Normal file
132
backend/go/crispasr/Makefile
Normal file
@@ -0,0 +1,132 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=c29f6653a516a3001d923944dad8892072cc7334
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
# Keep the build lean: no tests/examples/server/SDL2/curl/ffmpeg (the FROM scratch
|
||||
# image cannot satisfy those runtime deps). All ASR/TTS model backends stay enabled.
|
||||
CMAKE_ARGS+=-DCRISPASR_BUILD_TESTS=OFF -DCRISPASR_BUILD_EXAMPLES=OFF -DCRISPASR_BUILD_SERVER=OFF
|
||||
CMAKE_ARGS+=-DCRISPASR_SDL2=OFF -DCRISPASR_CURL=OFF -DCRISPASR_FFMPEG=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/CrispASR:
|
||||
mkdir -p sources/CrispASR
|
||||
cd sources/CrispASR && \
|
||||
git init && \
|
||||
git remote add origin $(CRISPASR_REPO) && \
|
||||
git fetch origin && \
|
||||
git checkout $(CRISPASR_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
# CrispASR's src/CMakeLists.txt locates its vendored llama.cpp
|
||||
# (crispasr-llama-core, used by the chat C-ABI) via ${CMAKE_SOURCE_DIR},
|
||||
# which assumes CrispASR is the top-level CMake project. We add_subdirectory
|
||||
# it, so ${CMAKE_SOURCE_DIR} is THIS backend dir and the talk-llama sources
|
||||
# aren't found. Rewrite to ${PROJECT_SOURCE_DIR} (the crispasr project root),
|
||||
# which is correct both standalone and as a subproject. Idempotent.
|
||||
sed -i 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgocrispasr-avx.so libgocrispasr-avx2.so libgocrispasr-avx512.so libgocrispasr-fallback.so
|
||||
else
|
||||
VARIANT_TARGETS = libgocrispasr-fallback.so
|
||||
endif
|
||||
|
||||
crispasr: main.go gocrispasr.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o crispasr ./
|
||||
|
||||
package: crispasr
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgocrispasr*.so package sources/CrispASR crispasr
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgocrispasr-avx.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx${RESET})
|
||||
SO_TARGET=libgocrispasr-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx2.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx2${RESET})
|
||||
SO_TARGET=libgocrispasr-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx512.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx512${RESET})
|
||||
SO_TARGET=libgocrispasr-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
libgocrispasr-fallback.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:fallback${RESET})
|
||||
SO_TARGET=libgocrispasr-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-custom: CMakeLists.txt cpp/crispasr_shim.cpp cpp/crispasr_shim.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgocrispasr.so ./$(SO_TARGET)
|
||||
|
||||
test: crispasr
|
||||
CGO_ENABLED=0 $(GOCMD) test -v ./...
|
||||
|
||||
all: crispasr package
|
||||
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
#include "crispasr_shim.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "crispasr.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
// Opaque session types. crispasr.h declares `struct crispasr_session;` but not
|
||||
// the result type nor the open/transcribe/result accessors — those are
|
||||
// CA_EXPORT extern "C" symbols in src/crispasr_c_api.cpp, so we forward-declare
|
||||
// exactly the ones we use. Signatures verified against
|
||||
// sources/CrispASR/src/crispasr_c_api.cpp.
|
||||
struct crispasr_session_result;
|
||||
extern "C" {
|
||||
crispasr_session *crispasr_session_open(const char *model_path, int n_threads);
|
||||
crispasr_session *crispasr_session_open_explicit(const char *model_path,
|
||||
const char *backend_name,
|
||||
int n_threads);
|
||||
int crispasr_session_set_codec_path(crispasr_session *s, const char *path);
|
||||
void crispasr_session_close(crispasr_session *s);
|
||||
const char *crispasr_session_backend(crispasr_session *s);
|
||||
int crispasr_session_set_translate(crispasr_session *s, int enable);
|
||||
crispasr_session_result *crispasr_session_transcribe_lang(
|
||||
crispasr_session *s, const float *pcm, int n_samples, const char *language);
|
||||
int crispasr_session_result_n_segments(crispasr_session_result *r);
|
||||
const char *crispasr_session_result_segment_text(crispasr_session_result *r,
|
||||
int i);
|
||||
int64_t crispasr_session_result_segment_t0(crispasr_session_result *r, int i);
|
||||
int64_t crispasr_session_result_segment_t1(crispasr_session_result *r, int i);
|
||||
void crispasr_session_result_free(crispasr_session_result *r);
|
||||
float *crispasr_session_synthesize(crispasr_session *s, const char *text,
|
||||
int *out_n_samples);
|
||||
void crispasr_pcm_free(float *pcm);
|
||||
int crispasr_session_set_speaker_name(crispasr_session *s, const char *name);
|
||||
int crispasr_session_set_voice(crispasr_session *s, const char *path,
|
||||
const char *ref_text_or_null);
|
||||
}
|
||||
|
||||
static crispasr_session *g_session = nullptr;
|
||||
static crispasr_session_result *g_result = nullptr;
|
||||
|
||||
static struct whisper_vad_context *vctx;
|
||||
static std::vector<float> flat_segs;
|
||||
|
||||
static std::atomic<int> g_abort{0};
|
||||
|
||||
extern "C" void set_abort(int v) {
|
||||
g_abort.store(v, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
|
||||
if (!log) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG:
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_INFO:
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_WARN:
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default: /* Potential future-proofing */
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[%-5s] ", level_str);
|
||||
fputs(log, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (backend_name && *backend_name) {
|
||||
g_session =
|
||||
crispasr_session_open_explicit(model_path, backend_name, threads);
|
||||
} else {
|
||||
g_session = crispasr_session_open(model_path, threads);
|
||||
}
|
||||
if (g_session == nullptr) {
|
||||
fprintf(stderr, "error: failed to open CrispASR session for model\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "info: CrispASR backend selected: %s\n",
|
||||
crispasr_session_backend(g_session));
|
||||
return 0;
|
||||
}
|
||||
|
||||
// set_codec_path forwards a companion file (qwen3-tts codec, orpheus SNAC,
|
||||
// chatterbox s3gen, or mimo-asr tokenizer) to the active session. Returns 0 on
|
||||
// success or when the active backend needs no companion, negative on failure,
|
||||
// and -1 when no session is open.
|
||||
int set_codec_path(const char *path) {
|
||||
return g_session ? crispasr_session_set_codec_path(g_session, path) : -1;
|
||||
}
|
||||
|
||||
int load_model_vad(const char *const model_path) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
struct whisper_vad_context_params vcparams =
|
||||
whisper_vad_default_context_params();
|
||||
|
||||
// XXX: Overridden to false in upstream due to performance?
|
||||
// vcparams.use_gpu = true;
|
||||
|
||||
vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
|
||||
if (vctx == nullptr) {
|
||||
fprintf(stderr, "error: Failed to init model as VAD\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
|
||||
size_t *segs_out_len) {
|
||||
if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
|
||||
fprintf(stderr, "error: failed to detect speech\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct whisper_vad_params params = whisper_vad_default_params();
|
||||
struct whisper_vad_segments *segs =
|
||||
whisper_vad_segments_from_probs(vctx, params);
|
||||
size_t segn = whisper_vad_segments_n_segments(segs);
|
||||
|
||||
// fprintf(stderr, "Got segments %zd\n", segn);
|
||||
|
||||
flat_segs.clear();
|
||||
|
||||
for (int i = 0; i < segn; i++) {
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
|
||||
}
|
||||
|
||||
// fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
|
||||
// segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
|
||||
// flat_segs.size());
|
||||
*segs_out = flat_segs.data();
|
||||
*segs_out_len = flat_segs.size();
|
||||
|
||||
// fprintf(stderr, "freeing segs\n");
|
||||
whisper_vad_free_segments(segs);
|
||||
|
||||
// fprintf(stderr, "returning\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// threads, diarize and prompt are accepted for Go-side API parity but unused
|
||||
// in Phase 1: the thread count is fixed at session open, and diarization and
|
||||
// the initial prompt are separate CrispASR features not yet wired through the
|
||||
// session ASR path.
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt) {
|
||||
(void)threads;
|
||||
(void)diarize;
|
||||
(void)prompt;
|
||||
|
||||
if (!g_session) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Reset stale abort flag from any prior cancelled call. set_abort remains
|
||||
// best-effort: the session transcribe call is blocking and exposes no abort
|
||||
// hook, so a mid-decode abort cannot interrupt it.
|
||||
g_abort.store(0, std::memory_order_relaxed);
|
||||
|
||||
crispasr_session_set_translate(g_session, translate ? 1 : 0);
|
||||
|
||||
if (g_result) {
|
||||
crispasr_session_result_free(g_result);
|
||||
g_result = nullptr;
|
||||
}
|
||||
|
||||
const char *language = (lang && *lang) ? lang : nullptr;
|
||||
g_result = crispasr_session_transcribe_lang(g_session, pcmf32, (int)pcmf32_len,
|
||||
language);
|
||||
if (!g_result) {
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
*segs_out_len = crispasr_session_result_n_segments(g_result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char *get_segment_text(int i) {
|
||||
if (!g_result) {
|
||||
return "";
|
||||
}
|
||||
return crispasr_session_result_segment_text(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t0(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t0(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t1(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t1(g_result, i);
|
||||
}
|
||||
|
||||
const char *get_backend(void) {
|
||||
return g_session ? crispasr_session_backend(g_session) : "";
|
||||
}
|
||||
|
||||
// TTS uses the already-open session (crispasr_session_open auto-detects a TTS
|
||||
// model). Output is 24 kHz mono float PCM (upstream CrispASR convention),
|
||||
// malloc'd by the C API; the caller must release it via tts_free.
|
||||
float *tts_synthesize(const char *text, int *out_n_samples) {
|
||||
if (out_n_samples) *out_n_samples = 0;
|
||||
if (!g_session || !text) return nullptr;
|
||||
return crispasr_session_synthesize(g_session, text, out_n_samples);
|
||||
}
|
||||
|
||||
void tts_free(float *pcm) {
|
||||
if (pcm) crispasr_pcm_free(pcm);
|
||||
}
|
||||
|
||||
int tts_set_voice(const char *name) {
|
||||
if (!g_session || !name || !*name) return 0;
|
||||
return crispasr_session_set_speaker_name(g_session, name);
|
||||
}
|
||||
|
||||
// tts_set_voice_file loads a voice from a file: a .gguf path selects a voice
|
||||
// pack, a .wav path with a non-empty ref_text performs zero-shot voice cloning
|
||||
// (the C API returns -2 when ref_text is required but missing). Returns -1 when
|
||||
// no session is open or path is null.
|
||||
int tts_set_voice_file(const char *path, const char *ref_text) {
|
||||
if (!g_session || !path) return -1;
|
||||
const char *ref = (ref_text && *ref_text) ? ref_text : nullptr;
|
||||
return crispasr_session_set_voice(g_session, path, ref);
|
||||
}
|
||||
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
extern "C" {
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name);
|
||||
int set_codec_path(const char *path);
|
||||
int load_model_vad(const char *const model_path);
|
||||
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
||||
size_t *segs_out_len);
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt);
|
||||
const char *get_segment_text(int i);
|
||||
int64_t get_segment_t0(int i);
|
||||
int64_t get_segment_t1(int i);
|
||||
const char *get_backend(void);
|
||||
void set_abort(int v);
|
||||
float *tts_synthesize(const char *text, int *out_n_samples); // 24kHz mono float, malloc'd; NULL on failure
|
||||
void tts_free(float *pcm);
|
||||
int tts_set_voice(const char *name); // best-effort speaker selection; 0 ok
|
||||
int tts_set_voice_file(const char *path, const char *ref_text); // load voice pack (.gguf) or zero-shot clone (.wav + ref_text)
|
||||
}
|
||||
497
backend/go/crispasr/gocrispasr.go
Normal file
497
backend/go/crispasr/gocrispasr.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string, threads int, backendName string) int
|
||||
CppSetCodecPath func(path string) int
|
||||
CppLoadModelVAD func(modelPath string) int
|
||||
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
||||
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int
|
||||
CppGetSegmentText func(i int) string
|
||||
CppGetSegmentStart func(i int) int64
|
||||
CppGetSegmentEnd func(i int) int64
|
||||
CppGetBackend func() string
|
||||
CppSetAbort func(v int)
|
||||
CppTTSSynthesize func(text string, outNSamples unsafe.Pointer) uintptr
|
||||
CppTTSFree func(ptr uintptr)
|
||||
CppTTSSetVoice func(name string) int
|
||||
CppTTSSetVoiceFile func(path string, refText string) int
|
||||
)
|
||||
|
||||
type CrispASR struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
// splitOption splits a "prefix:value" model option into its key and value,
|
||||
// matching the convention used by other backends (see sherpa-onnx). It returns
|
||||
// ok=false when the option carries no ':' separator.
|
||||
func splitOption(oo string) (key, value string, ok bool) {
|
||||
parts := strings.SplitN(oo, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
|
||||
func (w *CrispASR) Load(opts *pb.ModelOptions) error {
|
||||
vadOnly := false
|
||||
backendName := ""
|
||||
codecPath := ""
|
||||
speakerName := ""
|
||||
voicePath := ""
|
||||
voiceRefText := ""
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
if oo == "vad_only" {
|
||||
vadOnly = true
|
||||
continue
|
||||
}
|
||||
switch key, value, ok := splitOption(oo); {
|
||||
case ok && key == "backend":
|
||||
backendName = value
|
||||
case ok && key == "codec":
|
||||
codecPath = value
|
||||
case ok && key == "speaker":
|
||||
speakerName = value
|
||||
case ok && key == "voice":
|
||||
voicePath = value
|
||||
case ok && key == "voice_text":
|
||||
voiceRefText = value
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
}
|
||||
|
||||
if vadOnly {
|
||||
if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR VAD model")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve a relative companion path against the model directory so a config
|
||||
// can reference a sibling codec/tokenizer file by name alone.
|
||||
if codecPath != "" && !filepath.IsAbs(codecPath) {
|
||||
codecPath = filepath.Join(filepath.Dir(opts.ModelFile), codecPath)
|
||||
}
|
||||
|
||||
// A voice file (.gguf pack or .wav prompt) is resolved against the model
|
||||
// directory just like the codec, so a config can reference a sibling file.
|
||||
if voicePath != "" && !filepath.IsAbs(voicePath) {
|
||||
voicePath = filepath.Join(filepath.Dir(opts.ModelFile), voicePath)
|
||||
}
|
||||
|
||||
if ret := CppLoadModel(opts.ModelFile, int(opts.Threads), backendName); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR transcription model")
|
||||
}
|
||||
|
||||
// Load the companion file (codec/tokenizer/s3gen) after the session is open.
|
||||
// rc==0 means success or "not applicable" for the active backend; only a
|
||||
// negative code is fatal.
|
||||
if codecPath != "" {
|
||||
if rc := CppSetCodecPath(codecPath); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load companion file %q (rc=%d)", codecPath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR companion file loaded: %s\n", codecPath)
|
||||
}
|
||||
|
||||
// Apply the Load-time default voice. A baked speaker (speaker:) is selected
|
||||
// by name and is best-effort: a backend that can't honor it is logged, not
|
||||
// fatal. A voice file (voice:) is a hard requirement once configured, so a
|
||||
// negative rc fails Load.
|
||||
if speakerName != "" {
|
||||
if rc := CppTTSSetVoice(speakerName); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: speaker %q not applied (rc=%d)\n", speakerName, rc)
|
||||
}
|
||||
}
|
||||
if voicePath != "" {
|
||||
if rc := CppTTSSetVoiceFile(voicePath, voiceRefText); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load voice %q (rc=%d)", voicePath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR voice loaded: %s\n", voicePath)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "CrispASR backend selected: %s\n", CppGetBackend())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
audio := req.Audio
|
||||
// We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
|
||||
segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
|
||||
segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
|
||||
return pb.VADResponse{}, fmt.Errorf("Failed VAD")
|
||||
}
|
||||
|
||||
// Happens when CPP vector has not had any elements pushed to it
|
||||
if segsPtr == 0 {
|
||||
return pb.VADResponse{
|
||||
Segments: []*pb.VADSegment{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
|
||||
// however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
|
||||
segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen) //nolint:govet // segsPtr addresses C++-owned heap memory passed back through the cgo-free purego boundary; the uintptr->Pointer round-trip is intentional and the buffer outlives this read.
|
||||
|
||||
vadSegments := []*pb.VADSegment{}
|
||||
for i := range len(segs) >> 1 {
|
||||
s := segs[2*i] / 100
|
||||
t := segs[2*i+1] / 100
|
||||
vadSegments = append(vadSegments, &pb.VADSegment{
|
||||
Start: s,
|
||||
End: t,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.VADResponse{
|
||||
Segments: vadSegments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
// Watcher: flips the C-side abort flag when ctx is cancelled. The
|
||||
// goroutine is joined synchronously (close(done) signals it to exit,
|
||||
// wg.Wait() blocks until it has) so a late CppSetAbort(1) cannot fire
|
||||
// after the function returns and corrupt the next transcription call.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
text := ""
|
||||
for i := range int(segsLen) {
|
||||
// segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895
|
||||
s := CppGetSegmentStart(i) * (10000000)
|
||||
t := CppGetSegmentEnd(i) * (10000000)
|
||||
// The session result can emit bytes that aren't valid UTF-8 (e.g. a
|
||||
// multibyte codepoint split across token boundaries); protobuf string
|
||||
// fields reject those at marshal time. Scrub before the value escapes
|
||||
// cgo. The session result is segment+word based and exposes no token
|
||||
// IDs, so Tokens is left empty.
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
|
||||
segment := &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
|
||||
text += " " + strings.TrimSpace(txt)
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: strings.TrimSpace(text),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream runs the session transcribe to completion and then
|
||||
// emits one delta per non-empty segment, followed by a final TranscriptResult.
|
||||
// Progressive/real-time streaming isn't available via the session API (there
|
||||
// is no per-decode callback), so deltas are emitted per-segment after the
|
||||
// blocking decode returns rather than as segments are produced. The offline
|
||||
// AudioTranscription is unchanged; both paths share the session and the
|
||||
// SingleThread concurrency model.
|
||||
func (w *CrispASR) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
|
||||
// Same abort-watcher pattern as AudioTranscription. Joined synchronously
|
||||
// so a late CppSetAbort(1) cannot fire after this function returns.
|
||||
// Best-effort only: the session transcribe is blocking with no abort hook.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
// Walk the segments once: emit a delta per non-empty segment and build the
|
||||
// final TranscriptResult.Segments alongside. The first delta has no leading
|
||||
// space and subsequent ones are prefixed with a single space, so
|
||||
// concat(deltas) == final.Text exactly, matching the e2e contract.
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
var assembled strings.Builder
|
||||
for i := range int(segsLen) {
|
||||
s := CppGetSegmentStart(i) * 10000000
|
||||
t := CppGetSegmentEnd(i) * 10000000
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
segments = append(segments, &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
})
|
||||
|
||||
trimmed := strings.TrimSpace(txt)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
var delta string
|
||||
if assembled.Len() == 0 {
|
||||
delta = trimmed
|
||||
} else {
|
||||
delta = " " + trimmed
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
assembled.WriteString(delta)
|
||||
}
|
||||
|
||||
final := &pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: assembled.String(),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: final}
|
||||
return nil
|
||||
}
|
||||
|
||||
// synthesize returns 24 kHz mono float32 PCM for text via the open session.
|
||||
func (w *CrispASR) synthesize(text string) ([]float32, error) {
|
||||
if text == "" {
|
||||
return nil, fmt.Errorf("crispasr: TTS requires non-empty text")
|
||||
}
|
||||
var n int32
|
||||
ptr := CppTTSSynthesize(text, unsafe.Pointer(&n))
|
||||
if ptr == 0 || n <= 0 {
|
||||
return nil, fmt.Errorf("crispasr: synthesis failed (the loaded model may not be a supported TTS backend, or needs extra config e.g. orpheus SNAC codec)")
|
||||
}
|
||||
defer CppTTSFree(ptr)
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free.
|
||||
out := make([]float32, int(n)) // copy out of C memory before free
|
||||
copy(out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// setVoice applies a per-call speaker/voice override (best effort). CrispASR
|
||||
// returns a negative code when the active backend can't honor the name; we log
|
||||
// it rather than fail, so an unknown voice falls back to the default speaker.
|
||||
func setVoice(voice string) {
|
||||
v := strings.TrimSpace(voice)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if rc := CppTTSSetVoice(v); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: voice %q not applied by the active TTS backend (rc=%d); using default\n", v, rc)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *CrispASR) TTS(req *pb.TTSRequest) error {
|
||||
if req.Dst == "" {
|
||||
return fmt.Errorf("crispasr: TTS requires a destination path")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWAV24k(req.Dst, pcm)
|
||||
}
|
||||
|
||||
// TTSStream is the streaming counterpart to TTS. CrispASR has no progressive
|
||||
// (native streaming) synth, so we synthesize the whole utterance, encode it to
|
||||
// a 24 kHz WAV, and emit the encoded bytes as a single chunk. The gRPC server
|
||||
// wrapper (pkg/grpc/server.go:TTSStream) ranges over the channel until it is
|
||||
// closed, so this method owns the close - mirrors vibevoice-cpp's TTSStream.
|
||||
func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
defer close(results)
|
||||
|
||||
if req.Text == "" {
|
||||
return fmt.Errorf("crispasr: TTSStream requires text")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp("", "crispasr-tts-stream-*.wav")
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: tempfile: %w", err)
|
||||
}
|
||||
dst := tmp.Name()
|
||||
if err := tmp.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close tempfile: %w", err)
|
||||
}
|
||||
defer func() { _ = os.Remove(dst) }()
|
||||
|
||||
if err := writeWAV24k(dst, pcm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoded, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: read tempfile: %w", err)
|
||||
}
|
||||
results <- encoded
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeWAV24k writes pcm as a 24000 Hz, mono, 16-bit PCM WAV at dst.
|
||||
func writeWAV24k(dst string, pcm []float32) error {
|
||||
f, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: create %q: %w", dst, err)
|
||||
}
|
||||
|
||||
enc := wav.NewEncoder(f, 24000, 16, 1, 1)
|
||||
ints := make([]int, len(pcm))
|
||||
for i, s := range pcm {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
ints[i] = int(s * 32767)
|
||||
}
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 24000},
|
||||
Data: ints,
|
||||
SourceBitDepth: 16,
|
||||
}
|
||||
if err := enc.Write(buf); err != nil {
|
||||
_ = enc.Close()
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: encode WAV: %w", err)
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: finalize WAV: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close %q: %w", dst, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
193
backend/go/crispasr/gocrispasr_test.go
Normal file
193
backend/go/crispasr/gocrispasr_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestCrispASR(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "CrispASR Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// bridge without spinning up the gRPC server. Skips the current spec when the
|
||||
// shared library isn't present (e.g. running before `make backends/whisper`).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
if _, err := os.Stat(libName); err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
|
||||
purego.RegisterLibFunc(&CppSetCodecPath, gosd, "set_codec_path")
|
||||
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
|
||||
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
|
||||
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
|
||||
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
|
||||
purego.RegisterLibFunc(&CppGetBackend, gosd, "get_backend")
|
||||
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
|
||||
purego.RegisterLibFunc(&CppTTSSynthesize, gosd, "tts_synthesize")
|
||||
purego.RegisterLibFunc(&CppTTSFree, gosd, "tts_free")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoice, gosd, "tts_set_voice")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoiceFile, gosd, "tts_set_voice_file")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("whisper library not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if either
|
||||
// env var is unset. The test never runs in default CI — it requires a real
|
||||
// whisper model and a long audio file (~3 minutes) on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("CRISPASR_MODEL_PATH")
|
||||
audioPath := os.Getenv("CRISPASR_AUDIO_PATH")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set CRISPASR_MODEL_PATH and CRISPASR_AUDIO_PATH to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// ttsModelOrSkip returns the TTS model path or skips the spec when the env var
|
||||
// is unset. Like the transcription fixtures, this never runs in default CI — it
|
||||
// needs a real TTS model (e.g. a vibevoice GGUF) on disk.
|
||||
func ttsModelOrSkip() string {
|
||||
modelPath := os.Getenv("CRISPASR_TTS_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
Skip("set CRISPASR_TTS_MODEL_PATH to run this spec")
|
||||
}
|
||||
return modelPath
|
||||
}
|
||||
|
||||
var _ = Describe("CrispASR", func() {
|
||||
Context("AudioTranscription cancellation", func() {
|
||||
It("returns codes.Canceled on a pre-cancelled context and still succeeds afterwards", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
// The session transcribe is blocking and exposes no abort hook, so
|
||||
// a mid-decode cancel can't interrupt it. The contract we can rely
|
||||
// on is the pre-call ctx.Err() check: a context cancelled before
|
||||
// the call must yield codes.Canceled without starting a decode.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "expected pre-cancelled context to fail")
|
||||
st, ok := status.FromError(err)
|
||||
Expect(ok).To(BeTrue(), "expected gRPC status error, got %v", err)
|
||||
Expect(st.Code()).To(Equal(codes.Canceled), "expected codes.Canceled, got %v", err)
|
||||
|
||||
// Subsequent transcription must succeed — proves g_abort reset.
|
||||
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "post-cancel transcription failed")
|
||||
Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("emits multiple deltas progressively for a multi-segment clip", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- w.AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
Stream: true,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var assembled strings.Builder
|
||||
var finalText string
|
||||
var finalSegmentCount int
|
||||
for chunk := range results {
|
||||
if d := chunk.GetDelta(); d != "" {
|
||||
deltas = append(deltas, d)
|
||||
assembled.WriteString(d)
|
||||
}
|
||||
if final := chunk.GetFinalResult(); final != nil {
|
||||
finalText = final.GetText()
|
||||
finalSegmentCount = len(final.GetSegments())
|
||||
}
|
||||
}
|
||||
Expect(<-done).ToNot(HaveOccurred())
|
||||
|
||||
// One delta per non-empty segment is emitted after the blocking
|
||||
// decode returns (the session API has no per-decode callback), so a
|
||||
// multi-segment clip MUST produce >=2 delta events, and
|
||||
// concat(deltas) MUST equal final.Text exactly.
|
||||
Expect(len(deltas)).To(BeNumerically(">=", 2),
|
||||
"expected multiple deltas from a multi-segment clip, got %d (assembled=%q)",
|
||||
len(deltas), assembled.String())
|
||||
Expect(finalSegmentCount).To(BeNumerically(">=", 2),
|
||||
"expected final to carry multiple segments")
|
||||
Expect(assembled.String()).To(Equal(finalText),
|
||||
"concat(deltas) must equal final.Text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("TTS", func() {
|
||||
It("synthesizes a non-empty WAV", func() {
|
||||
ttsModel := ttsModelOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: ttsModel})).To(Succeed())
|
||||
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR.", Dst: dst})).To(Succeed())
|
||||
|
||||
info, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred(), "synthesized WAV should exist at %q", dst)
|
||||
// A real 24 kHz mono WAV is a 44-byte header plus samples; anything
|
||||
// this small would mean an empty/failed synth.
|
||||
Expect(info.Size()).To(BeNumerically(">", 1024),
|
||||
"expected a non-trivial WAV, got %d bytes", info.Size())
|
||||
})
|
||||
})
|
||||
})
|
||||
58
backend/go/crispasr/main.go
Normal file
58
backend/go/crispasr/main.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "load_model"},
|
||||
{&CppSetCodecPath, "set_codec_path"},
|
||||
{&CppLoadModelVAD, "load_model_vad"},
|
||||
{&CppVAD, "vad"},
|
||||
{&CppTranscribe, "transcribe"},
|
||||
{&CppGetSegmentText, "get_segment_text"},
|
||||
{&CppGetSegmentStart, "get_segment_t0"},
|
||||
{&CppGetSegmentEnd, "get_segment_t1"},
|
||||
{&CppGetBackend, "get_backend"},
|
||||
{&CppSetAbort, "set_abort"},
|
||||
{&CppTTSSynthesize, "tts_synthesize"},
|
||||
{&CppTTSFree, "tts_free"},
|
||||
{&CppTTSSetVoice, "tts_set_voice"},
|
||||
{&CppTTSSetVoiceFile, "tts_set_voice_file"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &CrispASR{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
65
backend/go/crispasr/package.sh
Executable file
65
backend/go/crispasr/package.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
# This script is used in the final stage of the Dockerfile
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/crispasr $CURDIR/package/
|
||||
cp -fv $CURDIR/libgocrispasr-*.so $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/crispasr/run.sh
Executable file
52
backend/go/crispasr/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgocrispasr-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export CRISPASR_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/crispasr "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/crispasr "$@"
|
||||
10
backend/go/dllm/.gitignore
vendored
Normal file
10
backend/go/dllm/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
.cache/
|
||||
sources/
|
||||
build/
|
||||
package/
|
||||
dllm-grpc
|
||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||
# symlinked for local dev; the real sources live in dllm.cpp upstream.
|
||||
*.so
|
||||
*.so.*
|
||||
compile_commands.json
|
||||
93
backend/go/dllm/Makefile
Normal file
93
backend/go/dllm/Makefile
Normal file
@@ -0,0 +1,93 @@
|
||||
# dllm backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DLLM_VERSION?=<sha> so .github/bump_deps.sh
|
||||
# can find and update it - matches the whisper.cpp / parakeet-cpp / ds4
|
||||
# convention.
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree dllm.cpp build,
|
||||
# you can symlink the .so into this directory and skip the clone/cmake
|
||||
# steps entirely, e.g.:
|
||||
#
|
||||
# ln -sf /path/to/dllm.cpp/build/libdllm.so .
|
||||
# go build -o dllm-grpc .
|
||||
#
|
||||
# That's what the gated C-ABI binding smoke uses (DLLM_TEST_LIBRARY). The
|
||||
# default target below does the proper clone-at-pin + cmake build so CI
|
||||
# doesn't need a side-checkout.
|
||||
#
|
||||
# NOTE: github.com/mudler/dllm.cpp is still private (publishing is planned);
|
||||
# until then the anonymous clone below fails. Use the symlink shortcut above
|
||||
# with a local checkout, or a git credential helper with access to the repo.
|
||||
|
||||
DLLM_VERSION?=b22fcebebfb225131113188599a9ae542b2935d7
|
||||
DLLM_REPO?=https://github.com/mudler/dllm.cpp
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
# libdllm.so is self-contained: dllm.cpp's CMakeLists statically absorbs ggml
|
||||
# (BUILD_SHARED_LIBS=OFF + PIC) into the shared lib, so dlopen needs no
|
||||
# libggml*.so alongside it, only system libs (libstdc++/libgomp/libc) the
|
||||
# runtime image already provides. Tests/CLI are upstream-only concerns.
|
||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DDLLM_BUILD_TESTS=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# Same arch set the sibling ggml backends (acestep/vibevoice/qwen3-tts) bake
|
||||
# for their cublas images; override for a native build.
|
||||
CUDA_ARCHITECTURES?=75-virtual;80-virtual;86-real;89-real
|
||||
|
||||
# dllm.cpp gates CUDA behind DLLM_CUDA (set(GGML_CUDA ... CACHE FORCE)), so
|
||||
# forward that instead of a bare -DGGML_CUDA=ON.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DDLLM_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="$(CUDA_ARCHITECTURES)"
|
||||
endif
|
||||
|
||||
.PHONY: dllm-grpc package build clean purge test all
|
||||
|
||||
all: dllm-grpc
|
||||
|
||||
# Clone the upstream dllm.cpp source at the pinned commit (ggml comes in as
|
||||
# a submodule). Directory acts as the target so make only re-clones when
|
||||
# missing. After a DLLM_VERSION bump, run 'make purge && make' to refetch.
|
||||
sources/dllm.cpp:
|
||||
mkdir -p sources/dllm.cpp
|
||||
cd sources/dllm.cpp && \
|
||||
git init -q && \
|
||||
git remote add origin $(DLLM_REPO) && \
|
||||
git fetch --depth 1 origin $(DLLM_VERSION) && \
|
||||
git checkout FETCH_HEAD && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Build the shared lib out-of-tree, then stage it next to the Go sources so
|
||||
# purego.Dlopen("libdllm.so") and the packaging step both pick it up.
|
||||
libdllm.so: sources/dllm.cpp
|
||||
cmake -B sources/dllm.cpp/build -S sources/dllm.cpp $(CMAKE_ARGS)
|
||||
cmake --build sources/dllm.cpp/build --config Release -j$(JOBS)
|
||||
cp -fv sources/dllm.cpp/build/libdllm.so ./
|
||||
|
||||
dllm-grpc: libdllm.so main.go capi.go
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o dllm-grpc .
|
||||
|
||||
package: dllm-grpc
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
# Test target. The C-ABI binding smoke is gated on DLLM_TEST_LIBRARY +
|
||||
# DLLM_TEST_TINY_MODEL; without them the gated specs auto-skip and only the
|
||||
# pure-Go helper specs run.
|
||||
test:
|
||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||
|
||||
clean: purge
|
||||
rm -rf libdllm.so* package dllm-grpc
|
||||
|
||||
purge:
|
||||
rm -rf sources/dllm.cpp
|
||||
256
backend/go/dllm/capi.go
Normal file
256
backend/go/dllm/capi.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package main
|
||||
|
||||
// Typed Go wrappers over dllm.cpp's flat C-ABI (include/dllm_capi.h, ABI v1).
|
||||
//
|
||||
// Contract highlights the wrappers encode (see the header + src/capi.cpp):
|
||||
// - tokenize_json/generate return malloc'd char* the CALLER owns: bound as
|
||||
// uintptr, copied with goStringFromCPtr, released via dllm_capi_free_string.
|
||||
// - last_error returns a BORROWED pointer (valid until the next call on the
|
||||
// same ctx): bound as a plain string (purego copies), never freed, and only
|
||||
// read AFTER the failing call has returned - reading it while a generate is
|
||||
// in flight on the same ctx violates the per-ctx serialization contract.
|
||||
// - All entry points except dllm_capi_cancel must be externally serialized
|
||||
// per ctx (one ctx = one concurrent generate/tokenize). Cancel only flips
|
||||
// an atomic and may be called from any goroutine mid-generate.
|
||||
// - No C++ exception crosses the boundary; failures land in last_error.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
)
|
||||
|
||||
// dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written
|
||||
// against; main.go refuses to start against a libdllm.so reporting another.
|
||||
const dllmABIVersion = 1
|
||||
|
||||
// purego-bound entry points from libdllm.so. Names match dllm_capi.h
|
||||
// exactly; loadCAPI (main.go) fills these in at boot.
|
||||
var (
|
||||
cppAbiVersion func() int32
|
||||
cppLoad func(ggufPath, paramsJSON string) uintptr
|
||||
cppFree func(ctx uintptr)
|
||||
cppLastError func(ctx uintptr) string // borrowed pointer: purego copies, do NOT free
|
||||
cppFreeString func(s uintptr)
|
||||
// malloc'd char* returns, hence uintptr (see loadCAPI's doc comment).
|
||||
cppTokenizeJSON func(ctx uintptr, text string) uintptr
|
||||
cppGenerate func(ctx uintptr, prompt, optsJSON string) uintptr
|
||||
// on_block/on_step are C function pointers produced by purego.NewCallback;
|
||||
// userData carries the streamCallStates registry key.
|
||||
cppGenerateStream func(ctx uintptr, prompt, optsJSON string, onBlock, onStep, userData uintptr) int32
|
||||
cppCancel func(ctx uintptr)
|
||||
)
|
||||
|
||||
// cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION.
|
||||
func cAbiVersion() int32 {
|
||||
return cppAbiVersion()
|
||||
}
|
||||
|
||||
// cLoad opens the GGUF at path with the flat params JSON (e.g.
|
||||
// {"n_gpu_layers":99}). Returns 0 on failure; per the header contract there
|
||||
// is no ctx to carry the reason, the C side logs it to stderr (and
|
||||
// cLastError(0) only yields the static NULL-ctx message).
|
||||
func cLoad(path, paramsJSON string) uintptr {
|
||||
return cppLoad(path, paramsJSON)
|
||||
}
|
||||
|
||||
// cFree releases a ctx; safe on 0 (delete nullptr).
|
||||
func cFree(h uintptr) {
|
||||
cppFree(h)
|
||||
}
|
||||
|
||||
// cLastError returns the ctx's last error message (or the static NULL-ctx
|
||||
// message for h==0). The C pointer is borrowed and only valid until the next
|
||||
// call on the same ctx; purego's string return copies it immediately, so the
|
||||
// returned Go string is safe to keep. Must not be called while another call
|
||||
// on the same ctx is in flight.
|
||||
func cLastError(h uintptr) string {
|
||||
return cppLastError(h)
|
||||
}
|
||||
|
||||
// lastErrorOr is cLastError with a fallback for the empty-message case, so
|
||||
// wrapped errors never end in ": ".
|
||||
func lastErrorOr(h uintptr, fallback string) string {
|
||||
if msg := cLastError(h); msg != "" {
|
||||
return msg
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// cTokenizeJSON tokenizes text (the C side prepends bos per vocab.add_bos)
|
||||
// and returns the token ids as a JSON array string, e.g. "[2,18]".
|
||||
func cTokenizeJSON(h uintptr, text string) (string, error) {
|
||||
ret := cppTokenizeJSON(h, text)
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("dllm: tokenize failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
out := goStringFromCPtr(ret)
|
||||
cppFreeString(ret)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// cGenerate runs a blocking generation and returns the detokenized text.
|
||||
// optsJSON must be a FLAT JSON object of scalars (use buildOptsJSON); the C
|
||||
// parser rejects nested objects/arrays. NULL return -> last_error (read only
|
||||
// after the call returned, per the serialization contract); a cancelled call
|
||||
// surfaces as the "cancelled" message.
|
||||
func cGenerate(h uintptr, prompt, optsJSON string) (string, error) {
|
||||
ret := cppGenerate(h, prompt, optsJSON)
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("dllm: generate failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
out := goStringFromCPtr(ret)
|
||||
cppFreeString(ret)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// streamCallState carries the Go callbacks for one in-flight
|
||||
// cGenerateStream call; the registry key travels through C as user_data.
|
||||
// The map shape mirrors the whisper backend's streamCallStates: only one
|
||||
// entry per ctx is ever live (the C-ABI is serialized per ctx), but keying
|
||||
// by call survives multiple models/processes sharing the package.
|
||||
type streamCallState struct {
|
||||
onBlock func(text string)
|
||||
onStep func(step, total int, preview string)
|
||||
}
|
||||
|
||||
var (
|
||||
streamCallStates sync.Map // uint64 -> *streamCallState
|
||||
streamCallSeq atomic.Uint64
|
||||
|
||||
// purego.NewCallback allocates a finite, never-released callback slot, so
|
||||
// the two trampolines are created exactly once and reused across calls.
|
||||
streamCbOnce sync.Once
|
||||
blockCbPtr uintptr
|
||||
stepCbPtr uintptr
|
||||
)
|
||||
|
||||
// onBlockTrampoline is the Go side of dllm_block_cb. It runs on the C
|
||||
// calling thread, mid-generate: keep it tiny and non-blocking (callers that
|
||||
// bridge to goroutines must hand off via buffered channels). The text
|
||||
// pointer is only valid for the duration of the invocation, so it is copied
|
||||
// to a Go string immediately.
|
||||
func onBlockTrampoline(text uintptr, userData uintptr) {
|
||||
v, ok := streamCallStates.Load(uint64(userData))
|
||||
if !ok {
|
||||
return // call already torn down
|
||||
}
|
||||
state := v.(*streamCallState)
|
||||
if state.onBlock != nil {
|
||||
state.onBlock(goStringFromCPtr(text))
|
||||
}
|
||||
}
|
||||
|
||||
// onStepTrampoline is the Go side of dllm_step_cb; same threading and
|
||||
// lifetime caveats as onBlockTrampoline.
|
||||
func onStepTrampoline(step int32, totalSteps int32, canvasPreview uintptr, userData uintptr) {
|
||||
v, ok := streamCallStates.Load(uint64(userData))
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
state := v.(*streamCallState)
|
||||
if state.onStep != nil {
|
||||
state.onStep(int(step), int(totalSteps), goStringFromCPtr(canvasPreview))
|
||||
}
|
||||
}
|
||||
|
||||
// cGenerateStream runs a generation with per-committed-block (onBlock) and
|
||||
// per-denoising-step (onStep) callbacks; either may be nil. The callbacks
|
||||
// run on the C thread (see the trampoline docs). Returns an error carrying
|
||||
// last_error on failure; cancellation surfaces as the "cancelled" message.
|
||||
func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error {
|
||||
streamCbOnce.Do(func() {
|
||||
blockCbPtr = purego.NewCallback(onBlockTrampoline)
|
||||
stepCbPtr = purego.NewCallback(onStepTrampoline)
|
||||
})
|
||||
|
||||
id := streamCallSeq.Add(1)
|
||||
streamCallStates.Store(id, &streamCallState{onBlock: onBlock, onStep: onStep})
|
||||
defer streamCallStates.Delete(id)
|
||||
|
||||
// Pass NULL for absent callbacks so the C side skips the per-block /
|
||||
// per-step detokenize work entirely.
|
||||
var blockPtr, stepPtr uintptr
|
||||
if onBlock != nil {
|
||||
blockPtr = blockCbPtr
|
||||
}
|
||||
if onStep != nil {
|
||||
stepPtr = stepCbPtr
|
||||
}
|
||||
|
||||
if rc := cppGenerateStream(h, prompt, optsJSON, blockPtr, stepPtr, uintptr(id)); rc != 0 {
|
||||
return fmt.Errorf("dllm: generate_stream failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cCancel requests cancellation of the in-flight generate on h. This is the
|
||||
// ONE entry point safe to call from any goroutine while a generate runs (it
|
||||
// only flips an atomic). Note the cancel-reset race from the header: each
|
||||
// generate resets the flag on entry, so a watchdog should re-issue cancel if
|
||||
// the call has not returned.
|
||||
func cCancel(h uintptr) {
|
||||
cppCancel(h)
|
||||
}
|
||||
|
||||
// buildOptsJSON renders generation options as the flat JSON object the
|
||||
// C-ABI expects (known keys: n_predict, blocks, seed, eb_*, kv_cache). The
|
||||
// C-side scanner only understands scalar number/string values and rejects
|
||||
// nested objects/arrays loudly; bools are rejected here too because the
|
||||
// scanner has no concept of them. Fail loud rather than let an option be
|
||||
// silently misread.
|
||||
//
|
||||
// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g.
|
||||
// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys
|
||||
// (kv_cache: auto|on|off) can contain those bytes today; if one ever does,
|
||||
// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString.
|
||||
func buildOptsJSON(opts map[string]any) (string, error) {
|
||||
if len(opts) == 0 {
|
||||
return "{}", nil
|
||||
}
|
||||
for k, v := range opts {
|
||||
switch v.(type) {
|
||||
case string,
|
||||
int, int8, int16, int32, int64,
|
||||
uint, uint8, uint16, uint32, uint64,
|
||||
float32, float64,
|
||||
json.Number:
|
||||
// scalar: fine
|
||||
default:
|
||||
return "", fmt.Errorf("dllm: opts key %q has non-scalar value %T (the C-ABI only accepts flat number/string scalars)", k, v)
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(opts)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("dllm: marshal opts: %w", err)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is
|
||||
// the raw pointer returned by purego from the C-ABI (a malloc'd buffer the
|
||||
// caller owns, or a callback argument only valid during the invocation);
|
||||
// owning callers must free it via cppFreeString after the copy lands.
|
||||
//
|
||||
// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check,
|
||||
// which can't distinguish a C-owned heap pointer from Go-managed memory (the
|
||||
// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting
|
||||
// through &cptr below is equivalent at runtime and keeps plain `go vet`
|
||||
// clean. It is safe either way: the pointer addresses C memory the Go GC
|
||||
// neither tracks nor moves, and we dereference it immediately to copy the
|
||||
// bytes out.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
}
|
||||
return string(unsafe.Slice((*byte)(p), n))
|
||||
}
|
||||
553
backend/go/dllm/dllm.go
Normal file
553
backend/go/dllm/dllm.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package main
|
||||
|
||||
// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models).
|
||||
//
|
||||
// Wiring overview:
|
||||
// - Load opens the GGUF via dllm_capi_load and starts the per-model worker
|
||||
// goroutine that serializes every C call (see submit).
|
||||
// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the
|
||||
// request carries raw messages (use_tokenizer_template), the backend owns
|
||||
// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies
|
||||
// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend.
|
||||
// - The legacy Predict / PredictStream methods delegate to the rich pair
|
||||
// (cloud-proxy precedent); the gRPC server prefers the rich path anyway.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// The gRPC server cancels in-flight generations on client disconnect only
|
||||
// for backends advertising the Cancellable capability; keep Dllm pinned to
|
||||
// it so a signature drift fails the build, not the disconnect path.
|
||||
var _ grpc.Cancellable = (*Dllm)(nil)
|
||||
|
||||
// generator is the seam between the backend wiring and the dllm.cpp C-ABI:
|
||||
// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON
|
||||
// family, while tests substitute a fake to exercise prompt construction,
|
||||
// parsing and serialization without libdllm.so.
|
||||
type generator interface {
|
||||
generate(prompt, optsJSON string) (string, error)
|
||||
// generateStream invokes onBlock once per committed diffusion block, on
|
||||
// the thread running the C call, before returning.
|
||||
generateStream(prompt, optsJSON string, onBlock func(text string)) error
|
||||
tokenizeJSON(text string) (string, error)
|
||||
// cancel is the ONE entry point safe to call concurrently with an
|
||||
// in-flight generate on the same ctx (dllm_capi.h: it only flips an
|
||||
// atomic; everything else must be externally serialized per ctx).
|
||||
cancel()
|
||||
free()
|
||||
}
|
||||
|
||||
// capiGenerator is the production generator over one dllm_ctx handle.
|
||||
type capiGenerator struct {
|
||||
h uintptr
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) {
|
||||
return cGenerate(g.h, prompt, optsJSON)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
|
||||
// on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is
|
||||
// passed as nil for now: a future progress hook for the React UI can
|
||||
// plumb it through without touching the C binding.
|
||||
return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) tokenizeJSON(text string) (string, error) {
|
||||
return cTokenizeJSON(g.h, text)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) cancel() {
|
||||
cCancel(g.h)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) free() {
|
||||
cFree(g.h)
|
||||
}
|
||||
|
||||
// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts
|
||||
// one backend process per model).
|
||||
type Dllm struct {
|
||||
base.Base
|
||||
|
||||
gen generator
|
||||
// genOpts holds the model-level generation overrides parsed from
|
||||
// ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes
|
||||
// them per-generate, not per-load, so they are merged into every
|
||||
// request's opts JSON (requestOptsJSON).
|
||||
genOpts map[string]any
|
||||
|
||||
// jobs is the per-model worker queue. dllm_capi.h requires every entry
|
||||
// point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one
|
||||
// ctx = one concurrent generate/tokenize; last_error is unsafe to read
|
||||
// while a call is in flight). A single goroutine owning all C calls makes
|
||||
// that contract structural instead of relying on lock discipline.
|
||||
jobs chan func()
|
||||
workerWG sync.WaitGroup
|
||||
|
||||
// genMu guards gen against Free racing in-flight requests: requests hold
|
||||
// the read lock for their full duration (they stay concurrent with each
|
||||
// other - the worker still serializes the C calls), Free takes the write
|
||||
// lock so it can only run when no request is in flight.
|
||||
genMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (d *Dllm) startWorker() {
|
||||
d.jobs = make(chan func())
|
||||
d.workerWG.Add(1)
|
||||
go func() {
|
||||
defer d.workerWG.Done()
|
||||
for job := range d.jobs {
|
||||
job()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// submit runs job on the worker goroutine and waits for it to finish.
|
||||
// Concurrent gRPC requests therefore queue up and execute one at a time
|
||||
// against the single dllm_ctx.
|
||||
func (d *Dllm) submit(job func()) {
|
||||
done := make(chan struct{})
|
||||
d.jobs <- func() {
|
||||
defer close(done)
|
||||
job()
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
// Load opens the GGUF and prepares the worker. Load-time engine parameters
|
||||
// travel as the flat params JSON of dllm_capi_load; generation overrides
|
||||
// from Options are stored for per-request opts JSON instead (the C-ABI has
|
||||
// no per-load sampler state).
|
||||
func (d *Dllm) Load(opts *pb.ModelOptions) error {
|
||||
if d.gen != nil {
|
||||
return errors.New("dllm: model already loaded")
|
||||
}
|
||||
|
||||
params := map[string]any{
|
||||
"n_gpu_layers": opts.GetNGPULayers(),
|
||||
}
|
||||
if opts.GetThreads() > 0 {
|
||||
params["n_threads"] = opts.GetThreads()
|
||||
}
|
||||
if opts.GetContextSize() > 0 {
|
||||
params["ctx_len"] = opts.GetContextSize()
|
||||
}
|
||||
paramsJSON, err := buildOptsJSON(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.genOpts = parseModelGenOpts(opts.GetOptions())
|
||||
|
||||
h := cLoad(opts.GetModelFile(), paramsJSON)
|
||||
if h == 0 {
|
||||
// No ctx exists on load failure, so last_error(NULL) only carries the
|
||||
// static NULL-ctx message; the real reason is on the backend's stderr.
|
||||
return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)",
|
||||
opts.GetModelFile(), lastErrorOr(0, "unknown error"))
|
||||
}
|
||||
d.gen = &capiGenerator{h: h}
|
||||
d.startWorker()
|
||||
xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Free releases the dllm ctx and stops the worker. Safe when never loaded.
|
||||
//
|
||||
// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the
|
||||
// model-unload path around line 764) calls Free with no locking of its own,
|
||||
// and base.Base provides none either. Without it a request racing Free would
|
||||
// panic sending on the closed jobs channel - or worse, generate on a freed C
|
||||
// ctx. Holding genMu until gen is nil also turns post-Free requests into a
|
||||
// clean "model not loaded" error instead of a crash.
|
||||
func (d *Dllm) Free() error {
|
||||
d.genMu.Lock()
|
||||
defer d.genMu.Unlock()
|
||||
if d.gen == nil {
|
||||
return nil
|
||||
}
|
||||
d.submit(d.gen.free)
|
||||
close(d.jobs)
|
||||
d.workerWG.Wait()
|
||||
d.gen = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel requests cancellation of the in-flight generate (the
|
||||
// grpc.Cancellable capability). The gRPC server arms it via
|
||||
// context.AfterFunc on the request/stream context, so a client
|
||||
// disconnect or timeout aborts the generation server-side - the same
|
||||
// semantics the llama.cpp C++ backend gets from polling IsCancelled().
|
||||
// It deliberately bypasses the worker queue: dllm_capi_cancel is the one
|
||||
// call the C-ABI allows from any goroutine mid-generate (it only flips
|
||||
// an atomic).
|
||||
//
|
||||
// Note dllm_capi.h's cancel-reset race: each generate resets the flag on
|
||||
// entry, so a Cancel racing a NEW generate on the same ctx can be lost
|
||||
// (and, with requests queued on the worker, it aborts whichever generate
|
||||
// is currently running). The single-flag granularity is acceptable here
|
||||
// because the server de-registers the hook on normal completion and one
|
||||
// backend process serves one model.
|
||||
func (d *Dllm) Cancel() {
|
||||
// RLock so a server-side AfterFunc firing in the window between a
|
||||
// request finishing and a model unload cannot touch a freed C ctx
|
||||
// (Free holds the write lock while tearing gen down). cancel() is the
|
||||
// one C call that is safe concurrently with an in-flight generate, so
|
||||
// taking a read lock here cannot deadlock against request holders.
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen != nil {
|
||||
d.gen.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to
|
||||
// the engine. Options is a shared free-form bag (other layers put their own
|
||||
// entries there), so unknown keys are skipped with a warning, not an error.
|
||||
var dllmGenOptKeys = map[string]bool{
|
||||
"blocks": true,
|
||||
"kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3
|
||||
}
|
||||
|
||||
// parseModelGenOpts parses "key:value" Options entries into the flat scalar
|
||||
// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler
|
||||
// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by
|
||||
// first successful parse (int, then float, else string) to match the C
|
||||
// scanner's number/string scalars.
|
||||
func parseModelGenOpts(options []string) map[string]any {
|
||||
out := map[string]any{}
|
||||
for _, o := range options {
|
||||
key, val, found := strings.Cut(o, ":")
|
||||
if !found {
|
||||
xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o)
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] {
|
||||
xlog.Debug("dllm: ignoring unrecognized option", "key", key)
|
||||
continue
|
||||
}
|
||||
out[key] = parseScalarOpt(val)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseScalarOpt(v string) any {
|
||||
if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return iv
|
||||
}
|
||||
if fv, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return fv
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default
|
||||
// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat
|
||||
// template guards every thinking branch with `enable_thinking is defined and
|
||||
// enable_thinking`, i.e. thinking is opt-in for this model family, and the
|
||||
// no-thinking render pre-closes an empty thought channel that the OFF
|
||||
// default must produce.
|
||||
func metadataEnableThinking(opts *pb.PredictOptions) bool {
|
||||
v := opts.GetMetadata()["enable_thinking"]
|
||||
return v == "true" || v == "1"
|
||||
}
|
||||
|
||||
// buildPrompt resolves the prompt for a request. With use_tokenizer_template
|
||||
// and raw messages the backend owns templating (RenderGemma4) and the output
|
||||
// is in the known gemma4 format, so parse=true. Without it the caller
|
||||
// templated the prompt themselves (LocalAI's Go templates + PEG fallback, or
|
||||
// a bare completion): the prompt passes through verbatim and the output is
|
||||
// NOT gemma4-parsed - it is emitted as plain content and the Go side's
|
||||
// extraction applies, as for any non-autoparsing backend.
|
||||
func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) {
|
||||
if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 {
|
||||
prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), metadataEnableThinking(opts), true)
|
||||
return prompt, true, err
|
||||
}
|
||||
return opts.GetPrompt(), false, nil
|
||||
}
|
||||
|
||||
// requestOptsJSON merges the model-level overrides with the request's
|
||||
// sampling fields into the flat opts JSON for one generate call.
|
||||
func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) {
|
||||
m := make(map[string]any, len(d.genOpts)+2)
|
||||
for k, v := range d.genOpts {
|
||||
m[k] = v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
// The engine rounds n_predict UP to a whole number of diffusion
|
||||
// blocks (the canvas is denoised block-wise), so the completion may
|
||||
// run slightly past the requested budget. Tokens==0 omits the key so
|
||||
// the C-ABI default of 256 applies (hardcoded in capi.cpp's
|
||||
// parse_gen_opts, independent of canvas_length).
|
||||
m["n_predict"] = n
|
||||
}
|
||||
if s := opts.GetSeed(); s > 0 {
|
||||
// The engine seeds mt19937 with explicit non-negative seeds. Seed<=0
|
||||
// is omitted: proto3 cannot distinguish 0 from unset, and negative
|
||||
// values conventionally mean "random" across LocalAI backends.
|
||||
m["seed"] = s
|
||||
}
|
||||
return buildOptsJSON(m)
|
||||
}
|
||||
|
||||
// prepareRequest is the shared prologue of the rich methods: resolve the
|
||||
// prompt (and whether the output gets gemma4-parsed) and build the per-call
|
||||
// opts JSON.
|
||||
func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON string, err error) {
|
||||
prompt, parse, err = buildPrompt(opts)
|
||||
if err != nil {
|
||||
return "", false, "", err
|
||||
}
|
||||
optsJSON, err = d.requestOptsJSON(opts)
|
||||
if err != nil {
|
||||
return "", false, "", err
|
||||
}
|
||||
return prompt, parse, optsJSON, nil
|
||||
}
|
||||
|
||||
// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary
|
||||
// detokenization and byte-fallback tokens can produce invalid UTF-8, and
|
||||
// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so
|
||||
// every string destined for a Reply/ChatDelta must pass through here (or
|
||||
// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely
|
||||
// undecodable: replace with U+FFFD rather than crash the stream.
|
||||
func sanitizeUTF8(s string) string {
|
||||
if utf8.ValidString(s) {
|
||||
return s
|
||||
}
|
||||
return strings.ToValidUTF8(s, "<22>")
|
||||
}
|
||||
|
||||
// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte
|
||||
// (1 for bytes that can never lead a multi-byte sequence, so they are never
|
||||
// held back and fall through to sanitizeUTF8's replacement).
|
||||
func utf8SeqLen(b byte) int {
|
||||
switch {
|
||||
case b&0xE0 == 0xC0:
|
||||
return 2
|
||||
case b&0xF0 == 0xE0:
|
||||
return 3
|
||||
case b&0xF8 == 0xF0:
|
||||
return 4
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// splitValidUTF8 prepends the previous block's carry to the new block and
|
||||
// splits the result into text safe to emit now and a trailing INCOMPLETE
|
||||
// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block:
|
||||
// the per-block detokenize can split a multi-byte character across block
|
||||
// boundaries (llama.cpp's grpc-server holds back the same way). Only a
|
||||
// suffix that can still become a valid rune is withheld; bytes that are
|
||||
// already undecodable are replaced immediately so the carry stays bounded.
|
||||
func splitValidUTF8(carry, block string) (emit, newCarry string) {
|
||||
s := carry + block
|
||||
cut := len(s)
|
||||
for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- {
|
||||
b := s[i]
|
||||
if b < utf8.RuneSelf {
|
||||
break // ASCII: everything before the tail scan is complete
|
||||
}
|
||||
if !utf8.RuneStart(b) {
|
||||
continue // continuation byte: keep looking for its leading byte
|
||||
}
|
||||
// Leading byte: hold the sequence back iff it declares more bytes
|
||||
// than the stream has produced so far (it may complete next block).
|
||||
if utf8SeqLen(b) > len(s)-i {
|
||||
cut = i
|
||||
}
|
||||
break
|
||||
}
|
||||
return sanitizeUTF8(s[:cut]), s[cut:]
|
||||
}
|
||||
|
||||
// PredictRich is the non-streaming inference path (grpc.AIModelRich).
|
||||
// Returns one Reply whose Message is the aggregated assistant content and
|
||||
// whose ChatDeltas carry the parsed content/reasoning/tool-call events.
|
||||
func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return nil, grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
prompt, parse, optsJSON, err := d.prepareRequest(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out string
|
||||
var genErr error
|
||||
d.submit(func() {
|
||||
out, genErr = d.gen.generate(prompt, optsJSON)
|
||||
})
|
||||
if genErr != nil {
|
||||
return nil, genErr
|
||||
}
|
||||
// Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings
|
||||
// must be valid or grpc-go fails the whole reply at marshal time.
|
||||
out = sanitizeUTF8(out)
|
||||
|
||||
if !parse {
|
||||
// Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt).
|
||||
return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil
|
||||
}
|
||||
|
||||
// The prompt renders with add_generation_prompt; both thinking modes
|
||||
// leave the model starting in content state (see the Gemma4Parser header
|
||||
// comment), hence NewGemma4Parser(false).
|
||||
parser := NewGemma4Parser(false)
|
||||
if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil {
|
||||
return reply, nil
|
||||
}
|
||||
// Everything was markers (or out was empty): an empty but non-nil Reply.
|
||||
return &pb.Reply{}, nil
|
||||
}
|
||||
|
||||
// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one
|
||||
// Reply per committed diffusion block that produced deltas. Per the
|
||||
// interface contract the channel is only sent into here - the gRPC server
|
||||
// closes it after this returns (opposite to legacy PredictStream).
|
||||
func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
prompt, parse, optsJSON, err := d.prepareRequest(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var parser *Gemma4Parser
|
||||
if parse {
|
||||
parser = NewGemma4Parser(false)
|
||||
}
|
||||
// emit runs inside onBlock, i.e. on the thread driving the C generate.
|
||||
// Sending on results can block on a slow consumer, but the server-side
|
||||
// pump (pkg/grpc/server.go PredictStream) drains continuously and drops
|
||||
// undeliverable sends, so this backpressure is brief and bounded - and
|
||||
// pausing the diffusion loop under it is the desired behavior anyway.
|
||||
emit := func(text string) {
|
||||
if !parse {
|
||||
if text != "" {
|
||||
results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}}
|
||||
}
|
||||
return
|
||||
}
|
||||
deltas := parser.Feed(text)
|
||||
if reply := replyFromDeltas(deltas); reply != nil {
|
||||
results <- reply
|
||||
}
|
||||
}
|
||||
// onBlock guards emit (and through it the parser) against invalid UTF-8:
|
||||
// a multi-byte character split across block boundaries is held back until
|
||||
// it completes (see splitValidUTF8), so proto3 marshaling never fails.
|
||||
var carry string
|
||||
onBlock := func(block string) {
|
||||
var text string
|
||||
text, carry = splitValidUTF8(carry, block)
|
||||
emit(text)
|
||||
}
|
||||
|
||||
var genErr error
|
||||
d.submit(func() {
|
||||
genErr = d.gen.generateStream(prompt, optsJSON, onBlock)
|
||||
})
|
||||
if genErr != nil {
|
||||
return genErr
|
||||
}
|
||||
if carry != "" {
|
||||
// The stream ended mid-sequence: the held-back bytes can no longer
|
||||
// complete, so flush them through the U+FFFD last resort.
|
||||
emit(sanitizeUTF8(carry))
|
||||
}
|
||||
if parse {
|
||||
if reply := replyFromDeltas(parser.Close()); reply != nil {
|
||||
results <- reply
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply,
|
||||
// or nil when the batch is empty (markers consumed, nothing emitted yet).
|
||||
// Message mirrors the batch's content text so legacy chan-string consumers
|
||||
// see exactly the displayed tokens.
|
||||
func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
var content strings.Builder
|
||||
for _, delta := range deltas {
|
||||
content.WriteString(delta.GetContent())
|
||||
}
|
||||
return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas}
|
||||
}
|
||||
|
||||
// Predict is the legacy (string, error) signature; the gRPC server prefers
|
||||
// PredictRich, this exists for non-rich callers (cloud-proxy precedent).
|
||||
func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
reply, err := d.PredictRich(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.GetMessage()), nil
|
||||
}
|
||||
|
||||
// PredictStream is the legacy chan-string path: rich replies reduced to
|
||||
// their content text. Note the inverted channel ownership - the LEGACY
|
||||
// contract requires the impl to close the channel.
|
||||
func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
defer close(results)
|
||||
richCh := make(chan *pb.Reply)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.PredictStreamRich(opts, richCh)
|
||||
close(richCh)
|
||||
}()
|
||||
for reply := range richCh {
|
||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
||||
results <- string(msg)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C
|
||||
// side prepends bos per the vocab) and decodes the returned id array.
|
||||
func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
var out string
|
||||
var tokErr error
|
||||
d.submit(func() {
|
||||
out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt())
|
||||
})
|
||||
if tokErr != nil {
|
||||
return pb.TokenizationResponse{}, tokErr
|
||||
}
|
||||
var tokens []int32
|
||||
if err := json.Unmarshal([]byte(out), &tokens); err != nil {
|
||||
return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err)
|
||||
}
|
||||
return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil
|
||||
}
|
||||
807
backend/go/dllm/dllm_test.go
Normal file
807
backend/go/dllm/dllm_test.go
Normal file
@@ -0,0 +1,807 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
func TestDllm(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "dllm Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// C-ABI bridge without spinning up the gRPC server. The library path comes
|
||||
// from DLLM_TEST_LIBRARY (gated specs Skip when it is unset).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libLoadErr = loadCAPI(os.Getenv("DLLM_TEST_LIBRARY"))
|
||||
})
|
||||
}
|
||||
|
||||
// C-ABI binding smoke: drives the real libdllm.so against the tiny GGUF
|
||||
// fixture from dllm.cpp (tests/fixtures/tiny_with_vocab.gguf). Gated on:
|
||||
//
|
||||
// DLLM_TEST_LIBRARY absolute path to libdllm.so
|
||||
// DLLM_TEST_TINY_MODEL absolute path to tiny_with_vocab.gguf
|
||||
var _ = Describe("C-ABI binding", func() {
|
||||
BeforeEach(func() {
|
||||
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
|
||||
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the C-ABI binding smoke")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
Expect(libLoadErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("binds the 9 symbols and round-trips the tiny model", func() {
|
||||
Expect(cAbiVersion()).To(Equal(int32(1)))
|
||||
|
||||
h := cLoad(os.Getenv("DLLM_TEST_TINY_MODEL"), "{}")
|
||||
Expect(h).ToNot(BeZero(), "dllm_capi_load of the tiny fixture")
|
||||
|
||||
// Tiny fixture vocab: "hello" tokenizes to ids [2,18] (bos prepended
|
||||
// by the C side: vocab.add_bos).
|
||||
toks, err := cTokenizeJSON(h, "hello")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(toks).To(Equal("[2,18]"))
|
||||
|
||||
// Deterministic generation: an explicit non-negative seed seeds
|
||||
// mt19937, so two identical calls must produce identical text.
|
||||
out1, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out1).ToNot(BeEmpty())
|
||||
// Cancel with no call in flight is dropped: each generate resets the
|
||||
// cancel flag on entry (header contract), so this must not affect
|
||||
// the next call. Also binds the 9th symbol; safe on NULL too.
|
||||
cCancel(h)
|
||||
cCancel(0)
|
||||
|
||||
out2, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out2).To(Equal(out1))
|
||||
|
||||
// Streaming variant: same opts, blocks arrive via the purego
|
||||
// callback trampoline. The per-block detokenize can differ from the
|
||||
// seamless full-text decode at block boundaries, so only assert that
|
||||
// blocks arrived and were non-trivial, not byte equality with out1.
|
||||
var blocks []string
|
||||
var steps int
|
||||
err = cGenerateStream(h, "hello", `{"n_predict":16,"seed":7}`,
|
||||
func(text string) { blocks = append(blocks, text) },
|
||||
func(step, total int, preview string) { steps++ },
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(blocks).ToNot(BeEmpty())
|
||||
Expect(steps).To(BeNumerically(">", 0))
|
||||
|
||||
// Load failure path: NULL ctx back, and last_error(NULL) returns the
|
||||
// static NULL-ctx message (there is no ctx to carry the real reason).
|
||||
bad := cLoad("/nonexistent/dllm-model.gguf", "{}")
|
||||
Expect(bad).To(BeZero())
|
||||
Expect(cLastError(0)).ToNot(BeEmpty())
|
||||
|
||||
// Free is safe on a live handle and a NULL one (delete nullptr).
|
||||
cFree(h)
|
||||
cFree(0)
|
||||
})
|
||||
})
|
||||
|
||||
// Ungated specs for the pure-Go helpers (no libdllm.so required).
|
||||
var _ = Describe("buildOptsJSON", func() {
|
||||
It("renders flat scalars as a JSON object", func() {
|
||||
out, err := buildOptsJSON(map[string]any{
|
||||
"n_predict": 16,
|
||||
"seed": int64(7),
|
||||
"eb_t_min": 0.5,
|
||||
"kv_cache": "auto",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(MatchJSON(`{"n_predict":16,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
|
||||
})
|
||||
|
||||
It("renders an empty object for no options", func() {
|
||||
out, err := buildOptsJSON(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal("{}"))
|
||||
})
|
||||
|
||||
It("rejects nested objects (the C-side scanner only reads flat scalars)", func() {
|
||||
_, err := buildOptsJSON(map[string]any{"sampler": map[string]any{"seed": 1}})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects arrays", func() {
|
||||
_, err := buildOptsJSON(map[string]any{"stop": []string{"a"}})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects booleans (the C-side scanner only understands numbers and strings)", func() {
|
||||
_, err := buildOptsJSON(map[string]any{"flag": true})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("splitValidUTF8", func() {
|
||||
It("holds back a trailing incomplete sequence and completes it next block", func() {
|
||||
emit, carry := splitValidUTF8("", "caf\xe2")
|
||||
Expect(emit).To(Equal("caf"))
|
||||
Expect(carry).To(Equal("\xe2"))
|
||||
|
||||
emit, carry = splitValidUTF8(carry, "\x82")
|
||||
Expect(emit).To(BeEmpty())
|
||||
Expect(carry).To(Equal("\xe2\x82"))
|
||||
|
||||
emit, carry = splitValidUTF8(carry, "\xac!")
|
||||
Expect(emit).To(Equal("€!"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("holds back up to 3 bytes of a 4-byte sequence", func() {
|
||||
emit, carry := splitValidUTF8("", "x\xf0\x9f\x98") // 😀 missing its last byte
|
||||
Expect(emit).To(Equal("x"))
|
||||
Expect(carry).To(Equal("\xf0\x9f\x98"))
|
||||
|
||||
emit, carry = splitValidUTF8(carry, "\x80")
|
||||
Expect(emit).To(Equal("😀"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("replaces undecodable bytes immediately instead of carrying them", func() {
|
||||
// A mid-string invalid byte can never complete: carrying it would let
|
||||
// the carry grow unboundedly, so it is substituted on the spot.
|
||||
emit, carry := splitValidUTF8("", "a\xe2bc")
|
||||
Expect(emit).To(Equal("a<>bc"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
|
||||
// Orphan continuation bytes at the end have no leading byte to wait
|
||||
// for either.
|
||||
emit, carry = splitValidUTF8("", "a\x82")
|
||||
Expect(emit).To(Equal("a<>"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("passes pure ASCII and complete UTF-8 through untouched", func() {
|
||||
emit, carry := splitValidUTF8("", "héllo €")
|
||||
Expect(emit).To(Equal("héllo €"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("goStringFromCPtr", func() {
|
||||
It("copies a NUL-terminated buffer", func() {
|
||||
buf := []byte("dllm\x00")
|
||||
s := goStringFromCPtr(uintptr(unsafe.Pointer(&buf[0])))
|
||||
// The uintptr round-trip hides buf from the GC's liveness analysis;
|
||||
// keep it reachable until after the copy.
|
||||
runtime.KeepAlive(buf)
|
||||
Expect(s).To(Equal("dllm"))
|
||||
})
|
||||
|
||||
It("returns the empty string for NULL", func() {
|
||||
Expect(goStringFromCPtr(0)).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Backend wiring (T4): fake-generator specs, no libdllm.so required.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type fakeGenCall struct {
|
||||
prompt string
|
||||
optsJSON string
|
||||
}
|
||||
|
||||
// fakeGen implements generator in-process. It records every call (prompt +
|
||||
// opts JSON), tracks concurrent in-flight calls to prove worker
|
||||
// serialization, and replays canned output (out for generate/tokenize,
|
||||
// blocks for generateStream).
|
||||
type fakeGen struct {
|
||||
mu sync.Mutex
|
||||
calls []fakeGenCall
|
||||
inFlight int
|
||||
maxInFlight int
|
||||
|
||||
out string
|
||||
blocks []string
|
||||
err error
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (f *fakeGen) begin(prompt, optsJSON string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON})
|
||||
f.inFlight++
|
||||
if f.inFlight > f.maxInFlight {
|
||||
f.maxInFlight = f.inFlight
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeGen) end() {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.inFlight--
|
||||
}
|
||||
|
||||
func (f *fakeGen) snapshot() (calls []fakeGenCall, maxInFlight int) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return append([]fakeGenCall(nil), f.calls...), f.maxInFlight
|
||||
}
|
||||
|
||||
func (f *fakeGen) generate(prompt, optsJSON string) (string, error) {
|
||||
f.begin(prompt, optsJSON)
|
||||
defer f.end()
|
||||
if f.delay > 0 {
|
||||
time.Sleep(f.delay)
|
||||
}
|
||||
return f.out, f.err
|
||||
}
|
||||
|
||||
func (f *fakeGen) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
|
||||
f.begin(prompt, optsJSON)
|
||||
defer f.end()
|
||||
if f.err != nil {
|
||||
return f.err
|
||||
}
|
||||
for _, b := range f.blocks {
|
||||
onBlock(b)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeGen) tokenizeJSON(text string) (string, error) {
|
||||
f.begin(text, "")
|
||||
defer f.end()
|
||||
return f.out, f.err
|
||||
}
|
||||
|
||||
func (f *fakeGen) cancel() {}
|
||||
func (f *fakeGen) free() {}
|
||||
|
||||
// newTestDllm assembles a backend around a fake generator (bypassing Load,
|
||||
// which needs libdllm.so) and registers cleanup of the worker goroutine.
|
||||
func newTestDllm(g generator, genOpts map[string]any) *Dllm {
|
||||
d := &Dllm{gen: g, genOpts: genOpts}
|
||||
d.startWorker()
|
||||
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
|
||||
return d
|
||||
}
|
||||
|
||||
// drainReplies empties ch without blocking, failing the spec if the channel
|
||||
// was closed (PredictStreamRich must NOT close it - interface.go contract).
|
||||
// Size ch above the expected reply count: an overflow deadlocks the spec on
|
||||
// the producer's send instead of failing it.
|
||||
func drainReplies(ch chan *pb.Reply) []*pb.Reply {
|
||||
var out []*pb.Reply
|
||||
for {
|
||||
select {
|
||||
case r, ok := <-ch:
|
||||
if !ok {
|
||||
Fail("PredictStreamRich closed the results channel (the gRPC server owns the close)")
|
||||
}
|
||||
expectValidUTF8Reply(r)
|
||||
out = append(out, r)
|
||||
default:
|
||||
return out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expectValidUTF8Reply is the blanket guard for the proto3 marshaling
|
||||
// contract: grpc-go rejects any string field carrying invalid UTF-8, so every
|
||||
// reply field that ends up in a proto string must validate.
|
||||
func expectValidUTF8Reply(r *pb.Reply) {
|
||||
GinkgoHelper()
|
||||
Expect(utf8.ValidString(string(r.GetMessage()))).To(BeTrue(), "Reply.Message carries invalid UTF-8")
|
||||
for _, delta := range r.GetChatDeltas() {
|
||||
Expect(utf8.ValidString(delta.GetContent())).To(BeTrue(), "ChatDelta.Content carries invalid UTF-8")
|
||||
Expect(utf8.ValidString(delta.GetReasoningContent())).To(BeTrue(), "ChatDelta.ReasoningContent carries invalid UTF-8")
|
||||
for _, tc := range delta.GetToolCalls() {
|
||||
Expect(utf8.ValidString(tc.GetName())).To(BeTrue(), "ToolCallDelta.Name carries invalid UTF-8")
|
||||
Expect(utf8.ValidString(tc.GetArguments())).To(BeTrue(), "ToolCallDelta.Arguments carries invalid UTF-8")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("Dllm backend wiring", func() {
|
||||
Describe("PredictRich", func() {
|
||||
It("renders gemma4 from raw messages and parses the output when use_tokenizer_template is set", func() {
|
||||
fake := &fakeGen{out: "<|channel>thought\npondering<channel|>The answer.<turn|>"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
reply, err := d.PredictRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
|
||||
Metadata: map[string]string{"enable_thinking": "true"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls).To(HaveLen(1))
|
||||
// The enable_thinking=true render from the transformers fixture.
|
||||
Expect(calls[0].prompt).To(Equal(
|
||||
"<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n"))
|
||||
|
||||
Expect(string(reply.GetMessage())).To(Equal("The answer."))
|
||||
Expect(reply.GetChatDeltas()).To(HaveLen(2))
|
||||
Expect(reply.GetChatDeltas()[0].GetReasoningContent()).To(Equal("pondering"))
|
||||
Expect(reply.GetChatDeltas()[1].GetContent()).To(Equal("The answer."))
|
||||
})
|
||||
|
||||
It("defaults enable_thinking OFF (the gemma4 template treats thinking as opt-in)", func() {
|
||||
fake := &fakeGen{out: "hi"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
// No-thinking render: the template pre-opens AND pre-closes an
|
||||
// empty thought channel in the generation prompt.
|
||||
Expect(calls[0].prompt).To(Equal(
|
||||
"<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>"))
|
||||
})
|
||||
|
||||
It("passes the raw prompt verbatim and skips gemma4 parsing without use_tokenizer_template", func() {
|
||||
// Marker-looking text must survive untouched: in raw-prompt mode
|
||||
// the caller templates themselves and the Go-side extraction
|
||||
// applies, so the backend must not interpret the output.
|
||||
fake := &fakeGen{out: "<|channel>thought\nnot parsed<channel|>tail"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "my raw prompt"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].prompt).To(Equal("my raw prompt"))
|
||||
Expect(string(reply.GetMessage())).To(Equal(fake.out))
|
||||
Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal(fake.out))
|
||||
})
|
||||
|
||||
It("sanitizes invalid UTF-8 in the non-streaming output", func() {
|
||||
// Byte-fallback tokens can decode to lone malformed bytes; the
|
||||
// whole-output sanitize must replace them so proto3 marshaling of
|
||||
// Message/ChatDeltas cannot fail.
|
||||
fake := &fakeGen{out: "a\xe2b"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectValidUTF8Reply(reply)
|
||||
Expect(string(reply.GetMessage())).To(Equal("a<>b"))
|
||||
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal("a<>b"))
|
||||
})
|
||||
|
||||
It("maps Tokens and Seed into the opts JSON on top of the model-level overrides", func() {
|
||||
fake := &fakeGen{out: "x"}
|
||||
d := newTestDllm(fake, map[string]any{"eb_t_min": 0.5, "kv_cache": "auto"})
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p", Tokens: 32, Seed: 7})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].optsJSON).To(MatchJSON(`{"n_predict":32,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
|
||||
})
|
||||
|
||||
It("omits n_predict and seed when unset so the engine defaults apply", func() {
|
||||
fake := &fakeGen{out: "x"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].optsJSON).To(MatchJSON(`{}`))
|
||||
})
|
||||
|
||||
It("surfaces generator errors", func() {
|
||||
fake := &fakeGen{err: errors.New("boom")}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).To(MatchError("boom"))
|
||||
})
|
||||
|
||||
It("errors before generating when no model is loaded", func() {
|
||||
d := &Dllm{} // no Load, no worker: must fail fast, not hang
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("makes a concurrent Free wait for the in-flight request (both finish cleanly)", func() {
|
||||
// server.go's Free has no locking of its own: the backend's genMu
|
||||
// must hold Free back until the racing generate drains, instead of
|
||||
// closing the jobs channel (panic) or freeing the C ctx under it.
|
||||
fake := &fakeGen{out: "x", delay: 50 * time.Millisecond}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
predictDone := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
predictDone <- err
|
||||
}()
|
||||
// Wait until the fake generate is actually in flight (the read
|
||||
// lock is held from before submit until PredictRich returns).
|
||||
Eventually(func() int {
|
||||
_, maxInFlight := fake.snapshot()
|
||||
return maxInFlight
|
||||
}).Should(Equal(1))
|
||||
|
||||
Expect(d.Free()).To(Succeed())
|
||||
// Free's write lock means the request finished before Free did.
|
||||
var predictErr error
|
||||
Eventually(predictDone).Should(Receive(&predictErr))
|
||||
Expect(predictErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns model-not-loaded for requests after Free", func() {
|
||||
fake := &fakeGen{out: "x"}
|
||||
d := newTestDllm(fake, nil)
|
||||
Expect(d.Free()).To(Succeed())
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
|
||||
})
|
||||
|
||||
It("serializes concurrent requests through the worker goroutine", func() {
|
||||
// dllm_capi.h: one ctx = one concurrent generate. Two overlapping
|
||||
// PredictRich calls must execute the C calls one at a time.
|
||||
fake := &fakeGen{out: "x", delay: 30 * time.Millisecond}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 2 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer GinkgoRecover()
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
calls, maxInFlight := fake.snapshot()
|
||||
Expect(calls).To(HaveLen(2))
|
||||
Expect(maxInFlight).To(Equal(1), "generate calls overlapped despite the worker queue")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("PredictStreamRich", func() {
|
||||
It("emits one reply per delta-producing block and leaves the channel open", func() {
|
||||
// Blocks split mid-marker and mid-payload: the parser's holdback
|
||||
// must keep marker fragments out of the emitted deltas.
|
||||
fake := &fakeGen{blocks: []string{
|
||||
"<|channel>thou", // partial channel open: no deltas yet
|
||||
"ght\nponder", // header completes, reasoning starts
|
||||
"ing<channel|>Hi ", // reasoning ends, content starts
|
||||
"there<turn|>discarded", // turn ends: trailing text dropped
|
||||
}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
replies := drainReplies(ch)
|
||||
Expect(replies).To(HaveLen(3), "block 1 completes no delta and must not produce a reply")
|
||||
|
||||
var content, reasoning string
|
||||
for _, r := range replies {
|
||||
for _, delta := range r.GetChatDeltas() {
|
||||
content += delta.GetContent()
|
||||
reasoning += delta.GetReasoningContent()
|
||||
}
|
||||
}
|
||||
Expect(reasoning).To(Equal("pondering"))
|
||||
Expect(content).To(Equal("Hi there"))
|
||||
// Message mirrors each reply's content so legacy consumers see
|
||||
// exactly the displayed tokens.
|
||||
Expect(string(replies[1].GetMessage())).To(Equal("Hi "))
|
||||
Expect(string(replies[2].GetMessage())).To(Equal("there"))
|
||||
})
|
||||
|
||||
It("streams raw blocks verbatim without use_tokenizer_template", func() {
|
||||
fake := &fakeGen{blocks: []string{"abc", "", "<|channel>def"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
replies := drainReplies(ch)
|
||||
Expect(replies).To(HaveLen(2), "empty blocks produce no reply")
|
||||
Expect(string(replies[0].GetMessage())).To(Equal("abc"))
|
||||
Expect(string(replies[1].GetMessage())).To(Equal("<|channel>def"))
|
||||
Expect(replies[1].GetChatDeltas()).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("flushes parser holdback after the stream ends", func() {
|
||||
// The unterminated partial marker "<chan" is held back during the
|
||||
// stream and must come out as content on the final flush.
|
||||
fake := &fakeGen{blocks: []string{"tail<chan"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) {
|
||||
content += string(r.GetMessage())
|
||||
}
|
||||
Expect(content).To(Equal("tail<chan"))
|
||||
})
|
||||
|
||||
It("reassembles a multi-byte character split across block boundaries", func() {
|
||||
// Per-block detokenize can split "€" (E2 82 AC) as E2 | 82 AC.
|
||||
// Emitting the lone E2 would make grpc-go fail the marshal of the
|
||||
// whole reply; the trailing incomplete sequence must be held back
|
||||
// and completed by the next block.
|
||||
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac ok"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) { // drain asserts ValidString per reply
|
||||
content += string(r.GetMessage())
|
||||
}
|
||||
Expect(content).To(Equal("caf€ ok"))
|
||||
})
|
||||
|
||||
It("reassembles a split multi-byte character in parsed (gemma4) mode too", func() {
|
||||
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac<turn|>"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) {
|
||||
for _, delta := range r.GetChatDeltas() {
|
||||
content += delta.GetContent()
|
||||
}
|
||||
}
|
||||
Expect(content).To(Equal("caf€"))
|
||||
})
|
||||
|
||||
It("replaces an incomplete sequence left at stream end with U+FFFD", func() {
|
||||
// A byte-fallback token can leave a lone leading byte (0xE2) that
|
||||
// no later block completes: the final flush must substitute it,
|
||||
// never emit it raw and never drop into a marshal error.
|
||||
fake := &fakeGen{blocks: []string{"ok\xe2"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) {
|
||||
content += string(r.GetMessage())
|
||||
}
|
||||
Expect(content).To(Equal("ok<6F>"))
|
||||
})
|
||||
|
||||
It("surfaces generator errors without sending replies", func() {
|
||||
fake := &fakeGen{err: errors.New("stream boom")}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
|
||||
Expect(err).To(MatchError("stream boom"))
|
||||
Expect(drainReplies(ch)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("errors before generating when no model is loaded", func() {
|
||||
d := &Dllm{} // no Load, no worker: must fail fast, not hang
|
||||
ch := make(chan *pb.Reply, 1)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
|
||||
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
|
||||
Expect(drainReplies(ch)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("legacy Predict/PredictStream adapters", func() {
|
||||
It("Predict returns the aggregated content string", func() {
|
||||
fake := &fakeGen{out: "plain text"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
out, err := d.Predict(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal("plain text"))
|
||||
})
|
||||
|
||||
It("PredictStream forwards content strings and closes the channel (legacy ownership)", func() {
|
||||
fake := &fakeGen{blocks: []string{"a", "b"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan string, 16)
|
||||
Expect(d.PredictStream(&pb.PredictOptions{Prompt: "p"}, ch)).To(Succeed())
|
||||
|
||||
var got []string
|
||||
for s := range ch { // terminates only if the impl closed ch
|
||||
got = append(got, s)
|
||||
}
|
||||
Expect(got).To(Equal([]string{"a", "b"}))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("TokenizeString", func() {
|
||||
It("decodes the C-side JSON id array", func() {
|
||||
fake := &fakeGen{out: "[2,18]"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Length).To(Equal(int32(2)))
|
||||
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].prompt).To(Equal("hello"))
|
||||
})
|
||||
|
||||
It("fails loud on a malformed id array", func() {
|
||||
fake := &fakeGen{out: "not json"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors before tokenizing when no model is loaded", func() {
|
||||
d := &Dllm{} // no Load, no worker: must fail fast, not hang
|
||||
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("parseModelGenOpts", func() {
|
||||
It("parses eb_*/blocks/kv_cache entries and types values by first successful parse", func() {
|
||||
got := parseModelGenOpts([]string{
|
||||
"eb_max_steps:16",
|
||||
"eb_t_min:0.25",
|
||||
"kv_cache:auto",
|
||||
"blocks:4",
|
||||
"unrelated_key:1", // other layers' options: skipped
|
||||
"malformed", // no colon: skipped
|
||||
})
|
||||
Expect(got).To(Equal(map[string]any{
|
||||
"eb_max_steps": int64(16),
|
||||
"eb_t_min": 0.25,
|
||||
"kv_cache": "auto",
|
||||
"blocks": int64(4),
|
||||
}))
|
||||
})
|
||||
|
||||
It("round-trips through buildOptsJSON (only flat scalars are produced)", func() {
|
||||
got := parseModelGenOpts([]string{"eb_entropy_bound:0.8", "kv_cache:off"})
|
||||
out, err := buildOptsJSON(got)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(MatchJSON(`{"eb_entropy_bound":0.8,"kv_cache":"off"}`))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Gated backend round-trip against the real libdllm.so + tiny GGUF fixture.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var _ = Describe("Dllm backend (real tiny model)", func() {
|
||||
BeforeEach(func() {
|
||||
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
|
||||
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the backend round-trip")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
Expect(libLoadErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("round-trips Load, PredictRich, PredictStreamRich and TokenizeString", func() {
|
||||
d := &Dllm{}
|
||||
Expect(d.Load(&pb.ModelOptions{ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL")})).To(Succeed())
|
||||
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
|
||||
|
||||
// TokenizeString: tiny fixture vocab tokenizes "hello" to [2,18].
|
||||
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
|
||||
Expect(resp.Length).To(Equal(int32(2)))
|
||||
|
||||
req := &pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
|
||||
Tokens: 16,
|
||||
Seed: 7,
|
||||
}
|
||||
|
||||
// Non-streaming: the tiny random-weight model emits arbitrary vocab
|
||||
// words; with no gemma4 markers in them everything is content.
|
||||
reply, err := d.PredictRich(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(reply.GetMessage())).ToNot(BeEmpty())
|
||||
Expect(reply.GetChatDeltas()).ToNot(BeEmpty())
|
||||
|
||||
// Streaming: at least one reply, and the channel-ownership rule is
|
||||
// honored (drainReplies fails the spec on a closed channel).
|
||||
ch := make(chan *pb.Reply, 64)
|
||||
Expect(d.PredictStreamRich(req, ch)).To(Succeed())
|
||||
replies := drainReplies(ch)
|
||||
Expect(replies).ToNot(BeEmpty())
|
||||
var streamed string
|
||||
for _, r := range replies {
|
||||
streamed += string(r.GetMessage())
|
||||
}
|
||||
Expect(streamed).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("aborts an in-flight generation promptly on Cancel", func() {
|
||||
d := &Dllm{}
|
||||
// eb_max_steps inflates the per-block denoise loop so the full run
|
||||
// takes ~10s on the tiny fixture (vs ~40ms at engine defaults; 16
|
||||
// blocks, first block after ~0.7s) - long enough that a prompt
|
||||
// post-cancel return is distinguishable from the generation simply
|
||||
// finishing.
|
||||
Expect(d.Load(&pb.ModelOptions{
|
||||
ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL"),
|
||||
Options: []string{"eb_max_steps:256"},
|
||||
})).To(Succeed())
|
||||
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
|
||||
|
||||
ch := make(chan *pb.Reply, 64)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
errCh <- d.PredictStreamRich(&pb.PredictOptions{Prompt: "hello", Tokens: 256, Seed: 7}, ch)
|
||||
}()
|
||||
|
||||
// Cancel only once the first block proves the generate is in
|
||||
// flight: the C side resets the cancel flag on generate entry, so
|
||||
// an earlier Cancel would be swallowed (dllm_capi.h race note).
|
||||
Eventually(ch, "60s").Should(Receive())
|
||||
cancelAt := time.Now()
|
||||
d.Cancel()
|
||||
|
||||
// Uncancelled, ~10s of generation remain; the cancelled call must
|
||||
// come back in milliseconds (the flag is checked per denoise step).
|
||||
var genErr error
|
||||
Eventually(errCh, "5s").Should(Receive(&genErr))
|
||||
latency := time.Since(cancelAt)
|
||||
Expect(genErr).To(MatchError(ContainSubstring("cancelled")))
|
||||
GinkgoWriter.Printf("dllm cancel: PredictStreamRich returned %v after Cancel\n", latency)
|
||||
})
|
||||
})
|
||||
562
backend/go/dllm/gemma4_parser.go
Normal file
562
backend/go/dllm/gemma4_parser.go
Normal file
@@ -0,0 +1,562 @@
|
||||
// Gemma4 (DiffusionGemma) streaming output parser: raw model text, fed in
|
||||
// arbitrary fragments (per committed diffusion block; a fragment can split
|
||||
// anywhere, including mid-marker and mid-payload), is turned into
|
||||
// pb.ChatDelta events (content / reasoning_content / tool_calls).
|
||||
//
|
||||
// Normative sources:
|
||||
// - The chat template embedded at the top of gemma4_renderer.go ("tpl L<n>"
|
||||
// citations below refer to its numbered lines). The OUTPUT format mirrors
|
||||
// what the template renders for assistant history: thought channels
|
||||
// (<|channel>thought\n ... <channel|>, tpl L240), tool calls
|
||||
// (<|tool_call>call:name{...}<tool_call|>, tpl L246-L257) and turn ends
|
||||
// (<turn|>, tpl L351).
|
||||
// - vLLM PR #45163: vllm/tool_parsers/gemma4_tool_parser.py (marker
|
||||
// handling, the call:name{...} argument grammar and its decoder, ported
|
||||
// below) and vllm/reasoning/gemma4_reasoning_parser.py (channel markers,
|
||||
// the "thought\n" role label, is_reasoning_end semantics).
|
||||
//
|
||||
// Initial state (derived from the generation prompt, tpl L356-L362, see
|
||||
// RenderGemma4):
|
||||
// - enable_thinking=false: the prompt ends with "<|turn>model\n" +
|
||||
// "<|channel>thought\n<channel|>" - an EMPTY thought channel, pre-opened
|
||||
// AND pre-closed by the template. The model's output therefore starts in
|
||||
// plain content. Use NewGemma4Parser(false).
|
||||
// - enable_thinking=true: the prompt ends at "<|turn>model\n" and the model
|
||||
// opens and closes its own thought channel in the OUTPUT
|
||||
// ("<|channel>thought\n...reasoning...<channel|>final answer", per the
|
||||
// vLLM Gemma4ReasoningParser docstring). The parser still starts in
|
||||
// content state - the channel markers in the output drive the switch.
|
||||
// Use NewGemma4Parser(false) here too.
|
||||
// - NewGemma4Parser(true) is for callers that pre-open the thought channel
|
||||
// in the prompt themselves (appending "<|channel>thought\n" after the
|
||||
// generation prompt to force thinking): the output then begins mid-thought
|
||||
// and everything is reasoning until the first <channel|>.
|
||||
//
|
||||
// State diagram (markers are consumed, never emitted):
|
||||
//
|
||||
// <|channel> \n (channel name dropped: the
|
||||
// [content] --------------> [chan-header] ----> [thought] "thought\n" role
|
||||
// ^ | <channel|> (stray close: swallowed, label, stripped
|
||||
// +-+ strip_thinking semantics, tpl L148-L158) like vLLM does)
|
||||
// ^ <channel|>
|
||||
// +----------------------------------------- [thought]
|
||||
// ^ <tool_call|> | <|tool_call> (implicit
|
||||
// +-------------- [tool-call] <-------------------+ reasoning end, vLLM
|
||||
// | <|tool_call> ^ is_reasoning_end)
|
||||
// +-------------------+
|
||||
// [content]/[thought] --- <turn|> ---> [done] (everything after is dropped)
|
||||
//
|
||||
// Buffering rules:
|
||||
// - content/thought states hold back at most len(longest marker)-1 bytes:
|
||||
// the longest tail that is still a proper prefix of a watched marker.
|
||||
// Content is otherwise emitted immediately (no unbounded buffering).
|
||||
// - the tool-call state buffers the whole payload until <tool_call|>. This
|
||||
// is unbounded in principle but bounded in practice by the model's
|
||||
// diffusion canvas, and is required because the call:name{...} payload
|
||||
// only becomes decodable (and trustworthy) once complete - the same
|
||||
// reason vLLM's parser accumulates before parsing.
|
||||
// - Close() flushes whatever is still held: partial markers come out as
|
||||
// content/reasoning (per the state that held them); an unterminated
|
||||
// channel header or tool-call payload is re-emitted RAW (including its
|
||||
// opening marker) as content - malformed output is never silently
|
||||
// dropped (mirrors vLLM extract_tool_calls returning the raw text as
|
||||
// content when its regex does not match).
|
||||
//
|
||||
// Streaming granularity DIVERGENCE from vLLM: vLLM re-parses the partial
|
||||
// payload on every token and streams argument-JSON diffs (its `partial=True`
|
||||
// decoder mode plus withholding logic exist only for that). Our fragments are
|
||||
// whole committed diffusion blocks, so each completed tool call is emitted
|
||||
// once, as a single ToolCallDelta carrying index + id + name + the full
|
||||
// arguments JSON - exactly the shape backend/python/vllm/backend.py emits
|
||||
// per call and pkg/functions.ToolCallsFromChatDeltas re-accumulates.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// gemma4CallRE is vLLM's tool_call_regex
|
||||
// (`<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>`, DOTALL) anchored to
|
||||
// a single already-extracted payload: name charset [\w\-.], braces mandatory.
|
||||
var gemma4CallRE = regexp.MustCompile(`(?s)^call:([\w\-.]+)\{(.*)\}$`)
|
||||
|
||||
type g4State int
|
||||
|
||||
const (
|
||||
g4Content g4State = iota
|
||||
g4ChanHeader
|
||||
g4Thought
|
||||
g4ToolCall
|
||||
g4Done
|
||||
)
|
||||
|
||||
// Markers watched per emitting state. A stray <tool_call|> outside a tool
|
||||
// call is deliberately NOT watched: it passes through verbatim, consistent
|
||||
// with the malformed-payload fallback re-emitting it as content.
|
||||
var (
|
||||
gemma4ContentMarkers = []string{gemma4ChannelOpen, gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
|
||||
gemma4ThoughtMarkers = []string{gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
|
||||
)
|
||||
|
||||
type Gemma4Parser struct {
|
||||
state g4State
|
||||
// held is the per-state carry-over between Feed calls: a partial marker
|
||||
// (content/thought), a partial channel header (chan-header) or the
|
||||
// payload accumulated so far (tool-call).
|
||||
held string
|
||||
toolIdx int
|
||||
}
|
||||
|
||||
// NewGemma4Parser returns a parser positioned per the initial-state rules in
|
||||
// the header comment: startInThought=true only when the caller pre-opened a
|
||||
// thought channel in the prompt.
|
||||
func NewGemma4Parser(startInThought bool) *Gemma4Parser {
|
||||
state := g4Content
|
||||
if startInThought {
|
||||
state = g4Thought
|
||||
}
|
||||
return &Gemma4Parser{state: state}
|
||||
}
|
||||
|
||||
// Feed consumes the next output fragment and returns the deltas it completes.
|
||||
func (p *Gemma4Parser) Feed(text string) []*pb.ChatDelta {
|
||||
if text == "" || p.state == g4Done {
|
||||
return nil
|
||||
}
|
||||
pending := p.held + text
|
||||
p.held = ""
|
||||
var em g4Emitter
|
||||
for pending != "" {
|
||||
switch p.state {
|
||||
case g4Content, g4Thought:
|
||||
markers := gemma4ContentMarkers
|
||||
if p.state == g4Thought {
|
||||
markers = gemma4ThoughtMarkers
|
||||
}
|
||||
idx, marker := findEarliestGemma4Marker(pending, markers)
|
||||
if idx == -1 {
|
||||
hold := gemma4MarkerHoldback(pending, markers)
|
||||
p.emitText(&em, pending[:len(pending)-hold])
|
||||
p.held = pending[len(pending)-hold:]
|
||||
pending = ""
|
||||
continue
|
||||
}
|
||||
p.emitText(&em, pending[:idx])
|
||||
pending = pending[idx+len(marker):]
|
||||
switch marker {
|
||||
case gemma4ChannelOpen:
|
||||
p.state = g4ChanHeader
|
||||
case gemma4ChannelClose:
|
||||
// In thought: channel ends. In content: stray close,
|
||||
// swallowed (strip_thinking keeps both sides, tpl L148-L158).
|
||||
p.state = g4Content
|
||||
case gemma4ToolCallOpen:
|
||||
p.state = g4ToolCall
|
||||
case gemma4TurnEnd:
|
||||
p.state = g4Done
|
||||
}
|
||||
case g4ChanHeader:
|
||||
// The channel header is "<name>\n"; the template only ever writes
|
||||
// "thought" (tpl L240/L360) and the label is structural, so it is
|
||||
// dropped, not emitted (vLLM strips the same "thought\n" prefix).
|
||||
nl := strings.IndexByte(pending, '\n')
|
||||
if nl == -1 {
|
||||
p.held = pending
|
||||
pending = ""
|
||||
continue
|
||||
}
|
||||
pending = pending[nl+1:]
|
||||
p.state = g4Thought
|
||||
case g4ToolCall:
|
||||
end := strings.Index(pending, gemma4ToolCallClose)
|
||||
if end == -1 {
|
||||
p.held = pending
|
||||
pending = ""
|
||||
continue
|
||||
}
|
||||
p.emitToolCall(&em, pending[:end])
|
||||
pending = pending[end+len(gemma4ToolCallClose):]
|
||||
p.state = g4Content
|
||||
case g4Done:
|
||||
pending = ""
|
||||
}
|
||||
}
|
||||
return em.deltas
|
||||
}
|
||||
|
||||
// Close flushes held-back partials. Incomplete structures (open channel
|
||||
// header, unterminated tool payload) are re-emitted raw as content rather
|
||||
// than dropped. The parser is finished afterwards.
|
||||
func (p *Gemma4Parser) Close() []*pb.ChatDelta {
|
||||
var em g4Emitter
|
||||
switch p.state {
|
||||
case g4Content:
|
||||
em.content(p.held)
|
||||
case g4Thought:
|
||||
em.reasoning(p.held)
|
||||
case g4ChanHeader:
|
||||
em.content(gemma4ChannelOpen + p.held)
|
||||
case g4ToolCall:
|
||||
em.content(gemma4ToolCallOpen + p.held)
|
||||
case g4Done:
|
||||
}
|
||||
p.held = ""
|
||||
p.state = g4Done
|
||||
return em.deltas
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) emitText(em *g4Emitter, s string) {
|
||||
if p.state == g4Thought {
|
||||
em.reasoning(s)
|
||||
return
|
||||
}
|
||||
em.content(s)
|
||||
}
|
||||
|
||||
// emitToolCall decodes one complete <|tool_call>...<tool_call|> payload. On a
|
||||
// payload that does not match call:name{...} the raw text (markers included)
|
||||
// is emitted as content, mirroring vLLM's extract_tool_calls fallback.
|
||||
func (p *Gemma4Parser) emitToolCall(em *g4Emitter, payload string) {
|
||||
m := gemma4CallRE.FindStringSubmatch(payload)
|
||||
if m == nil {
|
||||
em.content(gemma4ToolCallOpen + payload + gemma4ToolCallClose)
|
||||
return
|
||||
}
|
||||
// Index-based ids: deterministic (the split-invariance property relies
|
||||
// on it) and matching the call_<n> convention of pkg/grpc/rich_test.go;
|
||||
// core only needs ids to be non-empty and unique within the response.
|
||||
em.tool(p.toolIdx, "call_"+strconv.Itoa(p.toolIdx), m[1], decodeGemma4Args(m[2], 0))
|
||||
p.toolIdx++
|
||||
}
|
||||
|
||||
// g4Emitter collects ChatDeltas; empty text events are dropped.
|
||||
type g4Emitter struct {
|
||||
deltas []*pb.ChatDelta
|
||||
}
|
||||
|
||||
func (e *g4Emitter) content(s string) {
|
||||
if s != "" {
|
||||
e.deltas = append(e.deltas, &pb.ChatDelta{Content: s})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *g4Emitter) reasoning(s string) {
|
||||
if s != "" {
|
||||
e.deltas = append(e.deltas, &pb.ChatDelta{ReasoningContent: s})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *g4Emitter) tool(index int, id, name, argsJSON string) {
|
||||
e.deltas = append(e.deltas, &pb.ChatDelta{ToolCalls: []*pb.ToolCallDelta{{
|
||||
Index: int32(index),
|
||||
Id: id,
|
||||
Name: name,
|
||||
Arguments: argsJSON,
|
||||
}}})
|
||||
}
|
||||
|
||||
// findEarliestGemma4Marker returns the position and value of the first
|
||||
// complete marker occurrence, or (-1, "").
|
||||
func findEarliestGemma4Marker(s string, markers []string) (int, string) {
|
||||
best, bestMarker := -1, ""
|
||||
for _, m := range markers {
|
||||
if idx := strings.Index(s, m); idx >= 0 && (best == -1 || idx < best) {
|
||||
best, bestMarker = idx, m
|
||||
}
|
||||
}
|
||||
return best, bestMarker
|
||||
}
|
||||
|
||||
// gemma4MarkerHoldback returns the length of the longest suffix of s that is
|
||||
// a proper prefix of a watched marker - the only bytes that may still grow
|
||||
// into a marker and therefore must not be emitted yet (bounded by the
|
||||
// longest marker, so content is never buffered unboundedly).
|
||||
func gemma4MarkerHoldback(s string, markers []string) int {
|
||||
maxHold := 0
|
||||
for _, m := range markers {
|
||||
if len(m)-1 > maxHold {
|
||||
maxHold = len(m) - 1
|
||||
}
|
||||
}
|
||||
if len(s) < maxHold {
|
||||
maxHold = len(s)
|
||||
}
|
||||
for k := maxHold; k >= 1; k-- {
|
||||
tail := s[len(s)-k:]
|
||||
for _, m := range markers {
|
||||
if strings.HasPrefix(m, tail) {
|
||||
return k
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// call:name{...} argument decoder
|
||||
//
|
||||
// Port of vLLM's _parse_gemma4_args / _parse_gemma4_array /
|
||||
// _parse_gemma4_value (gemma4_tool_parser.py) in non-partial mode only: this
|
||||
// parser decodes exclusively COMPLETE payloads (incomplete ones fall back to
|
||||
// raw content at Close), so vLLM's partial-withholding machinery
|
||||
// (trailing-dot floats, withheld bare tails) is intentionally not ported.
|
||||
//
|
||||
// Grammar (inverse of the renderer's formatGemma4Argument, tpl L118-L147):
|
||||
//
|
||||
// args := pair (',' pair)*
|
||||
// pair := key ':' value (keys unquoted, up to the first ':')
|
||||
// value := string | object | array | bare
|
||||
// string := '<|"|>' ... '<|"|>' (no escapes; unterminated -> rest)
|
||||
// object := '{' args '}' (delimited strings skipped when
|
||||
// array := '[' value,* ']' counting braces/brackets)
|
||||
// bare := true | false | null/none/nil | number | bare-string
|
||||
//
|
||||
// Output is a JSON object/array string with keys in payload order (Python
|
||||
// dict insertion order), built with HTML escaping off so payload text
|
||||
// survives byte-for-byte.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func isGemma4Space(c byte) bool { return c == ' ' || c == '\n' || c == '\t' }
|
||||
|
||||
// gemma4MaxArgsDepth caps the mutual recursion between decodeGemma4Args and
|
||||
// decodeGemma4Array. Defense against model-generated deep nesting: a Go stack
|
||||
// overflow is a fatal process kill, not a recoverable error, so past the cap
|
||||
// a nested body gracefully degrades to a JSON string of its raw text.
|
||||
const gemma4MaxArgsDepth = 100
|
||||
|
||||
// decodeGemma4Args decodes one args body (the text between the outer braces
|
||||
// of call:name{...}) into a JSON object string. depth is the current nesting
|
||||
// level (0 at the payload root); see gemma4MaxArgsDepth.
|
||||
func decodeGemma4Args(s string, depth int) string {
|
||||
if depth > gemma4MaxArgsDepth {
|
||||
return gemma4JSONString(s)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("{")
|
||||
first := true
|
||||
pair := func(key, val string) {
|
||||
if !first {
|
||||
b.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
b.WriteString(gemma4JSONString(key))
|
||||
b.WriteString(":")
|
||||
b.WriteString(val)
|
||||
}
|
||||
i, n := 0, len(s)
|
||||
for i < n {
|
||||
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
break
|
||||
}
|
||||
keyStart := i
|
||||
for i < n && s[i] != ':' {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
break // no ':' -> trailing junk, dropped (vLLM does the same)
|
||||
}
|
||||
key := strings.TrimSpace(s[keyStart:i])
|
||||
i++ // skip ':'
|
||||
for i < n && isGemma4Space(s[i]) {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
pair(key, `""`) // "key:" with nothing after -> empty string
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(s[i:], gemma4StringDelim):
|
||||
i += len(gemma4StringDelim)
|
||||
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
|
||||
pair(key, gemma4JSONString(s[i:])) // unterminated -> take rest
|
||||
i = n
|
||||
} else {
|
||||
pair(key, gemma4JSONString(s[i:i+end]))
|
||||
i += end + len(gemma4StringDelim)
|
||||
}
|
||||
case s[i] == '{':
|
||||
inner, next := scanGemma4Balanced(s, i, '{', '}')
|
||||
pair(key, decodeGemma4Args(inner, depth+1))
|
||||
i = next
|
||||
case s[i] == '[':
|
||||
inner, next := scanGemma4Balanced(s, i, '[', ']')
|
||||
pair(key, decodeGemma4Array(inner, depth+1))
|
||||
i = next
|
||||
default:
|
||||
valStart := i
|
||||
for i < n && s[i] != ',' && s[i] != '}' && s[i] != ']' {
|
||||
i++
|
||||
}
|
||||
if i == valStart {
|
||||
// No progress (value starts on a stray '}'/']'): abort on
|
||||
// malformed input rather than loop, like vLLM.
|
||||
i = n
|
||||
continue
|
||||
}
|
||||
pair(key, decodeGemma4Bare(s[valStart:i]))
|
||||
}
|
||||
}
|
||||
b.WriteString("}")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// decodeGemma4Array decodes one array body (the text between '[' and ']')
|
||||
// into a JSON array string. depth is the current nesting level; see
|
||||
// gemma4MaxArgsDepth.
|
||||
func decodeGemma4Array(s string, depth int) string {
|
||||
if depth > gemma4MaxArgsDepth {
|
||||
return gemma4JSONString(s)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("[")
|
||||
first := true
|
||||
item := func(val string) {
|
||||
if !first {
|
||||
b.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
b.WriteString(val)
|
||||
}
|
||||
i, n := 0, len(s)
|
||||
for i < n {
|
||||
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(s[i:], gemma4StringDelim):
|
||||
i += len(gemma4StringDelim)
|
||||
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
|
||||
item(gemma4JSONString(s[i:]))
|
||||
i = n
|
||||
} else {
|
||||
item(gemma4JSONString(s[i : i+end]))
|
||||
i += end + len(gemma4StringDelim)
|
||||
}
|
||||
case s[i] == '{':
|
||||
inner, next := scanGemma4Balanced(s, i, '{', '}')
|
||||
item(decodeGemma4Args(inner, depth+1))
|
||||
i = next
|
||||
case s[i] == '[':
|
||||
inner, next := scanGemma4Balanced(s, i, '[', ']')
|
||||
item(decodeGemma4Array(inner, depth+1))
|
||||
i = next
|
||||
default:
|
||||
valStart := i
|
||||
for i < n && s[i] != ',' && s[i] != ']' {
|
||||
i++
|
||||
}
|
||||
if i == valStart {
|
||||
i = n // no progress: abort on malformed input, like vLLM
|
||||
continue
|
||||
}
|
||||
item(decodeGemma4Bare(s[valStart:i]))
|
||||
}
|
||||
}
|
||||
b.WriteString("]")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// scanGemma4Balanced scans a brace/bracket-balanced span starting at the
|
||||
// opener s[start], skipping over <|"|>-delimited strings so structural
|
||||
// characters inside them do not count (vLLM's depth scan). Returns the inner
|
||||
// text and the index just past the closer; an unterminated span yields the
|
||||
// rest of the string (the inner decoder still extracts what is there - this
|
||||
// path is only reachable from genuinely malformed complete payloads).
|
||||
func scanGemma4Balanced(s string, start int, open, close byte) (string, int) {
|
||||
depth := 1
|
||||
i := start + 1
|
||||
innerStart := i
|
||||
n := len(s)
|
||||
for i < n && depth > 0 {
|
||||
if strings.HasPrefix(s[i:], gemma4StringDelim) {
|
||||
i += len(gemma4StringDelim)
|
||||
if nd := strings.Index(s[i:], gemma4StringDelim); nd == -1 {
|
||||
i = n
|
||||
} else {
|
||||
i += nd + len(gemma4StringDelim)
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch s[i] {
|
||||
case open:
|
||||
depth++
|
||||
case close:
|
||||
depth--
|
||||
}
|
||||
i++
|
||||
}
|
||||
if depth > 0 {
|
||||
return s[innerStart:], n
|
||||
}
|
||||
return s[innerStart : i-1], i
|
||||
}
|
||||
|
||||
// decodeGemma4Bare maps an undelimited value to its JSON form: booleans,
|
||||
// null aliases (null/none/nil, case-insensitive - the renderer writes
|
||||
// Python None as "None", tpl L144-L145 via format_argument's else branch),
|
||||
// numbers (vLLM's rule: a '.' tries float, otherwise int; anything that
|
||||
// fails parses as a bare string).
|
||||
func decodeGemma4Bare(raw string) string {
|
||||
v := strings.TrimSpace(raw)
|
||||
if v == "" {
|
||||
return `""`
|
||||
}
|
||||
if v == "true" || v == "false" {
|
||||
return v
|
||||
}
|
||||
switch strings.ToLower(v) {
|
||||
case "null", "none", "nil":
|
||||
return "null"
|
||||
}
|
||||
if strings.Contains(v, ".") {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return formatGemma4Float(f)
|
||||
}
|
||||
} else if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return strconv.FormatInt(iv, 10)
|
||||
}
|
||||
return gemma4JSONString(v)
|
||||
}
|
||||
|
||||
// formatGemma4Float renders like Python's json.dumps(float): integral floats
|
||||
// keep a ".0" suffix ("108." decodes to 108.0, not 108), so the arguments
|
||||
// JSON matches what vLLM would have produced for the same payload.
|
||||
func formatGemma4Float(f float64) string {
|
||||
s := strconv.FormatFloat(f, 'g', -1, 64)
|
||||
if !strings.ContainsAny(s, ".eE") {
|
||||
s += ".0"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// gemma4JSONString encodes a JSON string WITHOUT HTML escaping (json.Marshal
|
||||
// would escape the angle brackets in "<div>" to \u003c / \u003e sequences;
|
||||
// payload text should survive
|
||||
// byte-for-byte, like Python's json.dumps(ensure_ascii=False)).
|
||||
func gemma4JSONString(s string) string {
|
||||
var sb strings.Builder
|
||||
enc := json.NewEncoder(&sb)
|
||||
enc.SetEscapeHTML(false)
|
||||
if err := enc.Encode(s); err != nil {
|
||||
// Unreachable for plain strings; fall back to default escaping
|
||||
// rather than emitting invalid JSON.
|
||||
b, mErr := json.Marshal(s)
|
||||
if mErr != nil {
|
||||
return `""`
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
// Encode appends a trailing newline.
|
||||
return strings.TrimSuffix(sb.String(), "\n")
|
||||
}
|
||||
592
backend/go/dllm/gemma4_parser_test.go
Normal file
592
backend/go/dllm/gemma4_parser_test.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package main
|
||||
|
||||
// Parser specs for Gemma4Parser (model output text -> pb.ChatDelta events).
|
||||
//
|
||||
// Fixture provenance:
|
||||
// - Entries marked "vLLM: <name>" are direct ports of the named test from
|
||||
// vLLM PR #45163, tests/tool_parsers/test_gemma4_tool_parser.py (the
|
||||
// authoritative test-suite for the gemma4 tool-call wire format). The
|
||||
// streaming tests' chunk lists are reused verbatim as Feed fragments.
|
||||
// - Decoder entries port the TestParseGemma4Args / TestParseGemma4Array
|
||||
// classes from the same file (non-partial mode only; this parser never
|
||||
// decodes partial payloads, see the divergence note in gemma4_parser.go).
|
||||
// - Channel/turn-marker expectations come from the chat template embedded
|
||||
// in gemma4_renderer.go (tpl L356-L362 generation prompt, L148-L158
|
||||
// strip_thinking) and vLLM's Gemma4ReasoningParser
|
||||
// (vllm/reasoning/gemma4_reasoning_parser.py).
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// flatGemma4Tool is one accumulated tool call, mirroring how LocalAI core
|
||||
// folds ToolCallDelta streams (pkg/functions/chat_deltas.go
|
||||
// ToolCallsFromChatDeltas: name/id latch on first non-empty, arguments
|
||||
// concatenate per index). Tests flatten through the same rules so they
|
||||
// assert exactly what core will reconstruct.
|
||||
type flatGemma4Tool struct {
|
||||
id string
|
||||
name string
|
||||
args string
|
||||
}
|
||||
|
||||
func flattenGemma4Deltas(deltas []*pb.ChatDelta) (string, string, []flatGemma4Tool) {
|
||||
var content, reasoning strings.Builder
|
||||
byIndex := map[int32]*flatGemma4Tool{}
|
||||
maxIdx := int32(-1)
|
||||
for _, d := range deltas {
|
||||
content.WriteString(d.GetContent())
|
||||
reasoning.WriteString(d.GetReasoningContent())
|
||||
for _, tc := range d.GetToolCalls() {
|
||||
acc, ok := byIndex[tc.GetIndex()]
|
||||
if !ok {
|
||||
acc = &flatGemma4Tool{}
|
||||
byIndex[tc.GetIndex()] = acc
|
||||
}
|
||||
if tc.GetName() != "" {
|
||||
acc.name = tc.GetName()
|
||||
}
|
||||
if tc.GetId() != "" {
|
||||
acc.id = tc.GetId()
|
||||
}
|
||||
acc.args += tc.GetArguments()
|
||||
if tc.GetIndex() > maxIdx {
|
||||
maxIdx = tc.GetIndex()
|
||||
}
|
||||
}
|
||||
}
|
||||
var tools []flatGemma4Tool
|
||||
for i := int32(0); i <= maxIdx; i++ {
|
||||
if acc, ok := byIndex[i]; ok {
|
||||
tools = append(tools, *acc)
|
||||
}
|
||||
}
|
||||
return content.String(), reasoning.String(), tools
|
||||
}
|
||||
|
||||
type wantGemma4Tool struct {
|
||||
name string
|
||||
argsJSON string // compared with MatchJSON (key order irrelevant)
|
||||
}
|
||||
|
||||
type parseGemma4Case struct {
|
||||
startInThought bool
|
||||
fragments []string
|
||||
wantContent string
|
||||
wantReasoning string
|
||||
wantTools []wantGemma4Tool
|
||||
}
|
||||
|
||||
func parseGemma4Fragments(startInThought bool, fragments []string) []*pb.ChatDelta {
|
||||
p := NewGemma4Parser(startInThought)
|
||||
var all []*pb.ChatDelta
|
||||
for _, f := range fragments {
|
||||
all = append(all, p.Feed(f)...)
|
||||
}
|
||||
return append(all, p.Close()...)
|
||||
}
|
||||
|
||||
var _ = Describe("Gemma4Parser", func() {
|
||||
DescribeTable("parses streamed gemma4 output into ChatDeltas",
|
||||
func(c parseGemma4Case) {
|
||||
content, reasoning, tools := flattenGemma4Deltas(parseGemma4Fragments(c.startInThought, c.fragments))
|
||||
Expect(content).To(Equal(c.wantContent))
|
||||
Expect(reasoning).To(Equal(c.wantReasoning))
|
||||
Expect(tools).To(HaveLen(len(c.wantTools)))
|
||||
seenIDs := map[string]bool{}
|
||||
for i, want := range c.wantTools {
|
||||
Expect(tools[i].name).To(Equal(want.name), "tool %d name", i)
|
||||
Expect(tools[i].args).To(MatchJSON(want.argsJSON), "tool %d arguments", i)
|
||||
Expect(tools[i].id).ToNot(BeEmpty(), "tool %d id", i)
|
||||
Expect(seenIDs).ToNot(HaveKey(tools[i].id), "tool %d id must be unique", i)
|
||||
seenIDs[tools[i].id] = true
|
||||
}
|
||||
},
|
||||
|
||||
// --- (1) pure content -------------------------------------------------
|
||||
// vLLM: test_no_tool_calls
|
||||
Entry("pure content, single fragment", parseGemma4Case{
|
||||
fragments: []string{"Hello, how can I help you today?"},
|
||||
wantContent: "Hello, how can I help you today?",
|
||||
}),
|
||||
|
||||
// --- (2) thought -> final transition ----------------------------------
|
||||
// enable_thinking render: prompt ends at <|turn>model\n and the model
|
||||
// opens/closes its own thought channel in the OUTPUT (vLLM
|
||||
// Gemma4ReasoningParser docstring; tpl L356-L362). The "thought\n"
|
||||
// role label after <|channel> is structural and must be stripped
|
||||
// (vLLM _THOUGHT_PREFIX handling).
|
||||
Entry("thought channel then final content", parseGemma4Case{
|
||||
fragments: []string{"<|channel>thought\nLet me think about this.\n<channel|>The answer is 42."},
|
||||
wantReasoning: "Let me think about this.\n",
|
||||
wantContent: "The answer is 42.",
|
||||
}),
|
||||
|
||||
// --- (3) startInThought both ways -------------------------------------
|
||||
Entry("startInThought=true routes initial text to reasoning until <channel|>", parseGemma4Case{
|
||||
startInThought: true,
|
||||
fragments: []string{"I am thinking hard.<channel|>Done."},
|
||||
wantReasoning: "I am thinking hard.",
|
||||
wantContent: "Done.",
|
||||
}),
|
||||
// A stray <channel|> with no open channel is swallowed, matching the
|
||||
// template's strip_thinking (tpl L148-L158: the marker is dropped,
|
||||
// text on both sides is kept).
|
||||
Entry("startInThought=false keeps the same text as content, stray <channel|> swallowed", parseGemma4Case{
|
||||
startInThought: false,
|
||||
fragments: []string{"I am thinking hard.<channel|>Done."},
|
||||
wantContent: "I am thinking hard.Done.",
|
||||
}),
|
||||
|
||||
// --- (4) one tool call, full payload type zoo --------------------------
|
||||
Entry("single tool call: strings, numbers, bools, null, nested object and array", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:complex_function{text:<|"|>with, comma and {braces}<|"|>,count:42,score:3.14,yes:true,no:false,nothing:null,obj:{inner:<|"|>v<|"|>,k:1},arr:[<|"|>a<|"|>,2,true]}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{
|
||||
name: "complex_function",
|
||||
argsJSON: `{"text":"with, comma and {braces}","count":42,"score":3.14,"yes":true,"no":false,"nothing":null,"obj":{"inner":"v","k":1},"arr":["a",2,true]}`,
|
||||
}},
|
||||
}),
|
||||
|
||||
// --- (5) payload split across 3 fragments ------------------------------
|
||||
Entry("tool-call payload split across three fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>call:get_weather{loc",
|
||||
`ation:<|"|>Paris, Fra`,
|
||||
`nce<|"|>}<tool_call|>`,
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
|
||||
}),
|
||||
|
||||
// --- (6) marker split across fragments ----------------------------------
|
||||
Entry("tool-call open marker split across fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_ca",
|
||||
`ll>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
Entry("channel open marker split across fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|chan",
|
||||
"nel>thought\ndeep thought<channel|>final",
|
||||
},
|
||||
wantReasoning: "deep thought",
|
||||
wantContent: "final",
|
||||
}),
|
||||
|
||||
// --- (7) trailing partial marker held, flushed by Close -----------------
|
||||
Entry("trailing partial marker is held back and flushed by Close", parseGemma4Case{
|
||||
fragments: []string{"Hello <|tool"},
|
||||
wantContent: "Hello <|tool",
|
||||
}),
|
||||
|
||||
// --- (8) malformed/incomplete payload -> content fallback ---------------
|
||||
// vLLM: test_incomplete_tool_call (no end marker: the whole text stays
|
||||
// content, never silently dropped).
|
||||
Entry("incomplete tool payload at Close is emitted as raw content", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London`},
|
||||
wantContent: `<|tool_call>call:get_weather{location:<|"|>London`,
|
||||
}),
|
||||
Entry("malformed complete payload is emitted as raw content, parsing continues", parseGemma4Case{
|
||||
fragments: []string{"<|tool_call>oops no call syntax<tool_call|> done"},
|
||||
wantContent: "<|tool_call>oops no call syntax<tool_call|> done",
|
||||
}),
|
||||
|
||||
// --- (9) <turn|> ends the turn -------------------------------------------
|
||||
Entry("text after <turn|> is ignored, including later fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"before<turn|>after",
|
||||
`more <|tool_call>call:f{}<tool_call|>`,
|
||||
},
|
||||
wantContent: "before",
|
||||
}),
|
||||
Entry("<turn|> inside a thought channel ends the turn", parseGemma4Case{
|
||||
startInThought: true,
|
||||
fragments: []string{"thinking<turn|>ignored"},
|
||||
wantReasoning: "thinking",
|
||||
}),
|
||||
|
||||
// --- (10) ported vLLM non-streaming cases ---------------------------------
|
||||
// vLLM: test_single_tool_call
|
||||
Entry("vLLM: test_single_tool_call", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_multiple_arguments
|
||||
Entry("vLLM: test_multiple_arguments", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"San Francisco","unit":"celsius"}`}},
|
||||
}),
|
||||
// vLLM: test_text_before_tool_call. DIVERGENCE: vLLM's non-streaming
|
||||
// extractor trims the content ("...you."); a streaming parser cannot
|
||||
// retroactively trim already-emitted text, so the trailing space is
|
||||
// kept (vLLM's own streaming path keeps it too, see
|
||||
// test_streaming_text_before_tool_call which only checks a prefix).
|
||||
Entry("vLLM: test_text_before_tool_call (streaming semantics: no trim)", parseGemma4Case{
|
||||
fragments: []string{`Let me check the weather for you. <|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`},
|
||||
wantContent: "Let me check the weather for you. ",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
|
||||
}),
|
||||
// vLLM: test_multiple_tool_calls (also covers case 11: multi-tool sequence)
|
||||
Entry("vLLM: test_multiple_tool_calls", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|><|tool_call>call:get_time{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{
|
||||
{name: "get_weather", argsJSON: `{"location":"London"}`},
|
||||
{name: "get_time", argsJSON: `{"location":"London"}`},
|
||||
},
|
||||
}),
|
||||
// vLLM: test_nested_arguments
|
||||
Entry("vLLM: test_nested_arguments", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:complex_function{nested:{inner:<|"|>value<|"|>},list:[<|"|>a<|"|>,<|"|>b<|"|>]}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "complex_function", argsJSON: `{"nested":{"inner":"value"},"list":["a","b"]}`}},
|
||||
}),
|
||||
// vLLM: test_tool_call_with_number_and_boolean
|
||||
Entry("vLLM: test_tool_call_with_number_and_boolean", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:set_status{is_active:true,count:42,score:3.14}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "set_status", argsJSON: `{"is_active":true,"count":42,"score":3.14}`}},
|
||||
}),
|
||||
// vLLM: test_hyphenated_function_name
|
||||
Entry("vLLM: test_hyphenated_function_name", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "get-weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_dotted_function_name
|
||||
Entry("vLLM: test_dotted_function_name", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "weather.get", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_no_arguments
|
||||
Entry("vLLM: test_no_arguments", parseGemma4Case{
|
||||
fragments: []string{"<|tool_call>call:get_status{}<tool_call|>"},
|
||||
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
|
||||
}),
|
||||
|
||||
// --- ported vLLM streaming cases (chunk lists reused as fragments) --------
|
||||
// vLLM: test_basic_streaming_single_tool
|
||||
Entry("vLLM: test_basic_streaming_single_tool", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>Paris`,
|
||||
", France",
|
||||
`<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_multi_arg
|
||||
Entry("vLLM: test_streaming_multi_arg", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>Tokyo<|"|>,`,
|
||||
`unit:<|"|>celsius<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Tokyo","unit":"celsius"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_text_before_tool_call
|
||||
Entry("vLLM: test_streaming_text_before_tool_call", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"Let me check ",
|
||||
"the weather. ",
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>London<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantContent: "Let me check the weather. ",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_numeric_args
|
||||
Entry("vLLM: test_streaming_numeric_args", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:set_config{",
|
||||
"count:42,",
|
||||
"active:true}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "set_config", argsJSON: `{"count":42,"active":true}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_boolean_split_across_chunks
|
||||
Entry("vLLM: test_streaming_boolean_split_across_chunks", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:search{input:{all:tru",
|
||||
"e}}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "search", argsJSON: `{"input":{"all":true}}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_false_split_across_chunks
|
||||
Entry("vLLM: test_streaming_false_split_across_chunks", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:set{flag:fals",
|
||||
"e}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"flag":false}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_number_split_across_chunks
|
||||
Entry("vLLM: test_streaming_number_split_across_chunks", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:set{count:4",
|
||||
"2}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"count":42}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_empty_args
|
||||
Entry("vLLM: test_streaming_empty_args", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_status{}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_split_delimiter_no_invalid_json (string
|
||||
// delimiter <|"|> split across fragments must not leak fragments).
|
||||
Entry("vLLM: test_streaming_split_delimiter_no_invalid_json", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:todowrite{",
|
||||
`content:<|"|>Buy milk<|`,
|
||||
`"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "todowrite", argsJSON: `{"content":"Buy milk"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call
|
||||
Entry("vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>Paris<|"|>}`,
|
||||
"<tool_call|><",
|
||||
"div>",
|
||||
},
|
||||
wantContent: "<div>",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes
|
||||
Entry("vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:write_file{",
|
||||
`path:<|"|>index.html<|"|>,`,
|
||||
`content:<|"|><!DOCTYPE html>` + "\n<",
|
||||
`html lang="zh-CN">` + "\n<",
|
||||
"head>\n <",
|
||||
`meta charset="UTF-8">` + "\n <",
|
||||
`meta name="viewport" content="width=device-width">` + "\n",
|
||||
`<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{
|
||||
name: "write_file",
|
||||
argsJSON: `{"path":"index.html","content":"<!DOCTYPE html>\n<html lang=\"zh-CN\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width\">\n"}`,
|
||||
}},
|
||||
}),
|
||||
// vLLM: test_streaming_single_chunk_complete_tool_call
|
||||
Entry("vLLM: test_streaming_single_chunk_complete_tool_call", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:name_a_color{color_hex:<|"|>00ff11<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "name_a_color", argsJSON: `{"color_hex":"00ff11"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_multi_chunk_batched_tool_calls (two complete
|
||||
// calls in ONE fragment; both must come out with distinct indices)
|
||||
Entry("vLLM: test_streaming_multi_chunk_batched_tool_calls", parseGemma4Case{
|
||||
fragments: []string{
|
||||
`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>` +
|
||||
`<|tool_call>call:get_time{timezone:<|"|>GMT<|"|>}<tool_call|>`,
|
||||
},
|
||||
wantTools: []wantGemma4Tool{
|
||||
{name: "get_weather", argsJSON: `{"location":"London"}`},
|
||||
{name: "get_time", argsJSON: `{"timezone":"GMT"}`},
|
||||
},
|
||||
}),
|
||||
// vLLM: test_streaming_trailing_bare_bool_not_duplicated
|
||||
Entry("vLLM: test_streaming_trailing_bare_bool_not_duplicated", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:Edit{",
|
||||
`file_path:<|"|>src/env.py<|"|>,`,
|
||||
`old_string:<|"|>old_val<|"|>,`,
|
||||
`new_string:<|"|>new_val<|"|>,`,
|
||||
"replace_all:",
|
||||
"false}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{
|
||||
name: "Edit",
|
||||
argsJSON: `{"file_path":"src/env.py","old_string":"old_val","new_string":"new_val","replace_all":false}`,
|
||||
}},
|
||||
}),
|
||||
|
||||
// --- implicit reasoning end on <|tool_call> (vLLM is_reasoning_end:
|
||||
// a tool_call token means reasoning is over) -----------------------------
|
||||
Entry("tool call inside an open thought channel ends the reasoning", parseGemma4Case{
|
||||
startInThought: true,
|
||||
fragments: []string{`need the weather<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}<tool_call|>`},
|
||||
wantReasoning: "need the weather",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Rome"}`}},
|
||||
}),
|
||||
|
||||
// --- (12) empty fragments are no-ops --------------------------------------
|
||||
Entry("empty fragments are no-ops", parseGemma4Case{
|
||||
fragments: []string{"", "Hello", "", "", " world", ""},
|
||||
wantContent: "Hello world",
|
||||
}),
|
||||
)
|
||||
|
||||
It("returns no deltas for an empty fragment and after Close", func() {
|
||||
p := NewGemma4Parser(false)
|
||||
Expect(p.Feed("")).To(BeEmpty())
|
||||
Expect(p.Feed("hi")).ToNot(BeEmpty())
|
||||
Expect(p.Close()).To(BeEmpty()) // nothing held back
|
||||
// The parser is finished after Close: further input is dropped.
|
||||
Expect(p.Feed("more")).To(BeEmpty())
|
||||
Expect(p.Close()).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("generates index-based tool call ids (call_<index>)", func() {
|
||||
// Mirrors the index-based id convention of pkg/grpc/rich_test.go and
|
||||
// keeps ids deterministic for the split-invariance property below.
|
||||
deltas := parseGemma4Fragments(false, []string{
|
||||
`<|tool_call>call:a{}<tool_call|><|tool_call>call:b{}<tool_call|>`,
|
||||
})
|
||||
_, _, tools := flattenGemma4Deltas(deltas)
|
||||
Expect(tools).To(HaveLen(2))
|
||||
Expect(tools[0].id).To(Equal("call_0"))
|
||||
Expect(tools[1].id).To(Equal("call_1"))
|
||||
})
|
||||
|
||||
// Property: for a fixed full output, EVERY 2-split position must yield
|
||||
// exactly the same flattened result as the unsplit parse. This kills
|
||||
// fragment-boundary bugs (mid-marker, mid-delimiter, mid-payload splits).
|
||||
DescribeTable("2-split fragment invariance",
|
||||
func(startInThought bool, full string) {
|
||||
refContent, refReasoning, refTools := flattenGemma4Deltas(
|
||||
parseGemma4Fragments(startInThought, []string{full}))
|
||||
for i := 0; i <= len(full); i++ {
|
||||
content, reasoning, tools := flattenGemma4Deltas(
|
||||
parseGemma4Fragments(startInThought, []string{full[:i], full[i:]}))
|
||||
Expect(content).To(Equal(refContent), fmt.Sprintf("content diverged at split %d", i))
|
||||
Expect(reasoning).To(Equal(refReasoning), fmt.Sprintf("reasoning diverged at split %d", i))
|
||||
Expect(tools).To(Equal(refTools), fmt.Sprintf("tool calls diverged at split %d", i))
|
||||
}
|
||||
},
|
||||
Entry("thought + content + two tool calls + turn end", false,
|
||||
"<|channel>thought\nPondering the request...\n<channel|>Sure - calling tools now. "+
|
||||
`<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>,days:3,detailed:true}<tool_call|>`+
|
||||
`<|tool_call>call:get_time{timezone:<|"|>Europe/Lisbon<|"|>,nested:{flag:false,vals:[1,2.5,<|"|>x<|"|>]}}<tool_call|>`+
|
||||
"Done.<turn|>ignored tail"),
|
||||
Entry("startInThought + tool call + trailing partial marker", true,
|
||||
`Deep thought<channel|>final answer <|tool_call>call:noop{}<tool_call|> trailing <|tool`),
|
||||
Entry("malformed payload fallback", false,
|
||||
`pre <|tool_call>not a call<tool_call|> post`),
|
||||
)
|
||||
})
|
||||
|
||||
// Decoder-level ports of vLLM's TestParseGemma4Args / TestParseGemma4Array
|
||||
// (non-partial mode; the partial-withholding tests do not apply because this
|
||||
// parser only ever decodes COMPLETE payloads, see gemma4_parser.go).
|
||||
var _ = Describe("decodeGemma4Args", func() {
|
||||
DescribeTable("decodes the gemma4 call syntax into JSON arguments",
|
||||
func(in, wantJSON string) {
|
||||
Expect(decodeGemma4Args(in, 0)).To(MatchJSON(wantJSON))
|
||||
},
|
||||
// vLLM: test_empty_string / test_whitespace_only
|
||||
Entry("empty string", "", `{}`),
|
||||
Entry("whitespace only", " ", `{}`),
|
||||
// vLLM: test_single_string_value
|
||||
Entry("single string value", `location:<|"|>Paris<|"|>`, `{"location":"Paris"}`),
|
||||
// vLLM: test_string_value_with_comma
|
||||
Entry("string value with comma", `location:<|"|>Paris, France<|"|>`, `{"location":"Paris, France"}`),
|
||||
// vLLM: test_multiple_string_values
|
||||
Entry("multiple string values", `location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>`, `{"location":"San Francisco","unit":"celsius"}`),
|
||||
// vLLM: test_integer_value / test_float_value
|
||||
Entry("integer value", "count:42", `{"count":42}`),
|
||||
Entry("float value", "score:3.14", `{"score":3.14}`),
|
||||
// vLLM: test_boolean_true / test_boolean_false
|
||||
Entry("boolean true", "flag:true", `{"flag":true}`),
|
||||
Entry("boolean false", "flag:false", `{"flag":false}`),
|
||||
// vLLM: test_null_value (bare null must become JSON null, not "null")
|
||||
Entry("null value", "param:null", `{"param":null}`),
|
||||
// vLLM: test_mixed_types
|
||||
Entry("mixed types", `name:<|"|>test<|"|>,count:42,active:true,score:3.14`,
|
||||
`{"name":"test","count":42,"active":true,"score":3.14}`),
|
||||
// vLLM: test_nested_object
|
||||
Entry("nested object", `nested:{inner:<|"|>value<|"|>}`, `{"nested":{"inner":"value"}}`),
|
||||
// vLLM: test_array_of_strings
|
||||
Entry("array of strings", `items:[<|"|>a<|"|>,<|"|>b<|"|>]`, `{"items":["a","b"]}`),
|
||||
// vLLM: test_unterminated_string (take everything after the delimiter)
|
||||
Entry("unterminated string", `key:<|"|>unterminated`, `{"key":"unterminated"}`),
|
||||
// vLLM: test_empty_value (key with no value after colon)
|
||||
Entry("empty value", "key:", `{"key":""}`),
|
||||
// vLLM: test_trailing_dot_float_partial_withheld, non-partial branch
|
||||
// (trailing-dot floats parse normally outside streaming).
|
||||
Entry("trailing dot float, complete payload", "left:108.,right:22.8", `{"left":108.0,"right":22.8}`),
|
||||
)
|
||||
|
||||
It("terminates and yields valid JSON on malformed input", func() {
|
||||
// vLLM: test_malformed_partial_array (the assertion there is only
|
||||
// "returns a dict without hanging"; ours is "valid JSON object").
|
||||
out := decodeGemma4Args(":[t:[]", 0)
|
||||
var v map[string]any
|
||||
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
|
||||
})
|
||||
|
||||
It("degrades nesting beyond the recursion cap to a string value", func() {
|
||||
// 200 levels of a:{a:{...a:1...}}. Without the depth cap the mutual
|
||||
// recursion would grow the stack with the model's output; a Go stack
|
||||
// overflow is a fatal process kill, so levels past gemma4MaxArgsDepth
|
||||
// must gracefully fall back to the raw inner text as a JSON string.
|
||||
const depth = 200
|
||||
body := strings.Repeat("a:{", depth-1) + "a:1" + strings.Repeat("}", depth-1)
|
||||
out := decodeGemma4Args(body, 0)
|
||||
var v map[string]any
|
||||
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
|
||||
levels := 0
|
||||
var cur any = v
|
||||
for {
|
||||
m, ok := cur.(map[string]any)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
Expect(m).To(HaveKey("a"))
|
||||
cur = m["a"]
|
||||
levels++
|
||||
}
|
||||
Expect(levels).To(Equal(gemma4MaxArgsDepth + 1))
|
||||
Expect(cur).To(BeAssignableToTypeOf(""))
|
||||
Expect(cur).To(ContainSubstring("a:{"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("decodeGemma4Array", func() {
|
||||
DescribeTable("decodes gemma4 array bodies into JSON arrays",
|
||||
func(in, wantJSON string) {
|
||||
Expect(decodeGemma4Array(in, 0)).To(MatchJSON(wantJSON))
|
||||
},
|
||||
// vLLM: test_string_array / test_empty_array / test_bare_values
|
||||
Entry("string array", `<|"|>a<|"|>,<|"|>b<|"|>`, `["a","b"]`),
|
||||
Entry("empty array", "", `[]`),
|
||||
Entry("bare values", "42,true,3.14", `[42,true,3.14]`),
|
||||
// vLLM: test_string_element_with_closing_bracket (a ']' inside a
|
||||
// delimited string must not close the array)
|
||||
Entry("string element with closing bracket", `[<|"|>a]b<|"|>,<|"|>c<|"|>],<|"|>tail<|"|>`, `[["a]b","c"],"tail"]`),
|
||||
// vLLM: test_stray_closing_bracket (no-progress abort, keep prefix)
|
||||
Entry("stray closing bracket", "42,]trailing", `[42]`),
|
||||
)
|
||||
})
|
||||
1026
backend/go/dllm/gemma4_renderer.go
Normal file
1026
backend/go/dllm/gemma4_renderer.go
Normal file
File diff suppressed because it is too large
Load Diff
347
backend/go/dllm/gemma4_renderer_test.go
Normal file
347
backend/go/dllm/gemma4_renderer_test.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package main
|
||||
|
||||
// Renderer specs for RenderGemma4 against the canonical gemma4 chat template
|
||||
// (see the normative template comment in gemma4_renderer.go).
|
||||
//
|
||||
// Fixture provenance:
|
||||
// - "single user message" and "enable_thinking" are the EXACT expected
|
||||
// decodes from transformers tests/models/diffusion_gemma/
|
||||
// test_modeling_diffusion_gemma.py (test_diffusion_gemma_chat_template
|
||||
// and ..._with_thinking) with ONE difference: the transformers fixtures
|
||||
// start with "<bos>" because apply_chat_template tokenizes the rendered
|
||||
// text with add_bos. Our prompt goes through dllm_capi_generate, whose
|
||||
// run_generate already tokenizes with prepend_bos = vocab.add_bos
|
||||
// (dllm.cpp src/capi.cpp:230-231, true for gemma4), so the renderer must
|
||||
// NOT emit a literal <bos> (it would double) and every expected string
|
||||
// here drops that leading token.
|
||||
// - All other expected strings were produced by rendering the verbatim
|
||||
// GGUF template with jinja2 3.1.2 (bos_token="<bos>") and dropping the
|
||||
// leading "<bos>" for the same reason.
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// Two-function tools array used by the tool fixtures (OpenAI wire shape, as
|
||||
// LocalAI passes it through PredictOptions.Tools).
|
||||
const testToolsJSON = `[{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a location.","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city name."},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}},{"type":"function","function":{"name":"get_time","description":"Get the current time in a timezone.","parameters":{"type":"object","properties":{"timezone":{"type":"string","description":"IANA timezone name."}},"required":["timezone"]}}}]`
|
||||
|
||||
// The <|tool>...<tool|> block the template renders for testToolsJSON inside
|
||||
// the system turn (jinja2-verified).
|
||||
const testToolsBlock = `<|tool>declaration:get_weather{description:<|"|>Get the current weather in a location.<|"|>,parameters:{properties:{location:{description:<|"|>The city name.<|"|>,type:<|"|>STRING<|"|>},unit:{enum:[<|"|>celsius<|"|>,<|"|>fahrenheit<|"|>],type:<|"|>STRING<|"|>}},required:[<|"|>location<|"|>],type:<|"|>OBJECT<|"|>}}<tool|><|tool>declaration:get_time{description:<|"|>Get the current time in a timezone.<|"|>,parameters:{properties:{timezone:{description:<|"|>IANA timezone name.<|"|>,type:<|"|>STRING<|"|>}},required:[<|"|>timezone<|"|>],type:<|"|>OBJECT<|"|>}}<tool|>`
|
||||
|
||||
// A single tool exercising the deep format_parameters branches: array items
|
||||
// (string-typed and nested-array), nullable, enum+nullable, nested object
|
||||
// properties/required, and a response declaration.
|
||||
const complexToolsJSON = `[{"type":"function","function":{"name":"complex_tool","description":"A complex tool.","parameters":{"type":"object","properties":{"tags":{"type":"array","description":"Tags.","items":{"type":"string"}},"matrix":{"type":"array","items":{"type":"array","items":{"type":"number"}}},"opts":{"type":"object","description":"Options.","properties":{"depth":{"type":"integer","nullable":true}},"required":["depth"]},"mode":{"type":"string","enum":["a","b"],"nullable":true}},"required":["tags","opts"]},"response":{"description":"The result.","type":"object"}}}]`
|
||||
|
||||
// jinja2-verified render of complexToolsJSON. Notable template quirks pinned
|
||||
// here: nested array items go through format_argument with ESCAPED keys and
|
||||
// an un-uppercased type (<|"|>type<|"|>:<|"|>number<|"|>), while direct item
|
||||
// types are uppercased; properties dictsort case-insensitively.
|
||||
const complexToolsBlock = `<|tool>declaration:complex_tool{description:<|"|>A complex tool.<|"|>,parameters:{properties:{matrix:{items:{items:{<|"|>type<|"|>:<|"|>number<|"|>},type:<|"|>ARRAY<|"|>},type:<|"|>ARRAY<|"|>},mode:{enum:[<|"|>a<|"|>,<|"|>b<|"|>],nullable:true,type:<|"|>STRING<|"|>},opts:{description:<|"|>Options.<|"|>,properties:{depth:{nullable:true,type:<|"|>INTEGER<|"|>}},required:[<|"|>depth<|"|>],type:<|"|>OBJECT<|"|>},tags:{description:<|"|>Tags.<|"|>,items:{type:<|"|>STRING<|"|>},type:<|"|>ARRAY<|"|>}},required:[<|"|>tags<|"|>,<|"|>opts<|"|>],type:<|"|>OBJECT<|"|>},response:{description:<|"|>The result.<|"|>,type:<|"|>OBJECT<|"|>}}<tool|>`
|
||||
|
||||
type renderGemma4Case struct {
|
||||
msgs []*pb.Message
|
||||
toolsJSON string
|
||||
enableThinking bool
|
||||
noGenerationPrompt bool // inverted so the zero value is the common case
|
||||
expected string
|
||||
}
|
||||
|
||||
var _ = Describe("RenderGemma4", func() {
|
||||
DescribeTable("renders the canonical gemma4 prompt",
|
||||
func(c renderGemma4Case) {
|
||||
out, err := RenderGemma4(c.msgs, c.toolsJSON, c.enableThinking, !c.noGenerationPrompt)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal(c.expected))
|
||||
// The C-ABI generate prepends BOS itself: a literal <bos>
|
||||
// anywhere in the rendered prompt would double-encode it.
|
||||
Expect(out).ToNot(ContainSubstring("<bos>"))
|
||||
},
|
||||
|
||||
// transformers fixture (test_diffusion_gemma_chat_template), sans <bos>:
|
||||
// default thinking pre-opens an EMPTY thought channel in the
|
||||
// generation prompt.
|
||||
Entry("single user message, default (no thinking)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Write a long essay about Portugal."},
|
||||
},
|
||||
expected: "<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// transformers fixture (test_diffusion_gemma_chat_template_with_thinking),
|
||||
// sans <bos>: a system turn carrying <|think|> and NO auto-opened
|
||||
// thought channel.
|
||||
Entry("enable_thinking=true", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Write a long essay about Portugal."},
|
||||
},
|
||||
enableThinking: true,
|
||||
expected: "<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n",
|
||||
}),
|
||||
|
||||
Entry("multi-turn user/assistant/user", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Hello, who are you?"},
|
||||
{Role: "assistant", Content: "I am Gemma, a helpful assistant."},
|
||||
{Role: "user", Content: "Tell me a joke."},
|
||||
},
|
||||
expected: "<|turn>user\nHello, who are you?<turn|>\n<|turn>model\nI am Gemma, a helpful assistant.<turn|>\n<|turn>user\nTell me a joke.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L178-L195: a leading system message is folded into the system
|
||||
// turn (trimmed) and consumed from the loop.
|
||||
Entry("system message folds into the system turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "system", Content: "You are a pirate."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|turn>system\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L182-L185: <|think|> goes at the very top of the SAME system
|
||||
// turn, before the system prompt text.
|
||||
Entry("system message with enable_thinking shares the turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "system", Content: "You are a pirate."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
enableThinking: true,
|
||||
expected: "<|turn>system\n<|think|>\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n",
|
||||
}),
|
||||
|
||||
// tpl L196-L203: tool declarations render in the system turn, one
|
||||
// <|tool>declaration:...<tool|> block per tool, no separators.
|
||||
Entry("tools array (two functions)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
||||
},
|
||||
toolsJSON: testToolsJSON,
|
||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// format_parameters deep branches (tpl L1-L85) + response declaration
|
||||
// (tpl L106-L116).
|
||||
Entry("complex tool schema (array items, nullable, nested object, response)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
},
|
||||
toolsJSON: complexToolsJSON,
|
||||
expected: "<|turn>system\n" + complexToolsBlock + "<turn|>\n<|turn>user\ngo<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L243-L313: assistant tool_calls render as
|
||||
// <|tool_call>call:name{args}<tool_call|>; the following role=tool
|
||||
// message renders inline as <|tool_response>response:name{value:..}
|
||||
// <tool_response|>; the model turn stays OPEN (no <turn|>, no new
|
||||
// generation prompt) so the model continues after the response.
|
||||
Entry("assistant tool_calls + role=tool result", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
|
||||
{Role: "tool", ToolCallId: "call_1", Content: "Sunny, 22 degrees celsius."},
|
||||
},
|
||||
toolsJSON: testToolsJSON,
|
||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny, 22 degrees celsius.<|"|>}<tool_response|>`,
|
||||
}),
|
||||
|
||||
// tpl L348-L349: a tool_calls turn with no rendered responses ends
|
||||
// on an OPEN <|tool_response> marker for the runtime to fill, and
|
||||
// add_generation_prompt adds nothing (tpl L357).
|
||||
Entry("assistant tool_calls without a result leaves <|tool_response> open", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
|
||||
},
|
||||
toolsJSON: testToolsJSON,
|
||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>`,
|
||||
}),
|
||||
|
||||
// tpl L237-L241: reasoning_content renders as a thought channel only
|
||||
// on a tool-calling turn after the last user message.
|
||||
Entry("reasoning_content with tool_calls renders the thought channel", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "weather?"},
|
||||
{Role: "assistant", Content: "", ReasoningContent: "I should call the tool", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
|
||||
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
|
||||
},
|
||||
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n<|channel>thought\nI should call the tool\n<channel|>" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>`,
|
||||
}),
|
||||
|
||||
// tpl L220-L235: the assistant answer following its own tool round
|
||||
// continues the SAME model turn (no second <|turn>model).
|
||||
Entry("tool round then final assistant answer then user", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "weather?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
|
||||
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
|
||||
{Role: "assistant", Content: "It is sunny."},
|
||||
{Role: "user", Content: "thanks"},
|
||||
},
|
||||
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>` + "It is sunny.<turn|>\n<|turn>user\nthanks<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// format_argument (tpl L118-L147): numbers keep their JSON literal,
|
||||
// booleans lower-case, nested maps have unquoted dictsorted keys,
|
||||
// arrays bracketed; top-level args are dictsorted case-insensitively.
|
||||
Entry("tool_call argument types (number/bool/nested/array)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"count\":42,\"ratio\":3.5,\"flag\":true,\"off\":false,\"nested\":{\"x\":\"y\",\"n\":7},\"list\":[\"a\",1,true]}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{count:42,flag:true,list:[<|"|>a<|"|>,1,true],nested:{n:7,x:<|"|>y<|"|>},off:false,ratio:3.5}<tool_call|><|tool_response>`,
|
||||
}),
|
||||
|
||||
// jinja dictsort is case-insensitive: alpha sorts before Beta.
|
||||
Entry("tool_call argument dictsort is case-insensitive", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"Beta\":1,\"alpha\":2}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{alpha:2,Beta:1}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
// jinja renders Python None as "None" (round-trips through vLLM's
|
||||
// parser, which lowers "none" back to null).
|
||||
Entry("tool_call null argument renders as None", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"maybe\":null}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{maybe:None}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
Entry("tool_call empty arguments render empty braces", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
// tpl L253-L254: a non-object arguments string renders verbatim.
|
||||
Entry("tool_call non-object string arguments render verbatim", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"just text"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{just text}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
// tpl L278-L285: unmatched tool_call_id falls back to the tool
|
||||
// message's own name.
|
||||
Entry("tool result name falls back when tool_call_id does not match", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
|
||||
{Role: "tool", ToolCallId: "OTHER", Name: "named_tool", Content: "out"},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{}<tool_call|><|tool_response>response:named_tool{value:<|"|>out<|"|>}<tool_response|>`,
|
||||
}),
|
||||
|
||||
// strip_thinking (tpl L148-L158): historical assistant content loses
|
||||
// its <|channel>...<channel|> spans.
|
||||
Entry("assistant content thinking channels are stripped", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "<|channel>thought\nsecret\n<channel|>visible answer"},
|
||||
{Role: "user", Content: "more"},
|
||||
},
|
||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\nvisible answer<turn|>\n<|turn>user\nmore<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L220-L235: consecutive assistant messages suppress the second
|
||||
// <|turn>model (continuation), but each still closes with <turn|>.
|
||||
Entry("consecutive assistant messages continue the model turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "part one"},
|
||||
{Role: "assistant", Content: "part two"},
|
||||
{Role: "user", Content: "ok"},
|
||||
},
|
||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\npart one<turn|>\npart two<turn|>\n<|turn>user\nok<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
Entry("add_generation_prompt=false renders no model turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
},
|
||||
noGenerationPrompt: true,
|
||||
expected: "<|turn>user\nhi<turn|>\n",
|
||||
}),
|
||||
)
|
||||
|
||||
Describe("error handling", func() {
|
||||
It("fails loud on an unknown role", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "narrator", Content: "Meanwhile..."},
|
||||
}, "", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(`unknown role "narrator"`))
|
||||
})
|
||||
|
||||
It("fails on invalid tools JSON", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, "{not json", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools JSON"))
|
||||
})
|
||||
|
||||
It("fails on invalid tool_calls JSON", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "", ToolCalls: "{not json"},
|
||||
}, "", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tool_calls JSON"))
|
||||
})
|
||||
|
||||
It("fails on an orphan tool message, naming its index", func() {
|
||||
// A role:tool message with no preceding assistant tool_calls turn
|
||||
// would be silently dropped by the jinja; we fail loud instead.
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "tool", Content: `{"temp": 20}`, ToolCallId: "call_1"},
|
||||
}, "", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("orphan tool message 1"))
|
||||
})
|
||||
|
||||
It("fails on trailing garbage after the tools JSON array", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, "[] junk", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools JSON"))
|
||||
})
|
||||
|
||||
It("fails when the tools JSON is not an array", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, `{"type":"function"}`, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools JSON is not an array"))
|
||||
})
|
||||
|
||||
It("fails when a tools array element is not an object", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, `[42]`, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools[0] is not an object"))
|
||||
})
|
||||
|
||||
It("rejects a nil message via the unknown-role check", func() {
|
||||
// Pins current behavior: pb getters are nil-safe, so a nil message
|
||||
// reads as role "" and trips the fail-loud unknown-role guard.
|
||||
_, err := RenderGemma4([]*pb.Message{nil}, "", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(`unknown role "" in message 0`))
|
||||
})
|
||||
})
|
||||
})
|
||||
85
backend/go/dllm/main.go
Normal file
85
backend/go/dllm/main.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
//
|
||||
// Loads libdllm.so via purego and registers the 9-symbol flat C-ABI
|
||||
// declared in dllm.cpp's include/dllm_capi.h (ABI v1). The library name can
|
||||
// be overridden with DLLM_LIBRARY (mirrors the PARAKEET_LIBRARY /
|
||||
// WHISPER_LIBRARY convention in the sibling backends); the default looks
|
||||
// for the .so next to this binary (run.sh puts the package dir on
|
||||
// LD_LIBRARY_PATH).
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
// loadCAPI dlopens libName and binds the 9 dllm_capi_* entry points 1:1 to
|
||||
// dllm_capi.h, so an `nm libdllm.so | grep dllm_capi` is enough to spot
|
||||
// drift. Shared with the test suite (ensureLibLoaded), which drives the
|
||||
// bridge without the gRPC server.
|
||||
//
|
||||
// The C-ABI returns malloc'd char* buffers from tokenize_json/generate; we
|
||||
// register those as uintptr so we get the raw pointer back and can call
|
||||
// dllm_capi_free_string on it (purego's string return would copy and forget
|
||||
// the original pointer, leaking it on every call). last_error returns a
|
||||
// BORROWED pointer instead, so it is registered as a plain string: purego
|
||||
// copies it and nothing must be freed.
|
||||
func loadCAPI(libName string) error {
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dllm: dlopen %q: %w", libName, err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&cppAbiVersion, "dllm_capi_abi_version"},
|
||||
{&cppLoad, "dllm_capi_load"},
|
||||
{&cppFree, "dllm_capi_free"},
|
||||
{&cppLastError, "dllm_capi_last_error"},
|
||||
{&cppFreeString, "dllm_capi_free_string"},
|
||||
{&cppTokenizeJSON, "dllm_capi_tokenize_json"},
|
||||
{&cppGenerate, "dllm_capi_generate"},
|
||||
{&cppGenerateStream, "dllm_capi_generate_stream"},
|
||||
{&cppCancel, "dllm_capi_cancel"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("DLLM_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libdllm.so"
|
||||
}
|
||||
|
||||
if err := loadCAPI(libName); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Hard-fail on an ABI mismatch: the flat-pointer bindings above would
|
||||
// otherwise misbehave silently against a future libdllm.so.
|
||||
if v := cAbiVersion(); v != dllmABIVersion {
|
||||
panic(fmt.Errorf("dllm: libdllm.so ABI=%d, this backend speaks ABI=%d", v, dllmABIVersion))
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "[dllm] ABI=%d\n", cAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &Dllm{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
24
backend/go/dllm/package.sh
Executable file
24
backend/go/dllm/package.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# T1 packaging stub: copy the binary, run.sh and libdllm.so into package/.
|
||||
# The full ldd walk (libc, libstdc++, libgomp, GPU runtimes, arch
|
||||
# detection) lands with the registration task, mirroring
|
||||
# backend/go/whisper/package.sh.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/dllm-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libdllm.so + any soname symlinks, should upstream ever add them.
|
||||
cp -avf "$CURDIR"/libdllm.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libdllm.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo "T1 package layout (full ldd walk lands with registration):"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
16
backend/go/dllm/run.sh
Executable file
16
backend/go/dllm/run.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# If a self-contained ld.so was packaged, route through it so the
|
||||
# packaged libc / libstdc++ are used instead of the host's (matches the
|
||||
# whisper / parakeet-cpp backends' runtime layout).
|
||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec "$CURDIR/lib/ld.so" "$CURDIR/dllm-grpc" "$@"
|
||||
fi
|
||||
|
||||
exec "$CURDIR/dllm-grpc" "$@"
|
||||
@@ -9,7 +9,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
# LocalVQE upstream version pin. Bump to a specific commit when picking up
|
||||
# a new release; `main` works for development but is not reproducible.
|
||||
LOCALVQE_REPO?=https://github.com/localai-org/LocalVQE
|
||||
LOCALVQE_VERSION?=72bfb4c6
|
||||
LOCALVQE_VERSION?=b0f0378a450e87c871b85689554801601ca56d98
|
||||
|
||||
# LocalVQE handles CPU feature selection internally (it ships the multiple
|
||||
# libggml-cpu-*.so variants and its loader picks the best one at runtime
|
||||
@@ -27,7 +27,8 @@ endif
|
||||
|
||||
# LocalVQE upstream supports CPU + Vulkan only. Other BUILD_TYPE values
|
||||
# fall through to the default CPU build — Vulkan is already as fast as the
|
||||
# specialised GPU paths would be on this 1.3 M-parameter model.
|
||||
# specialised GPU paths would be on these small (1.3 M–4.8 M parameter)
|
||||
# models.
|
||||
ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DLOCALVQE_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -46,24 +46,24 @@ const (
|
||||
// through the options builder (CppOptionsNew + setters + CppNewWithOptions)
|
||||
// — the bare localvqe_new path doesn't expose backend / device selection.
|
||||
var (
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
)
|
||||
|
||||
// LocalVQE speaks gRPC against LocalVQE's flat C ABI. The streaming
|
||||
@@ -490,11 +490,14 @@ func (v *LocalVQE) applyStreamConfig(cfg *pb.AudioTransformStreamConfig) error {
|
||||
|
||||
// ---- WAV I/O ----------------------------------------------------------
|
||||
//
|
||||
// Minimal mono PCM WAV reader/writer. Only handles the subset LocalVQE
|
||||
// cares about (mono, 16-bit signed, no extensible chunks). For broader
|
||||
// audio support the HTTP layer's `audio.NormalizeAudioFile` already
|
||||
// converts arbitrary input to a canonical WAV before we see it; this
|
||||
// reader just decodes the canonical shape.
|
||||
// Reader/writer for the mono 16-bit PCM shape LocalVQE works with. Decoding
|
||||
// goes through the shared go-audio/wav decoder (as the whisper and parakeet
|
||||
// backends do) so RIFF chunk walking is handled robustly — an 18/40-byte
|
||||
// extensible `fmt ` chunk, or JUNK/bext/LIST metadata before or after `data`
|
||||
// (e.g. ffmpeg's trailing "Lavf" tag), is skipped rather than spliced into
|
||||
// the PCM stream as an audible click. The HTTP layer normalises arbitrary
|
||||
// input to WAV before we see it, but that WAV is ffmpeg output and is not
|
||||
// guaranteed to be the canonical 44-byte layout.
|
||||
|
||||
func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
f, err := os.Open(path)
|
||||
@@ -502,35 +505,26 @@ func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
header := make([]byte, 44)
|
||||
if _, err := io.ReadFull(f, header); err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
buf, err := wav.NewDecoder(f).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("decode WAV: %w", err)
|
||||
}
|
||||
if string(header[0:4]) != "RIFF" || string(header[8:12]) != "WAVE" {
|
||||
if buf == nil || buf.Format == nil {
|
||||
return nil, 0, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
channels := binary.LittleEndian.Uint16(header[22:24])
|
||||
sampleRate := binary.LittleEndian.Uint32(header[24:28])
|
||||
bitsPerSample := binary.LittleEndian.Uint16(header[34:36])
|
||||
|
||||
if channels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", channels)
|
||||
if buf.Format.NumChannels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", buf.Format.NumChannels)
|
||||
}
|
||||
if bitsPerSample != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", bitsPerSample)
|
||||
if buf.SourceBitDepth != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", buf.SourceBitDepth)
|
||||
}
|
||||
|
||||
rest, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
if len(buf.Data) == 0 {
|
||||
return nil, 0, fmt.Errorf("WAV has no audio data")
|
||||
}
|
||||
n := len(rest) / 2
|
||||
out := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
s := int16(binary.LittleEndian.Uint16(rest[i*2 : i*2+2]))
|
||||
out[i] = float32(s) / 32768.0
|
||||
}
|
||||
return out, int(sampleRate), nil
|
||||
// AsFloat32Buffer normalises by 2^(bitDepth-1) == /32768 for 16-bit,
|
||||
// matching the model's expected [-1, 1) input range.
|
||||
return buf.AsFloat32Buffer().Data, buf.Format.SampleRate, nil
|
||||
}
|
||||
|
||||
func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
@@ -546,13 +540,13 @@ func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
binary.LittleEndian.PutUint32(header[4:8], 36+dataLen)
|
||||
copy(header[8:12], []byte("WAVE"))
|
||||
copy(header[12:16], []byte("fmt "))
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[24:28], uint32(sampleRate))
|
||||
binary.LittleEndian.PutUint32(header[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
copy(header[36:40], []byte("data"))
|
||||
binary.LittleEndian.PutUint32(header[40:44], dataLen)
|
||||
if _, err := f.Write(header); err != nil {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -92,6 +94,147 @@ var _ = Describe("LocalVQE-cpp", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("readMonoWAVf32 chunk parsing", func() {
|
||||
// chunk builds a word-aligned RIFF sub-chunk (id + size + body + pad).
|
||||
chunk := func(id string, body []byte) []byte {
|
||||
out := append([]byte(id), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
out = append(out, body...)
|
||||
if len(body)&1 == 1 {
|
||||
out = append(out, 0) // pad byte for odd-sized chunks
|
||||
}
|
||||
return out
|
||||
}
|
||||
// fmtBody returns a PCM `fmt ` chunk body. extra bytes simulate the
|
||||
// 18/40-byte extensible form (cbSize + extension).
|
||||
fmtBody := func(channels, bits uint16, rate uint32, extra int) []byte {
|
||||
b := make([]byte, 16+extra)
|
||||
binary.LittleEndian.PutUint16(b[0:2], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(b[2:4], channels)
|
||||
binary.LittleEndian.PutUint32(b[4:8], rate)
|
||||
binary.LittleEndian.PutUint32(b[8:12], rate*uint32(channels)*uint32(bits)/8)
|
||||
binary.LittleEndian.PutUint16(b[12:14], channels*bits/8)
|
||||
binary.LittleEndian.PutUint16(b[14:16], bits)
|
||||
if extra >= 2 {
|
||||
binary.LittleEndian.PutUint16(b[16:18], uint16(extra-2)) // cbSize
|
||||
}
|
||||
return b
|
||||
}
|
||||
// pcm encodes int16 samples little-endian.
|
||||
pcm := func(samples ...int16) []byte {
|
||||
b := make([]byte, len(samples)*2)
|
||||
for i, s := range samples {
|
||||
binary.LittleEndian.PutUint16(b[i*2:i*2+2], uint16(s))
|
||||
}
|
||||
return b
|
||||
}
|
||||
riff := func(chunks ...[]byte) []byte {
|
||||
body := []byte("WAVE")
|
||||
for _, c := range chunks {
|
||||
body = append(body, c...)
|
||||
}
|
||||
out := append([]byte("RIFF"), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
return append(out, body...)
|
||||
}
|
||||
writeWAV := func(b []byte) string {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "in.wav")
|
||||
Expect(os.WriteFile(p, b, 0o600)).To(Succeed())
|
||||
return p
|
||||
}
|
||||
// A canonical sample run with distinct values so any off-by-one /
|
||||
// misalignment shows up as wrong numbers, not just wrong length.
|
||||
samples := []int16{1000, -2000, 3000, -4000, 5000, -6000}
|
||||
expectSamples := func(got []float32) {
|
||||
Expect(got).To(HaveLen(len(samples)))
|
||||
for i, s := range samples {
|
||||
Expect(got[i]).To(BeNumerically("~", float32(s)/32768.0, 1e-6))
|
||||
}
|
||||
}
|
||||
|
||||
It("reads a canonical 44-byte WAV", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("ignores a LIST/JUNK chunk placed before data (no leading-impulse splice)", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("JUNK", []byte("padding-bytes-here!")), // odd length → exercises pad
|
||||
chunk("LIST", []byte("INFOISFTLavf60.0")),
|
||||
chunk("data", pcm(samples...)),
|
||||
))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out) // not corrupted by the preceding chunks
|
||||
})
|
||||
|
||||
It("honours the data chunk size and drops a trailing metadata chunk", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("data", pcm(samples...)),
|
||||
chunk("LIST", []byte("INFOISFTLavf60.16.100")), // ffmpeg trailer tag
|
||||
))
|
||||
out, _, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectSamples(out) // trailing LIST bytes not decoded as PCM
|
||||
})
|
||||
|
||||
It("handles the 18-byte extensible fmt chunk", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 2)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("rejects non-mono input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(2, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("mono"))
|
||||
})
|
||||
|
||||
It("rejects non-16-bit input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 8, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("16-bit"))
|
||||
})
|
||||
|
||||
It("rejects a non-WAV file", func() {
|
||||
p := writeWAV([]byte("not a riff file at all"))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors when the data chunk is missing", func() {
|
||||
// fmt but no data: the decoder must fail rather than return an
|
||||
// empty (or garbage) sample slice. The exact message is the
|
||||
// decoder's, so just assert it errors.
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("round-trips through writeMonoWAVf32", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "rt.wav")
|
||||
in := []float32{0.1, -0.2, 0.3, -0.4}
|
||||
Expect(writeMonoWAVf32(p, in, 16000)).To(Succeed())
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
Expect(out).To(HaveLen(len(in)))
|
||||
for i := range in {
|
||||
Expect(out[i]).To(BeNumerically("~", in[i], 1e-4))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("model-gated integration (LOCALVQE_MODEL_PATH)", func() {
|
||||
It("load + sample rate + hop + fft", func() {
|
||||
path := modelPathOrSkip()
|
||||
|
||||
11
backend/go/parakeet-cpp/.gitignore
vendored
Normal file
11
backend/go/parakeet-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
.cache/
|
||||
sources/
|
||||
build/
|
||||
package/
|
||||
parakeet-cpp-grpc
|
||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||
# symlinked for local dev; the real sources live in parakeet.cpp upstream.
|
||||
*.so
|
||||
*.so.*
|
||||
parakeet_capi.h
|
||||
compile_commands.json
|
||||
93
backend/go/parakeet-cpp/Makefile
Normal file
93
backend/go/parakeet-cpp/Makefile
Normal file
@@ -0,0 +1,93 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree parakeet.cpp
|
||||
# build, you can symlink the .so + header into this directory and skip
|
||||
# the clone/cmake steps entirely, e.g.:
|
||||
#
|
||||
# ln -sf /path/to/parakeet.cpp/build-shared/libparakeet.so .
|
||||
# ln -sf /path/to/parakeet.cpp/include/parakeet_capi.h .
|
||||
# go build -o parakeet-cpp-grpc .
|
||||
#
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
# Build ggml statically into libparakeet.so (PIC) so the shared lib is
|
||||
# self-contained: dlopen needs no libggml*.so alongside it, only system libs
|
||||
# (libstdc++/libgomp/libc) that the runtime image already provides.
|
||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DPARAKEET_SHARED=ON -DPARAKEET_BUILD_CLI=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# parakeet.cpp gates its GGML backends behind PARAKEET_GGML_* options and does
|
||||
# set(GGML_CUDA ${PARAKEET_GGML_CUDA} CACHE BOOL "" FORCE), so a bare -DGGML_CUDA=ON
|
||||
# is overwritten back to OFF and the build silently falls back to CPU. Forward the
|
||||
# PARAKEET_GGML_* options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_HIP=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_VULKAN=ON
|
||||
endif
|
||||
|
||||
.PHONY: parakeet-cpp-grpc package build clean purge test all
|
||||
|
||||
all: parakeet-cpp-grpc
|
||||
|
||||
# Clone the upstream parakeet.cpp source at the pinned commit. Directory
|
||||
# acts as the target so make only re-clones when missing. After a
|
||||
# PARAKEET_VERSION bump, run 'make purge && make' to refetch.
|
||||
sources/parakeet.cpp:
|
||||
mkdir -p sources/parakeet.cpp
|
||||
cd sources/parakeet.cpp && \
|
||||
git init -q && \
|
||||
git remote add origin $(PARAKEET_REPO) && \
|
||||
git fetch --depth 1 origin $(PARAKEET_VERSION) && \
|
||||
git checkout FETCH_HEAD && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Build the shared lib + header out-of-tree, then stage them next to the
|
||||
# Go sources so purego.Dlopen("libparakeet.so") and the cgo-less build
|
||||
# both pick them up.
|
||||
libparakeet.so: sources/parakeet.cpp
|
||||
cmake -B sources/parakeet.cpp/build-shared -S sources/parakeet.cpp $(CMAKE_ARGS)
|
||||
cmake --build sources/parakeet.cpp/build-shared --config Release -j$(JOBS)
|
||||
cp -fv sources/parakeet.cpp/build-shared/libparakeet.so* ./ 2>/dev/null || true
|
||||
cp -fv sources/parakeet.cpp/include/parakeet_capi.h ./
|
||||
|
||||
parakeet-cpp-grpc: libparakeet.so main.go goparakeetcpp.go
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o parakeet-cpp-grpc .
|
||||
|
||||
package: parakeet-cpp-grpc
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
# Test target. Smoke test is gated on PARAKEET_BACKEND_TEST_MODEL +
|
||||
# PARAKEET_BACKEND_TEST_WAV; without them the spec auto-skips.
|
||||
test:
|
||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||
|
||||
clean: purge
|
||||
rm -rf libparakeet.so* parakeet_capi.h package parakeet-cpp-grpc
|
||||
|
||||
purge:
|
||||
rm -rf sources/parakeet.cpp
|
||||
105
backend/go/parakeet-cpp/batcher.go
Normal file
105
backend/go/parakeet-cpp/batcher.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package main
|
||||
|
||||
import "time"
|
||||
|
||||
// batchRequest is one in-flight unary transcription waiting to be batched.
|
||||
// In production pcm/decoder are set; tag is an opaque marker used by tests.
|
||||
type batchRequest struct {
|
||||
pcm []float32
|
||||
decoder int32
|
||||
// language is the per-request target locale ("" means the model default).
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang for the whole batch,
|
||||
// so the dispatcher only coalesces requests that share a language.
|
||||
language string
|
||||
tag string
|
||||
reply chan batchReply
|
||||
}
|
||||
|
||||
// batchReply carries one per-item JSON object string (an element of the C-API's
|
||||
// JSON array) or an error back to the waiting handler goroutine.
|
||||
type batchReply struct {
|
||||
json string
|
||||
err error
|
||||
}
|
||||
|
||||
// batcher coalesces concurrent batchRequests into batched runBatch calls. A
|
||||
// single run() goroutine is the sole caller of runBatch, so runBatch (which in
|
||||
// production calls the thread-unsafe C engine) is never entered concurrently.
|
||||
type batcher struct {
|
||||
submit chan *batchRequest
|
||||
maxSize int
|
||||
maxWait time.Duration
|
||||
runBatch func(reqs []*batchRequest) // must deliver a reply to every req
|
||||
}
|
||||
|
||||
func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchRequest)) *batcher {
|
||||
if maxSize < 1 {
|
||||
maxSize = 1
|
||||
}
|
||||
return &batcher{
|
||||
submit: make(chan *batchRequest),
|
||||
maxSize: maxSize,
|
||||
maxWait: maxWait,
|
||||
runBatch: runBatch,
|
||||
}
|
||||
}
|
||||
|
||||
// run is the dispatcher loop: accumulate submitted requests until either maxSize
|
||||
// is reached or maxWait elapses since the first queued request, then dispatch.
|
||||
// Exits when stop is closed (draining any partially-filled batch first).
|
||||
//
|
||||
// A batch carries ONE language (parakeet.cpp's batched C-API takes a single
|
||||
// target_lang), so a request whose language differs from the batch leader is
|
||||
// not coalesced: it is held in carry and becomes the leader of the next batch.
|
||||
// carry is therefore never dropped and its caller never deadlocks: every batch
|
||||
// (including a lone carry on stop) is dispatched, and runBatch replies to all.
|
||||
func (b *batcher) run(stop <-chan struct{}) {
|
||||
var carry *batchRequest
|
||||
for {
|
||||
var first *batchRequest
|
||||
if carry != nil {
|
||||
// A mismatched request from the previous fill leads this batch.
|
||||
first, carry = carry, nil
|
||||
} else {
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
batch := []*batchRequest{first}
|
||||
|
||||
// maxSize==1 disables batching: dispatch immediately (passthrough).
|
||||
if b.maxSize == 1 {
|
||||
b.runBatch(batch)
|
||||
continue
|
||||
}
|
||||
|
||||
timer := time.NewTimer(b.maxWait)
|
||||
fill:
|
||||
for len(batch) < b.maxSize {
|
||||
select {
|
||||
case r := <-b.submit:
|
||||
if r.language != first.language {
|
||||
// Different language: carry it to the next batch so this
|
||||
// batch stays single-language, then dispatch what we have.
|
||||
carry = r
|
||||
break fill
|
||||
}
|
||||
batch = append(batch, r)
|
||||
case <-timer.C:
|
||||
break fill
|
||||
case <-stop:
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
// Don't strand a carried request's caller on shutdown.
|
||||
if carry != nil {
|
||||
b.runBatch([]*batchRequest{carry})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
}
|
||||
}
|
||||
164
backend/go/parakeet-cpp/batcher_test.go
Normal file
164
backend/go/parakeet-cpp/batcher_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("batcher", func() {
|
||||
echoReply := func(reqs []*batchRequest) {
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{json: r.tag}
|
||||
}
|
||||
}
|
||||
|
||||
It("coalesces concurrent submits into batches", func() {
|
||||
var mu sync.Mutex
|
||||
var sizes []int
|
||||
run := func(reqs []*batchRequest) {
|
||||
mu.Lock()
|
||||
sizes = append(sizes, len(reqs))
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(4, 50*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
const N = 4
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
total, maxBatch := 0, 0
|
||||
for _, s := range sizes {
|
||||
total += s
|
||||
if s > maxBatch {
|
||||
maxBatch = s
|
||||
}
|
||||
}
|
||||
Expect(total).To(Equal(N))
|
||||
Expect(maxBatch).To(BeNumerically(">=", 2), "expected at least one batch to coalesce >1 request")
|
||||
})
|
||||
|
||||
It("dispatches when max size is reached", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(2, time.Hour, run) // huge window: only size can trigger
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
for i := 0; i < 2; i++ {
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func(rep chan batchReply) { <-rep }(rep)
|
||||
}
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(2)))
|
||||
})
|
||||
|
||||
It("dispatches when the wait window elapses", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(8, 20*time.Millisecond, run) // size unreachable; window fires
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("bypasses batching when max size is 1", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(1, time.Hour, run) // size 1 => immediate dispatch
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("never coalesces requests with different languages into one batch", func() {
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang per batch, so the
|
||||
// dispatcher must keep every dispatched batch single-language. Submit a
|
||||
// mix of languages and assert (a) no batch ever carries more than one
|
||||
// distinct language and (b) every submitted request still gets a reply
|
||||
// (the mismatched carry-over is never dropped).
|
||||
var mu sync.Mutex
|
||||
var langsPerBatch [][]string
|
||||
run := func(reqs []*batchRequest) {
|
||||
seen := map[string]struct{}{}
|
||||
var distinct []string
|
||||
for _, r := range reqs {
|
||||
if _, ok := seen[r.language]; !ok {
|
||||
seen[r.language] = struct{}{}
|
||||
distinct = append(distinct, r.language)
|
||||
}
|
||||
}
|
||||
mu.Lock()
|
||||
langsPerBatch = append(langsPerBatch, distinct)
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
// Large window + size so the fill loop stays open across submits and the
|
||||
// language constraint (not the timer) is what splits the batches.
|
||||
b := newBatcher(16, 200*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
langs := []string{"en", "en", "de", "de", "en", "fr", "fr"}
|
||||
const N = 7
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), language: langs[i], reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// Invariant: every dispatched batch is single-language.
|
||||
for _, distinct := range langsPerBatch {
|
||||
Expect(len(distinct)).To(Equal(1), "a batch coalesced more than one language: %v", distinct)
|
||||
}
|
||||
// Liveness: every request got a reply (carry-over never stranded).
|
||||
for i := 0; i < N; i++ {
|
||||
Expect(got[i]).To(Equal(string(rune('a' + i))))
|
||||
}
|
||||
})
|
||||
})
|
||||
826
backend/go/parakeet-cpp/goparakeetcpp.go
Normal file
826
backend/go/parakeet-cpp/goparakeetcpp.go
Normal file
@@ -0,0 +1,826 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// purego-bound entry points from libparakeet.so. Names match
|
||||
// parakeet_capi.h exactly so a `nm libparakeet.so | grep parakeet_capi`
|
||||
// is enough to spot drift.
|
||||
//
|
||||
// Functions that return char* are declared as uintptr so we can call
|
||||
// parakeet_capi_free_string on the same pointer after copying, the
|
||||
// C-API contract is "caller owns and must free the returned buffer".
|
||||
var (
|
||||
CppAbiVersion func() int32
|
||||
CppLoad func(ggufPath string) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppTranscribePath func(ctx uintptr, wavPath string, decoder int32) uintptr
|
||||
CppTranscribePathJSON func(ctx uintptr, wavPath string, decoder int32) uintptr
|
||||
CppFreeString func(s uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
|
||||
// Batched JSON transcription: takes a concatenated float buffer of clips
|
||||
// plus their per-clip sample counts (sum(nSamples)==len(samplesConcat))
|
||||
// and returns a malloc'd char* JSON ARRAY of per-clip {"text","words",
|
||||
// "tokens"} objects (uintptr, freed via CppFreeString). purego passes the
|
||||
// Go slices as the base pointer of their backing array (kept alive for the
|
||||
// call), matching the CppStreamFeed pcm []float32 binding pattern; the C
|
||||
// side reads them as const float*/const int*.
|
||||
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
|
||||
|
||||
// CppTranscribePcmBatchJSONLang is the multilingual variant of the batched
|
||||
// JSON entry point: identical, plus a trailing target_lang. "" (the model
|
||||
// default, "auto") is passed for non-prompt models, which ignore it; an
|
||||
// unknown locale on a prompt model returns 0 and sets last_error. Present
|
||||
// only in newer libparakeet.so; nil falls back to CppTranscribePcmBatchJSON.
|
||||
CppTranscribePcmBatchJSONLang func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32, targetLang string) uintptr
|
||||
|
||||
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
|
||||
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
|
||||
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
|
||||
CppStreamBegin func(ctx uintptr) uintptr
|
||||
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
|
||||
CppStreamFinalize func(s uintptr) uintptr
|
||||
CppStreamFree func(s uintptr)
|
||||
|
||||
// CppStreamBeginLang is the multilingual variant of stream_begin: identical,
|
||||
// plus a trailing target_lang ("" means the model default). Present only in
|
||||
// newer libparakeet.so; nil falls back to CppStreamBegin.
|
||||
CppStreamBeginLang func(ctx uintptr, targetLang string) uintptr
|
||||
|
||||
// Streaming JSON variants (ABI v4): feed/finalize returning a malloc'd char*
|
||||
// JSON document {text,eou,frame_sec,words} (uintptr, freed via CppFreeString)
|
||||
// so streaming segments can carry per-word timestamps. Present only in newer
|
||||
// libparakeet.so; nil falls back to the text-only CppStreamFeed/Finalize path.
|
||||
CppStreamFeedJSON func(s uintptr, pcm []float32, nSamples int32) uintptr
|
||||
CppStreamFinalizeJSON func(s uintptr) uintptr
|
||||
)
|
||||
|
||||
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
|
||||
// call (1 s). The session buffers internally and decodes once a full
|
||||
// cache-aware encoder chunk is available, so this only bounds how often we
|
||||
// poll for newly-finalized text, not the model's actual chunk size.
|
||||
const streamChunkSamples = 16000
|
||||
|
||||
// transcriptJSON mirrors the document returned by
|
||||
// parakeet_capi_transcribe_path_json (see parakeet_capi.h):
|
||||
//
|
||||
// {"text":"...",
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...],
|
||||
// "tokens":[{"id":123,"t":0.480,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
|
||||
type transcriptJSON struct {
|
||||
Text string `json:"text"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
}
|
||||
|
||||
// streamFeedJSON mirrors the document returned by
|
||||
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v4):
|
||||
//
|
||||
// {"text":"...","eou":0,"frame_sec":0.080000,
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
|
||||
// <EOU>/<EOB> fired this feed; "words" are the words finalized this call with
|
||||
// absolute (stream-relative) start/end seconds.
|
||||
type streamFeedJSON struct {
|
||||
Text string `json:"text"`
|
||||
Eou int `json:"eou"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
}
|
||||
|
||||
type transcriptWord struct {
|
||||
W string `json:"w"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Conf float64 `json:"conf"`
|
||||
}
|
||||
|
||||
type transcriptToken struct {
|
||||
ID int32 `json:"id"`
|
||||
T float64 `json:"t"`
|
||||
Conf float64 `json:"conf"`
|
||||
}
|
||||
|
||||
// ParakeetCpp owns a single loaded parakeet_ctx. The C engine is a
|
||||
// thread-unsafe singleton (mirrors whisper.cpp / vibevoice.cpp). Rather than
|
||||
// serialize every call through base.SingleThread, we route unary
|
||||
// transcription through an in-process batcher (its sole dispatcher goroutine
|
||||
// is the only caller of the engine on that path) and guard the shared engine
|
||||
// with engineMu so a streaming session and a batched-unary dispatch never
|
||||
// touch it concurrently.
|
||||
type ParakeetCpp struct {
|
||||
base.Base
|
||||
ctxPtr uintptr
|
||||
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
|
||||
bat *batcher
|
||||
batStop chan struct{}
|
||||
// segmentGapFrames is NeMo's segment_gap_threshold in ENCODER FRAMES (model
|
||||
// YAML option, default 0=off). When >0 it adds NeMo's silence-gap split on
|
||||
// top of the punctuation split; converted to seconds via the JSON frame_sec.
|
||||
segmentGapFrames int
|
||||
}
|
||||
|
||||
// Load is the LocalAI gRPC entry point for LoadModel: it calls
|
||||
// parakeet_capi_load with the GGUF path and stashes the resulting
|
||||
// opaque context pointer for AudioTranscription.
|
||||
func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
if opts.ModelFile == "" {
|
||||
return errors.New("parakeet-cpp: ModelFile is required")
|
||||
}
|
||||
|
||||
ctx := CppLoad(opts.ModelFile)
|
||||
if ctx == 0 {
|
||||
// No ctx to ask for last_error (the C-API's last-error buffer
|
||||
// lives on the ctx that was never returned). Surface the path
|
||||
// so the operator at least knows which load failed.
|
||||
return fmt.Errorf("parakeet-cpp: parakeet_capi_load failed for %q", opts.ModelFile)
|
||||
}
|
||||
p.ctxPtr = ctx
|
||||
|
||||
// Dynamic batching knobs (model YAML options:, key:value form). Batching is
|
||||
// OFF by default (batch_max_size:1): each request runs on its own. On GPU,
|
||||
// raising batch_max_size coalesces concurrent requests into one batched
|
||||
// engine call and improves throughput under load; leave it at 1 on CPU and
|
||||
// for low-concurrency setups, where batching only adds latency.
|
||||
maxSize := optInt(opts, "batch_max_size", 1)
|
||||
maxWaitMs := optInt(opts, "batch_max_wait_ms", 15)
|
||||
if maxWaitMs < 0 {
|
||||
maxWaitMs = 0
|
||||
}
|
||||
|
||||
// NeMo's segment_gap_threshold (encoder frames, default 0=off). Off by
|
||||
// default matches NeMo's default (punctuation-only segments); when set it
|
||||
// additionally splits segments on inter-word silence (see transcriptResultFromDoc).
|
||||
p.segmentGapFrames = optInt(opts, "segment_gap_threshold", 0)
|
||||
if CppTranscribePcmBatchJSON != nil {
|
||||
p.batStop = make(chan struct{})
|
||||
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
|
||||
go p.bat.run(p.batStop) // dispatcher runs until Free closes batStop
|
||||
if maxSize > 1 {
|
||||
xlog.Info("parakeet-cpp: dynamic batching enabled",
|
||||
"batch_max_size", maxSize, "batch_max_wait_ms", maxWaitMs)
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: dynamic batching off (batch_max_size=1); " +
|
||||
"set batch_max_size>1 to coalesce concurrent requests on GPU")
|
||||
}
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: batched C-API not present in libparakeet.so; " +
|
||||
"batching disabled, using per-request transcription")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// optInt reads an integer model option (key:value form) from ModelOptions,
|
||||
// returning def when absent or unparseable. The options array carries the
|
||||
// model YAML's options: entries (see core/config; siblings such as
|
||||
// acestep-cpp parse the same key:value form via strings.Cut on ":").
|
||||
func optInt(opts *pb.ModelOptions, key string, def int) int {
|
||||
for _, o := range opts.GetOptions() {
|
||||
k, v, ok := strings.Cut(o, ":")
|
||||
if ok && strings.TrimSpace(k) == key {
|
||||
if n, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// runBatch is the dispatcher's batch handler and the ONLY caller of the C
|
||||
// engine on the unary path. It concatenates the batch PCM, calls the batched
|
||||
// JSON C-API under engineMu, splits the JSON array, and replies to each request.
|
||||
func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// Observability: the actual coalesced batch size per engine call. Debug-level
|
||||
// so it stays silent in normal operation but lets operators confirm/tune batching.
|
||||
xlog.Debug("parakeet-cpp: dispatching batch", "size", len(reqs))
|
||||
nSamples := make([]int32, len(reqs))
|
||||
total := 0
|
||||
for i, r := range reqs {
|
||||
nSamples[i] = int32(len(r.pcm))
|
||||
total += len(r.pcm)
|
||||
}
|
||||
concat := make([]float32, 0, total)
|
||||
for _, r := range reqs {
|
||||
concat = append(concat, r.pcm...)
|
||||
}
|
||||
var dec int32
|
||||
if len(reqs) > 0 {
|
||||
dec = reqs[0].decoder
|
||||
}
|
||||
// All requests in a batch share one language (the batcher coalesces only
|
||||
// same-language requests), so any element's language describes the batch.
|
||||
lang := ""
|
||||
if len(reqs) > 0 {
|
||||
lang = reqs[0].language
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
var cstr uintptr
|
||||
if CppTranscribePcmBatchJSONLang != nil {
|
||||
cstr = CppTranscribePcmBatchJSONLang(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec, lang)
|
||||
} else {
|
||||
cstr = CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
}
|
||||
p.engineMu.Unlock()
|
||||
if cstr == 0 {
|
||||
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: err}
|
||||
}
|
||||
return
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var docs []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(raw), &docs); err != nil || len(docs) != len(reqs) {
|
||||
e := fmt.Errorf("parakeet-cpp: batch json: got %d results for %d reqs (%v)", len(docs), len(reqs), err)
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: e}
|
||||
}
|
||||
return
|
||||
}
|
||||
for i, r := range reqs {
|
||||
r.reply <- batchReply{json: string(docs[i])}
|
||||
}
|
||||
}
|
||||
|
||||
// AudioTranscription decodes the wav at opts.Dst to 16 kHz mono PCM and
|
||||
// submits it to the in-process batcher, which coalesces concurrent requests
|
||||
// into a single batched engine call (parakeet_capi_transcribe_pcm_batch_json)
|
||||
// with the default decoder (decoder=0, which selects the right head per
|
||||
// architecture: transducer for tdt/rnnt/hybrid, CTC for ctc) and shapes the
|
||||
// per-word timestamps into a LocalAI TranscriptResult.
|
||||
//
|
||||
// Parakeet emits word- and token-level timestamps but no native segment
|
||||
// boundaries, so we synthesise a single whole-clip segment spanning the first
|
||||
// word start to the last word end. Word-level timings are attached only when
|
||||
// the caller opts in via timestamp_granularities=["word"] (matching the
|
||||
// OpenAI API, whose default is segment-level); token ids always populate
|
||||
// Segment.Tokens.
|
||||
//
|
||||
// translate/diarize/prompt/temperature/threads are not applicable to parakeet
|
||||
// and are ignored; language is honored on the batched + streaming paths (see
|
||||
// opts.GetLanguage() below); streaming is handled by AudioTranscriptionStream
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
// Fallback when the batched C-API is unavailable: transcribe from a file
|
||||
// path (original behavior, no batching). The C library's audio loader only
|
||||
// understands 16 kHz mono WAV/PCM, so convert the input first - otherwise
|
||||
// any non-WAV upload (MP3, etc.) fails with "failed to load audio". This
|
||||
// mirrors what every other audio backend (whisper, crispasr) does via
|
||||
// utils.AudioToWav before handing the file to the engine.
|
||||
if p.bat == nil {
|
||||
converted, cleanup, err := convertToWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, converted, 0)
|
||||
if cstr == 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
}
|
||||
|
||||
// Batched path: decode to PCM, submit to the batcher, wait for this request's
|
||||
// JSON element. The dispatcher is the sole engine caller on this path; both
|
||||
// sends honour ctx cancellation.
|
||||
pcm, _, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
rep := make(chan batchReply, 1)
|
||||
select {
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, language: opts.GetLanguage(), reply: rep}:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
var res batchReply
|
||||
select {
|
||||
case res = <-rep:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if res.err != nil {
|
||||
return pb.TranscriptResult{}, res.err
|
||||
}
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
}
|
||||
|
||||
// segmentSeparators is NeMo's default segment_seperators (sentence-ending
|
||||
// punctuation). Splitting on these matches NeMo's default segment timestamps.
|
||||
var segmentSeparators = []rune{'.', '?', '!'}
|
||||
|
||||
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
|
||||
// grouping words into NeMo-faithful segments (see splitWordsIntoSegments). The
|
||||
// optional gapFrames (NeMo's segment_gap_threshold, in encoder FRAMES; 0=off)
|
||||
// additionally splits on inter-word silence; it is converted to a seconds gap
|
||||
// with the document's frame_sec. Per-segment word timings are attached only when
|
||||
// the caller requested word granularity; token ids populate each segment's
|
||||
// Tokens by time-window membership. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
|
||||
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
|
||||
gapSeconds := 0.0
|
||||
if gapFrames > 0 {
|
||||
if doc.FrameSec > 0 {
|
||||
gapSeconds = float64(gapFrames) * doc.FrameSec
|
||||
} else {
|
||||
xlog.Warn("parakeet-cpp: segment_gap_threshold set but libparakeet.so " +
|
||||
"did not report frame_sec; falling back to punctuation-only segments")
|
||||
}
|
||||
}
|
||||
|
||||
groups := splitWordsIntoSegments(doc.Words, segmentSeparators, gapSeconds)
|
||||
if len(groups) == 0 {
|
||||
// No words (edge case): single whole-clip text segment.
|
||||
return pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
|
||||
}
|
||||
}
|
||||
|
||||
wantWords := wordsRequested(opts.TimestampGranularities)
|
||||
segments := make([]*pb.TranscriptSegment, 0, len(groups))
|
||||
for id, group := range groups {
|
||||
parts := make([]string, len(group))
|
||||
for i, gw := range group {
|
||||
parts[i] = gw.W
|
||||
}
|
||||
seg := &pb.TranscriptSegment{
|
||||
Id: int32(id),
|
||||
Start: secondsToNanos(group[0].Start),
|
||||
End: secondsToNanos(group[len(group)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
Tokens: tokensInWindow(doc.Tokens, group[0].Start, group[len(group)-1].End),
|
||||
}
|
||||
if wantWords {
|
||||
ws := make([]*pb.TranscriptWord, len(group))
|
||||
for i, gw := range group {
|
||||
ws[i] = &pb.TranscriptWord{Start: secondsToNanos(gw.Start), End: secondsToNanos(gw.End), Text: gw.W}
|
||||
}
|
||||
seg.Words = ws
|
||||
}
|
||||
segments = append(segments, seg)
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: segments}
|
||||
}
|
||||
|
||||
// splitWordsIntoSegments groups words into segments exactly as NeMo's
|
||||
// get_segment_offsets does (nemo/collections/asr/parts/utils/timestamp_utils.py).
|
||||
// Walking the words, it closes a segment when (1) the gap rule is enabled
|
||||
// (gapSeconds > 0) and the segment already has words and the gap from the
|
||||
// previous word's end to this word's start is >= gapSeconds - the current word
|
||||
// then STARTS a new segment - or, checked only when the gap rule did not apply
|
||||
// (NeMo's elif), (2) the word ends with (or is) a separator, which closes the
|
||||
// segment INCLUDING that word. Trailing words flush into a final segment.
|
||||
// gapSeconds <= 0 disables the gap rule, matching NeMo's default
|
||||
// segment_gap_threshold=None (punctuation-only segments).
|
||||
func splitWordsIntoSegments(words []transcriptWord, separators []rune, gapSeconds float64) [][]transcriptWord {
|
||||
var segments [][]transcriptWord
|
||||
var cur []transcriptWord
|
||||
for i, word := range words {
|
||||
gapActive := gapSeconds > 0 && len(cur) > 0
|
||||
if gapActive && (word.Start-words[i-1].End) >= gapSeconds {
|
||||
segments = append(segments, cur)
|
||||
cur = []transcriptWord{word}
|
||||
continue
|
||||
}
|
||||
if !gapActive && endsWithSeparator(word.W, separators) {
|
||||
cur = append(cur, word)
|
||||
segments = append(segments, cur)
|
||||
cur = nil
|
||||
continue
|
||||
}
|
||||
cur = append(cur, word)
|
||||
}
|
||||
if len(cur) > 0 {
|
||||
segments = append(segments, cur)
|
||||
}
|
||||
return segments
|
||||
}
|
||||
|
||||
// endsWithSeparator reports whether w's last rune is in separators (matching
|
||||
// NeMo's `word[-1] in delims or word in delims`).
|
||||
func endsWithSeparator(w string, separators []rune) bool {
|
||||
r := []rune(strings.TrimSpace(w))
|
||||
if len(r) == 0 {
|
||||
return false
|
||||
}
|
||||
last := r[len(r)-1]
|
||||
for _, s := range separators {
|
||||
if last == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// tokensInWindow returns the ids of tokens whose timestamp t falls in
|
||||
// [start, end] (inclusive), assigning each token to the segment that spans its
|
||||
// time. The last segment's end is the last word end, so the final token is
|
||||
// included.
|
||||
func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
|
||||
var ids []int32
|
||||
for _, t := range tokens {
|
||||
if t.T >= start && t.T <= end {
|
||||
ids = append(ids, t.ID)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
|
||||
// is the model's own utterance boundary; each closed segment takes its start/end
|
||||
// from its first/last accumulated word.
|
||||
type streamSegmenter struct {
|
||||
segs []*pb.TranscriptSegment
|
||||
cur []transcriptWord
|
||||
nextID int32
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) add(doc streamFeedJSON) {
|
||||
s.cur = append(s.cur, doc.Words...)
|
||||
if doc.Eou != 0 {
|
||||
s.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) flush() {
|
||||
if len(s.cur) == 0 {
|
||||
return
|
||||
}
|
||||
parts := make([]string, len(s.cur))
|
||||
for i, w := range s.cur {
|
||||
parts[i] = w.W
|
||||
}
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{
|
||||
Id: s.nextID,
|
||||
Start: secondsToNanos(s.cur[0].Start),
|
||||
End: secondsToNanos(s.cur[len(s.cur)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
})
|
||||
s.nextID++
|
||||
s.cur = nil
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
|
||||
|
||||
// wordsRequested reports whether the caller asked for word-level timestamps.
|
||||
// The OpenAI transcription API gates word timings behind
|
||||
// timestamp_granularities[] containing "word" and defaults to segment-level
|
||||
// otherwise; we follow that contract.
|
||||
func wordsRequested(granularities []string) bool {
|
||||
for _, g := range granularities {
|
||||
if strings.EqualFold(strings.TrimSpace(g), "word") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// secondsToNanos converts the C-API's fractional-second timestamps into the
|
||||
// int64 nanoseconds LocalAI carries on TranscriptSegment/TranscriptWord, the
|
||||
// same nanosecond convention the whisper backend uses.
|
||||
func secondsToNanos(sec float64) int64 {
|
||||
return int64(sec * 1e9)
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream drives the cache-aware streaming RNN-T over the
|
||||
// audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it in
|
||||
// chunks to parakeet_capi_stream_feed, and emits each newly-finalized text
|
||||
// run as a TranscriptStreamResponse delta. <EOU>/<EOB> events close the
|
||||
// current segment; a closing FinalResult carries the full transcript and the
|
||||
// per-utterance segments.
|
||||
//
|
||||
// stream_begin returns 0 for models that are not cache-aware streaming models
|
||||
// (only e.g. nvidia/parakeet_realtime_eou_120m-v1 qualifies). For those we fall
|
||||
// back to a single offline transcription emitted as one delta plus a closing
|
||||
// FinalResult, matching LocalAI's non-streaming streaming contract (and the
|
||||
// whisper backend), so the streaming endpoint works for every model.
|
||||
func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
var stream uintptr
|
||||
if CppStreamBeginLang != nil {
|
||||
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
|
||||
} else {
|
||||
stream = CppStreamBegin(p.ctxPtr)
|
||||
}
|
||||
if stream == 0 {
|
||||
// Not a cache-aware streaming model: run a normal offline
|
||||
// transcription and emit it as one delta + a closing final result.
|
||||
res, err := p.AudioTranscription(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t := strings.TrimSpace(res.Text); t != "" {
|
||||
results <- &pb.TranscriptStreamResponse{Delta: t}
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: &res}
|
||||
return nil
|
||||
}
|
||||
defer CppStreamFree(stream)
|
||||
// The C engine is a single shared context: a streaming session and a batched
|
||||
// unary dispatch must never touch it at once, so hold engineMu for the whole
|
||||
// stream. This lock is intentionally taken AFTER the non-streaming fallback
|
||||
// above returns: that fallback goes through AudioTranscription -> the batcher
|
||||
// -> runBatch, which itself acquires engineMu, so locking here first would
|
||||
// deadlock. Do not hoist this lock above the fallback.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
|
||||
data, duration, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ABI v4: when the streaming JSON entry points are present, drive them so the
|
||||
// per-utterance segments carry per-word start/end timestamps. Falls through to
|
||||
// the text-only loop below against an older libparakeet.so. Runs under the
|
||||
// engineMu already held above.
|
||||
if CppStreamFeedJSON != nil {
|
||||
return p.streamJSON(ctx, stream, data, duration, results)
|
||||
}
|
||||
|
||||
var (
|
||||
full strings.Builder
|
||||
segText strings.Builder
|
||||
segments []*pb.TranscriptSegment
|
||||
segID int32
|
||||
)
|
||||
|
||||
flushSegment := func() {
|
||||
t := strings.TrimSpace(segText.String())
|
||||
segText.Reset()
|
||||
if t == "" {
|
||||
return
|
||||
}
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: segID, Text: t})
|
||||
segID++
|
||||
}
|
||||
|
||||
// emitDelta consumes the malloc'd char* returned by feed/finalize: frees
|
||||
// it, accumulates the text, and sends a delta when non-empty. A 0 return
|
||||
// is an error (vs the "" empty-but-non-NULL no-new-text case).
|
||||
emitDelta := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
delta := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
full.WriteString(delta)
|
||||
segText.WriteString(delta)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
|
||||
var eou int32
|
||||
ret := CppStreamFeed(stream, chunk, int32(len(chunk)), unsafe.Pointer(&eou))
|
||||
if err := emitDelta(ret); err != nil {
|
||||
return err
|
||||
}
|
||||
if eou != 0 {
|
||||
flushSegment()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush the streaming tail (final encoder chunk).
|
||||
if err := emitDelta(CppStreamFinalize(stream)); err != nil {
|
||||
return err
|
||||
}
|
||||
flushSegment()
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamJSON drives the ABI v4 streaming JSON entry points: each feed/finalize
|
||||
// returns a {text,eou,frame_sec,words} document. The newly-finalized text is
|
||||
// emitted as a delta (unchanged streaming contract) while words are accumulated
|
||||
// into per-utterance segments (closed on EOU) so the closing FinalResult carries
|
||||
// timestamped segments. Runs under engineMu (already held by the caller).
|
||||
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
|
||||
duration float32, results chan *pb.TranscriptStreamResponse) error {
|
||||
var (
|
||||
full strings.Builder
|
||||
seg streamSegmenter
|
||||
)
|
||||
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
|
||||
// emits the delta, and routes words through the segmenter.
|
||||
consume := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
raw := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
var doc streamFeedJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
|
||||
}
|
||||
if doc.Text != "" {
|
||||
full.WriteString(doc.Text)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
|
||||
}
|
||||
seg.add(doc)
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
|
||||
return err
|
||||
}
|
||||
seg.flush() // close any trailing utterance that never saw an EOU
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
segments := seg.segments()
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
|
||||
// float samples plus the clip duration in seconds. Mirrors the whisper
|
||||
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
|
||||
// decodes the PCM.
|
||||
// convertToWavMono16k converts an arbitrary audio file to a 16 kHz mono WAV in
|
||||
// a fresh temp dir and returns the path together with a cleanup func the caller
|
||||
// must defer. WAV inputs already at 16 kHz/mono/16-bit are passed through by
|
||||
// utils.AudioToWav (hardlink/copy), everything else is transcoded via ffmpeg.
|
||||
// Used by the direct (non-batched) transcription path, which hands a file path
|
||||
// to the C library's WAV-only audio loader.
|
||||
func convertToWavMono16k(path string) (string, func(), error) {
|
||||
dir, err := os.MkdirTemp("", "parakeet")
|
||||
if err != nil {
|
||||
return "", func() {}, err
|
||||
}
|
||||
cleanup := func() { _ = os.RemoveAll(dir) }
|
||||
|
||||
converted := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(path, converted); err != nil {
|
||||
cleanup()
|
||||
return "", func() {}, err
|
||||
}
|
||||
return converted, cleanup, nil
|
||||
}
|
||||
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
converted, cleanup, err := convertToWavMono16k(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
fh, err := os.Open(converted)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
buf, err := wav.NewDecoder(fh).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
return data, duration, nil
|
||||
}
|
||||
|
||||
// Free releases the underlying parakeet_ctx. Called by LocalAI when the
|
||||
// model is unloaded.
|
||||
func (p *ParakeetCpp) Free() error {
|
||||
// Stop the dispatcher before releasing the engine so no in-flight runBatch
|
||||
// can touch a freed ctx (close leak / use-after-free on reload).
|
||||
if p.batStop != nil {
|
||||
close(p.batStop)
|
||||
p.batStop = nil
|
||||
}
|
||||
if p.ctxPtr != 0 {
|
||||
CppFree(p.ctxPtr)
|
||||
p.ctxPtr = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory.
|
||||
// cptr is the raw pointer returned by purego from the C-API (a malloc'd
|
||||
// buffer the caller owns); callers must free it via CppFreeString after
|
||||
// the copy lands.
|
||||
//
|
||||
// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr
|
||||
// check, which can't distinguish a C-owned heap pointer from Go-managed
|
||||
// memory. It is safe here: the pointer addresses a malloc'd C buffer the
|
||||
// Go GC neither tracks nor moves, and we dereference it immediately to
|
||||
// copy the bytes out, the same pattern (and the same tolerated warning)
|
||||
// as the whisper backend's unsafe.Slice over segsPtr.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := unsafe.Pointer(cptr) //nolint:govet // C-owned malloc'd buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
}
|
||||
return string(unsafe.Slice((*byte)(p), n))
|
||||
}
|
||||
247
backend/go/parakeet-cpp/goparakeetcpp_test.go
Normal file
247
backend/go/parakeet-cpp/goparakeetcpp_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestParakeetCpp(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "parakeet-cpp Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive
|
||||
// the C-API bridge without spinning up the gRPC server. Skips the
|
||||
// current spec when libparakeet.so isn't loadable from cwd
|
||||
// ($LD_LIBRARY_PATH or a symlink in ./).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("PARAKEET_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libparakeet.so"
|
||||
}
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppAbiVersion, lib, "parakeet_capi_abi_version")
|
||||
purego.RegisterLibFunc(&CppLoad, lib, "parakeet_capi_load")
|
||||
purego.RegisterLibFunc(&CppFree, lib, "parakeet_capi_free")
|
||||
purego.RegisterLibFunc(&CppTranscribePath, lib, "parakeet_capi_transcribe_path")
|
||||
purego.RegisterLibFunc(&CppTranscribePathJSON, lib, "parakeet_capi_transcribe_path_json")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppStreamBegin, lib, "parakeet_capi_stream_begin")
|
||||
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
|
||||
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
|
||||
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
|
||||
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("libparakeet.so not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if
|
||||
// either env var is unset. The smoke test never runs in default CI; it
|
||||
// needs a real parakeet GGUF and a 16 kHz mono WAV on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_MODEL")
|
||||
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set PARAKEET_BACKEND_TEST_MODEL and PARAKEET_BACKEND_TEST_WAV to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// writeMono16kWav writes `samples` frames of 16 kHz mono 16-bit silence to
|
||||
// path. The result is already in AudioToWav's target format, so the conversion
|
||||
// helper copies it through without invoking ffmpeg.
|
||||
func writeMono16kWav(path string, samples int) {
|
||||
GinkgoHelper()
|
||||
f, err := os.Create(path)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
enc := wav.NewEncoder(f, 16000, 16, 1, 1)
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 16000},
|
||||
SourceBitDepth: 16,
|
||||
Data: make([]int, samples),
|
||||
}
|
||||
Expect(enc.Write(buf)).To(Succeed())
|
||||
Expect(enc.Close()).To(Succeed())
|
||||
Expect(f.Close()).To(Succeed())
|
||||
}
|
||||
|
||||
var _ = Describe("ParakeetCpp", func() {
|
||||
Context("AudioTranscription", func() {
|
||||
It("transcribes a WAV via the parakeet C-API", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
|
||||
"expected non-empty transcript for %s", audioPath)
|
||||
// NeMo-faithful segmentation: one or more punctuation-delimited
|
||||
// segments, each with text and a monotonically-advancing time span.
|
||||
Expect(res.Segments).ToNot(BeEmpty(), "expected at least one segment")
|
||||
var prevEnd int64
|
||||
for i, seg := range res.Segments {
|
||||
Expect(strings.TrimSpace(seg.Text)).ToNot(BeEmpty(),
|
||||
"segment %d must have text", i)
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start),
|
||||
"segment %d end must not precede its start", i)
|
||||
Expect(seg.Start).To(BeNumerically(">=", prevEnd),
|
||||
"segments must be in time order")
|
||||
prevEnd = seg.End
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(seg.Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
}
|
||||
})
|
||||
|
||||
It("emits word-level timestamps when granularity=word", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
TimestampGranularities: []string{"word"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Segments).ToNot(BeEmpty())
|
||||
// With word granularity every segment carries its own words, and each
|
||||
// segment's span tracks its first/last word; word starts advance
|
||||
// monotonically across the whole transcript.
|
||||
totalWords := 0
|
||||
var prevStart int64 = -1
|
||||
for i, seg := range res.Segments {
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"segment %d must carry per-word timestamps with granularity=word", i)
|
||||
Expect(seg.Start).To(Equal(seg.Words[0].Start),
|
||||
"segment %d start tracks its first word", i)
|
||||
Expect(seg.End).To(Equal(seg.Words[len(seg.Words)-1].End),
|
||||
"segment %d end tracks its last word", i)
|
||||
for _, w := range seg.Words {
|
||||
Expect(w.End).To(BeNumerically(">=", w.Start))
|
||||
Expect(w.Start).To(BeNumerically(">=", prevStart))
|
||||
prevStart = w.Start
|
||||
totalWords++
|
||||
}
|
||||
}
|
||||
Expect(totalWords).To(BeNumerically(">", 0))
|
||||
Expect(res.Segments[0].Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("convertToWavMono16k", func() {
|
||||
// The non-batched transcription path hands a file path to the C
|
||||
// library's WAV-only audio loader, so it must convert first.
|
||||
// utils.AudioToWav passes an already-16kHz/mono/16-bit WAV through
|
||||
// without ffmpeg, which lets us exercise the helper (and the
|
||||
// regression: the direct path used to skip conversion entirely)
|
||||
// without a model, the C library, or ffmpeg.
|
||||
It("returns a decodable 16kHz mono WAV copy and cleans it up", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
src := filepath.Join(dir, "input.wav")
|
||||
writeMono16kWav(src, 16000) // 1s of silence at 16 kHz
|
||||
|
||||
converted, cleanup, err := convertToWavMono16k(src)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// It must produce a fresh temp file, not return the original path.
|
||||
Expect(converted).ToNot(Equal(src))
|
||||
Expect(converted).To(BeAnExistingFile())
|
||||
|
||||
pcm, _, err := decodeWavMono16k(converted)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pcm).To(HaveLen(16000), "round-trips the sample count")
|
||||
|
||||
cleanup()
|
||||
Expect(converted).ToNot(BeAnExistingFile(), "cleanup removes the temp dir")
|
||||
})
|
||||
|
||||
It("errors on a non-existent input rather than passing the path through", func() {
|
||||
_, _, err := convertToWavMono16k(filepath.Join(GinkgoT().TempDir(), "missing.mp3"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("streams deltas and a closing FinalResult from a cache-aware model", func() {
|
||||
// Streaming needs a cache-aware streaming model (e.g.
|
||||
// realtime_eou); the offline test model would fail stream_begin.
|
||||
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_STREAM_MODEL")
|
||||
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set PARAKEET_BACKEND_TEST_STREAM_MODEL (cache-aware streaming model) and PARAKEET_BACKEND_TEST_WAV")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- p.AudioTranscriptionStream(context.Background(),
|
||||
&pb.TranscriptRequest{Dst: audioPath}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var final *pb.TranscriptResult
|
||||
for r := range results {
|
||||
if r.Delta != "" {
|
||||
deltas = append(deltas, r.Delta)
|
||||
}
|
||||
if r.FinalResult != nil {
|
||||
final = r.FinalResult
|
||||
}
|
||||
}
|
||||
Expect(<-errCh).ToNot(HaveOccurred())
|
||||
|
||||
Expect(final).ToNot(BeNil(), "expected a closing FinalResult")
|
||||
Expect(strings.TrimSpace(final.Text)).ToNot(BeEmpty(),
|
||||
"expected a non-empty streamed transcript")
|
||||
Expect(final.Segments).ToNot(BeEmpty(),
|
||||
"FinalResult always carries at least one segment")
|
||||
// The concatenated deltas reconstruct the final transcript.
|
||||
Expect(strings.TrimSpace(strings.Join(deltas, ""))).To(Equal(strings.TrimSpace(final.Text)),
|
||||
"deltas should reconstruct the final text")
|
||||
})
|
||||
})
|
||||
})
|
||||
94
backend/go/parakeet-cpp/main.go
Normal file
94
backend/go/parakeet-cpp/main.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
//
|
||||
// Loads libparakeet.so via purego and registers the flat C-API entry
|
||||
// points declared in parakeet_capi.h. The library name can be overridden
|
||||
// with PARAKEET_LIBRARY (mirrors the WHISPER_LIBRARY / VIBEVOICECPP_LIBRARY
|
||||
// convention in the sibling backends); the default looks for the .so next
|
||||
// to this binary.
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("PARAKEET_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libparakeet.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("parakeet-cpp: dlopen %q: %w", libName, err))
|
||||
}
|
||||
|
||||
// Bound 1:1 to parakeet_capi.h. The C-API returns malloc'd char*
|
||||
// buffers from transcribe_*; we register those as uintptr so we get
|
||||
// the raw pointer back and can call parakeet_capi_free_string on it
|
||||
// (purego's string return would copy and forget the original pointer,
|
||||
// leaking it on every call).
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppAbiVersion, "parakeet_capi_abi_version"},
|
||||
{&CppLoad, "parakeet_capi_load"},
|
||||
{&CppFree, "parakeet_capi_free"},
|
||||
{&CppTranscribePath, "parakeet_capi_transcribe_path"},
|
||||
{&CppTranscribePathJSON, "parakeet_capi_transcribe_path_json"},
|
||||
{&CppStreamBegin, "parakeet_capi_stream_begin"},
|
||||
{&CppStreamFeed, "parakeet_capi_stream_feed"},
|
||||
{&CppStreamFinalize, "parakeet_capi_stream_finalize"},
|
||||
{&CppStreamFree, "parakeet_capi_stream_free"},
|
||||
{&CppFreeString, "parakeet_capi_free_string"},
|
||||
{&CppLastError, "parakeet_capi_last_error"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
// The batched-JSON entry point exists only in newer libparakeet.so (ABI >= 2).
|
||||
// Probe with Dlsym and register only if present, so the backend still loads
|
||||
// against an older library (it falls back to per-request transcription).
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
|
||||
// Per-request language variants (multilingual nemotron). Same probe pattern:
|
||||
// present only in libparakeet.so built with multilingual support, so the
|
||||
// backend still loads against an older library and falls back to the
|
||||
// non-lang batched + streaming entry points (model default / "auto").
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSONLang, lib, "parakeet_capi_transcribe_pcm_batch_json_lang")
|
||||
}
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_begin_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamBeginLang, lib, "parakeet_capi_stream_begin_lang")
|
||||
}
|
||||
|
||||
// Streaming JSON entry points (ABI v4): surface per-word timestamps on the
|
||||
// streaming path. Same probe pattern; absent in older libparakeet.so, where
|
||||
// the backend falls back to the text-only streaming feed.
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &ParakeetCpp{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
backend/go/parakeet-cpp/package.sh
Executable file
23
backend/go/parakeet-cpp/package.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# L0 packaging stub: copy the binary, run.sh and libparakeet.so* into
|
||||
# package/. The full ldd walk (libc, libstdc++, libgomp, GPU runtimes,
|
||||
# arch detection) lands in L3, mirroring backend/go/whisper/package.sh.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/parakeet-cpp-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libparakeet.so + any soname symlinks (libparakeet.so.X, libparakeet.so.X.Y).
|
||||
cp -avf "$CURDIR"/libparakeet.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libparakeet.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo "L0 package layout (full ldd walk lands in L3):"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
16
backend/go/parakeet-cpp/run.sh
Executable file
16
backend/go/parakeet-cpp/run.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# If a self-contained ld.so was packaged, route through it so the
|
||||
# packaged libc / libstdc++ are used instead of the host's (matches the
|
||||
# whisper backend's runtime layout).
|
||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec "$CURDIR/lib/ld.so" "$CURDIR/parakeet-cpp-grpc" "$@"
|
||||
fi
|
||||
|
||||
exec "$CURDIR/parakeet-cpp-grpc" "$@"
|
||||
127
backend/go/parakeet-cpp/segments_test.go
Normal file
127
backend/go/parakeet-cpp/segments_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func tw(text string, start, end float64) transcriptWord {
|
||||
return transcriptWord{W: text, Start: start, End: end}
|
||||
}
|
||||
|
||||
var _ = Describe("splitWordsIntoSegments (NeMo get_segment_offsets parity)", func() {
|
||||
seps := []rune{'.', '?', '!'}
|
||||
|
||||
It("splits on sentence-ending punctuation, including the delimiter word", func() {
|
||||
words := []transcriptWord{tw("hello", 0, 0.4), tw("world.", 0.4, 0.8), tw("bye", 1.0, 1.3)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[0][1].W).To(Equal("world."))
|
||||
Expect(segs[1]).To(HaveLen(1))
|
||||
Expect(segs[1][0].W).To(Equal("bye"))
|
||||
})
|
||||
|
||||
It("keeps a single segment with no terminal punctuation and gap off", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("splits on the gap rule when enabled, the gapped word starting the next segment", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0) // c is 4.6s after b
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2)) // a b
|
||||
Expect(segs[1]).To(HaveLen(1)) // c
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("checks the gap rule before punctuation (NeMo elif order)", func() {
|
||||
// "b." would terminate, but c is far after it -> gap closes [a b.] at b.
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b.", 0.2, 0.4), tw("c", 9.0, 9.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("still splits on punctuation when the gap rule is enabled but does not fire", func() {
|
||||
words := []transcriptWord{tw("hi.", 0, 0.4), tw("bye", 0.4, 0.8)}
|
||||
segs := splitWordsIntoSegments(words, seps, 5.0) // gap never reached
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0][0].W).To(Equal("hi."))
|
||||
})
|
||||
|
||||
It("returns nothing for empty input", func() {
|
||||
Expect(splitWordsIntoSegments(nil, seps, 0)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
|
||||
doc := transcriptJSON{
|
||||
Text: "hello world. bye now",
|
||||
FrameSec: 0.08,
|
||||
Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4},
|
||||
{W: "world.", Start: 0.4, End: 0.8},
|
||||
{W: "bye", Start: 1.0, End: 1.3},
|
||||
{W: "now", Start: 1.3, End: 1.6},
|
||||
},
|
||||
Tokens: []transcriptToken{{ID: 1, T: 0.1}, {ID: 2, T: 0.5}, {ID: 3, T: 1.1}, {ID: 4, T: 1.4}},
|
||||
}
|
||||
|
||||
It("emits one segment per punctuation-delimited group with start/end", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(2))
|
||||
Expect(res.Segments[0].Text).To(Equal("hello world."))
|
||||
Expect(res.Segments[0].Start).To(Equal(int64(0)))
|
||||
Expect(res.Segments[0].End).To(Equal(secondsToNanos(0.8)))
|
||||
Expect(res.Segments[1].Text).To(Equal("bye now"))
|
||||
Expect(res.Segments[1].Start).To(Equal(secondsToNanos(1.0)))
|
||||
Expect(res.Segments[1].Id).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("assigns tokens to the segment whose time window contains them", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments[0].Tokens).To(Equal([]int32{1, 2}))
|
||||
Expect(res.Segments[1].Tokens).To(Equal([]int32{3, 4}))
|
||||
})
|
||||
|
||||
It("attaches per-segment words only when word granularity requested", func() {
|
||||
plain := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(plain.Segments[0].Words).To(BeEmpty())
|
||||
withWords := transcriptResultFromDoc(doc, &pb.TranscriptRequest{TimestampGranularities: []string{"word"}}, 0)
|
||||
Expect(withWords.Segments[0].Words).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("falls back to a single text segment when there are no words", func() {
|
||||
res := transcriptResultFromDoc(transcriptJSON{Text: "hi"}, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
Expect(res.Segments[0].Text).To(Equal("hi"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("streaming segment assembly", func() {
|
||||
It("closes a segment with start/end from its words on EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
Expect(segs).To(HaveLen(1))
|
||||
Expect(segs[0].Text).To(Equal("hello world"))
|
||||
Expect(segs[0].Start).To(Equal(int64(0)))
|
||||
Expect(segs[0].End).To(Equal(secondsToNanos(0.9)))
|
||||
})
|
||||
|
||||
It("buffers words across feeds until EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
|
||||
Expect(acc.segments()).To(BeEmpty())
|
||||
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
|
||||
Expect(acc.segments()).To(HaveLen(1))
|
||||
Expect(acc.segments()[0].Text).To(Equal("hi there"))
|
||||
})
|
||||
})
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# qwen3-tts.cpp version
|
||||
QWEN3TTS_REPO?=https://github.com/predict-woo/qwen3-tts.cpp
|
||||
QWEN3TTS_CPP_VERSION?=7a762e2ad4bacc6fdda81d81bf10a09ffb546f29
|
||||
QWEN3TTS_CPP_VERSION?=136e5d36c17083da0321fd96512dc7b263f94a44
|
||||
SO_TARGET?=libgoqwen3ttscpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -21,6 +22,43 @@ type Qwen3TtsCpp struct {
|
||||
threads int
|
||||
}
|
||||
|
||||
// languageNameAliases maps common full language names to the canonical
|
||||
// two-letter code understood by the C++ language_to_id table.
|
||||
var languageNameAliases = map[string]string{
|
||||
"english": "en",
|
||||
"russian": "ru",
|
||||
"chinese": "zh",
|
||||
"japanese": "ja",
|
||||
"korean": "ko",
|
||||
"german": "de",
|
||||
"french": "fr",
|
||||
"spanish": "es",
|
||||
"italian": "it",
|
||||
"portuguese": "pt",
|
||||
}
|
||||
|
||||
// normalizeLanguage coerces a caller-supplied language into the canonical code
|
||||
// the model expects. It lowercases, trims, strips any region/locale suffix
|
||||
// (en-US, en_US, ja.JP -> en/ja), and resolves common full names (english -> en).
|
||||
// An empty input stays empty so the C++ side applies its English default; an
|
||||
// unrecognized value is returned normalized so C++ can log it and default.
|
||||
func normalizeLanguage(lang string) string {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
if lang == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip region/locale suffix: keep the segment before the first separator.
|
||||
if i := strings.IndexAny(lang, "-_."); i >= 0 {
|
||||
lang = lang[:i]
|
||||
}
|
||||
|
||||
if code, ok := languageNameAliases[lang]; ok {
|
||||
return code
|
||||
}
|
||||
return lang
|
||||
}
|
||||
|
||||
func (q *Qwen3TtsCpp) Load(opts *pb.ModelOptions) error {
|
||||
// ModelFile is the model directory path (containing GGUF files)
|
||||
modelDir := opts.ModelFile
|
||||
@@ -54,7 +92,7 @@ func (q *Qwen3TtsCpp) TTS(req *pb.TTSRequest) error {
|
||||
dst := req.Dst
|
||||
language := ""
|
||||
if req.Language != nil {
|
||||
language = *req.Language
|
||||
language = normalizeLanguage(*req.Language)
|
||||
}
|
||||
|
||||
// Synthesis parameters with sensible defaults
|
||||
|
||||
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLanguageNormalization(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "qwen3-tts-cpp language normalization")
|
||||
}
|
||||
|
||||
var _ = Describe("normalizeLanguage", func() {
|
||||
DescribeTable("maps caller input to the canonical model language code",
|
||||
func(input, expected string) {
|
||||
Expect(normalizeLanguage(input)).To(Equal(expected))
|
||||
},
|
||||
// Canonical codes pass through unchanged
|
||||
Entry("canonical en", "en", "en"),
|
||||
Entry("canonical zh", "zh", "zh"),
|
||||
Entry("canonical pt", "pt", "pt"),
|
||||
|
||||
// Case-insensitive
|
||||
Entry("uppercase", "EN", "en"),
|
||||
Entry("mixed case", "Ja", "ja"),
|
||||
|
||||
// Surrounding whitespace
|
||||
Entry("trims whitespace", " en ", "en"),
|
||||
|
||||
// Region/locale stripping
|
||||
Entry("BCP-47 region", "en-US", "en"),
|
||||
Entry("underscore region", "en_US", "en"),
|
||||
Entry("dotted locale", "ja.JP", "ja"),
|
||||
Entry("region + case", "ZH-CN", "zh"),
|
||||
|
||||
// Full-name aliases
|
||||
Entry("english name", "english", "en"),
|
||||
Entry("chinese name cased", "Chinese", "zh"),
|
||||
Entry("japanese name", "japanese", "ja"),
|
||||
Entry("russian name", "russian", "ru"),
|
||||
Entry("portuguese name", "portuguese", "pt"),
|
||||
|
||||
// Empty stays empty (C++ applies the English default)
|
||||
Entry("empty", "", ""),
|
||||
Entry("whitespace only", " ", ""),
|
||||
|
||||
// Unknown values pass through normalized so C++ can log + default
|
||||
Entry("unknown code", "klingon", "klingon"),
|
||||
Entry("unknown with region", "xx-YY", "xx"),
|
||||
)
|
||||
})
|
||||
@@ -11,7 +11,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
# build; leaving this on `master` always picks up the latest C-API surface
|
||||
# (incl. the per-detection accessor functions used by gorfdetrcpp.go).
|
||||
RFDETR_REPO?=https://github.com/mudler/rf-detr.cpp.git
|
||||
RFDETR_VERSION?=main
|
||||
RFDETR_VERSION?=65c0ffcc9a9bc9dae38252f63d0417c9845a6cf7
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=92dc7268fc4ffb0c0cc0bd52dfcefea91326e797
|
||||
STABLEDIFFUSION_GGML_VERSION?=19bdfe22d255d5b4dff39d449318b9bc5ea2317f
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -386,6 +386,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *llm_vision_path = "";
|
||||
const char *diffusion_model_path = stableDiffusionModel;
|
||||
const char *high_noise_diffusion_model_path = "";
|
||||
const char *uncond_diffusion_model_path = "";
|
||||
const char *taesd_path = "";
|
||||
const char *control_net_path = "";
|
||||
const char *embedding_dir = "";
|
||||
@@ -472,6 +473,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval);
|
||||
if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "uncond_diffusion_model_path")) uncond_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
|
||||
if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
|
||||
if (!strcmp(optname, "embedding_dir")) {
|
||||
@@ -571,6 +573,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.llm_vision_path = llm_vision_path;
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.uncond_diffusion_model_path = uncond_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=27101c01dcac1676e2b6422256233cd0f1f9ae28
|
||||
WHISPER_CPP_VERSION?=df7638d8229a243af8a4b5a8ae557e0d74e0a0ae
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -95,6 +95,29 @@
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4"
|
||||
metal: "metal-ds4"
|
||||
metal-darwin-arm64: "metal-ds4"
|
||||
- &dllm
|
||||
name: "dllm"
|
||||
alias: "dllm"
|
||||
license: mit
|
||||
description: |
|
||||
mudler/dllm.cpp - DiffusionGemma block-diffusion LLM inference engine
|
||||
(C++/ggml, GGUF weights). Decodes whole token canvases per diffusion
|
||||
round instead of autoregressive sampling. Runs on CPU and NVIDIA CUDA 13
|
||||
(including Jetson/GB10 L4T targets).
|
||||
urls:
|
||||
- https://github.com/mudler/dllm.cpp
|
||||
tags:
|
||||
- text-to-text
|
||||
- LLM
|
||||
- gguf
|
||||
- diffusion
|
||||
- CPU
|
||||
- CUDA
|
||||
capabilities:
|
||||
default: "cpu-dllm"
|
||||
nvidia: "cuda13-dllm"
|
||||
nvidia-cuda-13: "cuda13-dllm"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm"
|
||||
- &whispercpp
|
||||
name: "whisper"
|
||||
alias: "whisper"
|
||||
@@ -122,6 +145,62 @@
|
||||
nvidia-cuda-12: "cuda12-whisper"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisper"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-whisper"
|
||||
- &crispasr
|
||||
name: "crispasr"
|
||||
alias: "crispasr"
|
||||
license: mit
|
||||
icon: https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg
|
||||
description: |
|
||||
CrispASR unified speech engine (whisper.cpp fork on ggml) supporting many ASR architectures (Parakeet, Canary, Voxtral, Qwen3-ASR, Granite, Wav2Vec2, Moonshine, OmniASR, FireRedASR, and more).
|
||||
urls:
|
||||
- https://github.com/CrispStrobe/CrispASR
|
||||
tags:
|
||||
- audio-transcription
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-crispasr"
|
||||
nvidia: "cuda12-crispasr"
|
||||
intel: "intel-sycl-f16-crispasr"
|
||||
metal: "metal-crispasr"
|
||||
amd: "rocm-crispasr"
|
||||
vulkan: "vulkan-crispasr"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-cuda-13: "cuda13-crispasr"
|
||||
nvidia-cuda-12: "cuda12-crispasr"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
- ¶keetcpp
|
||||
name: "parakeet-cpp"
|
||||
alias: "parakeet-cpp"
|
||||
license: mit
|
||||
icon: https://avatars.githubusercontent.com/u/95302084
|
||||
description: |
|
||||
parakeet.cpp is a C++/ggml port of NVIDIA NeMo Parakeet automatic speech recognition (ASR) models.
|
||||
It supports the tdt, ctc, rnnt and hybrid decoder families as well as cache-aware streaming transcription,
|
||||
and runs on CPU, NVIDIA CUDA, AMD ROCm/HIP, Intel SYCL and NVIDIA Jetson (L4T) targets.
|
||||
urls:
|
||||
- https://github.com/mudler/parakeet.cpp
|
||||
tags:
|
||||
- audio-transcription
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-parakeet-cpp"
|
||||
nvidia: "cuda12-parakeet-cpp"
|
||||
intel: "intel-sycl-f16-parakeet-cpp"
|
||||
metal: "metal-parakeet-cpp"
|
||||
amd: "rocm-parakeet-cpp"
|
||||
vulkan: "vulkan-parakeet-cpp"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-parakeet-cpp"
|
||||
nvidia-cuda-13: "cuda13-parakeet-cpp"
|
||||
nvidia-cuda-12: "cuda12-parakeet-cpp"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-parakeet-cpp"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-parakeet-cpp"
|
||||
- &voxtral
|
||||
name: "voxtral"
|
||||
alias: "voxtral"
|
||||
@@ -1216,6 +1295,13 @@
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4-development"
|
||||
metal: "metal-ds4-development"
|
||||
metal-darwin-arm64: "metal-ds4-development"
|
||||
- !!merge <<: *dllm
|
||||
name: "dllm-development"
|
||||
capabilities:
|
||||
default: "cpu-dllm-development"
|
||||
nvidia: "cuda13-dllm-development"
|
||||
nvidia-cuda-13: "cuda13-dllm-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm-development"
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "stablediffusion-ggml-development"
|
||||
capabilities:
|
||||
@@ -1803,6 +1889,37 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-ds4"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-ds4
|
||||
## dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cpu-dllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cpu-dllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-dllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-dllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-nvidia-l4t-arm64-dllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-nvidia-l4t-arm64-dllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-dllm
|
||||
## whisper
|
||||
- !!merge <<: *whispercpp
|
||||
name: "whisper-development"
|
||||
@@ -1928,6 +2045,246 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-whisper
|
||||
## crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "crispasr-development"
|
||||
capabilities:
|
||||
default: "cpu-crispasr-development"
|
||||
nvidia: "cuda12-crispasr-development"
|
||||
intel: "intel-sycl-f16-crispasr-development"
|
||||
metal: "metal-crispasr-development"
|
||||
amd: "rocm-crispasr-development"
|
||||
vulkan: "vulkan-crispasr-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-cuda-13: "cuda13-crispasr-development"
|
||||
nvidia-cuda-12: "cuda12-crispasr-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-crispasr
|
||||
## parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "parakeet-cpp-development"
|
||||
capabilities:
|
||||
default: "cpu-parakeet-cpp-development"
|
||||
nvidia: "cuda12-parakeet-cpp-development"
|
||||
intel: "intel-sycl-f16-parakeet-cpp-development"
|
||||
metal: "metal-parakeet-cpp-development"
|
||||
amd: "rocm-parakeet-cpp-development"
|
||||
vulkan: "vulkan-parakeet-cpp-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
nvidia-cuda-13: "cuda13-parakeet-cpp-development"
|
||||
nvidia-cuda-12: "cuda12-parakeet-cpp-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "nvidia-l4t-arm64-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-nvidia-l4t-arm64-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cpu-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cpu-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "metal-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "metal-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda12-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda12-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "rocm-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "rocm-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f32-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f32-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f16-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f16-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "vulkan-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "vulkan-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-parakeet-cpp
|
||||
## stablediffusion-ggml
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "cpu-stablediffusion-ggml"
|
||||
|
||||
@@ -37,6 +37,20 @@ def is_int(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a TTSRequest.params value (string on the wire) to the type the
|
||||
Chatterbox generate() kwargs expect (float/int/bool), matching how static
|
||||
YAML options are coerced at load time. Non-string values pass through."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if is_float(value):
|
||||
return float(value)
|
||||
if is_int(value):
|
||||
return int(value)
|
||||
if value.lower() in ["true", "false"]:
|
||||
return value.lower() == "true"
|
||||
return value
|
||||
|
||||
def split_text_at_word_boundary(text, max_length=250):
|
||||
"""
|
||||
Split text at word boundaries without truncating words.
|
||||
@@ -191,6 +205,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Merge per-request params (TTSRequest.params), overriding the static
|
||||
# YAML options. This exposes Chatterbox generation knobs (e.g.
|
||||
# exaggeration, cfg_weight, temperature) per request. Values arrive as
|
||||
# strings on the wire and are coerced to float/int/bool.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Check if text exceeds 250 characters
|
||||
# (chatterbox does not support long text)
|
||||
# https://github.com/resemble-ai/chatterbox/issues/60
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
texterrors==1.1.6
|
||||
nemo_toolkit[asr]
|
||||
|
||||
@@ -47,6 +47,26 @@ def is_int(s):
|
||||
return False
|
||||
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a string param value (from the TTSRequest.params map, which is
|
||||
string-typed on the wire) into the most specific Python type the model
|
||||
generation kwargs expect: bool, int, float, else the original string."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
lowered = value.strip().lower()
|
||||
if lowered in ("true", "false"):
|
||||
return lowered == "true"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
@@ -322,6 +342,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def _effective_instruct(self, request):
|
||||
"""Resolve the instruction/style string for this request, preferring the
|
||||
per-request TTSRequest.instructions value and falling back to the static
|
||||
YAML `instruct` option. Empty string means "no instruction"."""
|
||||
req_instruct = (
|
||||
request.instructions
|
||||
if hasattr(request, "instructions") and request.instructions
|
||||
else ""
|
||||
)
|
||||
if req_instruct:
|
||||
return req_instruct
|
||||
return self.options.get("instruct", "") or ""
|
||||
|
||||
def _detect_mode(self, request):
|
||||
"""Detect which mode to use based on request parameters."""
|
||||
# Priority: VoiceClone > VoiceDesign > CustomVoice
|
||||
@@ -338,8 +371,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if self.audio_path or self.voices:
|
||||
return "VoiceClone"
|
||||
|
||||
# VoiceDesign: instruct option is provided
|
||||
if "instruct" in self.options and self.options["instruct"]:
|
||||
# VoiceDesign: instruct provided per-request or via YAML option
|
||||
if self._effective_instruct(request):
|
||||
return "VoiceDesign"
|
||||
|
||||
# Default to CustomVoice
|
||||
@@ -690,10 +723,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if do_sample is not None:
|
||||
generation_kwargs["do_sample"] = do_sample
|
||||
|
||||
instruct = self.options.get("instruct", "")
|
||||
# Prefer the per-request instruction (TTSRequest.instructions) over the
|
||||
# static YAML `instruct` option. This lets clients set a different style
|
||||
# (CustomVoice emotion) or designed voice (VoiceDesign) per request.
|
||||
instruct = self._effective_instruct(request)
|
||||
if instruct is not None and instruct != "":
|
||||
generation_kwargs["instruct"] = instruct
|
||||
|
||||
# Merge any per-request backend-specific params (TTSRequest.params).
|
||||
# Values arrive as strings on the wire; coerce to int/float/bool so the
|
||||
# model receives the types it expects. These override YAML-derived kwargs.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
generation_kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Generate audio based on mode
|
||||
if mode == "VoiceClone":
|
||||
# VoiceClone mode
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf==7.35.0
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -26,7 +26,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
try:
|
||||
from vllm.tokenizers import get_tokenizer # vLLM >= 0.22
|
||||
except ImportError:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer # vLLM < 0.22
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.assets.video import VideoAsset
|
||||
import base64
|
||||
|
||||
@@ -3,5 +3,5 @@
|
||||
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
||||
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
||||
# so uv consults this index alongside PyPI.
|
||||
--extra-index-url https://wheels.vllm.ai/0.21.0/cu130
|
||||
vllm==0.21.0
|
||||
--extra-index-url https://wheels.vllm.ai/0.22.1/cu130
|
||||
vllm==0.22.1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Release represents a LocalAI release
|
||||
@@ -67,9 +68,7 @@ func NewReleaseManager() *ReleaseManager {
|
||||
CurrentVersion: internal.PrintableVersion(),
|
||||
ChecksumsPath: checksumsPath,
|
||||
MetadataPath: metadataPath,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
HTTPClient: httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ type Application struct {
|
||||
// LocalAI Assistant in-process MCP server. nil when DisableLocalAIAssistant
|
||||
// is set; otherwise initialised in start() after galleryService.
|
||||
localAIAssistant *mcpTools.LocalAIAssistantHolder
|
||||
|
||||
shutdownOnce sync.Once
|
||||
}
|
||||
|
||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
@@ -320,6 +322,24 @@ func (a *Application) IsDistributed() bool {
|
||||
return a.distributed != nil
|
||||
}
|
||||
|
||||
// Shutdown stops backend gRPC processes and distributed services
|
||||
// synchronously on the caller's stack. The context-cancel goroutine wired
|
||||
// in New does the same work asynchronously, which races test-binary exit
|
||||
// and CLI shutdown — orphaning spawned mock-backend / llama.cpp / etc.
|
||||
// children to init. Callers that need a guarantee that cleanup has
|
||||
// finished before they proceed (AfterSuite/AfterEach, signal handlers)
|
||||
// must call this. Safe to call multiple times.
|
||||
func (a *Application) Shutdown() error {
|
||||
var err error
|
||||
a.shutdownOnce.Do(func() {
|
||||
a.distributed.Shutdown()
|
||||
if a.modelLoader != nil {
|
||||
err = a.modelLoader.StopAllGRPC()
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// waitForHealthyWorker blocks until at least one healthy backend worker is registered.
|
||||
// This prevents the agent pool from failing during startup when workers haven't connected yet.
|
||||
func (a *Application) waitForHealthyWorker() {
|
||||
|
||||
@@ -16,7 +16,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
@@ -100,7 +102,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
|
||||
natsAuth := cfg.Distributed.NatsAuthConfig()
|
||||
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
|
||||
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
|
||||
}
|
||||
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
@@ -240,6 +247,84 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
|
||||
// Prefix-cache-aware routing. Enabled by default; an operator can opt out
|
||||
// with --distributed-prefix-cache=false, which leaves prefixProvider and
|
||||
// pressure nil so the SmartRouter and reconciler behave exactly as the
|
||||
// round-robin floor (true no-op). When enabled we build the local index,
|
||||
// wrap it in a NATS-backed Sync (publishes our observations, applies peers'
|
||||
// via the subscriptions below), install the extraction hook used by
|
||||
// core/backend/llm.go, and run a background eviction ticker on the app ctx.
|
||||
var prefixProvider prefixcache.Provider
|
||||
var pressure *prefixcache.Pressure
|
||||
var prefixCfg prefixcache.Config
|
||||
if !cfg.Distributed.PrefixCacheDisabled {
|
||||
prefixCfg = prefixcache.DefaultConfig()
|
||||
if cfg.Distributed.PrefixCacheTTL > 0 {
|
||||
prefixCfg.TTL = cfg.Distributed.PrefixCacheTTL
|
||||
}
|
||||
if err := prefixCfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid prefix-cache configuration: %w", err)
|
||||
}
|
||||
idx := prefixcache.NewIndex(prefixCfg)
|
||||
prefixSync := prefixcache.NewSync(idx, natsClient)
|
||||
pressure = prefixcache.NewPressure(prefixCfg.PressureWindow)
|
||||
prefixProvider = prefixSync
|
||||
|
||||
// Invalidate the prefix-cache index whenever a replica row is removed.
|
||||
// SetReplicaRemovedHook fires from the single chokepoint all removal paths
|
||||
// funnel through (RemoveNodeModel / RemoveAllNodeModelReplicas), so this
|
||||
// one hook covers every path: reconciler scale-down, probe reaper,
|
||||
// health-monitor reap, RemoteUnloaderAdapter, and the router. Registering
|
||||
// it only inside this enabled block keeps the disabled path a true no-op
|
||||
// (the registry stays hook-less).
|
||||
registry.SetReplicaRemovedHook(func(model, node string, replica int) {
|
||||
if replica < 0 {
|
||||
prefixSync.InvalidateNode(model, node)
|
||||
} else {
|
||||
prefixSync.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: replica})
|
||||
}
|
||||
})
|
||||
|
||||
distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 {
|
||||
return prefixcache.ExtractChain(model, prompt, prefixCfg)
|
||||
}
|
||||
|
||||
// Apply peers' observations/invalidations to the same Sync. ApplyObserve
|
||||
// and ApplyInvalidate update only the local index and do not re-publish,
|
||||
// so there is no broadcast loop.
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheObserve, func(ev messaging.PrefixCacheObserveEvent) {
|
||||
prefixSync.ApplyObserve(ev, time.Now())
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheObserve, err)
|
||||
}
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheInvalidate, func(ev messaging.PrefixCacheInvalidateEvent) {
|
||||
prefixSync.ApplyInvalidate(ev)
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheInvalidate, err)
|
||||
}
|
||||
|
||||
// Background eviction: sweep idle entries on the app context. Stopped
|
||||
// when the app context is cancelled (mirrors the reconciler loop which
|
||||
// also runs on options.Context). TTL/2 keeps stale entries from
|
||||
// outliving their idle window by more than half a TTL.
|
||||
evictInterval := prefixCfg.TTL / 2
|
||||
go func() {
|
||||
ticker := time.NewTicker(evictInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-cfg.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
prefixSync.Evict(time.Now())
|
||||
}
|
||||
}
|
||||
}()
|
||||
xlog.Info("Prefix-cache-aware routing enabled", "ttl", prefixCfg.TTL, "evictInterval", evictInterval)
|
||||
} else {
|
||||
xlog.Info("Prefix-cache-aware routing disabled: using round-robin routing")
|
||||
}
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
if configLoader != nil {
|
||||
@@ -252,6 +337,9 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
ConflictResolver: conflictResolver,
|
||||
PrefixProvider: prefixProvider,
|
||||
PrefixConfig: prefixCfg,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
|
||||
@@ -268,6 +356,8 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
Interval: 30 * time.Second,
|
||||
ScaleDownDelay: 5 * time.Minute,
|
||||
ProbeStaleAfter: 2 * time.Minute,
|
||||
Pressure: pressure,
|
||||
PressureThreshold: prefixCfg.PressureScaleThreshold,
|
||||
})
|
||||
|
||||
// Create ModelRouterAdapter to wire into ModelLoader
|
||||
|
||||
@@ -23,9 +23,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -308,10 +308,31 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
application.galleryService.SetNATSClient(distSvc.Nats)
|
||||
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
|
||||
// Clean up stale in-progress operations from previous crashed instances
|
||||
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
if _, err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to clean stale gallery operations", "error", err)
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
|
||||
// Reap stale ops periodically, not just at boot: an op orphaned by
|
||||
// a replica that died mid-install (its foreground handler goroutine
|
||||
// gone) would otherwise linger "processing" in the UI until the next
|
||||
// restart. 30m matches the install/upgrade ceiling so a genuinely
|
||||
// slow op is never reaped out from under itself.
|
||||
gsvc := application.galleryService
|
||||
go func() {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-options.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := gsvc.ReapStaleOperations(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to reap stale gallery operations", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
// Hydrate from the store first so the wildcard subscriber finds an
|
||||
// already-populated statuses map for any operations still in flight
|
||||
@@ -449,13 +470,15 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
|
||||
application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging)
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
// Safety-net cleanup if the application context is cancelled without
|
||||
// the caller invoking Shutdown directly. This is fire-and-forget — it
|
||||
// races binary exit and is unreliable in tests; the deterministic path
|
||||
// is application.Shutdown(), which Shutdown's sync.Once dedupes with
|
||||
// this goroutine.
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
xlog.Debug("Context canceled, shutting down")
|
||||
application.distributed.Shutdown()
|
||||
err := application.ModelLoader().StopAllGRPC()
|
||||
if err != nil {
|
||||
if err := application.Shutdown(); err != nil {
|
||||
xlog.Error("error while stopping all grpc backends", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -214,7 +214,9 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
"from", info.InstalledVersion, "to", info.AvailableVersion)
|
||||
var err error
|
||||
if bm != nil {
|
||||
err = bm.UpgradeBackend(ctx, name, nil)
|
||||
// Background auto-upgrade: no live admin watching a progress bar,
|
||||
// so opID is empty and the distributed path skips progress streaming.
|
||||
err = bm.UpgradeBackend(ctx, "", name, nil)
|
||||
} else {
|
||||
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)
|
||||
|
||||
@@ -123,14 +123,14 @@ var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
|
||||
})
|
||||
|
||||
It("ModelTTS forwards the request context to the SmartRouter", func() {
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -94,6 +95,22 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
// Make the rendered prompt's prefix chain available to the distributed router
|
||||
// for prefix-cache-aware node selection. No-op in single-process mode. The
|
||||
// model id MUST match the id ModelOptions feeds to model.WithModelID, so both
|
||||
// use the shared config.ModelConfig.ModelID() helper (Name with a fallback to
|
||||
// Model) or the chain salt and the tracking key would diverge.
|
||||
//
|
||||
// s is empty for UseTokenizerTemplate models (the backend tokenizes the
|
||||
// structured messages itself), so fall back to a prefix-stable serialization
|
||||
// of the messages - otherwise prefix routing would silently degrade to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
chainSource := s
|
||||
if chainSource == "" {
|
||||
chainSource = messagesPrefixSource(messages)
|
||||
}
|
||||
ctx = distributedhdr.MaybeWithPrefixChain(ctx, c.ModelID(), chainSource)
|
||||
|
||||
opts := ModelOptions(*c, o, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -34,16 +34,11 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
|
||||
}
|
||||
|
||||
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
|
||||
name := c.Name
|
||||
if name == "" {
|
||||
name = c.Model
|
||||
}
|
||||
|
||||
defOpts := []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithModel(c.Model),
|
||||
model.WithContext(so.Context),
|
||||
model.WithModelID(name),
|
||||
model.WithModelID(c.ModelID()),
|
||||
}
|
||||
|
||||
threads := 1
|
||||
@@ -244,13 +239,13 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
|
||||
if c.Backend == "cloud-proxy" {
|
||||
opts.Proxy = &pb.ProxyOptions{
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,6 +323,12 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
metadata["enable_thinking"] = "true"
|
||||
}
|
||||
}
|
||||
// Forward the effective reasoning effort so the backend can pass it to the
|
||||
// jinja chat template (chat_template_kwargs.reasoning_effort) — the lever
|
||||
// models like gpt-oss / LFM2.5 actually read, distinct from enable_thinking.
|
||||
if c.ReasoningEffort != "" {
|
||||
metadata["reasoning_effort"] = c.ReasoningEffort
|
||||
}
|
||||
pbOpts.Metadata = metadata
|
||||
|
||||
// Logprobs and TopLogprobs are set by the caller if provided
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -42,3 +43,57 @@ var _ = Describe("grpcModelOpts EngineArgs", func() {
|
||||
Expect(opts.EngineArgs).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
// Guards the DisableReasoning -> enable_thinking metadata conversion that the
|
||||
// per-request reasoning_effort feature (issue #10072) relies on: the request
|
||||
// merge sets ReasoningConfig.DisableReasoning, and gRPCPredictOpts is where it
|
||||
// becomes the gRPC PredictOptions.Metadata the backend reads.
|
||||
var _ = Describe("gRPCPredictOpts enable_thinking metadata", func() {
|
||||
// withReasoning builds a fully-defaulted config (gRPCPredictOpts dereferences
|
||||
// many pointer fields) and overrides only the reasoning toggle.
|
||||
withReasoning := func(disable *bool) config.ModelConfig {
|
||||
cfg := config.ModelConfig{}
|
||||
cfg.SetDefaults()
|
||||
cfg.ReasoningConfig = reasoning.Config{DisableReasoning: disable}
|
||||
return cfg
|
||||
}
|
||||
disabled := true
|
||||
enabled := false
|
||||
|
||||
It("emits enable_thinking=false when reasoning is disabled", func() {
|
||||
opts := gRPCPredictOpts(withReasoning(&disabled), "/tmp/models")
|
||||
Expect(opts.Metadata).To(HaveKeyWithValue("enable_thinking", "false"))
|
||||
})
|
||||
|
||||
It("emits enable_thinking=true when reasoning is enabled", func() {
|
||||
opts := gRPCPredictOpts(withReasoning(&enabled), "/tmp/models")
|
||||
Expect(opts.Metadata).To(HaveKeyWithValue("enable_thinking", "true"))
|
||||
})
|
||||
|
||||
It("omits enable_thinking when reasoning is unset", func() {
|
||||
opts := gRPCPredictOpts(withReasoning(nil), "/tmp/models")
|
||||
Expect(opts.Metadata).ToNot(HaveKey("enable_thinking"))
|
||||
})
|
||||
})
|
||||
|
||||
// Guards forwarding the effective reasoning_effort into PredictOptions.Metadata,
|
||||
// where the backend passes it to the jinja chat template (chat_template_kwargs)
|
||||
// so models like gpt-oss / LFM2.5 honor it.
|
||||
var _ = Describe("gRPCPredictOpts reasoning_effort metadata", func() {
|
||||
withEffort := func(effort string) config.ModelConfig {
|
||||
cfg := config.ModelConfig{}
|
||||
cfg.SetDefaults()
|
||||
cfg.ReasoningEffort = effort
|
||||
return cfg
|
||||
}
|
||||
|
||||
It("forwards reasoning_effort when set", func() {
|
||||
opts := gRPCPredictOpts(withEffort("none"), "/tmp/models")
|
||||
Expect(opts.Metadata).To(HaveKeyWithValue("reasoning_effort", "none"))
|
||||
})
|
||||
|
||||
It("omits reasoning_effort when empty", func() {
|
||||
opts := gRPCPredictOpts(withEffort(""), "/tmp/models")
|
||||
Expect(opts.Metadata).ToNot(HaveKey("reasoning_effort"))
|
||||
})
|
||||
})
|
||||
|
||||
36
core/backend/prefix_source.go
Normal file
36
core/backend/prefix_source.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// messagesPrefixSource builds a deterministic, prefix-stable serialization of a
|
||||
// chat conversation for prefix-cache-aware routing. It is the fallback used when
|
||||
// the frontend did not render a prompt string: models with
|
||||
// config.TemplateConfig.UseTokenizerTemplate tokenize the structured messages
|
||||
// backend-side, so the frontend's rendered prompt is empty and a chain built
|
||||
// from it would always be empty - silently degrading prefix routing to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
//
|
||||
// Messages are emitted head-first in turn order (role line + content line per
|
||||
// message), so two conversations sharing a leading system prompt and early turns
|
||||
// share a leading byte prefix. That is exactly what ExtractChain hashes into a
|
||||
// shared chain prefix, landing both requests on the same cache-warm replica.
|
||||
func messagesPrefixSource(messages schema.Messages) string {
|
||||
var b strings.Builder
|
||||
for _, m := range messages {
|
||||
b.WriteString(m.Role)
|
||||
b.WriteByte('\n')
|
||||
content := m.StringContent
|
||||
if content == "" {
|
||||
if s, ok := m.Content.(string); ok {
|
||||
content = s
|
||||
}
|
||||
}
|
||||
b.WriteString(content)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
53
core/backend/prefix_source_internal_test.go
Normal file
53
core/backend/prefix_source_internal_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("messagesPrefixSource", func() {
|
||||
mk := func(role, content string) schema.Message {
|
||||
return schema.Message{Role: role, StringContent: content}
|
||||
}
|
||||
|
||||
It("serializes messages head-first in turn order", func() {
|
||||
got := messagesPrefixSource(schema.Messages{
|
||||
mk("system", "You are helpful."),
|
||||
mk("user", "Hi"),
|
||||
})
|
||||
Expect(got).To(Equal("system\nYou are helpful.\nuser\nHi\n"))
|
||||
})
|
||||
|
||||
It("is deterministic across calls for the same conversation", func() {
|
||||
conv := schema.Messages{mk("system", "S"), mk("user", "U")}
|
||||
Expect(messagesPrefixSource(conv)).To(Equal(messagesPrefixSource(conv)))
|
||||
})
|
||||
|
||||
It("shares a leading byte prefix when the system prompt is shared", func() {
|
||||
shared := "system\nShared system prompt.\nuser\n"
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question A")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question B")})
|
||||
Expect(strings.HasPrefix(a, shared)).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, shared)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does NOT share a prefix when the system prompt differs", func() {
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Prompt A"), mk("user", "Q")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Prompt B"), mk("user", "Q")})
|
||||
Expect(strings.HasPrefix(a, "system\nPrompt A")).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, "system\nPrompt B")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns empty for no messages", func() {
|
||||
Expect(messagesPrefixSource(nil)).To(Equal(""))
|
||||
})
|
||||
|
||||
It("falls back to Content when StringContent is empty", func() {
|
||||
got := messagesPrefixSource(schema.Messages{{Role: "user", Content: "plain"}})
|
||||
Expect(got).To(Equal("user\nplain\n"))
|
||||
})
|
||||
})
|
||||
@@ -20,11 +20,32 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
// newTTSRequest assembles the gRPC TTSRequest from the per-request inputs. The
|
||||
// optional instructions string is only attached when non-empty so backends can
|
||||
// distinguish "no per-request instruction" (fall back to YAML) from an explicit
|
||||
// empty one. params is forwarded as-is (nil when unset).
|
||||
func newTTSRequest(text, modelPath, voice, dst, language, instructions string, params map[string]string) *proto.TTSRequest {
|
||||
req := &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: dst,
|
||||
Language: &language,
|
||||
Params: params,
|
||||
}
|
||||
if instructions != "" {
|
||||
req.Instructions = &instructions
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func ModelTTS(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -74,13 +95,9 @@ func ModelTTS(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: filePath,
|
||||
Language: &language,
|
||||
})
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, filePath, language, instructions, params)
|
||||
|
||||
res, err := ttsModel.TTS(ctx, ttsRequest)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
@@ -128,7 +145,9 @@ func ModelTTSStream(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -177,12 +196,10 @@ func ModelTTSStream(
|
||||
var totalPCMBytes int
|
||||
snippetCapped := false
|
||||
|
||||
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Language: &language,
|
||||
}, func(reply *proto.Reply) {
|
||||
// Streaming TTS writes to the HTTP response, not a file, so dst is empty.
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, "", language, instructions, params)
|
||||
|
||||
err = ttsModel.TTSStream(ctx, ttsRequest, func(reply *proto.Reply) {
|
||||
// First message contains sample rate info
|
||||
if !headerSent && len(reply.Message) > 0 {
|
||||
var info map[string]any
|
||||
|
||||
42
core/backend/tts_test.go
Normal file
42
core/backend/tts_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
// Specs for the TTSRequest assembly that carries the per-request
|
||||
// instructions/params from the OpenAI `instructions` field (and the LocalAI
|
||||
// `params` extension) through to the gRPC boundary. Before this plumbing the
|
||||
// instruction value was dropped before reaching the backend; these specs pin
|
||||
// that it now survives, and that the empty case stays backward compatible.
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("newTTSRequest", func() {
|
||||
It("attaches the instructions when a per-request value is set", func() {
|
||||
req := newTTSRequest("hi", "/m", "alloy", "/out.wav", "en", "cheerful narrator", nil)
|
||||
Expect(req.Instructions).ToNot(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal("cheerful narrator"))
|
||||
Expect(req.GetText()).To(Equal("hi"))
|
||||
Expect(req.GetVoice()).To(Equal("alloy"))
|
||||
Expect(req.GetDst()).To(Equal("/out.wav"))
|
||||
Expect(req.GetLanguage()).To(Equal("en"))
|
||||
})
|
||||
|
||||
It("leaves instructions unset when empty so backends fall back to YAML", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.Instructions).To(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal(""))
|
||||
})
|
||||
|
||||
It("forwards per-request params through to the backend", func() {
|
||||
params := map[string]string{"exaggeration": "0.7", "cfg_weight": "0.3"}
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", params)
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("exaggeration", "0.7"))
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("cfg_weight", "0.3"))
|
||||
})
|
||||
|
||||
It("leaves params nil when none are supplied", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.GetParams()).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -52,10 +52,28 @@ type AgentWorkerCMD struct {
|
||||
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
|
||||
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
|
||||
|
||||
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
|
||||
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
|
||||
// DistributedRequireAuth is the umbrella switch; for the agent worker (which
|
||||
// has no file-transfer server) it implies NATS auth is required.
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch implying --nats-require-auth (agent workers have no file-transfer server)" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
|
||||
// Timeouts
|
||||
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
|
||||
}
|
||||
|
||||
// natsAuthRequired reports whether NATS JWT credentials must be present — the
|
||||
// granular flag or the umbrella (LOCALAI_DISTRIBUTED_REQUIRE_AUTH).
|
||||
func (cmd *AgentWorkerCMD) natsAuthRequired() bool {
|
||||
return cmd.NatsRequireAuth || cmd.DistributedRequireAuth
|
||||
}
|
||||
|
||||
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
|
||||
|
||||
@@ -81,15 +99,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
registrationBody["token"] = cmd.RegistrationToken
|
||||
}
|
||||
|
||||
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
// Context cancelled on shutdown — used by registration waits, heartbeat, and
|
||||
// other background goroutines.
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Acquire credentials via (re)registration. When the bus requires auth and no
|
||||
// static fallback is configured, wait through admin approval until the
|
||||
// frontend mints credentials rather than starting unauthenticated.
|
||||
credMgr := workerregistry.NewNATSCredentialManager(
|
||||
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
|
||||
return regClient.RegisterFull(ctx, registrationBody)
|
||||
},
|
||||
cmd.natsAuthRequired() && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
|
||||
)
|
||||
res, err := credMgr.Acquire(shutdownCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
nodeID := res.ID
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||
|
||||
// Use provisioned API token if none was set
|
||||
if cmd.APIToken == "" {
|
||||
cmd.APIToken = apiToken
|
||||
cmd.APIToken = res.APIToken
|
||||
}
|
||||
|
||||
// Start heartbeat
|
||||
@@ -98,14 +131,40 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cmd.NatsURL)
|
||||
// Resolve NATS credentials with precedence: explicit env override, then
|
||||
// frontend-minted (auto-refreshed before expiry), then service fallback.
|
||||
// Each static source must supply JWT and seed together.
|
||||
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
|
||||
var natsOpts []messaging.Option
|
||||
switch {
|
||||
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
|
||||
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
|
||||
case credMgr.HasCredentials():
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
|
||||
go func() {
|
||||
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
|
||||
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
|
||||
shutdownCancel()
|
||||
}
|
||||
}()
|
||||
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
|
||||
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
|
||||
case cmd.natsAuthRequired():
|
||||
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
|
||||
}
|
||||
if natsTLS.Enabled() {
|
||||
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
|
||||
}
|
||||
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
@@ -183,17 +242,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
|
||||
|
||||
// Wait for shutdown
|
||||
// Wait for an OS signal or an internal fatal condition (e.g. NATS
|
||||
// credentials became unrenewable), so the worker restarts and re-acquires
|
||||
// rather than lingering unable to serve.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
var runErr error
|
||||
select {
|
||||
case <-sigCh:
|
||||
case <-shutdownCtx.Done():
|
||||
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
|
||||
xlog.Error("Internal shutdown requested", "error", runErr)
|
||||
}
|
||||
|
||||
xlog.Info("Shutting down agent worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
dispatcher.Stop()
|
||||
mcpTools.CloseAllMCPSessions()
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
return nil
|
||||
return runErr
|
||||
}
|
||||
|
||||
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
|
||||
|
||||
30
core/cli/chat/chat.go
Normal file
30
core/cli/chat/chat.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Model string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
In io.Reader
|
||||
Out io.Writer
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, opts Options) error {
|
||||
if opts.In == nil {
|
||||
opts.In = strings.NewReader("")
|
||||
}
|
||||
if opts.Out == nil {
|
||||
opts.Out = io.Discard
|
||||
}
|
||||
|
||||
session, err := newChatSession(ctx, newLocalAIChatClient(opts.BaseURL, opts.APIKey), opts.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return runTerminalChat(ctx, session, opts.In, opts.Out)
|
||||
}
|
||||
13
core/cli/chat/chat_suite_test.go
Normal file
13
core/cli/chat/chat_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestChat(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Chat Suite")
|
||||
}
|
||||
172
core/cli/chat/chat_test.go
Normal file
172
core/cli/chat/chat_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Run chat", func() {
|
||||
It("streams a single chat response", func() {
|
||||
var capturedModel string
|
||||
var capturedAuth string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v1/models" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
writeResponse(w, `{"object":"list","data":[{"id":"test-model","object":"model"}]}`)
|
||||
return
|
||||
}
|
||||
|
||||
Expect(r.URL.Path).To(Equal("/v1/chat/completions"))
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
|
||||
var body struct {
|
||||
Model string `json:"model"`
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
|
||||
capturedModel = body.Model
|
||||
Expect(body.Messages).To(HaveLen(1))
|
||||
Expect(body.Messages[0].Role).To(Equal("user"))
|
||||
Expect(body.Messages[0].Content).To(Equal("hello"))
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n")
|
||||
writeResponse(w, "data: [DONE]\n\n")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var out bytes.Buffer
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
Model: "test-model",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "secret",
|
||||
In: strings.NewReader("hello\n/exit\n"),
|
||||
Out: &out,
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(capturedModel).To(Equal("test-model"))
|
||||
Expect(capturedAuth).To(Equal("Bearer secret"))
|
||||
Expect(out.String()).To(ContainSubstring("assistant: hi!"))
|
||||
Expect(out.String()).To(ContainSubstring("bye"))
|
||||
})
|
||||
|
||||
It("auto-selects the only available model", func() {
|
||||
server := chatTestServer([]string{"solo"}, nil)
|
||||
defer server.Close()
|
||||
|
||||
var out bytes.Buffer
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader("/exit\n"),
|
||||
Out: &out,
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out.String()).To(ContainSubstring("LocalAI chat (solo)"))
|
||||
})
|
||||
|
||||
It("returns an actionable error when no models are installed", func() {
|
||||
server := chatTestServer(nil, nil)
|
||||
defer server.Close()
|
||||
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader(""),
|
||||
})
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no chat models are installed"))
|
||||
Expect(err.Error()).To(ContainSubstring("local-ai models install <model>"))
|
||||
})
|
||||
|
||||
It("returns an actionable error when multiple models are available without a selection", func() {
|
||||
server := chatTestServer([]string{"alpha", "beta"}, nil)
|
||||
defer server.Close()
|
||||
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader(""),
|
||||
})
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("multiple models are available"))
|
||||
Expect(err.Error()).To(ContainSubstring("--model"))
|
||||
Expect(err.Error()).To(ContainSubstring("alpha"))
|
||||
Expect(err.Error()).To(ContainSubstring("beta"))
|
||||
})
|
||||
|
||||
It("lists and switches models inside the chat", func() {
|
||||
requestedModels := []string{}
|
||||
server := chatTestServer([]string{"alpha", "beta"}, func(model string) {
|
||||
requestedModels = append(requestedModels, model)
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
var out bytes.Buffer
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
Model: "alpha",
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader("/models\n/model beta\nhello\n/exit\n"),
|
||||
Out: &out,
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out.String()).To(ContainSubstring("* alpha"))
|
||||
Expect(out.String()).To(ContainSubstring(" beta"))
|
||||
Expect(out.String()).To(ContainSubstring("switched to beta; conversation cleared"))
|
||||
Expect(requestedModels).To(Equal([]string{"beta"}))
|
||||
})
|
||||
})
|
||||
|
||||
func chatTestServer(models []string, onChat func(model string)) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v1/models":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
writeResponse(w, `{"object":"list","data":[`)
|
||||
for i, model := range models {
|
||||
if i > 0 {
|
||||
writeResponse(w, ",")
|
||||
}
|
||||
writeResponsef(w, `{"id":%q,"object":"model"}`, model)
|
||||
}
|
||||
writeResponse(w, `]}`)
|
||||
case "/v1/chat/completions":
|
||||
var body struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
|
||||
if onChat != nil {
|
||||
onChat(body.Model)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n\n")
|
||||
writeResponse(w, "data: [DONE]\n\n")
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func writeResponse(w io.Writer, text string) {
|
||||
_, err := fmt.Fprint(w, text)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
func writeResponsef(w io.Writer, format string, args ...any) {
|
||||
_, err := fmt.Fprintf(w, format, args...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
114
core/cli/chat/client.go
Normal file
114
core/cli/chat/client.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type chatClient interface {
|
||||
ListModels(ctx context.Context) ([]string, error)
|
||||
StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error)
|
||||
}
|
||||
|
||||
type localAIChatClient struct {
|
||||
client *openai.Client
|
||||
}
|
||||
|
||||
func newLocalAIChatClient(baseURL string, apiKey string) *localAIChatClient {
|
||||
cfg := openai.DefaultConfig(apiKey)
|
||||
cfg.BaseURL = baseURL
|
||||
return &localAIChatClient{client: openai.NewClientWithConfig(cfg)}
|
||||
}
|
||||
|
||||
func (c *localAIChatClient) ListModels(ctx context.Context) ([]string, error) {
|
||||
resp, err := c.client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(resp.Models))
|
||||
for _, model := range resp.Models {
|
||||
if model.ID != "" {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
}
|
||||
sort.Strings(models)
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (c *localAIChatClient) StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
|
||||
stream, err := c.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: openAIChatMessages(messages),
|
||||
})
|
||||
if err != nil {
|
||||
return "", friendlyChatError(err, model)
|
||||
}
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
var answer strings.Builder
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return answer.String(), friendlyChatError(err, model)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
token := resp.Choices[0].Delta.Content
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
answer.WriteString(token)
|
||||
if _, err := fmt.Fprint(out, token); err != nil {
|
||||
return answer.String(), err
|
||||
}
|
||||
}
|
||||
|
||||
return answer.String(), nil
|
||||
}
|
||||
|
||||
func openAIChatMessages(messages []chatMessage) []openai.ChatCompletionMessage {
|
||||
converted := make([]openai.ChatCompletionMessage, len(messages))
|
||||
for i, message := range messages {
|
||||
converted[i] = openai.ChatCompletionMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func friendlyChatError(err error, model string) error {
|
||||
var apiErr *openai.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
switch apiErr.HTTPStatusCode {
|
||||
case 404:
|
||||
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
|
||||
case 403:
|
||||
return fmt.Errorf("model %q is disabled. Enable it from LocalAI settings or choose another model with `/model <name>`", model)
|
||||
}
|
||||
if apiErr.Message != "" {
|
||||
return errors.New(apiErr.Message)
|
||||
}
|
||||
}
|
||||
|
||||
msg := err.Error()
|
||||
if strings.Contains(msg, "model") && strings.Contains(msg, "not found") {
|
||||
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
17
core/cli/chat/models.go
Normal file
17
core/cli/chat/models.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package chat
|
||||
|
||||
import "strings"
|
||||
|
||||
func formatChatModelList(models []string, current string) string {
|
||||
var b strings.Builder
|
||||
for _, model := range models {
|
||||
prefix := " "
|
||||
if model == current {
|
||||
prefix = "* "
|
||||
}
|
||||
b.WriteString(prefix)
|
||||
b.WriteString(model)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
120
core/cli/chat/session.go
Normal file
120
core/cli/chat/session.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
chatRoleUser = "user"
|
||||
chatRoleAssistant = "assistant"
|
||||
)
|
||||
|
||||
type chatMessage struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
type chatSession struct {
|
||||
client chatClient
|
||||
model string
|
||||
models []string
|
||||
messages []chatMessage
|
||||
}
|
||||
|
||||
func newChatSession(ctx context.Context, client chatClient, requestedModel string) (*chatSession, error) {
|
||||
models, err := client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list models: %w", err)
|
||||
}
|
||||
|
||||
model, err := resolveChatModel(requestedModel, models)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &chatSession{
|
||||
client: client,
|
||||
model: model,
|
||||
models: models,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *chatSession) CurrentModel() string {
|
||||
return s.model
|
||||
}
|
||||
|
||||
func (s *chatSession) Models() []string {
|
||||
models := make([]string, len(s.models))
|
||||
copy(models, s.models)
|
||||
return models
|
||||
}
|
||||
|
||||
func (s *chatSession) Clear() {
|
||||
s.messages = nil
|
||||
}
|
||||
|
||||
func (s *chatSession) SwitchModel(model string) error {
|
||||
if !modelExists(s.models, model) {
|
||||
return fmt.Errorf("model %q is not available. Use /models to see installed models", model)
|
||||
}
|
||||
s.model = model
|
||||
s.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *chatSession) Send(ctx context.Context, prompt string, out io.Writer) error {
|
||||
s.messages = append(s.messages, chatMessage{
|
||||
Role: chatRoleUser,
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
answer, err := s.client.StreamChat(ctx, s.model, s.messages, out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.messages = append(s.messages, chatMessage{
|
||||
Role: chatRoleAssistant,
|
||||
Content: answer,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveChatModel(requested string, models []string) (string, error) {
|
||||
switch {
|
||||
case requested == "" && len(models) == 0:
|
||||
return "", errors.New(`no chat models are installed.
|
||||
|
||||
Install a model first, for example:
|
||||
local-ai models list
|
||||
local-ai models install <model>
|
||||
local-ai run
|
||||
|
||||
Then start a chat session:
|
||||
local-ai chat --model <model>`)
|
||||
case requested == "" && len(models) == 1:
|
||||
return models[0], nil
|
||||
case requested == "" && len(models) > 1:
|
||||
var b strings.Builder
|
||||
b.WriteString("multiple models are available; choose one with --model:\n")
|
||||
b.WriteString(formatChatModelList(models, ""))
|
||||
return "", errors.New(b.String())
|
||||
case !modelExists(models, requested):
|
||||
return "", fmt.Errorf("model %q is not available. Use `local-ai models list` and `local-ai models install <model>`, or pass an installed model with --model", requested)
|
||||
default:
|
||||
return requested, nil
|
||||
}
|
||||
}
|
||||
|
||||
func modelExists(models []string, name string) bool {
|
||||
for _, model := range models {
|
||||
if model == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
56
core/cli/chat/session_test.go
Normal file
56
core/cli/chat/session_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Chat session", func() {
|
||||
It("keeps model switching and message history out of the terminal adapter", func() {
|
||||
client := &fakeChatClient{
|
||||
models: []string{"alpha", "beta"},
|
||||
answer: "pong",
|
||||
}
|
||||
|
||||
session, err := newChatSession(context.Background(), client, "alpha")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(session.CurrentModel()).To(Equal("alpha"))
|
||||
|
||||
Expect(session.SwitchModel("beta")).To(Succeed())
|
||||
Expect(session.CurrentModel()).To(Equal("beta"))
|
||||
Expect(session.Send(context.Background(), "ping", io.Discard)).To(Succeed())
|
||||
|
||||
Expect(client.requests).To(HaveLen(1))
|
||||
Expect(client.requests[0].model).To(Equal("beta"))
|
||||
Expect(client.requests[0].messages).To(HaveLen(1))
|
||||
Expect(client.requests[0].messages[0].Content).To(Equal("ping"))
|
||||
})
|
||||
})
|
||||
|
||||
type fakeChatClient struct {
|
||||
models []string
|
||||
answer string
|
||||
requests []fakeChatRequest
|
||||
}
|
||||
|
||||
type fakeChatRequest struct {
|
||||
model string
|
||||
messages []chatMessage
|
||||
}
|
||||
|
||||
func (c *fakeChatClient) ListModels(context.Context) ([]string, error) {
|
||||
return c.models, nil
|
||||
}
|
||||
|
||||
func (c *fakeChatClient) StreamChat(_ context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
|
||||
copied := make([]chatMessage, len(messages))
|
||||
copy(copied, messages)
|
||||
c.requests = append(c.requests, fakeChatRequest{model: model, messages: copied})
|
||||
if _, err := io.WriteString(out, c.answer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return c.answer, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user