mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-22 15:49:12 -04:00
Compare commits
152 Commits
feat/dllm-
...
feat/recon
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6170b875d | ||
|
|
a9c7484986 | ||
|
|
e05dece93c | ||
|
|
7c2a347e79 | ||
|
|
6e0c491380 | ||
|
|
2bcdfe2a68 | ||
|
|
b843f498ca | ||
|
|
46d7d59a82 | ||
|
|
e3bca9a172 | ||
|
|
a19ab22186 | ||
|
|
91d08d88e6 | ||
|
|
2c5ed413cb | ||
|
|
01e098a844 | ||
|
|
600dafd20b | ||
|
|
ce8a3e9266 | ||
|
|
a88d9d2de3 | ||
|
|
1cf1bf32e1 | ||
|
|
f45c6acc54 | ||
|
|
1a1bd57469 | ||
|
|
1f29e96030 | ||
|
|
64560a974b | ||
|
|
32c47706ae | ||
|
|
e58870a573 | ||
|
|
8fab1d2e45 | ||
|
|
7b462a0d51 | ||
|
|
aed181e6c1 | ||
|
|
a556cd9afc | ||
|
|
b50b1fe418 | ||
|
|
b4c0dc67fe | ||
|
|
01fa12e0de | ||
|
|
cf7f9573a2 | ||
|
|
c6303104c7 | ||
|
|
3e96d811b7 | ||
|
|
23f225260c | ||
|
|
aef10723c9 | ||
|
|
9565db5f94 | ||
|
|
e19c43cf04 | ||
|
|
b081247d95 | ||
|
|
1be959ce30 | ||
|
|
518381278e | ||
|
|
93706fec57 | ||
|
|
11aee03a80 | ||
|
|
8915f2ab91 | ||
|
|
f143d7f688 | ||
|
|
dd928f0bdd | ||
|
|
c43a752afc | ||
|
|
079ac0e15a | ||
|
|
2e734bf560 | ||
|
|
72d46c1115 | ||
|
|
606128e4e9 | ||
|
|
59c7ad5153 | ||
|
|
78d682224a | ||
|
|
29dbba7a25 | ||
|
|
4ad754eea3 | ||
|
|
67692cb984 | ||
|
|
f68edfc85f | ||
|
|
c3b3336654 | ||
|
|
c4cd86bb15 | ||
|
|
13f59f0822 | ||
|
|
3fa7b2955c | ||
|
|
c133ca39dc | ||
|
|
757822cd74 | ||
|
|
91f97f2a54 | ||
|
|
55f9ff6805 | ||
|
|
88726f2da4 | ||
|
|
5c2ae7857a | ||
|
|
4af360300f | ||
|
|
5ac864dbed | ||
|
|
9b57dcb721 | ||
|
|
95e7149c87 | ||
|
|
fd26c8c753 | ||
|
|
e60c094a7d | ||
|
|
159df8e2ef | ||
|
|
de299ca101 | ||
|
|
980ec4a311 | ||
|
|
dfd5a00e6f | ||
|
|
63be479066 | ||
|
|
4c6750fe6b | ||
|
|
a6e1c6d0b3 | ||
|
|
294170d3ed | ||
|
|
1ab61a0875 | ||
|
|
f44034021e | ||
|
|
6b9f1bd4b3 | ||
|
|
416f871bea | ||
|
|
8bd2df8f68 | ||
|
|
6799d802d3 | ||
|
|
40cc549882 | ||
|
|
3d295adfa8 | ||
|
|
4fa2064875 | ||
|
|
cb74399b3a | ||
|
|
2388686369 | ||
|
|
edc61053aa | ||
|
|
9ba8521e7e | ||
|
|
51c23197ed | ||
|
|
2df2876db2 | ||
|
|
f648f07b13 | ||
|
|
1dedb5277c | ||
|
|
7d2a762b53 | ||
|
|
61cde6fd77 | ||
|
|
ca1668dd85 | ||
|
|
fdc352a618 | ||
|
|
692970e507 | ||
|
|
e046a7749f | ||
|
|
e5c95e0449 | ||
|
|
4d3d54d61b | ||
|
|
36e3419203 | ||
|
|
4ec6e3221e | ||
|
|
4bb592cf91 | ||
|
|
3e838c0cff | ||
|
|
36b4a81d1e | ||
|
|
0854932a25 | ||
|
|
203410871b | ||
|
|
7637f8cf1b | ||
|
|
f0e001b7f8 | ||
|
|
cf9debf4eb | ||
|
|
e1556aa1dc | ||
|
|
53cbb578a9 | ||
|
|
99c8205740 | ||
|
|
d7162b9f89 | ||
|
|
3351b62c91 | ||
|
|
0eca930b8d | ||
|
|
81ab62e874 | ||
|
|
0413fc03f8 | ||
|
|
7088572f75 | ||
|
|
c1e8440f5b | ||
|
|
8f0059123b | ||
|
|
a906438a69 | ||
|
|
d28a5b6da1 | ||
|
|
edeacf22c4 | ||
|
|
51f4f67c47 | ||
|
|
cf71e291b4 | ||
|
|
a7a7bd646b | ||
|
|
cec93d2e00 | ||
|
|
722bdb87e9 | ||
|
|
50dea8c983 | ||
|
|
46ba70632b | ||
|
|
60facc7252 | ||
|
|
8c8204d3c4 | ||
|
|
4ce0f6102a | ||
|
|
085fc53bbc | ||
|
|
56cc4f63fc | ||
|
|
a53f34e78f | ||
|
|
1cea96f09f | ||
|
|
006a9d38c7 | ||
|
|
892ce951ce | ||
|
|
7cda221d36 | ||
|
|
9a88eb81e7 | ||
|
|
58cdc050e9 | ||
|
|
b962f4a192 | ||
|
|
b6fcb3e1db | ||
|
|
ff09683d84 | ||
|
|
f618636c71 |
@@ -198,6 +198,27 @@ docker-build-backends: ... docker-build-<backend-name>
|
|||||||
- If the backend is in `backend/python/<backend-name>/` but uses `.` as context in the workflow file, use `.` context
|
- If the backend is in `backend/python/<backend-name>/` but uses `.` as context in the workflow file, use `.` context
|
||||||
- Check similar backends to determine the correct context
|
- Check similar backends to determine the correct context
|
||||||
|
|
||||||
|
## Documenting the backend (README + docs)
|
||||||
|
|
||||||
|
A backend is not "added" until it is discoverable. Update the user-facing docs:
|
||||||
|
|
||||||
|
- **`docs/content/features/backends.md`** - add the backend to the right
|
||||||
|
category in the "LocalAI supports various types of backends" list (and add a
|
||||||
|
new category if it introduces a new modality, e.g. sound classification).
|
||||||
|
- If the backend introduces a **new API surface** (a new endpoint or a realtime
|
||||||
|
capability), document it under `docs/content/` where its area lives (audio,
|
||||||
|
vision, etc.) and follow the api-endpoints checklist in
|
||||||
|
[api-endpoints-and-auth.md](api-endpoints-and-auth.md).
|
||||||
|
|
||||||
|
**If the backend is a native C/C++/GGML engine created and maintained by the
|
||||||
|
LocalAI team** (a from-scratch port like `parakeet.cpp`, `ced.cpp`,
|
||||||
|
`vibevoice.cpp`, `rf-detr.cpp`, not a wrapper around a third-party runtime), it
|
||||||
|
ALSO belongs in the top-level **`README.md`** table under "native C/C++/GGML
|
||||||
|
engines ... developed and maintained by the LocalAI project itself". Add a row
|
||||||
|
linking the upstream engine repo with a one-line description. This is the
|
||||||
|
project's showcase of its own engines; a new in-house backend that is missing
|
||||||
|
from it is a documentation bug.
|
||||||
|
|
||||||
## 5. Verification Checklist
|
## 5. Verification Checklist
|
||||||
|
|
||||||
After adding a new backend, verify:
|
After adding a new backend, verify:
|
||||||
@@ -211,6 +232,8 @@ After adding a new backend, verify:
|
|||||||
- [ ] No YAML syntax errors (check with linter)
|
- [ ] No YAML syntax errors (check with linter)
|
||||||
- [ ] No Makefile syntax errors (check with linter)
|
- [ ] No Makefile syntax errors (check with linter)
|
||||||
- [ ] Follows the same pattern as similar backends (e.g., if it's a transcription backend, follow `faster-whisper` pattern)
|
- [ ] Follows the same pattern as similar backends (e.g., if it's a transcription backend, follow `faster-whisper` pattern)
|
||||||
|
- [ ] Documented: added to the category list in `docs/content/features/backends.md` (and any new endpoint/realtime capability documented under `docs/content/`)
|
||||||
|
- [ ] If it is an in-house native C/C++/GGML engine, added to the maintained-engines table in the top-level `README.md`
|
||||||
|
|
||||||
## Bundling runtime shared libraries (`package.sh`)
|
## Bundling runtime shared libraries (`package.sh`)
|
||||||
|
|
||||||
|
|||||||
@@ -1,138 +0,0 @@
|
|||||||
# Working on the dllm Backend
|
|
||||||
|
|
||||||
`mudler/dllm.cpp` is a standalone C++/ggml engine for DiffusionGemma
|
|
||||||
block-diffusion models. LocalAI wraps it with a **pure-Go** backend at
|
|
||||||
`backend/go/dllm/` that dlopens `libdllm.so` via purego (ebitengine/purego) -
|
|
||||||
NOT cgo, and NOT a C++ grpc-server fork. The Go side owns chat templating
|
|
||||||
(gemma4 renderer) and output parsing (gemma4 streaming parser) and implements
|
|
||||||
the rich gRPC interface (`PredictRich`/`PredictStreamRich`, ChatDelta replies).
|
|
||||||
|
|
||||||
> NOTE: github.com/mudler/dllm.cpp is still **private** (publishing is
|
|
||||||
> planned). Until then the Makefile's anonymous clone fails; use the local-dev
|
|
||||||
> symlink shortcut documented at the top of `backend/go/dllm/Makefile`
|
|
||||||
> (symlink an out-of-tree `build/libdllm.so` into the backend dir and skip the
|
|
||||||
> clone), or a git credential helper with repo access.
|
|
||||||
|
|
||||||
## Pin
|
|
||||||
|
|
||||||
`backend/go/dllm/Makefile` pins `DLLM_VERSION?=<sha>` at the top
|
|
||||||
(whisper / parakeet-cpp / ds4 convention). The bump-deps bot
|
|
||||||
(`.github/workflows/bump_deps.yaml`) tracks `mudler/dllm.cpp` `main` and
|
|
||||||
rewrites that variable. After a manual bump: `make -C backend/go/dllm purge &&
|
|
||||||
make -C backend/go/dllm` (the clone is keyed on the directory existing, not
|
|
||||||
the sha).
|
|
||||||
|
|
||||||
## C-ABI and the serialization contract
|
|
||||||
|
|
||||||
The binding covers the 9-symbol flat C-ABI from dllm.cpp's
|
|
||||||
`include/dllm_capi.h` (ABI v1; `main.go` hard-fails on a version mismatch):
|
|
||||||
`abi_version, load, free, last_error, free_string, tokenize_json, generate,
|
|
||||||
generate_stream, cancel`. Contract points the Go wiring encodes (`capi.go`
|
|
||||||
header comment has the full list):
|
|
||||||
|
|
||||||
- **One ctx = one concurrent generate/tokenize.** A per-model worker
|
|
||||||
goroutine (`Dllm.jobs` in `dllm.go`) owns ALL C calls, making the
|
|
||||||
serialization structural instead of lock discipline.
|
|
||||||
- **`dllm_capi_cancel` is the ONE exception**: it only flips an atomic and may
|
|
||||||
be called from any goroutine mid-generate, so `Dllm.Cancel` bypasses the
|
|
||||||
worker queue. The flag resets at the start of each generate, so a watchdog
|
|
||||||
racing a new generate must re-issue cancel.
|
|
||||||
- **`last_error` is a borrowed pointer** and must only be read AFTER the
|
|
||||||
failing call returned (never while a generate is in flight on the same ctx).
|
|
||||||
- **Free vs in-flight requests**: requests hold `genMu.RLock` for their full
|
|
||||||
duration; `Free` takes the write lock, so it only runs when nothing is in
|
|
||||||
flight, then drains and closes the worker. Post-Free requests get a clean
|
|
||||||
"model not loaded" error.
|
|
||||||
- `tokenize_json`/`generate` return malloc'd `char*` (bound as `uintptr`,
|
|
||||||
copied, then `dllm_capi_free_string`d); opts/params JSON must be a FLAT
|
|
||||||
object of scalars (`buildOptsJSON` rejects anything else).
|
|
||||||
|
|
||||||
## Wire shape
|
|
||||||
|
|
||||||
| RPC | Implementation |
|
|
||||||
|---|---|
|
|
||||||
| LoadModel | `dllm_capi_load` (params: `n_gpu_layers`, `n_threads`, `ctx_len`); `Options[]` parsed into per-request gen opts (`eb_*`, `blocks`, `kv_cache`) by `parseModelGenOpts` |
|
|
||||||
| PredictRich | render (if templated) → `dllm_capi_generate` → parse → ONE Reply with aggregated ChatDeltas + legacy `Message` bytes |
|
|
||||||
| PredictStreamRich | `dllm_capi_generate_stream`; per committed diffusion block → UTF-8 holdback → parser.Feed → one Reply per non-empty delta batch (channel closed by the CALLER, per `pkg/grpc/interface.go`) |
|
|
||||||
| Predict / PredictStream | Legacy paths, delegate to the rich pair (legacy stream INVERTS channel ownership: the impl closes) |
|
|
||||||
| TokenizeString | `dllm_capi_tokenize_json` (C side prepends BOS per `vocab.add_bos`) |
|
|
||||||
| Cancel | `dllm_capi_cancel`, exposed as the `grpc.Cancellable` capability (`pkg/grpc/interface.go`): the gRPC server arms it via `context.AfterFunc` on the Predict/PredictStream context, so client disconnects/timeouts abort the in-flight generate - llama.cpp `IsCancelled()` parity for Go backends |
|
|
||||||
|
|
||||||
`n_threads` and `ctx_len` are accepted-but-ignored by the engine at the
|
|
||||||
current pin (the context bound comes from GGUF `n_ctx_train`); they are sent
|
|
||||||
for forward compatibility.
|
|
||||||
|
|
||||||
## Renderer / parser (the templated chat path)
|
|
||||||
|
|
||||||
With `use_tokenizer_template` + raw Messages, the backend owns templating and
|
|
||||||
parsing (the ds4 precedent, but in Go):
|
|
||||||
|
|
||||||
- `gemma4_renderer.go` - `RenderGemma4(msgs, toolsJSON, enableThinking,
|
|
||||||
addGenerationPrompt)`. The file embeds the FULL `tokenizer.chat_template`
|
|
||||||
jinja (17466 bytes, md5 `8c34cf93c7a7815b3fdb300a009c4c17`) extracted
|
|
||||||
verbatim from `diffusiongemma-26B-A4B-it-BF16.gguf` via gguf-py - e.g.
|
|
||||||
`python scripts/dump_gguf.py model.gguf | grep -A400 chat_template` in the
|
|
||||||
dllm.cpp checkout - as a numbered comment block; every Go rule cites its
|
|
||||||
"tpl L<n>" line. Re-verify the md5 before blaming the renderer for a
|
|
||||||
mismatch with a new GGUF. **BOS exception**: the template emits
|
|
||||||
`{{- bos_token -}}` but the renderer deliberately does NOT - dllm.cpp's
|
|
||||||
`run_generate` tokenizes with `prepend_bos = vocab.add_bos` (true for
|
|
||||||
gemma4), so a literal `<bos>` would double it.
|
|
||||||
- `gemma4_parser.go` - streaming state machine turning raw model text
|
|
||||||
(fragments can split anywhere, including mid-marker) into ChatDeltas:
|
|
||||||
thought channels → `reasoning_content`, `<|tool_call>call:name{...}` →
|
|
||||||
ToolCallDelta, `<turn|>` → done. Marker grammar cross-checked against vLLM
|
|
||||||
PR #45163's gemma4 tool/reasoning parsers. Malformed payloads are re-emitted
|
|
||||||
raw as content, never dropped.
|
|
||||||
- Thinking is **opt-in** for this family (`Metadata["enable_thinking"]`,
|
|
||||||
default OFF - the inverse of ds4): the template gates every thinking branch
|
|
||||||
on `enable_thinking`, and the no-thinking render pre-closes an empty thought
|
|
||||||
channel, so the parser always starts in content state.
|
|
||||||
- **UTF-8 boundary holdback** (`splitValidUTF8` in `dllm.go`): per-block
|
|
||||||
detokenization can split a multi-byte character across block boundaries, and
|
|
||||||
grpc-go refuses to marshal invalid UTF-8 in proto3 strings. An incomplete
|
|
||||||
trailing sequence (at most 3 bytes) is carried into the next block; genuinely
|
|
||||||
undecodable bytes become U+FFFD.
|
|
||||||
|
|
||||||
Without `use_tokenizer_template`, the prompt passes through verbatim and the
|
|
||||||
output is NOT gemma4-parsed (plain content, like any non-autoparsing backend).
|
|
||||||
|
|
||||||
## Tests
|
|
||||||
|
|
||||||
| Layer | Gate | What |
|
|
||||||
|---|---|---|
|
|
||||||
| `backend/go/dllm/*_test.go` (renderer/parser/wiring) | none - run in plain `go test ./backend/go/dllm/...` | Ginkgo specs over a fake `generator` seam; canonical renderer fixtures from transformers' `test_modeling_diffusion_gemma.py`, parser tables from the vLLM gemma4 parsers |
|
|
||||||
| `backend/go/dllm/dllm_test.go` C-ABI smoke | `DLLM_TEST_LIBRARY` + `DLLM_TEST_TINY_MODEL` (dllm.cpp's `tests/fixtures/tiny_with_vocab.gguf`); Skips when unset | Drives the real `libdllm.so`: ABI check, load, tokenize `[2,18]`, deterministic generate, cancel (incl. mid-stream `Dllm.Cancel` aborting a deliberately slow `eb_max_steps:256` run in ~10ms) |
|
|
||||||
| `tests/e2e-backends/dllm_test.go` | `BACKEND_TEST_DLLM=1` + `BACKEND_BINARY` (packaged run.sh) + `BACKEND_TEST_MODEL_FILE` (tiny fixture) | Templated chat round trip (Messages + UseTokenizerTemplate) over the real gRPC binary, non-streaming + streaming; plus client-context cancellation mid-stream (proves the `Cancellable` server plumbing end to end) |
|
|
||||||
| Real-model e2e | `BACKEND_TEST_DLLM_REAL_MODEL_FILE` (26B BF16, ~50 GB) + `BACKEND_TEST_DLLM_REAL_GPU_LAYERS` | CUDA-13-class hardware only |
|
|
||||||
|
|
||||||
Tool-call e2e is deliberately absent from the tiny-model spec: the fixture has
|
|
||||||
random weights and cannot be coaxed into emitting tool markup; the unit tables
|
|
||||||
carry that coverage.
|
|
||||||
|
|
||||||
## Build matrix
|
|
||||||
|
|
||||||
`cpu-dllm` (amd64 + arm64), `cuda13-dllm` (amd64), and
|
|
||||||
`cuda13-nvidia-l4t-arm64-dllm` (arm64 CUDA: Jetson / DGX Spark GB10), via
|
|
||||||
`.github/backend-matrix.yml`. No darwin/Metal. CUDA builds forward
|
|
||||||
`-DDLLM_CUDA=ON` (dllm.cpp gates ggml's CUDA behind its own flag - a bare
|
|
||||||
`-DGGML_CUDA=ON` is overridden by the cache FORCE). `libdllm.so` is
|
|
||||||
self-contained (ggml statically absorbed, PIC), so `package.sh` only ships
|
|
||||||
the binary, `run.sh` and that one .so (the parakeet-cpp-style stub layout;
|
|
||||||
no ldd walk yet).
|
|
||||||
|
|
||||||
## Known limitations
|
|
||||||
|
|
||||||
- **Cancel granularity**: the C-ABI cancel flag is per-ctx and resets on
|
|
||||||
every generate entry, so a Cancel racing a NEW generate can be lost, and
|
|
||||||
with requests queued on the worker it aborts whichever generate is
|
|
||||||
currently running (acceptable: the server de-registers the hook on normal
|
|
||||||
completion, one process serves one model).
|
|
||||||
- **Throughput**: ~0.15 tok/s on the 26B at default settings (GB10) - every
|
|
||||||
denoise step recomputes the full prompt+canvas. The upstream prefix-KV
|
|
||||||
cache (dllm.cpp P3) is the fix; `kv_cache:on` errors until it lands
|
|
||||||
(`auto`/`off` are accepted no-ops).
|
|
||||||
- **Repo privacy**: see the note at the top - CI clone of dllm.cpp needs the
|
|
||||||
repo published (or credentials) before the backend images can build.
|
|
||||||
- Engine spec/validation references: dllm.cpp `docs/validation.md` and
|
|
||||||
LocalAI `docs/superpowers/specs/2026-06-10-dllm-cpp-design.md`.
|
|
||||||
@@ -44,6 +44,39 @@ maps to `DS4_THINK_HIGH`. We pass the chosen mode to `ds4_chat_append_assistant_
|
|||||||
via `ModelOptions.Options[] = "kv_cache_dir:/some/path"`. Format is **our own** -
|
via `ModelOptions.Options[] = "kv_cache_dir:/some/path"`. Format is **our own** -
|
||||||
NOT bit-compatible with ds4-server's KVC files (interop is a follow-up plan).
|
NOT bit-compatible with ds4-server's KVC files (interop is a follow-up plan).
|
||||||
|
|
||||||
|
## Engine options (LoadModel)
|
||||||
|
|
||||||
|
`LoadModel` maps `ModelOptions.Options[]` (`"key:value"`, from model-YAML
|
||||||
|
`options:`) onto `ds4_engine_options` through a **declarative table**
|
||||||
|
(`kEngineOptSpecs` + `apply_engine_option` in `grpc-server.cpp`). The struct is
|
||||||
|
plain C with no reflection, so the field set is enumerated once in the table;
|
||||||
|
adding a future engine knob is a one-line table row, not a new branch. Unknown
|
||||||
|
keys are ignored (back-compat). A bare flag (`ssd_streaming` with no value)
|
||||||
|
means `true`. Path-type values (`mtp_path`, `expert_profile_path`,
|
||||||
|
`directional_steering_file`) resolve **relative to the model directory**, so a
|
||||||
|
gallery entry can reference a companion file it downloaded by bare filename;
|
||||||
|
absolute values pass through. `ds4_role` / `ds4_layers` / `ds4_listen` /
|
||||||
|
`ds4_route_timeout` / `kv_cache_dir` keep their dedicated handling (validation
|
||||||
|
+ coordinator wiring) and are not in the table.
|
||||||
|
|
||||||
|
Wired keys: `mtp_path`, `mtp_draft`, `mtp_margin`, `prefill_chunk`,
|
||||||
|
`power_percent`, `warm_weights`, `quality`, `ssd_streaming`,
|
||||||
|
`ssd_streaming_cold`, `ssd_streaming_preload_experts`,
|
||||||
|
`ssd_streaming_cache_experts` (count or `NGB`, sets both experts+bytes via
|
||||||
|
`ds4_parse_streaming_cache_experts_arg`), `simulate_used_memory` (`NGB` via
|
||||||
|
`ds4_parse_gib_arg`), `expert_profile_path`, `directional_steering_file`,
|
||||||
|
`directional_steering_attn`, `directional_steering_ffn`.
|
||||||
|
|
||||||
|
## SSD streaming (running models larger than RAM)
|
||||||
|
|
||||||
|
ds4's **SSD streaming** keeps non-routed weights resident and streams routed MoE
|
||||||
|
experts from the GGUF on cache misses, turning "does it fit in RAM" into a speed
|
||||||
|
spectrum. **Metal (Darwin) only** - it is a no-op on CUDA/CPU. Enable with
|
||||||
|
`options: ["ssd_streaming"]`; size the routed-expert cache with
|
||||||
|
`ssd_streaming_cache_experts:NGB` (omit for ds4's automatic 80%-of-working-set
|
||||||
|
budget). Gallery entries built on this: `deepseek-v4-flash-q4-ssd` (153 GB Flash
|
||||||
|
on a 128 GB Mac) and `deepseek-v4-pro-q2-ssd` (433 GB Pro, experimental).
|
||||||
|
|
||||||
## Build matrix
|
## Build matrix
|
||||||
|
|
||||||
| Build | Where | Notes |
|
| Build | Where | Notes |
|
||||||
|
|||||||
@@ -70,6 +70,12 @@ if [ "${BUILD_TYPE:-}" = "vulkan" ] && [ "${SKIP_DRIVERS:-false}" = "false" ]; t
|
|||||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||||
|
# Mesa Vulkan ICD drivers (ANV/RADV/lavapipe + Arm SoC) and their ICD
|
||||||
|
# manifests. The LunarG SDK below only provides the loader and shader
|
||||||
|
# tooling, not hardware drivers — without Mesa the packaged Vulkan backend
|
||||||
|
# would ship a loader that finds no GPU. package-gpu-libs.sh bundles these
|
||||||
|
# .so files plus their deps into the backend so it stays self-contained.
|
||||||
|
apt-get install -y mesa-vulkan-drivers libdrm2
|
||||||
if [ "amd64" = "${TARGETARCH:-}" ]; then
|
if [ "amd64" = "${TARGETARCH:-}" ]; then
|
||||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz"
|
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz"
|
||||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz
|
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz
|
||||||
|
|||||||
@@ -31,6 +31,15 @@ backend/python/**/source
|
|||||||
backend/cpp/llama-cpp/llama.cpp
|
backend/cpp/llama-cpp/llama.cpp
|
||||||
backend/cpp/llama-cpp-*-build
|
backend/cpp/llama-cpp-*-build
|
||||||
|
|
||||||
|
# privacy-filter: same in-place pattern. The Makefile fetches privacy-filter.cpp
|
||||||
|
# at the pinned commit (or symlinks a PRIVACY_FILTER_SRC checkout for local dev).
|
||||||
|
# A stale dir/symlink COPY'd into the image makes the clone step fail (dangling
|
||||||
|
# symlink) or compile against the wrong commit, so keep host build state out.
|
||||||
|
backend/cpp/privacy-filter/privacy-filter.cpp
|
||||||
|
backend/cpp/privacy-filter/build
|
||||||
|
backend/cpp/privacy-filter/grpc-server
|
||||||
|
backend/cpp/privacy-filter/package
|
||||||
|
|
||||||
# Rust backend build output (sources are tracked; target/ is generated)
|
# Rust backend build output (sources are tracked; target/ is generated)
|
||||||
backend/rust/*/target
|
backend/rust/*/target
|
||||||
|
|
||||||
|
|||||||
1014
.github/backend-matrix.yml
vendored
1014
.github/backend-matrix.yml
vendored
File diff suppressed because it is too large
Load Diff
9
.github/workflows/backend_build_darwin.yml
vendored
9
.github/workflows/backend_build_darwin.yml
vendored
@@ -98,6 +98,7 @@ jobs:
|
|||||||
/opt/homebrew/Cellar/hiredis
|
/opt/homebrew/Cellar/hiredis
|
||||||
/opt/homebrew/Cellar/xxhash
|
/opt/homebrew/Cellar/xxhash
|
||||||
/opt/homebrew/Cellar/zstd
|
/opt/homebrew/Cellar/zstd
|
||||||
|
/opt/homebrew/Cellar/nlohmann-json
|
||||||
key: brew-${{ runner.os }}-${{ runner.arch }}-v1-${{ hashFiles('.github/workflows/backend_build_darwin.yml') }}
|
key: brew-${{ runner.os }}-${{ runner.arch }}-v1-${{ hashFiles('.github/workflows/backend_build_darwin.yml') }}
|
||||||
|
|
||||||
- name: Dependencies
|
- name: Dependencies
|
||||||
@@ -109,7 +110,10 @@ jobs:
|
|||||||
# Without explicitly installing them, a brew cache-hit run restores
|
# Without explicitly installing them, a brew cache-hit run restores
|
||||||
# ccache's Cellar dir but skips installing those transitive deps,
|
# ccache's Cellar dir but skips installing those transitive deps,
|
||||||
# and ccache fails at runtime with `dyld: Library not loaded`.
|
# and ccache fails at runtime with `dyld: Library not loaded`.
|
||||||
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm ccache blake3 fmt hiredis xxhash zstd
|
# nlohmann-json is header-only and required by the ds4 backend
|
||||||
|
# (dsml_renderer.cpp includes <nlohmann/json.hpp>); on Linux it comes
|
||||||
|
# from the apt-installed nlohmann-json3-dev in the build image.
|
||||||
|
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm ccache blake3 fmt hiredis xxhash zstd nlohmann-json
|
||||||
# Force-reinstall ccache so brew re-validates its full runtime-dep
|
# Force-reinstall ccache so brew re-validates its full runtime-dep
|
||||||
# closure on every run. This is the durable fix: when the upstream
|
# closure on every run. This is the durable fix: when the upstream
|
||||||
# ccache formula gains a new transitive dep (as it has multiple times
|
# ccache formula gains a new transitive dep (as it has multiple times
|
||||||
@@ -128,7 +132,7 @@ jobs:
|
|||||||
# and decides "already installed" without re-linking, so on a cache-
|
# and decides "already installed" without re-linking, so on a cache-
|
||||||
# hit run the formulas aren't on PATH. Force-link them; --overwrite
|
# hit run the formulas aren't on PATH. Force-link them; --overwrite
|
||||||
# tolerates pre-existing symlinks from earlier installs.
|
# tolerates pre-existing symlinks from earlier installs.
|
||||||
brew link --overwrite protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm ccache blake3 fmt hiredis xxhash zstd 2>/dev/null || true
|
brew link --overwrite protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm ccache blake3 fmt hiredis xxhash zstd nlohmann-json 2>/dev/null || true
|
||||||
|
|
||||||
- name: Save Homebrew cache
|
- name: Save Homebrew cache
|
||||||
if: github.event_name != 'pull_request' && steps.brew-cache.outputs.cache-hit != 'true'
|
if: github.event_name != 'pull_request' && steps.brew-cache.outputs.cache-hit != 'true'
|
||||||
@@ -148,6 +152,7 @@ jobs:
|
|||||||
/opt/homebrew/Cellar/hiredis
|
/opt/homebrew/Cellar/hiredis
|
||||||
/opt/homebrew/Cellar/xxhash
|
/opt/homebrew/Cellar/xxhash
|
||||||
/opt/homebrew/Cellar/zstd
|
/opt/homebrew/Cellar/zstd
|
||||||
|
/opt/homebrew/Cellar/nlohmann-json
|
||||||
key: brew-${{ runner.os }}-${{ runner.arch }}-v1-${{ hashFiles('.github/workflows/backend_build_darwin.yml') }}
|
key: brew-${{ runner.os }}-${{ runner.arch }}-v1-${{ hashFiles('.github/workflows/backend_build_darwin.yml') }}
|
||||||
|
|
||||||
# ---- ccache for llama.cpp CMake builds ----
|
# ---- ccache for llama.cpp CMake builds ----
|
||||||
|
|||||||
36
.github/workflows/bump_deps.yaml
vendored
36
.github/workflows/bump_deps.yaml
vendored
@@ -26,6 +26,10 @@ jobs:
|
|||||||
variable: "DS4_VERSION"
|
variable: "DS4_VERSION"
|
||||||
branch: "main"
|
branch: "main"
|
||||||
file: "backend/cpp/ds4/Makefile"
|
file: "backend/cpp/ds4/Makefile"
|
||||||
|
- repository: "localai-org/privacy-filter.cpp"
|
||||||
|
variable: "PRIVACY_FILTER_VERSION"
|
||||||
|
branch: "master"
|
||||||
|
file: "backend/cpp/privacy-filter/Makefile"
|
||||||
- repository: "ggml-org/whisper.cpp"
|
- repository: "ggml-org/whisper.cpp"
|
||||||
variable: "WHISPER_CPP_VERSION"
|
variable: "WHISPER_CPP_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
@@ -38,10 +42,22 @@ jobs:
|
|||||||
variable: "PARAKEET_VERSION"
|
variable: "PARAKEET_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
file: "backend/go/parakeet-cpp/Makefile"
|
file: "backend/go/parakeet-cpp/Makefile"
|
||||||
- repository: "mudler/dllm.cpp"
|
- repository: "mudler/ced.cpp"
|
||||||
variable: "DLLM_VERSION"
|
variable: "CED_VERSION"
|
||||||
branch: "main"
|
branch: "master"
|
||||||
file: "backend/go/dllm/Makefile"
|
file: "backend/go/ced/Makefile"
|
||||||
|
- repository: "mudler/voice-detect.cpp"
|
||||||
|
variable: "VOICEDETECT_VERSION"
|
||||||
|
branch: "master"
|
||||||
|
file: "backend/go/voice-detect/Makefile"
|
||||||
|
- repository: "mudler/face-detect.cpp"
|
||||||
|
variable: "FACEDETECT_VERSION"
|
||||||
|
branch: "master"
|
||||||
|
file: "backend/go/face-detect/Makefile"
|
||||||
|
- repository: "mudler/depth-anything.cpp"
|
||||||
|
variable: "DEPTHANYTHING_VERSION"
|
||||||
|
branch: "master"
|
||||||
|
file: "backend/go/depth-anything-cpp/Makefile"
|
||||||
- repository: "leejet/stable-diffusion.cpp"
|
- repository: "leejet/stable-diffusion.cpp"
|
||||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
@@ -66,10 +82,18 @@ jobs:
|
|||||||
variable: "RFDETR_VERSION"
|
variable: "RFDETR_VERSION"
|
||||||
branch: "main"
|
branch: "main"
|
||||||
file: "backend/go/rfdetr-cpp/Makefile"
|
file: "backend/go/rfdetr-cpp/Makefile"
|
||||||
- repository: "predict-woo/qwen3-tts.cpp"
|
- repository: "mudler/locate-anything.cpp"
|
||||||
|
variable: "LOCATEANYTHING_VERSION"
|
||||||
|
branch: "master"
|
||||||
|
file: "backend/go/locate-anything-cpp/Makefile"
|
||||||
|
- repository: "ServeurpersoCom/qwentts.cpp"
|
||||||
variable: "QWEN3TTS_CPP_VERSION"
|
variable: "QWEN3TTS_CPP_VERSION"
|
||||||
branch: "main"
|
branch: "master"
|
||||||
file: "backend/go/qwen3-tts-cpp/Makefile"
|
file: "backend/go/qwen3-tts-cpp/Makefile"
|
||||||
|
- repository: "ServeurpersoCom/omnivoice.cpp"
|
||||||
|
variable: "OMNIVOICE_VERSION"
|
||||||
|
branch: "master"
|
||||||
|
file: "backend/go/omnivoice-cpp/Makefile"
|
||||||
- repository: "localai-org/vibevoice.cpp"
|
- repository: "localai-org/vibevoice.cpp"
|
||||||
variable: "VIBEVOICE_CPP_VERSION"
|
variable: "VIBEVOICE_CPP_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
|
|||||||
5
.github/workflows/secscan.yaml
vendored
5
.github/workflows/secscan.yaml
vendored
@@ -21,7 +21,10 @@ jobs:
|
|||||||
uses: securego/gosec@v2.27.1
|
uses: securego/gosec@v2.27.1
|
||||||
with:
|
with:
|
||||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
# backend/go/supertonic is excluded: it vendors upstream supertone-inc/supertonic
|
||||||
|
# (helper.go), whose findings (G304 model-file loads, G404 math/rand for flow-matching
|
||||||
|
# noise, G104 unhandled errors) are inherent to that upstream code, not ours to rewrite.
|
||||||
|
args: '-no-fail -exclude-dir=backend/go/supertonic -fmt sarif -out results.sarif ./...'
|
||||||
- name: Upload SARIF file
|
- name: Upload SARIF file
|
||||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||||
uses: github/codeql-action/upload-sarif@v4
|
uses: github/codeql-action/upload-sarif@v4
|
||||||
|
|||||||
42
.github/workflows/test-extra.yml
vendored
42
.github/workflows/test-extra.yml
vendored
@@ -38,6 +38,7 @@ jobs:
|
|||||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||||
qwen3-tts-cpp: ${{ steps.detect.outputs.qwen3-tts-cpp }}
|
qwen3-tts-cpp: ${{ steps.detect.outputs.qwen3-tts-cpp }}
|
||||||
rfdetr-cpp: ${{ steps.detect.outputs.rfdetr-cpp }}
|
rfdetr-cpp: ${{ steps.detect.outputs.rfdetr-cpp }}
|
||||||
|
locate-anything-cpp: ${{ steps.detect.outputs.locate-anything-cpp }}
|
||||||
vibevoice-cpp: ${{ steps.detect.outputs.vibevoice-cpp }}
|
vibevoice-cpp: ${{ steps.detect.outputs.vibevoice-cpp }}
|
||||||
localvqe: ${{ steps.detect.outputs.localvqe }}
|
localvqe: ${{ steps.detect.outputs.localvqe }}
|
||||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||||
@@ -563,7 +564,7 @@ jobs:
|
|||||||
- name: Run e2e-backends smoke
|
- name: Run e2e-backends smoke
|
||||||
env:
|
env:
|
||||||
BACKEND_IMAGE: quay.io/go-skynet/local-ai-backends:master-cpu-llama-cpp
|
BACKEND_IMAGE: quay.io/go-skynet/local-ai-backends:master-cpu-llama-cpp
|
||||||
BACKEND_TEST_CAPS: health,load,predict,stream,logprobs,logit_bias
|
BACKEND_TEST_CAPS: health,load,predict,stream,logprobs,logit_bias,tokenize
|
||||||
run: |
|
run: |
|
||||||
make test-extra-backend
|
make test-extra-backend
|
||||||
# Realtime e2e with sherpa-onnx driving VAD + STT + TTS against a mocked LLM.
|
# Realtime e2e with sherpa-onnx driving VAD + STT + TTS against a mocked LLM.
|
||||||
@@ -901,6 +902,45 @@ jobs:
|
|||||||
- name: Test rfdetr-cpp
|
- name: Test rfdetr-cpp
|
||||||
run: |
|
run: |
|
||||||
make --jobs=5 --output-sync=target -C backend/go/rfdetr-cpp test
|
make --jobs=5 --output-sync=target -C backend/go/rfdetr-cpp test
|
||||||
|
# Per-backend e2e for locate-anything-cpp: builds the .so + Go binary and
|
||||||
|
# runs `make -C backend/go/locate-anything-cpp test`. test.sh fetches the
|
||||||
|
# locate-anything-q8_0 GGUF (~6.3 GB, NVIDIA LocateAnything-3B) from the
|
||||||
|
# published mudler/locate-anything.cpp-gguf HF repo + a COCO image, then the
|
||||||
|
# Go wire test loads the model and runs an open-vocabulary Detect, asserting
|
||||||
|
# at least one labeled box. Heavier than the other Go backends (it is a 3B),
|
||||||
|
# so it is gated to changes under backend/go/locate-anything-cpp/.
|
||||||
|
tests-locate-anything-cpp:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.locate-anything-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
- name: Dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y build-essential cmake curl libopenblas-dev
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
- name: Display Go version
|
||||||
|
run: go version
|
||||||
|
- name: Proto Dependencies
|
||||||
|
run: |
|
||||||
|
# Install protoc
|
||||||
|
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
|
||||||
|
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||||
|
rm protoc.zip
|
||||||
|
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
||||||
|
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||||
|
PATH="$PATH:$HOME/go/bin" make protogen-go
|
||||||
|
- name: Build locate-anything-cpp
|
||||||
|
run: |
|
||||||
|
make --jobs=5 --output-sync=target -C backend/go/locate-anything-cpp
|
||||||
|
- name: Test locate-anything-cpp
|
||||||
|
run: |
|
||||||
|
make --jobs=5 --output-sync=target -C backend/go/locate-anything-cpp test
|
||||||
# Per-backend smoke for vibevoice-cpp: builds the .so + Go binary and
|
# Per-backend smoke for vibevoice-cpp: builds the .so + Go binary and
|
||||||
# runs `make -C backend/go/vibevoice-cpp test`. test.sh auto-downloads
|
# runs `make -C backend/go/vibevoice-cpp test`. test.sh auto-downloads
|
||||||
# the published mudler/vibevoice.cpp-models bundle (TTS Q8_0 + ASR Q4_K
|
# the published mudler/vibevoice.cpp-models bundle (TTS Q8_0 + ASR Q4_K
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ linters:
|
|||||||
paths:
|
paths:
|
||||||
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
||||||
- 'backend/go/whisper/sources'
|
- 'backend/go/whisper/sources'
|
||||||
|
# Vendored upstream supertonic pipeline (supertone-inc/supertonic go/helper.go).
|
||||||
|
- 'backend/go/supertonic/helper.go'
|
||||||
- 'docs/'
|
- 'docs/'
|
||||||
rules:
|
rules:
|
||||||
# CLI entry points: kong's `env:"..."` tag is the legitimate env→struct
|
# CLI entry points: kong's `env:"..."` tag is the legitimate env→struct
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
|||||||
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
|
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
|
||||||
| [.agents/sglang-backend.md](.agents/sglang-backend.md) | Working on the SGLang backend — `engine_args` validation against ServerArgs, speculative-decoding (EAGLE/EAGLE3/DFLASH/MTP) recipes, parser handling |
|
| [.agents/sglang-backend.md](.agents/sglang-backend.md) | Working on the SGLang backend — `engine_args` validation against ServerArgs, speculative-decoding (EAGLE/EAGLE3/DFLASH/MTP) recipes, parser handling |
|
||||||
| [.agents/ds4-backend.md](.agents/ds4-backend.md) | Working on the ds4 backend - DSML state machine, thinking modes, KV cache, Metal+CUDA matrix |
|
| [.agents/ds4-backend.md](.agents/ds4-backend.md) | Working on the ds4 backend - DSML state machine, thinking modes, KV cache, Metal+CUDA matrix |
|
||||||
| [.agents/dllm-backend.md](.agents/dllm-backend.md) | Working on the dllm backend (DiffusionGemma block-diffusion) - purego C-ABI binding, per-ctx serialization contract, gemma4 renderer/parser, gated test layers |
|
|
||||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ RUN <<EOT bash
|
|||||||
apt-get update && \
|
apt-get update && \
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
|
cuda-nvrtc-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
|
|||||||
23
Makefile
23
Makefile
@@ -1,5 +1,5 @@
|
|||||||
# Disable parallel execution for backend builds
|
# Disable parallel execution for backend builds
|
||||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/dllm backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/omnivoice-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio backends/supertonic backends/depth-anything-cpp backends/privacy-filter
|
||||||
|
|
||||||
GOCMD=go
|
GOCMD=go
|
||||||
GOTEST=$(GOCMD) test
|
GOTEST=$(GOCMD) test
|
||||||
@@ -566,6 +566,7 @@ prepare-test-extra: protogen-python
|
|||||||
$(MAKE) -C backend/python/speaker-recognition
|
$(MAKE) -C backend/python/speaker-recognition
|
||||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||||
$(MAKE) -C backend/go/rfdetr-cpp
|
$(MAKE) -C backend/go/rfdetr-cpp
|
||||||
|
$(MAKE) -C backend/go/locate-anything-cpp
|
||||||
|
|
||||||
test-extra: prepare-test-extra
|
test-extra: prepare-test-extra
|
||||||
$(MAKE) -C backend/python/transformers test
|
$(MAKE) -C backend/python/transformers test
|
||||||
@@ -593,6 +594,9 @@ test-extra: prepare-test-extra
|
|||||||
$(MAKE) -C backend/python/speaker-recognition test
|
$(MAKE) -C backend/python/speaker-recognition test
|
||||||
$(MAKE) -C backend/rust/kokoros test
|
$(MAKE) -C backend/rust/kokoros test
|
||||||
$(MAKE) -C backend/go/rfdetr-cpp test
|
$(MAKE) -C backend/go/rfdetr-cpp test
|
||||||
|
$(MAKE) -C backend/go/locate-anything-cpp test
|
||||||
|
$(MAKE) -C backend/go/depth-anything-cpp test
|
||||||
|
$(MAKE) -C backend/go/supertonic test
|
||||||
|
|
||||||
##
|
##
|
||||||
## End-to-end gRPC tests that exercise a built backend container image.
|
## End-to-end gRPC tests that exercise a built backend container image.
|
||||||
@@ -1160,6 +1164,10 @@ BACKEND_TURBOQUANT = turboquant|turboquant|.|false|false
|
|||||||
# Single-model; hardware-only validation lives at tests/e2e-backends/
|
# Single-model; hardware-only validation lives at tests/e2e-backends/
|
||||||
# (BACKEND_BINARY mode); see docs/superpowers/plans/2026-05-11-ds4-backend.md.
|
# (BACKEND_BINARY mode); see docs/superpowers/plans/2026-05-11-ds4-backend.md.
|
||||||
BACKEND_DS4 = ds4|ds4|.|false|false
|
BACKEND_DS4 = ds4|ds4|.|false|false
|
||||||
|
# privacy-filter wraps the standalone privacy-filter.cpp GGML engine (the
|
||||||
|
# openai-privacy-filter PII/NER token classifier) — the TokenClassify RPC for
|
||||||
|
# the PII redactor tier, on stock ggml with no llama.cpp carry-patches.
|
||||||
|
BACKEND_PRIVACY_FILTER = privacy-filter|privacy-filter|.|false|false
|
||||||
|
|
||||||
# Golang backends
|
# Golang backends
|
||||||
BACKEND_PIPER = piper|golang|.|false|true
|
BACKEND_PIPER = piper|golang|.|false|true
|
||||||
@@ -1171,16 +1179,16 @@ BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|tr
|
|||||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||||
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
||||||
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
||||||
# dllm is mudler/dllm.cpp, the DiffusionGemma block-diffusion engine,
|
BACKEND_DEPTH_ANYTHING_CPP = depth-anything-cpp|golang|.|false|true
|
||||||
# wrapped by the purego backend at backend/go/dllm.
|
|
||||||
BACKEND_DLLM = dllm|golang|.|false|true
|
|
||||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||||
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
||||||
|
BACKEND_OMNIVOICE_CPP = omnivoice-cpp|golang|.|false|true
|
||||||
BACKEND_VIBEVOICE_CPP = vibevoice-cpp|golang|.|false|true
|
BACKEND_VIBEVOICE_CPP = vibevoice-cpp|golang|.|false|true
|
||||||
BACKEND_LOCALVQE = localvqe|golang|.|false|true
|
BACKEND_LOCALVQE = localvqe|golang|.|false|true
|
||||||
BACKEND_OPUS = opus|golang|.|false|true
|
BACKEND_OPUS = opus|golang|.|false|true
|
||||||
BACKEND_SHERPA_ONNX = sherpa-onnx|golang|.|false|true
|
BACKEND_SHERPA_ONNX = sherpa-onnx|golang|.|false|true
|
||||||
|
BACKEND_SUPERTONIC = supertonic|golang|.|false|true
|
||||||
|
|
||||||
# Python backends with root context
|
# Python backends with root context
|
||||||
BACKEND_RERANKERS = rerankers|python|.|false|true
|
BACKEND_RERANKERS = rerankers|python|.|false|true
|
||||||
@@ -1254,6 +1262,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP)))
|
|||||||
$(eval $(call generate-docker-build-target,$(BACKEND_IK_LLAMA_CPP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_IK_LLAMA_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_DS4)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_DS4)))
|
||||||
|
$(eval $(call generate-docker-build-target,$(BACKEND_PRIVACY_FILTER)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_CLOUD_PROXY)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_CLOUD_PROXY)))
|
||||||
@@ -1263,7 +1272,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
|||||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_DLLM)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_DEPTH_ANYTHING_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
||||||
@@ -1296,6 +1305,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
|||||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN3_TTS_CPP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN3_TTS_CPP)))
|
||||||
|
$(eval $(call generate-docker-build-target,$(BACKEND_OMNIVOICE_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE_CPP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCALVQE)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_LOCALVQE)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_MLX)))
|
||||||
@@ -1308,12 +1318,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
|
|||||||
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_RFDETR_CPP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_RFDETR_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||||
|
$(eval $(call generate-docker-build-target,$(BACKEND_SUPERTONIC)))
|
||||||
|
|
||||||
# Pattern rule for docker-save targets
|
# Pattern rule for docker-save targets
|
||||||
docker-save-%: backend-images
|
docker-save-%: backend-images
|
||||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||||
|
|
||||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-crispasr docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-crispasr docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-omnivoice-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy docker-build-supertonic docker-build-depth-anything-cpp docker-build-privacy-filter
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
### Mock Backend for E2E Tests
|
### Mock Backend for E2E Tests
|
||||||
|
|||||||
39
README.md
39
README.md
@@ -29,6 +29,18 @@
|
|||||||
<a href="https://trendshift.io/repositories/5539" target="_blank"><img src="https://trendshift.io/api/badge/repositories/5539" alt="mudler%2FLocalAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/5539" target="_blank"><img src="https://trendshift.io/api/badge/repositories/5539" alt="mudler%2FLocalAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<!-- Keep these links, translations synced daily. -->
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://zdoc.app/de/mudler/LocalAI">Deutsch</a> |
|
||||||
|
<a href="https://zdoc.app/es/mudler/LocalAI">Español</a> |
|
||||||
|
<a href="https://zdoc.app/fr/mudler/LocalAI">français</a> |
|
||||||
|
<a href="https://zdoc.app/ja/mudler/LocalAI">日本語</a> |
|
||||||
|
<a href="https://zdoc.app/ko/mudler/LocalAI">한국어</a> |
|
||||||
|
<a href="https://zdoc.app/pt/mudler/LocalAI">Português</a> |
|
||||||
|
<a href="https://zdoc.app/ru/mudler/LocalAI">Русский</a> |
|
||||||
|
<a href="https://zdoc.app/zh/mudler/LocalAI">中文</a>
|
||||||
|
</p>
|
||||||
|
|
||||||
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||||
|
|
||||||
**A small core, not a bundle.** Each backend wraps a best-in-class engine (llama.cpp, vLLM, whisper.cpp, stable-diffusion, MLX...) in its own image, pulled only when a model needs it. You install nothing you don't use.
|
**A small core, not a bundle.** Each backend wraps a best-in-class engine (llama.cpp, vLLM, whisper.cpp, stable-diffusion, MLX...) in its own image, pulled only when a model needs it. You install nothing you don't use.
|
||||||
@@ -165,6 +177,10 @@ For more details, see the [Getting Started guide](https://localai.io/basics/gett
|
|||||||
|
|
||||||
## Latest News
|
## Latest News
|
||||||
|
|
||||||
|
- **June 2026**: New [realtime voice assistant demo](https://github.com/localai-org/localai-realtime-demo) (a tiny Go client for the Realtime API with a full talk-back voice loop and tool calling), plus [streaming of the realtime LLM / TTS / transcription pipeline stages](https://github.com/mudler/LocalAI/pull/10176) and [configurable WebRTC ICE candidates](https://github.com/mudler/LocalAI/pull/10231).
|
||||||
|
- **June 2026**: Big speech push: the [parakeet.cpp](https://github.com/mudler/parakeet.cpp) ASR engine gains [NeMo-faithful segment timestamps](https://github.com/mudler/LocalAI/pull/10207), a [multilingual streaming Nemotron-3.5 model](https://github.com/mudler/LocalAI/pull/10199), [dynamic batching for concurrent transcription](https://github.com/mudler/LocalAI/pull/10112) and [CUDA graphs](https://github.com/mudler/LocalAI/pull/10273); the new [CrispASR backend](https://github.com/mudler/LocalAI/pull/10099) adds multi-architecture ASR + TTS, and [60 Piper TTS voices across 42 languages](https://github.com/mudler/LocalAI/pull/10296) land in the gallery (plus [per-request TTS instructions and params](https://github.com/mudler/LocalAI/pull/10172)).
|
||||||
|
- **June 2026**: New backends and models: [locate-anything.cpp](https://github.com/mudler/LocalAI/pull/10264) for open-vocabulary object detection via ggml, [Ideogram4 image generation](https://github.com/mudler/LocalAI/pull/10201) in stablediffusion-ggml, [llama.cpp video input](https://github.com/mudler/LocalAI/pull/10216), and the [Gemma 4 QAT family with MTP speculative-decoding pairs](https://github.com/mudler/LocalAI/pull/10215). Plus an [interactive CLI chat mode](https://github.com/mudler/LocalAI/pull/10226) and [RAG source citations in agent responses](https://github.com/mudler/LocalAI/pull/10228).
|
||||||
|
- **June 2026**: Distributed mode hardening: [prefix-cache-aware routing](https://github.com/mudler/LocalAI/pull/10071), a [production-ready request router with auto-sized embedding/rerank batches](https://github.com/mudler/LocalAI/pull/10104), [ds4 layer-split distributed inference](https://github.com/mudler/LocalAI/pull/10098), [NATS JWT auth + TLS/mTLS](https://github.com/mudler/LocalAI/pull/10159), and [resumable file uploads](https://github.com/mudler/LocalAI/pull/10109).
|
||||||
- **May 2026**: **LocalAI 4.3.0** - `llama.cpp` [prompt cache on by default](https://github.com/mudler/LocalAI/pull/9925) (repeated system prompts collapse from minutes to seconds), [keyless cosign signing of backend OCI images](https://github.com/mudler/LocalAI/pull/9823), [per-API-key + per-user usage attribution](https://github.com/mudler/LocalAI/pull/9920), Distributed v3 with [per-request replica routing](https://github.com/mudler/LocalAI/pull/9968). [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.3.0)
|
- **May 2026**: **LocalAI 4.3.0** - `llama.cpp` [prompt cache on by default](https://github.com/mudler/LocalAI/pull/9925) (repeated system prompts collapse from minutes to seconds), [keyless cosign signing of backend OCI images](https://github.com/mudler/LocalAI/pull/9823), [per-API-key + per-user usage attribution](https://github.com/mudler/LocalAI/pull/9920), Distributed v3 with [per-request replica routing](https://github.com/mudler/LocalAI/pull/9968). [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.3.0)
|
||||||
- **May 2026**: **LocalAI 4.2.0** - LocalAI sees and hears: [voice recognition](https://github.com/mudler/LocalAI/pull/9500), [face recognition + antispoofing liveness](https://github.com/mudler/LocalAI/pull/9480), speaker diarization. Plus [drop-in Ollama API](https://github.com/mudler/LocalAI/pull/9284), [video generation](https://github.com/mudler/LocalAI/pull/9420), redesigned UI with i18n + admin-configurable branding, vLLM at feature parity with llama.cpp, and 11 new backends. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.2.0)
|
- **May 2026**: **LocalAI 4.2.0** - LocalAI sees and hears: [voice recognition](https://github.com/mudler/LocalAI/pull/9500), [face recognition + antispoofing liveness](https://github.com/mudler/LocalAI/pull/9480), speaker diarization. Plus [drop-in Ollama API](https://github.com/mudler/LocalAI/pull/9284), [video generation](https://github.com/mudler/LocalAI/pull/9420), redesigned UI with i18n + admin-configurable branding, vLLM at feature parity with llama.cpp, and 11 new backends. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.2.0)
|
||||||
- **April 2026**: **LocalAI 4.1.0** - LocalAI becomes a control tower: distributed cluster mode with VRAM-aware smart routing + autoscaling, multi-user platform with OIDC and API keys, per-user quotas with predictive analytics, in-UI fine-tuning with TRL (auto-export to GGUF), on-the-fly quantization backend, visual pipeline editor. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.1.0)
|
- **April 2026**: **LocalAI 4.1.0** - LocalAI becomes a control tower: distributed cluster mode with VRAM-aware smart routing + autoscaling, multi-user platform with OIDC and API keys, per-user quotas with predictive analytics, in-UI fine-tuning with TRL (auto-export to GGUF), on-the-fly quantization backend, visual pipeline editor. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.1.0)
|
||||||
@@ -204,10 +220,29 @@ For older news and full release notes, see [GitHub Releases](https://github.com/
|
|||||||
|
|
||||||
## Supported Backends & Acceleration
|
## Supported Backends & Acceleration
|
||||||
|
|
||||||
LocalAI supports **36+ backends** including llama.cpp, vLLM, transformers, whisper.cpp, diffusers, MLX, MLX-VLM, and many more. Hardware acceleration is available for **NVIDIA** (CUDA 12/13), **AMD** (ROCm), **Intel** (oneAPI/SYCL), **Apple Silicon** (Metal), **Vulkan**, and **NVIDIA Jetson** (L4T). All backends can be installed on-the-fly from the [Backend Gallery](https://localai.io/backends/).
|
LocalAI supports **60+ backends** including llama.cpp, vLLM, SGLang, transformers, whisper.cpp, diffusers, MLX, MLX-VLM, and many more. Hardware acceleration is available for **NVIDIA** (CUDA 12/13), **AMD** (ROCm), **Intel** (oneAPI/SYCL), **Apple Silicon** (Metal), **Vulkan**, and **NVIDIA Jetson** (L4T). All backends can be installed on-the-fly from the [Backend Gallery](https://localai.io/backends/).
|
||||||
|
|
||||||
See the full [Backend & Model Compatibility Table](https://localai.io/model-compatibility/) and [GPU Acceleration guide](https://localai.io/features/gpu-acceleration/).
|
See the full [Backend & Model Compatibility Table](https://localai.io/model-compatibility/) and [GPU Acceleration guide](https://localai.io/features/gpu-acceleration/).
|
||||||
|
|
||||||
|
### Backends built by us
|
||||||
|
|
||||||
|
Most backends wrap a best-in-class upstream engine. A handful of them are native C/C++/GGML engines (no Python at inference) developed and maintained by the LocalAI project itself:
|
||||||
|
|
||||||
|
| Backend | What it does |
|
||||||
|
|---------|-------------|
|
||||||
|
| [parakeet.cpp](https://github.com/mudler/parakeet.cpp) | C++/GGML port of NVIDIA NeMo Parakeet ASR (tdt/ctc/rnnt/hybrid), with cache-aware streaming transcription |
|
||||||
|
| [ced.cpp](https://github.com/mudler/ced.cpp) | C++/GGML port of the CED audio-tagging models: sound-event classification (527-class AudioSet) over REST and the realtime API for live recognition |
|
||||||
|
| [voxtral.c](https://github.com/mudler/voxtral.c) | Voxtral Realtime 4B speech-to-text in pure C |
|
||||||
|
| [vibevoice.cpp](https://github.com/mudler/vibevoice.cpp) | Native port of Microsoft VibeVoice for TTS (voice cloning) and long-form ASR with speaker diarization |
|
||||||
|
| [rf-detr.cpp](https://github.com/mudler/rf-detr.cpp) | Native RF-DETR object detection and instance segmentation |
|
||||||
|
| [locate-anything.cpp](https://github.com/mudler/locate-anything.cpp) | Open-vocabulary object detection and visual grounding (LocateAnything-3B) |
|
||||||
|
| [depth-anything.cpp](https://github.com/mudler/depth-anything.cpp) | Depth Anything 3 monocular metric depth + camera pose estimation |
|
||||||
|
| [privacy-filter.cpp](https://github.com/localai-org/privacy-filter.cpp) | Standalone GGML PII/NER token-classification engine powering LocalAI's PII redaction tier |
|
||||||
|
| [LocalVQE](https://github.com/localai-org/LocalVQE) | Joint acoustic echo cancellation, noise suppression, and dereverberation |
|
||||||
|
| [local-store](https://github.com/mudler/LocalAI) | Local-first vector database for embeddings (shipped in-tree) |
|
||||||
|
|
||||||
|
We also maintain [apex-quant](https://github.com/localai-org/apex-quant), a per-tensor, per-layer quantization recipe for Mixture-of-Experts models that exploits their structural sparsity to produce GGUFs matching or beating Q8_0 quality - and they run out of the box on stock llama.cpp.
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
- [Documentation](https://localai.io/)
|
- [Documentation](https://localai.io/)
|
||||||
@@ -217,7 +252,7 @@ See the full [Backend & Model Compatibility Table](https://localai.io/model-comp
|
|||||||
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
||||||
- [Installation video walkthrough](https://www.youtube.com/watch?v=cMVNnlqwfw4)
|
- [Installation video walkthrough](https://www.youtube.com/watch?v=cMVNnlqwfw4)
|
||||||
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
||||||
- [Examples](https://github.com/mudler/LocalAI-examples)
|
- [Examples](https://github.com/mudler/LocalAI-examples) — including the [realtime voice assistant demo](https://github.com/localai-org/localai-realtime-demo) (Go client for the Realtime API with tool calling)
|
||||||
|
|
||||||
## Team
|
## Team
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,12 @@ RUN <<EOT bash
|
|||||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils && \
|
||||||
|
apt-get install -y mesa-vulkan-drivers libdrm2
|
||||||
|
# Mesa Vulkan ICD drivers (ANV/RADV/lavapipe) + their manifests. The
|
||||||
|
# LunarG SDK below only provides the loader and shader tooling, not
|
||||||
|
# hardware drivers — without Mesa, package-gpu-libs.sh has no ICD to
|
||||||
|
# bundle and the packaged backend finds no GPU at runtime.
|
||||||
if [ "amd64" = "$TARGETARCH" ]; then
|
if [ "amd64" = "$TARGETARCH" ]; then
|
||||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||||
@@ -206,6 +211,16 @@ RUN if [ "${BACKEND}" = "opus" ]; then \
|
|||||||
apt-get clean && rm -rf /var/lib/apt/lists/*; \
|
apt-get clean && rm -rf /var/lib/apt/lists/*; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# CrispASR's piper TTS backend dlopens libespeak-ng at runtime to phonemize
|
||||||
|
# non-English text (the MIT-clean path; English uses a built-in G2P). Install
|
||||||
|
# the espeak-ng runtime + its libpcaudio/libsonic deps + voice data so
|
||||||
|
# package.sh can bundle them into the FROM scratch image.
|
||||||
|
RUN if [ "${BACKEND}" = "crispasr" ]; then \
|
||||||
|
apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
espeak-ng-data libespeak-ng1 libpcaudio0 libsonic0 && \
|
||||||
|
apt-get clean && rm -rf /var/lib/apt/lists/*; \
|
||||||
|
fi
|
||||||
|
|
||||||
COPY . /LocalAI
|
COPY . /LocalAI
|
||||||
|
|
||||||
RUN git config --global --add safe.directory /LocalAI
|
RUN git config --global --add safe.directory /LocalAI
|
||||||
|
|||||||
109
backend/Dockerfile.privacy-filter
Normal file
109
backend/Dockerfile.privacy-filter
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
ARG BASE_IMAGE=ubuntu:24.04
|
||||||
|
# BUILDER_BASE_IMAGE defaults to BASE_IMAGE so the Dockerfile parses when no
|
||||||
|
# prebuilt base is supplied; the builder-prebuilt stage is only entered when
|
||||||
|
# BUILDER_TARGET=builder-prebuilt, so the fallback content is harmless
|
||||||
|
# (BuildKit prunes the unreferenced builder).
|
||||||
|
ARG BUILDER_BASE_IMAGE=${BASE_IMAGE}
|
||||||
|
# BUILDER_TARGET selects which builder stage the scratch image copies from.
|
||||||
|
# Declared before any FROM so it is usable in `FROM ${BUILDER_TARGET}`. The
|
||||||
|
# backend_build workflow sets it to builder-prebuilt when the matrix entry
|
||||||
|
# provides builder-base-image, else builder-fromsource (the local default).
|
||||||
|
ARG BUILDER_TARGET=builder-fromsource
|
||||||
|
ARG APT_MIRROR=""
|
||||||
|
ARG APT_PORTS_MIRROR=""
|
||||||
|
|
||||||
|
# privacy-filter: standalone GGML engine for the openai-privacy-filter PII/NER
|
||||||
|
# token classifier, wrapped as a LocalAI gRPC backend.
|
||||||
|
#
|
||||||
|
# Mirrors backend/Dockerfile.llama-cpp: the build toolchain (gRPC + cmake +
|
||||||
|
# protoc + conditional CUDA/Vulkan) comes from the shared
|
||||||
|
# .docker/install-base-deps.sh (from-source path) or a prebuilt
|
||||||
|
# quay.io/go-skynet/ci-cache:base-grpc-* image (CI path) — nothing GPU-specific
|
||||||
|
# is hand-rolled here. BUILD_TYPE selects the engine backend in the Makefile:
|
||||||
|
# "" = cpu, "cublas" -> -DPF_CUDA=ON, "vulkan" -> -DPF_VULKAN=ON.
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Stage: builder-fromsource — self-contained build. Runs the same install
|
||||||
|
# script backend/Dockerfile.base-grpc-builder runs, so this path is
|
||||||
|
# bit-equivalent to the prebuilt base. Used when BUILDER_TARGET=builder-fromsource
|
||||||
|
# (the default; local `make backends/privacy-filter`).
|
||||||
|
# ============================================================================
|
||||||
|
FROM ${BASE_IMAGE} AS builder-fromsource
|
||||||
|
ARG BUILD_TYPE
|
||||||
|
ARG CUDA_MAJOR_VERSION
|
||||||
|
ARG CUDA_MINOR_VERSION
|
||||||
|
ARG CMAKE_FROM_SOURCE=false
|
||||||
|
# CUDA Toolkit 13.x needs CMake 3.31.9+ for correct toolchain/arch detection.
|
||||||
|
ARG CMAKE_VERSION=3.31.10
|
||||||
|
ARG GRPC_VERSION=v1.65.0
|
||||||
|
ARG GRPC_MAKEFLAGS="-j4 -Otarget"
|
||||||
|
ARG SKIP_DRIVERS=false
|
||||||
|
ARG TARGETARCH
|
||||||
|
ARG UBUNTU_VERSION=2404
|
||||||
|
ARG APT_MIRROR
|
||||||
|
ARG APT_PORTS_MIRROR
|
||||||
|
|
||||||
|
ENV BUILD_TYPE=${BUILD_TYPE} \
|
||||||
|
CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION} \
|
||||||
|
CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION} \
|
||||||
|
CMAKE_FROM_SOURCE=${CMAKE_FROM_SOURCE} \
|
||||||
|
CMAKE_VERSION=${CMAKE_VERSION} \
|
||||||
|
GRPC_VERSION=${GRPC_VERSION} \
|
||||||
|
GRPC_MAKEFLAGS=${GRPC_MAKEFLAGS} \
|
||||||
|
SKIP_DRIVERS=${SKIP_DRIVERS} \
|
||||||
|
TARGETARCH=${TARGETARCH} \
|
||||||
|
UBUNTU_VERSION=${UBUNTU_VERSION} \
|
||||||
|
APT_MIRROR=${APT_MIRROR} \
|
||||||
|
APT_PORTS_MIRROR=${APT_PORTS_MIRROR} \
|
||||||
|
DEBIAN_FRONTEND=noninteractive
|
||||||
|
# CUDA on PATH (a no-op when CUDA is not installed, e.g. cpu/vulkan builds).
|
||||||
|
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
# apt deps + cmake + protoc + gRPC + conditional CUDA/Vulkan, all from the
|
||||||
|
# shared script (the source of truth that base-grpc-builder also runs).
|
||||||
|
RUN --mount=type=bind,source=.docker/install-base-deps.sh,target=/usr/local/sbin/install-base-deps \
|
||||||
|
--mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||||
|
bash /usr/local/sbin/install-base-deps
|
||||||
|
|
||||||
|
# install-base-deps installs gRPC under /opt/grpc; copy it to /usr/local so the
|
||||||
|
# backend's find_package(gRPC CONFIG) resolves it at the canonical prefix.
|
||||||
|
RUN cp -a /opt/grpc/. /usr/local/
|
||||||
|
|
||||||
|
COPY . /LocalAI
|
||||||
|
|
||||||
|
RUN --mount=type=cache,target=/root/.ccache,id=privacy-filter-ccache-${TARGETARCH}-${BUILD_TYPE},sharing=locked \
|
||||||
|
make -C /LocalAI/backend/cpp/privacy-filter BUILD_TYPE=${BUILD_TYPE} NATIVE=false grpc-server package
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Stage: builder-prebuilt — FROM a prebuilt
|
||||||
|
# quay.io/go-skynet/ci-cache:base-grpc-* image (gRPC at /opt/grpc + apt deps +
|
||||||
|
# CUDA/Vulkan already installed). Used in CI when the matrix entry sets
|
||||||
|
# builder-base-image.
|
||||||
|
# ============================================================================
|
||||||
|
FROM ${BUILDER_BASE_IMAGE} AS builder-prebuilt
|
||||||
|
ARG BUILD_TYPE
|
||||||
|
ARG TARGETARCH
|
||||||
|
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||||
|
# CUDA on PATH (a no-op for the cpu/vulkan base images).
|
||||||
|
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||||
|
|
||||||
|
# Mirror builder-fromsource: the base-grpc image installs gRPC to /opt/grpc but
|
||||||
|
# does not copy it to /usr/local.
|
||||||
|
RUN cp -a /opt/grpc/. /usr/local/
|
||||||
|
|
||||||
|
COPY . /LocalAI
|
||||||
|
|
||||||
|
RUN --mount=type=cache,target=/root/.ccache,id=privacy-filter-ccache-${TARGETARCH}-${BUILD_TYPE},sharing=locked \
|
||||||
|
make -C /LocalAI/backend/cpp/privacy-filter BUILD_TYPE=${BUILD_TYPE} NATIVE=false grpc-server package
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Final stage — copy the package output from the selected builder. BuildKit
|
||||||
|
# does not expand variables in `COPY --from=`, so alias the chosen builder to a
|
||||||
|
# fixed stage name first.
|
||||||
|
# ============================================================================
|
||||||
|
FROM ${BUILDER_TARGET} AS builder
|
||||||
|
|
||||||
|
FROM scratch
|
||||||
|
COPY --from=builder /LocalAI/backend/cpp/privacy-filter/package/. ./
|
||||||
@@ -66,7 +66,12 @@ RUN <<EOT bash
|
|||||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils && \
|
||||||
|
apt-get install -y mesa-vulkan-drivers libdrm2
|
||||||
|
# Mesa Vulkan ICD drivers (ANV/RADV/lavapipe) + their manifests. The
|
||||||
|
# LunarG SDK below only provides the loader and shader tooling, not
|
||||||
|
# hardware drivers — without Mesa, package-gpu-libs.sh has no ICD to
|
||||||
|
# bundle and the packaged backend finds no GPU at runtime.
|
||||||
if [ "amd64" = "$TARGETARCH" ]; then
|
if [ "amd64" = "$TARGETARCH" ]; then
|
||||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||||
@@ -126,6 +131,7 @@ RUN <<EOT bash
|
|||||||
apt-get update && \
|
apt-get update && \
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
|
cuda-nvrtc-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ service Backend {
|
|||||||
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
||||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||||
rpc Detect(DetectOptions) returns (DetectResponse) {}
|
rpc Detect(DetectOptions) returns (DetectResponse) {}
|
||||||
|
// SoundDetection runs an audio-tagging / sound-event-classification model
|
||||||
|
// (e.g. CED over the AudioSet ontology) on a clip and returns scored labels.
|
||||||
|
rpc SoundDetection(SoundDetectionRequest) returns (SoundDetectionResponse) {}
|
||||||
|
rpc Depth(DepthRequest) returns (DepthResponse) {}
|
||||||
rpc FaceVerify(FaceVerifyRequest) returns (FaceVerifyResponse) {}
|
rpc FaceVerify(FaceVerifyRequest) returns (FaceVerifyResponse) {}
|
||||||
rpc FaceAnalyze(FaceAnalyzeRequest) returns (FaceAnalyzeResponse) {}
|
rpc FaceAnalyze(FaceAnalyzeRequest) returns (FaceAnalyzeResponse) {}
|
||||||
rpc VoiceVerify(VoiceVerifyRequest) returns (VoiceVerifyResponse) {}
|
rpc VoiceVerify(VoiceVerifyRequest) returns (VoiceVerifyResponse) {}
|
||||||
@@ -670,6 +674,53 @@ message DetectResponse {
|
|||||||
repeated Detection Detections = 1;
|
repeated Detection Detections = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Sound-event classification / audio tagging messages (CED) ---
|
||||||
|
|
||||||
|
message SoundDetectionRequest {
|
||||||
|
string src = 1; // audio file path (LocalAI writes the upload to disk)
|
||||||
|
int32 top_k = 2; // number of top tags to return (0 = all classes)
|
||||||
|
float threshold = 3; // optional: drop tags scoring below this
|
||||||
|
}
|
||||||
|
|
||||||
|
message SoundClass {
|
||||||
|
string label = 1; // AudioSet class name, e.g. "Baby cry, infant cry"
|
||||||
|
float score = 2; // per-class probability (multi-label, independent)
|
||||||
|
int32 index = 3; // class index in the model ontology
|
||||||
|
}
|
||||||
|
|
||||||
|
message SoundDetectionResponse {
|
||||||
|
repeated SoundClass detections = 1; // score-descending
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Depth estimation messages (Depth Anything 3) ---
|
||||||
|
|
||||||
|
message DepthRequest {
|
||||||
|
string src = 1; // input image (filesystem path or base64-encoded payload)
|
||||||
|
string dst = 2; // optional output directory for exports (glb/colmap)
|
||||||
|
bool include_depth = 3; // return the per-pixel metric depth map
|
||||||
|
bool include_confidence = 4; // return the per-pixel confidence map (DualDPT)
|
||||||
|
bool include_pose = 5; // return camera extrinsics/intrinsics (DualDPT)
|
||||||
|
bool include_sky = 6; // return the per-pixel sky map (mono models)
|
||||||
|
bool include_points = 7; // back-project to a 3D point cloud (DualDPT)
|
||||||
|
float points_conf_thresh = 8; // keep points with confidence >= this threshold
|
||||||
|
repeated string exports = 9; // requested exports: "glb", "colmap"
|
||||||
|
}
|
||||||
|
|
||||||
|
message DepthResponse {
|
||||||
|
int32 width = 1; // processed depth-map width
|
||||||
|
int32 height = 2; // processed depth-map height
|
||||||
|
repeated float depth = 3; // width*height row-major metric depth
|
||||||
|
repeated float confidence = 4; // width*height row-major confidence (DualDPT)
|
||||||
|
repeated float sky = 5; // width*height row-major sky map (mono)
|
||||||
|
repeated float extrinsics = 6; // 12 floats, 3x4 row-major (world-to-camera)
|
||||||
|
repeated float intrinsics = 7; // 9 floats, 3x3 row-major
|
||||||
|
int32 num_points = 8; // number of 3D points
|
||||||
|
repeated float points = 9; // num_points*3 xyz, world space
|
||||||
|
bytes point_colors = 10; // num_points*3 uint8 rgb
|
||||||
|
repeated string export_paths = 11; // paths written for the requested exports
|
||||||
|
bool is_metric = 12; // depth is in metric units
|
||||||
|
}
|
||||||
|
|
||||||
// --- Face recognition messages ---
|
// --- Face recognition messages ---
|
||||||
|
|
||||||
message FacialArea {
|
message FacialArea {
|
||||||
|
|||||||
@@ -9,6 +9,22 @@ option(DS4_NATIVE "Compile with -march=native / -mcpu=native" ON)
|
|||||||
set(DS4_GPU "cpu" CACHE STRING "GPU backend: cpu, cuda, or metal")
|
set(DS4_GPU "cpu" CACHE STRING "GPU backend: cpu, cuda, or metal")
|
||||||
set(DS4_DIR "${CMAKE_CURRENT_SOURCE_DIR}/ds4" CACHE PATH "Path to cloned ds4 source")
|
set(DS4_DIR "${CMAKE_CURRENT_SOURCE_DIR}/ds4" CACHE PATH "Path to cloned ds4 source")
|
||||||
|
|
||||||
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
|
# Homebrew installs protobuf/grpc under a non-default prefix. The generated
|
||||||
|
# backend.pb.cc / backend.grpc.pb.cc pull in google/protobuf and grpcpp
|
||||||
|
# headers, but the hw_grpc_proto library links neither target, so on macOS
|
||||||
|
# the headers (e.g. google/protobuf/runtime_version.h) are never on the
|
||||||
|
# compiler's include path. Add the Homebrew prefix globally, matching the
|
||||||
|
# llama-cpp backend which builds on Darwin CI.
|
||||||
|
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "arm64")
|
||||||
|
set(HOMEBREW_DEFAULT_PREFIX "/opt/homebrew")
|
||||||
|
else()
|
||||||
|
set(HOMEBREW_DEFAULT_PREFIX "/usr/local")
|
||||||
|
endif()
|
||||||
|
link_directories("${HOMEBREW_DEFAULT_PREFIX}/lib")
|
||||||
|
include_directories("${HOMEBREW_DEFAULT_PREFIX}/include")
|
||||||
|
endif()
|
||||||
|
|
||||||
find_package(Threads REQUIRED)
|
find_package(Threads REQUIRED)
|
||||||
find_package(Protobuf CONFIG QUIET)
|
find_package(Protobuf CONFIG QUIET)
|
||||||
if(NOT Protobuf_FOUND)
|
if(NOT Protobuf_FOUND)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# ds4 backend Makefile.
|
# ds4 backend Makefile.
|
||||||
#
|
#
|
||||||
# Upstream pin lives below as DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
|
# Upstream pin lives below as DS4_VERSION?=80ebbc396aee40eedc1d829222f3362d10fa4c6c
|
||||||
# (.github/bump_deps.sh) can find and update it - matches the
|
# (.github/bump_deps.sh) can find and update it - matches the
|
||||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||||
|
|
||||||
DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
|
DS4_VERSION?=80ebbc396aee40eedc1d829222f3362d10fa4c6c
|
||||||
DS4_REPO?=https://github.com/antirez/ds4
|
DS4_REPO?=https://github.com/antirez/ds4
|
||||||
|
|
||||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ extern "C" {
|
|||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include <csignal>
|
#include <csignal>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
@@ -105,6 +107,130 @@ static bool parse_layers_spec(const std::string &spec, ds4_distributed_layers *o
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse a boolean LoadModel option. An empty value (a bare flag-style option
|
||||||
|
// like "ssd_streaming" with no colon) means true so model YAMLs can write
|
||||||
|
// options: ["ssd_streaming"] to enable a switch.
|
||||||
|
static bool parse_bool_option(const std::string &s, bool *out) {
|
||||||
|
if (s.empty() || s == "true" || s == "1" || s == "yes" || s == "on") { *out = true; return true; }
|
||||||
|
if (s == "false" || s == "0" || s == "no" || s == "off") { *out = false; return true; }
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Table-driven mapping from LoadModel option keys to ds4_engine_options fields.
|
||||||
|
// ds4_engine_options is a fixed C struct with no reflection, so the field set
|
||||||
|
// is enumerated once here; adding a future engine knob is a one-line table
|
||||||
|
// entry rather than a new branch in LoadModel. Two fields need ds4's own typed
|
||||||
|
// parsers (Gib, CacheExperts) so a plain string passthrough can't cover them.
|
||||||
|
enum class DsOptType { Bool, Int, Uint, Float, Str, Gib, CacheExperts };
|
||||||
|
|
||||||
|
struct DsOptSpec {
|
||||||
|
const char *key;
|
||||||
|
DsOptType type;
|
||||||
|
size_t off; // byte offset into ds4_engine_options
|
||||||
|
size_t off2; // second offset (CacheExperts writes experts + bytes)
|
||||||
|
bool is_path; // Str values: resolve a relative value against the model dir
|
||||||
|
};
|
||||||
|
|
||||||
|
static const DsOptSpec kEngineOptSpecs[] = {
|
||||||
|
{"mtp_path", DsOptType::Str, offsetof(ds4_engine_options, mtp_path), 0, true},
|
||||||
|
{"mtp_draft", DsOptType::Int, offsetof(ds4_engine_options, mtp_draft_tokens), 0},
|
||||||
|
{"mtp_margin", DsOptType::Float, offsetof(ds4_engine_options, mtp_margin), 0},
|
||||||
|
{"prefill_chunk", DsOptType::Uint, offsetof(ds4_engine_options, prefill_chunk), 0},
|
||||||
|
{"power_percent", DsOptType::Int, offsetof(ds4_engine_options, power_percent), 0},
|
||||||
|
{"warm_weights", DsOptType::Bool, offsetof(ds4_engine_options, warm_weights), 0},
|
||||||
|
{"quality", DsOptType::Bool, offsetof(ds4_engine_options, quality), 0},
|
||||||
|
{"ssd_streaming", DsOptType::Bool, offsetof(ds4_engine_options, ssd_streaming), 0},
|
||||||
|
{"ssd_streaming_cold", DsOptType::Bool, offsetof(ds4_engine_options, ssd_streaming_cold), 0},
|
||||||
|
{"ssd_streaming_preload_experts", DsOptType::Uint, offsetof(ds4_engine_options, ssd_streaming_preload_experts), 0},
|
||||||
|
{"ssd_streaming_cache_experts", DsOptType::CacheExperts, offsetof(ds4_engine_options, ssd_streaming_cache_experts),
|
||||||
|
offsetof(ds4_engine_options, ssd_streaming_cache_bytes)},
|
||||||
|
{"simulate_used_memory", DsOptType::Gib, offsetof(ds4_engine_options, simulate_used_memory_bytes), 0},
|
||||||
|
{"expert_profile_path", DsOptType::Str, offsetof(ds4_engine_options, expert_profile_path), 0, true},
|
||||||
|
{"directional_steering_file", DsOptType::Str, offsetof(ds4_engine_options, directional_steering_file), 0, true},
|
||||||
|
{"directional_steering_attn", DsOptType::Float, offsetof(ds4_engine_options, directional_steering_attn), 0},
|
||||||
|
{"directional_steering_ffn", DsOptType::Float, offsetof(ds4_engine_options, directional_steering_ffn), 0},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply a single key:value LoadModel option to the engine options struct.
|
||||||
|
// Unknown keys are ignored (back-compat: callers pass mixed option sets).
|
||||||
|
// String values are copied into `storage`, whose elements the engine reads by
|
||||||
|
// pointer during ds4_engine_open; `storage` MUST have reserved capacity so
|
||||||
|
// push_back never reallocates and dangles an earlier c_str(). Returns false
|
||||||
|
// with `err` set when a recognized key has an invalid value.
|
||||||
|
static bool apply_engine_option(ds4_engine_options *opt, const std::string &key,
|
||||||
|
const std::string &val, const std::string &model_dir,
|
||||||
|
std::vector<std::string> &storage, std::string &err) {
|
||||||
|
const DsOptSpec *spec = nullptr;
|
||||||
|
for (const auto &s : kEngineOptSpecs) {
|
||||||
|
if (key == s.key) { spec = &s; break; }
|
||||||
|
}
|
||||||
|
if (!spec) return true; // unknown key: ignore
|
||||||
|
|
||||||
|
char *base = reinterpret_cast<char *>(opt);
|
||||||
|
switch (spec->type) {
|
||||||
|
case DsOptType::Bool: {
|
||||||
|
bool b = false;
|
||||||
|
if (!parse_bool_option(val, &b)) { err = key + " must be true/false"; return false; }
|
||||||
|
*reinterpret_cast<bool *>(base + spec->off) = b;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
case DsOptType::Int: {
|
||||||
|
char *end = nullptr;
|
||||||
|
long v = std::strtol(val.c_str(), &end, 10);
|
||||||
|
if (val.empty() || !end || *end != '\0') { err = key + " must be an integer"; return false; }
|
||||||
|
*reinterpret_cast<int *>(base + spec->off) = static_cast<int>(v);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
case DsOptType::Uint: {
|
||||||
|
char *end = nullptr;
|
||||||
|
long v = std::strtol(val.c_str(), &end, 10);
|
||||||
|
if (val.empty() || !end || *end != '\0' || v < 0 || v > static_cast<long>(UINT32_MAX)) {
|
||||||
|
err = key + " must be a non-negative integer"; return false;
|
||||||
|
}
|
||||||
|
*reinterpret_cast<uint32_t *>(base + spec->off) = static_cast<uint32_t>(v);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
case DsOptType::Float: {
|
||||||
|
char *end = nullptr;
|
||||||
|
float f = std::strtof(val.c_str(), &end);
|
||||||
|
if (val.empty() || !end || *end != '\0') { err = key + " must be a number"; return false; }
|
||||||
|
*reinterpret_cast<float *>(base + spec->off) = f;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
case DsOptType::Str: {
|
||||||
|
// Resolve a relative path option (e.g. mtp_path: a sibling GGUF the
|
||||||
|
// gallery downloaded next to the model) against the model directory, so
|
||||||
|
// YAMLs reference companion files by name. Absolute values pass through.
|
||||||
|
if (spec->is_path && !model_dir.empty() && !val.empty() && val.front() != '/') {
|
||||||
|
storage.push_back(model_dir + "/" + val);
|
||||||
|
} else {
|
||||||
|
storage.push_back(val);
|
||||||
|
}
|
||||||
|
*reinterpret_cast<const char **>(base + spec->off) = storage.back().c_str();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
case DsOptType::Gib: {
|
||||||
|
uint64_t bytes = 0;
|
||||||
|
if (!ds4_parse_gib_arg(val.c_str(), &bytes)) {
|
||||||
|
err = key + " must be a GiB value, e.g. 64GB"; return false;
|
||||||
|
}
|
||||||
|
*reinterpret_cast<uint64_t *>(base + spec->off) = bytes;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
case DsOptType::CacheExperts: {
|
||||||
|
uint32_t experts = 0;
|
||||||
|
uint64_t bytes = 0;
|
||||||
|
if (!ds4_parse_streaming_cache_experts_arg(val.c_str(), &experts, &bytes)) {
|
||||||
|
err = key + " must be a positive expert count or a <number>GB budget"; return false;
|
||||||
|
}
|
||||||
|
*reinterpret_cast<uint32_t *>(base + spec->off) = experts;
|
||||||
|
*reinterpret_cast<uint64_t *>(base + spec->off2) = bytes;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// When acting as a distributed coordinator, block until the worker route
|
// When acting as a distributed coordinator, block until the worker route
|
||||||
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
||||||
// elapses. Returns an empty string on success, or an error message to return
|
// elapses. Returns an empty string on success, or an error message to return
|
||||||
@@ -476,39 +602,10 @@ public:
|
|||||||
return GStatus::OK;
|
return GStatus::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string mtp_path;
|
|
||||||
int mtp_draft = 0;
|
|
||||||
float mtp_margin = 3.0f;
|
|
||||||
std::string ds4_role, ds4_layers, ds4_listen;
|
|
||||||
for (const auto &opt : request->options()) {
|
|
||||||
auto [k, v] = split_option(opt);
|
|
||||||
if (k == "mtp_path") mtp_path = v;
|
|
||||||
else if (k == "mtp_draft") mtp_draft = std::stoi(v);
|
|
||||||
else if (k == "mtp_margin") mtp_margin = std::stof(v);
|
|
||||||
else if (k == "kv_cache_dir") g_kv_cache_dir = v;
|
|
||||||
else if (k == "ds4_role") ds4_role = v;
|
|
||||||
else if (k == "ds4_layers") ds4_layers = v;
|
|
||||||
else if (k == "ds4_listen") ds4_listen = v;
|
|
||||||
else if (k == "ds4_route_timeout") {
|
|
||||||
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
|
||||||
result->set_success(false);
|
|
||||||
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
|
||||||
return GStatus::OK;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
g_kv_cache.SetDir(g_kv_cache_dir);
|
|
||||||
|
|
||||||
ds4_engine_options opt = {};
|
ds4_engine_options opt = {};
|
||||||
opt.model_path = model_path.c_str();
|
opt.model_path = model_path.c_str();
|
||||||
opt.mtp_path = mtp_path.empty() ? nullptr : mtp_path.c_str();
|
|
||||||
opt.n_threads = request->threads() > 0 ? request->threads() : 0;
|
opt.n_threads = request->threads() > 0 ? request->threads() : 0;
|
||||||
opt.mtp_draft_tokens = mtp_draft;
|
opt.mtp_margin = 3.0f; // ds4 default; overridable via the mtp_margin option
|
||||||
opt.mtp_margin = mtp_margin;
|
|
||||||
opt.directional_steering_file = nullptr;
|
|
||||||
opt.warm_weights = false;
|
|
||||||
opt.quality = false;
|
|
||||||
|
|
||||||
#if defined(DS4_NO_GPU)
|
#if defined(DS4_NO_GPU)
|
||||||
opt.backend = DS4_BACKEND_CPU;
|
opt.backend = DS4_BACKEND_CPU;
|
||||||
@@ -518,6 +615,46 @@ public:
|
|||||||
opt.backend = DS4_BACKEND_CUDA;
|
opt.backend = DS4_BACKEND_CUDA;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Stable storage for string-valued engine options. The engine reads
|
||||||
|
// these by pointer during ds4_engine_open, so the std::string backing
|
||||||
|
// store must outlive the call and not reallocate; reserve up front so
|
||||||
|
// push_back keeps every prior c_str() valid. Static + clear() reuses
|
||||||
|
// the buffer across LoadModel calls (the old engine is closed above).
|
||||||
|
static std::vector<std::string> s_opt_strings;
|
||||||
|
s_opt_strings.clear();
|
||||||
|
s_opt_strings.reserve(sizeof(kEngineOptSpecs) / sizeof(kEngineOptSpecs[0]));
|
||||||
|
|
||||||
|
// Directory of the main model, used to resolve relative path options.
|
||||||
|
std::string model_dir;
|
||||||
|
if (auto slash = model_path.find_last_of('/'); slash != std::string::npos) {
|
||||||
|
model_dir = model_path.substr(0, slash);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ds4_role, ds4_layers, ds4_listen;
|
||||||
|
for (const auto &o : request->options()) {
|
||||||
|
auto [k, v] = split_option(o);
|
||||||
|
if (k == "kv_cache_dir") { g_kv_cache_dir = v; continue; }
|
||||||
|
else if (k == "ds4_role") { ds4_role = v; continue; }
|
||||||
|
else if (k == "ds4_layers") { ds4_layers = v; continue; }
|
||||||
|
else if (k == "ds4_listen") { ds4_listen = v; continue; }
|
||||||
|
else if (k == "ds4_route_timeout") {
|
||||||
|
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
||||||
|
result->set_success(false);
|
||||||
|
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
||||||
|
return GStatus::OK;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string err;
|
||||||
|
if (!apply_engine_option(&opt, k, v, model_dir, s_opt_strings, err)) {
|
||||||
|
result->set_success(false);
|
||||||
|
result->set_message("ds4: " + err);
|
||||||
|
return GStatus::OK;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
g_kv_cache.SetDir(g_kv_cache_dir);
|
||||||
|
|
||||||
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
||||||
// distributed inference: this process listens on ds4_listen and owns
|
// distributed inference: this process listens on ds4_listen and owns
|
||||||
// the ds4_layers slice; workers dial in (see `local-ai worker
|
// the ds4_layers slice; workers dial in (see `local-ai worker
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
IK_LLAMA_VERSION?=e6f8112f3ba126eed3ff5b30cdd08085414a7516
|
IK_LLAMA_VERSION?=6c00e87ac84404af588ad2e65935bd6f079c696f
|
||||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
LLAMA_VERSION?=039e20a2db9e87b2477c76cc04905f3e1acad77f
|
LLAMA_VERSION?=e475fa2b5f9fb50c3d6fc3e7c6fdf1e004465b62
|
||||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
|
|||||||
@@ -18,6 +18,18 @@
|
|||||||
#if __has_include("server-chat.cpp")
|
#if __has_include("server-chat.cpp")
|
||||||
#include "server-chat.cpp"
|
#include "server-chat.cpp"
|
||||||
#endif
|
#endif
|
||||||
|
// server-schema.cpp exists only in llama.cpp after the upstream refactor that
|
||||||
|
// extracted the JSON request-schema evaluation (previously the static
|
||||||
|
// server_task::params_from_json_cmpl) into server_schema::eval_llama_cmpl_schema.
|
||||||
|
// server-context.cpp and grpc-server.cpp both call into it, so its definitions
|
||||||
|
// must be part of this translation unit or the link fails. __has_include keeps
|
||||||
|
// the source compatible with older pins/forks (e.g. llama-cpp-turboquant) that
|
||||||
|
// predate the split and still expose params_from_json_cmpl (see the guarded
|
||||||
|
// call sites below).
|
||||||
|
#if __has_include("server-schema.cpp")
|
||||||
|
#define LOCALAI_HAS_SERVER_SCHEMA 1
|
||||||
|
#include "server-schema.cpp"
|
||||||
|
#endif
|
||||||
#include "server-context.cpp"
|
#include "server-context.cpp"
|
||||||
|
|
||||||
// LocalAI
|
// LocalAI
|
||||||
@@ -1922,25 +1934,27 @@ public:
|
|||||||
body_json["min_p"] = data["min_p"];
|
body_json["min_p"] = data["min_p"];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass enable_thinking via chat_template_kwargs (where oaicompat_chat_params_parse reads it)
|
// Forward the chat_template_kwargs the Go layer resolved (model config
|
||||||
|
// chat_template_kwargs + per-request metadata: enable_thinking,
|
||||||
|
// reasoning_effort, preserve_thinking, ...). One generic merge replaces
|
||||||
|
// the previous per-key handling - new template levers need no C++ change.
|
||||||
|
// oaicompat_chat_params_parse reads these from body_json.
|
||||||
const auto& metadata = request->metadata();
|
const auto& metadata = request->metadata();
|
||||||
auto et_it = metadata.find("enable_thinking");
|
auto ctk_it = metadata.find("chat_template_kwargs");
|
||||||
if (et_it != metadata.end()) {
|
if (ctk_it != metadata.end() && !ctk_it->second.empty()) {
|
||||||
if (!body_json.contains("chat_template_kwargs")) {
|
try {
|
||||||
body_json["chat_template_kwargs"] = json::object();
|
json ctk = json::parse(ctk_it->second);
|
||||||
|
if (ctk.is_object()) {
|
||||||
|
if (!body_json.contains("chat_template_kwargs")) {
|
||||||
|
body_json["chat_template_kwargs"] = json::object();
|
||||||
|
}
|
||||||
|
for (auto& el : ctk.items()) {
|
||||||
|
body_json["chat_template_kwargs"][el.key()] = el.value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
SRV_WRN("failed to parse chat_template_kwargs metadata: %s\n", e.what());
|
||||||
}
|
}
|
||||||
body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
|
||||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
|
||||||
// from enable_thinking which those templates ignore.
|
|
||||||
auto re_it = metadata.find("reasoning_effort");
|
|
||||||
if (re_it != metadata.end() && !re_it->second.empty()) {
|
|
||||||
if (!body_json.contains("chat_template_kwargs")) {
|
|
||||||
body_json["chat_template_kwargs"] = json::object();
|
|
||||||
}
|
|
||||||
body_json["chat_template_kwargs"]["reasoning_effort"] = re_it->second;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||||
@@ -2100,7 +2114,11 @@ public:
|
|||||||
task.index = i;
|
task.index = i;
|
||||||
|
|
||||||
task.tokens = std::move(inputs[i]);
|
task.tokens = std::move(inputs[i]);
|
||||||
|
#ifdef LOCALAI_HAS_SERVER_SCHEMA
|
||||||
|
task.params = server_schema::eval_llama_cmpl_schema(
|
||||||
|
#else
|
||||||
task.params = server_task::params_from_json_cmpl(
|
task.params = server_task::params_from_json_cmpl(
|
||||||
|
#endif
|
||||||
ctx_server.impl->vocab,
|
ctx_server.impl->vocab,
|
||||||
params_base,
|
params_base,
|
||||||
ctx_server.get_meta().slot_n_ctx,
|
ctx_server.get_meta().slot_n_ctx,
|
||||||
@@ -2114,7 +2132,7 @@ public:
|
|||||||
// cannot detect tool calls or separate reasoning from content.
|
// cannot detect tool calls or separate reasoning from content.
|
||||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by eval_llama_cmpl_schema
|
||||||
|
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
@@ -2756,25 +2774,26 @@ public:
|
|||||||
body_json["min_p"] = data["min_p"];
|
body_json["min_p"] = data["min_p"];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass enable_thinking via chat_template_kwargs (where oaicompat_chat_params_parse reads it)
|
// Forward the chat_template_kwargs the Go layer resolved (model config
|
||||||
|
// chat_template_kwargs + per-request metadata: enable_thinking,
|
||||||
|
// reasoning_effort, preserve_thinking, ...). One generic merge replaces
|
||||||
|
// the previous per-key handling - new template levers need no C++ change.
|
||||||
const auto& predict_metadata = request->metadata();
|
const auto& predict_metadata = request->metadata();
|
||||||
auto predict_et_it = predict_metadata.find("enable_thinking");
|
auto predict_ctk_it = predict_metadata.find("chat_template_kwargs");
|
||||||
if (predict_et_it != predict_metadata.end()) {
|
if (predict_ctk_it != predict_metadata.end() && !predict_ctk_it->second.empty()) {
|
||||||
if (!body_json.contains("chat_template_kwargs")) {
|
try {
|
||||||
body_json["chat_template_kwargs"] = json::object();
|
json ctk = json::parse(predict_ctk_it->second);
|
||||||
|
if (ctk.is_object()) {
|
||||||
|
if (!body_json.contains("chat_template_kwargs")) {
|
||||||
|
body_json["chat_template_kwargs"] = json::object();
|
||||||
|
}
|
||||||
|
for (auto& el : ctk.items()) {
|
||||||
|
body_json["chat_template_kwargs"][el.key()] = el.value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
SRV_WRN("failed to parse chat_template_kwargs metadata: %s\n", e.what());
|
||||||
}
|
}
|
||||||
body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
|
||||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
|
||||||
// from enable_thinking which those templates ignore.
|
|
||||||
auto predict_re_it = predict_metadata.find("reasoning_effort");
|
|
||||||
if (predict_re_it != predict_metadata.end() && !predict_re_it->second.empty()) {
|
|
||||||
if (!body_json.contains("chat_template_kwargs")) {
|
|
||||||
body_json["chat_template_kwargs"] = json::object();
|
|
||||||
}
|
|
||||||
body_json["chat_template_kwargs"]["reasoning_effort"] = predict_re_it->second;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||||
@@ -2937,7 +2956,11 @@ public:
|
|||||||
task.index = i;
|
task.index = i;
|
||||||
|
|
||||||
task.tokens = std::move(inputs[i]);
|
task.tokens = std::move(inputs[i]);
|
||||||
|
#ifdef LOCALAI_HAS_SERVER_SCHEMA
|
||||||
|
task.params = server_schema::eval_llama_cmpl_schema(
|
||||||
|
#else
|
||||||
task.params = server_task::params_from_json_cmpl(
|
task.params = server_task::params_from_json_cmpl(
|
||||||
|
#endif
|
||||||
ctx_server.impl->vocab,
|
ctx_server.impl->vocab,
|
||||||
params_base,
|
params_base,
|
||||||
ctx_server.get_meta().slot_n_ctx,
|
ctx_server.get_meta().slot_n_ctx,
|
||||||
@@ -2949,7 +2972,7 @@ public:
|
|||||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by eval_llama_cmpl_schema
|
||||||
|
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
@@ -3486,7 +3509,7 @@ public:
|
|||||||
if (body.count("prompt") != 0) {
|
if (body.count("prompt") != 0) {
|
||||||
const bool add_special = json_value(body, "add_special", false);
|
const bool add_special = json_value(body, "add_special", false);
|
||||||
|
|
||||||
llama_tokens tokens = tokenize_mixed(ctx_server.impl->vocab, body.at("content"), add_special, true);
|
llama_tokens tokens = tokenize_mixed(ctx_server.impl->vocab, body.at("prompt"), add_special, true);
|
||||||
|
|
||||||
|
|
||||||
for (const auto& token : tokens) {
|
for (const auto& token : tokens) {
|
||||||
|
|||||||
9
backend/cpp/privacy-filter/.gitignore
vendored
Normal file
9
backend/cpp/privacy-filter/.gitignore
vendored
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
/privacy-filter.cpp
|
||||||
|
build/
|
||||||
|
package/
|
||||||
|
grpc-server
|
||||||
|
*.o
|
||||||
|
backend.pb.cc
|
||||||
|
backend.pb.h
|
||||||
|
backend.grpc.pb.cc
|
||||||
|
backend.grpc.pb.h
|
||||||
69
backend/cpp/privacy-filter/CMakeLists.txt
Normal file
69
backend/cpp/privacy-filter/CMakeLists.txt
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.21)
|
||||||
|
project(privacy-filter-grpc-server LANGUAGES CXX C)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
set(TARGET grpc-server)
|
||||||
|
|
||||||
|
# Path to the privacy-filter.cpp engine sources. The Makefile arranges for this
|
||||||
|
# to exist (clone of a pinned commit, or a symlink to PRIVACY_FILTER_SRC).
|
||||||
|
set(PRIVACY_FILTER_DIR "${CMAKE_CURRENT_SOURCE_DIR}/privacy-filter.cpp"
|
||||||
|
CACHE PATH "Path to the privacy-filter.cpp engine source tree")
|
||||||
|
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
find_package(Protobuf CONFIG QUIET)
|
||||||
|
if(NOT Protobuf_FOUND)
|
||||||
|
find_package(Protobuf REQUIRED)
|
||||||
|
endif()
|
||||||
|
find_package(gRPC CONFIG QUIET)
|
||||||
|
if(NOT gRPC_FOUND)
|
||||||
|
# Ubuntu's apt-installed grpc++ does not ship a CMake config - fall back.
|
||||||
|
find_library(GRPCPP_LIB grpc++ REQUIRED)
|
||||||
|
find_library(GRPCPP_REFLECTION_LIB grpc++_reflection REQUIRED)
|
||||||
|
add_library(gRPC::grpc++ INTERFACE IMPORTED)
|
||||||
|
set_target_properties(gRPC::grpc++ PROPERTIES INTERFACE_LINK_LIBRARIES "${GRPCPP_LIB}")
|
||||||
|
add_library(gRPC::grpc++_reflection INTERFACE IMPORTED)
|
||||||
|
set_target_properties(gRPC::grpc++_reflection PROPERTIES INTERFACE_LINK_LIBRARIES "${GRPCPP_REFLECTION_LIB}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_program(_PROTOC NAMES protoc REQUIRED)
|
||||||
|
find_program(_GRPC_CPP_PLUGIN NAMES grpc_cpp_plugin REQUIRED)
|
||||||
|
|
||||||
|
get_filename_component(HW_PROTO "${CMAKE_CURRENT_SOURCE_DIR}/../../backend.proto" ABSOLUTE)
|
||||||
|
get_filename_component(HW_PROTO_PATH "${HW_PROTO}" PATH)
|
||||||
|
|
||||||
|
set(HW_PROTO_SRCS "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.cc")
|
||||||
|
set(HW_PROTO_HDRS "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.h")
|
||||||
|
set(HW_GRPC_SRCS "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.cc")
|
||||||
|
set(HW_GRPC_HDRS "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.h")
|
||||||
|
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT "${HW_PROTO_SRCS}" "${HW_PROTO_HDRS}" "${HW_GRPC_SRCS}" "${HW_GRPC_HDRS}"
|
||||||
|
COMMAND ${_PROTOC}
|
||||||
|
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
|
||||||
|
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
|
||||||
|
-I "${HW_PROTO_PATH}"
|
||||||
|
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN}"
|
||||||
|
"${HW_PROTO}"
|
||||||
|
DEPENDS "${HW_PROTO}")
|
||||||
|
|
||||||
|
add_library(hw_grpc_proto STATIC
|
||||||
|
${HW_GRPC_SRCS} ${HW_GRPC_HDRS}
|
||||||
|
${HW_PROTO_SRCS} ${HW_PROTO_HDRS})
|
||||||
|
target_include_directories(hw_grpc_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
|
||||||
|
# Build only the pf static lib (+ ggml) from the engine tree — no CLI/bench/tests.
|
||||||
|
# PF_VULKAN is honored when passed on the cmake command line (it lands in the
|
||||||
|
# shared cache the engine reads).
|
||||||
|
set(PF_BUILD_TOOLS OFF CACHE BOOL "" FORCE)
|
||||||
|
set(PF_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
|
add_subdirectory(${PRIVACY_FILTER_DIR} ${CMAKE_CURRENT_BINARY_DIR}/privacy-filter.cpp)
|
||||||
|
|
||||||
|
add_executable(${TARGET} grpc-server.cpp)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE
|
||||||
|
pf
|
||||||
|
hw_grpc_proto
|
||||||
|
gRPC::grpc++
|
||||||
|
gRPC::grpc++_reflection
|
||||||
|
protobuf::libprotobuf
|
||||||
|
Threads::Threads)
|
||||||
77
backend/cpp/privacy-filter/Makefile
Normal file
77
backend/cpp/privacy-filter/Makefile
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# privacy-filter backend Makefile.
|
||||||
|
#
|
||||||
|
# Wraps the standalone privacy-filter.cpp GGML engine (the openai-privacy-filter
|
||||||
|
# PII/NER token classifier) as a LocalAI gRPC backend. The engine source is
|
||||||
|
# fetched at the pin below — .github/workflows/bump_deps.yaml finds and updates
|
||||||
|
# PRIVACY_FILTER_VERSION, matching the llama-cpp / ds4 convention.
|
||||||
|
#
|
||||||
|
# Local development: point at a working checkout instead of cloning, e.g.
|
||||||
|
# make PRIVACY_FILTER_SRC=$HOME/c/privacy-filter.cpp grpc-server
|
||||||
|
|
||||||
|
PRIVACY_FILTER_VERSION?=98f52c5ef2250f207cc6b9a6aef05393a120cb7c
|
||||||
|
PRIVACY_FILTER_REPO?=https://github.com/localai-org/privacy-filter.cpp
|
||||||
|
PRIVACY_FILTER_SRC?=
|
||||||
|
|
||||||
|
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||||
|
BUILD_DIR := build
|
||||||
|
|
||||||
|
BUILD_TYPE ?=
|
||||||
|
NATIVE ?= false
|
||||||
|
JOBS ?= $(shell nproc 2>/dev/null || echo 4)
|
||||||
|
|
||||||
|
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
|
||||||
|
|
||||||
|
# GPU backends; the default (cpu) needs no extra flags. 'cublas' is LocalAI's
|
||||||
|
# name for the CUDA build (matches llama-cpp / ds4), mapping to the engine's
|
||||||
|
# GGML_CUDA path; 'vulkan' selects the ggml Vulkan backend.
|
||||||
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
|
CMAKE_ARGS += -DPF_CUDA=ON
|
||||||
|
endif
|
||||||
|
ifeq ($(BUILD_TYPE),vulkan)
|
||||||
|
CMAKE_ARGS += -DPF_VULKAN=ON
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Portable binaries for distribution: disable -march=native unless asked.
|
||||||
|
ifneq ($(NATIVE),true)
|
||||||
|
CMAKE_ARGS += -DGGML_NATIVE=OFF
|
||||||
|
endif
|
||||||
|
|
||||||
|
.PHONY: grpc-server package clean purge test all
|
||||||
|
all: grpc-server
|
||||||
|
|
||||||
|
# Provide the engine sources at ./privacy-filter.cpp. With PRIVACY_FILTER_SRC
|
||||||
|
# set we symlink a local checkout (instant, no network); otherwise we clone the
|
||||||
|
# pinned commit and its ggml submodule. The directory/symlink is the target, so
|
||||||
|
# make only does this once — run 'make purge && make' to refetch after a bump.
|
||||||
|
privacy-filter.cpp:
|
||||||
|
ifneq ($(PRIVACY_FILTER_SRC),)
|
||||||
|
ln -sfn $(abspath $(PRIVACY_FILTER_SRC)) privacy-filter.cpp
|
||||||
|
else
|
||||||
|
mkdir -p privacy-filter.cpp
|
||||||
|
cd privacy-filter.cpp && \
|
||||||
|
git init -q && \
|
||||||
|
git remote add origin $(PRIVACY_FILTER_REPO) && \
|
||||||
|
git fetch --depth 1 origin $(PRIVACY_FILTER_VERSION) && \
|
||||||
|
git checkout FETCH_HEAD && \
|
||||||
|
git submodule update --init --recursive --depth 1
|
||||||
|
endif
|
||||||
|
|
||||||
|
grpc-server: privacy-filter.cpp
|
||||||
|
@echo "Building privacy-filter grpc-server ($(BUILD_TYPE)) with $(CMAKE_ARGS)"
|
||||||
|
mkdir -p $(BUILD_DIR)
|
||||||
|
cd $(BUILD_DIR) && cmake $(CMAKE_ARGS) $(CURRENT_MAKEFILE_DIR) && cmake --build . --config Release -j $(JOBS)
|
||||||
|
cp $(BUILD_DIR)/grpc-server grpc-server
|
||||||
|
|
||||||
|
package: grpc-server
|
||||||
|
bash package.sh
|
||||||
|
|
||||||
|
test:
|
||||||
|
@echo "privacy-filter backend: parity/regression coverage lives in the engine repo"
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf $(BUILD_DIR) grpc-server package
|
||||||
|
|
||||||
|
# 'privacy-filter.cpp' may be a symlink (PRIVACY_FILTER_SRC) — rm without a
|
||||||
|
# trailing slash removes the link, never the linked-to checkout.
|
||||||
|
purge: clean
|
||||||
|
rm -rf privacy-filter.cpp
|
||||||
210
backend/cpp/privacy-filter/grpc-server.cpp
Normal file
210
backend/cpp/privacy-filter/grpc-server.cpp
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
// privacy-filter LocalAI gRPC backend.
|
||||||
|
//
|
||||||
|
// Thin shim over privacy-filter.cpp's flat C API (include/pf.h): a standalone
|
||||||
|
// GGML engine for the openai-privacy-filter token-classification model family
|
||||||
|
// (PII NER). It replaces the llama.cpp-patched TokenClassify path for this one
|
||||||
|
// model family — same GGUF files, no llama.cpp carry-patches.
|
||||||
|
//
|
||||||
|
// Only the RPCs the PII tier needs are implemented: LoadModel, TokenClassify,
|
||||||
|
// plus Health / Status / Free. Everything else inherits the generated base
|
||||||
|
// class default (UNIMPLEMENTED).
|
||||||
|
|
||||||
|
#include "backend.pb.h"
|
||||||
|
#include "backend.grpc.pb.h"
|
||||||
|
|
||||||
|
#include "pf.h"
|
||||||
|
|
||||||
|
#include <grpcpp/grpcpp.h>
|
||||||
|
#include <grpcpp/server.h>
|
||||||
|
#include <grpcpp/server_builder.h>
|
||||||
|
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <chrono>
|
||||||
|
#include <csignal>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
using grpc::Server;
|
||||||
|
using grpc::ServerBuilder;
|
||||||
|
using grpc::ServerContext;
|
||||||
|
// NOTE: do NOT alias grpc::Status as Status — the Status RPC method below would
|
||||||
|
// shadow the type and break the other method signatures. Use GStatus instead.
|
||||||
|
using GStatus = ::grpc::Status;
|
||||||
|
using grpc::StatusCode;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// The engine is single-model-per-process: LocalAI spawns one backend process
|
||||||
|
// per loaded model. g_mu guards (re)load against in-flight classification.
|
||||||
|
std::mutex g_mu;
|
||||||
|
pf_ctx * g_ctx = nullptr;
|
||||||
|
std::atomic<Server *> g_server{nullptr};
|
||||||
|
|
||||||
|
// Resolve the device string the engine expects ("cpu" / "gpu" / "cuda" /
|
||||||
|
// "vulkan", optionally ":N"). Priority: an explicit "device:..." in
|
||||||
|
// ModelOptions.Options, then a non-zero NGPULayers as a coarse "use the GPU"
|
||||||
|
// signal, else CPU. "gpu" lets the engine pick whichever GPU backend this
|
||||||
|
// binary was compiled with (CUDA or Vulkan), so the same config works on
|
||||||
|
// either build; pin "device:cuda"/"device:vulkan" to be explicit.
|
||||||
|
std::string resolve_device(const backend::ModelOptions * opts) {
|
||||||
|
for (const auto & o : opts->options()) {
|
||||||
|
const std::string prefix = "device:";
|
||||||
|
if (o.rfind(prefix, 0) == 0) {
|
||||||
|
return o.substr(prefix.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (opts->ngpulayers() > 0) {
|
||||||
|
return "gpu";
|
||||||
|
}
|
||||||
|
return "cpu";
|
||||||
|
}
|
||||||
|
|
||||||
|
class PrivacyFilterBackend final : public backend::Backend::Service {
|
||||||
|
public:
|
||||||
|
GStatus Health(ServerContext *, const backend::HealthMessage *,
|
||||||
|
backend::Reply * reply) override {
|
||||||
|
reply->set_message("OK");
|
||||||
|
return GStatus::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
GStatus Status(ServerContext *, const backend::HealthMessage *,
|
||||||
|
backend::StatusResponse * response) override {
|
||||||
|
std::lock_guard<std::mutex> lock(g_mu);
|
||||||
|
response->set_state(g_ctx ? backend::StatusResponse::READY
|
||||||
|
: backend::StatusResponse::UNINITIALIZED);
|
||||||
|
return GStatus::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
GStatus LoadModel(ServerContext *, const backend::ModelOptions * request,
|
||||||
|
backend::Result * result) override {
|
||||||
|
std::lock_guard<std::mutex> lock(g_mu);
|
||||||
|
|
||||||
|
// ModelFile is the absolute path LocalAI resolves; Model is the bare
|
||||||
|
// name. Prefer the former, fall back to the latter.
|
||||||
|
const std::string path =
|
||||||
|
!request->modelfile().empty() ? request->modelfile() : request->model();
|
||||||
|
if (path.empty()) {
|
||||||
|
result->set_success(false);
|
||||||
|
result->set_message("no model path supplied");
|
||||||
|
return GStatus::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string device = resolve_device(request);
|
||||||
|
|
||||||
|
if (g_ctx) { pf_free(g_ctx); g_ctx = nullptr; }
|
||||||
|
|
||||||
|
pf_ctx * ctx = pf_load(path.c_str(), device.c_str(), request->threads());
|
||||||
|
const char * err = pf_last_error(ctx);
|
||||||
|
if (err) {
|
||||||
|
result->set_success(false);
|
||||||
|
result->set_message(std::string("privacy-filter load failed: ") + err);
|
||||||
|
pf_free(ctx);
|
||||||
|
return GStatus::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContextSize, when set, becomes the per-forward window. The engine
|
||||||
|
// ignores values that are too small to window (<= 2*halo) and just
|
||||||
|
// runs a single forward, so passing it through is always safe.
|
||||||
|
if (request->contextsize() > 0) {
|
||||||
|
pf_set_window(ctx, request->contextsize());
|
||||||
|
}
|
||||||
|
|
||||||
|
g_ctx = ctx;
|
||||||
|
result->set_success(true);
|
||||||
|
result->set_message("privacy-filter loaded (" + device + ")");
|
||||||
|
return GStatus::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
GStatus TokenClassify(ServerContext *, const backend::TokenClassifyRequest * request,
|
||||||
|
backend::TokenClassifyResponse * response) override {
|
||||||
|
std::lock_guard<std::mutex> lock(g_mu);
|
||||||
|
if (!g_ctx) {
|
||||||
|
return GStatus(StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string & text = request->text();
|
||||||
|
if (text.empty()) {
|
||||||
|
return GStatus::OK; // no text -> no entities
|
||||||
|
}
|
||||||
|
|
||||||
|
pf_entity * ents = nullptr;
|
||||||
|
size_t n = 0;
|
||||||
|
if (pf_classify(g_ctx, text.data(), text.size(), request->threshold(), &ents, &n) != 0) {
|
||||||
|
const char * err = pf_last_error(g_ctx);
|
||||||
|
return GStatus(StatusCode::INTERNAL,
|
||||||
|
std::string("TokenClassify failed: ") + (err ? err : "unknown"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Byte offsets are into the original UTF-8 text; the engine already
|
||||||
|
// applied the threshold and whitespace-trimmed span edges.
|
||||||
|
for (size_t i = 0; i < n; i++) {
|
||||||
|
backend::TokenClassifyEntity * ent = response->add_entities();
|
||||||
|
ent->set_entity_group(ents[i].label ? ents[i].label : "");
|
||||||
|
ent->set_start(ents[i].start);
|
||||||
|
ent->set_end(ents[i].end);
|
||||||
|
ent->set_score(ents[i].score);
|
||||||
|
ent->set_text(text.substr((size_t) ents[i].start,
|
||||||
|
(size_t) (ents[i].end - ents[i].start)));
|
||||||
|
}
|
||||||
|
pf_entities_free(ents, n);
|
||||||
|
return GStatus::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
GStatus Free(ServerContext *, const backend::HealthMessage *,
|
||||||
|
backend::Result * result) override {
|
||||||
|
std::lock_guard<std::mutex> lock(g_mu);
|
||||||
|
if (g_ctx) { pf_free(g_ctx); g_ctx = nullptr; }
|
||||||
|
result->set_success(true);
|
||||||
|
return GStatus::OK;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void RunServer(const std::string & addr) {
|
||||||
|
PrivacyFilterBackend service;
|
||||||
|
grpc::EnableDefaultHealthCheckService(true);
|
||||||
|
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
|
||||||
|
|
||||||
|
ServerBuilder builder;
|
||||||
|
builder.AddListeningPort(addr, grpc::InsecureServerCredentials());
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
builder.SetMaxReceiveMessageSize(64 * 1024 * 1024);
|
||||||
|
builder.SetMaxSendMessageSize(64 * 1024 * 1024);
|
||||||
|
|
||||||
|
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||||
|
if (!server) {
|
||||||
|
std::cerr << "privacy-filter grpc-server: failed to bind " << addr << "\n";
|
||||||
|
std::exit(1);
|
||||||
|
}
|
||||||
|
g_server = server.get();
|
||||||
|
std::cerr << "privacy-filter grpc-server listening on " << addr << "\n";
|
||||||
|
server->Wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
void signal_handler(int) {
|
||||||
|
if (auto * srv = g_server.load()) {
|
||||||
|
srv->Shutdown(std::chrono::system_clock::now() + std::chrono::seconds(3));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int main(int argc, char * argv[]) {
|
||||||
|
std::string addr = "127.0.0.1:50051";
|
||||||
|
for (int i = 1; i < argc; ++i) {
|
||||||
|
std::string a = argv[i];
|
||||||
|
const std::string addr_flag = "--addr=";
|
||||||
|
if (a.rfind(addr_flag, 0) == 0) addr = a.substr(addr_flag.size());
|
||||||
|
else if (a == "--addr" && i + 1 < argc) addr = argv[++i];
|
||||||
|
else if (a == "--help" || a == "-h") {
|
||||||
|
std::cout << "Usage: grpc-server --addr=HOST:PORT\n";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::signal(SIGINT, signal_handler);
|
||||||
|
std::signal(SIGTERM, signal_handler);
|
||||||
|
RunServer(addr);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
39
backend/cpp/privacy-filter/package.sh
Executable file
39
backend/cpp/privacy-filter/package.sh
Executable file
@@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Assemble package/ for the from-scratch backend image: the grpc-server binary,
|
||||||
|
# run.sh, the dynamic loader, and every shared library the binary needs.
|
||||||
|
set -e
|
||||||
|
CURDIR=$(dirname "$(realpath "$0")")
|
||||||
|
REPO_ROOT="${CURDIR}/../../.."
|
||||||
|
|
||||||
|
mkdir -p "$CURDIR/package/lib"
|
||||||
|
cp -avf "$CURDIR/grpc-server" "$CURDIR/package/"
|
||||||
|
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||||
|
|
||||||
|
# The dynamic loader, renamed to lib/ld.so so run.sh can invoke it explicitly
|
||||||
|
# (makes the image independent of the host's glibc layout).
|
||||||
|
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||||
|
cp -arfLv /lib64/ld-linux-x86-64.so.2 "$CURDIR/package/lib/ld.so"
|
||||||
|
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||||
|
cp -arfLv /lib/ld-linux-aarch64.so.1 "$CURDIR/package/lib/ld.so"
|
||||||
|
else
|
||||||
|
echo "package.sh: unknown architecture" >&2; exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Bundle the binary's transitive shared deps (libstdc++, libgomp, and the apt
|
||||||
|
# grpc++/protobuf/absl stack) by walking ldd — robust to whichever of those are
|
||||||
|
# linked shared vs static. The loader line (no "=>") is skipped; ld.so above
|
||||||
|
# already covers it.
|
||||||
|
ldd "$CURDIR/grpc-server" | awk '$2 == "=>" && $3 ~ /^\// { print $3 }' | sort -u | \
|
||||||
|
while read -r so; do
|
||||||
|
[ -f "$so" ] && cp -arfLv "$so" "$CURDIR/package/lib/"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Vulkan loader / GPU libs when building the GPU variant.
|
||||||
|
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||||
|
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||||
|
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||||
|
package_gpu_libs
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "privacy-filter package contents:"
|
||||||
|
ls -lah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||||
9
backend/cpp/privacy-filter/run.sh
Executable file
9
backend/cpp/privacy-filter/run.sh
Executable file
@@ -0,0 +1,9 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Entry point for the privacy-filter backend image / BACKEND_BINARY mode.
|
||||||
|
set -e
|
||||||
|
CURDIR=$(dirname "$(realpath "$0")")
|
||||||
|
export LD_LIBRARY_PATH="$CURDIR/lib:$LD_LIBRARY_PATH"
|
||||||
|
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||||
|
exec "$CURDIR/lib/ld.so" "$CURDIR/grpc-server" "$@"
|
||||||
|
fi
|
||||||
|
exec "$CURDIR/grpc-server" "$@"
|
||||||
@@ -2,9 +2,10 @@
|
|||||||
sources/
|
sources/
|
||||||
build/
|
build/
|
||||||
package/
|
package/
|
||||||
dllm-grpc
|
ced-grpc
|
||||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||||
# symlinked for local dev; the real sources live in dllm.cpp upstream.
|
# symlinked for local dev; the real sources live in ced.cpp upstream.
|
||||||
*.so
|
*.so
|
||||||
*.so.*
|
*.so.*
|
||||||
|
ced_capi.h
|
||||||
compile_commands.json
|
compile_commands.json
|
||||||
77
backend/go/ced/Makefile
Normal file
77
backend/go/ced/Makefile
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# ced sound-classification backend Makefile.
|
||||||
|
#
|
||||||
|
# Upstream pin lives below as CED_VERSION?=<sha> so .github/bump_deps.sh can find
|
||||||
|
# and update it (matches the parakeet-cpp / whisper.cpp convention).
|
||||||
|
#
|
||||||
|
# Local dev shortcut: symlink an out-of-tree ced.cpp shared build + header and
|
||||||
|
# skip the clone/cmake steps entirely:
|
||||||
|
# ln -sf /path/to/ced.cpp/build-shared/libced.so .
|
||||||
|
# ln -sf /path/to/ced.cpp/include/ced_capi.h .
|
||||||
|
# go build -o ced-grpc .
|
||||||
|
|
||||||
|
CED_VERSION?=c04ac14b7992d00584d9e812c9bb6268598a6ce7
|
||||||
|
CED_REPO?=https://github.com/mudler/ced.cpp
|
||||||
|
|
||||||
|
GOCMD?=go
|
||||||
|
GO_TAGS?=
|
||||||
|
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||||
|
|
||||||
|
BUILD_TYPE?=
|
||||||
|
NATIVE?=false
|
||||||
|
|
||||||
|
# Static-link ggml into libced.so (PIC) so the shared lib is self-contained:
|
||||||
|
# dlopen needs no libggml*.so alongside it, only system libs the runtime image
|
||||||
|
# already provides.
|
||||||
|
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DCED_SHARED=ON -DCED_BUILD_CLI=OFF -DCED_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||||
|
|
||||||
|
ifeq ($(NATIVE),false)
|
||||||
|
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||||
|
endif
|
||||||
|
|
||||||
|
# ced.cpp gates its ggml backends behind CED_GGML_* options (set(... CACHE BOOL
|
||||||
|
# "" FORCE)), so forward those instead of a bare -DGGML_CUDA=ON.
|
||||||
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
|
CMAKE_ARGS+=-DCED_GGML_CUDA=ON -DGGML_CUDA_GRAPHS=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),openblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||||
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
|
CMAKE_ARGS+=-DCED_GGML_HIP=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),vulkan)
|
||||||
|
CMAKE_ARGS+=-DCED_GGML_VULKAN=ON
|
||||||
|
endif
|
||||||
|
|
||||||
|
.PHONY: ced-grpc package build clean purge test all
|
||||||
|
|
||||||
|
all: ced-grpc
|
||||||
|
|
||||||
|
sources/ced.cpp:
|
||||||
|
mkdir -p sources/ced.cpp
|
||||||
|
cd sources/ced.cpp && \
|
||||||
|
git init -q && \
|
||||||
|
git remote add origin $(CED_REPO) && \
|
||||||
|
git fetch --depth 1 origin $(CED_VERSION) && \
|
||||||
|
git checkout FETCH_HEAD && \
|
||||||
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
libced.so: sources/ced.cpp
|
||||||
|
cmake -B sources/ced.cpp/build-shared -S sources/ced.cpp $(CMAKE_ARGS)
|
||||||
|
cmake --build sources/ced.cpp/build-shared --config Release -j$(JOBS)
|
||||||
|
cp -fv sources/ced.cpp/build-shared/libced.so* ./ 2>/dev/null || true
|
||||||
|
cp -fv sources/ced.cpp/include/ced_capi.h ./
|
||||||
|
|
||||||
|
ced-grpc: libced.so main.go goced.go
|
||||||
|
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o ced-grpc .
|
||||||
|
|
||||||
|
package: ced-grpc
|
||||||
|
bash package.sh
|
||||||
|
|
||||||
|
build: package
|
||||||
|
|
||||||
|
test:
|
||||||
|
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||||
|
|
||||||
|
clean: purge
|
||||||
|
rm -rf libced.so* ced_capi.h package ced-grpc
|
||||||
|
|
||||||
|
purge:
|
||||||
|
rm -rf sources/ced.cpp
|
||||||
130
backend/go/ced/goced.go
Normal file
130
backend/go/ced/goced.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// Go side of the ced backend: purego bindings over ced_capi.h plus the gRPC
|
||||||
|
// SoundDetection implementation.
|
||||||
|
//
|
||||||
|
// SKETCH: the pb.SoundDetection* types come from backend.proto (regenerate with
|
||||||
|
// `make protogen-go`). The C side is single-threaded per ctx, so we guard the
|
||||||
|
// engine with engineMu; LocalAI also serializes via base.SingleThread.
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// purego-bound entry points from libced.so. Names match ced_capi.h exactly.
|
||||||
|
var (
|
||||||
|
CppAbiVersion func() int32
|
||||||
|
CppLoad func(ggufPath string) uintptr
|
||||||
|
CppFree func(ctx uintptr)
|
||||||
|
CppLastError func(ctx uintptr) string
|
||||||
|
CppNumClasses func(ctx uintptr) int32
|
||||||
|
CppSampleRate func(ctx uintptr) int32
|
||||||
|
CppClassifyPathJSON func(ctx uintptr, wavPath string, topK int32) uintptr
|
||||||
|
CppClassifyPcmJSON func(ctx uintptr, pcm []float32, nSamples int32, sampleRate int32, topK int32) uintptr
|
||||||
|
CppFreeString func(s uintptr)
|
||||||
|
)
|
||||||
|
|
||||||
|
// cstr copies a malloc'd C string (returned as uintptr) into a Go string and
|
||||||
|
// frees the original via ced_capi_free_string. Empty/0 -> "".
|
||||||
|
func cstr(p uintptr) string {
|
||||||
|
if p == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer CppFreeString(p)
|
||||||
|
var b []byte
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
ch := *(*byte)(unsafe.Pointer(p + uintptr(i))) //nolint:govet // #nosec G103 -- C-owned NUL-terminated string from libced (not Go-GC memory)
|
||||||
|
if ch == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
b = append(b, ch)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ced is the gRPC backend. One loaded CED model per instance.
|
||||||
|
type Ced struct {
|
||||||
|
base.Base
|
||||||
|
ctxPtr uintptr
|
||||||
|
engineMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load resolves the GGUF and opens the C-API context.
|
||||||
|
func (c *Ced) Load(opts *pb.ModelOptions) error {
|
||||||
|
if opts.ModelFile == "" {
|
||||||
|
return errors.New("ced: ModelFile is required")
|
||||||
|
}
|
||||||
|
ctx := CppLoad(opts.ModelFile)
|
||||||
|
if ctx == 0 {
|
||||||
|
return fmt.Errorf("ced: ced_capi_load failed for %q: %s", opts.ModelFile, CppLastError(0))
|
||||||
|
}
|
||||||
|
c.ctxPtr = ctx
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// jsonTag mirrors the ced_capi JSON tag objects.
|
||||||
|
type jsonTag struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Score float32 `json:"score"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoundDetection classifies the clip at req.Src and returns scored AudioSet tags.
|
||||||
|
func (c *Ced) SoundDetection(ctx context.Context, req *pb.SoundDetectionRequest) (*pb.SoundDetectionResponse, error) {
|
||||||
|
if c.ctxPtr == 0 {
|
||||||
|
return nil, errors.New("ced: model not loaded")
|
||||||
|
}
|
||||||
|
if req.GetSrc() == "" {
|
||||||
|
return nil, errors.New("ced: SoundDetectionRequest.src (audio path) is required")
|
||||||
|
}
|
||||||
|
topK := req.GetTopK()
|
||||||
|
if topK <= 0 {
|
||||||
|
topK = 10 // sensible default for a tagging response
|
||||||
|
}
|
||||||
|
|
||||||
|
c.engineMu.Lock()
|
||||||
|
out := cstr(CppClassifyPathJSON(c.ctxPtr, req.GetSrc(), topK))
|
||||||
|
lastErr := CppLastError(c.ctxPtr)
|
||||||
|
c.engineMu.Unlock()
|
||||||
|
|
||||||
|
if out == "" {
|
||||||
|
return nil, fmt.Errorf("ced: classification failed: %s", lastErr)
|
||||||
|
}
|
||||||
|
var tags []jsonTag
|
||||||
|
if err := json.Unmarshal([]byte(out), &tags); err != nil {
|
||||||
|
return nil, fmt.Errorf("ced: bad classifier JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
thr := req.GetThreshold()
|
||||||
|
resp := &pb.SoundDetectionResponse{}
|
||||||
|
for _, t := range tags {
|
||||||
|
if t.Score < thr {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resp.Detections = append(resp.Detections, &pb.SoundClass{
|
||||||
|
Label: t.Label, Score: t.Score, Index: int32(t.Index),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
sort.Slice(resp.Detections, func(i, j int) bool {
|
||||||
|
return resp.Detections[i].Score > resp.Detections[j].Score
|
||||||
|
})
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Ced) Free() error {
|
||||||
|
c.engineMu.Lock()
|
||||||
|
defer c.engineMu.Unlock()
|
||||||
|
if c.ctxPtr != 0 {
|
||||||
|
CppFree(c.ctxPtr)
|
||||||
|
c.ctxPtr = 0
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
59
backend/go/ced/main.go
Normal file
59
backend/go/ced/main.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// ced sound-classification backend. Started internally by LocalAI: one gRPC
|
||||||
|
// server per loaded model. Loads libced.so via purego and registers the flat
|
||||||
|
// C-API declared in ced_capi.h. The library name can be overridden with
|
||||||
|
// CED_LIBRARY (mirrors PARAKEET_LIBRARY / WHISPER_LIBRARY); the default looks
|
||||||
|
// for the .so next to this binary.
|
||||||
|
//
|
||||||
|
// SKETCH: requires `make protogen-go` after the backend.proto SoundDetection
|
||||||
|
// addition, and a built libced.so (see Makefile). See DESIGN.md.
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
|
||||||
|
type libFunc struct {
|
||||||
|
ptr any
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
libName := os.Getenv("CED_LIBRARY")
|
||||||
|
if libName == "" {
|
||||||
|
libName = "libced.so"
|
||||||
|
}
|
||||||
|
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("ced: dlopen %q: %w", libName, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bound 1:1 to ced_capi.h. char*-returning functions are declared uintptr
|
||||||
|
// so we can free the same pointer with ced_capi_free_string after copying
|
||||||
|
// (purego's string return would copy and leak the original).
|
||||||
|
for _, lf := range []libFunc{
|
||||||
|
{&CppAbiVersion, "ced_capi_abi_version"},
|
||||||
|
{&CppLoad, "ced_capi_load"},
|
||||||
|
{&CppFree, "ced_capi_free"},
|
||||||
|
{&CppLastError, "ced_capi_last_error"},
|
||||||
|
{&CppNumClasses, "ced_capi_num_classes"},
|
||||||
|
{&CppSampleRate, "ced_capi_sample_rate"},
|
||||||
|
{&CppClassifyPathJSON, "ced_capi_classify_path_json"},
|
||||||
|
{&CppClassifyPcmJSON, "ced_capi_classify_pcm_json"},
|
||||||
|
{&CppFreeString, "ced_capi_free_string"},
|
||||||
|
} {
|
||||||
|
purego.RegisterLibFunc(lf.ptr, lib, lf.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "[ced] ABI=%d\n", CppAbiVersion())
|
||||||
|
flag.Parse()
|
||||||
|
if err := grpc.StartServer(*addr, &Ced{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
60
backend/go/ced/package.sh
Executable file
60
backend/go/ced/package.sh
Executable file
@@ -0,0 +1,60 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Bundle the ced-grpc binary, libced.so, the core runtime libs (libc/libstdc++/
|
||||||
|
# libgomp + ld.so) and the GPU runtime for the active BUILD_TYPE so the package
|
||||||
|
# is self-contained. Mirrors backend/go/parakeet-cpp/package.sh; run.sh routes
|
||||||
|
# the (CGO_ENABLED=0) binary through lib/ld.so so the packaged libc is used.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath "$0")")
|
||||||
|
REPO_ROOT="${CURDIR}/../../.."
|
||||||
|
|
||||||
|
mkdir -p "$CURDIR/package/lib"
|
||||||
|
|
||||||
|
cp -avf "$CURDIR/ced-grpc" "$CURDIR/package/"
|
||||||
|
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||||
|
|
||||||
|
cp -avf "$CURDIR"/libced.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||||
|
echo "ERROR: libced.so not found in $CURDIR, run 'make' first" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||||
|
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||||
|
cp -arfLv /lib64/ld-linux-x86-64.so.2 "$CURDIR/package/lib/ld.so"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||||
|
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||||
|
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||||
|
cp -arfLv /lib/ld-linux-aarch64.so.1 "$CURDIR/package/lib/ld.so"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||||
|
elif [ "$(uname -s)" = "Darwin" ]; then
|
||||||
|
echo "Detected Darwin"
|
||||||
|
else
|
||||||
|
echo "Error: Could not detect architecture"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||||
|
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||||
|
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||||
|
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||||
|
package_gpu_libs
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Packaging completed successfully"
|
||||||
|
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||||
@@ -5,12 +5,11 @@ CURDIR=$(dirname "$(realpath "$0")")
|
|||||||
|
|
||||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||||
|
|
||||||
# If a self-contained ld.so was packaged, route through it so the
|
# If a self-contained ld.so was packaged, route through it so the packaged
|
||||||
# packaged libc / libstdc++ are used instead of the host's (matches the
|
# libc / libstdc++ are used instead of the host's (matches the sibling backends).
|
||||||
# whisper / parakeet-cpp backends' runtime layout).
|
|
||||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||||
echo "Using lib/ld.so"
|
echo "Using lib/ld.so"
|
||||||
exec "$CURDIR/lib/ld.so" "$CURDIR/dllm-grpc" "$@"
|
exec "$CURDIR/lib/ld.so" "$CURDIR/ced-grpc" "$@"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
exec "$CURDIR/dllm-grpc" "$@"
|
exec "$CURDIR/ced-grpc" "$@"
|
||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# CrispASR version (release tag)
|
# CrispASR version (release tag)
|
||||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||||
CRISPASR_VERSION?=c29f6653a516a3001d923944dad8892072cc7334
|
CRISPASR_VERSION?=d745bda4386ae0f9d1d2f23fff8ec95d76428221
|
||||||
SO_TARGET?=libgocrispasr.so
|
SO_TARGET?=libgocrispasr.so
|
||||||
|
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
@@ -67,7 +67,7 @@ sources/CrispASR:
|
|||||||
# it, so ${CMAKE_SOURCE_DIR} is THIS backend dir and the talk-llama sources
|
# it, so ${CMAKE_SOURCE_DIR} is THIS backend dir and the talk-llama sources
|
||||||
# aren't found. Rewrite to ${PROJECT_SOURCE_DIR} (the crispasr project root),
|
# aren't found. Rewrite to ${PROJECT_SOURCE_DIR} (the crispasr project root),
|
||||||
# which is correct both standalone and as a subproject. Idempotent.
|
# which is correct both standalone and as a subproject. Idempotent.
|
||||||
sed -i 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt
|
sed -i.bak 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt && rm -f sources/CrispASR/src/CMakeLists.txt.bak
|
||||||
|
|
||||||
# Detect OS
|
# Detect OS
|
||||||
UNAME_S := $(shell uname -s)
|
UNAME_S := $(shell uname -s)
|
||||||
|
|||||||
@@ -47,6 +47,74 @@ extern "C" void set_abort(int v) {
|
|||||||
g_abort.store(v, std::memory_order_relaxed);
|
g_abort.store(v, std::memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- word-level timestamp accessors ---
|
||||||
|
extern "C" {
|
||||||
|
int crispasr_session_result_n_words(crispasr_session_result *r, int seg_i);
|
||||||
|
const char *crispasr_session_result_word_text(crispasr_session_result *r,
|
||||||
|
int seg_i, int word_i);
|
||||||
|
int64_t crispasr_session_result_word_t0(crispasr_session_result *r, int seg_i,
|
||||||
|
int word_i);
|
||||||
|
int64_t crispasr_session_result_word_t1(crispasr_session_result *r, int seg_i,
|
||||||
|
int word_i);
|
||||||
|
|
||||||
|
// Parakeet-specific word accessors
|
||||||
|
int crispasr_parakeet_result_n_words(void *r);
|
||||||
|
const char *crispasr_parakeet_result_word_text(void *r, int word_i);
|
||||||
|
int64_t crispasr_parakeet_result_word_t0(void *r, int word_i);
|
||||||
|
int64_t crispasr_parakeet_result_word_t1(void *r, int word_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
void *get_result(void) { return g_result; }
|
||||||
|
|
||||||
|
int get_word_count(int seg_i) {
|
||||||
|
if (!g_result)
|
||||||
|
return 0;
|
||||||
|
return crispasr_session_result_n_words(g_result, seg_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *get_word_text(int seg_i, int word_i) {
|
||||||
|
if (!g_result)
|
||||||
|
return "";
|
||||||
|
return crispasr_session_result_word_text(g_result, seg_i, word_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t get_word_t0(int seg_i, int word_i) {
|
||||||
|
if (!g_result)
|
||||||
|
return 0;
|
||||||
|
return crispasr_session_result_word_t0(g_result, seg_i, word_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t get_word_t1(int seg_i, int word_i) {
|
||||||
|
if (!g_result)
|
||||||
|
return 0;
|
||||||
|
return crispasr_session_result_word_t1(g_result, seg_i, word_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parakeet-specific word accessors
|
||||||
|
int get_parakeet_word_count(void) {
|
||||||
|
if (!g_result)
|
||||||
|
return 0;
|
||||||
|
return crispasr_parakeet_result_n_words(g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *get_parakeet_word_text(int word_i) {
|
||||||
|
if (!g_result)
|
||||||
|
return "";
|
||||||
|
return crispasr_parakeet_result_word_text(g_result, word_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t get_parakeet_word_t0(int word_i) {
|
||||||
|
if (!g_result)
|
||||||
|
return 0;
|
||||||
|
return crispasr_parakeet_result_word_t0(g_result, word_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t get_parakeet_word_t1(int word_i) {
|
||||||
|
if (!g_result)
|
||||||
|
return 0;
|
||||||
|
return crispasr_parakeet_result_word_t1(g_result, word_i);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||||
void *data) {
|
void *data) {
|
||||||
const char *level_str;
|
const char *level_str;
|
||||||
|
|||||||
@@ -20,4 +20,18 @@ float *tts_synthesize(const char *text, int *out_n_samples); // 24kHz mono float
|
|||||||
void tts_free(float *pcm);
|
void tts_free(float *pcm);
|
||||||
int tts_set_voice(const char *name); // best-effort speaker selection; 0 ok
|
int tts_set_voice(const char *name); // best-effort speaker selection; 0 ok
|
||||||
int tts_set_voice_file(const char *path, const char *ref_text); // load voice pack (.gguf) or zero-shot clone (.wav + ref_text)
|
int tts_set_voice_file(const char *path, const char *ref_text); // load voice pack (.gguf) or zero-shot clone (.wav + ref_text)
|
||||||
|
|
||||||
|
// --- word-level timestamp accessors ---
|
||||||
|
// Session-based (works for whisper-like backends)
|
||||||
|
void *get_result(void);
|
||||||
|
int get_word_count(int seg_i);
|
||||||
|
const char *get_word_text(int seg_i, int word_i);
|
||||||
|
int64_t get_word_t0(int seg_i, int word_i);
|
||||||
|
int64_t get_word_t1(int seg_i, int word_i);
|
||||||
|
|
||||||
|
// Parakeet-specific (global word list, no segment index)
|
||||||
|
int get_parakeet_word_count(void);
|
||||||
|
const char *get_parakeet_word_text(int word_i);
|
||||||
|
int64_t get_parakeet_word_t0(int word_i);
|
||||||
|
int64_t get_parakeet_word_t1(int word_i);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/go-audio/audio"
|
"github.com/go-audio/audio"
|
||||||
"github.com/go-audio/wav"
|
"github.com/go-audio/wav"
|
||||||
|
gguf "github.com/gpustack/gguf-parser-go"
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
@@ -33,10 +34,55 @@ var (
|
|||||||
CppTTSFree func(ptr uintptr)
|
CppTTSFree func(ptr uintptr)
|
||||||
CppTTSSetVoice func(name string) int
|
CppTTSSetVoice func(name string) int
|
||||||
CppTTSSetVoiceFile func(path string, refText string) int
|
CppTTSSetVoiceFile func(path string, refText string) int
|
||||||
|
|
||||||
|
// Word-level timestamp accessors (session-based, per-segment)
|
||||||
|
CppGetWordCount func(segI int) int
|
||||||
|
CppGetWordText func(segI int, wordI int) string
|
||||||
|
CppGetWordT0 func(segI int, wordI int) int64
|
||||||
|
CppGetWordT1 func(segI int, wordI int) int64
|
||||||
|
|
||||||
|
// Parakeet-specific word accessors (global, no segment index)
|
||||||
|
CppGetParakeetWordCount func() int
|
||||||
|
CppGetParakeetWordText func(wordI int) string
|
||||||
|
CppGetParakeetWordT0 func(wordI int) int64
|
||||||
|
CppGetParakeetWordT1 func(wordI int) int64
|
||||||
)
|
)
|
||||||
|
|
||||||
type CrispASR struct {
|
type CrispASR struct {
|
||||||
base.SingleThread
|
base.SingleThread
|
||||||
|
// sampleRate is the output rate (Hz) of the loaded TTS engine's PCM, used to
|
||||||
|
// write a correct WAV header. Most CrispASR TTS backends emit 24 kHz, but
|
||||||
|
// piper returns its model's native rate (16 kHz for x_low/low voices,
|
||||||
|
// 22.05 kHz for medium/high), so it is read from the GGUF metadata at Load.
|
||||||
|
sampleRate int
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultTTSSampleRate is the output rate assumed for CrispASR TTS engines that
|
||||||
|
// don't advertise one in GGUF metadata (vibevoice/orpheus/chatterbox/qwen3-tts
|
||||||
|
// all emit 24 kHz). piper is the exception and carries piper.sample_rate.
|
||||||
|
const defaultTTSSampleRate = 24000
|
||||||
|
|
||||||
|
// piperSampleRate reads the piper.sample_rate metadata key from a GGUF model.
|
||||||
|
// CrispASR's piper backend returns PCM at the model's native rate without
|
||||||
|
// resampling, so the WAV header must match it. Returns ok=false for non-piper
|
||||||
|
// models (key absent) or an unreadable file, letting the caller fall back to
|
||||||
|
// defaultTTSSampleRate.
|
||||||
|
func piperSampleRate(modelPath string) (int, bool) {
|
||||||
|
// Only scalar architecture keys are read, so skip the large array metadata
|
||||||
|
// (phoneme map) and mmap the header - same rationale as pkg/vram's reader.
|
||||||
|
f, err := gguf.ParseGGUFFile(modelPath, gguf.UseMMap(), gguf.SkipLargeMetadata())
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
kv, ok := f.Header.MetadataKV.Get("piper.sample_rate")
|
||||||
|
if !ok || kv.ValueType != gguf.GGUFMetadataValueTypeUint32 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
rate := int(kv.ValueUint32())
|
||||||
|
if rate <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return rate, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// splitOption splits a "prefix:value" model option into its key and value,
|
// splitOption splits a "prefix:value" model option into its key and value,
|
||||||
@@ -103,6 +149,14 @@ func (w *CrispASR) Load(opts *pb.ModelOptions) error {
|
|||||||
return fmt.Errorf("Failed to load CrispASR transcription model")
|
return fmt.Errorf("Failed to load CrispASR transcription model")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine the TTS output sample rate for the WAV header. piper voices
|
||||||
|
// carry their native rate in GGUF metadata and CrispASR does not resample;
|
||||||
|
// every other engine emits the 24 kHz default.
|
||||||
|
w.sampleRate = defaultTTSSampleRate
|
||||||
|
if rate, ok := piperSampleRate(opts.ModelFile); ok {
|
||||||
|
w.sampleRate = rate
|
||||||
|
}
|
||||||
|
|
||||||
// Load the companion file (codec/tokenizer/s3gen) after the session is open.
|
// Load the companion file (codec/tokenizer/s3gen) after the session is open.
|
||||||
// rc==0 means success or "not applicable" for the active backend; only a
|
// rc==0 means success or "not applicable" for the active backend; only a
|
||||||
// negative code is fatal.
|
// negative code is fatal.
|
||||||
@@ -170,6 +224,28 @@ func (w *CrispASR) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isValidWord reports whether a TranscriptWord contains recognisable speech
|
||||||
|
// content. The parakeet-specific word accessors can return stale initialisation
|
||||||
|
// data (model name, binary blobs) when a segment has no real speech. A word is
|
||||||
|
// considered valid only when:
|
||||||
|
// - the text is non-empty after trimming,
|
||||||
|
// - it contains no U+FFFD replacement characters (from binary data scrubbing),
|
||||||
|
// - both timestamps are non-negative,
|
||||||
|
// - the word has positive duration (end > start).
|
||||||
|
func isValidWord(w *pb.TranscriptWord) bool {
|
||||||
|
txt := strings.TrimSpace(w.Text)
|
||||||
|
if txt == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.ContainsRune(txt, '\uFFFD') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if w.Start < 0 || w.End < 0 || w.End <= w.Start {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||||
if err := ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||||
@@ -248,15 +324,54 @@ func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRe
|
|||||||
// IDs, so Tokens is left empty.
|
// IDs, so Tokens is left empty.
|
||||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||||
|
|
||||||
|
// Populate word-level timestamps. Try session-based functions first
|
||||||
|
// (per-segment); fall back to parakeet-specific functions (global word
|
||||||
|
// list with no segment index — only populated on the first segment to
|
||||||
|
// avoid duplication).
|
||||||
|
words := []*pb.TranscriptWord{}
|
||||||
|
wordCount := CppGetWordCount(i)
|
||||||
|
if wordCount == 0 && i == 0 {
|
||||||
|
wordCount = CppGetParakeetWordCount()
|
||||||
|
for j := 0; j < wordCount; j++ {
|
||||||
|
w := &pb.TranscriptWord{
|
||||||
|
Start: CppGetParakeetWordT0(j) * (10000000),
|
||||||
|
End: CppGetParakeetWordT1(j) * (10000000),
|
||||||
|
Text: strings.ToValidUTF8(strings.Clone(CppGetParakeetWordText(j)), "<22>"),
|
||||||
|
}
|
||||||
|
if isValidWord(w) {
|
||||||
|
words = append(words, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for j := 0; j < wordCount; j++ {
|
||||||
|
w := &pb.TranscriptWord{
|
||||||
|
Start: CppGetWordT0(i, j) * (10000000),
|
||||||
|
End: CppGetWordT1(i, j) * (10000000),
|
||||||
|
Text: strings.ToValidUTF8(strings.Clone(CppGetWordText(i, j)), "<22>"),
|
||||||
|
}
|
||||||
|
if isValidWord(w) {
|
||||||
|
words = append(words, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip empty segments with no recognisable content (e.g. trailing
|
||||||
|
// silence segments that parakeet emits with stale init data).
|
||||||
|
trimmed := strings.TrimSpace(txt)
|
||||||
|
if trimmed == "" && len(words) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
segment := &pb.TranscriptSegment{
|
segment := &pb.TranscriptSegment{
|
||||||
Id: int32(i),
|
Id: int32(i),
|
||||||
Text: txt,
|
Text: txt,
|
||||||
Start: s, End: t,
|
Start: s, End: t,
|
||||||
|
Words: words,
|
||||||
}
|
}
|
||||||
|
|
||||||
segments = append(segments, segment)
|
segments = append(segments, segment)
|
||||||
|
|
||||||
text += " " + strings.TrimSpace(txt)
|
text += " " + trimmed
|
||||||
}
|
}
|
||||||
|
|
||||||
return pb.TranscriptResult{
|
return pb.TranscriptResult{
|
||||||
@@ -348,13 +463,20 @@ func (w *CrispASR) AudioTranscriptionStream(ctx context.Context, opts *pb.Transc
|
|||||||
s := CppGetSegmentStart(i) * 10000000
|
s := CppGetSegmentStart(i) * 10000000
|
||||||
t := CppGetSegmentEnd(i) * 10000000
|
t := CppGetSegmentEnd(i) * 10000000
|
||||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||||
|
|
||||||
|
// Skip empty segments (e.g. trailing silence that parakeet emits
|
||||||
|
// with stale init data).
|
||||||
|
trimmed := strings.TrimSpace(txt)
|
||||||
|
if trimmed == "" && s == t {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
segments = append(segments, &pb.TranscriptSegment{
|
segments = append(segments, &pb.TranscriptSegment{
|
||||||
Id: int32(i),
|
Id: int32(i),
|
||||||
Text: txt,
|
Text: txt,
|
||||||
Start: s, End: t,
|
Start: s, End: t,
|
||||||
})
|
})
|
||||||
|
|
||||||
trimmed := strings.TrimSpace(txt)
|
|
||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -390,7 +512,7 @@ func (w *CrispASR) synthesize(text string) ([]float32, error) {
|
|||||||
}
|
}
|
||||||
defer CppTTSFree(ptr)
|
defer CppTTSFree(ptr)
|
||||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free.
|
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free.
|
||||||
out := make([]float32, int(n)) // copy out of C memory before free
|
out := make([]float32, int(n)) // copy out of C memory before free
|
||||||
copy(out, src)
|
copy(out, src)
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
@@ -417,7 +539,7 @@ func (w *CrispASR) TTS(req *pb.TTSRequest) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeWAV24k(req.Dst, pcm)
|
return writeWAV(req.Dst, pcm, w.sampleRate)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TTSStream is the streaming counterpart to TTS. CrispASR has no progressive
|
// TTSStream is the streaming counterpart to TTS. CrispASR has no progressive
|
||||||
@@ -447,7 +569,7 @@ func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
|||||||
}
|
}
|
||||||
defer func() { _ = os.Remove(dst) }()
|
defer func() { _ = os.Remove(dst) }()
|
||||||
|
|
||||||
if err := writeWAV24k(dst, pcm); err != nil {
|
if err := writeWAV(dst, pcm, w.sampleRate); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -459,14 +581,14 @@ func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeWAV24k writes pcm as a 24000 Hz, mono, 16-bit PCM WAV at dst.
|
// writeWAV writes pcm as a sampleRate Hz, mono, 16-bit PCM WAV at dst.
|
||||||
func writeWAV24k(dst string, pcm []float32) error {
|
func writeWAV(dst string, pcm []float32, sampleRate int) error {
|
||||||
f, err := os.Create(dst)
|
f, err := os.Create(dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("crispasr: create %q: %w", dst, err)
|
return fmt.Errorf("crispasr: create %q: %w", dst, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
enc := wav.NewEncoder(f, 24000, 16, 1, 1)
|
enc := wav.NewEncoder(f, sampleRate, 16, 1, 1)
|
||||||
ints := make([]int, len(pcm))
|
ints := make([]int, len(pcm))
|
||||||
for i, s := range pcm {
|
for i, s := range pcm {
|
||||||
if s > 1 {
|
if s > 1 {
|
||||||
@@ -477,7 +599,7 @@ func writeWAV24k(dst string, pcm []float32) error {
|
|||||||
ints[i] = int(s * 32767)
|
ints[i] = int(s * 32767)
|
||||||
}
|
}
|
||||||
buf := &audio.IntBuffer{
|
buf := &audio.IntBuffer{
|
||||||
Format: &audio.Format{NumChannels: 1, SampleRate: 24000},
|
Format: &audio.Format{NumChannels: 1, SampleRate: sampleRate},
|
||||||
Data: ints,
|
Data: ints,
|
||||||
SourceBitDepth: 16,
|
SourceBitDepth: 16,
|
||||||
}
|
}
|
||||||
|
|||||||
164
backend/go/crispasr/gocrispasr_samplerate_test.go
Normal file
164
backend/go/crispasr/gocrispasr_samplerate_test.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/go-audio/wav"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GGUF metadata value type tags (subset) from the GGUF spec.
|
||||||
|
const (
|
||||||
|
ggufTypeUint32 uint32 = 4
|
||||||
|
ggufTypeString uint32 = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
type ggufKV struct {
|
||||||
|
key string
|
||||||
|
vtype uint32
|
||||||
|
val any
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeMinimalGGUF emits a valid, tensor-less GGUF file carrying only the given
|
||||||
|
// metadata key-values. Enough for the header-only parse path piperSampleRate
|
||||||
|
// uses; avoids pulling a real multi-MB voice into the test.
|
||||||
|
func writeMinimalGGUF(path string, kvs []ggufKV) error {
|
||||||
|
var b bytes.Buffer
|
||||||
|
b.WriteString("GGUF") // magic
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, uint32(3)) // version
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, uint64(0)) // tensor count
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, uint64(len(kvs)))
|
||||||
|
for _, kv := range kvs {
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, uint64(len(kv.key)))
|
||||||
|
b.WriteString(kv.key)
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, kv.vtype)
|
||||||
|
switch v := kv.val.(type) {
|
||||||
|
case uint32:
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, v)
|
||||||
|
case string:
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, uint64(len(v)))
|
||||||
|
b.WriteString(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, b.Bytes(), 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wavSampleRate decodes the WAV header at path and returns its sample rate.
|
||||||
|
func wavSampleRate(path string) (int, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() { _ = f.Close() }()
|
||||||
|
dec := wav.NewDecoder(f)
|
||||||
|
dec.ReadInfo()
|
||||||
|
return int(dec.SampleRate), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("piper sample rate", func() {
|
||||||
|
Context("piperSampleRate", func() {
|
||||||
|
It("reads piper.sample_rate from a piper GGUF (medium = 22050)", func() {
|
||||||
|
p := filepath.Join(GinkgoT().TempDir(), "voice.gguf")
|
||||||
|
Expect(writeMinimalGGUF(p, []ggufKV{
|
||||||
|
{key: "general.architecture", vtype: ggufTypeString, val: "piper"},
|
||||||
|
{key: "piper.sample_rate", vtype: ggufTypeUint32, val: uint32(22050)},
|
||||||
|
})).To(Succeed())
|
||||||
|
|
||||||
|
rate, ok := piperSampleRate(p)
|
||||||
|
Expect(ok).To(BeTrue(), "piper.sample_rate should be found")
|
||||||
|
Expect(rate).To(Equal(22050))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads the low-quality rate (16000)", func() {
|
||||||
|
p := filepath.Join(GinkgoT().TempDir(), "voice.gguf")
|
||||||
|
Expect(writeMinimalGGUF(p, []ggufKV{
|
||||||
|
{key: "piper.sample_rate", vtype: ggufTypeUint32, val: uint32(16000)},
|
||||||
|
})).To(Succeed())
|
||||||
|
|
||||||
|
rate, ok := piperSampleRate(p)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(rate).To(Equal(16000))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns ok=false for a non-piper GGUF (no piper.sample_rate key)", func() {
|
||||||
|
p := filepath.Join(GinkgoT().TempDir(), "other.gguf")
|
||||||
|
Expect(writeMinimalGGUF(p, []ggufKV{
|
||||||
|
{key: "general.architecture", vtype: ggufTypeString, val: "vibevoice"},
|
||||||
|
})).To(Succeed())
|
||||||
|
|
||||||
|
_, ok := piperSampleRate(p)
|
||||||
|
Expect(ok).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns ok=false for an unreadable/non-GGUF file", func() {
|
||||||
|
p := filepath.Join(GinkgoT().TempDir(), "garbage.gguf")
|
||||||
|
Expect(os.WriteFile(p, []byte("not a gguf"), 0o644)).To(Succeed())
|
||||||
|
|
||||||
|
_, ok := piperSampleRate(p)
|
||||||
|
Expect(ok).To(BeFalse())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// End-to-end through the built .so. Gated on CRISPASR_PIPER_MODEL_PATH (a
|
||||||
|
// real piper voice GGUF) like the other model-backed specs; never runs in
|
||||||
|
// default CI. Proves CrispASR's piper backend output rate flows into the
|
||||||
|
// WAV header instead of the hardcoded 24 kHz default.
|
||||||
|
Context("piper TTS end-to-end", func() {
|
||||||
|
It("writes the WAV at the model's native piper.sample_rate", func() {
|
||||||
|
model := os.Getenv("CRISPASR_PIPER_MODEL_PATH")
|
||||||
|
if model == "" {
|
||||||
|
Skip("set CRISPASR_PIPER_MODEL_PATH to run the piper e2e spec")
|
||||||
|
}
|
||||||
|
ensureLibLoaded()
|
||||||
|
|
||||||
|
expected, ok := piperSampleRate(model)
|
||||||
|
Expect(ok).To(BeTrue(), "model should carry piper.sample_rate metadata")
|
||||||
|
|
||||||
|
w := &CrispASR{}
|
||||||
|
Expect(w.Load(&pb.ModelOptions{
|
||||||
|
ModelFile: model,
|
||||||
|
Options: []string{"backend:piper"},
|
||||||
|
Threads: 4,
|
||||||
|
})).To(Succeed())
|
||||||
|
|
||||||
|
dst := filepath.Join(GinkgoT().TempDir(), "piper.wav")
|
||||||
|
Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR piper.", Dst: dst})).To(Succeed())
|
||||||
|
|
||||||
|
info, err := os.Stat(dst)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(info.Size()).To(BeNumerically(">", 1024), "expected a non-trivial WAV")
|
||||||
|
|
||||||
|
rate, err := wavSampleRate(dst)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(rate).To(Equal(expected),
|
||||||
|
"WAV header rate must equal the model's native piper.sample_rate, not the 24k default")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("writeWAV", func() {
|
||||||
|
It("writes the WAV header at the given sample rate (22050 for piper, not the 24k default)", func() {
|
||||||
|
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||||
|
pcm := make([]float32, 220) // 10 ms of silence is enough for a header
|
||||||
|
Expect(writeWAV(dst, pcm, 22050)).To(Succeed())
|
||||||
|
|
||||||
|
rate, err := wavSampleRate(dst)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(rate).To(Equal(22050))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes a 16000 Hz header for low-quality piper voices", func() {
|
||||||
|
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||||
|
pcm := make([]float32, 160)
|
||||||
|
Expect(writeWAV(dst, pcm, 16000)).To(Succeed())
|
||||||
|
|
||||||
|
rate, err := wavSampleRate(dst)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(rate).To(Equal(16000))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -44,6 +44,14 @@ func main() {
|
|||||||
{&CppTTSFree, "tts_free"},
|
{&CppTTSFree, "tts_free"},
|
||||||
{&CppTTSSetVoice, "tts_set_voice"},
|
{&CppTTSSetVoice, "tts_set_voice"},
|
||||||
{&CppTTSSetVoiceFile, "tts_set_voice_file"},
|
{&CppTTSSetVoiceFile, "tts_set_voice_file"},
|
||||||
|
{&CppGetWordCount, "get_word_count"},
|
||||||
|
{&CppGetWordText, "get_word_text"},
|
||||||
|
{&CppGetWordT0, "get_word_t0"},
|
||||||
|
{&CppGetWordT1, "get_word_t1"},
|
||||||
|
{&CppGetParakeetWordCount, "get_parakeet_word_count"},
|
||||||
|
{&CppGetParakeetWordText, "get_parakeet_word_text"},
|
||||||
|
{&CppGetParakeetWordT0, "get_parakeet_word_t0"},
|
||||||
|
{&CppGetParakeetWordT1, "get_parakeet_word_t1"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, lf := range libFuncs {
|
for _, lf := range libFuncs {
|
||||||
|
|||||||
@@ -51,6 +51,32 @@ else
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Bundle espeak-ng (+ its libpcaudio/libsonic runtime deps) and its voice data so
|
||||||
|
# the piper TTS backend can phonemize non-English text. CrispASR dlopens
|
||||||
|
# libespeak-ng.so.1 at runtime (the MIT-clean path); the dlopen succeeds loading
|
||||||
|
# libespeak-ng but FAILS if libpcaudio/libsonic are absent, so all three .so are
|
||||||
|
# required. run.sh points CRISPASR_ESPEAK_DATA_PATH at the bundled data dir.
|
||||||
|
# Best-effort: only copied when present, so a local dev build without espeak-ng
|
||||||
|
# installed still packages the rest (English voices keep working).
|
||||||
|
ESPEAK_LIBDIR=""
|
||||||
|
for d in /usr/lib/x86_64-linux-gnu /usr/lib/aarch64-linux-gnu; do
|
||||||
|
if [ -f "$d/libespeak-ng.so.1" ]; then
|
||||||
|
ESPEAK_LIBDIR="$d"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if [ -n "$ESPEAK_LIBDIR" ]; then
|
||||||
|
echo "Bundling espeak-ng from $ESPEAK_LIBDIR ..."
|
||||||
|
cp -arfLv "$ESPEAK_LIBDIR/libespeak-ng.so.1" $CURDIR/package/lib/
|
||||||
|
cp -arfLv "$ESPEAK_LIBDIR/libpcaudio.so.0" $CURDIR/package/lib/
|
||||||
|
cp -arfLv "$ESPEAK_LIBDIR/libsonic.so.0" $CURDIR/package/lib/
|
||||||
|
if [ -d "$ESPEAK_LIBDIR/espeak-ng-data" ]; then
|
||||||
|
cp -arfLv "$ESPEAK_LIBDIR/espeak-ng-data" $CURDIR/package/
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "espeak-ng not found; non-English piper voices will not phonemize"
|
||||||
|
fi
|
||||||
|
|
||||||
# Package GPU libraries based on BUILD_TYPE
|
# Package GPU libraries based on BUILD_TYPE
|
||||||
# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
|
# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
|
||||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||||
|
|||||||
@@ -41,6 +41,11 @@ fi
|
|||||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||||
export CRISPASR_LIBRARY=$LIBRARY
|
export CRISPASR_LIBRARY=$LIBRARY
|
||||||
|
|
||||||
|
# Point piper's espeak-ng phonemizer at the bundled voice data. The variable
|
||||||
|
# names the directory CONTAINING espeak-ng-data (package.sh drops it next to
|
||||||
|
# this script). Harmless when espeak-ng wasn't bundled.
|
||||||
|
export CRISPASR_ESPEAK_DATA_PATH=$CURDIR
|
||||||
|
|
||||||
# If there is a lib/ld.so, use it
|
# If there is a lib/ld.so, use it
|
||||||
if [ -f $CURDIR/lib/ld.so ]; then
|
if [ -f $CURDIR/lib/ld.so ]; then
|
||||||
echo "Using lib/ld.so"
|
echo "Using lib/ld.so"
|
||||||
|
|||||||
7
backend/go/depth-anything-cpp/.gitignore
vendored
Normal file
7
backend/go/depth-anything-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
sources/
|
||||||
|
build*/
|
||||||
|
package/
|
||||||
|
libdepthanythingcpp*.so
|
||||||
|
depth-anything-cpp
|
||||||
|
test-models/
|
||||||
|
test-data/
|
||||||
28
backend/go/depth-anything-cpp/CMakeLists.txt
Normal file
28
backend/go/depth-anything-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.18)
|
||||||
|
project(libdepthanythingcpp LANGUAGES C CXX)
|
||||||
|
|
||||||
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
# Static-link ggml into the depth-anything shared library so the resulting .so
|
||||||
|
# has no runtime dependency on an external libggml — only on
|
||||||
|
# libc/libstdc++/libgomp, which the LocalAI package step bundles into the
|
||||||
|
# docker image.
|
||||||
|
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build static libraries" FORCE)
|
||||||
|
|
||||||
|
# depth-anything.cpp build switches: skip CLI/tests, but build libdepthanything
|
||||||
|
# itself as a SHARED library (DA_SHARED) while ggml stays static
|
||||||
|
# (BUILD_SHARED_LIBS OFF above). The da_capi_* C ABI is compiled into
|
||||||
|
# src/da_capi.cpp and re-exported by that shared library, so no extra MODULE
|
||||||
|
# wrapper is needed (unlike locate-anything.cpp).
|
||||||
|
set(DA_BUILD_CLI OFF CACHE BOOL "Disable depth-anything CLI" FORCE)
|
||||||
|
set(DA_BUILD_TESTS OFF CACHE BOOL "Disable depth-anything tests" FORCE)
|
||||||
|
set(DA_SHARED ON CACHE BOOL "Build libdepthanything as a shared lib" FORCE)
|
||||||
|
|
||||||
|
add_subdirectory(./sources/depth-anything.cpp)
|
||||||
|
|
||||||
|
# Emit libdepthanything.so into the top-level build dir so the Makefile can
|
||||||
|
# rename it to the per-variant libdepthanythingcpp-<variant>.so.
|
||||||
|
set_target_properties(depthanything PROPERTIES
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||||
141
backend/go/depth-anything-cpp/Makefile
Normal file
141
backend/go/depth-anything-cpp/Makefile
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
CMAKE_ARGS?=
|
||||||
|
BUILD_TYPE?=
|
||||||
|
NATIVE?=false
|
||||||
|
|
||||||
|
GOCMD?=go
|
||||||
|
GO_TAGS?=
|
||||||
|
JOBS?=$(shell nproc --ignore=1)
|
||||||
|
|
||||||
|
# depth-anything.cpp. Pin to a specific commit for a stable build; a squash
|
||||||
|
# merge upstream can orphan a branch, so the native version is pinned by SHA.
|
||||||
|
# This SHA adds the Depth Anything V2 engine + C-API routing (depth-only,
|
||||||
|
# relative + metric) on top of the nested two-file metric C-API (abi_version 4,
|
||||||
|
# da_capi_load_nested) required by the depth-anything-3-nested gallery model.
|
||||||
|
# It is kept alive by the upstream tag da2-support (survives a squash-merge);
|
||||||
|
# repoint to the master merge commit once mudler/depth-anything.cpp PR #1 lands.
|
||||||
|
DEPTHANYTHING_REPO?=https://github.com/mudler/depth-anything.cpp.git
|
||||||
|
DEPTHANYTHING_VERSION?=f4e17dea695dd12ae76bea98ba58030996b98118
|
||||||
|
|
||||||
|
ifeq ($(NATIVE),false)
|
||||||
|
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Forward LocalAI's BUILD_TYPE to the matching ggml backend switch. depth-anything.cpp
|
||||||
|
# force-sets GGML_CUDA/GGML_VULKAN/GGML_METAL from its own DA_GGML_* options, so
|
||||||
|
# those must be toggled via the DA_GGML_* names (a bare -DGGML_CUDA=ON would be
|
||||||
|
# overridden); the remaining ggml switches pass straight through.
|
||||||
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
|
CMAKE_ARGS+=-DGGML_CUDA=ON -DDA_GGML_CUDA=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),openblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||||
|
else ifeq ($(BUILD_TYPE),clblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_CLBLAST=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
|
ROCM_HOME ?= /opt/rocm
|
||||||
|
ROCM_PATH ?= /opt/rocm
|
||||||
|
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||||
|
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||||
|
AMDGPU_TARGETS?=gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
|
||||||
|
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||||
|
else ifeq ($(BUILD_TYPE),vulkan)
|
||||||
|
CMAKE_ARGS+=-DGGML_VULKAN=ON -DDA_GGML_VULKAN=ON
|
||||||
|
else ifeq ($(OS),Darwin)
|
||||||
|
ifneq ($(BUILD_TYPE),metal)
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||||
|
else
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||||
|
CMAKE_ARGS+=-DDA_GGML_METAL=ON
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||||
|
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||||
|
-DCMAKE_C_COMPILER=icx \
|
||||||
|
-DCMAKE_CXX_COMPILER=icpx \
|
||||||
|
-DGGML_SYCL_F16=ON
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||||
|
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||||
|
-DCMAKE_C_COMPILER=icx \
|
||||||
|
-DCMAKE_CXX_COMPILER=icpx
|
||||||
|
endif
|
||||||
|
|
||||||
|
sources/depth-anything.cpp:
|
||||||
|
mkdir -p sources && \
|
||||||
|
git clone --recursive $(DEPTHANYTHING_REPO) sources/depth-anything.cpp && \
|
||||||
|
cd sources/depth-anything.cpp && \
|
||||||
|
git checkout $(DEPTHANYTHING_VERSION) && \
|
||||||
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
# Detect OS
|
||||||
|
UNAME_S := $(shell uname -s)
|
||||||
|
|
||||||
|
# Only build CPU variants on Linux
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
VARIANT_TARGETS = libdepthanythingcpp-avx.so libdepthanythingcpp-avx2.so libdepthanythingcpp-avx512.so libdepthanythingcpp-fallback.so
|
||||||
|
else
|
||||||
|
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||||
|
VARIANT_TARGETS = libdepthanythingcpp-fallback.so
|
||||||
|
endif
|
||||||
|
|
||||||
|
depth-anything-cpp: main.go godepthanythingcpp.go $(VARIANT_TARGETS)
|
||||||
|
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o depth-anything-cpp ./
|
||||||
|
|
||||||
|
package: depth-anything-cpp
|
||||||
|
bash package.sh
|
||||||
|
|
||||||
|
build: package
|
||||||
|
|
||||||
|
clean: purge
|
||||||
|
rm -rf libdepthanythingcpp*.so depth-anything-cpp package sources
|
||||||
|
|
||||||
|
purge:
|
||||||
|
rm -rf build*
|
||||||
|
|
||||||
|
# Build all variants (Linux only)
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
libdepthanythingcpp-avx.so: sources/depth-anything.cpp
|
||||||
|
rm -rfv build-$@
|
||||||
|
$(info ${GREEN}I depth-anything-cpp build info:avx${RESET})
|
||||||
|
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libdepthanythingcpp-custom
|
||||||
|
rm -rfv build-$@
|
||||||
|
|
||||||
|
libdepthanythingcpp-avx2.so: sources/depth-anything.cpp
|
||||||
|
rm -rfv build-$@
|
||||||
|
$(info ${GREEN}I depth-anything-cpp build info:avx2${RESET})
|
||||||
|
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libdepthanythingcpp-custom
|
||||||
|
rm -rfv build-$@
|
||||||
|
|
||||||
|
libdepthanythingcpp-avx512.so: sources/depth-anything.cpp
|
||||||
|
rm -rfv build-$@
|
||||||
|
$(info ${GREEN}I depth-anything-cpp build info:avx512${RESET})
|
||||||
|
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libdepthanythingcpp-custom
|
||||||
|
rm -rfv build-$@
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Build fallback variant (all platforms)
|
||||||
|
libdepthanythingcpp-fallback.so: sources/depth-anything.cpp
|
||||||
|
rm -rfv build-$@
|
||||||
|
$(info ${GREEN}I depth-anything-cpp build info:fallback${RESET})
|
||||||
|
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libdepthanythingcpp-custom
|
||||||
|
rm -rfv build-$@
|
||||||
|
|
||||||
|
libdepthanythingcpp-custom: CMakeLists.txt
|
||||||
|
mkdir -p build-$(SO_TARGET) && \
|
||||||
|
cd build-$(SO_TARGET) && \
|
||||||
|
cmake .. $(CMAKE_ARGS) && \
|
||||||
|
cmake --build . --config Release -j$(JOBS) && \
|
||||||
|
cd .. && \
|
||||||
|
mv build-$(SO_TARGET)/libdepthanything.so ./$(SO_TARGET)
|
||||||
|
|
||||||
|
all: depth-anything-cpp package
|
||||||
|
|
||||||
|
# `test` is invoked by the top-level Makefile's `test-extra` target. It builds
|
||||||
|
# the backend binary + the fallback shared library (needed for dlopen at
|
||||||
|
# runtime), then runs test.sh which downloads a small GGUF + a test image and
|
||||||
|
# exercises the gRPC Load/Predict wire path via the Go smoke test in
|
||||||
|
# main_test.go.
|
||||||
|
test: depth-anything-cpp libdepthanythingcpp-fallback.so
|
||||||
|
bash test.sh
|
||||||
556
backend/go/depth-anything-cpp/godepthanythingcpp.go
Normal file
556
backend/go/depth-anything-cpp/godepthanythingcpp.go
Normal file
@@ -0,0 +1,556 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// godepthanythingcpp.go - gRPC handlers (Load, Predict, GenerateImage) for the
|
||||||
|
// depth-anything-cpp backend, wrapping the Depth Anything 3 ggml C-API
|
||||||
|
// (libdepthanythingcpp-<variant>.so) via purego.
|
||||||
|
//
|
||||||
|
// Embeds base.SingleThread to default the unimplemented RPCs to "not supported"
|
||||||
|
// and to serialize calls — the C side shares a ggml graph allocator and is NOT
|
||||||
|
// reentrant, so all inference must run one-at-a-time.
|
||||||
|
//
|
||||||
|
// Depth has no native OpenAI endpoint, so the model is exposed two ways:
|
||||||
|
//
|
||||||
|
// - GenerateImage(src, dst): run depth on the src image and write a
|
||||||
|
// min-max-normalised grayscale depth PNG to dst.
|
||||||
|
// - Predict(images[0]): run depth+pose and return a JSON blob with the depth
|
||||||
|
// dimensions, depth stats and the camera extrinsics (3x4) / intrinsics (3x3).
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/png"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// C-API function pointers, registered in main.go via purego. The da_capi_*
|
||||||
|
// symbols live inside libdepthanything (src/da_capi.cpp) and are re-exported by
|
||||||
|
// the DA_SHARED build.
|
||||||
|
var (
|
||||||
|
// da_capi_load(const char* gguf_path, int n_threads) -> da_ctx* (0 = fail)
|
||||||
|
CapiLoad func(gguf string, nThreads int32) uintptr
|
||||||
|
// da_capi_load_nested(const char* anyview_gguf, const char* metric_gguf,
|
||||||
|
// int n_threads) -> da_ctx* (0 = fail). The returned ctx serves the nested
|
||||||
|
// metric model: depth/pose calls produce final metric-scale depth + scaled pose.
|
||||||
|
CapiLoadNested func(anyview string, metric string, nThreads int32) uintptr
|
||||||
|
// da_capi_free(da_ctx* ctx) — safe on a 0 handle.
|
||||||
|
CapiFree func(handle uintptr)
|
||||||
|
// da_capi_last_error(da_ctx* ctx) -> const char* (owned by ctx, "" if none).
|
||||||
|
// purego marshals the returned C string into a Go string (a copy), so we
|
||||||
|
// never free it.
|
||||||
|
CapiLastError func(handle uintptr) string
|
||||||
|
// da_capi_depth_path(ctx, image_path, out_h*, out_w*) -> float* depth map
|
||||||
|
// (row-major H*W); nil on error. Caller frees via da_capi_free_floats.
|
||||||
|
CapiDepthPath func(handle uintptr, imagePath string, outH *int32, outW *int32) *float32
|
||||||
|
// da_capi_free_floats(float* p)
|
||||||
|
CapiFreeFloats func(p *float32)
|
||||||
|
// da_capi_pose_path(ctx, image_path, out_ext[12], out_intr[9]) -> 0 ok, -1 err
|
||||||
|
CapiPosePath func(handle uintptr, imagePath string, outExt *float32, outIntr *float32) int32
|
||||||
|
// da_capi_depth_dense(ctx, image_path, out_h*, out_w*, out_depth**, out_conf**,
|
||||||
|
// out_sky**, out_ext[12], out_intr[9], out_is_metric*) -> 0 ok, -1 err.
|
||||||
|
// Each non-NULL out_depth/out_conf/out_sky receives a malloc'd float[H*W] (free
|
||||||
|
// via da_capi_free_floats); buffers the model doesn't produce are set NULL.
|
||||||
|
CapiDepthDense func(handle uintptr, imagePath string,
|
||||||
|
outH, outW *int32,
|
||||||
|
outDepth, outConf, outSky **float32,
|
||||||
|
outExt, outIntr *float32,
|
||||||
|
outIsMetric *int32) int32
|
||||||
|
// da_capi_points(ctx, image_path, conf_thresh, out_n*, out_xyz**, out_rgb**) ->
|
||||||
|
// 0 ok, -1 err. *out_xyz = malloc'd float[3*N] (free via da_capi_free_floats),
|
||||||
|
// *out_rgb = malloc'd uint8[3*N] (free via da_capi_free_bytes).
|
||||||
|
CapiPoints func(handle uintptr, imagePath string, confThresh float32,
|
||||||
|
outN *int32, outXyz **float32, outRgb **byte) int32
|
||||||
|
// da_capi_free_bytes(unsigned char* p)
|
||||||
|
CapiFreeBytes func(p *byte)
|
||||||
|
// da_capi_export_glb(ctx, image_path, out_glb) -> 0 ok, -1 err
|
||||||
|
CapiExportGlb func(handle uintptr, imagePath string, outGlb string) int32
|
||||||
|
// da_capi_export_colmap(ctx, image_path, out_dir, binary) -> 0 ok, -1 err
|
||||||
|
CapiExportColmap func(handle uintptr, imagePath string, outDir string, binary int32) int32
|
||||||
|
)
|
||||||
|
|
||||||
|
type DepthAnythingCpp struct {
|
||||||
|
base.SingleThread
|
||||||
|
handle uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads the GGUF model at opts.ModelFile (joined with opts.ModelPath if
|
||||||
|
// relative) and stores the da_ctx handle for later inference calls.
|
||||||
|
func (r *DepthAnythingCpp) Load(opts *pb.ModelOptions) error {
|
||||||
|
modelFile := opts.ModelFile
|
||||||
|
if modelFile == "" {
|
||||||
|
modelFile = opts.Model
|
||||||
|
}
|
||||||
|
if modelFile == "" {
|
||||||
|
return fmt.Errorf("depth-anything-cpp: ModelFile is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
resolve := func(name string) string {
|
||||||
|
if filepath.IsAbs(name) {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
return filepath.Join(opts.ModelPath, name)
|
||||||
|
}
|
||||||
|
modelPath := resolve(modelFile)
|
||||||
|
|
||||||
|
if _, err := os.Stat(modelPath); err != nil {
|
||||||
|
return fmt.Errorf("depth-anything-cpp: model file not found: %s: %w", modelPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nested metric models are a two-file pair: the main model is the anyview
|
||||||
|
// (GIANT) branch and the metric (ViT-L + DPT/sky) branch is named via a
|
||||||
|
// "metric_model:<filename>" entry in opts.Options. When present we load both
|
||||||
|
// branches so the engine runs the nested metric alignment.
|
||||||
|
metricFile := optionValue(opts.Options, "metric_model")
|
||||||
|
|
||||||
|
threads := opts.Threads
|
||||||
|
if threads <= 0 {
|
||||||
|
threads = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release previous model if any (re-Load).
|
||||||
|
if r.handle != 0 {
|
||||||
|
CapiFree(r.handle)
|
||||||
|
r.handle = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var h uintptr
|
||||||
|
if metricFile != "" {
|
||||||
|
metricPath := resolve(metricFile)
|
||||||
|
if _, err := os.Stat(metricPath); err != nil {
|
||||||
|
return fmt.Errorf("depth-anything-cpp: metric_model file not found: %s: %w", metricPath, err)
|
||||||
|
}
|
||||||
|
h = CapiLoadNested(modelPath, metricPath, threads)
|
||||||
|
if h == 0 {
|
||||||
|
if msg := CapiLastError(0); msg != "" {
|
||||||
|
return fmt.Errorf("depth-anything-cpp: da_capi_load_nested failed for %s + %s: %s", modelPath, metricPath, msg)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("depth-anything-cpp: da_capi_load_nested failed for %s + %s", modelPath, metricPath)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
h = CapiLoad(modelPath, threads)
|
||||||
|
if h == 0 {
|
||||||
|
// da_capi_last_error needs a ctx; on a failed load we have none (it
|
||||||
|
// returns "" for a null ctx), so the text is best-effort.
|
||||||
|
if msg := CapiLastError(0); msg != "" {
|
||||||
|
return fmt.Errorf("depth-anything-cpp: da_capi_load failed for %s: %s", modelPath, msg)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("depth-anything-cpp: da_capi_load failed for %s", modelPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.handle = h
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// optionValue returns the value of the first "key:value" entry in opts whose key
|
||||||
|
// matches (case-sensitive), or "" if absent. Mirrors how other LocalAI backends
|
||||||
|
// read ModelOptions.Options.
|
||||||
|
func optionValue(opts []string, key string) string {
|
||||||
|
prefix := key + ":"
|
||||||
|
for _, o := range opts {
|
||||||
|
if strings.HasPrefix(o, prefix) {
|
||||||
|
return strings.TrimSpace(o[len(prefix):])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// depthResult is the JSON payload returned by Predict.
|
||||||
|
type depthResult struct {
|
||||||
|
DepthW int `json:"depth_w"`
|
||||||
|
DepthH int `json:"depth_h"`
|
||||||
|
DepthMin float32 `json:"depth_min"`
|
||||||
|
DepthMax float32 `json:"depth_max"`
|
||||||
|
Extrinsics [12]float32 `json:"extrinsics"` // 3x4 row-major
|
||||||
|
Intrinsics [9]float32 `json:"intrinsics"` // 3x3 row-major
|
||||||
|
}
|
||||||
|
|
||||||
|
// Predict runs depth+pose on the first supplied image and returns depth
|
||||||
|
// statistics + camera pose as a JSON string. LocalAI wraps the string into the
|
||||||
|
// Reply.Message of the gRPC response. The image in Images[0] may be a
|
||||||
|
// filesystem path or a base64-encoded payload.
|
||||||
|
func (r *DepthAnythingCpp) Predict(opts *pb.PredictOptions) (string, error) {
|
||||||
|
imgs := opts.GetImages()
|
||||||
|
if len(imgs) == 0 {
|
||||||
|
return "", fmt.Errorf("depth-anything-cpp: Predict requires an image in Images[]")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgPath, cleanup, err := materializeImage(imgs[0])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("depth-anything-cpp: %w", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
depth, h, w, ext, intr, err := r.runDepthPose(imgPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
dmin, dmax := minMax(depth)
|
||||||
|
payload, err := json.Marshal(depthResult{
|
||||||
|
DepthW: w, DepthH: h,
|
||||||
|
DepthMin: dmin, DepthMax: dmax,
|
||||||
|
Extrinsics: ext, Intrinsics: intr,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("depth-anything-cpp: marshal: %w", err)
|
||||||
|
}
|
||||||
|
return string(payload), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateImage runs depth on req.Src and writes a normalised grayscale depth
|
||||||
|
// PNG to req.Dst.
|
||||||
|
func (r *DepthAnythingCpp) GenerateImage(req *pb.GenerateImageRequest) error {
|
||||||
|
if req.GetSrc() == "" {
|
||||||
|
return fmt.Errorf("depth-anything-cpp: GenerateImage requires src")
|
||||||
|
}
|
||||||
|
if req.GetDst() == "" {
|
||||||
|
return fmt.Errorf("depth-anything-cpp: GenerateImage requires dst")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgPath, cleanup, err := materializeImage(req.GetSrc())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("depth-anything-cpp: %w", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
depth, h, w, _, _, err := r.runDepthPose(imgPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeDepthPNG(req.GetDst(), depth, h, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Depth is the typed Depth RPC. It runs the Depth Anything 3 pipeline on the
|
||||||
|
// request's src image and fills a DepthResponse honoring the include_* flags and
|
||||||
|
// exports: per-pixel metric depth + confidence (DualDPT) or depth + sky (mono),
|
||||||
|
// camera extrinsics/intrinsics, an optional back-projected 3D point cloud and
|
||||||
|
// glb/COLMAP exports. The src may be a filesystem path or a base64 payload.
|
||||||
|
func (r *DepthAnythingCpp) Depth(in *pb.DepthRequest) (pb.DepthResponse, error) {
|
||||||
|
// Accumulate into locals and return a single composite literal at the end:
|
||||||
|
// returning a named pb.DepthResponse value would copy its embedded mutex
|
||||||
|
// (go vet copylocks).
|
||||||
|
if r.handle == 0 {
|
||||||
|
return pb.DepthResponse{}, fmt.Errorf("depth-anything-cpp: model not loaded")
|
||||||
|
}
|
||||||
|
if in.GetSrc() == "" {
|
||||||
|
return pb.DepthResponse{}, fmt.Errorf("depth-anything-cpp: Depth requires src")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgPath, cleanup, err := materializeImage(in.GetSrc())
|
||||||
|
if err != nil {
|
||||||
|
return pb.DepthResponse{}, fmt.Errorf("depth-anything-cpp: %w", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Dense per-pixel output + pose. Pass buffer pointers only for the
|
||||||
|
// requested maps so the native side can skip unrequested work; ext/intr
|
||||||
|
// must always point at 12/9 floats per the C ABI.
|
||||||
|
var (
|
||||||
|
h, w, isMetric int32
|
||||||
|
depthPtr, confPtr *float32
|
||||||
|
skyPtr *float32
|
||||||
|
ext [12]float32
|
||||||
|
intr [9]float32
|
||||||
|
pDepth, pConf, pSky **float32
|
||||||
|
)
|
||||||
|
if in.GetIncludeDepth() {
|
||||||
|
pDepth = &depthPtr
|
||||||
|
}
|
||||||
|
if in.GetIncludeConfidence() {
|
||||||
|
pConf = &confPtr
|
||||||
|
}
|
||||||
|
if in.GetIncludeSky() {
|
||||||
|
pSky = &skyPtr
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := CapiDepthDense(r.handle, imgPath, &h, &w, pDepth, pConf, pSky, &ext[0], &intr[0], &isMetric)
|
||||||
|
if rc != 0 {
|
||||||
|
return pb.DepthResponse{}, fmt.Errorf("depth-anything-cpp: da_capi_depth_dense failed (rc=%d): %s", rc, r.lastError())
|
||||||
|
}
|
||||||
|
|
||||||
|
n := int(h) * int(w)
|
||||||
|
var (
|
||||||
|
depth, conf, sky []float32
|
||||||
|
extrinsics, intrinsic []float32
|
||||||
|
numPoints int32
|
||||||
|
points []float32
|
||||||
|
pointColors []byte
|
||||||
|
exportPaths []string
|
||||||
|
)
|
||||||
|
|
||||||
|
if depthPtr != nil {
|
||||||
|
depth = copyFloats(depthPtr, n)
|
||||||
|
CapiFreeFloats(depthPtr)
|
||||||
|
}
|
||||||
|
if confPtr != nil {
|
||||||
|
conf = copyFloats(confPtr, n)
|
||||||
|
CapiFreeFloats(confPtr)
|
||||||
|
}
|
||||||
|
if skyPtr != nil {
|
||||||
|
sky = copyFloats(skyPtr, n)
|
||||||
|
CapiFreeFloats(skyPtr)
|
||||||
|
}
|
||||||
|
if in.GetIncludePose() {
|
||||||
|
extrinsics = append([]float32(nil), ext[:]...)
|
||||||
|
intrinsic = append([]float32(nil), intr[:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3D point cloud (DualDPT / pose-capable models only).
|
||||||
|
if in.GetIncludePoints() {
|
||||||
|
var (
|
||||||
|
np int32
|
||||||
|
xyzPtr *float32
|
||||||
|
rgbPtr *byte
|
||||||
|
)
|
||||||
|
if rc := CapiPoints(r.handle, imgPath, in.GetPointsConfThresh(), &np, &xyzPtr, &rgbPtr); rc != 0 {
|
||||||
|
return pb.DepthResponse{}, fmt.Errorf("depth-anything-cpp: da_capi_points failed (rc=%d): %s", rc, r.lastError())
|
||||||
|
}
|
||||||
|
numPoints = np
|
||||||
|
if xyzPtr != nil {
|
||||||
|
points = copyFloats(xyzPtr, int(np)*3)
|
||||||
|
CapiFreeFloats(xyzPtr)
|
||||||
|
}
|
||||||
|
if rgbPtr != nil {
|
||||||
|
pointColors = copyBytes(rgbPtr, int(np)*3)
|
||||||
|
CapiFreeBytes(rgbPtr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exports (glb / colmap). They are written under in.Dst (a directory); a
|
||||||
|
// temp dir is used when Dst is empty.
|
||||||
|
if len(in.GetExports()) > 0 {
|
||||||
|
exportPaths, err = r.runExports(imgPath, in.GetDst(), in.GetExports())
|
||||||
|
if err != nil {
|
||||||
|
return pb.DepthResponse{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pb.DepthResponse{
|
||||||
|
Width: w,
|
||||||
|
Height: h,
|
||||||
|
Depth: depth,
|
||||||
|
Confidence: conf,
|
||||||
|
Sky: sky,
|
||||||
|
Extrinsics: extrinsics,
|
||||||
|
Intrinsics: intrinsic,
|
||||||
|
NumPoints: numPoints,
|
||||||
|
Points: points,
|
||||||
|
PointColors: pointColors,
|
||||||
|
ExportPaths: exportPaths,
|
||||||
|
IsMetric: isMetric != 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runExports writes the requested exports for imgPath into dstDir and returns
|
||||||
|
// the written paths. Supported exports: "glb", "colmap".
|
||||||
|
func (r *DepthAnythingCpp) runExports(imgPath, dstDir string, exports []string) ([]string, error) {
|
||||||
|
if dstDir == "" {
|
||||||
|
tmp, err := os.MkdirTemp("", "depth-anything-export-*")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("depth-anything-cpp: mkdir export dir: %w", err)
|
||||||
|
}
|
||||||
|
dstDir = tmp
|
||||||
|
} else if err := os.MkdirAll(dstDir, 0o750); err != nil {
|
||||||
|
return nil, fmt.Errorf("depth-anything-cpp: mkdir %s: %w", dstDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var paths []string
|
||||||
|
for _, exp := range exports {
|
||||||
|
switch exp {
|
||||||
|
case "glb":
|
||||||
|
out := filepath.Join(dstDir, "pointcloud.glb")
|
||||||
|
if rc := CapiExportGlb(r.handle, imgPath, out); rc != 0 {
|
||||||
|
return nil, fmt.Errorf("depth-anything-cpp: da_capi_export_glb failed (rc=%d): %s", rc, r.lastError())
|
||||||
|
}
|
||||||
|
paths = append(paths, out)
|
||||||
|
case "colmap":
|
||||||
|
out := filepath.Join(dstDir, "colmap")
|
||||||
|
if err := os.MkdirAll(out, 0o750); err != nil {
|
||||||
|
return nil, fmt.Errorf("depth-anything-cpp: mkdir %s: %w", out, err)
|
||||||
|
}
|
||||||
|
if rc := CapiExportColmap(r.handle, imgPath, out, 1); rc != 0 {
|
||||||
|
return nil, fmt.Errorf("depth-anything-cpp: da_capi_export_colmap failed (rc=%d): %s", rc, r.lastError())
|
||||||
|
}
|
||||||
|
paths = append(paths, out)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("depth-anything-cpp: unknown export %q (want glb|colmap)", exp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return paths, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyFloats copies n float32 values from a C heap pointer into a fresh Go
|
||||||
|
// slice so the C buffer can be freed afterwards.
|
||||||
|
func copyFloats(p *float32, n int) []float32 {
|
||||||
|
if p == nil || n <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
src := unsafe.Slice(p, n)
|
||||||
|
out := make([]float32, n)
|
||||||
|
copy(out, src)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyBytes copies n bytes from a C heap pointer into a fresh Go slice.
|
||||||
|
func copyBytes(p *byte, n int) []byte {
|
||||||
|
if p == nil || n <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
src := unsafe.Slice(p, n)
|
||||||
|
out := make([]byte, n)
|
||||||
|
copy(out, src)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// runDepthPose runs depth estimation then pose recovery on an image file. It
|
||||||
|
// returns the row-major depth map (length h*w), its dimensions, the 3x4
|
||||||
|
// extrinsics (12 floats) and 3x3 intrinsics (9 floats).
|
||||||
|
// runDepthPose returns depth + camera pose via two C-API calls (depth then pose).
|
||||||
|
// For a nested metric model both calls run the full two-branch pipeline, so this
|
||||||
|
// path infers twice; the typed Depth RPC (single da_capi_depth_dense call) is the
|
||||||
|
// efficient path for nested models.
|
||||||
|
func (r *DepthAnythingCpp) runDepthPose(imagePath string) (depth []float32, h, w int, ext [12]float32, intr [9]float32, err error) {
|
||||||
|
if r.handle == 0 {
|
||||||
|
err = fmt.Errorf("depth-anything-cpp: model not loaded")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var ch, cw int32
|
||||||
|
ptr := CapiDepthPath(r.handle, imagePath, &ch, &cw)
|
||||||
|
if ptr == nil {
|
||||||
|
err = fmt.Errorf("depth-anything-cpp: da_capi_depth_path failed: %s", r.lastError())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h, w = int(ch), int(cw)
|
||||||
|
n := h * w
|
||||||
|
if n > 0 {
|
||||||
|
src := unsafe.Slice(ptr, n)
|
||||||
|
depth = make([]float32, n)
|
||||||
|
copy(depth, src)
|
||||||
|
}
|
||||||
|
CapiFreeFloats(ptr)
|
||||||
|
|
||||||
|
if rc := CapiPosePath(r.handle, imagePath, &ext[0], &intr[0]); rc != 0 {
|
||||||
|
err = fmt.Errorf("depth-anything-cpp: da_capi_pose_path failed (rc=%d): %s", rc, r.lastError())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// lastError returns the context's last error string, or "" if none.
|
||||||
|
func (r *DepthAnythingCpp) lastError() string {
|
||||||
|
if CapiLastError == nil || r.handle == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return CapiLastError(r.handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// materializeImage returns a filesystem path for an image argument that may be
|
||||||
|
// either an existing path or a base64-encoded payload. When the input is
|
||||||
|
// base64 it is decoded into a temp file; cleanup removes it (no-op for a path).
|
||||||
|
func materializeImage(arg string) (path string, cleanup func(), err error) {
|
||||||
|
cleanup = func() {}
|
||||||
|
if _, statErr := os.Stat(arg); statErr == nil {
|
||||||
|
return arg, cleanup, nil
|
||||||
|
}
|
||||||
|
// Strip an optional data URL prefix (data:image/...;base64,<payload>).
|
||||||
|
b64 := arg
|
||||||
|
if i := indexComma(b64); i >= 0 && hasDataPrefix(b64) {
|
||||||
|
b64 = b64[i+1:]
|
||||||
|
}
|
||||||
|
data, decErr := base64.StdEncoding.DecodeString(b64)
|
||||||
|
if decErr != nil {
|
||||||
|
return "", cleanup, fmt.Errorf("image is neither an existing path nor valid base64: %v", decErr)
|
||||||
|
}
|
||||||
|
f, tErr := os.CreateTemp("", "depth-anything-*.img")
|
||||||
|
if tErr != nil {
|
||||||
|
return "", cleanup, tErr
|
||||||
|
}
|
||||||
|
if _, wErr := f.Write(data); wErr != nil {
|
||||||
|
_ = f.Close()
|
||||||
|
_ = os.Remove(f.Name())
|
||||||
|
return "", cleanup, wErr
|
||||||
|
}
|
||||||
|
_ = f.Close()
|
||||||
|
name := f.Name()
|
||||||
|
return name, func() { _ = os.Remove(name) }, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasDataPrefix(s string) bool {
|
||||||
|
return len(s) >= 5 && s[:5] == "data:"
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexComma(s string) int {
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if s[i] == ',' {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeDepthPNG min-max normalises a depth map and writes it as an 8-bit
|
||||||
|
// grayscale PNG. Near = bright (255), far = dark (0), matching the usual
|
||||||
|
// depth-map convention for inverse-depth-like outputs.
|
||||||
|
func writeDepthPNG(dst string, depth []float32, h, w int) error {
|
||||||
|
if h <= 0 || w <= 0 || len(depth) < h*w {
|
||||||
|
return fmt.Errorf("depth-anything-cpp: writeDepthPNG: bad dims h=%d w=%d len=%d", h, w, len(depth))
|
||||||
|
}
|
||||||
|
dmin, dmax := minMax(depth)
|
||||||
|
span := dmax - dmin
|
||||||
|
if span <= 0 || math.IsNaN(float64(span)) {
|
||||||
|
span = 1
|
||||||
|
}
|
||||||
|
img := image.NewGray(image.Rect(0, 0, w, h))
|
||||||
|
for y := 0; y < h; y++ {
|
||||||
|
for x := 0; x < w; x++ {
|
||||||
|
v := depth[y*w+x]
|
||||||
|
n := (v - dmin) / span // 0..1
|
||||||
|
if math.IsNaN(float64(n)) {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
if n < 0 {
|
||||||
|
n = 0
|
||||||
|
} else if n > 1 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
img.Pix[y*img.Stride+x] = uint8(n * 255)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// dst is the gRPC-provided output path chosen by the LocalAI core (the
|
||||||
|
// intended write destination for the rendered depth map), not
|
||||||
|
// attacker-controlled input, so the variable path is expected here.
|
||||||
|
f, err := os.Create(dst) // #nosec G304
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = f.Close() }()
|
||||||
|
return png.Encode(f, img)
|
||||||
|
}
|
||||||
|
|
||||||
|
func minMax(v []float32) (mn, mx float32) {
|
||||||
|
if len(v) == 0 {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
mn, mx = v[0], v[0]
|
||||||
|
for _, x := range v {
|
||||||
|
if math.IsNaN(float64(x)) || math.IsInf(float64(x), 0) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if x < mn {
|
||||||
|
mn = x
|
||||||
|
}
|
||||||
|
if x > mx {
|
||||||
|
mx = x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mn, mx
|
||||||
|
}
|
||||||
62
backend/go/depth-anything-cpp/main.go
Normal file
62
backend/go/depth-anything-cpp/main.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// main.go - entry point for the depth-anything-cpp gRPC backend.
|
||||||
|
//
|
||||||
|
// Dlopens libdepthanythingcpp-<variant>.so via purego at the path in
|
||||||
|
// DEPTHANYTHING_LIBRARY (set by run.sh based on /proc/cpuinfo), registers the
|
||||||
|
// da_capi_* C ABI symbols, then starts the gRPC server.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
)
|
||||||
|
|
||||||
|
type LibFuncs struct {
|
||||||
|
FuncPtr any
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Get library name from environment variable, default to fallback
|
||||||
|
libName := os.Getenv("DEPTHANYTHING_LIBRARY")
|
||||||
|
if libName == "" {
|
||||||
|
libName = "./libdepthanythingcpp-fallback.so"
|
||||||
|
}
|
||||||
|
|
||||||
|
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
libFuncs := []LibFuncs{
|
||||||
|
{&CapiLoad, "da_capi_load"},
|
||||||
|
{&CapiLoadNested, "da_capi_load_nested"},
|
||||||
|
{&CapiFree, "da_capi_free"},
|
||||||
|
{&CapiLastError, "da_capi_last_error"},
|
||||||
|
{&CapiDepthPath, "da_capi_depth_path"},
|
||||||
|
{&CapiFreeFloats, "da_capi_free_floats"},
|
||||||
|
{&CapiPosePath, "da_capi_pose_path"},
|
||||||
|
{&CapiDepthDense, "da_capi_depth_dense"},
|
||||||
|
{&CapiPoints, "da_capi_points"},
|
||||||
|
{&CapiFreeBytes, "da_capi_free_bytes"},
|
||||||
|
{&CapiExportGlb, "da_capi_export_glb"},
|
||||||
|
{&CapiExportColmap, "da_capi_export_colmap"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, lf := range libFuncs {
|
||||||
|
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &DepthAnythingCpp{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
167
backend/go/depth-anything-cpp/main_test.go
Normal file
167
backend/go/depth-anything-cpp/main_test.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// main_test.go - end-to-end smoke test for the depth-anything-cpp gRPC backend.
|
||||||
|
//
|
||||||
|
// Spawns the compiled depth-anything-cpp binary on a free local port, dials it
|
||||||
|
// via gRPC, and exercises LoadModel + Predict against the test fixtures
|
||||||
|
// downloaded by test.sh: the small (vits) f32 GGUF of Depth Anything 3 and a
|
||||||
|
// real photo. Asserts that Predict returns a JSON payload with a positive
|
||||||
|
// depth-map width/height.
|
||||||
|
//
|
||||||
|
// The spec Skip()s cleanly if its fixtures (the model, the test image, the
|
||||||
|
// built binary, or the fallback .so) are missing, so the test target stays
|
||||||
|
// usable on a fresh checkout / on CI runners where the model hasn't been
|
||||||
|
// downloaded.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDepth(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "depth-anything-cpp backend smoke suite")
|
||||||
|
}
|
||||||
|
|
||||||
|
// freePort grabs an ephemeral TCP port and immediately releases it so the
|
||||||
|
// spawned backend can bind to it. There is a tiny TOCTOU window here but in
|
||||||
|
// practice it's adequate for a smoke test on a quiet runner.
|
||||||
|
func freePort() int {
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
Expect(err).ToNot(HaveOccurred(), "freePort listen")
|
||||||
|
port := l.Addr().(*net.TCPAddr).Port
|
||||||
|
Expect(l.Close()).To(Succeed())
|
||||||
|
return port
|
||||||
|
}
|
||||||
|
|
||||||
|
// startBackend spawns the depth-anything-cpp binary on the given port and waits
|
||||||
|
// until it accepts TCP connections (up to 10s). It mirrors how main.go resolves
|
||||||
|
// the purego library: the DEPTHANYTHING_LIBRARY env var points the dlopen at the
|
||||||
|
// freshly built fallback .so. The returned cleanup func kills the process.
|
||||||
|
func startBackend(port int) func() {
|
||||||
|
binary, err := filepath.Abs("./depth-anything-cpp")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if _, err := os.Stat(binary); err != nil {
|
||||||
|
Skip(fmt.Sprintf("backend binary not built: %s (run `make depth-anything-cpp` first)", binary))
|
||||||
|
}
|
||||||
|
|
||||||
|
libPath, err := filepath.Abs("./libdepthanythingcpp-fallback.so")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if _, err := os.Stat(libPath); err != nil {
|
||||||
|
Skip(fmt.Sprintf("fallback library not built: %s (run `make libdepthanythingcpp-fallback.so` first)", libPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
cmd := exec.Command(binary, "--addr", addr)
|
||||||
|
cmd.Env = append(os.Environ(), "DEPTHANYTHING_LIBRARY="+libPath)
|
||||||
|
cmd.Stdout = os.Stderr
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
Expect(cmd.Start()).To(Succeed())
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
if cmd.Process != nil {
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
_, _ = cmd.Process.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(10 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
c, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
_ = c.Close()
|
||||||
|
return cleanup
|
||||||
|
}
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
Fail(fmt.Sprintf("backend did not become ready on %s within 10s", addr))
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadTestImage reads the test image downloaded by test.sh and returns its
|
||||||
|
// base64-encoded content (one of the wire formats accepted by Predict).
|
||||||
|
func loadTestImage() string {
|
||||||
|
imgPath, err := filepath.Abs("test-data/test.jpg")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
imgBytes, err := os.ReadFile(imgPath)
|
||||||
|
if err != nil {
|
||||||
|
Skip(fmt.Sprintf("test image not present: %s (run test.sh first)", imgPath))
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(imgBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialBackend opens a gRPC client connection to the spawned backend.
|
||||||
|
func dialBackend(port int) (pb.BackendClient, func()) {
|
||||||
|
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return pb.NewBackendClient(conn), func() { _ = conn.Close() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelPathOrSkip resolves the model file under ./test-models/ and Skip()s the
|
||||||
|
// current spec if it's missing (not present on a fresh checkout / on CI runners
|
||||||
|
// without the download).
|
||||||
|
func modelPathOrSkip(name string) string {
|
||||||
|
modelDir, err := filepath.Abs("test-models")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
modelPath := filepath.Join(modelDir, name)
|
||||||
|
if _, err := os.Stat(modelPath); err != nil {
|
||||||
|
Skip(fmt.Sprintf("model not present: %s (run test.sh first)", modelPath))
|
||||||
|
}
|
||||||
|
return modelPath
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("depth-anything-cpp backend", func() {
|
||||||
|
It("runs depth+pose against a known-good image", func() {
|
||||||
|
modelPath := modelPathOrSkip("depth-anything-small-f32.gguf")
|
||||||
|
imgB64 := loadTestImage()
|
||||||
|
|
||||||
|
port := freePort()
|
||||||
|
cleanup := startBackend(port)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client, closeConn := dialBackend(port)
|
||||||
|
defer closeConn()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
loadResp, err := client.LoadModel(ctx, &pb.ModelOptions{
|
||||||
|
Model: "depth-anything-small-f32.gguf",
|
||||||
|
ModelFile: modelPath,
|
||||||
|
Threads: 4,
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred(), "LoadModel")
|
||||||
|
Expect(loadResp.GetSuccess()).To(BeTrue(), "LoadModel reported failure: %s", loadResp.GetMessage())
|
||||||
|
|
||||||
|
// Predict runs depth+pose and returns the JSON depthResult in Reply.Message.
|
||||||
|
reply, err := client.Predict(ctx, &pb.PredictOptions{
|
||||||
|
Images: []string{imgB64},
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred(), "Predict")
|
||||||
|
|
||||||
|
var res depthResult
|
||||||
|
Expect(json.Unmarshal(reply.GetMessage(), &res)).To(Succeed(), "Predict returned non-JSON: %q", string(reply.GetMessage()))
|
||||||
|
Expect(res.DepthW).To(BeNumerically(">", 0), "depth width should be positive")
|
||||||
|
Expect(res.DepthH).To(BeNumerically(">", 0), "depth height should be positive")
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(GinkgoWriter, "depth OK: %dx%d min=%.3f max=%.3f\n",
|
||||||
|
res.DepthW, res.DepthH, res.DepthMin, res.DepthMax)
|
||||||
|
})
|
||||||
|
})
|
||||||
64
backend/go/depth-anything-cpp/nested_e2e_test.go
Normal file
64
backend/go/depth-anything-cpp/nested_e2e_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// nested_e2e_test.go - e2e smoke for the nested two-file metric model. Loads the
|
||||||
|
// anyview branch as the main model and points the metric branch via the
|
||||||
|
// "metric_model:<file>" option (exactly as the depth-anything-3-nested gallery
|
||||||
|
// entry does), then exercises the typed Depth RPC and asserts a metric depth map.
|
||||||
|
//
|
||||||
|
// Skips cleanly unless both nested GGUFs are present under ./test-models/ and the
|
||||||
|
// backend binary + fallback .so are built.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("depth-anything-cpp nested metric model", func() {
|
||||||
|
It("loads the two-file pair via the metric_model option and returns metric depth", func() {
|
||||||
|
anyviewPath := modelPathOrSkip("depth-anything-nested-anyview.gguf")
|
||||||
|
_ = modelPathOrSkip("depth-anything-nested-metric.gguf")
|
||||||
|
imgB64 := loadTestImage()
|
||||||
|
|
||||||
|
port := freePort()
|
||||||
|
cleanup := startBackend(port)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client, closeConn := dialBackend(port)
|
||||||
|
defer closeConn()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
loadResp, err := client.LoadModel(ctx, &pb.ModelOptions{
|
||||||
|
Model: "depth-anything-nested-anyview.gguf",
|
||||||
|
ModelFile: anyviewPath,
|
||||||
|
ModelPath: filepath.Dir(anyviewPath),
|
||||||
|
Options: []string{"metric_model:depth-anything-nested-metric.gguf"},
|
||||||
|
Threads: 8,
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred(), "LoadModel(nested)")
|
||||||
|
Expect(loadResp.GetSuccess()).To(BeTrue(), "LoadModel reported failure: %s", loadResp.GetMessage())
|
||||||
|
|
||||||
|
resp, err := client.Depth(ctx, &pb.DepthRequest{
|
||||||
|
Src: imgB64,
|
||||||
|
IncludeDepth: true,
|
||||||
|
IncludePose: true,
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred(), "Depth(nested)")
|
||||||
|
Expect(resp.GetWidth()).To(BeNumerically(">", 0), "depth width")
|
||||||
|
Expect(resp.GetHeight()).To(BeNumerically(">", 0), "depth height")
|
||||||
|
Expect(resp.GetIsMetric()).To(BeTrue(), "nested output must be metric")
|
||||||
|
Expect(len(resp.GetDepth())).To(Equal(int(resp.GetWidth())*int(resp.GetHeight())), "dense depth length")
|
||||||
|
Expect(len(resp.GetExtrinsics())).To(Equal(12), "extrinsics 3x4")
|
||||||
|
Expect(resp.GetIntrinsics()[0]).To(BeNumerically(">", 0), "fx > 0")
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(GinkgoWriter, "nested depth OK: %dx%d is_metric=%v fx=%.2f\n",
|
||||||
|
resp.GetWidth(), resp.GetHeight(), resp.GetIsMetric(), resp.GetIntrinsics()[0])
|
||||||
|
})
|
||||||
|
})
|
||||||
20
backend/go/depth-anything-cpp/options_test.go
Normal file
20
backend/go/depth-anything-cpp/options_test.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = DescribeTable("optionValue",
|
||||||
|
func(opts []string, key, want string) {
|
||||||
|
Expect(optionValue(opts, key)).To(Equal(want))
|
||||||
|
},
|
||||||
|
Entry("present", []string{"foo:bar", "metric_model:m.gguf"}, "metric_model", "m.gguf"),
|
||||||
|
Entry("absent", []string{"foo:bar"}, "metric_model", ""),
|
||||||
|
Entry("nil", []string(nil), "metric_model", ""),
|
||||||
|
Entry("trims space", []string{"metric_model: m.gguf "}, "metric_model", "m.gguf"),
|
||||||
|
Entry("value with colon", []string{"metric_model:a:b.gguf"}, "metric_model", "a:b.gguf"),
|
||||||
|
Entry("first wins", []string{"metric_model:first.gguf", "metric_model:second.gguf"}, "metric_model", "first.gguf"),
|
||||||
|
Entry("empty value", []string{"metric_model:"}, "metric_model", ""),
|
||||||
|
Entry("prefix not key", []string{"metric_model_extra:x"}, "metric_model", ""),
|
||||||
|
)
|
||||||
59
backend/go/depth-anything-cpp/package.sh
Executable file
59
backend/go/depth-anything-cpp/package.sh
Executable file
@@ -0,0 +1,59 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Script to copy the appropriate libraries based on architecture
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
REPO_ROOT="${CURDIR}/../../.."
|
||||||
|
|
||||||
|
# Create lib directory
|
||||||
|
mkdir -p $CURDIR/package/lib
|
||||||
|
|
||||||
|
cp -avf $CURDIR/libdepthanythingcpp-*.so $CURDIR/package/
|
||||||
|
cp -avf $CURDIR/depth-anything-cpp $CURDIR/package/
|
||||||
|
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||||
|
|
||||||
|
# Detect architecture and copy appropriate libraries
|
||||||
|
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||||
|
# x86_64 architecture
|
||||||
|
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||||
|
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
|
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||||
|
# ARM64 architecture
|
||||||
|
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||||
|
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
|
elif [ $(uname -s) = "Darwin" ]; then
|
||||||
|
echo "Detected Darwin"
|
||||||
|
else
|
||||||
|
echo "Error: Could not detect architecture"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Package GPU libraries based on BUILD_TYPE
|
||||||
|
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||||
|
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||||
|
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||||
|
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||||
|
package_gpu_libs
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Packaging completed successfully"
|
||||||
|
ls -liah $CURDIR/package/
|
||||||
|
ls -liah $CURDIR/package/lib/
|
||||||
52
backend/go/depth-anything-cpp/run.sh
Executable file
52
backend/go/depth-anything-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# Get the absolute current dir where the script is located
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
|
||||||
|
cd /
|
||||||
|
|
||||||
|
echo "CPU info:"
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||||
|
grep -e "flags" /proc/cpuinfo | head -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
LIBRARY="$CURDIR/libdepthanythingcpp-fallback.so"
|
||||||
|
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX found OK"
|
||||||
|
if [ -e $CURDIR/libdepthanythingcpp-avx.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libdepthanythingcpp-avx.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX2 found OK"
|
||||||
|
if [ -e $CURDIR/libdepthanythingcpp-avx2.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libdepthanythingcpp-avx2.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check avx 512
|
||||||
|
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX512F found OK"
|
||||||
|
if [ -e $CURDIR/libdepthanythingcpp-avx512.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libdepthanythingcpp-avx512.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||||
|
export DEPTHANYTHING_LIBRARY=$LIBRARY
|
||||||
|
|
||||||
|
# If there is a lib/ld.so, use it
|
||||||
|
if [ -f $CURDIR/lib/ld.so ]; then
|
||||||
|
echo "Using lib/ld.so"
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
|
exec $CURDIR/lib/ld.so $CURDIR/depth-anything-cpp "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
|
exec $CURDIR/depth-anything-cpp "$@"
|
||||||
45
backend/go/depth-anything-cpp/test.sh
Executable file
45
backend/go/depth-anything-cpp/test.sh
Executable file
@@ -0,0 +1,45 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
|
||||||
|
echo "Running depth-anything-cpp backend tests..."
|
||||||
|
|
||||||
|
# Test model from the mudler/depth-anything.cpp-gguf HuggingFace repo. The small
|
||||||
|
# (vits) f32 GGUF is the lightest backbone (~131 MB), so it keeps the download
|
||||||
|
# cheap. It is resumed with `curl -C -` and skipped entirely if already present.
|
||||||
|
DEPTHANYTHING_MODEL_DIR="${DEPTHANYTHING_MODEL_DIR:-$CURDIR/test-models}"
|
||||||
|
|
||||||
|
DEPTHANYTHING_MODEL_FILE="${DEPTHANYTHING_MODEL_FILE:-depth-anything-small-f32.gguf}"
|
||||||
|
DEPTHANYTHING_MODEL_URL="${DEPTHANYTHING_MODEL_URL:-https://huggingface.co/mudler/depth-anything.cpp-gguf/resolve/main/depth-anything-small-f32.gguf}"
|
||||||
|
|
||||||
|
mkdir -p "$DEPTHANYTHING_MODEL_DIR"
|
||||||
|
|
||||||
|
if [ ! -f "$DEPTHANYTHING_MODEL_DIR/$DEPTHANYTHING_MODEL_FILE" ]; then
|
||||||
|
echo "Downloading depth-anything small f32 model (~131 MB)..."
|
||||||
|
# -C - resumes a partial download so an interrupted run doesn't restart from 0.
|
||||||
|
curl -L -C - -o "$DEPTHANYTHING_MODEL_DIR/$DEPTHANYTHING_MODEL_FILE" "$DEPTHANYTHING_MODEL_URL" --progress-bar
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Use a real photo (people + cars) from the upstream rf-detr.cpp repo (~46 KB).
|
||||||
|
# Depth estimation needs real content; a synthetic image would be degenerate.
|
||||||
|
TEST_IMAGE_DIR="$CURDIR/test-data"
|
||||||
|
TEST_IMAGE_FILE="$TEST_IMAGE_DIR/test.jpg"
|
||||||
|
TEST_IMAGE_URL="${TEST_IMAGE_URL:-https://raw.githubusercontent.com/mudler/rf-detr.cpp/main/tests/fixtures/ci/test_image.jpg}"
|
||||||
|
|
||||||
|
mkdir -p "$TEST_IMAGE_DIR"
|
||||||
|
if [ ! -f "$TEST_IMAGE_FILE" ]; then
|
||||||
|
echo "Downloading test image..."
|
||||||
|
curl -L -o "$TEST_IMAGE_FILE" "$TEST_IMAGE_URL" --progress-bar
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "depth-anything-cpp test setup complete."
|
||||||
|
echo " model: $DEPTHANYTHING_MODEL_DIR/$DEPTHANYTHING_MODEL_FILE"
|
||||||
|
echo " test image: $TEST_IMAGE_FILE"
|
||||||
|
|
||||||
|
# Run the Go smoke test: spawns the backend binary on a free port, calls
|
||||||
|
# LoadModel + Predict via gRPC against the downloaded GGUF + image.
|
||||||
|
echo ""
|
||||||
|
echo "Running Go smoke test..."
|
||||||
|
cd "$CURDIR"
|
||||||
|
go test -v -timeout 30m ./...
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
# dllm backend Makefile.
|
|
||||||
#
|
|
||||||
# Upstream pin lives below as DLLM_VERSION?=<sha> so .github/bump_deps.sh
|
|
||||||
# can find and update it - matches the whisper.cpp / parakeet-cpp / ds4
|
|
||||||
# convention.
|
|
||||||
#
|
|
||||||
# Local dev shortcut: if you already have an out-of-tree dllm.cpp build,
|
|
||||||
# you can symlink the .so into this directory and skip the clone/cmake
|
|
||||||
# steps entirely, e.g.:
|
|
||||||
#
|
|
||||||
# ln -sf /path/to/dllm.cpp/build/libdllm.so .
|
|
||||||
# go build -o dllm-grpc .
|
|
||||||
#
|
|
||||||
# That's what the gated C-ABI binding smoke uses (DLLM_TEST_LIBRARY). The
|
|
||||||
# default target below does the proper clone-at-pin + cmake build so CI
|
|
||||||
# doesn't need a side-checkout.
|
|
||||||
#
|
|
||||||
# NOTE: github.com/mudler/dllm.cpp is still private (publishing is planned);
|
|
||||||
# until then the anonymous clone below fails. Use the symlink shortcut above
|
|
||||||
# with a local checkout, or a git credential helper with access to the repo.
|
|
||||||
|
|
||||||
# The pin below is the P5 performance-parity head (device-resident
|
|
||||||
# self-conditioning, full-GPU placement at ngl >= n_layer, graph reuse,
|
|
||||||
# device-side EB reductions: ~8x per-step on GB10, see dllm.cpp
|
|
||||||
# docs/validation.md section 10). C-ABI unchanged (still version 1). It
|
|
||||||
# also carries the multimodal entry points (dllm_capi_generate_mm /
|
|
||||||
# dllm_capi_generate_stream_mm) the image-input path probes for; older
|
|
||||||
# libs still load, but image requests then fail with "library predates
|
|
||||||
# the multimodal entry points".
|
|
||||||
DLLM_VERSION?=320b57756efc3460169b8ea9e8c782867198f2a5
|
|
||||||
DLLM_REPO?=https://github.com/mudler/dllm.cpp
|
|
||||||
|
|
||||||
GOCMD?=go
|
|
||||||
GO_TAGS?=
|
|
||||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
|
||||||
|
|
||||||
BUILD_TYPE?=
|
|
||||||
NATIVE?=false
|
|
||||||
|
|
||||||
# libdllm.so is self-contained: dllm.cpp's CMakeLists statically absorbs ggml
|
|
||||||
# (BUILD_SHARED_LIBS=OFF + PIC) into the shared lib, so dlopen needs no
|
|
||||||
# libggml*.so alongside it, only system libs (libstdc++/libgomp/libc) the
|
|
||||||
# runtime image already provides. Tests/CLI are upstream-only concerns.
|
|
||||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DDLLM_BUILD_TESTS=OFF
|
|
||||||
|
|
||||||
ifeq ($(NATIVE),false)
|
|
||||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
|
||||||
endif
|
|
||||||
|
|
||||||
# Same arch set the sibling ggml backends (acestep/vibevoice/qwen3-tts) bake
|
|
||||||
# for their cublas images; override for a native build.
|
|
||||||
CUDA_ARCHITECTURES?=75-virtual;80-virtual;86-real;89-real
|
|
||||||
|
|
||||||
# dllm.cpp gates CUDA behind DLLM_CUDA (set(GGML_CUDA ... CACHE FORCE)), so
|
|
||||||
# forward that instead of a bare -DGGML_CUDA=ON.
|
|
||||||
ifeq ($(BUILD_TYPE),cublas)
|
|
||||||
CMAKE_ARGS+=-DDLLM_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="$(CUDA_ARCHITECTURES)"
|
|
||||||
endif
|
|
||||||
|
|
||||||
.PHONY: dllm-grpc package build clean purge test all
|
|
||||||
|
|
||||||
all: dllm-grpc
|
|
||||||
|
|
||||||
# Clone the upstream dllm.cpp source at the pinned commit (ggml comes in as
|
|
||||||
# a submodule). Directory acts as the target so make only re-clones when
|
|
||||||
# missing. After a DLLM_VERSION bump, run 'make purge && make' to refetch.
|
|
||||||
sources/dllm.cpp:
|
|
||||||
mkdir -p sources/dllm.cpp
|
|
||||||
cd sources/dllm.cpp && \
|
|
||||||
git init -q && \
|
|
||||||
git remote add origin $(DLLM_REPO) && \
|
|
||||||
git fetch --depth 1 origin $(DLLM_VERSION) && \
|
|
||||||
git checkout FETCH_HEAD && \
|
|
||||||
git submodule update --init --recursive --depth 1 --single-branch
|
|
||||||
|
|
||||||
# Build the shared lib out-of-tree, then stage it next to the Go sources so
|
|
||||||
# purego.Dlopen("libdllm.so") and the packaging step both pick it up.
|
|
||||||
libdllm.so: sources/dllm.cpp
|
|
||||||
cmake -B sources/dllm.cpp/build -S sources/dllm.cpp $(CMAKE_ARGS)
|
|
||||||
cmake --build sources/dllm.cpp/build --config Release -j$(JOBS)
|
|
||||||
cp -fv sources/dllm.cpp/build/libdllm.so ./
|
|
||||||
|
|
||||||
dllm-grpc: libdllm.so main.go capi.go
|
|
||||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o dllm-grpc .
|
|
||||||
|
|
||||||
package: dllm-grpc
|
|
||||||
bash package.sh
|
|
||||||
|
|
||||||
build: package
|
|
||||||
|
|
||||||
# Test target. The C-ABI binding smoke is gated on DLLM_TEST_LIBRARY +
|
|
||||||
# DLLM_TEST_TINY_MODEL; without them the gated specs auto-skip and only the
|
|
||||||
# pure-Go helper specs run.
|
|
||||||
test:
|
|
||||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
|
||||||
|
|
||||||
clean: purge
|
|
||||||
rm -rf libdllm.so* package dllm-grpc
|
|
||||||
|
|
||||||
purge:
|
|
||||||
rm -rf sources/dllm.cpp
|
|
||||||
@@ -1,326 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// Typed Go wrappers over dllm.cpp's flat C-ABI (include/dllm_capi.h, ABI v1).
|
|
||||||
//
|
|
||||||
// Contract highlights the wrappers encode (see the header + src/capi.cpp):
|
|
||||||
// - tokenize_json/generate return malloc'd char* the CALLER owns: bound as
|
|
||||||
// uintptr, copied with goStringFromCPtr, released via dllm_capi_free_string.
|
|
||||||
// - last_error returns a BORROWED pointer (valid until the next call on the
|
|
||||||
// same ctx): bound as a plain string (purego copies), never freed, and only
|
|
||||||
// read AFTER the failing call has returned - reading it while a generate is
|
|
||||||
// in flight on the same ctx violates the per-ctx serialization contract.
|
|
||||||
// - All entry points except dllm_capi_cancel must be externally serialized
|
|
||||||
// per ctx (one ctx = one concurrent generate/tokenize). Cancel only flips
|
|
||||||
// an atomic and may be called from any goroutine mid-generate.
|
|
||||||
// - No C++ exception crosses the boundary; failures land in last_error.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/ebitengine/purego"
|
|
||||||
)
|
|
||||||
|
|
||||||
// dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written
|
|
||||||
// against; main.go refuses to start against a libdllm.so reporting another.
|
|
||||||
const dllmABIVersion = 1
|
|
||||||
|
|
||||||
// purego-bound entry points from libdllm.so. Names match dllm_capi.h
|
|
||||||
// exactly; loadCAPI (main.go) fills these in at boot.
|
|
||||||
var (
|
|
||||||
cppAbiVersion func() int32
|
|
||||||
cppLoad func(ggufPath, paramsJSON string) uintptr
|
|
||||||
cppFree func(ctx uintptr)
|
|
||||||
cppLastError func(ctx uintptr) string // borrowed pointer: purego copies, do NOT free
|
|
||||||
cppFreeString func(s uintptr)
|
|
||||||
// malloc'd char* returns, hence uintptr (see loadCAPI's doc comment).
|
|
||||||
cppTokenizeJSON func(ctx uintptr, text string) uintptr
|
|
||||||
cppGenerate func(ctx uintptr, prompt, optsJSON string) uintptr
|
|
||||||
// on_block/on_step are C function pointers produced by purego.NewCallback;
|
|
||||||
// userData carries the streamCallStates registry key.
|
|
||||||
cppGenerateStream func(ctx uintptr, prompt, optsJSON string, onBlock, onStep, userData uintptr) int32
|
|
||||||
cppCancel func(ctx uintptr)
|
|
||||||
)
|
|
||||||
|
|
||||||
// Optional multimodal entry points (dllm_capi.h's P4 surface). The ABI
|
|
||||||
// version stays 1: presence is detected by PROBING the symbols with Dlsym at
|
|
||||||
// boot (loadCAPI, mirroring the parakeet-cpp optional-symbol pattern). nil
|
|
||||||
// means the loaded libdllm.so predates the mm surface; the wrappers below
|
|
||||||
// then fail with errMMUnsupported instead of crashing on a nil call.
|
|
||||||
var (
|
|
||||||
cppGenerateMM func(ctx uintptr, prompt, imagesJSON, optsJSON string) uintptr
|
|
||||||
cppGenerateStreamMM func(ctx uintptr, prompt, imagesJSON, optsJSON string, onBlock, onStep, userData uintptr) int32
|
|
||||||
)
|
|
||||||
|
|
||||||
// mmImageMarker is the literal placeholder dllm_capi_generate_mm expands to
|
|
||||||
// <boi> + soft-token placeholders + <eoi> (dllm_capi.h placeholder contract;
|
|
||||||
// capi.cpp MM_MARKER). The prompt must carry exactly one marker per
|
|
||||||
// images_json entry, in image order.
|
|
||||||
const mmImageMarker = "<image>"
|
|
||||||
|
|
||||||
// errMMUnsupported is returned for image-bearing requests against an old
|
|
||||||
// text-only libdllm.so (the Dlsym probe found no mm symbols).
|
|
||||||
var errMMUnsupported = errors.New(
|
|
||||||
"dllm: image input requires libdllm.so with the multimodal entry points (dllm_capi_generate_mm), but the loaded library predates them - rebuild/upgrade the dllm backend to use images")
|
|
||||||
|
|
||||||
// cMMSupported reports whether the loaded libdllm.so carries the multimodal
|
|
||||||
// generate pair. Both symbols ship together (same dllm.cpp commit), but the
|
|
||||||
// guard requires both anyway so a half-present surface can never dispatch.
|
|
||||||
func cMMSupported() bool {
|
|
||||||
return cppGenerateMM != nil && cppGenerateStreamMM != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION.
|
|
||||||
func cAbiVersion() int32 {
|
|
||||||
return cppAbiVersion()
|
|
||||||
}
|
|
||||||
|
|
||||||
// cLoad opens the GGUF at path with the flat params JSON (e.g.
|
|
||||||
// {"n_gpu_layers":99}). Returns 0 on failure; per the header contract there
|
|
||||||
// is no ctx to carry the reason, the C side logs it to stderr (and
|
|
||||||
// cLastError(0) only yields the static NULL-ctx message).
|
|
||||||
func cLoad(path, paramsJSON string) uintptr {
|
|
||||||
return cppLoad(path, paramsJSON)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cFree releases a ctx; safe on 0 (delete nullptr).
|
|
||||||
func cFree(h uintptr) {
|
|
||||||
cppFree(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cLastError returns the ctx's last error message (or the static NULL-ctx
|
|
||||||
// message for h==0). The C pointer is borrowed and only valid until the next
|
|
||||||
// call on the same ctx; purego's string return copies it immediately, so the
|
|
||||||
// returned Go string is safe to keep. Must not be called while another call
|
|
||||||
// on the same ctx is in flight.
|
|
||||||
func cLastError(h uintptr) string {
|
|
||||||
return cppLastError(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// lastErrorOr is cLastError with a fallback for the empty-message case, so
|
|
||||||
// wrapped errors never end in ": ".
|
|
||||||
func lastErrorOr(h uintptr, fallback string) string {
|
|
||||||
if msg := cLastError(h); msg != "" {
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
// cTokenizeJSON tokenizes text (the C side prepends bos per vocab.add_bos)
|
|
||||||
// and returns the token ids as a JSON array string, e.g. "[2,18]".
|
|
||||||
func cTokenizeJSON(h uintptr, text string) (string, error) {
|
|
||||||
ret := cppTokenizeJSON(h, text)
|
|
||||||
if ret == 0 {
|
|
||||||
return "", fmt.Errorf("dllm: tokenize failed: %s", lastErrorOr(h, "unknown error"))
|
|
||||||
}
|
|
||||||
out := goStringFromCPtr(ret)
|
|
||||||
cppFreeString(ret)
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cGenerate runs a blocking generation and returns the detokenized text.
|
|
||||||
// optsJSON must be a FLAT JSON object of scalars (use buildOptsJSON); the C
|
|
||||||
// parser rejects nested objects/arrays. NULL return -> last_error (read only
|
|
||||||
// after the call returned, per the serialization contract); a cancelled call
|
|
||||||
// surfaces as the "cancelled" message.
|
|
||||||
func cGenerate(h uintptr, prompt, optsJSON string) (string, error) {
|
|
||||||
ret := cppGenerate(h, prompt, optsJSON)
|
|
||||||
if ret == 0 {
|
|
||||||
return "", fmt.Errorf("dllm: generate failed: %s", lastErrorOr(h, "unknown error"))
|
|
||||||
}
|
|
||||||
out := goStringFromCPtr(ret)
|
|
||||||
cppFreeString(ret)
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cGenerateMM is cGenerate's multimodal counterpart. imagesJSON is the flat
|
|
||||||
// JSON array of image entries (data: base64 URIs here; the C side also takes
|
|
||||||
// file paths) and the prompt must carry one mmImageMarker per entry - the
|
|
||||||
// engine enforces the 1:1 match and reports mismatches through last_error.
|
|
||||||
func cGenerateMM(h uintptr, prompt, imagesJSON, optsJSON string) (string, error) {
|
|
||||||
if !cMMSupported() {
|
|
||||||
return "", errMMUnsupported
|
|
||||||
}
|
|
||||||
ret := cppGenerateMM(h, prompt, imagesJSON, optsJSON)
|
|
||||||
if ret == 0 {
|
|
||||||
return "", fmt.Errorf("dllm: generate_mm failed: %s", lastErrorOr(h, "unknown error"))
|
|
||||||
}
|
|
||||||
out := goStringFromCPtr(ret)
|
|
||||||
cppFreeString(ret)
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamCallState carries the Go callbacks for one in-flight
|
|
||||||
// cGenerateStream call; the registry key travels through C as user_data.
|
|
||||||
// The map shape mirrors the whisper backend's streamCallStates: only one
|
|
||||||
// entry per ctx is ever live (the C-ABI is serialized per ctx), but keying
|
|
||||||
// by call survives multiple models/processes sharing the package.
|
|
||||||
type streamCallState struct {
|
|
||||||
onBlock func(text string)
|
|
||||||
onStep func(step, total int, preview string)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
streamCallStates sync.Map // uint64 -> *streamCallState
|
|
||||||
streamCallSeq atomic.Uint64
|
|
||||||
|
|
||||||
// purego.NewCallback allocates a finite, never-released callback slot, so
|
|
||||||
// the two trampolines are created exactly once and reused across calls.
|
|
||||||
streamCbOnce sync.Once
|
|
||||||
blockCbPtr uintptr
|
|
||||||
stepCbPtr uintptr
|
|
||||||
)
|
|
||||||
|
|
||||||
// onBlockTrampoline is the Go side of dllm_block_cb. It runs on the C
|
|
||||||
// calling thread, mid-generate: keep it tiny and non-blocking (callers that
|
|
||||||
// bridge to goroutines must hand off via buffered channels). The text
|
|
||||||
// pointer is only valid for the duration of the invocation, so it is copied
|
|
||||||
// to a Go string immediately.
|
|
||||||
func onBlockTrampoline(text uintptr, userData uintptr) {
|
|
||||||
v, ok := streamCallStates.Load(uint64(userData))
|
|
||||||
if !ok {
|
|
||||||
return // call already torn down
|
|
||||||
}
|
|
||||||
state := v.(*streamCallState)
|
|
||||||
if state.onBlock != nil {
|
|
||||||
state.onBlock(goStringFromCPtr(text))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// onStepTrampoline is the Go side of dllm_step_cb; same threading and
|
|
||||||
// lifetime caveats as onBlockTrampoline.
|
|
||||||
func onStepTrampoline(step int32, totalSteps int32, canvasPreview uintptr, userData uintptr) {
|
|
||||||
v, ok := streamCallStates.Load(uint64(userData))
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
state := v.(*streamCallState)
|
|
||||||
if state.onStep != nil {
|
|
||||||
state.onStep(int(step), int(totalSteps), goStringFromCPtr(canvasPreview))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// withStreamCallbacks registers onBlock/onStep in the trampoline registry
|
|
||||||
// for the duration of one streaming C call and invokes call with the C
|
|
||||||
// function pointers (NULL for absent callbacks, so the C side skips the
|
|
||||||
// per-block / per-step detokenize work entirely) plus the registry key to
|
|
||||||
// pass as user_data. Shared by the text and multimodal stream wrappers.
|
|
||||||
func withStreamCallbacks(onBlock func(text string), onStep func(step, total int, preview string), call func(blockPtr, stepPtr, userData uintptr) int32) int32 {
|
|
||||||
streamCbOnce.Do(func() {
|
|
||||||
blockCbPtr = purego.NewCallback(onBlockTrampoline)
|
|
||||||
stepCbPtr = purego.NewCallback(onStepTrampoline)
|
|
||||||
})
|
|
||||||
|
|
||||||
id := streamCallSeq.Add(1)
|
|
||||||
streamCallStates.Store(id, &streamCallState{onBlock: onBlock, onStep: onStep})
|
|
||||||
defer streamCallStates.Delete(id)
|
|
||||||
|
|
||||||
var blockPtr, stepPtr uintptr
|
|
||||||
if onBlock != nil {
|
|
||||||
blockPtr = blockCbPtr
|
|
||||||
}
|
|
||||||
if onStep != nil {
|
|
||||||
stepPtr = stepCbPtr
|
|
||||||
}
|
|
||||||
return call(blockPtr, stepPtr, uintptr(id))
|
|
||||||
}
|
|
||||||
|
|
||||||
// cGenerateStream runs a generation with per-committed-block (onBlock) and
|
|
||||||
// per-denoising-step (onStep) callbacks; either may be nil. The callbacks
|
|
||||||
// run on the C thread (see the trampoline docs). Returns an error carrying
|
|
||||||
// last_error on failure; cancellation surfaces as the "cancelled" message.
|
|
||||||
func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error {
|
|
||||||
rc := withStreamCallbacks(onBlock, onStep, func(blockPtr, stepPtr, userData uintptr) int32 {
|
|
||||||
return cppGenerateStream(h, prompt, optsJSON, blockPtr, stepPtr, userData)
|
|
||||||
})
|
|
||||||
if rc != 0 {
|
|
||||||
return fmt.Errorf("dllm: generate_stream failed: %s", lastErrorOr(h, "unknown error"))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cGenerateStreamMM is cGenerateStream's multimodal counterpart; see
|
|
||||||
// cGenerateMM for the imagesJSON/marker contract.
|
|
||||||
func cGenerateStreamMM(h uintptr, prompt, imagesJSON, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error {
|
|
||||||
if !cMMSupported() {
|
|
||||||
return errMMUnsupported
|
|
||||||
}
|
|
||||||
rc := withStreamCallbacks(onBlock, onStep, func(blockPtr, stepPtr, userData uintptr) int32 {
|
|
||||||
return cppGenerateStreamMM(h, prompt, imagesJSON, optsJSON, blockPtr, stepPtr, userData)
|
|
||||||
})
|
|
||||||
if rc != 0 {
|
|
||||||
return fmt.Errorf("dllm: generate_stream_mm failed: %s", lastErrorOr(h, "unknown error"))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cCancel requests cancellation of the in-flight generate on h. This is the
|
|
||||||
// ONE entry point safe to call from any goroutine while a generate runs (it
|
|
||||||
// only flips an atomic). Note the cancel-reset race from the header: each
|
|
||||||
// generate resets the flag on entry, so a watchdog should re-issue cancel if
|
|
||||||
// the call has not returned.
|
|
||||||
func cCancel(h uintptr) {
|
|
||||||
cppCancel(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildOptsJSON renders generation options as the flat JSON object the
|
|
||||||
// C-ABI expects (known keys: n_predict, blocks, seed, eb_*, kv_cache). The
|
|
||||||
// C-side scanner only understands scalar number/string values and rejects
|
|
||||||
// nested objects/arrays loudly; bools are rejected here too because the
|
|
||||||
// scanner has no concept of them. Fail loud rather than let an option be
|
|
||||||
// silently misread.
|
|
||||||
//
|
|
||||||
// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g.
|
|
||||||
// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys
|
|
||||||
// (kv_cache: auto|on|off) can contain those bytes today; if one ever does,
|
|
||||||
// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString.
|
|
||||||
func buildOptsJSON(opts map[string]any) (string, error) {
|
|
||||||
if len(opts) == 0 {
|
|
||||||
return "{}", nil
|
|
||||||
}
|
|
||||||
for k, v := range opts {
|
|
||||||
switch v.(type) {
|
|
||||||
case string,
|
|
||||||
int, int8, int16, int32, int64,
|
|
||||||
uint, uint8, uint16, uint32, uint64,
|
|
||||||
float32, float64,
|
|
||||||
json.Number:
|
|
||||||
// scalar: fine
|
|
||||||
default:
|
|
||||||
return "", fmt.Errorf("dllm: opts key %q has non-scalar value %T (the C-ABI only accepts flat number/string scalars)", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b, err := json.Marshal(opts)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("dllm: marshal opts: %w", err)
|
|
||||||
}
|
|
||||||
return string(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is
|
|
||||||
// the raw pointer returned by purego from the C-ABI (a malloc'd buffer the
|
|
||||||
// caller owns, or a callback argument only valid during the invocation);
|
|
||||||
// owning callers must free it via cppFreeString after the copy lands.
|
|
||||||
//
|
|
||||||
// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check,
|
|
||||||
// which can't distinguish a C-owned heap pointer from Go-managed memory (the
|
|
||||||
// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting
|
|
||||||
// through &cptr below is equivalent at runtime and keeps plain `go vet`
|
|
||||||
// clean. It is safe either way: the pointer addresses C memory the Go GC
|
|
||||||
// neither tracks nor moves, and we dereference it immediately to copy the
|
|
||||||
// bytes out.
|
|
||||||
func goStringFromCPtr(cptr uintptr) string {
|
|
||||||
if cptr == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above)
|
|
||||||
n := 0
|
|
||||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
return string(unsafe.Slice((*byte)(p), n))
|
|
||||||
}
|
|
||||||
@@ -1,622 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models).
|
|
||||||
//
|
|
||||||
// Wiring overview:
|
|
||||||
// - Load opens the GGUF via dllm_capi_load and starts the per-model worker
|
|
||||||
// goroutine that serializes every C call (see submit).
|
|
||||||
// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the
|
|
||||||
// request carries raw messages (use_tokenizer_template), the backend owns
|
|
||||||
// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies
|
|
||||||
// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend.
|
|
||||||
// - The legacy Predict / PredictStream methods delegate to the rich pair
|
|
||||||
// (cloud-proxy precedent); the gRPC server prefers the rich path anyway.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
||||||
"github.com/mudler/xlog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The gRPC server cancels in-flight generations on client disconnect only
|
|
||||||
// for backends advertising the Cancellable capability; keep Dllm pinned to
|
|
||||||
// it so a signature drift fails the build, not the disconnect path.
|
|
||||||
var _ grpc.Cancellable = (*Dllm)(nil)
|
|
||||||
|
|
||||||
// generator is the seam between the backend wiring and the dllm.cpp C-ABI:
|
|
||||||
// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON
|
|
||||||
// family, while tests substitute a fake to exercise prompt construction,
|
|
||||||
// parsing and serialization without libdllm.so.
|
|
||||||
type generator interface {
|
|
||||||
generate(prompt, optsJSON string) (string, error)
|
|
||||||
// generateStream invokes onBlock once per committed diffusion block, on
|
|
||||||
// the thread running the C call, before returning.
|
|
||||||
generateStream(prompt, optsJSON string, onBlock func(text string)) error
|
|
||||||
// generateMM / generateStreamMM are the multimodal counterparts:
|
|
||||||
// imagesJSON is a flat JSON array of data: base64 URIs and the prompt
|
|
||||||
// carries one mmImageMarker per entry (dllm_capi.h placeholder
|
|
||||||
// contract). Against an old text-only libdllm.so they fail with
|
|
||||||
// errMMUnsupported.
|
|
||||||
generateMM(prompt, imagesJSON, optsJSON string) (string, error)
|
|
||||||
generateStreamMM(prompt, imagesJSON, optsJSON string, onBlock func(text string)) error
|
|
||||||
tokenizeJSON(text string) (string, error)
|
|
||||||
// cancel is the ONE entry point safe to call concurrently with an
|
|
||||||
// in-flight generate on the same ctx (dllm_capi.h: it only flips an
|
|
||||||
// atomic; everything else must be externally serialized per ctx).
|
|
||||||
cancel()
|
|
||||||
free()
|
|
||||||
}
|
|
||||||
|
|
||||||
// capiGenerator is the production generator over one dllm_ctx handle.
|
|
||||||
type capiGenerator struct {
|
|
||||||
h uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) {
|
|
||||||
return cGenerate(g.h, prompt, optsJSON)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
|
|
||||||
// on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is
|
|
||||||
// passed as nil for now: a future progress hook for the React UI can
|
|
||||||
// plumb it through without touching the C binding.
|
|
||||||
return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *capiGenerator) generateMM(prompt, imagesJSON, optsJSON string) (string, error) {
|
|
||||||
return cGenerateMM(g.h, prompt, imagesJSON, optsJSON)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *capiGenerator) generateStreamMM(prompt, imagesJSON, optsJSON string, onBlock func(text string)) error {
|
|
||||||
// on_step is nil for the same reason as generateStream.
|
|
||||||
return cGenerateStreamMM(g.h, prompt, imagesJSON, optsJSON, onBlock, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *capiGenerator) tokenizeJSON(text string) (string, error) {
|
|
||||||
return cTokenizeJSON(g.h, text)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *capiGenerator) cancel() {
|
|
||||||
cCancel(g.h)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *capiGenerator) free() {
|
|
||||||
cFree(g.h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts
|
|
||||||
// one backend process per model).
|
|
||||||
type Dllm struct {
|
|
||||||
base.Base
|
|
||||||
|
|
||||||
gen generator
|
|
||||||
// genOpts holds the model-level generation overrides parsed from
|
|
||||||
// ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes
|
|
||||||
// them per-generate, not per-load, so they are merged into every
|
|
||||||
// request's opts JSON (requestOptsJSON).
|
|
||||||
genOpts map[string]any
|
|
||||||
|
|
||||||
// jobs is the per-model worker queue. dllm_capi.h requires every entry
|
|
||||||
// point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one
|
|
||||||
// ctx = one concurrent generate/tokenize; last_error is unsafe to read
|
|
||||||
// while a call is in flight). A single goroutine owning all C calls makes
|
|
||||||
// that contract structural instead of relying on lock discipline.
|
|
||||||
jobs chan func()
|
|
||||||
workerWG sync.WaitGroup
|
|
||||||
|
|
||||||
// genMu guards gen against Free racing in-flight requests: requests hold
|
|
||||||
// the read lock for their full duration (they stay concurrent with each
|
|
||||||
// other - the worker still serializes the C calls), Free takes the write
|
|
||||||
// lock so it can only run when no request is in flight.
|
|
||||||
genMu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Dllm) startWorker() {
|
|
||||||
d.jobs = make(chan func())
|
|
||||||
d.workerWG.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer d.workerWG.Done()
|
|
||||||
for job := range d.jobs {
|
|
||||||
job()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// submit runs job on the worker goroutine and waits for it to finish.
|
|
||||||
// Concurrent gRPC requests therefore queue up and execute one at a time
|
|
||||||
// against the single dllm_ctx.
|
|
||||||
func (d *Dllm) submit(job func()) {
|
|
||||||
done := make(chan struct{})
|
|
||||||
d.jobs <- func() {
|
|
||||||
defer close(done)
|
|
||||||
job()
|
|
||||||
}
|
|
||||||
<-done
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load opens the GGUF and prepares the worker. Load-time engine parameters
|
|
||||||
// travel as the flat params JSON of dllm_capi_load; generation overrides
|
|
||||||
// from Options are stored for per-request opts JSON instead (the C-ABI has
|
|
||||||
// no per-load sampler state).
|
|
||||||
func (d *Dllm) Load(opts *pb.ModelOptions) error {
|
|
||||||
if d.gen != nil {
|
|
||||||
return errors.New("dllm: model already loaded")
|
|
||||||
}
|
|
||||||
|
|
||||||
params := map[string]any{
|
|
||||||
"n_gpu_layers": opts.GetNGPULayers(),
|
|
||||||
}
|
|
||||||
if opts.GetThreads() > 0 {
|
|
||||||
params["n_threads"] = opts.GetThreads()
|
|
||||||
}
|
|
||||||
if opts.GetContextSize() > 0 {
|
|
||||||
params["ctx_len"] = opts.GetContextSize()
|
|
||||||
}
|
|
||||||
paramsJSON, err := buildOptsJSON(params)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.genOpts = parseModelGenOpts(opts.GetOptions())
|
|
||||||
|
|
||||||
h := cLoad(opts.GetModelFile(), paramsJSON)
|
|
||||||
if h == 0 {
|
|
||||||
// No ctx exists on load failure, so last_error(NULL) only carries the
|
|
||||||
// static NULL-ctx message; the real reason is on the backend's stderr.
|
|
||||||
return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)",
|
|
||||||
opts.GetModelFile(), lastErrorOr(0, "unknown error"))
|
|
||||||
}
|
|
||||||
d.gen = &capiGenerator{h: h}
|
|
||||||
d.startWorker()
|
|
||||||
xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free releases the dllm ctx and stops the worker. Safe when never loaded.
|
|
||||||
//
|
|
||||||
// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the
|
|
||||||
// model-unload path around line 764) calls Free with no locking of its own,
|
|
||||||
// and base.Base provides none either. Without it a request racing Free would
|
|
||||||
// panic sending on the closed jobs channel - or worse, generate on a freed C
|
|
||||||
// ctx. Holding genMu until gen is nil also turns post-Free requests into a
|
|
||||||
// clean "model not loaded" error instead of a crash.
|
|
||||||
func (d *Dllm) Free() error {
|
|
||||||
d.genMu.Lock()
|
|
||||||
defer d.genMu.Unlock()
|
|
||||||
if d.gen == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
d.submit(d.gen.free)
|
|
||||||
close(d.jobs)
|
|
||||||
d.workerWG.Wait()
|
|
||||||
d.gen = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cancel requests cancellation of the in-flight generate (the
|
|
||||||
// grpc.Cancellable capability). The gRPC server arms it via
|
|
||||||
// context.AfterFunc on the request/stream context, so a client
|
|
||||||
// disconnect or timeout aborts the generation server-side - the same
|
|
||||||
// semantics the llama.cpp C++ backend gets from polling IsCancelled().
|
|
||||||
// It deliberately bypasses the worker queue: dllm_capi_cancel is the one
|
|
||||||
// call the C-ABI allows from any goroutine mid-generate (it only flips
|
|
||||||
// an atomic).
|
|
||||||
//
|
|
||||||
// Note dllm_capi.h's cancel-reset race: each generate resets the flag on
|
|
||||||
// entry, so a Cancel racing a NEW generate on the same ctx can be lost
|
|
||||||
// (and, with requests queued on the worker, it aborts whichever generate
|
|
||||||
// is currently running). The single-flag granularity is acceptable here
|
|
||||||
// because the server de-registers the hook on normal completion and one
|
|
||||||
// backend process serves one model.
|
|
||||||
func (d *Dllm) Cancel() {
|
|
||||||
// RLock so a server-side AfterFunc firing in the window between a
|
|
||||||
// request finishing and a model unload cannot touch a freed C ctx
|
|
||||||
// (Free holds the write lock while tearing gen down). cancel() is the
|
|
||||||
// one C call that is safe concurrently with an in-flight generate, so
|
|
||||||
// taking a read lock here cannot deadlock against request holders.
|
|
||||||
d.genMu.RLock()
|
|
||||||
defer d.genMu.RUnlock()
|
|
||||||
if d.gen != nil {
|
|
||||||
d.gen.cancel()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to
|
|
||||||
// the engine. Options is a shared free-form bag (other layers put their own
|
|
||||||
// entries there), so unknown keys are skipped with a warning, not an error.
|
|
||||||
var dllmGenOptKeys = map[string]bool{
|
|
||||||
"blocks": true,
|
|
||||||
"kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseModelGenOpts parses "key:value" Options entries into the flat scalar
|
|
||||||
// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler
|
|
||||||
// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by
|
|
||||||
// first successful parse (int, then float, else string) to match the C
|
|
||||||
// scanner's number/string scalars.
|
|
||||||
func parseModelGenOpts(options []string) map[string]any {
|
|
||||||
out := map[string]any{}
|
|
||||||
for _, o := range options {
|
|
||||||
key, val, found := strings.Cut(o, ":")
|
|
||||||
if !found {
|
|
||||||
xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] {
|
|
||||||
xlog.Debug("dllm: ignoring unrecognized option", "key", key)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
out[key] = parseScalarOpt(val)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseScalarOpt(v string) any {
|
|
||||||
if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
|
|
||||||
return iv
|
|
||||||
}
|
|
||||||
if fv, err := strconv.ParseFloat(v, 64); err == nil {
|
|
||||||
return fv
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default
|
|
||||||
// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat
|
|
||||||
// template guards every thinking branch with `enable_thinking is defined and
|
|
||||||
// enable_thinking`, i.e. thinking is opt-in for this model family, and the
|
|
||||||
// no-thinking render pre-closes an empty thought channel that the OFF
|
|
||||||
// default must produce.
|
|
||||||
func metadataEnableThinking(opts *pb.PredictOptions) bool {
|
|
||||||
v := opts.GetMetadata()["enable_thinking"]
|
|
||||||
return v == "true" || v == "1"
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildPrompt resolves the prompt for a request. With use_tokenizer_template
|
|
||||||
// and raw messages the backend owns templating (RenderGemma4, including the
|
|
||||||
// mmImageMarker injection for opts.Images) and the output is in the known
|
|
||||||
// gemma4 format, so parse=true. Without it the caller templated the prompt
|
|
||||||
// themselves (LocalAI's Go templates + PEG fallback, or a bare completion):
|
|
||||||
// the prompt passes through verbatim - for image requests it must already
|
|
||||||
// carry one literal mmImageMarker per image (the engine enforces the 1:1
|
|
||||||
// match) - and the output is NOT gemma4-parsed - it is emitted as plain
|
|
||||||
// content and the Go side's extraction applies, as for any non-autoparsing
|
|
||||||
// backend.
|
|
||||||
func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) {
|
|
||||||
if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 {
|
|
||||||
prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), len(opts.GetImages()), metadataEnableThinking(opts), true)
|
|
||||||
return prompt, true, err
|
|
||||||
}
|
|
||||||
return opts.GetPrompt(), false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// imagesJSON renders opts.Images as the flat JSON array of data: URIs the mm
|
|
||||||
// C-ABI expects, or "" when the request carries no images. The entries arrive
|
|
||||||
// as RAW base64 payloads: LocalAI's OpenAI layer decodes every image_url /
|
|
||||||
// image content part (URL download or data: URI) to plain base64 via
|
|
||||||
// utils.GetContentURIAsBase64 (core/http/middleware/request.go) and core
|
|
||||||
// flattens them into PredictOptions.Images (core/backend/llm.go). The
|
|
||||||
// hardcoded image/jpeg mime mirrors the llama.cpp backend's re-wrapping
|
|
||||||
// convention (grpc-server.cpp, "data:image/jpeg;base64," + images(i)); the
|
|
||||||
// engine ignores the declared mime and sniffs the real format from the
|
|
||||||
// decoded bytes (stb_image), so PNG/BMP payloads work through it too.
|
|
||||||
func imagesJSON(images []string) (string, error) {
|
|
||||||
if len(images) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
uris := make([]string, len(images))
|
|
||||||
for i, img := range images {
|
|
||||||
// dllm_capi.h: array entries are read VERBATIM up to the closing
|
|
||||||
// quote, with NO escape handling. json.Marshal would escape these
|
|
||||||
// bytes and the C side would misparse the entry, so fail loud (they
|
|
||||||
// can never appear in genuine base64 anyway).
|
|
||||||
if strings.ContainsAny(img, "\"\\") {
|
|
||||||
return "", fmt.Errorf("dllm: image %d is not base64 (contains a quote or backslash; PredictOptions.Images entries must be raw base64 payloads)", i)
|
|
||||||
}
|
|
||||||
uris[i] = "data:image/jpeg;base64," + img
|
|
||||||
}
|
|
||||||
b, err := json.Marshal(uris)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("dllm: marshal images: %w", err)
|
|
||||||
}
|
|
||||||
return string(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// requestOptsJSON merges the model-level overrides with the request's
|
|
||||||
// sampling fields into the flat opts JSON for one generate call.
|
|
||||||
func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) {
|
|
||||||
m := make(map[string]any, len(d.genOpts)+2)
|
|
||||||
for k, v := range d.genOpts {
|
|
||||||
m[k] = v
|
|
||||||
}
|
|
||||||
if n := opts.GetTokens(); n > 0 {
|
|
||||||
// The engine rounds n_predict UP to a whole number of diffusion
|
|
||||||
// blocks (the canvas is denoised block-wise), so the completion may
|
|
||||||
// run slightly past the requested budget. Tokens==0 omits the key so
|
|
||||||
// the C-ABI default of 256 applies (hardcoded in capi.cpp's
|
|
||||||
// parse_gen_opts, independent of canvas_length).
|
|
||||||
m["n_predict"] = n
|
|
||||||
}
|
|
||||||
if s := opts.GetSeed(); s > 0 {
|
|
||||||
// The engine seeds mt19937 with explicit non-negative seeds. Seed<=0
|
|
||||||
// is omitted: proto3 cannot distinguish 0 from unset, and negative
|
|
||||||
// values conventionally mean "random" across LocalAI backends.
|
|
||||||
m["seed"] = s
|
|
||||||
}
|
|
||||||
return buildOptsJSON(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareRequest is the shared prologue of the rich methods: resolve the
|
|
||||||
// prompt (and whether the output gets gemma4-parsed) and build the per-call
|
|
||||||
// opts JSON plus the images JSON ("" for text-only requests, which routes
|
|
||||||
// the call through the text generate entry points).
|
|
||||||
func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON, imgJSON string, err error) {
|
|
||||||
// Fail loud on media the engine has no path for, instead of silently
|
|
||||||
// generating from a prompt that ignores them.
|
|
||||||
if len(opts.GetVideos()) > 0 || len(opts.GetAudios()) > 0 {
|
|
||||||
return "", false, "", "", errors.New("dllm: video/audio input is not supported (images only)")
|
|
||||||
}
|
|
||||||
prompt, parse, err = buildPrompt(opts)
|
|
||||||
if err != nil {
|
|
||||||
return "", false, "", "", err
|
|
||||||
}
|
|
||||||
optsJSON, err = d.requestOptsJSON(opts)
|
|
||||||
if err != nil {
|
|
||||||
return "", false, "", "", err
|
|
||||||
}
|
|
||||||
imgJSON, err = imagesJSON(opts.GetImages())
|
|
||||||
if err != nil {
|
|
||||||
return "", false, "", "", err
|
|
||||||
}
|
|
||||||
return prompt, parse, optsJSON, imgJSON, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary
|
|
||||||
// detokenization and byte-fallback tokens can produce invalid UTF-8, and
|
|
||||||
// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so
|
|
||||||
// every string destined for a Reply/ChatDelta must pass through here (or
|
|
||||||
// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely
|
|
||||||
// undecodable: replace with U+FFFD rather than crash the stream.
|
|
||||||
func sanitizeUTF8(s string) string {
|
|
||||||
if utf8.ValidString(s) {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return strings.ToValidUTF8(s, "<22>")
|
|
||||||
}
|
|
||||||
|
|
||||||
// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte
|
|
||||||
// (1 for bytes that can never lead a multi-byte sequence, so they are never
|
|
||||||
// held back and fall through to sanitizeUTF8's replacement).
|
|
||||||
func utf8SeqLen(b byte) int {
|
|
||||||
switch {
|
|
||||||
case b&0xE0 == 0xC0:
|
|
||||||
return 2
|
|
||||||
case b&0xF0 == 0xE0:
|
|
||||||
return 3
|
|
||||||
case b&0xF8 == 0xF0:
|
|
||||||
return 4
|
|
||||||
default:
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitValidUTF8 prepends the previous block's carry to the new block and
|
|
||||||
// splits the result into text safe to emit now and a trailing INCOMPLETE
|
|
||||||
// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block:
|
|
||||||
// the per-block detokenize can split a multi-byte character across block
|
|
||||||
// boundaries (llama.cpp's grpc-server holds back the same way). Only a
|
|
||||||
// suffix that can still become a valid rune is withheld; bytes that are
|
|
||||||
// already undecodable are replaced immediately so the carry stays bounded.
|
|
||||||
func splitValidUTF8(carry, block string) (emit, newCarry string) {
|
|
||||||
s := carry + block
|
|
||||||
cut := len(s)
|
|
||||||
for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- {
|
|
||||||
b := s[i]
|
|
||||||
if b < utf8.RuneSelf {
|
|
||||||
break // ASCII: everything before the tail scan is complete
|
|
||||||
}
|
|
||||||
if !utf8.RuneStart(b) {
|
|
||||||
continue // continuation byte: keep looking for its leading byte
|
|
||||||
}
|
|
||||||
// Leading byte: hold the sequence back iff it declares more bytes
|
|
||||||
// than the stream has produced so far (it may complete next block).
|
|
||||||
if utf8SeqLen(b) > len(s)-i {
|
|
||||||
cut = i
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return sanitizeUTF8(s[:cut]), s[cut:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// PredictRich is the non-streaming inference path (grpc.AIModelRich).
|
|
||||||
// Returns one Reply whose Message is the aggregated assistant content and
|
|
||||||
// whose ChatDeltas carry the parsed content/reasoning/tool-call events.
|
|
||||||
func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) {
|
|
||||||
d.genMu.RLock()
|
|
||||||
defer d.genMu.RUnlock()
|
|
||||||
if d.gen == nil {
|
|
||||||
return nil, grpcerrors.ModelNotLoaded("dllm")
|
|
||||||
}
|
|
||||||
prompt, parse, optsJSON, imgJSON, err := d.prepareRequest(opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var out string
|
|
||||||
var genErr error
|
|
||||||
d.submit(func() {
|
|
||||||
if imgJSON != "" {
|
|
||||||
out, genErr = d.gen.generateMM(prompt, imgJSON, optsJSON)
|
|
||||||
} else {
|
|
||||||
out, genErr = d.gen.generate(prompt, optsJSON)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if genErr != nil {
|
|
||||||
return nil, genErr
|
|
||||||
}
|
|
||||||
// Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings
|
|
||||||
// must be valid or grpc-go fails the whole reply at marshal time.
|
|
||||||
out = sanitizeUTF8(out)
|
|
||||||
|
|
||||||
if !parse {
|
|
||||||
// Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt).
|
|
||||||
return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The prompt renders with add_generation_prompt; both thinking modes
|
|
||||||
// leave the model starting in content state (see the Gemma4Parser header
|
|
||||||
// comment), hence NewGemma4Parser(false).
|
|
||||||
parser := NewGemma4Parser(false)
|
|
||||||
if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil {
|
|
||||||
return reply, nil
|
|
||||||
}
|
|
||||||
// Everything was markers (or out was empty): an empty but non-nil Reply.
|
|
||||||
return &pb.Reply{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one
|
|
||||||
// Reply per committed diffusion block that produced deltas. Per the
|
|
||||||
// interface contract the channel is only sent into here - the gRPC server
|
|
||||||
// closes it after this returns (opposite to legacy PredictStream).
|
|
||||||
func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
|
||||||
d.genMu.RLock()
|
|
||||||
defer d.genMu.RUnlock()
|
|
||||||
if d.gen == nil {
|
|
||||||
return grpcerrors.ModelNotLoaded("dllm")
|
|
||||||
}
|
|
||||||
prompt, parse, optsJSON, imgJSON, err := d.prepareRequest(opts)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var parser *Gemma4Parser
|
|
||||||
if parse {
|
|
||||||
parser = NewGemma4Parser(false)
|
|
||||||
}
|
|
||||||
// emit runs inside onBlock, i.e. on the thread driving the C generate.
|
|
||||||
// Sending on results can block on a slow consumer, but the server-side
|
|
||||||
// pump (pkg/grpc/server.go PredictStream) drains continuously and drops
|
|
||||||
// undeliverable sends, so this backpressure is brief and bounded - and
|
|
||||||
// pausing the diffusion loop under it is the desired behavior anyway.
|
|
||||||
emit := func(text string) {
|
|
||||||
if !parse {
|
|
||||||
if text != "" {
|
|
||||||
results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
deltas := parser.Feed(text)
|
|
||||||
if reply := replyFromDeltas(deltas); reply != nil {
|
|
||||||
results <- reply
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// onBlock guards emit (and through it the parser) against invalid UTF-8:
|
|
||||||
// a multi-byte character split across block boundaries is held back until
|
|
||||||
// it completes (see splitValidUTF8), so proto3 marshaling never fails.
|
|
||||||
var carry string
|
|
||||||
onBlock := func(block string) {
|
|
||||||
var text string
|
|
||||||
text, carry = splitValidUTF8(carry, block)
|
|
||||||
emit(text)
|
|
||||||
}
|
|
||||||
|
|
||||||
var genErr error
|
|
||||||
d.submit(func() {
|
|
||||||
if imgJSON != "" {
|
|
||||||
genErr = d.gen.generateStreamMM(prompt, imgJSON, optsJSON, onBlock)
|
|
||||||
} else {
|
|
||||||
genErr = d.gen.generateStream(prompt, optsJSON, onBlock)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if genErr != nil {
|
|
||||||
return genErr
|
|
||||||
}
|
|
||||||
if carry != "" {
|
|
||||||
// The stream ended mid-sequence: the held-back bytes can no longer
|
|
||||||
// complete, so flush them through the U+FFFD last resort.
|
|
||||||
emit(sanitizeUTF8(carry))
|
|
||||||
}
|
|
||||||
if parse {
|
|
||||||
if reply := replyFromDeltas(parser.Close()); reply != nil {
|
|
||||||
results <- reply
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply,
|
|
||||||
// or nil when the batch is empty (markers consumed, nothing emitted yet).
|
|
||||||
// Message mirrors the batch's content text so legacy chan-string consumers
|
|
||||||
// see exactly the displayed tokens.
|
|
||||||
func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply {
|
|
||||||
if len(deltas) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var content strings.Builder
|
|
||||||
for _, delta := range deltas {
|
|
||||||
content.WriteString(delta.GetContent())
|
|
||||||
}
|
|
||||||
return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Predict is the legacy (string, error) signature; the gRPC server prefers
|
|
||||||
// PredictRich, this exists for non-rich callers (cloud-proxy precedent).
|
|
||||||
func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) {
|
|
||||||
reply, err := d.PredictRich(opts)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return string(reply.GetMessage()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PredictStream is the legacy chan-string path: rich replies reduced to
|
|
||||||
// their content text. Note the inverted channel ownership - the LEGACY
|
|
||||||
// contract requires the impl to close the channel.
|
|
||||||
func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
|
||||||
defer close(results)
|
|
||||||
richCh := make(chan *pb.Reply)
|
|
||||||
errCh := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
errCh <- d.PredictStreamRich(opts, richCh)
|
|
||||||
close(richCh)
|
|
||||||
}()
|
|
||||||
for reply := range richCh {
|
|
||||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
|
||||||
results <- string(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return <-errCh
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C
|
|
||||||
// side prepends bos per the vocab) and decodes the returned id array.
|
|
||||||
func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
|
||||||
d.genMu.RLock()
|
|
||||||
defer d.genMu.RUnlock()
|
|
||||||
if d.gen == nil {
|
|
||||||
return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm")
|
|
||||||
}
|
|
||||||
var out string
|
|
||||||
var tokErr error
|
|
||||||
d.submit(func() {
|
|
||||||
out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt())
|
|
||||||
})
|
|
||||||
if tokErr != nil {
|
|
||||||
return pb.TokenizationResponse{}, tokErr
|
|
||||||
}
|
|
||||||
var tokens []int32
|
|
||||||
if err := json.Unmarshal([]byte(out), &tokens); err != nil {
|
|
||||||
return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err)
|
|
||||||
}
|
|
||||||
return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,562 +0,0 @@
|
|||||||
// Gemma4 (DiffusionGemma) streaming output parser: raw model text, fed in
|
|
||||||
// arbitrary fragments (per committed diffusion block; a fragment can split
|
|
||||||
// anywhere, including mid-marker and mid-payload), is turned into
|
|
||||||
// pb.ChatDelta events (content / reasoning_content / tool_calls).
|
|
||||||
//
|
|
||||||
// Normative sources:
|
|
||||||
// - The chat template embedded at the top of gemma4_renderer.go ("tpl L<n>"
|
|
||||||
// citations below refer to its numbered lines). The OUTPUT format mirrors
|
|
||||||
// what the template renders for assistant history: thought channels
|
|
||||||
// (<|channel>thought\n ... <channel|>, tpl L240), tool calls
|
|
||||||
// (<|tool_call>call:name{...}<tool_call|>, tpl L246-L257) and turn ends
|
|
||||||
// (<turn|>, tpl L351).
|
|
||||||
// - vLLM PR #45163: vllm/tool_parsers/gemma4_tool_parser.py (marker
|
|
||||||
// handling, the call:name{...} argument grammar and its decoder, ported
|
|
||||||
// below) and vllm/reasoning/gemma4_reasoning_parser.py (channel markers,
|
|
||||||
// the "thought\n" role label, is_reasoning_end semantics).
|
|
||||||
//
|
|
||||||
// Initial state (derived from the generation prompt, tpl L356-L362, see
|
|
||||||
// RenderGemma4):
|
|
||||||
// - enable_thinking=false: the prompt ends with "<|turn>model\n" +
|
|
||||||
// "<|channel>thought\n<channel|>" - an EMPTY thought channel, pre-opened
|
|
||||||
// AND pre-closed by the template. The model's output therefore starts in
|
|
||||||
// plain content. Use NewGemma4Parser(false).
|
|
||||||
// - enable_thinking=true: the prompt ends at "<|turn>model\n" and the model
|
|
||||||
// opens and closes its own thought channel in the OUTPUT
|
|
||||||
// ("<|channel>thought\n...reasoning...<channel|>final answer", per the
|
|
||||||
// vLLM Gemma4ReasoningParser docstring). The parser still starts in
|
|
||||||
// content state - the channel markers in the output drive the switch.
|
|
||||||
// Use NewGemma4Parser(false) here too.
|
|
||||||
// - NewGemma4Parser(true) is for callers that pre-open the thought channel
|
|
||||||
// in the prompt themselves (appending "<|channel>thought\n" after the
|
|
||||||
// generation prompt to force thinking): the output then begins mid-thought
|
|
||||||
// and everything is reasoning until the first <channel|>.
|
|
||||||
//
|
|
||||||
// State diagram (markers are consumed, never emitted):
|
|
||||||
//
|
|
||||||
// <|channel> \n (channel name dropped: the
|
|
||||||
// [content] --------------> [chan-header] ----> [thought] "thought\n" role
|
|
||||||
// ^ | <channel|> (stray close: swallowed, label, stripped
|
|
||||||
// +-+ strip_thinking semantics, tpl L148-L158) like vLLM does)
|
|
||||||
// ^ <channel|>
|
|
||||||
// +----------------------------------------- [thought]
|
|
||||||
// ^ <tool_call|> | <|tool_call> (implicit
|
|
||||||
// +-------------- [tool-call] <-------------------+ reasoning end, vLLM
|
|
||||||
// | <|tool_call> ^ is_reasoning_end)
|
|
||||||
// +-------------------+
|
|
||||||
// [content]/[thought] --- <turn|> ---> [done] (everything after is dropped)
|
|
||||||
//
|
|
||||||
// Buffering rules:
|
|
||||||
// - content/thought states hold back at most len(longest marker)-1 bytes:
|
|
||||||
// the longest tail that is still a proper prefix of a watched marker.
|
|
||||||
// Content is otherwise emitted immediately (no unbounded buffering).
|
|
||||||
// - the tool-call state buffers the whole payload until <tool_call|>. This
|
|
||||||
// is unbounded in principle but bounded in practice by the model's
|
|
||||||
// diffusion canvas, and is required because the call:name{...} payload
|
|
||||||
// only becomes decodable (and trustworthy) once complete - the same
|
|
||||||
// reason vLLM's parser accumulates before parsing.
|
|
||||||
// - Close() flushes whatever is still held: partial markers come out as
|
|
||||||
// content/reasoning (per the state that held them); an unterminated
|
|
||||||
// channel header or tool-call payload is re-emitted RAW (including its
|
|
||||||
// opening marker) as content - malformed output is never silently
|
|
||||||
// dropped (mirrors vLLM extract_tool_calls returning the raw text as
|
|
||||||
// content when its regex does not match).
|
|
||||||
//
|
|
||||||
// Streaming granularity DIVERGENCE from vLLM: vLLM re-parses the partial
|
|
||||||
// payload on every token and streams argument-JSON diffs (its `partial=True`
|
|
||||||
// decoder mode plus withholding logic exist only for that). Our fragments are
|
|
||||||
// whole committed diffusion blocks, so each completed tool call is emitted
|
|
||||||
// once, as a single ToolCallDelta carrying index + id + name + the full
|
|
||||||
// arguments JSON - exactly the shape backend/python/vllm/backend.py emits
|
|
||||||
// per call and pkg/functions.ToolCallsFromChatDeltas re-accumulates.
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// gemma4CallRE is vLLM's tool_call_regex
|
|
||||||
// (`<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>`, DOTALL) anchored to
|
|
||||||
// a single already-extracted payload: name charset [\w\-.], braces mandatory.
|
|
||||||
var gemma4CallRE = regexp.MustCompile(`(?s)^call:([\w\-.]+)\{(.*)\}$`)
|
|
||||||
|
|
||||||
type g4State int
|
|
||||||
|
|
||||||
const (
|
|
||||||
g4Content g4State = iota
|
|
||||||
g4ChanHeader
|
|
||||||
g4Thought
|
|
||||||
g4ToolCall
|
|
||||||
g4Done
|
|
||||||
)
|
|
||||||
|
|
||||||
// Markers watched per emitting state. A stray <tool_call|> outside a tool
|
|
||||||
// call is deliberately NOT watched: it passes through verbatim, consistent
|
|
||||||
// with the malformed-payload fallback re-emitting it as content.
|
|
||||||
var (
|
|
||||||
gemma4ContentMarkers = []string{gemma4ChannelOpen, gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
|
|
||||||
gemma4ThoughtMarkers = []string{gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
|
|
||||||
)
|
|
||||||
|
|
||||||
type Gemma4Parser struct {
|
|
||||||
state g4State
|
|
||||||
// held is the per-state carry-over between Feed calls: a partial marker
|
|
||||||
// (content/thought), a partial channel header (chan-header) or the
|
|
||||||
// payload accumulated so far (tool-call).
|
|
||||||
held string
|
|
||||||
toolIdx int
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewGemma4Parser returns a parser positioned per the initial-state rules in
|
|
||||||
// the header comment: startInThought=true only when the caller pre-opened a
|
|
||||||
// thought channel in the prompt.
|
|
||||||
func NewGemma4Parser(startInThought bool) *Gemma4Parser {
|
|
||||||
state := g4Content
|
|
||||||
if startInThought {
|
|
||||||
state = g4Thought
|
|
||||||
}
|
|
||||||
return &Gemma4Parser{state: state}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Feed consumes the next output fragment and returns the deltas it completes.
|
|
||||||
func (p *Gemma4Parser) Feed(text string) []*pb.ChatDelta {
|
|
||||||
if text == "" || p.state == g4Done {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
pending := p.held + text
|
|
||||||
p.held = ""
|
|
||||||
var em g4Emitter
|
|
||||||
for pending != "" {
|
|
||||||
switch p.state {
|
|
||||||
case g4Content, g4Thought:
|
|
||||||
markers := gemma4ContentMarkers
|
|
||||||
if p.state == g4Thought {
|
|
||||||
markers = gemma4ThoughtMarkers
|
|
||||||
}
|
|
||||||
idx, marker := findEarliestGemma4Marker(pending, markers)
|
|
||||||
if idx == -1 {
|
|
||||||
hold := gemma4MarkerHoldback(pending, markers)
|
|
||||||
p.emitText(&em, pending[:len(pending)-hold])
|
|
||||||
p.held = pending[len(pending)-hold:]
|
|
||||||
pending = ""
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.emitText(&em, pending[:idx])
|
|
||||||
pending = pending[idx+len(marker):]
|
|
||||||
switch marker {
|
|
||||||
case gemma4ChannelOpen:
|
|
||||||
p.state = g4ChanHeader
|
|
||||||
case gemma4ChannelClose:
|
|
||||||
// In thought: channel ends. In content: stray close,
|
|
||||||
// swallowed (strip_thinking keeps both sides, tpl L148-L158).
|
|
||||||
p.state = g4Content
|
|
||||||
case gemma4ToolCallOpen:
|
|
||||||
p.state = g4ToolCall
|
|
||||||
case gemma4TurnEnd:
|
|
||||||
p.state = g4Done
|
|
||||||
}
|
|
||||||
case g4ChanHeader:
|
|
||||||
// The channel header is "<name>\n"; the template only ever writes
|
|
||||||
// "thought" (tpl L240/L360) and the label is structural, so it is
|
|
||||||
// dropped, not emitted (vLLM strips the same "thought\n" prefix).
|
|
||||||
nl := strings.IndexByte(pending, '\n')
|
|
||||||
if nl == -1 {
|
|
||||||
p.held = pending
|
|
||||||
pending = ""
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pending = pending[nl+1:]
|
|
||||||
p.state = g4Thought
|
|
||||||
case g4ToolCall:
|
|
||||||
end := strings.Index(pending, gemma4ToolCallClose)
|
|
||||||
if end == -1 {
|
|
||||||
p.held = pending
|
|
||||||
pending = ""
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.emitToolCall(&em, pending[:end])
|
|
||||||
pending = pending[end+len(gemma4ToolCallClose):]
|
|
||||||
p.state = g4Content
|
|
||||||
case g4Done:
|
|
||||||
pending = ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return em.deltas
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close flushes held-back partials. Incomplete structures (open channel
|
|
||||||
// header, unterminated tool payload) are re-emitted raw as content rather
|
|
||||||
// than dropped. The parser is finished afterwards.
|
|
||||||
func (p *Gemma4Parser) Close() []*pb.ChatDelta {
|
|
||||||
var em g4Emitter
|
|
||||||
switch p.state {
|
|
||||||
case g4Content:
|
|
||||||
em.content(p.held)
|
|
||||||
case g4Thought:
|
|
||||||
em.reasoning(p.held)
|
|
||||||
case g4ChanHeader:
|
|
||||||
em.content(gemma4ChannelOpen + p.held)
|
|
||||||
case g4ToolCall:
|
|
||||||
em.content(gemma4ToolCallOpen + p.held)
|
|
||||||
case g4Done:
|
|
||||||
}
|
|
||||||
p.held = ""
|
|
||||||
p.state = g4Done
|
|
||||||
return em.deltas
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Gemma4Parser) emitText(em *g4Emitter, s string) {
|
|
||||||
if p.state == g4Thought {
|
|
||||||
em.reasoning(s)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
em.content(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// emitToolCall decodes one complete <|tool_call>...<tool_call|> payload. On a
|
|
||||||
// payload that does not match call:name{...} the raw text (markers included)
|
|
||||||
// is emitted as content, mirroring vLLM's extract_tool_calls fallback.
|
|
||||||
func (p *Gemma4Parser) emitToolCall(em *g4Emitter, payload string) {
|
|
||||||
m := gemma4CallRE.FindStringSubmatch(payload)
|
|
||||||
if m == nil {
|
|
||||||
em.content(gemma4ToolCallOpen + payload + gemma4ToolCallClose)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Index-based ids: deterministic (the split-invariance property relies
|
|
||||||
// on it) and matching the call_<n> convention of pkg/grpc/rich_test.go;
|
|
||||||
// core only needs ids to be non-empty and unique within the response.
|
|
||||||
em.tool(p.toolIdx, "call_"+strconv.Itoa(p.toolIdx), m[1], decodeGemma4Args(m[2], 0))
|
|
||||||
p.toolIdx++
|
|
||||||
}
|
|
||||||
|
|
||||||
// g4Emitter collects ChatDeltas; empty text events are dropped.
|
|
||||||
type g4Emitter struct {
|
|
||||||
deltas []*pb.ChatDelta
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *g4Emitter) content(s string) {
|
|
||||||
if s != "" {
|
|
||||||
e.deltas = append(e.deltas, &pb.ChatDelta{Content: s})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *g4Emitter) reasoning(s string) {
|
|
||||||
if s != "" {
|
|
||||||
e.deltas = append(e.deltas, &pb.ChatDelta{ReasoningContent: s})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *g4Emitter) tool(index int, id, name, argsJSON string) {
|
|
||||||
e.deltas = append(e.deltas, &pb.ChatDelta{ToolCalls: []*pb.ToolCallDelta{{
|
|
||||||
Index: int32(index),
|
|
||||||
Id: id,
|
|
||||||
Name: name,
|
|
||||||
Arguments: argsJSON,
|
|
||||||
}}})
|
|
||||||
}
|
|
||||||
|
|
||||||
// findEarliestGemma4Marker returns the position and value of the first
|
|
||||||
// complete marker occurrence, or (-1, "").
|
|
||||||
func findEarliestGemma4Marker(s string, markers []string) (int, string) {
|
|
||||||
best, bestMarker := -1, ""
|
|
||||||
for _, m := range markers {
|
|
||||||
if idx := strings.Index(s, m); idx >= 0 && (best == -1 || idx < best) {
|
|
||||||
best, bestMarker = idx, m
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return best, bestMarker
|
|
||||||
}
|
|
||||||
|
|
||||||
// gemma4MarkerHoldback returns the length of the longest suffix of s that is
|
|
||||||
// a proper prefix of a watched marker - the only bytes that may still grow
|
|
||||||
// into a marker and therefore must not be emitted yet (bounded by the
|
|
||||||
// longest marker, so content is never buffered unboundedly).
|
|
||||||
func gemma4MarkerHoldback(s string, markers []string) int {
|
|
||||||
maxHold := 0
|
|
||||||
for _, m := range markers {
|
|
||||||
if len(m)-1 > maxHold {
|
|
||||||
maxHold = len(m) - 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(s) < maxHold {
|
|
||||||
maxHold = len(s)
|
|
||||||
}
|
|
||||||
for k := maxHold; k >= 1; k-- {
|
|
||||||
tail := s[len(s)-k:]
|
|
||||||
for _, m := range markers {
|
|
||||||
if strings.HasPrefix(m, tail) {
|
|
||||||
return k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// call:name{...} argument decoder
|
|
||||||
//
|
|
||||||
// Port of vLLM's _parse_gemma4_args / _parse_gemma4_array /
|
|
||||||
// _parse_gemma4_value (gemma4_tool_parser.py) in non-partial mode only: this
|
|
||||||
// parser decodes exclusively COMPLETE payloads (incomplete ones fall back to
|
|
||||||
// raw content at Close), so vLLM's partial-withholding machinery
|
|
||||||
// (trailing-dot floats, withheld bare tails) is intentionally not ported.
|
|
||||||
//
|
|
||||||
// Grammar (inverse of the renderer's formatGemma4Argument, tpl L118-L147):
|
|
||||||
//
|
|
||||||
// args := pair (',' pair)*
|
|
||||||
// pair := key ':' value (keys unquoted, up to the first ':')
|
|
||||||
// value := string | object | array | bare
|
|
||||||
// string := '<|"|>' ... '<|"|>' (no escapes; unterminated -> rest)
|
|
||||||
// object := '{' args '}' (delimited strings skipped when
|
|
||||||
// array := '[' value,* ']' counting braces/brackets)
|
|
||||||
// bare := true | false | null/none/nil | number | bare-string
|
|
||||||
//
|
|
||||||
// Output is a JSON object/array string with keys in payload order (Python
|
|
||||||
// dict insertion order), built with HTML escaping off so payload text
|
|
||||||
// survives byte-for-byte.
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
func isGemma4Space(c byte) bool { return c == ' ' || c == '\n' || c == '\t' }
|
|
||||||
|
|
||||||
// gemma4MaxArgsDepth caps the mutual recursion between decodeGemma4Args and
|
|
||||||
// decodeGemma4Array. Defense against model-generated deep nesting: a Go stack
|
|
||||||
// overflow is a fatal process kill, not a recoverable error, so past the cap
|
|
||||||
// a nested body gracefully degrades to a JSON string of its raw text.
|
|
||||||
const gemma4MaxArgsDepth = 100
|
|
||||||
|
|
||||||
// decodeGemma4Args decodes one args body (the text between the outer braces
|
|
||||||
// of call:name{...}) into a JSON object string. depth is the current nesting
|
|
||||||
// level (0 at the payload root); see gemma4MaxArgsDepth.
|
|
||||||
func decodeGemma4Args(s string, depth int) string {
|
|
||||||
if depth > gemma4MaxArgsDepth {
|
|
||||||
return gemma4JSONString(s)
|
|
||||||
}
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString("{")
|
|
||||||
first := true
|
|
||||||
pair := func(key, val string) {
|
|
||||||
if !first {
|
|
||||||
b.WriteString(",")
|
|
||||||
}
|
|
||||||
first = false
|
|
||||||
b.WriteString(gemma4JSONString(key))
|
|
||||||
b.WriteString(":")
|
|
||||||
b.WriteString(val)
|
|
||||||
}
|
|
||||||
i, n := 0, len(s)
|
|
||||||
for i < n {
|
|
||||||
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i >= n {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
keyStart := i
|
|
||||||
for i < n && s[i] != ':' {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i >= n {
|
|
||||||
break // no ':' -> trailing junk, dropped (vLLM does the same)
|
|
||||||
}
|
|
||||||
key := strings.TrimSpace(s[keyStart:i])
|
|
||||||
i++ // skip ':'
|
|
||||||
for i < n && isGemma4Space(s[i]) {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i >= n {
|
|
||||||
pair(key, `""`) // "key:" with nothing after -> empty string
|
|
||||||
break
|
|
||||||
}
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(s[i:], gemma4StringDelim):
|
|
||||||
i += len(gemma4StringDelim)
|
|
||||||
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
|
|
||||||
pair(key, gemma4JSONString(s[i:])) // unterminated -> take rest
|
|
||||||
i = n
|
|
||||||
} else {
|
|
||||||
pair(key, gemma4JSONString(s[i:i+end]))
|
|
||||||
i += end + len(gemma4StringDelim)
|
|
||||||
}
|
|
||||||
case s[i] == '{':
|
|
||||||
inner, next := scanGemma4Balanced(s, i, '{', '}')
|
|
||||||
pair(key, decodeGemma4Args(inner, depth+1))
|
|
||||||
i = next
|
|
||||||
case s[i] == '[':
|
|
||||||
inner, next := scanGemma4Balanced(s, i, '[', ']')
|
|
||||||
pair(key, decodeGemma4Array(inner, depth+1))
|
|
||||||
i = next
|
|
||||||
default:
|
|
||||||
valStart := i
|
|
||||||
for i < n && s[i] != ',' && s[i] != '}' && s[i] != ']' {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i == valStart {
|
|
||||||
// No progress (value starts on a stray '}'/']'): abort on
|
|
||||||
// malformed input rather than loop, like vLLM.
|
|
||||||
i = n
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pair(key, decodeGemma4Bare(s[valStart:i]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b.WriteString("}")
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// decodeGemma4Array decodes one array body (the text between '[' and ']')
|
|
||||||
// into a JSON array string. depth is the current nesting level; see
|
|
||||||
// gemma4MaxArgsDepth.
|
|
||||||
func decodeGemma4Array(s string, depth int) string {
|
|
||||||
if depth > gemma4MaxArgsDepth {
|
|
||||||
return gemma4JSONString(s)
|
|
||||||
}
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString("[")
|
|
||||||
first := true
|
|
||||||
item := func(val string) {
|
|
||||||
if !first {
|
|
||||||
b.WriteString(",")
|
|
||||||
}
|
|
||||||
first = false
|
|
||||||
b.WriteString(val)
|
|
||||||
}
|
|
||||||
i, n := 0, len(s)
|
|
||||||
for i < n {
|
|
||||||
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i >= n {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(s[i:], gemma4StringDelim):
|
|
||||||
i += len(gemma4StringDelim)
|
|
||||||
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
|
|
||||||
item(gemma4JSONString(s[i:]))
|
|
||||||
i = n
|
|
||||||
} else {
|
|
||||||
item(gemma4JSONString(s[i : i+end]))
|
|
||||||
i += end + len(gemma4StringDelim)
|
|
||||||
}
|
|
||||||
case s[i] == '{':
|
|
||||||
inner, next := scanGemma4Balanced(s, i, '{', '}')
|
|
||||||
item(decodeGemma4Args(inner, depth+1))
|
|
||||||
i = next
|
|
||||||
case s[i] == '[':
|
|
||||||
inner, next := scanGemma4Balanced(s, i, '[', ']')
|
|
||||||
item(decodeGemma4Array(inner, depth+1))
|
|
||||||
i = next
|
|
||||||
default:
|
|
||||||
valStart := i
|
|
||||||
for i < n && s[i] != ',' && s[i] != ']' {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i == valStart {
|
|
||||||
i = n // no progress: abort on malformed input, like vLLM
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
item(decodeGemma4Bare(s[valStart:i]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b.WriteString("]")
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// scanGemma4Balanced scans a brace/bracket-balanced span starting at the
|
|
||||||
// opener s[start], skipping over <|"|>-delimited strings so structural
|
|
||||||
// characters inside them do not count (vLLM's depth scan). Returns the inner
|
|
||||||
// text and the index just past the closer; an unterminated span yields the
|
|
||||||
// rest of the string (the inner decoder still extracts what is there - this
|
|
||||||
// path is only reachable from genuinely malformed complete payloads).
|
|
||||||
func scanGemma4Balanced(s string, start int, open, close byte) (string, int) {
|
|
||||||
depth := 1
|
|
||||||
i := start + 1
|
|
||||||
innerStart := i
|
|
||||||
n := len(s)
|
|
||||||
for i < n && depth > 0 {
|
|
||||||
if strings.HasPrefix(s[i:], gemma4StringDelim) {
|
|
||||||
i += len(gemma4StringDelim)
|
|
||||||
if nd := strings.Index(s[i:], gemma4StringDelim); nd == -1 {
|
|
||||||
i = n
|
|
||||||
} else {
|
|
||||||
i += nd + len(gemma4StringDelim)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch s[i] {
|
|
||||||
case open:
|
|
||||||
depth++
|
|
||||||
case close:
|
|
||||||
depth--
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if depth > 0 {
|
|
||||||
return s[innerStart:], n
|
|
||||||
}
|
|
||||||
return s[innerStart : i-1], i
|
|
||||||
}
|
|
||||||
|
|
||||||
// decodeGemma4Bare maps an undelimited value to its JSON form: booleans,
|
|
||||||
// null aliases (null/none/nil, case-insensitive - the renderer writes
|
|
||||||
// Python None as "None", tpl L144-L145 via format_argument's else branch),
|
|
||||||
// numbers (vLLM's rule: a '.' tries float, otherwise int; anything that
|
|
||||||
// fails parses as a bare string).
|
|
||||||
func decodeGemma4Bare(raw string) string {
|
|
||||||
v := strings.TrimSpace(raw)
|
|
||||||
if v == "" {
|
|
||||||
return `""`
|
|
||||||
}
|
|
||||||
if v == "true" || v == "false" {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
switch strings.ToLower(v) {
|
|
||||||
case "null", "none", "nil":
|
|
||||||
return "null"
|
|
||||||
}
|
|
||||||
if strings.Contains(v, ".") {
|
|
||||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
|
||||||
return formatGemma4Float(f)
|
|
||||||
}
|
|
||||||
} else if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
|
|
||||||
return strconv.FormatInt(iv, 10)
|
|
||||||
}
|
|
||||||
return gemma4JSONString(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatGemma4Float renders like Python's json.dumps(float): integral floats
|
|
||||||
// keep a ".0" suffix ("108." decodes to 108.0, not 108), so the arguments
|
|
||||||
// JSON matches what vLLM would have produced for the same payload.
|
|
||||||
func formatGemma4Float(f float64) string {
|
|
||||||
s := strconv.FormatFloat(f, 'g', -1, 64)
|
|
||||||
if !strings.ContainsAny(s, ".eE") {
|
|
||||||
s += ".0"
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// gemma4JSONString encodes a JSON string WITHOUT HTML escaping (json.Marshal
|
|
||||||
// would escape the angle brackets in "<div>" to \u003c / \u003e sequences;
|
|
||||||
// payload text should survive
|
|
||||||
// byte-for-byte, like Python's json.dumps(ensure_ascii=False)).
|
|
||||||
func gemma4JSONString(s string) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
enc := json.NewEncoder(&sb)
|
|
||||||
enc.SetEscapeHTML(false)
|
|
||||||
if err := enc.Encode(s); err != nil {
|
|
||||||
// Unreachable for plain strings; fall back to default escaping
|
|
||||||
// rather than emitting invalid JSON.
|
|
||||||
b, mErr := json.Marshal(s)
|
|
||||||
if mErr != nil {
|
|
||||||
return `""`
|
|
||||||
}
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
// Encode appends a trailing newline.
|
|
||||||
return strings.TrimSuffix(sb.String(), "\n")
|
|
||||||
}
|
|
||||||
@@ -1,592 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// Parser specs for Gemma4Parser (model output text -> pb.ChatDelta events).
|
|
||||||
//
|
|
||||||
// Fixture provenance:
|
|
||||||
// - Entries marked "vLLM: <name>" are direct ports of the named test from
|
|
||||||
// vLLM PR #45163, tests/tool_parsers/test_gemma4_tool_parser.py (the
|
|
||||||
// authoritative test-suite for the gemma4 tool-call wire format). The
|
|
||||||
// streaming tests' chunk lists are reused verbatim as Feed fragments.
|
|
||||||
// - Decoder entries port the TestParseGemma4Args / TestParseGemma4Array
|
|
||||||
// classes from the same file (non-partial mode only; this parser never
|
|
||||||
// decodes partial payloads, see the divergence note in gemma4_parser.go).
|
|
||||||
// - Channel/turn-marker expectations come from the chat template embedded
|
|
||||||
// in gemma4_renderer.go (tpl L356-L362 generation prompt, L148-L158
|
|
||||||
// strip_thinking) and vLLM's Gemma4ReasoningParser
|
|
||||||
// (vllm/reasoning/gemma4_reasoning_parser.py).
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// flatGemma4Tool is one accumulated tool call, mirroring how LocalAI core
|
|
||||||
// folds ToolCallDelta streams (pkg/functions/chat_deltas.go
|
|
||||||
// ToolCallsFromChatDeltas: name/id latch on first non-empty, arguments
|
|
||||||
// concatenate per index). Tests flatten through the same rules so they
|
|
||||||
// assert exactly what core will reconstruct.
|
|
||||||
type flatGemma4Tool struct {
|
|
||||||
id string
|
|
||||||
name string
|
|
||||||
args string
|
|
||||||
}
|
|
||||||
|
|
||||||
func flattenGemma4Deltas(deltas []*pb.ChatDelta) (string, string, []flatGemma4Tool) {
|
|
||||||
var content, reasoning strings.Builder
|
|
||||||
byIndex := map[int32]*flatGemma4Tool{}
|
|
||||||
maxIdx := int32(-1)
|
|
||||||
for _, d := range deltas {
|
|
||||||
content.WriteString(d.GetContent())
|
|
||||||
reasoning.WriteString(d.GetReasoningContent())
|
|
||||||
for _, tc := range d.GetToolCalls() {
|
|
||||||
acc, ok := byIndex[tc.GetIndex()]
|
|
||||||
if !ok {
|
|
||||||
acc = &flatGemma4Tool{}
|
|
||||||
byIndex[tc.GetIndex()] = acc
|
|
||||||
}
|
|
||||||
if tc.GetName() != "" {
|
|
||||||
acc.name = tc.GetName()
|
|
||||||
}
|
|
||||||
if tc.GetId() != "" {
|
|
||||||
acc.id = tc.GetId()
|
|
||||||
}
|
|
||||||
acc.args += tc.GetArguments()
|
|
||||||
if tc.GetIndex() > maxIdx {
|
|
||||||
maxIdx = tc.GetIndex()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var tools []flatGemma4Tool
|
|
||||||
for i := int32(0); i <= maxIdx; i++ {
|
|
||||||
if acc, ok := byIndex[i]; ok {
|
|
||||||
tools = append(tools, *acc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return content.String(), reasoning.String(), tools
|
|
||||||
}
|
|
||||||
|
|
||||||
type wantGemma4Tool struct {
|
|
||||||
name string
|
|
||||||
argsJSON string // compared with MatchJSON (key order irrelevant)
|
|
||||||
}
|
|
||||||
|
|
||||||
type parseGemma4Case struct {
|
|
||||||
startInThought bool
|
|
||||||
fragments []string
|
|
||||||
wantContent string
|
|
||||||
wantReasoning string
|
|
||||||
wantTools []wantGemma4Tool
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseGemma4Fragments(startInThought bool, fragments []string) []*pb.ChatDelta {
|
|
||||||
p := NewGemma4Parser(startInThought)
|
|
||||||
var all []*pb.ChatDelta
|
|
||||||
for _, f := range fragments {
|
|
||||||
all = append(all, p.Feed(f)...)
|
|
||||||
}
|
|
||||||
return append(all, p.Close()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Gemma4Parser", func() {
|
|
||||||
DescribeTable("parses streamed gemma4 output into ChatDeltas",
|
|
||||||
func(c parseGemma4Case) {
|
|
||||||
content, reasoning, tools := flattenGemma4Deltas(parseGemma4Fragments(c.startInThought, c.fragments))
|
|
||||||
Expect(content).To(Equal(c.wantContent))
|
|
||||||
Expect(reasoning).To(Equal(c.wantReasoning))
|
|
||||||
Expect(tools).To(HaveLen(len(c.wantTools)))
|
|
||||||
seenIDs := map[string]bool{}
|
|
||||||
for i, want := range c.wantTools {
|
|
||||||
Expect(tools[i].name).To(Equal(want.name), "tool %d name", i)
|
|
||||||
Expect(tools[i].args).To(MatchJSON(want.argsJSON), "tool %d arguments", i)
|
|
||||||
Expect(tools[i].id).ToNot(BeEmpty(), "tool %d id", i)
|
|
||||||
Expect(seenIDs).ToNot(HaveKey(tools[i].id), "tool %d id must be unique", i)
|
|
||||||
seenIDs[tools[i].id] = true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
// --- (1) pure content -------------------------------------------------
|
|
||||||
// vLLM: test_no_tool_calls
|
|
||||||
Entry("pure content, single fragment", parseGemma4Case{
|
|
||||||
fragments: []string{"Hello, how can I help you today?"},
|
|
||||||
wantContent: "Hello, how can I help you today?",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (2) thought -> final transition ----------------------------------
|
|
||||||
// enable_thinking render: prompt ends at <|turn>model\n and the model
|
|
||||||
// opens/closes its own thought channel in the OUTPUT (vLLM
|
|
||||||
// Gemma4ReasoningParser docstring; tpl L356-L362). The "thought\n"
|
|
||||||
// role label after <|channel> is structural and must be stripped
|
|
||||||
// (vLLM _THOUGHT_PREFIX handling).
|
|
||||||
Entry("thought channel then final content", parseGemma4Case{
|
|
||||||
fragments: []string{"<|channel>thought\nLet me think about this.\n<channel|>The answer is 42."},
|
|
||||||
wantReasoning: "Let me think about this.\n",
|
|
||||||
wantContent: "The answer is 42.",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (3) startInThought both ways -------------------------------------
|
|
||||||
Entry("startInThought=true routes initial text to reasoning until <channel|>", parseGemma4Case{
|
|
||||||
startInThought: true,
|
|
||||||
fragments: []string{"I am thinking hard.<channel|>Done."},
|
|
||||||
wantReasoning: "I am thinking hard.",
|
|
||||||
wantContent: "Done.",
|
|
||||||
}),
|
|
||||||
// A stray <channel|> with no open channel is swallowed, matching the
|
|
||||||
// template's strip_thinking (tpl L148-L158: the marker is dropped,
|
|
||||||
// text on both sides is kept).
|
|
||||||
Entry("startInThought=false keeps the same text as content, stray <channel|> swallowed", parseGemma4Case{
|
|
||||||
startInThought: false,
|
|
||||||
fragments: []string{"I am thinking hard.<channel|>Done."},
|
|
||||||
wantContent: "I am thinking hard.Done.",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (4) one tool call, full payload type zoo --------------------------
|
|
||||||
Entry("single tool call: strings, numbers, bools, null, nested object and array", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:complex_function{text:<|"|>with, comma and {braces}<|"|>,count:42,score:3.14,yes:true,no:false,nothing:null,obj:{inner:<|"|>v<|"|>,k:1},arr:[<|"|>a<|"|>,2,true]}<tool_call|>`},
|
|
||||||
wantTools: []wantGemma4Tool{{
|
|
||||||
name: "complex_function",
|
|
||||||
argsJSON: `{"text":"with, comma and {braces}","count":42,"score":3.14,"yes":true,"no":false,"nothing":null,"obj":{"inner":"v","k":1},"arr":["a",2,true]}`,
|
|
||||||
}},
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (5) payload split across 3 fragments ------------------------------
|
|
||||||
Entry("tool-call payload split across three fragments", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>call:get_weather{loc",
|
|
||||||
`ation:<|"|>Paris, Fra`,
|
|
||||||
`nce<|"|>}<tool_call|>`,
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (6) marker split across fragments ----------------------------------
|
|
||||||
Entry("tool-call open marker split across fragments", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_ca",
|
|
||||||
`ll>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
|
||||||
}),
|
|
||||||
Entry("channel open marker split across fragments", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|chan",
|
|
||||||
"nel>thought\ndeep thought<channel|>final",
|
|
||||||
},
|
|
||||||
wantReasoning: "deep thought",
|
|
||||||
wantContent: "final",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (7) trailing partial marker held, flushed by Close -----------------
|
|
||||||
Entry("trailing partial marker is held back and flushed by Close", parseGemma4Case{
|
|
||||||
fragments: []string{"Hello <|tool"},
|
|
||||||
wantContent: "Hello <|tool",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (8) malformed/incomplete payload -> content fallback ---------------
|
|
||||||
// vLLM: test_incomplete_tool_call (no end marker: the whole text stays
|
|
||||||
// content, never silently dropped).
|
|
||||||
Entry("incomplete tool payload at Close is emitted as raw content", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London`},
|
|
||||||
wantContent: `<|tool_call>call:get_weather{location:<|"|>London`,
|
|
||||||
}),
|
|
||||||
Entry("malformed complete payload is emitted as raw content, parsing continues", parseGemma4Case{
|
|
||||||
fragments: []string{"<|tool_call>oops no call syntax<tool_call|> done"},
|
|
||||||
wantContent: "<|tool_call>oops no call syntax<tool_call|> done",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (9) <turn|> ends the turn -------------------------------------------
|
|
||||||
Entry("text after <turn|> is ignored, including later fragments", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"before<turn|>after",
|
|
||||||
`more <|tool_call>call:f{}<tool_call|>`,
|
|
||||||
},
|
|
||||||
wantContent: "before",
|
|
||||||
}),
|
|
||||||
Entry("<turn|> inside a thought channel ends the turn", parseGemma4Case{
|
|
||||||
startInThought: true,
|
|
||||||
fragments: []string{"thinking<turn|>ignored"},
|
|
||||||
wantReasoning: "thinking",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (10) ported vLLM non-streaming cases ---------------------------------
|
|
||||||
// vLLM: test_single_tool_call
|
|
||||||
Entry("vLLM: test_single_tool_call", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_multiple_arguments
|
|
||||||
Entry("vLLM: test_multiple_arguments", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>}<tool_call|>`},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"San Francisco","unit":"celsius"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_text_before_tool_call. DIVERGENCE: vLLM's non-streaming
|
|
||||||
// extractor trims the content ("...you."); a streaming parser cannot
|
|
||||||
// retroactively trim already-emitted text, so the trailing space is
|
|
||||||
// kept (vLLM's own streaming path keeps it too, see
|
|
||||||
// test_streaming_text_before_tool_call which only checks a prefix).
|
|
||||||
Entry("vLLM: test_text_before_tool_call (streaming semantics: no trim)", parseGemma4Case{
|
|
||||||
fragments: []string{`Let me check the weather for you. <|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`},
|
|
||||||
wantContent: "Let me check the weather for you. ",
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_multiple_tool_calls (also covers case 11: multi-tool sequence)
|
|
||||||
Entry("vLLM: test_multiple_tool_calls", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|><|tool_call>call:get_time{location:<|"|>London<|"|>}<tool_call|>`},
|
|
||||||
wantTools: []wantGemma4Tool{
|
|
||||||
{name: "get_weather", argsJSON: `{"location":"London"}`},
|
|
||||||
{name: "get_time", argsJSON: `{"location":"London"}`},
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
// vLLM: test_nested_arguments
|
|
||||||
Entry("vLLM: test_nested_arguments", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:complex_function{nested:{inner:<|"|>value<|"|>},list:[<|"|>a<|"|>,<|"|>b<|"|>]}<tool_call|>`},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "complex_function", argsJSON: `{"nested":{"inner":"value"},"list":["a","b"]}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_tool_call_with_number_and_boolean
|
|
||||||
Entry("vLLM: test_tool_call_with_number_and_boolean", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:set_status{is_active:true,count:42,score:3.14}<tool_call|>`},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "set_status", argsJSON: `{"is_active":true,"count":42,"score":3.14}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_hyphenated_function_name
|
|
||||||
Entry("vLLM: test_hyphenated_function_name", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>`},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get-weather", argsJSON: `{"location":"London"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_dotted_function_name
|
|
||||||
Entry("vLLM: test_dotted_function_name", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>`},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "weather.get", argsJSON: `{"location":"London"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_no_arguments
|
|
||||||
Entry("vLLM: test_no_arguments", parseGemma4Case{
|
|
||||||
fragments: []string{"<|tool_call>call:get_status{}<tool_call|>"},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- ported vLLM streaming cases (chunk lists reused as fragments) --------
|
|
||||||
// vLLM: test_basic_streaming_single_tool
|
|
||||||
Entry("vLLM: test_basic_streaming_single_tool", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:get_weather{",
|
|
||||||
`location:<|"|>Paris`,
|
|
||||||
", France",
|
|
||||||
`<|"|>}`,
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_multi_arg
|
|
||||||
Entry("vLLM: test_streaming_multi_arg", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:get_weather{",
|
|
||||||
`location:<|"|>Tokyo<|"|>,`,
|
|
||||||
`unit:<|"|>celsius<|"|>}`,
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Tokyo","unit":"celsius"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_text_before_tool_call
|
|
||||||
Entry("vLLM: test_streaming_text_before_tool_call", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"Let me check ",
|
|
||||||
"the weather. ",
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:get_weather{",
|
|
||||||
`location:<|"|>London<|"|>}`,
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantContent: "Let me check the weather. ",
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_numeric_args
|
|
||||||
Entry("vLLM: test_streaming_numeric_args", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:set_config{",
|
|
||||||
"count:42,",
|
|
||||||
"active:true}",
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "set_config", argsJSON: `{"count":42,"active":true}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_boolean_split_across_chunks
|
|
||||||
Entry("vLLM: test_streaming_boolean_split_across_chunks", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:search{input:{all:tru",
|
|
||||||
"e}}",
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "search", argsJSON: `{"input":{"all":true}}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_false_split_across_chunks
|
|
||||||
Entry("vLLM: test_streaming_false_split_across_chunks", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:set{flag:fals",
|
|
||||||
"e}",
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"flag":false}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_number_split_across_chunks
|
|
||||||
Entry("vLLM: test_streaming_number_split_across_chunks", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:set{count:4",
|
|
||||||
"2}",
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"count":42}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_empty_args
|
|
||||||
Entry("vLLM: test_streaming_empty_args", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:get_status{}",
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_split_delimiter_no_invalid_json (string
|
|
||||||
// delimiter <|"|> split across fragments must not leak fragments).
|
|
||||||
Entry("vLLM: test_streaming_split_delimiter_no_invalid_json", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:todowrite{",
|
|
||||||
`content:<|"|>Buy milk<|`,
|
|
||||||
`"|>}`,
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "todowrite", argsJSON: `{"content":"Buy milk"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call
|
|
||||||
Entry("vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:get_weather{",
|
|
||||||
`location:<|"|>Paris<|"|>}`,
|
|
||||||
"<tool_call|><",
|
|
||||||
"div>",
|
|
||||||
},
|
|
||||||
wantContent: "<div>",
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes
|
|
||||||
Entry("vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:write_file{",
|
|
||||||
`path:<|"|>index.html<|"|>,`,
|
|
||||||
`content:<|"|><!DOCTYPE html>` + "\n<",
|
|
||||||
`html lang="zh-CN">` + "\n<",
|
|
||||||
"head>\n <",
|
|
||||||
`meta charset="UTF-8">` + "\n <",
|
|
||||||
`meta name="viewport" content="width=device-width">` + "\n",
|
|
||||||
`<|"|>}`,
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{
|
|
||||||
name: "write_file",
|
|
||||||
argsJSON: `{"path":"index.html","content":"<!DOCTYPE html>\n<html lang=\"zh-CN\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width\">\n"}`,
|
|
||||||
}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_single_chunk_complete_tool_call
|
|
||||||
Entry("vLLM: test_streaming_single_chunk_complete_tool_call", parseGemma4Case{
|
|
||||||
fragments: []string{`<|tool_call>call:name_a_color{color_hex:<|"|>00ff11<|"|>}<tool_call|>`},
|
|
||||||
wantTools: []wantGemma4Tool{{name: "name_a_color", argsJSON: `{"color_hex":"00ff11"}`}},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_multi_chunk_batched_tool_calls (two complete
|
|
||||||
// calls in ONE fragment; both must come out with distinct indices)
|
|
||||||
Entry("vLLM: test_streaming_multi_chunk_batched_tool_calls", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>` +
|
|
||||||
`<|tool_call>call:get_time{timezone:<|"|>GMT<|"|>}<tool_call|>`,
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{
|
|
||||||
{name: "get_weather", argsJSON: `{"location":"London"}`},
|
|
||||||
{name: "get_time", argsJSON: `{"timezone":"GMT"}`},
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
// vLLM: test_streaming_trailing_bare_bool_not_duplicated
|
|
||||||
Entry("vLLM: test_streaming_trailing_bare_bool_not_duplicated", parseGemma4Case{
|
|
||||||
fragments: []string{
|
|
||||||
"<|tool_call>",
|
|
||||||
"call:Edit{",
|
|
||||||
`file_path:<|"|>src/env.py<|"|>,`,
|
|
||||||
`old_string:<|"|>old_val<|"|>,`,
|
|
||||||
`new_string:<|"|>new_val<|"|>,`,
|
|
||||||
"replace_all:",
|
|
||||||
"false}",
|
|
||||||
"<tool_call|>",
|
|
||||||
},
|
|
||||||
wantTools: []wantGemma4Tool{{
|
|
||||||
name: "Edit",
|
|
||||||
argsJSON: `{"file_path":"src/env.py","old_string":"old_val","new_string":"new_val","replace_all":false}`,
|
|
||||||
}},
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- implicit reasoning end on <|tool_call> (vLLM is_reasoning_end:
|
|
||||||
// a tool_call token means reasoning is over) -----------------------------
|
|
||||||
Entry("tool call inside an open thought channel ends the reasoning", parseGemma4Case{
|
|
||||||
startInThought: true,
|
|
||||||
fragments: []string{`need the weather<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}<tool_call|>`},
|
|
||||||
wantReasoning: "need the weather",
|
|
||||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Rome"}`}},
|
|
||||||
}),
|
|
||||||
|
|
||||||
// --- (12) empty fragments are no-ops --------------------------------------
|
|
||||||
Entry("empty fragments are no-ops", parseGemma4Case{
|
|
||||||
fragments: []string{"", "Hello", "", "", " world", ""},
|
|
||||||
wantContent: "Hello world",
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
It("returns no deltas for an empty fragment and after Close", func() {
|
|
||||||
p := NewGemma4Parser(false)
|
|
||||||
Expect(p.Feed("")).To(BeEmpty())
|
|
||||||
Expect(p.Feed("hi")).ToNot(BeEmpty())
|
|
||||||
Expect(p.Close()).To(BeEmpty()) // nothing held back
|
|
||||||
// The parser is finished after Close: further input is dropped.
|
|
||||||
Expect(p.Feed("more")).To(BeEmpty())
|
|
||||||
Expect(p.Close()).To(BeEmpty())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates index-based tool call ids (call_<index>)", func() {
|
|
||||||
// Mirrors the index-based id convention of pkg/grpc/rich_test.go and
|
|
||||||
// keeps ids deterministic for the split-invariance property below.
|
|
||||||
deltas := parseGemma4Fragments(false, []string{
|
|
||||||
`<|tool_call>call:a{}<tool_call|><|tool_call>call:b{}<tool_call|>`,
|
|
||||||
})
|
|
||||||
_, _, tools := flattenGemma4Deltas(deltas)
|
|
||||||
Expect(tools).To(HaveLen(2))
|
|
||||||
Expect(tools[0].id).To(Equal("call_0"))
|
|
||||||
Expect(tools[1].id).To(Equal("call_1"))
|
|
||||||
})
|
|
||||||
|
|
||||||
// Property: for a fixed full output, EVERY 2-split position must yield
|
|
||||||
// exactly the same flattened result as the unsplit parse. This kills
|
|
||||||
// fragment-boundary bugs (mid-marker, mid-delimiter, mid-payload splits).
|
|
||||||
DescribeTable("2-split fragment invariance",
|
|
||||||
func(startInThought bool, full string) {
|
|
||||||
refContent, refReasoning, refTools := flattenGemma4Deltas(
|
|
||||||
parseGemma4Fragments(startInThought, []string{full}))
|
|
||||||
for i := 0; i <= len(full); i++ {
|
|
||||||
content, reasoning, tools := flattenGemma4Deltas(
|
|
||||||
parseGemma4Fragments(startInThought, []string{full[:i], full[i:]}))
|
|
||||||
Expect(content).To(Equal(refContent), fmt.Sprintf("content diverged at split %d", i))
|
|
||||||
Expect(reasoning).To(Equal(refReasoning), fmt.Sprintf("reasoning diverged at split %d", i))
|
|
||||||
Expect(tools).To(Equal(refTools), fmt.Sprintf("tool calls diverged at split %d", i))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Entry("thought + content + two tool calls + turn end", false,
|
|
||||||
"<|channel>thought\nPondering the request...\n<channel|>Sure - calling tools now. "+
|
|
||||||
`<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>,days:3,detailed:true}<tool_call|>`+
|
|
||||||
`<|tool_call>call:get_time{timezone:<|"|>Europe/Lisbon<|"|>,nested:{flag:false,vals:[1,2.5,<|"|>x<|"|>]}}<tool_call|>`+
|
|
||||||
"Done.<turn|>ignored tail"),
|
|
||||||
Entry("startInThought + tool call + trailing partial marker", true,
|
|
||||||
`Deep thought<channel|>final answer <|tool_call>call:noop{}<tool_call|> trailing <|tool`),
|
|
||||||
Entry("malformed payload fallback", false,
|
|
||||||
`pre <|tool_call>not a call<tool_call|> post`),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Decoder-level ports of vLLM's TestParseGemma4Args / TestParseGemma4Array
|
|
||||||
// (non-partial mode; the partial-withholding tests do not apply because this
|
|
||||||
// parser only ever decodes COMPLETE payloads, see gemma4_parser.go).
|
|
||||||
var _ = Describe("decodeGemma4Args", func() {
|
|
||||||
DescribeTable("decodes the gemma4 call syntax into JSON arguments",
|
|
||||||
func(in, wantJSON string) {
|
|
||||||
Expect(decodeGemma4Args(in, 0)).To(MatchJSON(wantJSON))
|
|
||||||
},
|
|
||||||
// vLLM: test_empty_string / test_whitespace_only
|
|
||||||
Entry("empty string", "", `{}`),
|
|
||||||
Entry("whitespace only", " ", `{}`),
|
|
||||||
// vLLM: test_single_string_value
|
|
||||||
Entry("single string value", `location:<|"|>Paris<|"|>`, `{"location":"Paris"}`),
|
|
||||||
// vLLM: test_string_value_with_comma
|
|
||||||
Entry("string value with comma", `location:<|"|>Paris, France<|"|>`, `{"location":"Paris, France"}`),
|
|
||||||
// vLLM: test_multiple_string_values
|
|
||||||
Entry("multiple string values", `location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>`, `{"location":"San Francisco","unit":"celsius"}`),
|
|
||||||
// vLLM: test_integer_value / test_float_value
|
|
||||||
Entry("integer value", "count:42", `{"count":42}`),
|
|
||||||
Entry("float value", "score:3.14", `{"score":3.14}`),
|
|
||||||
// vLLM: test_boolean_true / test_boolean_false
|
|
||||||
Entry("boolean true", "flag:true", `{"flag":true}`),
|
|
||||||
Entry("boolean false", "flag:false", `{"flag":false}`),
|
|
||||||
// vLLM: test_null_value (bare null must become JSON null, not "null")
|
|
||||||
Entry("null value", "param:null", `{"param":null}`),
|
|
||||||
// vLLM: test_mixed_types
|
|
||||||
Entry("mixed types", `name:<|"|>test<|"|>,count:42,active:true,score:3.14`,
|
|
||||||
`{"name":"test","count":42,"active":true,"score":3.14}`),
|
|
||||||
// vLLM: test_nested_object
|
|
||||||
Entry("nested object", `nested:{inner:<|"|>value<|"|>}`, `{"nested":{"inner":"value"}}`),
|
|
||||||
// vLLM: test_array_of_strings
|
|
||||||
Entry("array of strings", `items:[<|"|>a<|"|>,<|"|>b<|"|>]`, `{"items":["a","b"]}`),
|
|
||||||
// vLLM: test_unterminated_string (take everything after the delimiter)
|
|
||||||
Entry("unterminated string", `key:<|"|>unterminated`, `{"key":"unterminated"}`),
|
|
||||||
// vLLM: test_empty_value (key with no value after colon)
|
|
||||||
Entry("empty value", "key:", `{"key":""}`),
|
|
||||||
// vLLM: test_trailing_dot_float_partial_withheld, non-partial branch
|
|
||||||
// (trailing-dot floats parse normally outside streaming).
|
|
||||||
Entry("trailing dot float, complete payload", "left:108.,right:22.8", `{"left":108.0,"right":22.8}`),
|
|
||||||
)
|
|
||||||
|
|
||||||
It("terminates and yields valid JSON on malformed input", func() {
|
|
||||||
// vLLM: test_malformed_partial_array (the assertion there is only
|
|
||||||
// "returns a dict without hanging"; ours is "valid JSON object").
|
|
||||||
out := decodeGemma4Args(":[t:[]", 0)
|
|
||||||
var v map[string]any
|
|
||||||
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("degrades nesting beyond the recursion cap to a string value", func() {
|
|
||||||
// 200 levels of a:{a:{...a:1...}}. Without the depth cap the mutual
|
|
||||||
// recursion would grow the stack with the model's output; a Go stack
|
|
||||||
// overflow is a fatal process kill, so levels past gemma4MaxArgsDepth
|
|
||||||
// must gracefully fall back to the raw inner text as a JSON string.
|
|
||||||
const depth = 200
|
|
||||||
body := strings.Repeat("a:{", depth-1) + "a:1" + strings.Repeat("}", depth-1)
|
|
||||||
out := decodeGemma4Args(body, 0)
|
|
||||||
var v map[string]any
|
|
||||||
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
|
|
||||||
levels := 0
|
|
||||||
var cur any = v
|
|
||||||
for {
|
|
||||||
m, ok := cur.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
Expect(m).To(HaveKey("a"))
|
|
||||||
cur = m["a"]
|
|
||||||
levels++
|
|
||||||
}
|
|
||||||
Expect(levels).To(Equal(gemma4MaxArgsDepth + 1))
|
|
||||||
Expect(cur).To(BeAssignableToTypeOf(""))
|
|
||||||
Expect(cur).To(ContainSubstring("a:{"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
var _ = Describe("decodeGemma4Array", func() {
|
|
||||||
DescribeTable("decodes gemma4 array bodies into JSON arrays",
|
|
||||||
func(in, wantJSON string) {
|
|
||||||
Expect(decodeGemma4Array(in, 0)).To(MatchJSON(wantJSON))
|
|
||||||
},
|
|
||||||
// vLLM: test_string_array / test_empty_array / test_bare_values
|
|
||||||
Entry("string array", `<|"|>a<|"|>,<|"|>b<|"|>`, `["a","b"]`),
|
|
||||||
Entry("empty array", "", `[]`),
|
|
||||||
Entry("bare values", "42,true,3.14", `[42,true,3.14]`),
|
|
||||||
// vLLM: test_string_element_with_closing_bracket (a ']' inside a
|
|
||||||
// delimited string must not close the array)
|
|
||||||
Entry("string element with closing bracket", `[<|"|>a]b<|"|>,<|"|>c<|"|>],<|"|>tail<|"|>`, `[["a]b","c"],"tail"]`),
|
|
||||||
// vLLM: test_stray_closing_bracket (no-progress abort, keep prefix)
|
|
||||||
Entry("stray closing bracket", "42,]trailing", `[42]`),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,406 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// Renderer specs for RenderGemma4 against the canonical gemma4 chat template
|
|
||||||
// (see the normative template comment in gemma4_renderer.go).
|
|
||||||
//
|
|
||||||
// Fixture provenance:
|
|
||||||
// - "single user message" and "enable_thinking" are the EXACT expected
|
|
||||||
// decodes from transformers tests/models/diffusion_gemma/
|
|
||||||
// test_modeling_diffusion_gemma.py (test_diffusion_gemma_chat_template
|
|
||||||
// and ..._with_thinking) with ONE difference: the transformers fixtures
|
|
||||||
// start with "<bos>" because apply_chat_template tokenizes the rendered
|
|
||||||
// text with add_bos. Our prompt goes through dllm_capi_generate, whose
|
|
||||||
// run_generate already tokenizes with prepend_bos = vocab.add_bos
|
|
||||||
// (dllm.cpp src/capi.cpp:230-231, true for gemma4), so the renderer must
|
|
||||||
// NOT emit a literal <bos> (it would double) and every expected string
|
|
||||||
// here drops that leading token.
|
|
||||||
// - All other expected strings were produced by rendering the verbatim
|
|
||||||
// GGUF template with jinja2 3.1.2 (bos_token="<bos>") and dropping the
|
|
||||||
// leading "<bos>" for the same reason.
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Two-function tools array used by the tool fixtures (OpenAI wire shape, as
|
|
||||||
// LocalAI passes it through PredictOptions.Tools).
|
|
||||||
const testToolsJSON = `[{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a location.","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city name."},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}},{"type":"function","function":{"name":"get_time","description":"Get the current time in a timezone.","parameters":{"type":"object","properties":{"timezone":{"type":"string","description":"IANA timezone name."}},"required":["timezone"]}}}]`
|
|
||||||
|
|
||||||
// The <|tool>...<tool|> block the template renders for testToolsJSON inside
|
|
||||||
// the system turn (jinja2-verified).
|
|
||||||
const testToolsBlock = `<|tool>declaration:get_weather{description:<|"|>Get the current weather in a location.<|"|>,parameters:{properties:{location:{description:<|"|>The city name.<|"|>,type:<|"|>STRING<|"|>},unit:{enum:[<|"|>celsius<|"|>,<|"|>fahrenheit<|"|>],type:<|"|>STRING<|"|>}},required:[<|"|>location<|"|>],type:<|"|>OBJECT<|"|>}}<tool|><|tool>declaration:get_time{description:<|"|>Get the current time in a timezone.<|"|>,parameters:{properties:{timezone:{description:<|"|>IANA timezone name.<|"|>,type:<|"|>STRING<|"|>}},required:[<|"|>timezone<|"|>],type:<|"|>OBJECT<|"|>}}<tool|>`
|
|
||||||
|
|
||||||
// A single tool exercising the deep format_parameters branches: array items
|
|
||||||
// (string-typed and nested-array), nullable, enum+nullable, nested object
|
|
||||||
// properties/required, and a response declaration.
|
|
||||||
const complexToolsJSON = `[{"type":"function","function":{"name":"complex_tool","description":"A complex tool.","parameters":{"type":"object","properties":{"tags":{"type":"array","description":"Tags.","items":{"type":"string"}},"matrix":{"type":"array","items":{"type":"array","items":{"type":"number"}}},"opts":{"type":"object","description":"Options.","properties":{"depth":{"type":"integer","nullable":true}},"required":["depth"]},"mode":{"type":"string","enum":["a","b"],"nullable":true}},"required":["tags","opts"]},"response":{"description":"The result.","type":"object"}}}]`
|
|
||||||
|
|
||||||
// jinja2-verified render of complexToolsJSON. Notable template quirks pinned
|
|
||||||
// here: nested array items go through format_argument with ESCAPED keys and
|
|
||||||
// an un-uppercased type (<|"|>type<|"|>:<|"|>number<|"|>), while direct item
|
|
||||||
// types are uppercased; properties dictsort case-insensitively.
|
|
||||||
const complexToolsBlock = `<|tool>declaration:complex_tool{description:<|"|>A complex tool.<|"|>,parameters:{properties:{matrix:{items:{items:{<|"|>type<|"|>:<|"|>number<|"|>},type:<|"|>ARRAY<|"|>},type:<|"|>ARRAY<|"|>},mode:{enum:[<|"|>a<|"|>,<|"|>b<|"|>],nullable:true,type:<|"|>STRING<|"|>},opts:{description:<|"|>Options.<|"|>,properties:{depth:{nullable:true,type:<|"|>INTEGER<|"|>}},required:[<|"|>depth<|"|>],type:<|"|>OBJECT<|"|>},tags:{description:<|"|>Tags.<|"|>,items:{type:<|"|>STRING<|"|>},type:<|"|>ARRAY<|"|>}},required:[<|"|>tags<|"|>,<|"|>opts<|"|>],type:<|"|>OBJECT<|"|>},response:{description:<|"|>The result.<|"|>,type:<|"|>OBJECT<|"|>}}<tool|>`
|
|
||||||
|
|
||||||
type renderGemma4Case struct {
|
|
||||||
msgs []*pb.Message
|
|
||||||
toolsJSON string
|
|
||||||
// nImages mirrors len(PredictOptions.Images): the OpenAI layer strips
|
|
||||||
// image content parts out of the messages, so the renderer re-injects
|
|
||||||
// one engine marker per image on the last user message (see the IMAGE
|
|
||||||
// NOTE on RenderGemma4).
|
|
||||||
nImages int
|
|
||||||
enableThinking bool
|
|
||||||
noGenerationPrompt bool // inverted so the zero value is the common case
|
|
||||||
expected string
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("RenderGemma4", func() {
|
|
||||||
DescribeTable("renders the canonical gemma4 prompt",
|
|
||||||
func(c renderGemma4Case) {
|
|
||||||
out, err := RenderGemma4(c.msgs, c.toolsJSON, c.nImages, c.enableThinking, !c.noGenerationPrompt)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(out).To(Equal(c.expected))
|
|
||||||
// The C-ABI generate prepends BOS itself: a literal <bos>
|
|
||||||
// anywhere in the rendered prompt would double-encode it.
|
|
||||||
Expect(out).ToNot(ContainSubstring("<bos>"))
|
|
||||||
},
|
|
||||||
|
|
||||||
// transformers fixture (test_diffusion_gemma_chat_template), sans <bos>:
|
|
||||||
// default thinking pre-opens an EMPTY thought channel in the
|
|
||||||
// generation prompt.
|
|
||||||
Entry("single user message, default (no thinking)", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "Write a long essay about Portugal."},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// transformers fixture (test_diffusion_gemma_chat_template_with_thinking),
|
|
||||||
// sans <bos>: a system turn carrying <|think|> and NO auto-opened
|
|
||||||
// thought channel.
|
|
||||||
Entry("enable_thinking=true", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "Write a long essay about Portugal."},
|
|
||||||
},
|
|
||||||
enableThinking: true,
|
|
||||||
expected: "<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n",
|
|
||||||
}),
|
|
||||||
|
|
||||||
Entry("multi-turn user/assistant/user", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "Hello, who are you?"},
|
|
||||||
{Role: "assistant", Content: "I am Gemma, a helpful assistant."},
|
|
||||||
{Role: "user", Content: "Tell me a joke."},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\nHello, who are you?<turn|>\n<|turn>model\nI am Gemma, a helpful assistant.<turn|>\n<|turn>user\nTell me a joke.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L178-L195: a leading system message is folded into the system
|
|
||||||
// turn (trimmed) and consumed from the loop.
|
|
||||||
Entry("system message folds into the system turn", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "system", Content: "You are a pirate."},
|
|
||||||
{Role: "user", Content: "Hello!"},
|
|
||||||
},
|
|
||||||
expected: "<|turn>system\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L182-L185: <|think|> goes at the very top of the SAME system
|
|
||||||
// turn, before the system prompt text.
|
|
||||||
Entry("system message with enable_thinking shares the turn", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "system", Content: "You are a pirate."},
|
|
||||||
{Role: "user", Content: "Hello!"},
|
|
||||||
},
|
|
||||||
enableThinking: true,
|
|
||||||
expected: "<|turn>system\n<|think|>\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L196-L203: tool declarations render in the system turn, one
|
|
||||||
// <|tool>declaration:...<tool|> block per tool, no separators.
|
|
||||||
Entry("tools array (two functions)", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
|
||||||
},
|
|
||||||
toolsJSON: testToolsJSON,
|
|
||||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// format_parameters deep branches (tpl L1-L85) + response declaration
|
|
||||||
// (tpl L106-L116).
|
|
||||||
Entry("complex tool schema (array items, nullable, nested object, response)", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "go"},
|
|
||||||
},
|
|
||||||
toolsJSON: complexToolsJSON,
|
|
||||||
expected: "<|turn>system\n" + complexToolsBlock + "<turn|>\n<|turn>user\ngo<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L243-L313: assistant tool_calls render as
|
|
||||||
// <|tool_call>call:name{args}<tool_call|>; the following role=tool
|
|
||||||
// message renders inline as <|tool_response>response:name{value:..}
|
|
||||||
// <tool_response|>; the model turn stays OPEN (no <turn|>, no new
|
|
||||||
// generation prompt) so the model continues after the response.
|
|
||||||
Entry("assistant tool_calls + role=tool result", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
|
|
||||||
{Role: "tool", ToolCallId: "call_1", Content: "Sunny, 22 degrees celsius."},
|
|
||||||
},
|
|
||||||
toolsJSON: testToolsJSON,
|
|
||||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny, 22 degrees celsius.<|"|>}<tool_response|>`,
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L348-L349: a tool_calls turn with no rendered responses ends
|
|
||||||
// on an OPEN <|tool_response> marker for the runtime to fill, and
|
|
||||||
// add_generation_prompt adds nothing (tpl L357).
|
|
||||||
Entry("assistant tool_calls without a result leaves <|tool_response> open", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
|
|
||||||
},
|
|
||||||
toolsJSON: testToolsJSON,
|
|
||||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>`,
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L237-L241: reasoning_content renders as a thought channel only
|
|
||||||
// on a tool-calling turn after the last user message.
|
|
||||||
Entry("reasoning_content with tool_calls renders the thought channel", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "weather?"},
|
|
||||||
{Role: "assistant", Content: "", ReasoningContent: "I should call the tool", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
|
|
||||||
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n<|channel>thought\nI should call the tool\n<channel|>" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>`,
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L220-L235: the assistant answer following its own tool round
|
|
||||||
// continues the SAME model turn (no second <|turn>model).
|
|
||||||
Entry("tool round then final assistant answer then user", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "weather?"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
|
|
||||||
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
|
|
||||||
{Role: "assistant", Content: "It is sunny."},
|
|
||||||
{Role: "user", Content: "thanks"},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>` + "It is sunny.<turn|>\n<|turn>user\nthanks<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// format_argument (tpl L118-L147): numbers keep their JSON literal,
|
|
||||||
// booleans lower-case, nested maps have unquoted dictsorted keys,
|
|
||||||
// arrays bracketed; top-level args are dictsorted case-insensitively.
|
|
||||||
Entry("tool_call argument types (number/bool/nested/array)", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "go"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"count\":42,\"ratio\":3.5,\"flag\":true,\"off\":false,\"nested\":{\"x\":\"y\",\"n\":7},\"list\":[\"a\",1,true]}"}}]`},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{count:42,flag:true,list:[<|"|>a<|"|>,1,true],nested:{n:7,x:<|"|>y<|"|>},off:false,ratio:3.5}<tool_call|><|tool_response>`,
|
|
||||||
}),
|
|
||||||
|
|
||||||
// jinja dictsort is case-insensitive: alpha sorts before Beta.
|
|
||||||
Entry("tool_call argument dictsort is case-insensitive", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "go"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"Beta\":1,\"alpha\":2}"}}]`},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{alpha:2,Beta:1}<tool_call|><|tool_response>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// jinja renders Python None as "None" (round-trips through vLLM's
|
|
||||||
// parser, which lowers "none" back to null).
|
|
||||||
Entry("tool_call null argument renders as None", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "go"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"maybe\":null}"}}]`},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{maybe:None}<tool_call|><|tool_response>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
Entry("tool_call empty arguments render empty braces", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "go"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{}<tool_call|><|tool_response>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L253-L254: a non-object arguments string renders verbatim.
|
|
||||||
Entry("tool_call non-object string arguments render verbatim", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "go"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"just text"}}]`},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{just text}<tool_call|><|tool_response>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L278-L285: unmatched tool_call_id falls back to the tool
|
|
||||||
// message's own name.
|
|
||||||
Entry("tool result name falls back when tool_call_id does not match", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "go"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
|
|
||||||
{Role: "tool", ToolCallId: "OTHER", Name: "named_tool", Content: "out"},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{}<tool_call|><|tool_response>response:named_tool{value:<|"|>out<|"|>}<tool_response|>`,
|
|
||||||
}),
|
|
||||||
|
|
||||||
// strip_thinking (tpl L148-L158): historical assistant content loses
|
|
||||||
// its <|channel>...<channel|> spans.
|
|
||||||
Entry("assistant content thinking channels are stripped", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
{Role: "assistant", Content: "<|channel>thought\nsecret\n<channel|>visible answer"},
|
|
||||||
{Role: "user", Content: "more"},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\nvisible answer<turn|>\n<|turn>user\nmore<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L220-L235: consecutive assistant messages suppress the second
|
|
||||||
// <|turn>model (continuation), but each still closes with <turn|>.
|
|
||||||
Entry("consecutive assistant messages continue the model turn", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
{Role: "assistant", Content: "part one"},
|
|
||||||
{Role: "assistant", Content: "part two"},
|
|
||||||
{Role: "user", Content: "ok"},
|
|
||||||
},
|
|
||||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\npart one<turn|>\npart two<turn|>\n<|turn>user\nok<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
Entry("add_generation_prompt=false renders no model turn", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
},
|
|
||||||
noGenerationPrompt: true,
|
|
||||||
expected: "<|turn>user\nhi<turn|>\n",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// One engine marker per image, appended directly after the user
|
|
||||||
// text with no separator (tpl L323-L341 emits parts back-to-back;
|
|
||||||
// "<image>" is dllm_capi.h's splice marker, not the template's
|
|
||||||
// <|image|> text token - see the IMAGE NOTE on RenderGemma4).
|
|
||||||
Entry("one image appends one engine marker to the user message", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "What is in this picture?"},
|
|
||||||
},
|
|
||||||
nImages: 1,
|
|
||||||
expected: "<|turn>user\nWhat is in this picture?<image><turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
Entry("multiple images append markers in image order", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "Compare these."},
|
|
||||||
},
|
|
||||||
nImages: 3,
|
|
||||||
expected: "<|turn>user\nCompare these.<image><image><image><turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// Flattened delivery loses per-message attribution, so all images
|
|
||||||
// attach to the LAST user message (llama.cpp grpc-server convention).
|
|
||||||
Entry("images attach to the last user message in multi-turn", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
{Role: "assistant", Content: "hello"},
|
|
||||||
{Role: "user", Content: "and this?"},
|
|
||||||
},
|
|
||||||
nImages: 1,
|
|
||||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\nhello<turn|>\n<|turn>user\nand this?<image><turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
|
|
||||||
// tpl L346: the markers count as captured_content, so an image-only
|
|
||||||
// user message still has content and closes its turn normally.
|
|
||||||
Entry("image with empty user text still closes the turn", renderGemma4Case{
|
|
||||||
msgs: []*pb.Message{
|
|
||||||
{Role: "user", Content: ""},
|
|
||||||
},
|
|
||||||
nImages: 1,
|
|
||||||
expected: "<|turn>user\n<image><turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
Describe("error handling", func() {
|
|
||||||
It("fails loud on an unknown role", func() {
|
|
||||||
_, err := RenderGemma4([]*pb.Message{
|
|
||||||
{Role: "narrator", Content: "Meanwhile..."},
|
|
||||||
}, "", 0, false, true)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring(`unknown role "narrator"`))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails on invalid tools JSON", func() {
|
|
||||||
_, err := RenderGemma4([]*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
}, "{not json", 0, false, true)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("tools JSON"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails on invalid tool_calls JSON", func() {
|
|
||||||
_, err := RenderGemma4([]*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
{Role: "assistant", Content: "", ToolCalls: "{not json"},
|
|
||||||
}, "", 0, false, true)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("tool_calls JSON"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails on an orphan tool message, naming its index", func() {
|
|
||||||
// A role:tool message with no preceding assistant tool_calls turn
|
|
||||||
// would be silently dropped by the jinja; we fail loud instead.
|
|
||||||
_, err := RenderGemma4([]*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
{Role: "tool", Content: `{"temp": 20}`, ToolCallId: "call_1"},
|
|
||||||
}, "", 0, false, true)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("orphan tool message 1"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails on trailing garbage after the tools JSON array", func() {
|
|
||||||
_, err := RenderGemma4([]*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
}, "[] junk", 0, false, true)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("tools JSON"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails when the tools JSON is not an array", func() {
|
|
||||||
_, err := RenderGemma4([]*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
}, `{"type":"function"}`, 0, false, true)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("tools JSON is not an array"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails when a tools array element is not an object", func() {
|
|
||||||
_, err := RenderGemma4([]*pb.Message{
|
|
||||||
{Role: "user", Content: "hi"},
|
|
||||||
}, `[42]`, 0, false, true)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("tools[0] is not an object"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects a nil message via the unknown-role check", func() {
|
|
||||||
// Pins current behavior: pb getters are nil-safe, so a nil message
|
|
||||||
// reads as role "" and trips the fail-loud unknown-role guard.
|
|
||||||
_, err := RenderGemma4([]*pb.Message{nil}, "", 0, false, true)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring(`unknown role "" in message 0`))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails loud on images with no user message to attach them to", func() {
|
|
||||||
// The engine would reject the markerless prompt anyway
|
|
||||||
// (marker/image count mismatch); the renderer surfaces the bad
|
|
||||||
// request with a usable message instead.
|
|
||||||
_, err := RenderGemma4([]*pb.Message{
|
|
||||||
{Role: "system", Content: "sys"},
|
|
||||||
{Role: "assistant", Content: "hi"},
|
|
||||||
}, "", 1, false, true)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("no user message"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
|
||||||
//
|
|
||||||
// Loads libdllm.so via purego and registers the flat C-ABI declared in
|
|
||||||
// dllm.cpp's include/dllm_capi.h (ABI v1): 9 mandatory symbols plus the
|
|
||||||
// Dlsym-probed optional multimodal pair. The library name can
|
|
||||||
// be overridden with DLLM_LIBRARY (mirrors the PARAKEET_LIBRARY /
|
|
||||||
// WHISPER_LIBRARY convention in the sibling backends); the default looks
|
|
||||||
// for the .so next to this binary (run.sh puts the package dir on
|
|
||||||
// LD_LIBRARY_PATH).
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/ebitengine/purego"
|
|
||||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
|
||||||
)
|
|
||||||
|
|
||||||
type LibFuncs struct {
|
|
||||||
FuncPtr any
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadCAPI dlopens libName and binds the 9 dllm_capi_* entry points 1:1 to
|
|
||||||
// dllm_capi.h, so an `nm libdllm.so | grep dllm_capi` is enough to spot
|
|
||||||
// drift. Shared with the test suite (ensureLibLoaded), which drives the
|
|
||||||
// bridge without the gRPC server.
|
|
||||||
//
|
|
||||||
// The C-ABI returns malloc'd char* buffers from tokenize_json/generate; we
|
|
||||||
// register those as uintptr so we get the raw pointer back and can call
|
|
||||||
// dllm_capi_free_string on it (purego's string return would copy and forget
|
|
||||||
// the original pointer, leaking it on every call). last_error returns a
|
|
||||||
// BORROWED pointer instead, so it is registered as a plain string: purego
|
|
||||||
// copies it and nothing must be freed.
|
|
||||||
func loadCAPI(libName string) error {
|
|
||||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("dllm: dlopen %q: %w", libName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
libFuncs := []LibFuncs{
|
|
||||||
{&cppAbiVersion, "dllm_capi_abi_version"},
|
|
||||||
{&cppLoad, "dllm_capi_load"},
|
|
||||||
{&cppFree, "dllm_capi_free"},
|
|
||||||
{&cppLastError, "dllm_capi_last_error"},
|
|
||||||
{&cppFreeString, "dllm_capi_free_string"},
|
|
||||||
{&cppTokenizeJSON, "dllm_capi_tokenize_json"},
|
|
||||||
{&cppGenerate, "dllm_capi_generate"},
|
|
||||||
{&cppGenerateStream, "dllm_capi_generate_stream"},
|
|
||||||
{&cppCancel, "dllm_capi_cancel"},
|
|
||||||
}
|
|
||||||
for _, lf := range libFuncs {
|
|
||||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multimodal entry points (dllm_capi.h's P4 surface). Additive: the ABI
|
|
||||||
// version stays 1 and consumers detect the surface by probing the symbols
|
|
||||||
// (the parakeet-cpp optional-symbol pattern), so the backend still loads
|
|
||||||
// against an older text-only libdllm.so - image requests then fail with
|
|
||||||
// errMMUnsupported instead of a boot failure.
|
|
||||||
if sym, err := purego.Dlsym(lib, "dllm_capi_generate_mm"); err == nil && sym != 0 {
|
|
||||||
purego.RegisterLibFunc(&cppGenerateMM, lib, "dllm_capi_generate_mm")
|
|
||||||
}
|
|
||||||
if sym, err := purego.Dlsym(lib, "dllm_capi_generate_stream_mm"); err == nil && sym != 0 {
|
|
||||||
purego.RegisterLibFunc(&cppGenerateStreamMM, lib, "dllm_capi_generate_stream_mm")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
libName := os.Getenv("DLLM_LIBRARY")
|
|
||||||
if libName == "" {
|
|
||||||
libName = "libdllm.so"
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := loadCAPI(libName); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hard-fail on an ABI mismatch: the flat-pointer bindings above would
|
|
||||||
// otherwise misbehave silently against a future libdllm.so.
|
|
||||||
if v := cAbiVersion(); v != dllmABIVersion {
|
|
||||||
panic(fmt.Errorf("dllm: libdllm.so ABI=%d, this backend speaks ABI=%d", v, dllmABIVersion))
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "[dllm] ABI=%d multimodal=%t\n", cAbiVersion(), cMMSupported())
|
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if err := grpc.StartServer(*addr, &Dllm{}); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# T1 packaging stub: copy the binary, run.sh and libdllm.so into package/.
|
|
||||||
# The full ldd walk (libc, libstdc++, libgomp, GPU runtimes, arch
|
|
||||||
# detection) lands with the registration task, mirroring
|
|
||||||
# backend/go/whisper/package.sh.
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
CURDIR=$(dirname "$(realpath "$0")")
|
|
||||||
|
|
||||||
mkdir -p "$CURDIR/package/lib"
|
|
||||||
|
|
||||||
cp -avf "$CURDIR/dllm-grpc" "$CURDIR/package/"
|
|
||||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
|
||||||
|
|
||||||
# libdllm.so + any soname symlinks, should upstream ever add them.
|
|
||||||
cp -avf "$CURDIR"/libdllm.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
|
||||||
echo "ERROR: libdllm.so not found in $CURDIR, run 'make' first" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
echo "T1 package layout (full ldd walk lands with registration):"
|
|
||||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
|
||||||
18
backend/go/face-detect/.gitignore
vendored
Normal file
18
backend/go/face-detect/.gitignore
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Fetched upstream sources
|
||||||
|
sources/
|
||||||
|
|
||||||
|
# CMake build directories
|
||||||
|
build*/
|
||||||
|
|
||||||
|
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||||
|
# symlinked for local dev; the real sources live in face-detect.cpp upstream.
|
||||||
|
*.so
|
||||||
|
*.so.*
|
||||||
|
facedetect_capi.h
|
||||||
|
compile_commands.json
|
||||||
|
|
||||||
|
# Compiled backend binary
|
||||||
|
face-detect-grpc
|
||||||
|
|
||||||
|
# Packaging output
|
||||||
|
package/
|
||||||
97
backend/go/face-detect/Makefile
Normal file
97
backend/go/face-detect/Makefile
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
# face-detect backend Makefile.
|
||||||
|
#
|
||||||
|
# Upstream pin lives below as FACEDETECT_VERSION?=9c8adb7... (.github/bump_deps.sh
|
||||||
|
# can find and update it - matches the voice-detect / parakeet.cpp / whisper.cpp
|
||||||
|
# convention).
|
||||||
|
#
|
||||||
|
# Local dev shortcut: if you already have an out-of-tree face-detect.cpp build,
|
||||||
|
# symlink the .so + header into this directory and skip the clone/cmake steps:
|
||||||
|
#
|
||||||
|
# ln -sf /path/to/face-detect.cpp/build-shared/libfacedetect.so .
|
||||||
|
# ln -sf /path/to/face-detect.cpp/include/facedetect_capi.h .
|
||||||
|
# go build -o face-detect-grpc .
|
||||||
|
#
|
||||||
|
# The default target below does the proper clone-at-pin + cmake build so CI does
|
||||||
|
# not need a side-checkout.
|
||||||
|
|
||||||
|
FACEDETECT_VERSION?=9c8adb748f1f02d7fc0430a883234aef4b343a34
|
||||||
|
FACEDETECT_REPO?=https://github.com/mudler/face-detect.cpp
|
||||||
|
|
||||||
|
GOCMD?=go
|
||||||
|
GO_TAGS?=
|
||||||
|
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||||
|
|
||||||
|
BUILD_TYPE?=
|
||||||
|
NATIVE?=false
|
||||||
|
|
||||||
|
# Build ggml + the vendored libjpeg-turbo statically into libfacedetect.so (PIC)
|
||||||
|
# so the shared lib is self-contained: dlopen needs no libggml*.so alongside it,
|
||||||
|
# only system libs (libstdc++/libgomp/libc) the runtime image already provides.
|
||||||
|
# The vendored jpeg symbols are hidden via -Wl,--exclude-libs,ALL on the C++
|
||||||
|
# side, so only the facedetect_capi_* surface is exported.
|
||||||
|
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DFACEDETECT_SHARED=ON -DFACEDETECT_BUILD_CLI=OFF -DFACEDETECT_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||||
|
|
||||||
|
ifeq ($(NATIVE),false)
|
||||||
|
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||||
|
endif
|
||||||
|
|
||||||
|
# face-detect.cpp gates its GGML backends behind FACEDETECT_GGML_* options and
|
||||||
|
# does set(GGML_CUDA ${FACEDETECT_GGML_CUDA} CACHE BOOL "" FORCE), so a bare
|
||||||
|
# -DGGML_CUDA=ON is overwritten back to OFF. Forward the FACEDETECT_GGML_*
|
||||||
|
# options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||||
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
|
CMAKE_ARGS+=-DFACEDETECT_GGML_CUDA=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),openblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||||
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
|
CMAKE_ARGS+=-DFACEDETECT_GGML_HIP=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),vulkan)
|
||||||
|
CMAKE_ARGS+=-DFACEDETECT_GGML_VULKAN=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),metal)
|
||||||
|
CMAKE_ARGS+=-DFACEDETECT_GGML_METAL=ON
|
||||||
|
endif
|
||||||
|
|
||||||
|
.PHONY: face-detect-grpc package build clean purge test all
|
||||||
|
|
||||||
|
all: face-detect-grpc
|
||||||
|
|
||||||
|
# Clone the upstream face-detect.cpp source at the pinned commit. Directory acts
|
||||||
|
# as the target so make only re-clones when missing. After a FACEDETECT_VERSION
|
||||||
|
# bump, run 'make purge && make' to refetch.
|
||||||
|
sources/face-detect.cpp:
|
||||||
|
mkdir -p sources/face-detect.cpp
|
||||||
|
cd sources/face-detect.cpp && \
|
||||||
|
git init -q && \
|
||||||
|
git remote add origin $(FACEDETECT_REPO) && \
|
||||||
|
git fetch --depth 1 origin $(FACEDETECT_VERSION) && \
|
||||||
|
git checkout FETCH_HEAD && \
|
||||||
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
# Build the shared lib + header out-of-tree, then stage them next to the Go
|
||||||
|
# sources so purego.Dlopen("libfacedetect.so") and the cgo-less build both pick
|
||||||
|
# them up.
|
||||||
|
libfacedetect.so: sources/face-detect.cpp
|
||||||
|
cmake -B sources/face-detect.cpp/build-shared -S sources/face-detect.cpp $(CMAKE_ARGS)
|
||||||
|
cmake --build sources/face-detect.cpp/build-shared --config Release -j$(JOBS) --target facedetect
|
||||||
|
cp -fv sources/face-detect.cpp/build-shared/libfacedetect.so* ./ 2>/dev/null || true
|
||||||
|
cp -fv sources/face-detect.cpp/include/facedetect_capi.h ./
|
||||||
|
|
||||||
|
face-detect-grpc: libfacedetect.so main.go gofacedetect.go options.go
|
||||||
|
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o face-detect-grpc .
|
||||||
|
|
||||||
|
package: face-detect-grpc
|
||||||
|
bash package.sh
|
||||||
|
|
||||||
|
build: package
|
||||||
|
|
||||||
|
# Test target. The embed/detect/verify/analyze smoke specs are gated on
|
||||||
|
# FACEDETECT_BACKEND_TEST_MODEL + FACEDETECT_BACKEND_TEST_IMAGE; without them the
|
||||||
|
# heavy specs auto-skip and only the pure-Go parsing specs run.
|
||||||
|
test:
|
||||||
|
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||||
|
|
||||||
|
clean: purge
|
||||||
|
rm -rf libfacedetect.so* facedetect_capi.h package face-detect-grpc
|
||||||
|
|
||||||
|
purge:
|
||||||
|
rm -rf sources/face-detect.cpp
|
||||||
431
backend/go/face-detect/gofacedetect.go
Normal file
431
backend/go/face-detect/gofacedetect.go
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// purego-bound entry points from libfacedetect.so. Names match
|
||||||
|
// facedetect_capi.h exactly so a `nm libfacedetect.so | grep facedetect_capi`
|
||||||
|
// is enough to spot drift.
|
||||||
|
//
|
||||||
|
// The opaque ctx and the malloc'd char*/float* return values are declared as
|
||||||
|
// uintptr so we get the raw pointer back and can release it via the matching
|
||||||
|
// capi free function. purego's native string/[]float32 returns would copy and
|
||||||
|
// forget the original pointer, leaking the C-owned buffer on every call.
|
||||||
|
var (
|
||||||
|
CppAbiVersion func() int32
|
||||||
|
CppLoad func(ggufPath string) uintptr
|
||||||
|
CppFree func(ctx uintptr)
|
||||||
|
CppLastError func(ctx uintptr) string
|
||||||
|
CppFreeString func(s uintptr)
|
||||||
|
CppFreeVec func(v uintptr)
|
||||||
|
CppEmbedPath func(ctx uintptr, imagePath string, outVec, outDim unsafe.Pointer) int32
|
||||||
|
CppEmbedRGB func(ctx uintptr, rgb []byte, width, height int32, outVec, outDim unsafe.Pointer) int32
|
||||||
|
CppDetectJSON func(ctx uintptr, imagePath string) uintptr
|
||||||
|
CppVerifyPaths func(ctx uintptr, a, b string, threshold float32, antiSpoof int32, outDistance, outVerified unsafe.Pointer) int32
|
||||||
|
CppAnalyzeJSON func(ctx uintptr, imagePath string) uintptr
|
||||||
|
)
|
||||||
|
|
||||||
|
// FaceDetect implements the face-recognition (biometric) subset of the Backend
|
||||||
|
// gRPC service over libfacedetect.so. The C side keeps a single loaded model
|
||||||
|
// pack plus a per-ctx last-error buffer and is not reentrant, so
|
||||||
|
// base.SingleThread serializes every call.
|
||||||
|
type FaceDetect struct {
|
||||||
|
base.SingleThread
|
||||||
|
opts loadOptions
|
||||||
|
ctxPtr uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FaceDetect) Load(opts *pb.ModelOptions) error {
|
||||||
|
model := opts.ModelFile
|
||||||
|
if model == "" {
|
||||||
|
model = opts.ModelPath
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(model) && opts.ModelPath != "" {
|
||||||
|
model = filepath.Join(opts.ModelPath, model)
|
||||||
|
}
|
||||||
|
if model == "" {
|
||||||
|
return errors.New("face-detect: ModelFile is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
f.opts = parseOptions(opts.Options)
|
||||||
|
if f.opts.modelName == "" {
|
||||||
|
f.opts.modelName = filepath.Base(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Propagate LocalAI's per-model thread budget to the engine. LocalAI spawns
|
||||||
|
// one backend process per model and serves requests concurrently, so the
|
||||||
|
// engine's own min(hardware_concurrency, 8) default can oversubscribe cores.
|
||||||
|
// FACEDETECT_THREADS is read by the engine at backend construction, so it
|
||||||
|
// must be set before the capi load. A non-positive Threads means "unset":
|
||||||
|
// leave the env alone so the engine keeps its sane default.
|
||||||
|
threads := opts.Threads
|
||||||
|
if threads > 0 {
|
||||||
|
if err := os.Setenv("FACEDETECT_THREADS", strconv.Itoa(int(threads))); err != nil {
|
||||||
|
return fmt.Errorf("face-detect: set FACEDETECT_THREADS: %w", err)
|
||||||
|
}
|
||||||
|
xlog.Info("face-detect: applying LocalAI thread budget", "threads", threads)
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("face-detect: loading model", "model", model,
|
||||||
|
"verify_threshold", f.opts.verifyThreshold, "abi", CppAbiVersion())
|
||||||
|
|
||||||
|
ctx := CppLoad(model)
|
||||||
|
if ctx == 0 {
|
||||||
|
// The last-error buffer lives on the ctx that was never returned, so
|
||||||
|
// surface the path the operator tried to load instead.
|
||||||
|
return fmt.Errorf("face-detect: facedetect_capi_load failed for %q", model)
|
||||||
|
}
|
||||||
|
f.ctxPtr = ctx
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embeddings returns the L2-normalized ArcFace embedding of the primary face in
|
||||||
|
// the supplied image. Mirroring the Python face backend, the image is read from
|
||||||
|
// Images[0] as a base64 payload; materializeImage decodes it to a temp file so
|
||||||
|
// the path-based C-API can run its own decode (cv2.imread parity). The gRPC
|
||||||
|
// server wraps the returned slice in an EmbeddingResult.
|
||||||
|
func (f *FaceDetect) Embeddings(req *pb.PredictOptions) ([]float32, error) {
|
||||||
|
if f.ctxPtr == 0 {
|
||||||
|
return nil, errors.New("face-detect: model not loaded")
|
||||||
|
}
|
||||||
|
if len(req.Images) == 0 || req.Images[0] == "" {
|
||||||
|
return nil, errors.New("face-detect: Embedding requires Images[0] to be a base64 image")
|
||||||
|
}
|
||||||
|
|
||||||
|
path, cleanup, err := materializeImage(req.Images[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
return f.embedPath(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FaceDetect) embedPath(path string) ([]float32, error) {
|
||||||
|
var vec uintptr
|
||||||
|
var dim int32
|
||||||
|
rc := CppEmbedPath(f.ctxPtr, path, unsafe.Pointer(&vec), unsafe.Pointer(&dim))
|
||||||
|
if rc != 0 || vec == 0 || dim <= 0 {
|
||||||
|
return nil, f.lastErr("embed", path)
|
||||||
|
}
|
||||||
|
defer CppFreeVec(vec)
|
||||||
|
// Copy out of the C-owned malloc'd buffer before freeing it. The
|
||||||
|
// uintptr->Pointer conversion trips vet's unsafeptr check, which can't tell
|
||||||
|
// a C heap pointer from Go-managed memory; safe here, the GC neither tracks
|
||||||
|
// nor moves this buffer and we copy immediately.
|
||||||
|
src := unsafe.Slice((*float32)(unsafe.Pointer(vec)), int(dim)) //nolint:govet // C-owned malloc'd vector, copied out before free
|
||||||
|
out := make([]float32, int(dim))
|
||||||
|
copy(out, src)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect runs SCRFD over the image and returns one Detection per face. The
|
||||||
|
// C-API emits a box as [x1,y1,x2,y2] in pixels; the proto carries x/y plus
|
||||||
|
// width/height, so the corners are converted. The 5 facial landmarks the engine
|
||||||
|
// also returns are dropped: the Detection message has no field for them.
|
||||||
|
func (f *FaceDetect) Detect(req *pb.DetectOptions) (pb.DetectResponse, error) {
|
||||||
|
if f.ctxPtr == 0 {
|
||||||
|
return pb.DetectResponse{}, errors.New("face-detect: model not loaded")
|
||||||
|
}
|
||||||
|
if req.Src == "" {
|
||||||
|
return pb.DetectResponse{}, errors.New("face-detect: src image is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
path, cleanup, err := materializeImage(req.Src)
|
||||||
|
if err != nil {
|
||||||
|
return pb.DetectResponse{}, err
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
faces, err := f.detectFaces(path)
|
||||||
|
if err != nil {
|
||||||
|
return pb.DetectResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dets := make([]*pb.Detection, 0, len(faces))
|
||||||
|
for _, fc := range faces {
|
||||||
|
if req.Threshold > 0 && fc.Score < req.Threshold {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
x, y, w, h := fc.xywh()
|
||||||
|
dets = append(dets, &pb.Detection{
|
||||||
|
X: x,
|
||||||
|
Y: y,
|
||||||
|
Width: w,
|
||||||
|
Height: h,
|
||||||
|
Confidence: fc.Score,
|
||||||
|
ClassName: "face",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return pb.DetectResponse{Detections: dets}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FaceVerify embeds the primary face in each image and reports whether they are
|
||||||
|
// the same identity by cosine distance against a threshold. A request threshold
|
||||||
|
// <= 0 falls back to the model-configured default (verify_threshold option,
|
||||||
|
// 0.35 if unset). When anti_spoofing is set, the C-API applies a MiniFASNet
|
||||||
|
// veto internally (verified forced false on a spoof); the per-image liveness
|
||||||
|
// scores are not exposed by the verify entry point, so img*_is_real /
|
||||||
|
// img*_antispoof_score stay at their zero values.
|
||||||
|
func (f *FaceDetect) FaceVerify(req *pb.FaceVerifyRequest) (pb.FaceVerifyResponse, error) {
|
||||||
|
if f.ctxPtr == 0 {
|
||||||
|
return pb.FaceVerifyResponse{}, errors.New("face-detect: model not loaded")
|
||||||
|
}
|
||||||
|
if req.Img1 == "" || req.Img2 == "" {
|
||||||
|
return pb.FaceVerifyResponse{}, errors.New("face-detect: img1 and img2 are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
path1, cleanup1, err := materializeImage(req.Img1)
|
||||||
|
if err != nil {
|
||||||
|
return pb.FaceVerifyResponse{}, err
|
||||||
|
}
|
||||||
|
defer cleanup1()
|
||||||
|
path2, cleanup2, err := materializeImage(req.Img2)
|
||||||
|
if err != nil {
|
||||||
|
return pb.FaceVerifyResponse{}, err
|
||||||
|
}
|
||||||
|
defer cleanup2()
|
||||||
|
|
||||||
|
threshold := req.Threshold
|
||||||
|
if threshold <= 0 {
|
||||||
|
threshold = f.opts.verifyThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
antiSpoof := int32(0)
|
||||||
|
if req.AntiSpoofing {
|
||||||
|
antiSpoof = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
started := time.Now()
|
||||||
|
var distance float32
|
||||||
|
var verified int32
|
||||||
|
rc := CppVerifyPaths(f.ctxPtr, path1, path2, threshold, antiSpoof,
|
||||||
|
unsafe.Pointer(&distance), unsafe.Pointer(&verified))
|
||||||
|
if rc != 0 {
|
||||||
|
return pb.FaceVerifyResponse{}, f.lastErr("verify", req.Img1[:min(8, len(req.Img1))]+"...")
|
||||||
|
}
|
||||||
|
elapsedMs := float32(time.Since(started).Seconds() * 1000.0)
|
||||||
|
|
||||||
|
// Confidence decays linearly from 100 at distance 0 to 0 at the threshold,
|
||||||
|
// matching the Python face backend's reporting.
|
||||||
|
confidence := float32(0)
|
||||||
|
if threshold > 0 {
|
||||||
|
confidence = float32(math.Max(0, math.Min(100, (1.0-float64(distance)/float64(threshold))*100.0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return pb.FaceVerifyResponse{
|
||||||
|
Verified: verified != 0,
|
||||||
|
Distance: distance,
|
||||||
|
Threshold: threshold,
|
||||||
|
Confidence: confidence,
|
||||||
|
Model: f.opts.modelName,
|
||||||
|
Img1Area: f.bestArea(path1),
|
||||||
|
Img2Area: f.bestArea(path2),
|
||||||
|
ProcessingTimeMs: elapsedMs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FaceAnalyze runs the genderage head on every detected face. The C-API returns
|
||||||
|
// "M"/"F" gender labels and a rounded age; the labels are normalized to the
|
||||||
|
// "Man"/"Woman" values the proto documents.
|
||||||
|
func (f *FaceDetect) FaceAnalyze(req *pb.FaceAnalyzeRequest) (pb.FaceAnalyzeResponse, error) {
|
||||||
|
if f.ctxPtr == 0 {
|
||||||
|
return pb.FaceAnalyzeResponse{}, errors.New("face-detect: model not loaded")
|
||||||
|
}
|
||||||
|
if req.Img == "" {
|
||||||
|
return pb.FaceAnalyzeResponse{}, errors.New("face-detect: img is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
path, cleanup, err := materializeImage(req.Img)
|
||||||
|
if err != nil {
|
||||||
|
return pb.FaceAnalyzeResponse{}, err
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ptr := CppAnalyzeJSON(f.ctxPtr, path)
|
||||||
|
if ptr == 0 {
|
||||||
|
return pb.FaceAnalyzeResponse{}, f.lastErr("analyze", path)
|
||||||
|
}
|
||||||
|
defer CppFreeString(ptr)
|
||||||
|
|
||||||
|
faces, err := parseAnalyzeJSON(goStringFromCPtr(ptr))
|
||||||
|
if err != nil {
|
||||||
|
return pb.FaceAnalyzeResponse{}, fmt.Errorf("face-detect: analyze JSON: %w", err)
|
||||||
|
}
|
||||||
|
return pb.FaceAnalyzeResponse{Faces: faces}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// faceBox is one entry of the detect/analyze JSON documents the engine emits.
|
||||||
|
type faceBox struct {
|
||||||
|
Score float32 `json:"score"`
|
||||||
|
Box []float32 `json:"box"`
|
||||||
|
Age float32 `json:"age"`
|
||||||
|
Gender string `json:"gender"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// xywh converts the engine's [x1,y1,x2,y2] box into the x/y/width/height the
|
||||||
|
// proto carries. A short or missing box yields zeros.
|
||||||
|
func (b faceBox) xywh() (x, y, w, h float32) {
|
||||||
|
if len(b.Box) < 4 {
|
||||||
|
return 0, 0, 0, 0
|
||||||
|
}
|
||||||
|
return b.Box[0], b.Box[1], b.Box[2] - b.Box[0], b.Box[3] - b.Box[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
type facesJSON struct {
|
||||||
|
Faces []faceBox `json:"faces"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FaceDetect) detectFaces(path string) ([]faceBox, error) {
|
||||||
|
ptr := CppDetectJSON(f.ctxPtr, path)
|
||||||
|
if ptr == 0 {
|
||||||
|
return nil, f.lastErr("detect", path)
|
||||||
|
}
|
||||||
|
defer CppFreeString(ptr)
|
||||||
|
|
||||||
|
var doc facesJSON
|
||||||
|
if err := json.Unmarshal([]byte(goStringFromCPtr(ptr)), &doc); err != nil {
|
||||||
|
return nil, fmt.Errorf("face-detect: detect JSON: %w", err)
|
||||||
|
}
|
||||||
|
return doc.Faces, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bestArea returns the FacialArea of the highest-scoring face in an image, or an
|
||||||
|
// empty area when detection fails or finds nothing. Best-effort: verify already
|
||||||
|
// succeeded, so a missing region must not turn a valid match into an error.
|
||||||
|
func (f *FaceDetect) bestArea(path string) *pb.FacialArea {
|
||||||
|
faces, err := f.detectFaces(path)
|
||||||
|
if err != nil || len(faces) == 0 {
|
||||||
|
return &pb.FacialArea{}
|
||||||
|
}
|
||||||
|
best := faces[0]
|
||||||
|
for _, fc := range faces[1:] {
|
||||||
|
if fc.Score > best.Score {
|
||||||
|
best = fc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x, y, w, h := best.xywh()
|
||||||
|
return &pb.FacialArea{X: x, Y: y, W: w, H: h}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAnalyzeJSON maps the engine's analyze document onto FaceAnalysis entries.
|
||||||
|
// The engine reports gender as "M"/"F"; both the dominant label and the score
|
||||||
|
// map are filled with the "Man"/"Woman" form the proto documents.
|
||||||
|
func parseAnalyzeJSON(doc string) ([]*pb.FaceAnalysis, error) {
|
||||||
|
var parsed facesJSON
|
||||||
|
if err := json.Unmarshal([]byte(doc), &parsed); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*pb.FaceAnalysis, 0, len(parsed.Faces))
|
||||||
|
for _, fc := range parsed.Faces {
|
||||||
|
x, y, w, h := fc.xywh()
|
||||||
|
fa := &pb.FaceAnalysis{
|
||||||
|
Region: &pb.FacialArea{X: x, Y: y, W: w, H: h},
|
||||||
|
FaceConfidence: fc.Score,
|
||||||
|
Age: fc.Age,
|
||||||
|
}
|
||||||
|
if label := normalizeGender(fc.Gender); label != "" {
|
||||||
|
fa.DominantGender = label
|
||||||
|
fa.Gender = map[string]float32{label: 1.0}
|
||||||
|
}
|
||||||
|
out = append(out, fa)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeGender maps the engine's "M"/"F" code to the "Man"/"Woman" labels the
|
||||||
|
// proto documents. Unknown codes pass through unchanged.
|
||||||
|
func normalizeGender(g string) string {
|
||||||
|
switch strings.ToUpper(strings.TrimSpace(g)) {
|
||||||
|
case "M":
|
||||||
|
return "Man"
|
||||||
|
case "F":
|
||||||
|
return "Woman"
|
||||||
|
case "":
|
||||||
|
return ""
|
||||||
|
default:
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// materializeImage decodes a base64 image payload into a temp file and returns
|
||||||
|
// its path plus a cleanup func. As a convenience for callers that already pass a
|
||||||
|
// filesystem path (e.g. a test fixture), an existing path is used as-is with a
|
||||||
|
// no-op cleanup. data: URI prefixes are stripped before decoding.
|
||||||
|
func materializeImage(src string) (path string, cleanup func(), err error) {
|
||||||
|
noop := func() {}
|
||||||
|
if src == "" {
|
||||||
|
return "", noop, errors.New("face-detect: empty image input")
|
||||||
|
}
|
||||||
|
if _, statErr := os.Stat(src); statErr == nil {
|
||||||
|
return src, noop, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := src
|
||||||
|
if i := strings.Index(payload, ","); strings.HasPrefix(payload, "data:") && i >= 0 {
|
||||||
|
payload = payload[i+1:]
|
||||||
|
}
|
||||||
|
data, decErr := base64.StdEncoding.DecodeString(strings.TrimSpace(payload))
|
||||||
|
if decErr != nil || len(data) == 0 {
|
||||||
|
return "", noop, errors.New("face-detect: image is neither an existing path nor valid base64")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp, createErr := os.CreateTemp("", "face-detect-*.img")
|
||||||
|
if createErr != nil {
|
||||||
|
return "", noop, fmt.Errorf("face-detect: create temp image: %w", createErr)
|
||||||
|
}
|
||||||
|
cleanup = func() { _ = os.Remove(tmp.Name()) }
|
||||||
|
if _, wErr := tmp.Write(data); wErr != nil {
|
||||||
|
_ = tmp.Close()
|
||||||
|
cleanup()
|
||||||
|
return "", noop, fmt.Errorf("face-detect: write temp image: %w", wErr)
|
||||||
|
}
|
||||||
|
if cErr := tmp.Close(); cErr != nil {
|
||||||
|
cleanup()
|
||||||
|
return "", noop, fmt.Errorf("face-detect: close temp image: %w", cErr)
|
||||||
|
}
|
||||||
|
return tmp.Name(), cleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// lastErr wraps the C-API's per-ctx last-error buffer into a Go error.
|
||||||
|
func (f *FaceDetect) lastErr(op, subject string) error {
|
||||||
|
msg := strings.TrimSpace(CppLastError(f.ctxPtr))
|
||||||
|
if msg == "" {
|
||||||
|
msg = "no error detail"
|
||||||
|
}
|
||||||
|
return fmt.Errorf("face-detect: %s failed for %q: %s", op, subject, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is a
|
||||||
|
// malloc'd buffer the caller owns; release it via CppFreeString after the copy.
|
||||||
|
//
|
||||||
|
// The uintptr->Pointer conversion trips vet's unsafeptr check, which can't tell
|
||||||
|
// a C heap pointer from Go-managed memory. Safe here: the GC neither tracks nor
|
||||||
|
// moves the buffer and we dereference it immediately to copy the bytes out.
|
||||||
|
func goStringFromCPtr(cptr uintptr) string {
|
||||||
|
if cptr == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
p := unsafe.Pointer(cptr) //nolint:govet // C-owned malloc'd buffer, not Go-GC memory (see doc above)
|
||||||
|
n := 0
|
||||||
|
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
return string(unsafe.Slice((*byte)(p), n))
|
||||||
|
}
|
||||||
230
backend/go/face-detect/gofacedetect_test.go
Normal file
230
backend/go/face-detect/gofacedetect_test.go
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFaceDetect(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "face-detect Backend Suite")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
libLoadOnce sync.Once
|
||||||
|
libLoadErr error
|
||||||
|
)
|
||||||
|
|
||||||
|
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the C-API
|
||||||
|
// bridge without spinning up the gRPC server. Records the error (the smoke
|
||||||
|
// specs skip themselves) when libfacedetect.so is not loadable from cwd
|
||||||
|
// (LD_LIBRARY_PATH or a symlink in ./).
|
||||||
|
func ensureLibLoaded() error {
|
||||||
|
libLoadOnce.Do(func() {
|
||||||
|
libName := os.Getenv("FACEDETECT_LIBRARY")
|
||||||
|
if libName == "" {
|
||||||
|
libName = "libfacedetect.so"
|
||||||
|
}
|
||||||
|
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
libLoadErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
purego.RegisterLibFunc(&CppAbiVersion, lib, "facedetect_capi_abi_version")
|
||||||
|
purego.RegisterLibFunc(&CppLoad, lib, "facedetect_capi_load")
|
||||||
|
purego.RegisterLibFunc(&CppFree, lib, "facedetect_capi_free")
|
||||||
|
purego.RegisterLibFunc(&CppLastError, lib, "facedetect_capi_last_error")
|
||||||
|
purego.RegisterLibFunc(&CppFreeString, lib, "facedetect_capi_free_string")
|
||||||
|
purego.RegisterLibFunc(&CppFreeVec, lib, "facedetect_capi_free_vec")
|
||||||
|
purego.RegisterLibFunc(&CppEmbedPath, lib, "facedetect_capi_embed_path")
|
||||||
|
purego.RegisterLibFunc(&CppEmbedRGB, lib, "facedetect_capi_embed_rgb")
|
||||||
|
purego.RegisterLibFunc(&CppDetectJSON, lib, "facedetect_capi_detect_path_json")
|
||||||
|
purego.RegisterLibFunc(&CppVerifyPaths, lib, "facedetect_capi_verify_paths")
|
||||||
|
purego.RegisterLibFunc(&CppAnalyzeJSON, lib, "facedetect_capi_analyze_path_json")
|
||||||
|
})
|
||||||
|
return libLoadErr
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("parseOptions", func() {
|
||||||
|
It("defaults verify_threshold to 0.35", func() {
|
||||||
|
o := parseOptions(nil)
|
||||||
|
Expect(o.verifyThreshold).To(Equal(float32(0.35)))
|
||||||
|
Expect(o.modelName).To(Equal(""))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("parses verify_threshold, threshold alias and model_name", func() {
|
||||||
|
o := parseOptions([]string{"verify_threshold:0.4", "model_name:buffalo_l", "unknown:x"})
|
||||||
|
Expect(o.verifyThreshold).To(Equal(float32(0.4)))
|
||||||
|
Expect(o.modelName).To(Equal("buffalo_l"))
|
||||||
|
|
||||||
|
o2 := parseOptions([]string{"threshold:0.3"})
|
||||||
|
Expect(o2.verifyThreshold).To(Equal(float32(0.3)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("ignores non-positive thresholds and keeps the default", func() {
|
||||||
|
o := parseOptions([]string{"verify_threshold:0", "threshold:-1"})
|
||||||
|
Expect(o.verifyThreshold).To(Equal(float32(0.35)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("normalizeGender", func() {
|
||||||
|
It("maps M/F codes to Man/Woman", func() {
|
||||||
|
Expect(normalizeGender("M")).To(Equal("Man"))
|
||||||
|
Expect(normalizeGender("f")).To(Equal("Woman"))
|
||||||
|
Expect(normalizeGender(" m ")).To(Equal("Man"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("passes empty and unknown codes through", func() {
|
||||||
|
Expect(normalizeGender("")).To(Equal(""))
|
||||||
|
Expect(normalizeGender("nonbinary")).To(Equal("nonbinary"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("faceBox.xywh", func() {
|
||||||
|
It("converts an [x1,y1,x2,y2] box to x/y/width/height", func() {
|
||||||
|
b := faceBox{Box: []float32{10, 20, 50, 80}}
|
||||||
|
x, y, w, h := b.xywh()
|
||||||
|
Expect(x).To(Equal(float32(10)))
|
||||||
|
Expect(y).To(Equal(float32(20)))
|
||||||
|
Expect(w).To(Equal(float32(40)))
|
||||||
|
Expect(h).To(Equal(float32(60)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns zeros for a short box", func() {
|
||||||
|
x, y, w, h := faceBox{Box: []float32{1, 2}}.xywh()
|
||||||
|
Expect([]float32{x, y, w, h}).To(Equal([]float32{0, 0, 0, 0}))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("parseAnalyzeJSON", func() {
|
||||||
|
It("maps region, age and gender for each face", func() {
|
||||||
|
doc := `{"faces":[
|
||||||
|
{"score":0.997,"box":[10,20,50,80],"age":31,"gender":"M"},
|
||||||
|
{"score":0.81,"box":[0,0,40,40],"age":24,"gender":"F"}]}`
|
||||||
|
faces, err := parseAnalyzeJSON(doc)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(faces).To(HaveLen(2))
|
||||||
|
|
||||||
|
Expect(faces[0].FaceConfidence).To(BeNumerically("~", 0.997, 1e-4))
|
||||||
|
Expect(faces[0].Age).To(BeNumerically("~", 31, 1e-4))
|
||||||
|
Expect(faces[0].DominantGender).To(Equal("Man"))
|
||||||
|
Expect(faces[0].Gender).To(HaveKeyWithValue("Man", float32(1.0)))
|
||||||
|
Expect(faces[0].Region.W).To(Equal(float32(40)))
|
||||||
|
Expect(faces[0].Region.H).To(Equal(float32(60)))
|
||||||
|
|
||||||
|
Expect(faces[1].DominantGender).To(Equal("Woman"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("tolerates a missing gender field", func() {
|
||||||
|
faces, err := parseAnalyzeJSON(`{"faces":[{"score":0.5,"box":[0,0,10,10],"age":40}]}`)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(faces).To(HaveLen(1))
|
||||||
|
Expect(faces[0].DominantGender).To(Equal(""))
|
||||||
|
Expect(faces[0].Gender).To(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns no faces for an empty document", func() {
|
||||||
|
faces, err := parseAnalyzeJSON(`{"faces":[]}`)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(faces).To(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns an error on malformed JSON", func() {
|
||||||
|
_, err := parseAnalyzeJSON(`{not-json`)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("materializeImage", func() {
|
||||||
|
It("decodes a base64 payload to a temp file", func() {
|
||||||
|
payload := base64.StdEncoding.EncodeToString([]byte("\xff\xd8\xff\xe0fake-jpeg"))
|
||||||
|
path, cleanup, err := materializeImage(payload)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer cleanup()
|
||||||
|
data, rerr := os.ReadFile(path)
|
||||||
|
Expect(rerr).ToNot(HaveOccurred())
|
||||||
|
Expect(data).To(Equal([]byte("\xff\xd8\xff\xe0fake-jpeg")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("strips a data: URI prefix before decoding", func() {
|
||||||
|
payload := "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte("hello"))
|
||||||
|
path, cleanup, err := materializeImage(payload)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer cleanup()
|
||||||
|
data, rerr := os.ReadFile(path)
|
||||||
|
Expect(rerr).ToNot(HaveOccurred())
|
||||||
|
Expect(data).To(Equal([]byte("hello")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("uses an existing path as-is", func() {
|
||||||
|
tmp, err := os.CreateTemp("", "face-detect-fixture-*.bin")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer func() { _ = os.Remove(tmp.Name()) }()
|
||||||
|
Expect(tmp.Close()).To(Succeed())
|
||||||
|
|
||||||
|
path, cleanup, err := materializeImage(tmp.Name())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer cleanup()
|
||||||
|
Expect(path).To(Equal(tmp.Name()))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors on input that is neither a path nor base64", func() {
|
||||||
|
_, _, err := materializeImage("not base64!!!")
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// The specs below exercise the real C-API end to end. They run only when both a
|
||||||
|
// model GGUF and a test image are provided, and skip cleanly otherwise so the
|
||||||
|
// suite stays green without large assets.
|
||||||
|
var _ = Describe("FaceDetect end-to-end", Ordered, func() {
|
||||||
|
var (
|
||||||
|
f *FaceDetect
|
||||||
|
modelPath = os.Getenv("FACEDETECT_BACKEND_TEST_MODEL")
|
||||||
|
imagePath = os.Getenv("FACEDETECT_BACKEND_TEST_IMAGE")
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeAll(func() {
|
||||||
|
if modelPath == "" || imagePath == "" {
|
||||||
|
Skip("set FACEDETECT_BACKEND_TEST_MODEL and FACEDETECT_BACKEND_TEST_IMAGE to run the e2e specs")
|
||||||
|
}
|
||||||
|
if err := ensureLibLoaded(); err != nil {
|
||||||
|
Skip("libfacedetect.so not loadable: " + err.Error())
|
||||||
|
}
|
||||||
|
f = &FaceDetect{}
|
||||||
|
Expect(f.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("embeds the primary face in an image", func() {
|
||||||
|
emb, err := f.Embeddings(&pb.PredictOptions{Images: []string{imagePath}})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(emb).ToNot(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("detects at least one face", func() {
|
||||||
|
resp, err := f.Detect(&pb.DetectOptions{Src: imagePath})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.Detections).ToNot(BeEmpty())
|
||||||
|
Expect(resp.Detections[0].ClassName).To(Equal("face"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("verifies an image against itself as the same identity", func() {
|
||||||
|
resp, err := f.FaceVerify(&pb.FaceVerifyRequest{Img1: imagePath, Img2: imagePath})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.Verified).To(BeTrue())
|
||||||
|
Expect(resp.Distance).To(BeNumerically("<=", resp.Threshold))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("analyzes age/gender for each face", func() {
|
||||||
|
resp, err := f.FaceAnalyze(&pb.FaceAnalyzeRequest{Img: imagePath})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.Faces).ToNot(BeEmpty())
|
||||||
|
})
|
||||||
|
})
|
||||||
65
backend/go/face-detect/main.go
Normal file
65
backend/go/face-detect/main.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||||
|
//
|
||||||
|
// Loads libfacedetect.so via purego and registers the flat C-API entry points
|
||||||
|
// declared in facedetect_capi.h. The library name can be overridden with
|
||||||
|
// FACEDETECT_LIBRARY (mirrors the VOICEDETECT_LIBRARY / PARAKEET_LIBRARY
|
||||||
|
// convention in the sibling backends); the default looks for the .so next to
|
||||||
|
// this binary (resolved via LD_LIBRARY_PATH by run.sh).
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
)
|
||||||
|
|
||||||
|
type LibFuncs struct {
|
||||||
|
FuncPtr any
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
libName := os.Getenv("FACEDETECT_LIBRARY")
|
||||||
|
if libName == "" {
|
||||||
|
libName = "libfacedetect.so"
|
||||||
|
}
|
||||||
|
|
||||||
|
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("face-detect: dlopen %q: %w", libName, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bound 1:1 to facedetect_capi.h. char*/float* returns are registered as
|
||||||
|
// uintptr so the raw pointer can be freed via the matching capi free fn.
|
||||||
|
libFuncs := []LibFuncs{
|
||||||
|
{&CppAbiVersion, "facedetect_capi_abi_version"},
|
||||||
|
{&CppLoad, "facedetect_capi_load"},
|
||||||
|
{&CppFree, "facedetect_capi_free"},
|
||||||
|
{&CppLastError, "facedetect_capi_last_error"},
|
||||||
|
{&CppFreeString, "facedetect_capi_free_string"},
|
||||||
|
{&CppFreeVec, "facedetect_capi_free_vec"},
|
||||||
|
{&CppEmbedPath, "facedetect_capi_embed_path"},
|
||||||
|
{&CppEmbedRGB, "facedetect_capi_embed_rgb"},
|
||||||
|
{&CppDetectJSON, "facedetect_capi_detect_path_json"},
|
||||||
|
{&CppVerifyPaths, "facedetect_capi_verify_paths"},
|
||||||
|
{&CppAnalyzeJSON, "facedetect_capi_analyze_path_json"},
|
||||||
|
}
|
||||||
|
for _, lf := range libFuncs {
|
||||||
|
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "[face-detect] ABI=%d\n", CppAbiVersion())
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &FaceDetect{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
47
backend/go/face-detect/options.go
Normal file
47
backend/go/face-detect/options.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultVerifyThreshold is the cosine-distance cutoff used when a request does
|
||||||
|
// not set one. Matches the insightface buffalo_l ArcFace R50 default the Python
|
||||||
|
// face backend ships with so the two implementations agree on verdicts out of
|
||||||
|
// the box.
|
||||||
|
const defaultVerifyThreshold float32 = 0.35
|
||||||
|
|
||||||
|
// loadOptions holds the parsed model-level options for face-detect.
|
||||||
|
type loadOptions struct {
|
||||||
|
verifyThreshold float32
|
||||||
|
modelName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitOption(o string) (key, value string, ok bool) {
|
||||||
|
i := strings.Index(o, ":")
|
||||||
|
if i < 0 {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(o[:i]), strings.TrimSpace(o[i+1:]), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOptions reads the backend "key:value" option slice. Unknown keys are
|
||||||
|
// ignored. Defaults: verify_threshold 0.35, model_name derived from the file.
|
||||||
|
func parseOptions(opts []string) loadOptions {
|
||||||
|
o := loadOptions{verifyThreshold: defaultVerifyThreshold}
|
||||||
|
for _, oo := range opts {
|
||||||
|
key, value, ok := splitOption(oo)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch key {
|
||||||
|
case "verify_threshold", "threshold":
|
||||||
|
if f, err := strconv.ParseFloat(value, 32); err == nil && f > 0 {
|
||||||
|
o.verifyThreshold = float32(f)
|
||||||
|
}
|
||||||
|
case "model_name":
|
||||||
|
o.modelName = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return o
|
||||||
|
}
|
||||||
68
backend/go/face-detect/package.sh
Normal file
68
backend/go/face-detect/package.sh
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Bundle the face-detect-grpc binary, libfacedetect.so, the core runtime libs
|
||||||
|
# (libc/libstdc++/libgomp + ld.so) and the GPU runtime for the active BUILD_TYPE
|
||||||
|
# so the package is self-contained. Mirrors backend/go/voice-detect/package.sh;
|
||||||
|
# run.sh routes the (CGO_ENABLED=0) binary through lib/ld.so so the packaged libc
|
||||||
|
# is used instead of the host's.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath "$0")")
|
||||||
|
REPO_ROOT="${CURDIR}/../../.."
|
||||||
|
|
||||||
|
mkdir -p "$CURDIR/package/lib"
|
||||||
|
|
||||||
|
cp -avf "$CURDIR/face-detect-grpc" "$CURDIR/package/"
|
||||||
|
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||||
|
|
||||||
|
# libfacedetect.so + any soname symlinks. purego.Dlopen resolves it via
|
||||||
|
# LD_LIBRARY_PATH, which run.sh points at lib/.
|
||||||
|
cp -avf "$CURDIR"/libfacedetect.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||||
|
echo "ERROR: libfacedetect.so not found in $CURDIR, run 'make' first" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Detect architecture and copy the core runtime libs libfacedetect.so links
|
||||||
|
# against, plus the matching dynamic loader as lib/ld.so.
|
||||||
|
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||||
|
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||||
|
cp -arfLv /lib64/ld-linux-x86-64.so.2 "$CURDIR/package/lib/ld.so"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||||
|
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||||
|
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||||
|
cp -arfLv /lib/ld-linux-aarch64.so.1 "$CURDIR/package/lib/ld.so"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||||
|
elif [ "$(uname -s)" = "Darwin" ]; then
|
||||||
|
echo "Detected Darwin"
|
||||||
|
else
|
||||||
|
echo "Error: Could not detect architecture"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Package GPU libraries (CUDA/ROCm/Intel/Vulkan loader + ICDs + drivers) based on
|
||||||
|
# BUILD_TYPE so the backend can reach the GPU without the runtime base image
|
||||||
|
# shipping those drivers.
|
||||||
|
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||||
|
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||||
|
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||||
|
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||||
|
package_gpu_libs
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Packaging completed successfully"
|
||||||
|
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||||
16
backend/go/face-detect/run.sh
Normal file
16
backend/go/face-detect/run.sh
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath "$0")")
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||||
|
|
||||||
|
# If a self-contained ld.so was packaged, route through it so the packaged
|
||||||
|
# libc / libstdc++ are used instead of the host's (matches the voice-detect /
|
||||||
|
# whisper / parakeet backends' runtime layout).
|
||||||
|
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||||
|
echo "Using lib/ld.so"
|
||||||
|
exec "$CURDIR/lib/ld.so" "$CURDIR/face-detect-grpc" "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
exec "$CURDIR/face-detect-grpc" "$@"
|
||||||
15
backend/go/face-detect/test.sh
Normal file
15
backend/go/face-detect/test.sh
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath "$0")")
|
||||||
|
cd "$CURDIR"
|
||||||
|
|
||||||
|
echo "Running face-detect backend tests..."
|
||||||
|
|
||||||
|
# The pure-Go parsing specs always run. The embed/detect/verify/analyze smoke
|
||||||
|
# specs run only when a model + image are provided via
|
||||||
|
# FACEDETECT_BACKEND_TEST_MODEL and FACEDETECT_BACKEND_TEST_IMAGE; otherwise they
|
||||||
|
# auto-skip.
|
||||||
|
LD_LIBRARY_PATH="$CURDIR:${LD_LIBRARY_PATH:-}" go test -v -timeout 1200s .
|
||||||
|
|
||||||
|
echo "face-detect tests completed."
|
||||||
7
backend/go/locate-anything-cpp/.gitignore
vendored
Normal file
7
backend/go/locate-anything-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
sources/
|
||||||
|
build*/
|
||||||
|
package/
|
||||||
|
liblocateanythingcpp*.so
|
||||||
|
locate-anything-cpp
|
||||||
|
test-models/
|
||||||
|
test-data/
|
||||||
57
backend/go/locate-anything-cpp/CMakeLists.txt
Normal file
57
backend/go/locate-anything-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.18)
|
||||||
|
project(liblocateanythingcpp LANGUAGES C CXX)
|
||||||
|
|
||||||
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
# Static-link ggml + locate_anything so the resulting .so has no runtime
|
||||||
|
# dependency on extra ggml/locate_anything shared libraries — only on
|
||||||
|
# libc/libstdc++/libgomp, which the LocalAI package step bundles into the
|
||||||
|
# docker image.
|
||||||
|
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build static libraries" FORCE)
|
||||||
|
|
||||||
|
# locate-anything.cpp build switches: skip CLI/tests, keep static lib.
|
||||||
|
set(LA_BUILD_CLI OFF CACHE BOOL "Disable locate-anything CLI" FORCE)
|
||||||
|
set(LA_BUILD_TESTS OFF CACHE BOOL "Disable locate-anything tests" FORCE)
|
||||||
|
set(LA_SHARED OFF CACHE BOOL "Build locate_anything as static lib" FORCE)
|
||||||
|
|
||||||
|
# Unlike rt-detr.cpp, locate-anything.cpp ships no in-tree ggml patches, so
|
||||||
|
# there is no apply_ggml_patches.sh hook to shim here.
|
||||||
|
add_subdirectory(./sources/locate-anything.cpp)
|
||||||
|
|
||||||
|
# locate-anything.cpp's top-level CMakeLists points its own target's include
|
||||||
|
# dirs at ${CMAKE_SOURCE_DIR}/{include,src,third_party,...}. CMAKE_SOURCE_DIR
|
||||||
|
# is the *top-level* source dir of the whole CMake tree, so when we pull it in
|
||||||
|
# via add_subdirectory it resolves to OUR directory, not theirs, and the
|
||||||
|
# locate_anything target fails to find its own headers (la_capi.h, stb_image.h,
|
||||||
|
# la_gguf_keys.h). Re-add the correct, subdir-relative include paths to the
|
||||||
|
# already-defined target so it compiles regardless of where it's nested.
|
||||||
|
set(LA_SRC ${CMAKE_CURRENT_SOURCE_DIR}/sources/locate-anything.cpp)
|
||||||
|
target_include_directories(locate_anything PRIVATE
|
||||||
|
${LA_SRC}/include
|
||||||
|
${LA_SRC}/src
|
||||||
|
${LA_SRC}/third_party
|
||||||
|
${LA_SRC}/third_party/stb)
|
||||||
|
|
||||||
|
# locate-anything.cpp's C-API symbols already live inside liblocate_anything
|
||||||
|
# (src/la_capi.cpp is compiled into the lib). We re-export them via a MODULE
|
||||||
|
# library that links locate_anything so the symbols are visible at dlopen time.
|
||||||
|
add_library(locateanythingcpp MODULE
|
||||||
|
sources/locate-anything.cpp/src/la_capi.cpp)
|
||||||
|
|
||||||
|
target_include_directories(locateanythingcpp PRIVATE
|
||||||
|
sources/locate-anything.cpp/include
|
||||||
|
sources/locate-anything.cpp/src
|
||||||
|
sources/locate-anything.cpp/third_party
|
||||||
|
sources/locate-anything.cpp/third_party/stb
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(locateanythingcpp PRIVATE locate_anything ggml)
|
||||||
|
|
||||||
|
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||||
|
target_link_libraries(locateanythingcpp PRIVATE stdc++fs)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set_property(TARGET locateanythingcpp PROPERTY CXX_STANDARD 17)
|
||||||
|
set_target_properties(locateanythingcpp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||||
134
backend/go/locate-anything-cpp/Makefile
Normal file
134
backend/go/locate-anything-cpp/Makefile
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
CMAKE_ARGS?=
|
||||||
|
BUILD_TYPE?=
|
||||||
|
NATIVE?=false
|
||||||
|
|
||||||
|
GOCMD?=go
|
||||||
|
GO_TAGS?=
|
||||||
|
JOBS?=$(shell nproc --ignore=1)
|
||||||
|
|
||||||
|
# locate-anything.cpp. Pin to a specific commit for a stable build; leaving
|
||||||
|
# this on `master` always picks up the latest C-API surface (incl. the
|
||||||
|
# per-detection accessor functions used by golocateanythingcpp.go).
|
||||||
|
LOCATEANYTHING_REPO?=https://github.com/mudler/locate-anything.cpp.git
|
||||||
|
LOCATEANYTHING_VERSION?=92c1682da792c1e8a5dec91acc2be4b02c742ded
|
||||||
|
|
||||||
|
ifeq ($(NATIVE),false)
|
||||||
|
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Forward LocalAI's BUILD_TYPE to the matching ggml backend switch.
|
||||||
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
|
CMAKE_ARGS+=-DGGML_CUDA=ON -DLA_GGML_CUDA=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),openblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||||
|
else ifeq ($(BUILD_TYPE),clblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_CLBLAST=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
|
ROCM_HOME ?= /opt/rocm
|
||||||
|
ROCM_PATH ?= /opt/rocm
|
||||||
|
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||||
|
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||||
|
AMDGPU_TARGETS?=gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
|
||||||
|
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||||
|
else ifeq ($(BUILD_TYPE),vulkan)
|
||||||
|
CMAKE_ARGS+=-DGGML_VULKAN=ON -DLA_GGML_VULKAN=ON
|
||||||
|
else ifeq ($(OS),Darwin)
|
||||||
|
ifneq ($(BUILD_TYPE),metal)
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||||
|
else
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||||
|
CMAKE_ARGS+=-DLA_GGML_METAL=ON
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||||
|
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||||
|
-DCMAKE_C_COMPILER=icx \
|
||||||
|
-DCMAKE_CXX_COMPILER=icpx \
|
||||||
|
-DGGML_SYCL_F16=ON
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||||
|
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||||
|
-DCMAKE_C_COMPILER=icx \
|
||||||
|
-DCMAKE_CXX_COMPILER=icpx
|
||||||
|
endif
|
||||||
|
|
||||||
|
sources/locate-anything.cpp:
|
||||||
|
mkdir -p sources && \
|
||||||
|
git clone --recursive $(LOCATEANYTHING_REPO) sources/locate-anything.cpp && \
|
||||||
|
cd sources/locate-anything.cpp && \
|
||||||
|
git checkout $(LOCATEANYTHING_VERSION) && \
|
||||||
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
# Detect OS
|
||||||
|
UNAME_S := $(shell uname -s)
|
||||||
|
|
||||||
|
# Only build CPU variants on Linux
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
VARIANT_TARGETS = liblocateanythingcpp-avx.so liblocateanythingcpp-avx2.so liblocateanythingcpp-avx512.so liblocateanythingcpp-fallback.so
|
||||||
|
else
|
||||||
|
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||||
|
VARIANT_TARGETS = liblocateanythingcpp-fallback.so
|
||||||
|
endif
|
||||||
|
|
||||||
|
locate-anything-cpp: main.go golocateanythingcpp.go $(VARIANT_TARGETS)
|
||||||
|
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o locate-anything-cpp ./
|
||||||
|
|
||||||
|
package: locate-anything-cpp
|
||||||
|
bash package.sh
|
||||||
|
|
||||||
|
build: package
|
||||||
|
|
||||||
|
clean: purge
|
||||||
|
rm -rf liblocateanythingcpp*.so locate-anything-cpp package sources
|
||||||
|
|
||||||
|
purge:
|
||||||
|
rm -rf build*
|
||||||
|
|
||||||
|
# Build all variants (Linux only)
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
liblocateanythingcpp-avx.so: sources/locate-anything.cpp
|
||||||
|
rm -rfv build-$@
|
||||||
|
$(info ${GREEN}I locate-anything-cpp build info:avx${RESET})
|
||||||
|
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) liblocateanythingcpp-custom
|
||||||
|
rm -rfv build-$@
|
||||||
|
|
||||||
|
liblocateanythingcpp-avx2.so: sources/locate-anything.cpp
|
||||||
|
rm -rfv build-$@
|
||||||
|
$(info ${GREEN}I locate-anything-cpp build info:avx2${RESET})
|
||||||
|
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) liblocateanythingcpp-custom
|
||||||
|
rm -rfv build-$@
|
||||||
|
|
||||||
|
liblocateanythingcpp-avx512.so: sources/locate-anything.cpp
|
||||||
|
rm -rfv build-$@
|
||||||
|
$(info ${GREEN}I locate-anything-cpp build info:avx512${RESET})
|
||||||
|
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) liblocateanythingcpp-custom
|
||||||
|
rm -rfv build-$@
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Build fallback variant (all platforms)
|
||||||
|
liblocateanythingcpp-fallback.so: sources/locate-anything.cpp
|
||||||
|
rm -rfv build-$@
|
||||||
|
$(info ${GREEN}I locate-anything-cpp build info:fallback${RESET})
|
||||||
|
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) liblocateanythingcpp-custom
|
||||||
|
rm -rfv build-$@
|
||||||
|
|
||||||
|
liblocateanythingcpp-custom: CMakeLists.txt
|
||||||
|
mkdir -p build-$(SO_TARGET) && \
|
||||||
|
cd build-$(SO_TARGET) && \
|
||||||
|
cmake .. $(CMAKE_ARGS) && \
|
||||||
|
cmake --build . --config Release -j$(JOBS) && \
|
||||||
|
cd .. && \
|
||||||
|
mv build-$(SO_TARGET)/liblocateanythingcpp.so ./$(SO_TARGET)
|
||||||
|
|
||||||
|
all: locate-anything-cpp package
|
||||||
|
|
||||||
|
# `test` is invoked by the top-level Makefile's `test-extra` target. It builds
|
||||||
|
# the backend binary + the fallback shared library (needed for dlopen at
|
||||||
|
# runtime), then runs test.sh which downloads the q8_0 GGUF + COCO image and
|
||||||
|
# exercises the gRPC Load/Detect wire path via the Go smoke test in
|
||||||
|
# main_test.go.
|
||||||
|
test: locate-anything-cpp liblocateanythingcpp-fallback.so
|
||||||
|
bash test.sh
|
||||||
174
backend/go/locate-anything-cpp/golocateanythingcpp.go
Normal file
174
backend/go/locate-anything-cpp/golocateanythingcpp.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// golocateanythingcpp.go - gRPC handlers (Load, Detect) for the
|
||||||
|
// locate-anything-cpp backend.
|
||||||
|
//
|
||||||
|
// Embeds base.SingleThread to default unimplemented RPCs to "not supported"
|
||||||
|
// while we only implement open-vocabulary object detection (Detect).
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// la_ctx* is an opaque handle. la_capi_load returns it directly (0 == failure),
|
||||||
|
// unlike rfdetr's out-parameter convention.
|
||||||
|
var (
|
||||||
|
// la_capi_load(const char* gguf_path, int n_threads) -> la_ctx* (0 = fail)
|
||||||
|
CapiLoad func(gguf string, nThreads int32) uintptr
|
||||||
|
// la_capi_free(la_ctx* ctx)
|
||||||
|
CapiFree func(handle uintptr)
|
||||||
|
// la_capi_locate_path(ctx, image_path, prompt, mode) -> char* json (0 = err)
|
||||||
|
CapiLocatePath func(handle uintptr, imagePath string, prompt string, mode int32) uintptr
|
||||||
|
// la_capi_locate_buffer(ctx, bytes, len, prompt, mode) -> char* json (0 = err)
|
||||||
|
CapiLocateBuffer func(handle uintptr, bytes uintptr, length uintptr, prompt string, mode int32) uintptr
|
||||||
|
// la_capi_get_n_detections(ctx) -> int
|
||||||
|
CapiGetNDetections func(handle uintptr) int32
|
||||||
|
// la_capi_get_detection_box(ctx, i, out_xyxy[4]) -> int (0 on success)
|
||||||
|
CapiGetDetectionBox func(handle uintptr, i int32, outXYXY uintptr) int32
|
||||||
|
// la_capi_get_detection_label(ctx, i, buf, buf_size) -> int (required size incl NUL; two-call sizing)
|
||||||
|
CapiGetDetectionLabel func(handle uintptr, i int32, buf uintptr, bufSize int32) int32
|
||||||
|
// la_capi_free_string(char* s)
|
||||||
|
CapiFreeString func(s uintptr)
|
||||||
|
// la_capi_last_error(ctx) -> const char* (owned by ctx, "" if none / null ctx).
|
||||||
|
// purego marshals the returned C string into a Go string (a copy), so we
|
||||||
|
// never free it and avoid raw pointer arithmetic.
|
||||||
|
CapiLastError func(handle uintptr) string
|
||||||
|
)
|
||||||
|
|
||||||
|
type LocateAnythingCpp struct {
|
||||||
|
base.SingleThread
|
||||||
|
handle uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads the GGUF model at opts.ModelFile (joined with opts.ModelPath if
|
||||||
|
// relative) and stores the la_ctx handle for later Detect calls.
|
||||||
|
func (r *LocateAnythingCpp) Load(opts *pb.ModelOptions) error {
|
||||||
|
modelFile := opts.ModelFile
|
||||||
|
if modelFile == "" {
|
||||||
|
modelFile = opts.Model
|
||||||
|
}
|
||||||
|
if modelFile == "" {
|
||||||
|
return fmt.Errorf("locate-anything-cpp: ModelFile is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelPath string
|
||||||
|
if filepath.IsAbs(modelFile) {
|
||||||
|
modelPath = modelFile
|
||||||
|
} else {
|
||||||
|
modelPath = filepath.Join(opts.ModelPath, modelFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(modelPath); err != nil {
|
||||||
|
return fmt.Errorf("locate-anything-cpp: model file not found: %s: %w", modelPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
threads := opts.Threads
|
||||||
|
if threads <= 0 {
|
||||||
|
threads = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release previous model if any (re-Load).
|
||||||
|
if r.handle != 0 {
|
||||||
|
CapiFree(r.handle)
|
||||||
|
r.handle = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
h := CapiLoad(modelPath, threads)
|
||||||
|
if h == 0 {
|
||||||
|
// la_capi_last_error needs a ctx; on a failed load we have none (it
|
||||||
|
// returns "" for a null ctx), so the text is best-effort. Surface it
|
||||||
|
// when present.
|
||||||
|
if msg := CapiLastError(0); msg != "" {
|
||||||
|
return fmt.Errorf("locate-anything-cpp: la_capi_load failed for %s: %s", modelPath, msg)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("locate-anything-cpp: la_capi_load failed for %s", modelPath)
|
||||||
|
}
|
||||||
|
r.handle = h
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect runs open-vocabulary detection on the base64-encoded image in opts.Src
|
||||||
|
// using the required text prompt in opts.Prompt, returning one pb.Detection per
|
||||||
|
// located object with its predicted label as ClassName.
|
||||||
|
func (r *LocateAnythingCpp) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
|
||||||
|
if r.handle == 0 {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: model not loaded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open-vocabulary detection is prompt-driven; without a prompt there is
|
||||||
|
// nothing to locate.
|
||||||
|
prompt := opts.Prompt
|
||||||
|
if prompt == "" {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: a text prompt is required (open-vocabulary detection)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode base64 image and write to temp file.
|
||||||
|
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
|
||||||
|
if err != nil {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: failed to decode base64 image: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpFile, err := os.CreateTemp("", "locate-anything-*.img")
|
||||||
|
if err != nil {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: failed to create temp file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||||
|
|
||||||
|
if _, err := tmpFile.Write(imgData); err != nil {
|
||||||
|
_ = tmpFile.Close()
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: failed to write temp file: %w", err)
|
||||||
|
}
|
||||||
|
if err := tmpFile.Close(); err != nil {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: failed to close temp file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mode 0 = hybrid (Parallel Box Decoding). The JSON return value is unused:
|
||||||
|
// structured detections are read via the accessor functions. Still must
|
||||||
|
// free the returned string.
|
||||||
|
jsonPtr := CapiLocatePath(r.handle, tmpFile.Name(), prompt, 0)
|
||||||
|
if jsonPtr != 0 {
|
||||||
|
CapiFreeString(jsonPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := CapiGetNDetections(r.handle)
|
||||||
|
if n < 0 {
|
||||||
|
return pb.DetectResponse{}, fmt.Errorf("locate-anything-cpp: invalid n_detections=%d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
detections := make([]*pb.Detection, 0, n)
|
||||||
|
for i := int32(0); i < n; i++ {
|
||||||
|
var xyxy [4]float32 // x1, y1, x2, y2
|
||||||
|
if CapiGetDetectionBox(r.handle, i, uintptr(unsafe.Pointer(&xyxy[0]))) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two-call sizing for the label string.
|
||||||
|
label := ""
|
||||||
|
need := CapiGetDetectionLabel(r.handle, i, 0, 0)
|
||||||
|
if need > 0 {
|
||||||
|
buf := make([]byte, need)
|
||||||
|
CapiGetDetectionLabel(r.handle, i, uintptr(unsafe.Pointer(&buf[0])), need)
|
||||||
|
label = string(buf[:need-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
detections = append(detections, &pb.Detection{
|
||||||
|
X: xyxy[0],
|
||||||
|
Y: xyxy[1],
|
||||||
|
Width: xyxy[2] - xyxy[0],
|
||||||
|
Height: xyxy[3] - xyxy[1],
|
||||||
|
Confidence: 1.0,
|
||||||
|
ClassName: label,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return pb.DetectResponse{
|
||||||
|
Detections: detections,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
59
backend/go/locate-anything-cpp/main.go
Normal file
59
backend/go/locate-anything-cpp/main.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// main.go - entry point for the locate-anything-cpp gRPC backend.
|
||||||
|
//
|
||||||
|
// Dlopens liblocateanythingcpp-<variant>.so via purego at the path in
|
||||||
|
// LOCATEANYTHING_LIBRARY (set by run.sh based on /proc/cpuinfo), registers
|
||||||
|
// the la_capi_* C ABI symbols, then starts the gRPC server.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
)
|
||||||
|
|
||||||
|
type LibFuncs struct {
|
||||||
|
FuncPtr any
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Get library name from environment variable, default to fallback
|
||||||
|
libName := os.Getenv("LOCATEANYTHING_LIBRARY")
|
||||||
|
if libName == "" {
|
||||||
|
libName = "./liblocateanythingcpp-fallback.so"
|
||||||
|
}
|
||||||
|
|
||||||
|
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
libFuncs := []LibFuncs{
|
||||||
|
{&CapiLoad, "la_capi_load"},
|
||||||
|
{&CapiFree, "la_capi_free"},
|
||||||
|
{&CapiLocatePath, "la_capi_locate_path"},
|
||||||
|
{&CapiLocateBuffer, "la_capi_locate_buffer"},
|
||||||
|
{&CapiGetNDetections, "la_capi_get_n_detections"},
|
||||||
|
{&CapiGetDetectionBox, "la_capi_get_detection_box"},
|
||||||
|
{&CapiGetDetectionLabel, "la_capi_get_detection_label"},
|
||||||
|
{&CapiFreeString, "la_capi_free_string"},
|
||||||
|
{&CapiLastError, "la_capi_last_error"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, lf := range libFuncs {
|
||||||
|
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &LocateAnythingCpp{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
176
backend/go/locate-anything-cpp/main_test.go
Normal file
176
backend/go/locate-anything-cpp/main_test.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// main_test.go - end-to-end smoke test for the locate-anything-cpp gRPC backend.
|
||||||
|
//
|
||||||
|
// Spawns the compiled locate-anything-cpp binary on a free local port, dials it
|
||||||
|
// via gRPC, and exercises LoadModel + Detect against the test fixtures
|
||||||
|
// downloaded by test.sh: the q8_0 GGUF of nvidia/LocateAnything-3B and a real
|
||||||
|
// COCO image with people + cars. Asserts that open-vocabulary detection driven
|
||||||
|
// by a text prompt returns at least one detection, each carrying a non-empty
|
||||||
|
// class name and a bounding box of non-zero size.
|
||||||
|
//
|
||||||
|
// The spec Skip()s cleanly if its fixtures (the ~6.3 GB model, the test image,
|
||||||
|
// the built binary, or the fallback .so) are missing, so the test target stays
|
||||||
|
// usable on a fresh checkout / on CI runners where the large model hasn't been
|
||||||
|
// downloaded.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDetect(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "locate-anything-cpp backend smoke suite")
|
||||||
|
}
|
||||||
|
|
||||||
|
// freePort grabs an ephemeral TCP port and immediately releases it so the
|
||||||
|
// spawned backend can bind to it. There is a tiny TOCTOU window here but in
|
||||||
|
// practice it's adequate for a smoke test on a quiet runner.
|
||||||
|
func freePort() int {
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
Expect(err).ToNot(HaveOccurred(), "freePort listen")
|
||||||
|
port := l.Addr().(*net.TCPAddr).Port
|
||||||
|
Expect(l.Close()).To(Succeed())
|
||||||
|
return port
|
||||||
|
}
|
||||||
|
|
||||||
|
// startBackend spawns the locate-anything-cpp binary on the given port and
|
||||||
|
// waits until it accepts TCP connections (up to 10s). It mirrors how main.go
|
||||||
|
// resolves the purego library: the LOCATEANYTHING_LIBRARY env var points the
|
||||||
|
// dlopen at the freshly built fallback .so, and the la_capi_* symbols are
|
||||||
|
// registered there. The returned cleanup func kills the process and reaps it.
|
||||||
|
func startBackend(port int) func() {
|
||||||
|
binary, err := filepath.Abs("./locate-anything-cpp")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if _, err := os.Stat(binary); err != nil {
|
||||||
|
Skip(fmt.Sprintf("backend binary not built: %s (run `make locate-anything-cpp` first)", binary))
|
||||||
|
}
|
||||||
|
|
||||||
|
libPath, err := filepath.Abs("./liblocateanythingcpp-fallback.so")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if _, err := os.Stat(libPath); err != nil {
|
||||||
|
Skip(fmt.Sprintf("fallback library not built: %s (run `make liblocateanythingcpp-fallback.so` first)", libPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
cmd := exec.Command(binary, "--addr", addr)
|
||||||
|
cmd.Env = append(os.Environ(), "LOCATEANYTHING_LIBRARY="+libPath)
|
||||||
|
cmd.Stdout = os.Stderr
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
Expect(cmd.Start()).To(Succeed())
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
if cmd.Process != nil {
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
_, _ = cmd.Process.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(10 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
c, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
_ = c.Close()
|
||||||
|
return cleanup
|
||||||
|
}
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
Fail(fmt.Sprintf("backend did not become ready on %s within 10s", addr))
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadTestImage reads the COCO test image downloaded by test.sh and returns its
|
||||||
|
// base64-encoded content (the wire format accepted by the Detect RPC).
|
||||||
|
func loadTestImage() string {
|
||||||
|
imgPath, err := filepath.Abs("test-data/test.jpg")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
imgBytes, err := os.ReadFile(imgPath)
|
||||||
|
if err != nil {
|
||||||
|
Skip(fmt.Sprintf("test image not present: %s (run test.sh first)", imgPath))
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(imgBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialBackend opens a gRPC client connection to the spawned backend.
|
||||||
|
func dialBackend(port int) (pb.BackendClient, func()) {
|
||||||
|
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return pb.NewBackendClient(conn), func() { _ = conn.Close() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelPathOrSkip resolves the model file under ./test-models/ and Skip()s the
|
||||||
|
// current spec if it's missing (the ~6.3 GB GGUF is not present on a fresh
|
||||||
|
// checkout / on CI runners without the download).
|
||||||
|
func modelPathOrSkip(name string) string {
|
||||||
|
modelDir, err := filepath.Abs("test-models")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
modelPath := filepath.Join(modelDir, name)
|
||||||
|
if _, err := os.Stat(modelPath); err != nil {
|
||||||
|
Skip(fmt.Sprintf("model not present: %s (run test.sh first)", modelPath))
|
||||||
|
}
|
||||||
|
return modelPath
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("locate-anything-cpp backend", func() {
|
||||||
|
It("runs open-vocabulary detection against a known-good COCO image", func() {
|
||||||
|
modelPath := modelPathOrSkip("locate-anything-q8_0.gguf")
|
||||||
|
imgB64 := loadTestImage()
|
||||||
|
|
||||||
|
port := freePort()
|
||||||
|
cleanup := startBackend(port)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client, closeConn := dialBackend(port)
|
||||||
|
defer closeConn()
|
||||||
|
|
||||||
|
// The q8_0 model is ~6.3 GB and hybrid Parallel Box Decoding on CPU is
|
||||||
|
// not cheap, so give LoadModel + Detect a generous deadline.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
loadResp, err := client.LoadModel(ctx, &pb.ModelOptions{
|
||||||
|
Model: "locate-anything-q8_0.gguf",
|
||||||
|
ModelFile: modelPath,
|
||||||
|
Threads: 4,
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred(), "LoadModel")
|
||||||
|
Expect(loadResp.GetSuccess()).To(BeTrue(), "LoadModel reported failure: %s", loadResp.GetMessage())
|
||||||
|
|
||||||
|
// Open-vocabulary detection is prompt-driven; the prompt names the
|
||||||
|
// classes to locate (people + cars), separated by the </c> control token.
|
||||||
|
detResp, err := client.Detect(ctx, &pb.DetectOptions{
|
||||||
|
Src: imgB64,
|
||||||
|
Prompt: "Locate all the instances that matches the following description: person</c>car.",
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred(), "Detect")
|
||||||
|
Expect(detResp.GetDetections()).ToNot(BeEmpty(), "no detections returned on a known-good COCO image")
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(GinkgoWriter, "detection OK: %d detections\n", len(detResp.GetDetections()))
|
||||||
|
for i, d := range detResp.GetDetections() {
|
||||||
|
Expect(d.GetClassName()).ToNot(BeEmpty(), "detection %d has empty class_name", i)
|
||||||
|
Expect(d.GetWidth()).To(BeNumerically(">", float32(0)),
|
||||||
|
"detection %d has non-positive width", i)
|
||||||
|
Expect(d.GetHeight()).To(BeNumerically(">", float32(0)),
|
||||||
|
"detection %d has non-positive height", i)
|
||||||
|
_, _ = fmt.Fprintf(GinkgoWriter, " [%d] %s box=(%.1f,%.1f,%.1fx%.1f)\n",
|
||||||
|
i, d.GetClassName(), d.GetX(), d.GetY(), d.GetWidth(), d.GetHeight())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
59
backend/go/locate-anything-cpp/package.sh
Executable file
59
backend/go/locate-anything-cpp/package.sh
Executable file
@@ -0,0 +1,59 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Script to copy the appropriate libraries based on architecture
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
REPO_ROOT="${CURDIR}/../../.."
|
||||||
|
|
||||||
|
# Create lib directory
|
||||||
|
mkdir -p $CURDIR/package/lib
|
||||||
|
|
||||||
|
cp -avf $CURDIR/liblocateanythingcpp-*.so $CURDIR/package/
|
||||||
|
cp -avf $CURDIR/locate-anything-cpp $CURDIR/package/
|
||||||
|
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||||
|
|
||||||
|
# Detect architecture and copy appropriate libraries
|
||||||
|
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||||
|
# x86_64 architecture
|
||||||
|
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||||
|
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
|
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||||
|
# ARM64 architecture
|
||||||
|
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||||
|
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
|
elif [ $(uname -s) = "Darwin" ]; then
|
||||||
|
echo "Detected Darwin"
|
||||||
|
else
|
||||||
|
echo "Error: Could not detect architecture"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Package GPU libraries based on BUILD_TYPE
|
||||||
|
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||||
|
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||||
|
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||||
|
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||||
|
package_gpu_libs
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Packaging completed successfully"
|
||||||
|
ls -liah $CURDIR/package/
|
||||||
|
ls -liah $CURDIR/package/lib/
|
||||||
52
backend/go/locate-anything-cpp/run.sh
Executable file
52
backend/go/locate-anything-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# Get the absolute current dir where the script is located
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
|
||||||
|
cd /
|
||||||
|
|
||||||
|
echo "CPU info:"
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||||
|
grep -e "flags" /proc/cpuinfo | head -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
LIBRARY="$CURDIR/liblocateanythingcpp-fallback.so"
|
||||||
|
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX found OK"
|
||||||
|
if [ -e $CURDIR/liblocateanythingcpp-avx.so ]; then
|
||||||
|
LIBRARY="$CURDIR/liblocateanythingcpp-avx.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX2 found OK"
|
||||||
|
if [ -e $CURDIR/liblocateanythingcpp-avx2.so ]; then
|
||||||
|
LIBRARY="$CURDIR/liblocateanythingcpp-avx2.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check avx 512
|
||||||
|
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX512F found OK"
|
||||||
|
if [ -e $CURDIR/liblocateanythingcpp-avx512.so ]; then
|
||||||
|
LIBRARY="$CURDIR/liblocateanythingcpp-avx512.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||||
|
export LOCATEANYTHING_LIBRARY=$LIBRARY
|
||||||
|
|
||||||
|
# If there is a lib/ld.so, use it
|
||||||
|
if [ -f $CURDIR/lib/ld.so ]; then
|
||||||
|
echo "Using lib/ld.so"
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
|
exec $CURDIR/lib/ld.so $CURDIR/locate-anything-cpp "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
|
exec $CURDIR/locate-anything-cpp "$@"
|
||||||
47
backend/go/locate-anything-cpp/test.sh
Executable file
47
backend/go/locate-anything-cpp/test.sh
Executable file
@@ -0,0 +1,47 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
|
||||||
|
echo "Running locate-anything-cpp backend tests..."
|
||||||
|
|
||||||
|
# Test model from the mudler/locate-anything.cpp-gguf HuggingFace repo. This is
|
||||||
|
# the q8_0 quantization of nvidia/LocateAnything-3B (~6.3 GB), so the download
|
||||||
|
# is the slow step. It is resumed with `curl -C -` and skipped entirely if the
|
||||||
|
# file is already present.
|
||||||
|
LOCATEANYTHING_MODEL_DIR="${LOCATEANYTHING_MODEL_DIR:-$CURDIR/test-models}"
|
||||||
|
|
||||||
|
LOCATEANYTHING_MODEL_FILE="${LOCATEANYTHING_MODEL_FILE:-locate-anything-q8_0.gguf}"
|
||||||
|
LOCATEANYTHING_MODEL_URL="${LOCATEANYTHING_MODEL_URL:-https://huggingface.co/mudler/locate-anything.cpp-gguf/resolve/main/locate-anything-q8_0.gguf}"
|
||||||
|
|
||||||
|
mkdir -p "$LOCATEANYTHING_MODEL_DIR"
|
||||||
|
|
||||||
|
if [ ! -f "$LOCATEANYTHING_MODEL_DIR/$LOCATEANYTHING_MODEL_FILE" ]; then
|
||||||
|
echo "Downloading locate-anything q8_0 model (~6.3 GB, this is slow)..."
|
||||||
|
# -C - resumes a partial download so an interrupted run doesn't restart from 0.
|
||||||
|
curl -L -C - -o "$LOCATEANYTHING_MODEL_DIR/$LOCATEANYTHING_MODEL_FILE" "$LOCATEANYTHING_MODEL_URL" --progress-bar
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Use a real COCO test image (people + cars) from the upstream rf-detr.cpp repo
|
||||||
|
# (~46 KB). Open-vocabulary detection needs real content to locate, so a
|
||||||
|
# synthetic image would trivially yield zero detections.
|
||||||
|
TEST_IMAGE_DIR="$CURDIR/test-data"
|
||||||
|
TEST_IMAGE_FILE="$TEST_IMAGE_DIR/test.jpg"
|
||||||
|
TEST_IMAGE_URL="${TEST_IMAGE_URL:-https://raw.githubusercontent.com/mudler/rf-detr.cpp/main/tests/fixtures/ci/test_image.jpg}"
|
||||||
|
|
||||||
|
mkdir -p "$TEST_IMAGE_DIR"
|
||||||
|
if [ ! -f "$TEST_IMAGE_FILE" ]; then
|
||||||
|
echo "Downloading COCO test image..."
|
||||||
|
curl -L -o "$TEST_IMAGE_FILE" "$TEST_IMAGE_URL" --progress-bar
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "locate-anything-cpp test setup complete."
|
||||||
|
echo " model: $LOCATEANYTHING_MODEL_DIR/$LOCATEANYTHING_MODEL_FILE"
|
||||||
|
echo " test image: $TEST_IMAGE_FILE"
|
||||||
|
|
||||||
|
# Run the Go smoke test: spawns the backend binary on a free port, calls
|
||||||
|
# LoadModel + Detect via gRPC against the downloaded GGUF + COCO image.
|
||||||
|
echo ""
|
||||||
|
echo "Running Go smoke test..."
|
||||||
|
cd "$CURDIR"
|
||||||
|
go test -v -timeout 30m ./...
|
||||||
17
backend/go/omnivoice-cpp/.gitignore
vendored
Normal file
17
backend/go/omnivoice-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# Fetched upstream sources
|
||||||
|
sources/
|
||||||
|
|
||||||
|
# CMake build directories
|
||||||
|
build*/
|
||||||
|
|
||||||
|
# Compiled shared libraries
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Compiled backend binary
|
||||||
|
omnivoice-cpp
|
||||||
|
|
||||||
|
# Packaging output
|
||||||
|
package/
|
||||||
|
|
||||||
|
# Downloaded e2e models
|
||||||
|
omnivoice-models/
|
||||||
53
backend/go/omnivoice-cpp/CMakeLists.txt
Normal file
53
backend/go/omnivoice-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.14)
|
||||||
|
project(gomnivoicecpp LANGUAGES C CXX)
|
||||||
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
|
set(OMNIVOICE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/sources/omnivoice.cpp)
|
||||||
|
|
||||||
|
# Override upstream's CMAKE_CUDA_ARCHITECTURES before add_subdirectory.
|
||||||
|
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||||
|
set(CMAKE_CUDA_ARCHITECTURES "75-virtual;80-virtual;86-real;89-real")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Add the upstream project. Its own CMakeLists adds ggml + builds
|
||||||
|
# omnivoice-core (STATIC, contains src/omnivoice.cpp i.e. the ov_* impl).
|
||||||
|
# EXCLUDE_FROM_ALL keeps its CLI tools/tests from building unless referenced.
|
||||||
|
add_subdirectory(${OMNIVOICE_DIR} omnivoice EXCLUDE_FROM_ALL)
|
||||||
|
|
||||||
|
# Upstream generates version.h into its own CMAKE_CURRENT_BINARY_DIR and adds
|
||||||
|
# the top-level ${CMAKE_BINARY_DIR} to omnivoice-core's include path. When the
|
||||||
|
# project is nested under add_subdirectory those two directories differ
|
||||||
|
# (<build>/omnivoice vs <build>), so omnivoice.cpp cannot find version.h. Point
|
||||||
|
# omnivoice-core at the subproject binary dir where version.h is actually
|
||||||
|
# generated. (Fix lives here, never in the fetched upstream checkout.)
|
||||||
|
target_include_directories(omnivoice-core PRIVATE ${CMAKE_BINARY_DIR}/omnivoice)
|
||||||
|
|
||||||
|
add_library(gomnivoicecpp MODULE cpp/gomnivoicecpp.cpp)
|
||||||
|
target_link_libraries(gomnivoicecpp PRIVATE omnivoice-core)
|
||||||
|
|
||||||
|
target_include_directories(gomnivoicecpp PRIVATE ${OMNIVOICE_DIR}/src)
|
||||||
|
target_include_directories(gomnivoicecpp SYSTEM PRIVATE ${OMNIVOICE_DIR}/ggml/include)
|
||||||
|
|
||||||
|
# Link GPU backends if the upstream ggml created them.
|
||||||
|
foreach(backend blas cuda metal vulkan sycl)
|
||||||
|
if(TARGET ggml-${backend})
|
||||||
|
target_link_libraries(gomnivoicecpp PRIVATE ggml-${backend})
|
||||||
|
if(backend STREQUAL "cuda")
|
||||||
|
find_package(CUDAToolkit QUIET)
|
||||||
|
if(CUDAToolkit_FOUND)
|
||||||
|
target_link_libraries(gomnivoicecpp PRIVATE CUDA::cudart)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
if(MSVC)
|
||||||
|
target_compile_options(gomnivoicecpp PRIVATE /W4 /wd4100 /wd4505)
|
||||||
|
else()
|
||||||
|
target_compile_options(gomnivoicecpp PRIVATE -Wall -Wextra
|
||||||
|
-Wno-unused-parameter -Wno-unused-function)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set_property(TARGET gomnivoicecpp PROPERTY CXX_STANDARD 17)
|
||||||
|
set_target_properties(gomnivoicecpp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||||
122
backend/go/omnivoice-cpp/Makefile
Normal file
122
backend/go/omnivoice-cpp/Makefile
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
CMAKE_ARGS?=
|
||||||
|
BUILD_TYPE?=
|
||||||
|
NATIVE?=false
|
||||||
|
|
||||||
|
GOCMD?=go
|
||||||
|
GO_TAGS?=
|
||||||
|
JOBS?=$(shell nproc --ignore=1)
|
||||||
|
|
||||||
|
# omnivoice.cpp version
|
||||||
|
OMNIVOICE_REPO?=https://github.com/ServeurpersoCom/omnivoice.cpp
|
||||||
|
OMNIVOICE_VERSION?=96d30169afd5e6bb3fd6a0e9be0eb505bfe81fcd
|
||||||
|
SO_TARGET?=libgomnivoicecpp.so
|
||||||
|
|
||||||
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|
||||||
|
ifeq ($(NATIVE),false)
|
||||||
|
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
|
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),openblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||||
|
else ifeq ($(BUILD_TYPE),clblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||||
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
|
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||||
|
else ifeq ($(BUILD_TYPE),vulkan)
|
||||||
|
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||||
|
else ifeq ($(OS),Darwin)
|
||||||
|
ifneq ($(BUILD_TYPE),metal)
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||||
|
else
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||||
|
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||||
|
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||||
|
-DCMAKE_C_COMPILER=icx \
|
||||||
|
-DCMAKE_CXX_COMPILER=icpx \
|
||||||
|
-DGGML_SYCL_F16=ON
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||||
|
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||||
|
-DCMAKE_C_COMPILER=icx \
|
||||||
|
-DCMAKE_CXX_COMPILER=icpx
|
||||||
|
endif
|
||||||
|
|
||||||
|
sources/omnivoice.cpp:
|
||||||
|
mkdir -p sources/omnivoice.cpp
|
||||||
|
cd sources/omnivoice.cpp && \
|
||||||
|
git init && \
|
||||||
|
git remote add origin $(OMNIVOICE_REPO) && \
|
||||||
|
git fetch origin && \
|
||||||
|
git checkout $(OMNIVOICE_VERSION) && \
|
||||||
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
# Detect OS
|
||||||
|
UNAME_S := $(shell uname -s)
|
||||||
|
|
||||||
|
# Only build CPU variants on Linux
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
VARIANT_TARGETS = libgomnivoicecpp-avx.so libgomnivoicecpp-avx2.so libgomnivoicecpp-avx512.so libgomnivoicecpp-fallback.so
|
||||||
|
else
|
||||||
|
VARIANT_TARGETS = libgomnivoicecpp-fallback.so
|
||||||
|
endif
|
||||||
|
|
||||||
|
omnivoice-cpp: main.go gomnivoicecpp.go $(VARIANT_TARGETS)
|
||||||
|
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o omnivoice-cpp ./
|
||||||
|
|
||||||
|
package: omnivoice-cpp
|
||||||
|
bash package.sh
|
||||||
|
|
||||||
|
build: package
|
||||||
|
|
||||||
|
clean: purge
|
||||||
|
rm -rf libgomnivoicecpp*.so package sources/omnivoice.cpp omnivoice-cpp
|
||||||
|
|
||||||
|
purge:
|
||||||
|
rm -rf build*
|
||||||
|
|
||||||
|
.NOTPARALLEL:
|
||||||
|
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
libgomnivoicecpp-avx.so: sources/omnivoice.cpp
|
||||||
|
$(info ${GREEN}I omnivoice-cpp build info:avx${RESET})
|
||||||
|
SO_TARGET=libgomnivoicecpp-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgomnivoicecpp-custom
|
||||||
|
rm -rf build-libgomnivoicecpp-avx.so
|
||||||
|
|
||||||
|
libgomnivoicecpp-avx2.so: sources/omnivoice.cpp
|
||||||
|
$(info ${GREEN}I omnivoice-cpp build info:avx2${RESET})
|
||||||
|
SO_TARGET=libgomnivoicecpp-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgomnivoicecpp-custom
|
||||||
|
rm -rf build-libgomnivoicecpp-avx2.so
|
||||||
|
|
||||||
|
libgomnivoicecpp-avx512.so: sources/omnivoice.cpp
|
||||||
|
$(info ${GREEN}I omnivoice-cpp build info:avx512${RESET})
|
||||||
|
SO_TARGET=libgomnivoicecpp-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgomnivoicecpp-custom
|
||||||
|
rm -rf build-libgomnivoicecpp-avx512.so
|
||||||
|
endif
|
||||||
|
|
||||||
|
libgomnivoicecpp-fallback.so: sources/omnivoice.cpp
|
||||||
|
$(info ${GREEN}I omnivoice-cpp build info:fallback${RESET})
|
||||||
|
SO_TARGET=libgomnivoicecpp-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgomnivoicecpp-custom
|
||||||
|
rm -rf build-libgomnivoicecpp-fallback.so
|
||||||
|
|
||||||
|
libgomnivoicecpp-custom: CMakeLists.txt cpp/gomnivoicecpp.cpp cpp/gomnivoicecpp.h
|
||||||
|
mkdir -p build-$(SO_TARGET) && \
|
||||||
|
cd build-$(SO_TARGET) && \
|
||||||
|
cmake .. $(CMAKE_ARGS) && \
|
||||||
|
cmake --build . --config Release -j$(JOBS) --target gomnivoicecpp && \
|
||||||
|
cd .. && \
|
||||||
|
mv build-$(SO_TARGET)/libgomnivoicecpp.so ./$(SO_TARGET)
|
||||||
|
|
||||||
|
test: omnivoice-cpp
|
||||||
|
@echo "Running omnivoice-cpp tests..."
|
||||||
|
bash test.sh
|
||||||
|
@echo "omnivoice-cpp tests completed."
|
||||||
|
|
||||||
|
all: omnivoice-cpp package
|
||||||
129
backend/go/omnivoice-cpp/audio.go
Normal file
129
backend/go/omnivoice-cpp/audio.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/go-audio/audio"
|
||||||
|
"github.com/go-audio/wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
const omnivoiceSampleRate = 24000
|
||||||
|
|
||||||
|
// wavHeader24k returns a 44-byte WAV header for a streaming 24 kHz mono 16-bit
|
||||||
|
// PCM stream, with placeholder (0xFFFFFFFF) sizes since the total length is
|
||||||
|
// unknown up front. Emitted as the first chunk of TTSStream so the HTTP layer
|
||||||
|
// receives a self-describing WAV (the gRPC TTSStream path never sets Message,
|
||||||
|
// so the backend owns the header - see core/backend/tts.go:ModelTTSStream).
|
||||||
|
func wavHeader24k() []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
w := func(v any) { _ = binary.Write(&buf, binary.LittleEndian, v) }
|
||||||
|
buf.WriteString("RIFF")
|
||||||
|
w(uint32(0xFFFFFFFF))
|
||||||
|
buf.WriteString("WAVE")
|
||||||
|
buf.WriteString("fmt ")
|
||||||
|
w(uint32(16)) // Subchunk1Size
|
||||||
|
w(uint16(1)) // PCM
|
||||||
|
w(uint16(1)) // mono
|
||||||
|
w(uint32(omnivoiceSampleRate)) // sample rate
|
||||||
|
w(uint32(omnivoiceSampleRate * 2)) // byte rate = SR * blockAlign
|
||||||
|
w(uint16(2)) // block align (16-bit mono)
|
||||||
|
w(uint16(16)) // bits per sample
|
||||||
|
buf.WriteString("data")
|
||||||
|
w(uint32(0xFFFFFFFF))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// floatToPCM16LE clamps each sample to [-1,1] and encodes it as little-endian
|
||||||
|
// signed 16-bit PCM.
|
||||||
|
func floatToPCM16LE(samples []float32) []byte {
|
||||||
|
out := make([]byte, len(samples)*2)
|
||||||
|
for i, s := range samples {
|
||||||
|
if s > 1 {
|
||||||
|
s = 1
|
||||||
|
} else if s < -1 {
|
||||||
|
s = -1
|
||||||
|
}
|
||||||
|
v := int16(s * 32767)
|
||||||
|
out[i*2] = byte(v)
|
||||||
|
out[i*2+1] = byte(v >> 8)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeWAV24k writes samples as a finalized 24 kHz mono 16-bit WAV at dst.
|
||||||
|
func writeWAV24k(dst string, samples []float32) error {
|
||||||
|
f, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("omnivoice: create %q: %w", dst, err)
|
||||||
|
}
|
||||||
|
enc := wav.NewEncoder(f, omnivoiceSampleRate, 16, 1, 1)
|
||||||
|
ints := make([]int, len(samples))
|
||||||
|
for i, s := range samples {
|
||||||
|
if s > 1 {
|
||||||
|
s = 1
|
||||||
|
} else if s < -1 {
|
||||||
|
s = -1
|
||||||
|
}
|
||||||
|
ints[i] = int(s * 32767)
|
||||||
|
}
|
||||||
|
b := &audio.IntBuffer{
|
||||||
|
Format: &audio.Format{NumChannels: 1, SampleRate: omnivoiceSampleRate},
|
||||||
|
Data: ints,
|
||||||
|
SourceBitDepth: 16,
|
||||||
|
}
|
||||||
|
if err := enc.Write(b); err != nil {
|
||||||
|
_ = enc.Close()
|
||||||
|
_ = f.Close()
|
||||||
|
return fmt.Errorf("omnivoice: encode WAV: %w", err)
|
||||||
|
}
|
||||||
|
if err := enc.Close(); err != nil {
|
||||||
|
_ = f.Close()
|
||||||
|
return fmt.Errorf("omnivoice: finalize WAV: %w", err)
|
||||||
|
}
|
||||||
|
return f.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// readWAVAsFloat decodes a WAV file (any sample rate/channels) to a mono
|
||||||
|
// float32 slice in [-1,1] for use as reference audio. OmniVoice expects 24 kHz;
|
||||||
|
// callers should supply 24 kHz reference clips.
|
||||||
|
func readWAVAsFloat(path string) ([]float32, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("omnivoice: open ref %q: %w", path, err)
|
||||||
|
}
|
||||||
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
|
dec := wav.NewDecoder(f)
|
||||||
|
buf, err := dec.FullPCMBuffer()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("omnivoice: decode ref %q: %w", path, err)
|
||||||
|
}
|
||||||
|
ch := int(buf.Format.NumChannels)
|
||||||
|
if ch < 1 {
|
||||||
|
ch = 1
|
||||||
|
}
|
||||||
|
bitDepth := int(buf.SourceBitDepth)
|
||||||
|
if bitDepth == 0 {
|
||||||
|
bitDepth = 16
|
||||||
|
}
|
||||||
|
scale := float32(int64(1) << uint(bitDepth-1))
|
||||||
|
n := len(buf.Data) / ch
|
||||||
|
out := make([]float32, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
// Downmix to mono by averaging channels.
|
||||||
|
var acc int
|
||||||
|
for c := 0; c < ch; c++ {
|
||||||
|
acc += buf.Data[i*ch+c]
|
||||||
|
}
|
||||||
|
out[i] = float32(acc) / float32(ch) / scale
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runtimeKeepAlive prevents the GC from reclaiming the reference-audio slice
|
||||||
|
// while its backing pointer is in use across the C call.
|
||||||
|
func runtimeKeepAlive(v any) { runtime.KeepAlive(v) }
|
||||||
166
backend/go/omnivoice-cpp/cpp/gomnivoicecpp.cpp
Normal file
166
backend/go/omnivoice-cpp/cpp/gomnivoicecpp.cpp
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
#include "gomnivoicecpp.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
#include "omnivoice.h"
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
static ov_context *g_ctx = nullptr;
|
||||||
|
|
||||||
|
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||||
|
void * /*data*/) {
|
||||||
|
if (!log)
|
||||||
|
return;
|
||||||
|
const char *lvl = "?????";
|
||||||
|
switch (level) {
|
||||||
|
case GGML_LOG_LEVEL_DEBUG: lvl = "DEBUG"; break;
|
||||||
|
case GGML_LOG_LEVEL_INFO: lvl = "INFO"; break;
|
||||||
|
case GGML_LOG_LEVEL_WARN: lvl = "WARN"; break;
|
||||||
|
case GGML_LOG_LEVEL_ERROR: lvl = "ERROR"; break;
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
fprintf(stderr, "[%-5s] %s", lvl, log);
|
||||||
|
fflush(stderr);
|
||||||
|
}
|
||||||
|
|
||||||
|
int omni_load(const char *model_path, const char *codec_path, int use_fa,
|
||||||
|
int clamp_fp16) {
|
||||||
|
ggml_log_set(ggml_log_cb, nullptr);
|
||||||
|
ggml_backend_load_all();
|
||||||
|
|
||||||
|
if (!model_path || model_path[0] == '\0') {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: model_path is required\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (!codec_path || codec_path[0] == '\0') {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: codec_path is required\n");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
ov_init_params p;
|
||||||
|
ov_init_default_params(&p);
|
||||||
|
p.model_path = model_path;
|
||||||
|
p.codec_path = codec_path;
|
||||||
|
p.use_fa = use_fa != 0;
|
||||||
|
p.clamp_fp16 = clamp_fp16 != 0;
|
||||||
|
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] Loading model=%s codec=%s\n", model_path,
|
||||||
|
codec_path);
|
||||||
|
|
||||||
|
g_ctx = ov_init(&p);
|
||||||
|
if (!g_ctx) {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] FATAL: ov_init failed: %s\n",
|
||||||
|
ov_last_error());
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] Model loaded (%s)\n", ov_version());
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill an ov_tts_params from the flat wrapper arguments.
|
||||||
|
static void fill_params(ov_tts_params *tp, const char *text, const char *lang,
|
||||||
|
const char *instruct, const float *ref_samples,
|
||||||
|
int ref_n, const char *ref_text, long long seed,
|
||||||
|
int denoise) {
|
||||||
|
ov_tts_default_params(tp);
|
||||||
|
tp->text = text ? text : "";
|
||||||
|
tp->lang = lang ? lang : "";
|
||||||
|
if (instruct && instruct[0] != '\0')
|
||||||
|
tp->instruct = instruct;
|
||||||
|
if (ref_samples && ref_n > 0) {
|
||||||
|
tp->ref_audio_24k = ref_samples;
|
||||||
|
tp->ref_n_samples = ref_n;
|
||||||
|
if (ref_text && ref_text[0] != '\0')
|
||||||
|
tp->ref_text = ref_text;
|
||||||
|
tp->denoise = denoise != 0;
|
||||||
|
}
|
||||||
|
if (seed >= 0)
|
||||||
|
tp->mg_seed = (uint64_t)seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
float *omni_tts(const char *text, const char *lang, const char *instruct,
|
||||||
|
const float *ref_samples, int ref_n, const char *ref_text,
|
||||||
|
long long seed, int denoise, int *out_n) {
|
||||||
|
if (out_n)
|
||||||
|
*out_n = 0;
|
||||||
|
if (!g_ctx) {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: model not loaded\n");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!text || text[0] == '\0') {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: text is required\n");
|
||||||
|
return nullptr; // omni_tts: out_n already 0
|
||||||
|
}
|
||||||
|
ov_tts_params tp;
|
||||||
|
fill_params(&tp, text, lang, instruct, ref_samples, ref_n, ref_text, seed,
|
||||||
|
denoise);
|
||||||
|
|
||||||
|
ov_audio out = {0};
|
||||||
|
enum ov_status rc = ov_synthesize(g_ctx, &tp, &out);
|
||||||
|
if (rc != OV_STATUS_OK || out.n_samples <= 0 || !out.samples) {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: synthesize failed (rc=%d): %s\n",
|
||||||
|
(int)rc, ov_last_error());
|
||||||
|
ov_audio_free(&out);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy into a plain malloc buffer the Go side can free symmetrically via
|
||||||
|
// omni_pcm_free; then release the ov_audio-owned buffer.
|
||||||
|
size_t bytes = (size_t)out.n_samples * sizeof(float);
|
||||||
|
float *buf = (float *)malloc(bytes);
|
||||||
|
if (!buf) {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: malloc(%zu) failed\n", bytes);
|
||||||
|
ov_audio_free(&out);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
memcpy(buf, out.samples, bytes);
|
||||||
|
if (out_n)
|
||||||
|
*out_n = out.n_samples;
|
||||||
|
ov_audio_free(&out);
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
int omni_tts_stream(const char *text, const char *lang, const char *instruct,
|
||||||
|
const float *ref_samples, int ref_n, const char *ref_text,
|
||||||
|
long long seed, int denoise, omni_pcm_chunk_cb cb,
|
||||||
|
void *user_data) {
|
||||||
|
if (!g_ctx) {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: model not loaded\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (!cb) {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: stream callback is null\n");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
if (!text || text[0] == '\0') {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: text is required\n");
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
ov_tts_params tp;
|
||||||
|
fill_params(&tp, text, lang, instruct, ref_samples, ref_n, ref_text, seed,
|
||||||
|
denoise);
|
||||||
|
// ov_audio_chunk_cb has the identical signature to omni_pcm_chunk_cb
|
||||||
|
// (bool vs int return are ABI-compatible; non-zero == true).
|
||||||
|
tp.on_chunk = (ov_audio_chunk_cb)cb;
|
||||||
|
tp.on_chunk_user_data = user_data;
|
||||||
|
|
||||||
|
ov_audio out = {0}; // stays empty in streaming mode
|
||||||
|
enum ov_status rc = ov_synthesize(g_ctx, &tp, &out);
|
||||||
|
ov_audio_free(&out);
|
||||||
|
if (rc != OV_STATUS_OK && rc != OV_STATUS_CANCELLED) {
|
||||||
|
fprintf(stderr, "[omnivoice-cpp] ERROR: stream synth failed (rc=%d): %s\n",
|
||||||
|
(int)rc, ov_last_error());
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void omni_pcm_free(float *p) { free(p); }
|
||||||
|
|
||||||
|
void omni_unload(void) {
|
||||||
|
if (g_ctx) {
|
||||||
|
ov_free(g_ctx);
|
||||||
|
g_ctx = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
38
backend/go/omnivoice-cpp/cpp/gomnivoicecpp.h
Normal file
38
backend/go/omnivoice-cpp/cpp/gomnivoicecpp.h
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
// Streaming PCM chunk callback. samples is mono float PCM at 24 kHz, valid
|
||||||
|
// only for the duration of the call. Return non-zero to continue, 0 to abort.
|
||||||
|
typedef int (*omni_pcm_chunk_cb)(const float *samples, int n_samples,
|
||||||
|
void *user_data);
|
||||||
|
|
||||||
|
// Load the LM (model_path) + codec (codec_path) GGUFs. use_fa / clamp_fp16
|
||||||
|
// map to ov_init_params. Returns 0 on success, non-zero on failure.
|
||||||
|
int omni_load(const char *model_path, const char *codec_path, int use_fa,
|
||||||
|
int clamp_fp16);
|
||||||
|
|
||||||
|
// Synthesize to a malloc'd float PCM buffer (caller frees via omni_pcm_free).
|
||||||
|
// ref_samples != null && ref_n > 0 => voice cloning (ref_text optional).
|
||||||
|
// instruct != null && non-empty => voice design. seed < 0 keeps the default
|
||||||
|
// MaskGIT seed. denoise toggles the <|denoise|> marker (only with a reference).
|
||||||
|
// Writes the sample count to *out_n. Returns NULL on failure (out_n set to 0).
|
||||||
|
float *omni_tts(const char *text, const char *lang, const char *instruct,
|
||||||
|
const float *ref_samples, int ref_n, const char *ref_text,
|
||||||
|
long long seed, int denoise, int *out_n);
|
||||||
|
|
||||||
|
// Streaming synthesis: cb is invoked per PCM chunk as audio is produced.
|
||||||
|
// Same reference/design/seed semantics as omni_tts. Returns 0 on success.
|
||||||
|
int omni_tts_stream(const char *text, const char *lang, const char *instruct,
|
||||||
|
const float *ref_samples, int ref_n, const char *ref_text,
|
||||||
|
long long seed, int denoise, omni_pcm_chunk_cb cb,
|
||||||
|
void *user_data);
|
||||||
|
|
||||||
|
// Free a buffer returned by omni_tts.
|
||||||
|
void omni_pcm_free(float *p);
|
||||||
|
|
||||||
|
// Release the OmniVoice context.
|
||||||
|
void omni_unload(void);
|
||||||
|
}
|
||||||
74
backend/go/omnivoice-cpp/e2e_test.go
Normal file
74
backend/go/omnivoice-cpp/e2e_test.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ttsReq(text, voice string, lang *string, dst string) *pb.TTSRequest {
|
||||||
|
return &pb.TTSRequest{Text: text, Voice: voice, Language: lang, Dst: dst}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("OmniVoice e2e", Label("e2e"), func() {
|
||||||
|
var loaded bool
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
modelPath := os.Getenv("OMNIVOICE_MODEL")
|
||||||
|
codecPath := os.Getenv("OMNIVOICE_CODEC")
|
||||||
|
if modelPath == "" || codecPath == "" {
|
||||||
|
Skip("OMNIVOICE_MODEL / OMNIVOICE_CODEC not set; skipping e2e")
|
||||||
|
}
|
||||||
|
if !loaded {
|
||||||
|
lib := os.Getenv("OMNIVOICE_LIBRARY")
|
||||||
|
if lib == "" {
|
||||||
|
lib = "./libgomnivoicecpp-fallback.so"
|
||||||
|
}
|
||||||
|
h, err := purego.Dlopen(lib, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
purego.RegisterLibFunc(&CppLoad, h, "omni_load")
|
||||||
|
purego.RegisterLibFunc(&CppTTS, h, "omni_tts")
|
||||||
|
purego.RegisterLibFunc(&CppTTSStream, h, "omni_tts_stream")
|
||||||
|
purego.RegisterLibFunc(&CppPCMFree, h, "omni_pcm_free")
|
||||||
|
purego.RegisterLibFunc(&CppUnload, h, "omni_unload")
|
||||||
|
Expect(CppLoad(modelPath, codecPath, 0, 0)).To(Equal(0))
|
||||||
|
loaded = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("synthesizes a WAV file via TTS", func() {
|
||||||
|
b := &OmnivoiceCpp{opts: loadOptions{seed: 42, denoise: true}}
|
||||||
|
dst := GinkgoT().TempDir() + "/out.wav"
|
||||||
|
lang := "en"
|
||||||
|
err := b.TTS(ttsReq("Hello world.", "", &lang, dst))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
fi, err := os.Stat(dst)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(fi.Size()).To(BeNumerically(">", int64(44)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("streams audio chunks via TTSStream", func() {
|
||||||
|
b := &OmnivoiceCpp{opts: loadOptions{seed: 42, denoise: true}}
|
||||||
|
results := make(chan []byte, 1024)
|
||||||
|
lang := "en"
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() { done <- b.TTSStream(ttsReq("Hello there, streaming test.", "", &lang, ""), results) }()
|
||||||
|
|
||||||
|
var chunks int
|
||||||
|
var first []byte
|
||||||
|
for c := range results {
|
||||||
|
if chunks == 0 {
|
||||||
|
first = c
|
||||||
|
}
|
||||||
|
chunks++
|
||||||
|
}
|
||||||
|
Expect(<-done).ToNot(HaveOccurred())
|
||||||
|
Expect(chunks).To(BeNumerically(">=", 2))
|
||||||
|
Expect(string(first[0:4])).To(Equal("RIFF"))
|
||||||
|
Expect(strings.HasPrefix(string(first[8:12]), "WAVE")).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
246
backend/go/omnivoice-cpp/gomnivoicecpp.go
Normal file
246
backend/go/omnivoice-cpp/gomnivoicecpp.go
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// omni_load(model_path, codec_path, use_fa, clamp_fp16) int
|
||||||
|
CppLoad func(modelPath, codecPath string, useFA, clampFP16 int) int
|
||||||
|
// omni_tts(text, lang, instruct, ref_samples, ref_n, ref_text, seed, denoise, out_n) -> float* (uintptr)
|
||||||
|
CppTTS func(text, lang, instruct string, refSamples unsafe.Pointer, refN int,
|
||||||
|
refText string, seed int64, denoise int, outN unsafe.Pointer) uintptr
|
||||||
|
// omni_tts_stream(text, lang, instruct, ref_samples, ref_n, ref_text, seed, denoise, cb, user) int
|
||||||
|
CppTTSStream func(text, lang, instruct string, refSamples unsafe.Pointer, refN int,
|
||||||
|
refText string, seed int64, denoise int, cb uintptr, user uintptr) int
|
||||||
|
CppPCMFree func(ptr uintptr)
|
||||||
|
CppUnload func()
|
||||||
|
)
|
||||||
|
|
||||||
|
type OmnivoiceCpp struct {
|
||||||
|
base.SingleThread
|
||||||
|
opts loadOptions
|
||||||
|
// audioPath is the model-config reference voice (tts.audio_path), used as
|
||||||
|
// the default voice-cloning reference when a request does not set Voice.
|
||||||
|
audioPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OmnivoiceCpp) Load(opts *pb.ModelOptions) error {
|
||||||
|
model := opts.ModelFile
|
||||||
|
if model == "" {
|
||||||
|
model = opts.ModelPath
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(model) && opts.ModelPath != "" {
|
||||||
|
model = filepath.Join(opts.ModelPath, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.opts = parseOptions(opts.Options)
|
||||||
|
|
||||||
|
// Resolve the codec/tokenizer GGUF: explicit option, else auto-discover a
|
||||||
|
// *tokenizer*.gguf sibling of the base model.
|
||||||
|
codec := o.opts.codecPath
|
||||||
|
if codec != "" && !filepath.IsAbs(codec) {
|
||||||
|
codec = filepath.Join(filepath.Dir(model), codec)
|
||||||
|
}
|
||||||
|
if codec == "" {
|
||||||
|
codec = discoverTokenizer(filepath.Dir(model))
|
||||||
|
}
|
||||||
|
if codec == "" {
|
||||||
|
return fmt.Errorf("omnivoice: no codec/tokenizer GGUF found; set option 'tokenizer:<file>'")
|
||||||
|
}
|
||||||
|
o.opts.codecPath = codec
|
||||||
|
|
||||||
|
// tts.audio_path (ModelOptions.AudioPath) is the config-level voice-cloning
|
||||||
|
// reference: a default reference WAV used when a request omits Voice.
|
||||||
|
// Resolved relative to the model directory like the codec.
|
||||||
|
o.audioPath = opts.AudioPath
|
||||||
|
if o.audioPath != "" && !filepath.IsAbs(o.audioPath) {
|
||||||
|
o.audioPath = filepath.Join(filepath.Dir(model), o.audioPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
useFA := boolToInt(o.opts.useFA)
|
||||||
|
clamp := boolToInt(o.opts.clampFP16)
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "[omnivoice-cpp] Load model=%s codec=%s use_fa=%d clamp_fp16=%d\n",
|
||||||
|
model, codec, useFA, clamp)
|
||||||
|
|
||||||
|
if rc := CppLoad(model, codec, useFA, clamp); rc != 0 {
|
||||||
|
return fmt.Errorf("omnivoice: failed to load model (rc=%d)", rc)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// discoverTokenizer returns the first *tokenizer*.gguf in dir, or "".
|
||||||
|
func discoverTokenizer(dir string) string {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
name := strings.ToLower(e.Name())
|
||||||
|
if strings.Contains(name, "tokenizer") && strings.HasSuffix(name, ".gguf") {
|
||||||
|
return filepath.Join(dir, e.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolToInt(b bool) int {
|
||||||
|
if b {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// refAudio loads the reference WAV (voice cloning) if voice points to a file.
|
||||||
|
// Returns nil if no cloning (empty or non-path - voice design uses Instructions).
|
||||||
|
func (o *OmnivoiceCpp) refAudio(voice string) ([]float32, error) {
|
||||||
|
v := strings.TrimSpace(voice)
|
||||||
|
if v == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(v); err != nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return readWAVAsFloat(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// refAudioFor resolves the cloning reference for a request: the per-request
|
||||||
|
// Voice takes precedence, falling back to the model-config audio_path. Empty
|
||||||
|
// result means no cloning (voice design via Instructions still applies).
|
||||||
|
func (o *OmnivoiceCpp) refAudioFor(req *pb.TTSRequest) ([]float32, error) {
|
||||||
|
voice := strings.TrimSpace(req.Voice)
|
||||||
|
if voice == "" {
|
||||||
|
voice = o.audioPath
|
||||||
|
}
|
||||||
|
return o.refAudio(voice)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reqParam(req *pb.TTSRequest, key string) string {
|
||||||
|
if req.Params == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return req.Params[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OmnivoiceCpp) seedFor(req *pb.TTSRequest) int64 {
|
||||||
|
if s := reqParam(req, "seed"); s != "" {
|
||||||
|
var n int64
|
||||||
|
if _, err := fmt.Sscan(s, &n); err == nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return o.opts.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
func optStr(p *string) string {
|
||||||
|
if p == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OmnivoiceCpp) TTS(req *pb.TTSRequest) error {
|
||||||
|
if req.Dst == "" {
|
||||||
|
return fmt.Errorf("omnivoice: TTS requires a destination path")
|
||||||
|
}
|
||||||
|
lang := normalizeLanguage(optStr(req.Language))
|
||||||
|
instruct := optStr(req.Instructions)
|
||||||
|
refText := reqParam(req, "ref_text")
|
||||||
|
seed := o.seedFor(req)
|
||||||
|
|
||||||
|
ref, err := o.refAudioFor(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var refPtr unsafe.Pointer
|
||||||
|
if len(ref) > 0 {
|
||||||
|
refPtr = unsafe.Pointer(&ref[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int32
|
||||||
|
ptr := CppTTS(req.Text, lang, instruct, refPtr, len(ref), refText, seed,
|
||||||
|
boolToInt(o.opts.denoise), unsafe.Pointer(&n))
|
||||||
|
runtimeKeepAlive(ref)
|
||||||
|
if ptr == 0 || n <= 0 {
|
||||||
|
return fmt.Errorf("omnivoice: synthesis failed")
|
||||||
|
}
|
||||||
|
defer CppPCMFree(ptr)
|
||||||
|
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // C-allocated PCM, copied out before free
|
||||||
|
out := make([]float32, int(n))
|
||||||
|
copy(out, src)
|
||||||
|
return writeWAV24k(req.Dst, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamState carries the active TTSStream channel to the single shared C
|
||||||
|
// callback. base.SingleThread serializes TTS/TTSStream, so one global slot is
|
||||||
|
// safe and avoids leaking a purego callback per request (purego callbacks
|
||||||
|
// cannot be freed and are capped).
|
||||||
|
var (
|
||||||
|
streamMu sync.Mutex
|
||||||
|
streamChan chan []byte
|
||||||
|
streamCbOnce sync.Once
|
||||||
|
streamCbPtr uintptr
|
||||||
|
)
|
||||||
|
|
||||||
|
// streamCallback is registered once and forwards each PCM chunk to streamChan.
|
||||||
|
func streamCallback(samples *float32, nSamples int32, _ uintptr) uintptr {
|
||||||
|
if nSamples <= 0 || samples == nil || streamChan == nil {
|
||||||
|
return 1 // continue
|
||||||
|
}
|
||||||
|
src := unsafe.Slice(samples, int(nSamples))
|
||||||
|
cp := make([]float32, int(nSamples)) // copy out of C memory before returning
|
||||||
|
copy(cp, src)
|
||||||
|
streamChan <- floatToPCM16LE(cp)
|
||||||
|
return 1 // continue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OmnivoiceCpp) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||||
|
defer close(results)
|
||||||
|
if req.Text == "" {
|
||||||
|
return fmt.Errorf("omnivoice: TTSStream requires text")
|
||||||
|
}
|
||||||
|
|
||||||
|
streamCbOnce.Do(func() {
|
||||||
|
streamCbPtr = purego.NewCallback(streamCallback)
|
||||||
|
})
|
||||||
|
|
||||||
|
lang := normalizeLanguage(optStr(req.Language))
|
||||||
|
instruct := optStr(req.Instructions)
|
||||||
|
refText := reqParam(req, "ref_text")
|
||||||
|
seed := o.seedFor(req)
|
||||||
|
|
||||||
|
ref, err := o.refAudioFor(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var refPtr unsafe.Pointer
|
||||||
|
if len(ref) > 0 {
|
||||||
|
refPtr = unsafe.Pointer(&ref[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit the WAV header first so the HTTP layer gets a self-describing stream.
|
||||||
|
results <- wavHeader24k()
|
||||||
|
|
||||||
|
streamMu.Lock()
|
||||||
|
streamChan = results
|
||||||
|
rc := CppTTSStream(req.Text, lang, instruct, refPtr, len(ref), refText, seed,
|
||||||
|
boolToInt(o.opts.denoise), streamCbPtr, 0)
|
||||||
|
streamChan = nil
|
||||||
|
streamMu.Unlock()
|
||||||
|
runtimeKeepAlive(ref)
|
||||||
|
|
||||||
|
if rc != 0 {
|
||||||
|
return fmt.Errorf("omnivoice: streaming synthesis failed (rc=%d)", rc)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
90
backend/go/omnivoice-cpp/gomnivoicecpp_test.go
Normal file
90
backend/go/omnivoice-cpp/gomnivoicecpp_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOmnivoiceCpp(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "omnivoice-cpp suite")
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("normalizeLanguage", func() {
|
||||||
|
DescribeTable("maps caller language to OmniVoice codes",
|
||||||
|
func(in, want string) {
|
||||||
|
Expect(normalizeLanguage(in)).To(Equal(want))
|
||||||
|
},
|
||||||
|
Entry("empty stays empty", "", ""),
|
||||||
|
Entry("english full name", "English", "en"),
|
||||||
|
Entry("chinese full name", "Chinese", "zh"),
|
||||||
|
Entry("locale suffix stripped", "en-US", "en"),
|
||||||
|
Entry("underscore locale", "zh_CN", "zh"),
|
||||||
|
Entry("already a code", "en", "en"),
|
||||||
|
Entry("unknown passes through normalized", "xx", "xx"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("parseOptions", func() {
|
||||||
|
It("extracts codec, use_fa, clamp_fp16, seed, denoise", func() {
|
||||||
|
o := parseOptions([]string{
|
||||||
|
"tokenizer:tok.gguf",
|
||||||
|
"use_fa:true",
|
||||||
|
"clamp_fp16:true",
|
||||||
|
"seed:7",
|
||||||
|
"denoise:false",
|
||||||
|
"unknown:ignored",
|
||||||
|
})
|
||||||
|
Expect(o.codecPath).To(Equal("tok.gguf"))
|
||||||
|
Expect(o.useFA).To(BeTrue())
|
||||||
|
Expect(o.clampFP16).To(BeTrue())
|
||||||
|
Expect(o.seed).To(Equal(int64(7)))
|
||||||
|
Expect(o.denoise).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("accepts codec: as an alias for tokenizer:", func() {
|
||||||
|
o := parseOptions([]string{"codec:c.gguf"})
|
||||||
|
Expect(o.codecPath).To(Equal("c.gguf"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("defaults seed to -1 and denoise to true", func() {
|
||||||
|
o := parseOptions(nil)
|
||||||
|
Expect(o.seed).To(Equal(int64(-1)))
|
||||||
|
Expect(o.denoise).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("wavHeader24k", func() {
|
||||||
|
It("emits a 44-byte streaming WAV header at 24 kHz mono 16-bit", func() {
|
||||||
|
h := wavHeader24k()
|
||||||
|
Expect(h).To(HaveLen(44))
|
||||||
|
Expect(string(h[0:4])).To(Equal("RIFF"))
|
||||||
|
Expect(string(h[8:12])).To(Equal("WAVE"))
|
||||||
|
Expect(string(h[12:16])).To(Equal("fmt "))
|
||||||
|
Expect(string(h[36:40])).To(Equal("data"))
|
||||||
|
var sampleRate uint32
|
||||||
|
Expect(binary.Read(bytes.NewReader(h[24:28]), binary.LittleEndian, &sampleRate)).To(Succeed())
|
||||||
|
Expect(sampleRate).To(Equal(uint32(24000)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("floatToPCM16LE", func() {
|
||||||
|
It("clamps and converts float PCM to little-endian int16 bytes", func() {
|
||||||
|
b := floatToPCM16LE([]float32{0, 1.0, -1.0, 2.0, -2.0})
|
||||||
|
Expect(b).To(HaveLen(10)) // 5 samples * 2 bytes
|
||||||
|
read := func(off int) int16 {
|
||||||
|
var v int16
|
||||||
|
_ = binary.Read(bytes.NewReader(b[off:off+2]), binary.LittleEndian, &v)
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
Expect(read(0)).To(Equal(int16(0)))
|
||||||
|
Expect(read(2)).To(Equal(int16(32767)))
|
||||||
|
Expect(read(4)).To(Equal(int16(-32767)))
|
||||||
|
Expect(read(6)).To(Equal(int16(32767))) // clamped from 2.0
|
||||||
|
Expect(read(8)).To(Equal(int16(-32767))) // clamped from -2.0
|
||||||
|
})
|
||||||
|
})
|
||||||
48
backend/go/omnivoice-cpp/main.go
Normal file
48
backend/go/omnivoice-cpp/main.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
)
|
||||||
|
|
||||||
|
type LibFuncs struct {
|
||||||
|
FuncPtr any
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
libName := os.Getenv("OMNIVOICE_LIBRARY")
|
||||||
|
if libName == "" {
|
||||||
|
libName = "./libgomnivoicecpp-fallback.so"
|
||||||
|
}
|
||||||
|
|
||||||
|
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
libFuncs := []LibFuncs{
|
||||||
|
{&CppLoad, "omni_load"},
|
||||||
|
{&CppTTS, "omni_tts"},
|
||||||
|
{&CppTTSStream, "omni_tts_stream"},
|
||||||
|
{&CppPCMFree, "omni_pcm_free"},
|
||||||
|
{&CppUnload, "omni_unload"},
|
||||||
|
}
|
||||||
|
for _, lf := range libFuncs {
|
||||||
|
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &OmnivoiceCpp{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
74
backend/go/omnivoice-cpp/options.go
Normal file
74
backend/go/omnivoice-cpp/options.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// loadOptions holds the parsed model-level options for OmniVoice.
|
||||||
|
type loadOptions struct {
|
||||||
|
codecPath string
|
||||||
|
useFA bool
|
||||||
|
clampFP16 bool
|
||||||
|
seed int64
|
||||||
|
denoise bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitOption(o string) (key, value string, ok bool) {
|
||||||
|
i := strings.Index(o, ":")
|
||||||
|
if i < 0 {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(o[:i]), strings.TrimSpace(o[i+1:]), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOptions reads the backend "key:value" option slice. Unknown keys are
|
||||||
|
// ignored. Defaults: seed -1 (engine default), denoise true.
|
||||||
|
func parseOptions(opts []string) loadOptions {
|
||||||
|
o := loadOptions{seed: -1, denoise: true}
|
||||||
|
for _, oo := range opts {
|
||||||
|
key, value, ok := splitOption(oo)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch key {
|
||||||
|
case "tokenizer", "codec":
|
||||||
|
o.codecPath = value
|
||||||
|
case "use_fa":
|
||||||
|
o.useFA = value == "true" || value == "1"
|
||||||
|
case "clamp_fp16":
|
||||||
|
o.clampFP16 = value == "true" || value == "1"
|
||||||
|
case "seed":
|
||||||
|
if n, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||||
|
o.seed = n
|
||||||
|
}
|
||||||
|
case "denoise":
|
||||||
|
o.denoise = value == "true" || value == "1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
// languageNameAliases maps full language names to OmniVoice codes. OmniVoice's
|
||||||
|
// lang hint accepts "" (auto), "en", "zh" per the upstream convention; other
|
||||||
|
// codes pass through and the engine treats unknown hints as auto.
|
||||||
|
var languageNameAliases = map[string]string{
|
||||||
|
"english": "en",
|
||||||
|
"chinese": "zh",
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeLanguage lowercases, trims, strips a region/locale suffix, and
|
||||||
|
// resolves common full names. Empty stays empty so the engine auto-detects.
|
||||||
|
func normalizeLanguage(lang string) string {
|
||||||
|
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||||
|
if lang == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if i := strings.IndexAny(lang, "-_."); i >= 0 {
|
||||||
|
lang = lang[:i]
|
||||||
|
}
|
||||||
|
if code, ok := languageNameAliases[lang]; ok {
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
return lang
|
||||||
|
}
|
||||||
64
backend/go/omnivoice-cpp/package.sh
Executable file
64
backend/go/omnivoice-cpp/package.sh
Executable file
@@ -0,0 +1,64 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Script to copy the appropriate libraries based on architecture
|
||||||
|
# This script is used in the final stage of the Dockerfile
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
REPO_ROOT="${CURDIR}/../../.."
|
||||||
|
|
||||||
|
# Create lib directory
|
||||||
|
mkdir -p $CURDIR/package/lib
|
||||||
|
|
||||||
|
cp -avf $CURDIR/omnivoice-cpp $CURDIR/package/
|
||||||
|
cp -fv $CURDIR/libgomnivoicecpp-*.so $CURDIR/package/
|
||||||
|
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||||
|
|
||||||
|
# Detect architecture and copy appropriate libraries
|
||||||
|
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||||
|
# x86_64 architecture
|
||||||
|
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||||
|
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
|
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||||
|
# ARM64 architecture
|
||||||
|
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||||
|
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
|
elif [ $(uname -s) = "Darwin" ]; then
|
||||||
|
echo "Detected Darwin"
|
||||||
|
else
|
||||||
|
echo "Error: Could not detect architecture"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Package GPU libraries based on BUILD_TYPE
|
||||||
|
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||||
|
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||||
|
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||||
|
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||||
|
package_gpu_libs
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Packaging completed successfully"
|
||||||
|
ls -liah $CURDIR/package/
|
||||||
|
ls -liah $CURDIR/package/lib/
|
||||||
52
backend/go/omnivoice-cpp/run.sh
Executable file
52
backend/go/omnivoice-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# Get the absolute current dir where the script is located
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
|
||||||
|
cd /
|
||||||
|
|
||||||
|
echo "CPU info:"
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||||
|
grep -e "flags" /proc/cpuinfo | head -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
LIBRARY="$CURDIR/libgomnivoicecpp-fallback.so"
|
||||||
|
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX found OK"
|
||||||
|
if [ -e $CURDIR/libgomnivoicecpp-avx.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libgomnivoicecpp-avx.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX2 found OK"
|
||||||
|
if [ -e $CURDIR/libgomnivoicecpp-avx2.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libgomnivoicecpp-avx2.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check avx 512
|
||||||
|
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX512F found OK"
|
||||||
|
if [ -e $CURDIR/libgomnivoicecpp-avx512.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libgomnivoicecpp-avx512.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||||
|
export OMNIVOICE_LIBRARY=$LIBRARY
|
||||||
|
|
||||||
|
# If there is a lib/ld.so, use it
|
||||||
|
if [ -f $CURDIR/lib/ld.so ]; then
|
||||||
|
echo "Using lib/ld.so"
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
|
exec $CURDIR/lib/ld.so $CURDIR/omnivoice-cpp "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
|
exec $CURDIR/omnivoice-cpp "$@"
|
||||||
30
backend/go/omnivoice-cpp/test.sh
Executable file
30
backend/go/omnivoice-cpp/test.sh
Executable file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
cd "$CURDIR"
|
||||||
|
|
||||||
|
echo "Running omnivoice-cpp backend tests..."
|
||||||
|
|
||||||
|
if [ -z "$OMNIVOICE_MODEL" ]; then
|
||||||
|
MODEL_DIR="./omnivoice-models"
|
||||||
|
mkdir -p "$MODEL_DIR"
|
||||||
|
REPO_ID="Serveurperso/OmniVoice-GGUF"
|
||||||
|
BASE_URL="https://huggingface.co/${REPO_ID}/resolve/main"
|
||||||
|
FILES=( "omnivoice-base-Q4_K_M.gguf" "omnivoice-tokenizer-Q4_K_M.gguf" )
|
||||||
|
for file in "${FILES[@]}"; do
|
||||||
|
dest="${MODEL_DIR}/${file}"
|
||||||
|
if [ -f "${dest}" ]; then
|
||||||
|
echo " [skip] ${file}"
|
||||||
|
else
|
||||||
|
echo " [download] ${file}..."
|
||||||
|
curl -L -o "${dest}" "${BASE_URL}/${file}" --progress-bar
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
export OMNIVOICE_MODEL="${MODEL_DIR}/omnivoice-base-Q4_K_M.gguf"
|
||||||
|
export OMNIVOICE_CODEC="${MODEL_DIR}/omnivoice-tokenizer-Q4_K_M.gguf"
|
||||||
|
fi
|
||||||
|
|
||||||
|
go test -v -timeout 1200s .
|
||||||
|
|
||||||
|
echo "All omnivoice-cpp e2e tests passed."
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# parakeet-cpp backend Makefile.
|
# parakeet-cpp backend Makefile.
|
||||||
#
|
#
|
||||||
# Upstream pin lives below as PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
# Upstream pin lives below as PARAKEET_VERSION?=db755a78d39f789bb7d4e3935158a9e8105dbe36
|
||||||
# (.github/bump_deps.sh) can find and update it - matches the
|
# (.github/bump_deps.sh) can find and update it - matches the
|
||||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||||
#
|
#
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
# That's what the L0 smoke test uses. The default target below does the
|
# That's what the L0 smoke test uses. The default target below does the
|
||||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||||
|
|
||||||
PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
PARAKEET_VERSION?=db755a78d39f789bb7d4e3935158a9e8105dbe36
|
||||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||||
|
|
||||||
GOCMD?=go
|
GOCMD?=go
|
||||||
@@ -39,7 +39,10 @@ endif
|
|||||||
# is overwritten back to OFF and the build silently falls back to CPU. Forward the
|
# is overwritten back to OFF and the build silently falls back to CPU. Forward the
|
||||||
# PARAKEET_GGML_* options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
# PARAKEET_GGML_* options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||||
ifeq ($(BUILD_TYPE),cublas)
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
CMAKE_ARGS+=-DPARAKEET_GGML_CUDA=ON
|
# GGML_CUDA_GRAPHS is OFF by ggml default; enabling it gives a small free
|
||||||
|
# speedup (~1% measured on GB10, never negative) by capturing/replaying the
|
||||||
|
# CUDA graph. Not gated by parakeet.cpp, so it passes straight through to ggml.
|
||||||
|
CMAKE_ARGS+=-DPARAKEET_GGML_CUDA=ON -DGGML_CUDA_GRAPHS=ON
|
||||||
else ifeq ($(BUILD_TYPE),openblas)
|
else ifeq ($(BUILD_TYPE),openblas)
|
||||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||||
else ifeq ($(BUILD_TYPE),hipblas)
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
|
|||||||
@@ -98,17 +98,21 @@ type transcriptJSON struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// streamFeedJSON mirrors the document returned by
|
// streamFeedJSON mirrors the document returned by
|
||||||
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v4):
|
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v5):
|
||||||
//
|
//
|
||||||
// {"text":"...","eou":0,"frame_sec":0.080000,
|
// {"text":"...","eou":0,"eob":0,"frame_sec":0.080000,
|
||||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
|
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
|
||||||
//
|
//
|
||||||
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
|
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
|
||||||
// <EOU>/<EOB> fired this feed; "words" are the words finalized this call with
|
// <EOU> (end of utterance) fired this feed and "eob" is 1 when an <EOB>
|
||||||
// absolute (stream-relative) start/end seconds.
|
// (backchannel) fired. ABI v4 conflated the two into "eou"; v5 split them, so
|
||||||
|
// we read both and treat either as an utterance boundary for segmentation.
|
||||||
|
// "words" are the words finalized this call with absolute (stream-relative)
|
||||||
|
// start/end seconds.
|
||||||
type streamFeedJSON struct {
|
type streamFeedJSON struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
Eou int `json:"eou"`
|
Eou int `json:"eou"`
|
||||||
|
Eob int `json:"eob"`
|
||||||
FrameSec float64 `json:"frame_sec"`
|
FrameSec float64 `json:"frame_sec"`
|
||||||
Words []transcriptWord `json:"words"`
|
Words []transcriptWord `json:"words"`
|
||||||
}
|
}
|
||||||
@@ -483,7 +487,10 @@ type streamSegmenter struct {
|
|||||||
|
|
||||||
func (s *streamSegmenter) add(doc streamFeedJSON) {
|
func (s *streamSegmenter) add(doc streamFeedJSON) {
|
||||||
s.cur = append(s.cur, doc.Words...)
|
s.cur = append(s.cur, doc.Words...)
|
||||||
if doc.Eou != 0 {
|
// Close the segment on either turn signal: <EOU> (end of utterance) or
|
||||||
|
// <EOB> (backchannel). ABI v4 reported both via "eou"; v5 split them, so we
|
||||||
|
// OR them here to keep the v4 segmentation boundaries.
|
||||||
|
if doc.Eou != 0 || doc.Eob != 0 {
|
||||||
s.flush()
|
s.flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -671,11 +678,12 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// streamJSON drives the ABI v4 streaming JSON entry points: each feed/finalize
|
// streamJSON drives the streaming JSON entry points (present since ABI v4): each
|
||||||
// returns a {text,eou,frame_sec,words} document. The newly-finalized text is
|
// feed/finalize returns a {text,eou,eob,frame_sec,words} document. The
|
||||||
// emitted as a delta (unchanged streaming contract) while words are accumulated
|
// newly-finalized text is emitted as a delta (unchanged streaming contract)
|
||||||
// into per-utterance segments (closed on EOU) so the closing FinalResult carries
|
// while words are accumulated into per-utterance segments (closed on <EOU> or
|
||||||
// timestamped segments. Runs under engineMu (already held by the caller).
|
// <EOB>) so the closing FinalResult carries timestamped segments. Runs under
|
||||||
|
// engineMu (already held by the caller).
|
||||||
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
|
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
|
||||||
duration float32, results chan *pb.TranscriptStreamResponse) error {
|
duration float32, results chan *pb.TranscriptStreamResponse) error {
|
||||||
var (
|
var (
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user