mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-12 10:47:23 -04:00
Compare commits
83 Commits
feat/p2p-f
...
feat/dllm-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b40843cf62 | ||
|
|
c9c6040fe8 | ||
|
|
8134d6db37 | ||
|
|
ad6d1dbc8b | ||
|
|
eb61e1d770 | ||
|
|
aba9c4794a | ||
|
|
04d6f66a9a | ||
|
|
52b3b68cea | ||
|
|
99184809fa | ||
|
|
294c04ae2f | ||
|
|
778f85c2a0 | ||
|
|
af0db1419c | ||
|
|
892fc49949 | ||
|
|
228a6dfe79 | ||
|
|
51a92b6093 | ||
|
|
b5964d385d | ||
|
|
fba8c9c498 | ||
|
|
6b2badb837 | ||
|
|
8b8506d01a | ||
|
|
6910a0bb48 | ||
|
|
cffd03b522 | ||
|
|
bf448d3794 | ||
|
|
1d4a12f7c0 | ||
|
|
186d62801d | ||
|
|
da4ed05429 | ||
|
|
ec1eea4f45 | ||
|
|
b203b32e57 | ||
|
|
48a8ce98aa | ||
|
|
8344d1c865 | ||
|
|
d2e6b93369 | ||
|
|
e1ec03d33f | ||
|
|
9323f4b5ca | ||
|
|
c20225fc13 | ||
|
|
337acc4c37 | ||
|
|
618e90cd13 | ||
|
|
92dea961c2 | ||
|
|
2e93186043 | ||
|
|
d07037e817 | ||
|
|
f6cc90d258 | ||
|
|
2c804bef5a | ||
|
|
6070402477 | ||
|
|
67f80a152b | ||
|
|
a7cb587d96 | ||
|
|
f7c74ad2da | ||
|
|
7402d1fd20 | ||
|
|
8c42695ef8 | ||
|
|
72e3241431 | ||
|
|
cd2bf95862 | ||
|
|
f64b72dd7d | ||
|
|
03c84cff28 | ||
|
|
9bc69c9e5f | ||
|
|
1e6c9cfd60 | ||
|
|
0e6712f734 | ||
|
|
0e4cee9a97 | ||
|
|
352b7ec604 | ||
|
|
ba706422fb | ||
|
|
e837921c2c | ||
|
|
73385713ca | ||
|
|
a4e671779a | ||
|
|
7051b2e0a1 | ||
|
|
469737101a | ||
|
|
858257eaf0 | ||
|
|
ef80a0e825 | ||
|
|
92726f7631 | ||
|
|
994063ba9a | ||
|
|
c1a55cf72d | ||
|
|
96758841d8 | ||
|
|
7a59260621 | ||
|
|
27e63b9a78 | ||
|
|
55c0911c23 | ||
|
|
f6cb6ab6d9 | ||
|
|
9f11b09c6a | ||
|
|
a5c4f822f0 | ||
|
|
fb36c262fe | ||
|
|
0e4e8980e6 | ||
|
|
3a932a9803 | ||
|
|
9d10418593 | ||
|
|
5470051d4d | ||
|
|
68c5eeebc3 | ||
|
|
1531fabe23 | ||
|
|
b7673d5b76 | ||
|
|
b64bdaf406 | ||
|
|
eebf08ff1d |
138
.agents/dllm-backend.md
Normal file
138
.agents/dllm-backend.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# Working on the dllm Backend
|
||||
|
||||
`mudler/dllm.cpp` is a standalone C++/ggml engine for DiffusionGemma
|
||||
block-diffusion models. LocalAI wraps it with a **pure-Go** backend at
|
||||
`backend/go/dllm/` that dlopens `libdllm.so` via purego (ebitengine/purego) -
|
||||
NOT cgo, and NOT a C++ grpc-server fork. The Go side owns chat templating
|
||||
(gemma4 renderer) and output parsing (gemma4 streaming parser) and implements
|
||||
the rich gRPC interface (`PredictRich`/`PredictStreamRich`, ChatDelta replies).
|
||||
|
||||
> NOTE: github.com/mudler/dllm.cpp is still **private** (publishing is
|
||||
> planned). Until then the Makefile's anonymous clone fails; use the local-dev
|
||||
> symlink shortcut documented at the top of `backend/go/dllm/Makefile`
|
||||
> (symlink an out-of-tree `build/libdllm.so` into the backend dir and skip the
|
||||
> clone), or a git credential helper with repo access.
|
||||
|
||||
## Pin
|
||||
|
||||
`backend/go/dllm/Makefile` pins `DLLM_VERSION?=<sha>` at the top
|
||||
(whisper / parakeet-cpp / ds4 convention). The bump-deps bot
|
||||
(`.github/workflows/bump_deps.yaml`) tracks `mudler/dllm.cpp` `main` and
|
||||
rewrites that variable. After a manual bump: `make -C backend/go/dllm purge &&
|
||||
make -C backend/go/dllm` (the clone is keyed on the directory existing, not
|
||||
the sha).
|
||||
|
||||
## C-ABI and the serialization contract
|
||||
|
||||
The binding covers the 9-symbol flat C-ABI from dllm.cpp's
|
||||
`include/dllm_capi.h` (ABI v1; `main.go` hard-fails on a version mismatch):
|
||||
`abi_version, load, free, last_error, free_string, tokenize_json, generate,
|
||||
generate_stream, cancel`. Contract points the Go wiring encodes (`capi.go`
|
||||
header comment has the full list):
|
||||
|
||||
- **One ctx = one concurrent generate/tokenize.** A per-model worker
|
||||
goroutine (`Dllm.jobs` in `dllm.go`) owns ALL C calls, making the
|
||||
serialization structural instead of lock discipline.
|
||||
- **`dllm_capi_cancel` is the ONE exception**: it only flips an atomic and may
|
||||
be called from any goroutine mid-generate, so `Dllm.Cancel` bypasses the
|
||||
worker queue. The flag resets at the start of each generate, so a watchdog
|
||||
racing a new generate must re-issue cancel.
|
||||
- **`last_error` is a borrowed pointer** and must only be read AFTER the
|
||||
failing call returned (never while a generate is in flight on the same ctx).
|
||||
- **Free vs in-flight requests**: requests hold `genMu.RLock` for their full
|
||||
duration; `Free` takes the write lock, so it only runs when nothing is in
|
||||
flight, then drains and closes the worker. Post-Free requests get a clean
|
||||
"model not loaded" error.
|
||||
- `tokenize_json`/`generate` return malloc'd `char*` (bound as `uintptr`,
|
||||
copied, then `dllm_capi_free_string`d); opts/params JSON must be a FLAT
|
||||
object of scalars (`buildOptsJSON` rejects anything else).
|
||||
|
||||
## Wire shape
|
||||
|
||||
| RPC | Implementation |
|
||||
|---|---|
|
||||
| LoadModel | `dllm_capi_load` (params: `n_gpu_layers`, `n_threads`, `ctx_len`); `Options[]` parsed into per-request gen opts (`eb_*`, `blocks`, `kv_cache`) by `parseModelGenOpts` |
|
||||
| PredictRich | render (if templated) → `dllm_capi_generate` → parse → ONE Reply with aggregated ChatDeltas + legacy `Message` bytes |
|
||||
| PredictStreamRich | `dllm_capi_generate_stream`; per committed diffusion block → UTF-8 holdback → parser.Feed → one Reply per non-empty delta batch (channel closed by the CALLER, per `pkg/grpc/interface.go`) |
|
||||
| Predict / PredictStream | Legacy paths, delegate to the rich pair (legacy stream INVERTS channel ownership: the impl closes) |
|
||||
| TokenizeString | `dllm_capi_tokenize_json` (C side prepends BOS per `vocab.add_bos`) |
|
||||
| Cancel | `dllm_capi_cancel`, exposed as the `grpc.Cancellable` capability (`pkg/grpc/interface.go`): the gRPC server arms it via `context.AfterFunc` on the Predict/PredictStream context, so client disconnects/timeouts abort the in-flight generate - llama.cpp `IsCancelled()` parity for Go backends |
|
||||
|
||||
`n_threads` and `ctx_len` are accepted-but-ignored by the engine at the
|
||||
current pin (the context bound comes from GGUF `n_ctx_train`); they are sent
|
||||
for forward compatibility.
|
||||
|
||||
## Renderer / parser (the templated chat path)
|
||||
|
||||
With `use_tokenizer_template` + raw Messages, the backend owns templating and
|
||||
parsing (the ds4 precedent, but in Go):
|
||||
|
||||
- `gemma4_renderer.go` - `RenderGemma4(msgs, toolsJSON, enableThinking,
|
||||
addGenerationPrompt)`. The file embeds the FULL `tokenizer.chat_template`
|
||||
jinja (17466 bytes, md5 `8c34cf93c7a7815b3fdb300a009c4c17`) extracted
|
||||
verbatim from `diffusiongemma-26B-A4B-it-BF16.gguf` via gguf-py - e.g.
|
||||
`python scripts/dump_gguf.py model.gguf | grep -A400 chat_template` in the
|
||||
dllm.cpp checkout - as a numbered comment block; every Go rule cites its
|
||||
"tpl L<n>" line. Re-verify the md5 before blaming the renderer for a
|
||||
mismatch with a new GGUF. **BOS exception**: the template emits
|
||||
`{{- bos_token -}}` but the renderer deliberately does NOT - dllm.cpp's
|
||||
`run_generate` tokenizes with `prepend_bos = vocab.add_bos` (true for
|
||||
gemma4), so a literal `<bos>` would double it.
|
||||
- `gemma4_parser.go` - streaming state machine turning raw model text
|
||||
(fragments can split anywhere, including mid-marker) into ChatDeltas:
|
||||
thought channels → `reasoning_content`, `<|tool_call>call:name{...}` →
|
||||
ToolCallDelta, `<turn|>` → done. Marker grammar cross-checked against vLLM
|
||||
PR #45163's gemma4 tool/reasoning parsers. Malformed payloads are re-emitted
|
||||
raw as content, never dropped.
|
||||
- Thinking is **opt-in** for this family (`Metadata["enable_thinking"]`,
|
||||
default OFF - the inverse of ds4): the template gates every thinking branch
|
||||
on `enable_thinking`, and the no-thinking render pre-closes an empty thought
|
||||
channel, so the parser always starts in content state.
|
||||
- **UTF-8 boundary holdback** (`splitValidUTF8` in `dllm.go`): per-block
|
||||
detokenization can split a multi-byte character across block boundaries, and
|
||||
grpc-go refuses to marshal invalid UTF-8 in proto3 strings. An incomplete
|
||||
trailing sequence (at most 3 bytes) is carried into the next block; genuinely
|
||||
undecodable bytes become U+FFFD.
|
||||
|
||||
Without `use_tokenizer_template`, the prompt passes through verbatim and the
|
||||
output is NOT gemma4-parsed (plain content, like any non-autoparsing backend).
|
||||
|
||||
## Tests
|
||||
|
||||
| Layer | Gate | What |
|
||||
|---|---|---|
|
||||
| `backend/go/dllm/*_test.go` (renderer/parser/wiring) | none - run in plain `go test ./backend/go/dllm/...` | Ginkgo specs over a fake `generator` seam; canonical renderer fixtures from transformers' `test_modeling_diffusion_gemma.py`, parser tables from the vLLM gemma4 parsers |
|
||||
| `backend/go/dllm/dllm_test.go` C-ABI smoke | `DLLM_TEST_LIBRARY` + `DLLM_TEST_TINY_MODEL` (dllm.cpp's `tests/fixtures/tiny_with_vocab.gguf`); Skips when unset | Drives the real `libdllm.so`: ABI check, load, tokenize `[2,18]`, deterministic generate, cancel (incl. mid-stream `Dllm.Cancel` aborting a deliberately slow `eb_max_steps:256` run in ~10ms) |
|
||||
| `tests/e2e-backends/dllm_test.go` | `BACKEND_TEST_DLLM=1` + `BACKEND_BINARY` (packaged run.sh) + `BACKEND_TEST_MODEL_FILE` (tiny fixture) | Templated chat round trip (Messages + UseTokenizerTemplate) over the real gRPC binary, non-streaming + streaming; plus client-context cancellation mid-stream (proves the `Cancellable` server plumbing end to end) |
|
||||
| Real-model e2e | `BACKEND_TEST_DLLM_REAL_MODEL_FILE` (26B BF16, ~50 GB) + `BACKEND_TEST_DLLM_REAL_GPU_LAYERS` | CUDA-13-class hardware only |
|
||||
|
||||
Tool-call e2e is deliberately absent from the tiny-model spec: the fixture has
|
||||
random weights and cannot be coaxed into emitting tool markup; the unit tables
|
||||
carry that coverage.
|
||||
|
||||
## Build matrix
|
||||
|
||||
`cpu-dllm` (amd64 + arm64), `cuda13-dllm` (amd64), and
|
||||
`cuda13-nvidia-l4t-arm64-dllm` (arm64 CUDA: Jetson / DGX Spark GB10), via
|
||||
`.github/backend-matrix.yml`. No darwin/Metal. CUDA builds forward
|
||||
`-DDLLM_CUDA=ON` (dllm.cpp gates ggml's CUDA behind its own flag - a bare
|
||||
`-DGGML_CUDA=ON` is overridden by the cache FORCE). `libdllm.so` is
|
||||
self-contained (ggml statically absorbed, PIC), so `package.sh` only ships
|
||||
the binary, `run.sh` and that one .so (the parakeet-cpp-style stub layout;
|
||||
no ldd walk yet).
|
||||
|
||||
## Known limitations
|
||||
|
||||
- **Cancel granularity**: the C-ABI cancel flag is per-ctx and resets on
|
||||
every generate entry, so a Cancel racing a NEW generate can be lost, and
|
||||
with requests queued on the worker it aborts whichever generate is
|
||||
currently running (acceptable: the server de-registers the hook on normal
|
||||
completion, one process serves one model).
|
||||
- **Throughput**: ~0.15 tok/s on the 26B at default settings (GB10) - every
|
||||
denoise step recomputes the full prompt+canvas. The upstream prefix-KV
|
||||
cache (dllm.cpp P3) is the fix; `kv_cache:on` errors until it lands
|
||||
(`auto`/`off` are accepted no-ops).
|
||||
- **Repo privacy**: see the note at the top - CI clone of dllm.cpp needs the
|
||||
repo published (or credentials) before the backend images can build.
|
||||
- Engine spec/validation references: dllm.cpp `docs/validation.md` and
|
||||
LocalAI `docs/superpowers/specs/2026-06-10-dllm-cpp-design.md`.
|
||||
69
.github/backend-matrix.yml
vendored
69
.github/backend-matrix.yml
vendored
@@ -1608,6 +1608,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-dllm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1647,6 +1660,19 @@ include:
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-dllm'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1766,20 +1792,6 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-turboquant'
|
||||
builder-base-image: 'quay.io/go-skynet/ci-cache:base-grpc-rocm-amd64'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -3159,6 +3171,35 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# dllm
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-dllm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-dllm'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
|
||||
13
.github/gallery-agent/main.go
vendored
13
.github/gallery-agent/main.go
vendored
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -113,6 +114,17 @@ func main() {
|
||||
fmt.Println("Searching for trending models on HuggingFace...")
|
||||
rawModels, err := client.GetTrending(searchTerm, limit)
|
||||
if err != nil {
|
||||
if errors.Is(err, hfapi.ErrRateLimited) {
|
||||
fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err)
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: 0,
|
||||
ModelsAdded: 0,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -277,4 +289,3 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -38,6 +38,10 @@ jobs:
|
||||
variable: "PARAKEET_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/parakeet-cpp/Makefile"
|
||||
- repository: "mudler/dllm.cpp"
|
||||
variable: "DLLM_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/dllm/Makefile"
|
||||
- repository: "leejet/stable-diffusion.cpp"
|
||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||
branch: "master"
|
||||
|
||||
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: securego/gosec@v2.22.9
|
||||
uses: securego/gosec@v2.27.1
|
||||
with:
|
||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||
|
||||
@@ -26,6 +26,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
|
||||
| [.agents/sglang-backend.md](.agents/sglang-backend.md) | Working on the SGLang backend — `engine_args` validation against ServerArgs, speculative-decoding (EAGLE/EAGLE3/DFLASH/MTP) recipes, parser handling |
|
||||
| [.agents/ds4-backend.md](.agents/ds4-backend.md) | Working on the ds4 backend - DSML state machine, thinking modes, KV cache, Metal+CUDA matrix |
|
||||
| [.agents/dllm-backend.md](.agents/dllm-backend.md) | Working on the dllm backend (DiffusionGemma block-diffusion) - purego C-ABI binding, per-ctx serialization contract, gemma4 renderer/parser, gated test layers |
|
||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
|
||||
17
Makefile
17
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/dllm backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -180,7 +180,7 @@ osx-signed: build
|
||||
|
||||
## Run
|
||||
run: ## run local-ai
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./cmd/local-ai
|
||||
|
||||
prepare-test: protogen-go build-mock-backend
|
||||
|
||||
@@ -309,13 +309,20 @@ run-e2e-aio: protogen-go
|
||||
@echo 'Running e2e AIO tests'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
|
||||
|
||||
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
|
||||
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
|
||||
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
|
||||
test-e2e-distributed: protogen-go
|
||||
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
|
||||
|
||||
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
|
||||
# cpu-vllm backend from the current working tree, then drives a
|
||||
# head + headless follower via testcontainers-go and asserts a chat
|
||||
# completion. BuildKit caches both images, so re-runs only rebuild
|
||||
# what changed. The test lives under tests/e2e/distributed and is
|
||||
# selected by the VLLMMultinode label so it doesn't run alongside
|
||||
# the other distributed-suite tests by default.
|
||||
# test-e2e-distributed.
|
||||
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
|
||||
@echo 'Running e2e vLLM multi-node DP test'
|
||||
LOCALAI_IMAGE=local-ai \
|
||||
@@ -1164,6 +1171,9 @@ BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|tr
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
||||
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
||||
# dllm is mudler/dllm.cpp, the DiffusionGemma block-diffusion engine,
|
||||
# wrapped by the purego backend at backend/go/dllm.
|
||||
BACKEND_DLLM = dllm|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
||||
@@ -1253,6 +1263,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DLLM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
||||
|
||||
10
README.md
10
README.md
@@ -149,6 +149,16 @@ local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
||||
local-ai run oci://localai/phi-2:latest
|
||||
```
|
||||
|
||||
To test a running LocalAI server from the terminal, open an interactive chat session from another shell. Inside the prompt, `/models` lists installed models and `/model <name>` switches between them.
|
||||
|
||||
```bash
|
||||
# Terminal 1
|
||||
local-ai run llama-3.2-1b-instruct:q4_k_m
|
||||
|
||||
# Terminal 2
|
||||
local-ai chat --model llama-3.2-1b-instruct:q4_k_m
|
||||
```
|
||||
|
||||
> **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
|
||||
|
||||
For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).
|
||||
|
||||
@@ -537,6 +537,15 @@ message TTSRequest {
|
||||
string dst = 3;
|
||||
string voice = 4;
|
||||
optional string language = 5;
|
||||
// instructions is a free-form, per-request style/voice description (maps to
|
||||
// the OpenAI `instructions` field). Backends that support expressive synthesis
|
||||
// (e.g. Qwen3-TTS CustomVoice/VoiceDesign) prefer this over the static YAML
|
||||
// option when set; backends that don't simply ignore it.
|
||||
optional string instructions = 6;
|
||||
// params carries optional, backend-specific per-request generation parameters
|
||||
// (e.g. Chatterbox exaggeration/cfg_weight/temperature). Values are strings and
|
||||
// coerced by the backend; unset leaves the backend's configured defaults.
|
||||
map<string, string> params = 7;
|
||||
}
|
||||
|
||||
message VADRequest {
|
||||
|
||||
@@ -60,10 +60,12 @@ elseif(DS4_GPU STREQUAL "cpu")
|
||||
set(DS4_OBJS "${DS4_DIR}/ds4_cpu.o")
|
||||
endif()
|
||||
|
||||
# ds4.c now references ds4_distributed.c (distributed inference was split into
|
||||
# its own translation unit upstream). It is a single GPU-agnostic object shared
|
||||
# by every GPU mode, so link it in regardless of DS4_GPU.
|
||||
# ds4.c now references ds4_distributed.c (distributed inference) and ds4_ssd.c
|
||||
# (SSD expert-cache), each split into its own translation unit upstream. Both
|
||||
# are GPU-agnostic objects shared by every GPU mode, so link them in regardless
|
||||
# of DS4_GPU.
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_distributed.o")
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_ssd.o")
|
||||
|
||||
add_executable(${TARGET}
|
||||
grpc-server.cpp
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
# Upstream pin lives below as DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
@@ -18,19 +18,20 @@ UNAME_S := $(shell uname -s)
|
||||
|
||||
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
|
||||
|
||||
# ds4_distributed.o is a GPU-agnostic translation unit that ds4.c/ds4_cpu.o now
|
||||
# reference (upstream split distributed inference into its own .c). The same
|
||||
# object is shared by every GPU mode, so it is appended unconditionally below.
|
||||
# ds4_distributed.o and ds4_ssd.o are GPU-agnostic translation units that
|
||||
# ds4.c/ds4_cpu.o now reference (upstream split distributed inference and the
|
||||
# SSD expert-cache into their own .c files). Both objects are shared by every
|
||||
# GPU mode, so they are appended unconditionally below.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS += -DDS4_GPU=cuda
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
CMAKE_ARGS += -DDS4_GPU=metal
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
|
||||
else
|
||||
# CPU reference path (Linux only - macOS CPU path is broken by VM bug per ds4 README).
|
||||
CMAKE_ARGS += -DDS4_GPU=cpu
|
||||
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o
|
||||
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o ds4_ssd.o
|
||||
endif
|
||||
|
||||
ifneq ($(NATIVE),true)
|
||||
@@ -55,11 +56,11 @@ ds4:
|
||||
# the right per-platform compile flags (Objective-C/Metal on Darwin, nvcc on Linux+CUDA).
|
||||
ds4/ds4.o: ds4
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
|
||||
else
|
||||
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o
|
||||
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o ds4_ssd.o
|
||||
endif
|
||||
|
||||
grpc-server: ds4/ds4.o
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=3f40e73c367ad9f0c1b1819f28c7348c26aa340d
|
||||
IK_LLAMA_VERSION?=e6f8112f3ba126eed3ff5b30cdd08085414a7516
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=5dcb71166686799f0d873eab7386234302d05ecf
|
||||
LLAMA_VERSION?=039e20a2db9e87b2477c76cc04905f3e1acad77f
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -381,6 +381,15 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
});
|
||||
}
|
||||
|
||||
// for each video in the request, add the video data
|
||||
for (int i = 0; i < predict->videos_size(); i++) {
|
||||
data["video_data"].push_back(json
|
||||
{
|
||||
{"id", i},
|
||||
{"data", predict->videos(i)},
|
||||
});
|
||||
}
|
||||
|
||||
data["stop"] = predict->stopprompts();
|
||||
// data["n_probs"] = predict->nprobs();
|
||||
//TODO: images,
|
||||
@@ -482,23 +491,13 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!request->draftmodel().empty()) {
|
||||
params.speculative.draft.mparams.path = request->draftmodel();
|
||||
// Default to draft type if a draft model is set but no explicit type.
|
||||
// Upstream (post ggml-org/llama.cpp#22838) made the speculative type a
|
||||
// vector; the turboquant fork still uses the legacy scalar. The
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
||||
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
|
||||
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
}
|
||||
#else
|
||||
// Upstream made the speculative type a vector (ggml-org/llama.cpp#22838)
|
||||
// and renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE (#22964).
|
||||
const bool no_spec_type = params.speculative.types.empty() ||
|
||||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
||||
if (no_spec_type) {
|
||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// params.model_alias ??
|
||||
@@ -574,9 +573,10 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// tokens (0 disables the minimum). Match upstream's default (256). This
|
||||
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
|
||||
// also shifted from a fixed cadence to a minimum spacing. The turboquant
|
||||
// fork branched before the field existed, so skip it on the legacy path
|
||||
// (LOCALAI_LEGACY_LLAMA_CPP_SPEC is injected by patch-grpc-server.sh).
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// fork still lacks common_params::checkpoint_min_step, so skip it there
|
||||
// (LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh).
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
params.checkpoint_min_step = 256;
|
||||
#endif
|
||||
|
||||
@@ -752,7 +752,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
|
||||
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
|
||||
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
|
||||
@@ -906,17 +906,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// Fork only knows a single scalar `type`. Take the first comma-
|
||||
// separated value and assign it via the singular helper.
|
||||
std::string first = optval_str;
|
||||
const auto comma = first.find(',');
|
||||
if (comma != std::string::npos) first = first.substr(0, comma);
|
||||
auto type = common_speculative_type_from_name(first);
|
||||
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||
params.speculative.type = type;
|
||||
}
|
||||
#else
|
||||
// Upstream switched to a vector of types (comma-separated for multi-type
|
||||
// chaining via common_speculative_types_from_names). We keep accepting a
|
||||
// single value here, but also tolerate comma-separated lists.
|
||||
@@ -945,7 +934,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!parsed.empty()) {
|
||||
params.speculative.types = parsed;
|
||||
}
|
||||
#endif
|
||||
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.draft.n_max = std::stoi(optval_str); } catch (...) {}
|
||||
@@ -983,21 +971,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// shares the target context size. Accept the option for backward
|
||||
// compatibility but silently ignore it.
|
||||
|
||||
// Everything below relies on struct shape introduced in ggml-org/llama.cpp#22838
|
||||
// (parallel drafting): `ngram_mod`, `ngram_map_k`, `ngram_map_k4v`,
|
||||
// `ngram_cache`, and the `draft.{cache_type_*, cpuparams*, tensor_buft_overrides}`
|
||||
// fields. The turboquant fork branched before that, so its build defines
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC via patch-grpc-server.sh and these option
|
||||
// keys become unrecognized (silently dropped, like any unknown opt) for it.
|
||||
//
|
||||
// The `#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC` / `#else` split below sits at the
|
||||
// closing-brace position of the `draft_ctx_size` branch on purpose: in the
|
||||
// legacy build the chain ends here (the brace closes draft_ctx_size), and in
|
||||
// the modern build the chain continues with `} else if (...)` instead, so the
|
||||
// brace count stays balanced under both branches of the preprocessor.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
}
|
||||
#else
|
||||
// --- ngram_mod family (upstream --spec-ngram-mod-*) ---
|
||||
} else if (!strcmp(optname, "spec_ngram_mod_n_min")) {
|
||||
if (optval != NULL) {
|
||||
@@ -1127,7 +1100,6 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
}
|
||||
if (!cur.empty()) flush(cur);
|
||||
}
|
||||
#endif // LOCALAI_LEGACY_LLAMA_CPP_SPEC — closes the `else`/`#ifdef` opened at draft_ctx_size
|
||||
}
|
||||
|
||||
// Set params.n_parallel from environment variable if not set via options (fallback)
|
||||
@@ -1177,15 +1149,11 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
}
|
||||
// The draft tensor_buft_overrides are only populated under the modern
|
||||
// (post-#22838) layout, whose population code is itself gated by
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC above. The turboquant fork lacks
|
||||
// common_params_speculative::draft entirely, so skip the sentinel there too.
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// Terminate the draft tensor_buft_overrides list with a sentinel, mirroring
|
||||
// the main-model handling above.
|
||||
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
|
||||
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: Add yarn
|
||||
|
||||
@@ -1544,7 +1512,7 @@ public:
|
||||
msg_json["role"] = msg.role();
|
||||
|
||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
||||
|
||||
// Handle content - can be string, null, or array
|
||||
// For multimodal content, we'll embed images/audio from separate fields
|
||||
@@ -1595,6 +1563,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else {
|
||||
// Use content as-is (already array or not last user message)
|
||||
@@ -1629,6 +1607,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else if (msg.role() == "tool") {
|
||||
// Tool role messages must have content field set, even if empty
|
||||
@@ -1944,6 +1932,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto re_it = metadata.find("reasoning_effort");
|
||||
if (re_it != metadata.end() && !re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2069,6 +2068,16 @@ public:
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &video_data = data.find("video_data");
|
||||
if (video_data != data.end() && video_data->is_array())
|
||||
{
|
||||
for (const auto &video : *video_data)
|
||||
{
|
||||
auto decoded_data = base64_decode(video["data"].get<std::string>());
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const bool has_mtmd = ctx_server.impl->mctx != nullptr;
|
||||
@@ -2321,7 +2330,7 @@ public:
|
||||
}
|
||||
|
||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
||||
|
||||
// Handle content - can be string, null, or array
|
||||
// For multimodal content, we'll embed images/audio from separate fields
|
||||
@@ -2374,6 +2383,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else {
|
||||
// Use content as-is (already array or not last user message)
|
||||
@@ -2413,6 +2432,16 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i);
|
||||
} else if (!msg.tool_calls().empty()) {
|
||||
@@ -2737,6 +2766,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto predict_re_it = predict_metadata.find("reasoning_effort");
|
||||
if (predict_re_it != predict_metadata.end() && !predict_re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = predict_re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2864,6 +2904,16 @@ public:
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &video_data = data.find("video_data");
|
||||
if (video_data != data.end() && video_data->is_array())
|
||||
{
|
||||
for (const auto &video : *video_data)
|
||||
{
|
||||
auto decoded_data = base64_decode(video["data"].get<std::string>());
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process files
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
|
||||
TURBOQUANT_VERSION?=7d9715f1f071fa07c7b2ad3dbfd320b314139e65
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -4,21 +4,19 @@
|
||||
#
|
||||
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
|
||||
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
|
||||
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
|
||||
# server-side random per-instance marker) with the legacy "<__media__>"
|
||||
# literal. The fork branched before that PR, so server-common.cpp has no
|
||||
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
|
||||
# "<__media__>", and Go-side tooling falls back to that sentinel when the
|
||||
# backend does not expose media_marker, so substituting the literal keeps
|
||||
# behavior identical on the turboquant path.
|
||||
# 3. Revert the `common_params_speculative` field references to the
|
||||
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
|
||||
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
|
||||
# the turboquant fork branched before that PR and still exposes the flat
|
||||
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
|
||||
# below map the new nested paths back to the legacy flat names so the
|
||||
# shared grpc-server.cpp keeps compiling against the fork's common.h.
|
||||
# Drop this block once the fork rebases past #22397.
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file
|
||||
# so the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default and the option handler).
|
||||
# That field does not exist in the fork yet; drop this once it does.
|
||||
#
|
||||
# The fork used to lag upstream on the whole common_params_speculative refactor
|
||||
# (ggml-org/llama.cpp#22397/#22838/#22964), the model_tgt rename (#22838) and
|
||||
# get_media_marker (#21962), which required a much larger compat shim here
|
||||
# (flat-field sed renames + a coarse LOCALAI_LEGACY_LLAMA_CPP_SPEC define). The
|
||||
# fork has since rebased past all of those, so the only remaining gap is
|
||||
# checkpoint_min_step. If a future bump reintroduces a divergence, add a narrow
|
||||
# guard in grpc-server.cpp keyed on a fork-specific macro and inject it here
|
||||
# rather than resurrecting the coarse one.
|
||||
#
|
||||
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
|
||||
@@ -72,72 +70,20 @@ else
|
||||
echo "==> KV allow-list patch OK"
|
||||
fi
|
||||
|
||||
if grep -q 'get_media_marker()' "$SRC"; then
|
||||
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
|
||||
# Only one call site today (ModelMetadata), but replace all occurrences to
|
||||
# stay robust if upstream adds more. Use a temp file to avoid relying on
|
||||
# sed -i portability (the builder image uses GNU sed, but keeping this
|
||||
# consistent with the awk block above).
|
||||
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> get_media_marker() substitution OK"
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file so
|
||||
# the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default assignment and the option
|
||||
# handler). That field does not exist in the fork yet. Drop this block once
|
||||
# the fork rebases past the bump that added checkpoint_min_step.
|
||||
if grep -q '^#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP, skipping"
|
||||
else
|
||||
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
|
||||
fi
|
||||
|
||||
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
|
||||
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
|
||||
# Each substitution is the exact post-refactor path → legacy flat field.
|
||||
# Order doesn't matter because the source paths are disjoint, but we keep
|
||||
# the most-specific (mparams.path) first for readability.
|
||||
sed -E \
|
||||
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
|
||||
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
|
||||
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
|
||||
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
|
||||
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
|
||||
"$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> speculative field rename OK"
|
||||
else
|
||||
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
|
||||
fi
|
||||
|
||||
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
|
||||
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
|
||||
# exposes the field as `model` on `server_context_impl`. The two call sites
|
||||
# are in the Rerank and ModelMetadata RPC handlers.
|
||||
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
|
||||
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
|
||||
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> model_tgt rename OK"
|
||||
else
|
||||
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
|
||||
fi
|
||||
|
||||
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
|
||||
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
|
||||
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
|
||||
# draft.tensor_buft_overrides) introduced for the post-#22838 layout, the
|
||||
# draft.tensor_buft_overrides sentinel termination, and the
|
||||
# common_params::checkpoint_min_step default/option (added with the
|
||||
# 35c9b1f3 bump). Those blocks reference struct fields that simply do not
|
||||
# exist in the fork.
|
||||
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
|
||||
else
|
||||
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
|
||||
# Insert the define before the very first `#include` so it precedes all the
|
||||
# speculative-decoding code paths.
|
||||
echo "==> patching $SRC to define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top"
|
||||
# Insert the define before the very first `#include` so it precedes the
|
||||
# checkpoint_min_step references.
|
||||
awk '
|
||||
!done && /^#include/ {
|
||||
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
|
||||
print "#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP 1"
|
||||
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
|
||||
print ""
|
||||
done = 1
|
||||
@@ -145,13 +91,13 @@ else
|
||||
{ print }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
|
||||
echo "==> LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP define OK"
|
||||
fi
|
||||
|
||||
echo "==> all patches applied"
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
hip: port the turboquant CUDA additions that ggml's HIP shim doesn't cover
|
||||
|
||||
The turboquant fork adds/modifies a few ggml-cuda.cu spots with CUDA APIs
|
||||
that ggml's HIP (and MUSA) compatibility layer does not provide, breaking
|
||||
the -gpu-rocm-hipblas-turboquant build:
|
||||
|
||||
1. ggml_cuda_copy2d_across_devices() (host-staged cross-device copy for
|
||||
split mul_mat output) uses the CUDA 3D-peer copy APIs
|
||||
cudaMemcpy3DPeerParms / make_cudaPitchedPtr / make_cudaExtent /
|
||||
cudaMemcpy3DPeerAsync. HIP genuinely does not support these (see the
|
||||
fork's own comment "HIP does not support cudaMemcpy3DPeerAsync"), so
|
||||
guard the peer fast path with #if !defined(GGML_USE_HIP) &&
|
||||
!defined(GGML_USE_MUSA) -- matching how the fork already guards the
|
||||
same API for the sibling 2D copy -- and fall through to the existing
|
||||
cudaMemcpyAsync staging fallback below (functionally identical,
|
||||
slightly slower on multi-GPU ROCm).
|
||||
|
||||
2. ggml_backend_cuda_device_event_new() creates its event with plain
|
||||
cudaEventCreate, which ggml's HIP shim does not alias (it only aliases
|
||||
cudaEventCreateWithFlags). Use cudaEventCreateWithFlags(...,
|
||||
cudaEventDisableTiming) -- exactly what the rest of this file already
|
||||
does (cf. lines ~1034, ~3461) and HIP-safe.
|
||||
|
||||
CUDA builds are unaffected. Drop the relevant hunk once the fork HIP-ports
|
||||
these; apply-patches.sh fails fast if an anchor goes stale.
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 0427e6b..6352e6a 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -1933,6 +1933,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
size_t width, size_t height, cudaStream_t dst_stream, cudaStream_t src_stream) {
|
||||
|
||||
const auto & info = ggml_cuda_info();
|
||||
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // 3D-peer copy types unmapped by ggml's HIP/MUSA shim; use staging fallback below
|
||||
if (info.peer_access[src_device][dst_device]) {
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = dst_device;
|
||||
@@ -1942,6 +1943,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
p.extent = make_cudaExtent(width, height, 1);
|
||||
return cudaMemcpy3DPeerAsync(&p, dst_stream);
|
||||
}
|
||||
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
// Fallback: stage all rows through a single contiguous pinned buffer
|
||||
int prev_device = ggml_cuda_get_device();
|
||||
@@ -5714,7 +5716,7 @@ static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_
|
||||
ggml_cuda_set_device(dev_ctx->device);
|
||||
|
||||
cudaEvent_t event;
|
||||
- CUDA_CHECK(cudaEventCreate(&event));
|
||||
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
|
||||
|
||||
return new ggml_backend_event {
|
||||
/* .device = */ dev,
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
@@ -145,7 +146,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -175,7 +176,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -269,7 +270,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
|
||||
@@ -14,7 +14,7 @@ target_include_directories(gocrispasr PRIVATE
|
||||
# whisper. crispasr is the referencer; the backend static libs supply the
|
||||
# per-architecture symbols; ggml is the math/runtime base.
|
||||
target_link_libraries(gocrispasr PRIVATE
|
||||
crispasr
|
||||
crispasr-lib
|
||||
parakeet canary canary_ctc cohere granite_speech granite_nle
|
||||
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
|
||||
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=05e60432bcb5bc2113f8c395a41e86497c11504a
|
||||
CRISPASR_VERSION?=c29f6653a516a3001d923944dad8892072cc7334
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
10
backend/go/dllm/.gitignore
vendored
Normal file
10
backend/go/dllm/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
.cache/
|
||||
sources/
|
||||
build/
|
||||
package/
|
||||
dllm-grpc
|
||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||
# symlinked for local dev; the real sources live in dllm.cpp upstream.
|
||||
*.so
|
||||
*.so.*
|
||||
compile_commands.json
|
||||
97
backend/go/dllm/Makefile
Normal file
97
backend/go/dllm/Makefile
Normal file
@@ -0,0 +1,97 @@
|
||||
# dllm backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DLLM_VERSION?=<sha> so .github/bump_deps.sh
|
||||
# can find and update it - matches the whisper.cpp / parakeet-cpp / ds4
|
||||
# convention.
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree dllm.cpp build,
|
||||
# you can symlink the .so into this directory and skip the clone/cmake
|
||||
# steps entirely, e.g.:
|
||||
#
|
||||
# ln -sf /path/to/dllm.cpp/build/libdllm.so .
|
||||
# go build -o dllm-grpc .
|
||||
#
|
||||
# That's what the gated C-ABI binding smoke uses (DLLM_TEST_LIBRARY). The
|
||||
# default target below does the proper clone-at-pin + cmake build so CI
|
||||
# doesn't need a side-checkout.
|
||||
#
|
||||
# NOTE: github.com/mudler/dllm.cpp is still private (publishing is planned);
|
||||
# until then the anonymous clone below fails. Use the symlink shortcut above
|
||||
# with a local checkout, or a git credential helper with access to the repo.
|
||||
|
||||
# The pin below is the first commit carrying the multimodal C-ABI entry
|
||||
# points (dllm_capi_generate_mm / dllm_capi_generate_stream_mm) the
|
||||
# image-input path probes for; older libs still load, but image requests
|
||||
# then fail with "library predates the multimodal entry points".
|
||||
DLLM_VERSION?=e6dcf44cddd65845e3a0814a1c2282a5d90ee98a
|
||||
DLLM_REPO?=https://github.com/mudler/dllm.cpp
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
# libdllm.so is self-contained: dllm.cpp's CMakeLists statically absorbs ggml
|
||||
# (BUILD_SHARED_LIBS=OFF + PIC) into the shared lib, so dlopen needs no
|
||||
# libggml*.so alongside it, only system libs (libstdc++/libgomp/libc) the
|
||||
# runtime image already provides. Tests/CLI are upstream-only concerns.
|
||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DDLLM_BUILD_TESTS=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# Same arch set the sibling ggml backends (acestep/vibevoice/qwen3-tts) bake
|
||||
# for their cublas images; override for a native build.
|
||||
CUDA_ARCHITECTURES?=75-virtual;80-virtual;86-real;89-real
|
||||
|
||||
# dllm.cpp gates CUDA behind DLLM_CUDA (set(GGML_CUDA ... CACHE FORCE)), so
|
||||
# forward that instead of a bare -DGGML_CUDA=ON.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DDLLM_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="$(CUDA_ARCHITECTURES)"
|
||||
endif
|
||||
|
||||
.PHONY: dllm-grpc package build clean purge test all
|
||||
|
||||
all: dllm-grpc
|
||||
|
||||
# Clone the upstream dllm.cpp source at the pinned commit (ggml comes in as
|
||||
# a submodule). Directory acts as the target so make only re-clones when
|
||||
# missing. After a DLLM_VERSION bump, run 'make purge && make' to refetch.
|
||||
sources/dllm.cpp:
|
||||
mkdir -p sources/dllm.cpp
|
||||
cd sources/dllm.cpp && \
|
||||
git init -q && \
|
||||
git remote add origin $(DLLM_REPO) && \
|
||||
git fetch --depth 1 origin $(DLLM_VERSION) && \
|
||||
git checkout FETCH_HEAD && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Build the shared lib out-of-tree, then stage it next to the Go sources so
|
||||
# purego.Dlopen("libdllm.so") and the packaging step both pick it up.
|
||||
libdllm.so: sources/dllm.cpp
|
||||
cmake -B sources/dllm.cpp/build -S sources/dllm.cpp $(CMAKE_ARGS)
|
||||
cmake --build sources/dllm.cpp/build --config Release -j$(JOBS)
|
||||
cp -fv sources/dllm.cpp/build/libdllm.so ./
|
||||
|
||||
dllm-grpc: libdllm.so main.go capi.go
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o dllm-grpc .
|
||||
|
||||
package: dllm-grpc
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
# Test target. The C-ABI binding smoke is gated on DLLM_TEST_LIBRARY +
|
||||
# DLLM_TEST_TINY_MODEL; without them the gated specs auto-skip and only the
|
||||
# pure-Go helper specs run.
|
||||
test:
|
||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||
|
||||
clean: purge
|
||||
rm -rf libdllm.so* package dllm-grpc
|
||||
|
||||
purge:
|
||||
rm -rf sources/dllm.cpp
|
||||
326
backend/go/dllm/capi.go
Normal file
326
backend/go/dllm/capi.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package main
|
||||
|
||||
// Typed Go wrappers over dllm.cpp's flat C-ABI (include/dllm_capi.h, ABI v1).
|
||||
//
|
||||
// Contract highlights the wrappers encode (see the header + src/capi.cpp):
|
||||
// - tokenize_json/generate return malloc'd char* the CALLER owns: bound as
|
||||
// uintptr, copied with goStringFromCPtr, released via dllm_capi_free_string.
|
||||
// - last_error returns a BORROWED pointer (valid until the next call on the
|
||||
// same ctx): bound as a plain string (purego copies), never freed, and only
|
||||
// read AFTER the failing call has returned - reading it while a generate is
|
||||
// in flight on the same ctx violates the per-ctx serialization contract.
|
||||
// - All entry points except dllm_capi_cancel must be externally serialized
|
||||
// per ctx (one ctx = one concurrent generate/tokenize). Cancel only flips
|
||||
// an atomic and may be called from any goroutine mid-generate.
|
||||
// - No C++ exception crosses the boundary; failures land in last_error.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
)
|
||||
|
||||
// dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written
|
||||
// against; main.go refuses to start against a libdllm.so reporting another.
|
||||
const dllmABIVersion = 1
|
||||
|
||||
// purego-bound entry points from libdllm.so. Names match dllm_capi.h
|
||||
// exactly; loadCAPI (main.go) fills these in at boot.
|
||||
var (
|
||||
cppAbiVersion func() int32
|
||||
cppLoad func(ggufPath, paramsJSON string) uintptr
|
||||
cppFree func(ctx uintptr)
|
||||
cppLastError func(ctx uintptr) string // borrowed pointer: purego copies, do NOT free
|
||||
cppFreeString func(s uintptr)
|
||||
// malloc'd char* returns, hence uintptr (see loadCAPI's doc comment).
|
||||
cppTokenizeJSON func(ctx uintptr, text string) uintptr
|
||||
cppGenerate func(ctx uintptr, prompt, optsJSON string) uintptr
|
||||
// on_block/on_step are C function pointers produced by purego.NewCallback;
|
||||
// userData carries the streamCallStates registry key.
|
||||
cppGenerateStream func(ctx uintptr, prompt, optsJSON string, onBlock, onStep, userData uintptr) int32
|
||||
cppCancel func(ctx uintptr)
|
||||
)
|
||||
|
||||
// Optional multimodal entry points (dllm_capi.h's P4 surface). The ABI
|
||||
// version stays 1: presence is detected by PROBING the symbols with Dlsym at
|
||||
// boot (loadCAPI, mirroring the parakeet-cpp optional-symbol pattern). nil
|
||||
// means the loaded libdllm.so predates the mm surface; the wrappers below
|
||||
// then fail with errMMUnsupported instead of crashing on a nil call.
|
||||
var (
|
||||
cppGenerateMM func(ctx uintptr, prompt, imagesJSON, optsJSON string) uintptr
|
||||
cppGenerateStreamMM func(ctx uintptr, prompt, imagesJSON, optsJSON string, onBlock, onStep, userData uintptr) int32
|
||||
)
|
||||
|
||||
// mmImageMarker is the literal placeholder dllm_capi_generate_mm expands to
|
||||
// <boi> + soft-token placeholders + <eoi> (dllm_capi.h placeholder contract;
|
||||
// capi.cpp MM_MARKER). The prompt must carry exactly one marker per
|
||||
// images_json entry, in image order.
|
||||
const mmImageMarker = "<image>"
|
||||
|
||||
// errMMUnsupported is returned for image-bearing requests against an old
|
||||
// text-only libdllm.so (the Dlsym probe found no mm symbols).
|
||||
var errMMUnsupported = errors.New(
|
||||
"dllm: image input requires libdllm.so with the multimodal entry points (dllm_capi_generate_mm), but the loaded library predates them - rebuild/upgrade the dllm backend to use images")
|
||||
|
||||
// cMMSupported reports whether the loaded libdllm.so carries the multimodal
|
||||
// generate pair. Both symbols ship together (same dllm.cpp commit), but the
|
||||
// guard requires both anyway so a half-present surface can never dispatch.
|
||||
func cMMSupported() bool {
|
||||
return cppGenerateMM != nil && cppGenerateStreamMM != nil
|
||||
}
|
||||
|
||||
// cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION.
|
||||
func cAbiVersion() int32 {
|
||||
return cppAbiVersion()
|
||||
}
|
||||
|
||||
// cLoad opens the GGUF at path with the flat params JSON (e.g.
|
||||
// {"n_gpu_layers":99}). Returns 0 on failure; per the header contract there
|
||||
// is no ctx to carry the reason, the C side logs it to stderr (and
|
||||
// cLastError(0) only yields the static NULL-ctx message).
|
||||
func cLoad(path, paramsJSON string) uintptr {
|
||||
return cppLoad(path, paramsJSON)
|
||||
}
|
||||
|
||||
// cFree releases a ctx; safe on 0 (delete nullptr).
|
||||
func cFree(h uintptr) {
|
||||
cppFree(h)
|
||||
}
|
||||
|
||||
// cLastError returns the ctx's last error message (or the static NULL-ctx
|
||||
// message for h==0). The C pointer is borrowed and only valid until the next
|
||||
// call on the same ctx; purego's string return copies it immediately, so the
|
||||
// returned Go string is safe to keep. Must not be called while another call
|
||||
// on the same ctx is in flight.
|
||||
func cLastError(h uintptr) string {
|
||||
return cppLastError(h)
|
||||
}
|
||||
|
||||
// lastErrorOr is cLastError with a fallback for the empty-message case, so
|
||||
// wrapped errors never end in ": ".
|
||||
func lastErrorOr(h uintptr, fallback string) string {
|
||||
if msg := cLastError(h); msg != "" {
|
||||
return msg
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// cTokenizeJSON tokenizes text (the C side prepends bos per vocab.add_bos)
|
||||
// and returns the token ids as a JSON array string, e.g. "[2,18]".
|
||||
func cTokenizeJSON(h uintptr, text string) (string, error) {
|
||||
ret := cppTokenizeJSON(h, text)
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("dllm: tokenize failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
out := goStringFromCPtr(ret)
|
||||
cppFreeString(ret)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// cGenerate runs a blocking generation and returns the detokenized text.
|
||||
// optsJSON must be a FLAT JSON object of scalars (use buildOptsJSON); the C
|
||||
// parser rejects nested objects/arrays. NULL return -> last_error (read only
|
||||
// after the call returned, per the serialization contract); a cancelled call
|
||||
// surfaces as the "cancelled" message.
|
||||
func cGenerate(h uintptr, prompt, optsJSON string) (string, error) {
|
||||
ret := cppGenerate(h, prompt, optsJSON)
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("dllm: generate failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
out := goStringFromCPtr(ret)
|
||||
cppFreeString(ret)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// cGenerateMM is cGenerate's multimodal counterpart. imagesJSON is the flat
|
||||
// JSON array of image entries (data: base64 URIs here; the C side also takes
|
||||
// file paths) and the prompt must carry one mmImageMarker per entry - the
|
||||
// engine enforces the 1:1 match and reports mismatches through last_error.
|
||||
func cGenerateMM(h uintptr, prompt, imagesJSON, optsJSON string) (string, error) {
|
||||
if !cMMSupported() {
|
||||
return "", errMMUnsupported
|
||||
}
|
||||
ret := cppGenerateMM(h, prompt, imagesJSON, optsJSON)
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("dllm: generate_mm failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
out := goStringFromCPtr(ret)
|
||||
cppFreeString(ret)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// streamCallState carries the Go callbacks for one in-flight
|
||||
// cGenerateStream call; the registry key travels through C as user_data.
|
||||
// The map shape mirrors the whisper backend's streamCallStates: only one
|
||||
// entry per ctx is ever live (the C-ABI is serialized per ctx), but keying
|
||||
// by call survives multiple models/processes sharing the package.
|
||||
type streamCallState struct {
|
||||
onBlock func(text string)
|
||||
onStep func(step, total int, preview string)
|
||||
}
|
||||
|
||||
var (
|
||||
streamCallStates sync.Map // uint64 -> *streamCallState
|
||||
streamCallSeq atomic.Uint64
|
||||
|
||||
// purego.NewCallback allocates a finite, never-released callback slot, so
|
||||
// the two trampolines are created exactly once and reused across calls.
|
||||
streamCbOnce sync.Once
|
||||
blockCbPtr uintptr
|
||||
stepCbPtr uintptr
|
||||
)
|
||||
|
||||
// onBlockTrampoline is the Go side of dllm_block_cb. It runs on the C
|
||||
// calling thread, mid-generate: keep it tiny and non-blocking (callers that
|
||||
// bridge to goroutines must hand off via buffered channels). The text
|
||||
// pointer is only valid for the duration of the invocation, so it is copied
|
||||
// to a Go string immediately.
|
||||
func onBlockTrampoline(text uintptr, userData uintptr) {
|
||||
v, ok := streamCallStates.Load(uint64(userData))
|
||||
if !ok {
|
||||
return // call already torn down
|
||||
}
|
||||
state := v.(*streamCallState)
|
||||
if state.onBlock != nil {
|
||||
state.onBlock(goStringFromCPtr(text))
|
||||
}
|
||||
}
|
||||
|
||||
// onStepTrampoline is the Go side of dllm_step_cb; same threading and
|
||||
// lifetime caveats as onBlockTrampoline.
|
||||
func onStepTrampoline(step int32, totalSteps int32, canvasPreview uintptr, userData uintptr) {
|
||||
v, ok := streamCallStates.Load(uint64(userData))
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
state := v.(*streamCallState)
|
||||
if state.onStep != nil {
|
||||
state.onStep(int(step), int(totalSteps), goStringFromCPtr(canvasPreview))
|
||||
}
|
||||
}
|
||||
|
||||
// withStreamCallbacks registers onBlock/onStep in the trampoline registry
|
||||
// for the duration of one streaming C call and invokes call with the C
|
||||
// function pointers (NULL for absent callbacks, so the C side skips the
|
||||
// per-block / per-step detokenize work entirely) plus the registry key to
|
||||
// pass as user_data. Shared by the text and multimodal stream wrappers.
|
||||
func withStreamCallbacks(onBlock func(text string), onStep func(step, total int, preview string), call func(blockPtr, stepPtr, userData uintptr) int32) int32 {
|
||||
streamCbOnce.Do(func() {
|
||||
blockCbPtr = purego.NewCallback(onBlockTrampoline)
|
||||
stepCbPtr = purego.NewCallback(onStepTrampoline)
|
||||
})
|
||||
|
||||
id := streamCallSeq.Add(1)
|
||||
streamCallStates.Store(id, &streamCallState{onBlock: onBlock, onStep: onStep})
|
||||
defer streamCallStates.Delete(id)
|
||||
|
||||
var blockPtr, stepPtr uintptr
|
||||
if onBlock != nil {
|
||||
blockPtr = blockCbPtr
|
||||
}
|
||||
if onStep != nil {
|
||||
stepPtr = stepCbPtr
|
||||
}
|
||||
return call(blockPtr, stepPtr, uintptr(id))
|
||||
}
|
||||
|
||||
// cGenerateStream runs a generation with per-committed-block (onBlock) and
|
||||
// per-denoising-step (onStep) callbacks; either may be nil. The callbacks
|
||||
// run on the C thread (see the trampoline docs). Returns an error carrying
|
||||
// last_error on failure; cancellation surfaces as the "cancelled" message.
|
||||
func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error {
|
||||
rc := withStreamCallbacks(onBlock, onStep, func(blockPtr, stepPtr, userData uintptr) int32 {
|
||||
return cppGenerateStream(h, prompt, optsJSON, blockPtr, stepPtr, userData)
|
||||
})
|
||||
if rc != 0 {
|
||||
return fmt.Errorf("dllm: generate_stream failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cGenerateStreamMM is cGenerateStream's multimodal counterpart; see
|
||||
// cGenerateMM for the imagesJSON/marker contract.
|
||||
func cGenerateStreamMM(h uintptr, prompt, imagesJSON, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error {
|
||||
if !cMMSupported() {
|
||||
return errMMUnsupported
|
||||
}
|
||||
rc := withStreamCallbacks(onBlock, onStep, func(blockPtr, stepPtr, userData uintptr) int32 {
|
||||
return cppGenerateStreamMM(h, prompt, imagesJSON, optsJSON, blockPtr, stepPtr, userData)
|
||||
})
|
||||
if rc != 0 {
|
||||
return fmt.Errorf("dllm: generate_stream_mm failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cCancel requests cancellation of the in-flight generate on h. This is the
|
||||
// ONE entry point safe to call from any goroutine while a generate runs (it
|
||||
// only flips an atomic). Note the cancel-reset race from the header: each
|
||||
// generate resets the flag on entry, so a watchdog should re-issue cancel if
|
||||
// the call has not returned.
|
||||
func cCancel(h uintptr) {
|
||||
cppCancel(h)
|
||||
}
|
||||
|
||||
// buildOptsJSON renders generation options as the flat JSON object the
|
||||
// C-ABI expects (known keys: n_predict, blocks, seed, eb_*, kv_cache). The
|
||||
// C-side scanner only understands scalar number/string values and rejects
|
||||
// nested objects/arrays loudly; bools are rejected here too because the
|
||||
// scanner has no concept of them. Fail loud rather than let an option be
|
||||
// silently misread.
|
||||
//
|
||||
// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g.
|
||||
// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys
|
||||
// (kv_cache: auto|on|off) can contain those bytes today; if one ever does,
|
||||
// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString.
|
||||
func buildOptsJSON(opts map[string]any) (string, error) {
|
||||
if len(opts) == 0 {
|
||||
return "{}", nil
|
||||
}
|
||||
for k, v := range opts {
|
||||
switch v.(type) {
|
||||
case string,
|
||||
int, int8, int16, int32, int64,
|
||||
uint, uint8, uint16, uint32, uint64,
|
||||
float32, float64,
|
||||
json.Number:
|
||||
// scalar: fine
|
||||
default:
|
||||
return "", fmt.Errorf("dllm: opts key %q has non-scalar value %T (the C-ABI only accepts flat number/string scalars)", k, v)
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(opts)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("dllm: marshal opts: %w", err)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is
|
||||
// the raw pointer returned by purego from the C-ABI (a malloc'd buffer the
|
||||
// caller owns, or a callback argument only valid during the invocation);
|
||||
// owning callers must free it via cppFreeString after the copy lands.
|
||||
//
|
||||
// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check,
|
||||
// which can't distinguish a C-owned heap pointer from Go-managed memory (the
|
||||
// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting
|
||||
// through &cptr below is equivalent at runtime and keeps plain `go vet`
|
||||
// clean. It is safe either way: the pointer addresses C memory the Go GC
|
||||
// neither tracks nor moves, and we dereference it immediately to copy the
|
||||
// bytes out.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
}
|
||||
return string(unsafe.Slice((*byte)(p), n))
|
||||
}
|
||||
622
backend/go/dllm/dllm.go
Normal file
622
backend/go/dllm/dllm.go
Normal file
@@ -0,0 +1,622 @@
|
||||
package main
|
||||
|
||||
// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models).
|
||||
//
|
||||
// Wiring overview:
|
||||
// - Load opens the GGUF via dllm_capi_load and starts the per-model worker
|
||||
// goroutine that serializes every C call (see submit).
|
||||
// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the
|
||||
// request carries raw messages (use_tokenizer_template), the backend owns
|
||||
// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies
|
||||
// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend.
|
||||
// - The legacy Predict / PredictStream methods delegate to the rich pair
|
||||
// (cloud-proxy precedent); the gRPC server prefers the rich path anyway.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// The gRPC server cancels in-flight generations on client disconnect only
|
||||
// for backends advertising the Cancellable capability; keep Dllm pinned to
|
||||
// it so a signature drift fails the build, not the disconnect path.
|
||||
var _ grpc.Cancellable = (*Dllm)(nil)
|
||||
|
||||
// generator is the seam between the backend wiring and the dllm.cpp C-ABI:
|
||||
// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON
|
||||
// family, while tests substitute a fake to exercise prompt construction,
|
||||
// parsing and serialization without libdllm.so.
|
||||
type generator interface {
|
||||
generate(prompt, optsJSON string) (string, error)
|
||||
// generateStream invokes onBlock once per committed diffusion block, on
|
||||
// the thread running the C call, before returning.
|
||||
generateStream(prompt, optsJSON string, onBlock func(text string)) error
|
||||
// generateMM / generateStreamMM are the multimodal counterparts:
|
||||
// imagesJSON is a flat JSON array of data: base64 URIs and the prompt
|
||||
// carries one mmImageMarker per entry (dllm_capi.h placeholder
|
||||
// contract). Against an old text-only libdllm.so they fail with
|
||||
// errMMUnsupported.
|
||||
generateMM(prompt, imagesJSON, optsJSON string) (string, error)
|
||||
generateStreamMM(prompt, imagesJSON, optsJSON string, onBlock func(text string)) error
|
||||
tokenizeJSON(text string) (string, error)
|
||||
// cancel is the ONE entry point safe to call concurrently with an
|
||||
// in-flight generate on the same ctx (dllm_capi.h: it only flips an
|
||||
// atomic; everything else must be externally serialized per ctx).
|
||||
cancel()
|
||||
free()
|
||||
}
|
||||
|
||||
// capiGenerator is the production generator over one dllm_ctx handle.
|
||||
type capiGenerator struct {
|
||||
h uintptr
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) {
|
||||
return cGenerate(g.h, prompt, optsJSON)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
|
||||
// on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is
|
||||
// passed as nil for now: a future progress hook for the React UI can
|
||||
// plumb it through without touching the C binding.
|
||||
return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generateMM(prompt, imagesJSON, optsJSON string) (string, error) {
|
||||
return cGenerateMM(g.h, prompt, imagesJSON, optsJSON)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generateStreamMM(prompt, imagesJSON, optsJSON string, onBlock func(text string)) error {
|
||||
// on_step is nil for the same reason as generateStream.
|
||||
return cGenerateStreamMM(g.h, prompt, imagesJSON, optsJSON, onBlock, nil)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) tokenizeJSON(text string) (string, error) {
|
||||
return cTokenizeJSON(g.h, text)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) cancel() {
|
||||
cCancel(g.h)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) free() {
|
||||
cFree(g.h)
|
||||
}
|
||||
|
||||
// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts
|
||||
// one backend process per model).
|
||||
type Dllm struct {
|
||||
base.Base
|
||||
|
||||
gen generator
|
||||
// genOpts holds the model-level generation overrides parsed from
|
||||
// ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes
|
||||
// them per-generate, not per-load, so they are merged into every
|
||||
// request's opts JSON (requestOptsJSON).
|
||||
genOpts map[string]any
|
||||
|
||||
// jobs is the per-model worker queue. dllm_capi.h requires every entry
|
||||
// point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one
|
||||
// ctx = one concurrent generate/tokenize; last_error is unsafe to read
|
||||
// while a call is in flight). A single goroutine owning all C calls makes
|
||||
// that contract structural instead of relying on lock discipline.
|
||||
jobs chan func()
|
||||
workerWG sync.WaitGroup
|
||||
|
||||
// genMu guards gen against Free racing in-flight requests: requests hold
|
||||
// the read lock for their full duration (they stay concurrent with each
|
||||
// other - the worker still serializes the C calls), Free takes the write
|
||||
// lock so it can only run when no request is in flight.
|
||||
genMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (d *Dllm) startWorker() {
|
||||
d.jobs = make(chan func())
|
||||
d.workerWG.Add(1)
|
||||
go func() {
|
||||
defer d.workerWG.Done()
|
||||
for job := range d.jobs {
|
||||
job()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// submit runs job on the worker goroutine and waits for it to finish.
|
||||
// Concurrent gRPC requests therefore queue up and execute one at a time
|
||||
// against the single dllm_ctx.
|
||||
func (d *Dllm) submit(job func()) {
|
||||
done := make(chan struct{})
|
||||
d.jobs <- func() {
|
||||
defer close(done)
|
||||
job()
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
// Load opens the GGUF and prepares the worker. Load-time engine parameters
|
||||
// travel as the flat params JSON of dllm_capi_load; generation overrides
|
||||
// from Options are stored for per-request opts JSON instead (the C-ABI has
|
||||
// no per-load sampler state).
|
||||
func (d *Dllm) Load(opts *pb.ModelOptions) error {
|
||||
if d.gen != nil {
|
||||
return errors.New("dllm: model already loaded")
|
||||
}
|
||||
|
||||
params := map[string]any{
|
||||
"n_gpu_layers": opts.GetNGPULayers(),
|
||||
}
|
||||
if opts.GetThreads() > 0 {
|
||||
params["n_threads"] = opts.GetThreads()
|
||||
}
|
||||
if opts.GetContextSize() > 0 {
|
||||
params["ctx_len"] = opts.GetContextSize()
|
||||
}
|
||||
paramsJSON, err := buildOptsJSON(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.genOpts = parseModelGenOpts(opts.GetOptions())
|
||||
|
||||
h := cLoad(opts.GetModelFile(), paramsJSON)
|
||||
if h == 0 {
|
||||
// No ctx exists on load failure, so last_error(NULL) only carries the
|
||||
// static NULL-ctx message; the real reason is on the backend's stderr.
|
||||
return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)",
|
||||
opts.GetModelFile(), lastErrorOr(0, "unknown error"))
|
||||
}
|
||||
d.gen = &capiGenerator{h: h}
|
||||
d.startWorker()
|
||||
xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Free releases the dllm ctx and stops the worker. Safe when never loaded.
|
||||
//
|
||||
// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the
|
||||
// model-unload path around line 764) calls Free with no locking of its own,
|
||||
// and base.Base provides none either. Without it a request racing Free would
|
||||
// panic sending on the closed jobs channel - or worse, generate on a freed C
|
||||
// ctx. Holding genMu until gen is nil also turns post-Free requests into a
|
||||
// clean "model not loaded" error instead of a crash.
|
||||
func (d *Dllm) Free() error {
|
||||
d.genMu.Lock()
|
||||
defer d.genMu.Unlock()
|
||||
if d.gen == nil {
|
||||
return nil
|
||||
}
|
||||
d.submit(d.gen.free)
|
||||
close(d.jobs)
|
||||
d.workerWG.Wait()
|
||||
d.gen = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel requests cancellation of the in-flight generate (the
|
||||
// grpc.Cancellable capability). The gRPC server arms it via
|
||||
// context.AfterFunc on the request/stream context, so a client
|
||||
// disconnect or timeout aborts the generation server-side - the same
|
||||
// semantics the llama.cpp C++ backend gets from polling IsCancelled().
|
||||
// It deliberately bypasses the worker queue: dllm_capi_cancel is the one
|
||||
// call the C-ABI allows from any goroutine mid-generate (it only flips
|
||||
// an atomic).
|
||||
//
|
||||
// Note dllm_capi.h's cancel-reset race: each generate resets the flag on
|
||||
// entry, so a Cancel racing a NEW generate on the same ctx can be lost
|
||||
// (and, with requests queued on the worker, it aborts whichever generate
|
||||
// is currently running). The single-flag granularity is acceptable here
|
||||
// because the server de-registers the hook on normal completion and one
|
||||
// backend process serves one model.
|
||||
func (d *Dllm) Cancel() {
|
||||
// RLock so a server-side AfterFunc firing in the window between a
|
||||
// request finishing and a model unload cannot touch a freed C ctx
|
||||
// (Free holds the write lock while tearing gen down). cancel() is the
|
||||
// one C call that is safe concurrently with an in-flight generate, so
|
||||
// taking a read lock here cannot deadlock against request holders.
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen != nil {
|
||||
d.gen.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to
|
||||
// the engine. Options is a shared free-form bag (other layers put their own
|
||||
// entries there), so unknown keys are skipped with a warning, not an error.
|
||||
var dllmGenOptKeys = map[string]bool{
|
||||
"blocks": true,
|
||||
"kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3
|
||||
}
|
||||
|
||||
// parseModelGenOpts parses "key:value" Options entries into the flat scalar
|
||||
// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler
|
||||
// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by
|
||||
// first successful parse (int, then float, else string) to match the C
|
||||
// scanner's number/string scalars.
|
||||
func parseModelGenOpts(options []string) map[string]any {
|
||||
out := map[string]any{}
|
||||
for _, o := range options {
|
||||
key, val, found := strings.Cut(o, ":")
|
||||
if !found {
|
||||
xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o)
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] {
|
||||
xlog.Debug("dllm: ignoring unrecognized option", "key", key)
|
||||
continue
|
||||
}
|
||||
out[key] = parseScalarOpt(val)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseScalarOpt(v string) any {
|
||||
if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return iv
|
||||
}
|
||||
if fv, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return fv
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default
|
||||
// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat
|
||||
// template guards every thinking branch with `enable_thinking is defined and
|
||||
// enable_thinking`, i.e. thinking is opt-in for this model family, and the
|
||||
// no-thinking render pre-closes an empty thought channel that the OFF
|
||||
// default must produce.
|
||||
func metadataEnableThinking(opts *pb.PredictOptions) bool {
|
||||
v := opts.GetMetadata()["enable_thinking"]
|
||||
return v == "true" || v == "1"
|
||||
}
|
||||
|
||||
// buildPrompt resolves the prompt for a request. With use_tokenizer_template
|
||||
// and raw messages the backend owns templating (RenderGemma4, including the
|
||||
// mmImageMarker injection for opts.Images) and the output is in the known
|
||||
// gemma4 format, so parse=true. Without it the caller templated the prompt
|
||||
// themselves (LocalAI's Go templates + PEG fallback, or a bare completion):
|
||||
// the prompt passes through verbatim - for image requests it must already
|
||||
// carry one literal mmImageMarker per image (the engine enforces the 1:1
|
||||
// match) - and the output is NOT gemma4-parsed - it is emitted as plain
|
||||
// content and the Go side's extraction applies, as for any non-autoparsing
|
||||
// backend.
|
||||
func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) {
|
||||
if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 {
|
||||
prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), len(opts.GetImages()), metadataEnableThinking(opts), true)
|
||||
return prompt, true, err
|
||||
}
|
||||
return opts.GetPrompt(), false, nil
|
||||
}
|
||||
|
||||
// imagesJSON renders opts.Images as the flat JSON array of data: URIs the mm
|
||||
// C-ABI expects, or "" when the request carries no images. The entries arrive
|
||||
// as RAW base64 payloads: LocalAI's OpenAI layer decodes every image_url /
|
||||
// image content part (URL download or data: URI) to plain base64 via
|
||||
// utils.GetContentURIAsBase64 (core/http/middleware/request.go) and core
|
||||
// flattens them into PredictOptions.Images (core/backend/llm.go). The
|
||||
// hardcoded image/jpeg mime mirrors the llama.cpp backend's re-wrapping
|
||||
// convention (grpc-server.cpp, "data:image/jpeg;base64," + images(i)); the
|
||||
// engine ignores the declared mime and sniffs the real format from the
|
||||
// decoded bytes (stb_image), so PNG/BMP payloads work through it too.
|
||||
func imagesJSON(images []string) (string, error) {
|
||||
if len(images) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
uris := make([]string, len(images))
|
||||
for i, img := range images {
|
||||
// dllm_capi.h: array entries are read VERBATIM up to the closing
|
||||
// quote, with NO escape handling. json.Marshal would escape these
|
||||
// bytes and the C side would misparse the entry, so fail loud (they
|
||||
// can never appear in genuine base64 anyway).
|
||||
if strings.ContainsAny(img, "\"\\") {
|
||||
return "", fmt.Errorf("dllm: image %d is not base64 (contains a quote or backslash; PredictOptions.Images entries must be raw base64 payloads)", i)
|
||||
}
|
||||
uris[i] = "data:image/jpeg;base64," + img
|
||||
}
|
||||
b, err := json.Marshal(uris)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("dllm: marshal images: %w", err)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// requestOptsJSON merges the model-level overrides with the request's
|
||||
// sampling fields into the flat opts JSON for one generate call.
|
||||
func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) {
|
||||
m := make(map[string]any, len(d.genOpts)+2)
|
||||
for k, v := range d.genOpts {
|
||||
m[k] = v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
// The engine rounds n_predict UP to a whole number of diffusion
|
||||
// blocks (the canvas is denoised block-wise), so the completion may
|
||||
// run slightly past the requested budget. Tokens==0 omits the key so
|
||||
// the C-ABI default of 256 applies (hardcoded in capi.cpp's
|
||||
// parse_gen_opts, independent of canvas_length).
|
||||
m["n_predict"] = n
|
||||
}
|
||||
if s := opts.GetSeed(); s > 0 {
|
||||
// The engine seeds mt19937 with explicit non-negative seeds. Seed<=0
|
||||
// is omitted: proto3 cannot distinguish 0 from unset, and negative
|
||||
// values conventionally mean "random" across LocalAI backends.
|
||||
m["seed"] = s
|
||||
}
|
||||
return buildOptsJSON(m)
|
||||
}
|
||||
|
||||
// prepareRequest is the shared prologue of the rich methods: resolve the
|
||||
// prompt (and whether the output gets gemma4-parsed) and build the per-call
|
||||
// opts JSON plus the images JSON ("" for text-only requests, which routes
|
||||
// the call through the text generate entry points).
|
||||
func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON, imgJSON string, err error) {
|
||||
// Fail loud on media the engine has no path for, instead of silently
|
||||
// generating from a prompt that ignores them.
|
||||
if len(opts.GetVideos()) > 0 || len(opts.GetAudios()) > 0 {
|
||||
return "", false, "", "", errors.New("dllm: video/audio input is not supported (images only)")
|
||||
}
|
||||
prompt, parse, err = buildPrompt(opts)
|
||||
if err != nil {
|
||||
return "", false, "", "", err
|
||||
}
|
||||
optsJSON, err = d.requestOptsJSON(opts)
|
||||
if err != nil {
|
||||
return "", false, "", "", err
|
||||
}
|
||||
imgJSON, err = imagesJSON(opts.GetImages())
|
||||
if err != nil {
|
||||
return "", false, "", "", err
|
||||
}
|
||||
return prompt, parse, optsJSON, imgJSON, nil
|
||||
}
|
||||
|
||||
// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary
|
||||
// detokenization and byte-fallback tokens can produce invalid UTF-8, and
|
||||
// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so
|
||||
// every string destined for a Reply/ChatDelta must pass through here (or
|
||||
// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely
|
||||
// undecodable: replace with U+FFFD rather than crash the stream.
|
||||
func sanitizeUTF8(s string) string {
|
||||
if utf8.ValidString(s) {
|
||||
return s
|
||||
}
|
||||
return strings.ToValidUTF8(s, "<22>")
|
||||
}
|
||||
|
||||
// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte
|
||||
// (1 for bytes that can never lead a multi-byte sequence, so they are never
|
||||
// held back and fall through to sanitizeUTF8's replacement).
|
||||
func utf8SeqLen(b byte) int {
|
||||
switch {
|
||||
case b&0xE0 == 0xC0:
|
||||
return 2
|
||||
case b&0xF0 == 0xE0:
|
||||
return 3
|
||||
case b&0xF8 == 0xF0:
|
||||
return 4
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// splitValidUTF8 prepends the previous block's carry to the new block and
|
||||
// splits the result into text safe to emit now and a trailing INCOMPLETE
|
||||
// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block:
|
||||
// the per-block detokenize can split a multi-byte character across block
|
||||
// boundaries (llama.cpp's grpc-server holds back the same way). Only a
|
||||
// suffix that can still become a valid rune is withheld; bytes that are
|
||||
// already undecodable are replaced immediately so the carry stays bounded.
|
||||
func splitValidUTF8(carry, block string) (emit, newCarry string) {
|
||||
s := carry + block
|
||||
cut := len(s)
|
||||
for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- {
|
||||
b := s[i]
|
||||
if b < utf8.RuneSelf {
|
||||
break // ASCII: everything before the tail scan is complete
|
||||
}
|
||||
if !utf8.RuneStart(b) {
|
||||
continue // continuation byte: keep looking for its leading byte
|
||||
}
|
||||
// Leading byte: hold the sequence back iff it declares more bytes
|
||||
// than the stream has produced so far (it may complete next block).
|
||||
if utf8SeqLen(b) > len(s)-i {
|
||||
cut = i
|
||||
}
|
||||
break
|
||||
}
|
||||
return sanitizeUTF8(s[:cut]), s[cut:]
|
||||
}
|
||||
|
||||
// PredictRich is the non-streaming inference path (grpc.AIModelRich).
|
||||
// Returns one Reply whose Message is the aggregated assistant content and
|
||||
// whose ChatDeltas carry the parsed content/reasoning/tool-call events.
|
||||
func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return nil, grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
prompt, parse, optsJSON, imgJSON, err := d.prepareRequest(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out string
|
||||
var genErr error
|
||||
d.submit(func() {
|
||||
if imgJSON != "" {
|
||||
out, genErr = d.gen.generateMM(prompt, imgJSON, optsJSON)
|
||||
} else {
|
||||
out, genErr = d.gen.generate(prompt, optsJSON)
|
||||
}
|
||||
})
|
||||
if genErr != nil {
|
||||
return nil, genErr
|
||||
}
|
||||
// Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings
|
||||
// must be valid or grpc-go fails the whole reply at marshal time.
|
||||
out = sanitizeUTF8(out)
|
||||
|
||||
if !parse {
|
||||
// Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt).
|
||||
return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil
|
||||
}
|
||||
|
||||
// The prompt renders with add_generation_prompt; both thinking modes
|
||||
// leave the model starting in content state (see the Gemma4Parser header
|
||||
// comment), hence NewGemma4Parser(false).
|
||||
parser := NewGemma4Parser(false)
|
||||
if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil {
|
||||
return reply, nil
|
||||
}
|
||||
// Everything was markers (or out was empty): an empty but non-nil Reply.
|
||||
return &pb.Reply{}, nil
|
||||
}
|
||||
|
||||
// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one
|
||||
// Reply per committed diffusion block that produced deltas. Per the
|
||||
// interface contract the channel is only sent into here - the gRPC server
|
||||
// closes it after this returns (opposite to legacy PredictStream).
|
||||
func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
prompt, parse, optsJSON, imgJSON, err := d.prepareRequest(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var parser *Gemma4Parser
|
||||
if parse {
|
||||
parser = NewGemma4Parser(false)
|
||||
}
|
||||
// emit runs inside onBlock, i.e. on the thread driving the C generate.
|
||||
// Sending on results can block on a slow consumer, but the server-side
|
||||
// pump (pkg/grpc/server.go PredictStream) drains continuously and drops
|
||||
// undeliverable sends, so this backpressure is brief and bounded - and
|
||||
// pausing the diffusion loop under it is the desired behavior anyway.
|
||||
emit := func(text string) {
|
||||
if !parse {
|
||||
if text != "" {
|
||||
results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}}
|
||||
}
|
||||
return
|
||||
}
|
||||
deltas := parser.Feed(text)
|
||||
if reply := replyFromDeltas(deltas); reply != nil {
|
||||
results <- reply
|
||||
}
|
||||
}
|
||||
// onBlock guards emit (and through it the parser) against invalid UTF-8:
|
||||
// a multi-byte character split across block boundaries is held back until
|
||||
// it completes (see splitValidUTF8), so proto3 marshaling never fails.
|
||||
var carry string
|
||||
onBlock := func(block string) {
|
||||
var text string
|
||||
text, carry = splitValidUTF8(carry, block)
|
||||
emit(text)
|
||||
}
|
||||
|
||||
var genErr error
|
||||
d.submit(func() {
|
||||
if imgJSON != "" {
|
||||
genErr = d.gen.generateStreamMM(prompt, imgJSON, optsJSON, onBlock)
|
||||
} else {
|
||||
genErr = d.gen.generateStream(prompt, optsJSON, onBlock)
|
||||
}
|
||||
})
|
||||
if genErr != nil {
|
||||
return genErr
|
||||
}
|
||||
if carry != "" {
|
||||
// The stream ended mid-sequence: the held-back bytes can no longer
|
||||
// complete, so flush them through the U+FFFD last resort.
|
||||
emit(sanitizeUTF8(carry))
|
||||
}
|
||||
if parse {
|
||||
if reply := replyFromDeltas(parser.Close()); reply != nil {
|
||||
results <- reply
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply,
|
||||
// or nil when the batch is empty (markers consumed, nothing emitted yet).
|
||||
// Message mirrors the batch's content text so legacy chan-string consumers
|
||||
// see exactly the displayed tokens.
|
||||
func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
var content strings.Builder
|
||||
for _, delta := range deltas {
|
||||
content.WriteString(delta.GetContent())
|
||||
}
|
||||
return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas}
|
||||
}
|
||||
|
||||
// Predict is the legacy (string, error) signature; the gRPC server prefers
|
||||
// PredictRich, this exists for non-rich callers (cloud-proxy precedent).
|
||||
func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
reply, err := d.PredictRich(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.GetMessage()), nil
|
||||
}
|
||||
|
||||
// PredictStream is the legacy chan-string path: rich replies reduced to
|
||||
// their content text. Note the inverted channel ownership - the LEGACY
|
||||
// contract requires the impl to close the channel.
|
||||
func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
defer close(results)
|
||||
richCh := make(chan *pb.Reply)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.PredictStreamRich(opts, richCh)
|
||||
close(richCh)
|
||||
}()
|
||||
for reply := range richCh {
|
||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
||||
results <- string(msg)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C
|
||||
// side prepends bos per the vocab) and decodes the returned id array.
|
||||
func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
var out string
|
||||
var tokErr error
|
||||
d.submit(func() {
|
||||
out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt())
|
||||
})
|
||||
if tokErr != nil {
|
||||
return pb.TokenizationResponse{}, tokErr
|
||||
}
|
||||
var tokens []int32
|
||||
if err := json.Unmarshal([]byte(out), &tokens); err != nil {
|
||||
return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err)
|
||||
}
|
||||
return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil
|
||||
}
|
||||
1098
backend/go/dllm/dllm_test.go
Normal file
1098
backend/go/dllm/dllm_test.go
Normal file
File diff suppressed because it is too large
Load Diff
562
backend/go/dllm/gemma4_parser.go
Normal file
562
backend/go/dllm/gemma4_parser.go
Normal file
@@ -0,0 +1,562 @@
|
||||
// Gemma4 (DiffusionGemma) streaming output parser: raw model text, fed in
|
||||
// arbitrary fragments (per committed diffusion block; a fragment can split
|
||||
// anywhere, including mid-marker and mid-payload), is turned into
|
||||
// pb.ChatDelta events (content / reasoning_content / tool_calls).
|
||||
//
|
||||
// Normative sources:
|
||||
// - The chat template embedded at the top of gemma4_renderer.go ("tpl L<n>"
|
||||
// citations below refer to its numbered lines). The OUTPUT format mirrors
|
||||
// what the template renders for assistant history: thought channels
|
||||
// (<|channel>thought\n ... <channel|>, tpl L240), tool calls
|
||||
// (<|tool_call>call:name{...}<tool_call|>, tpl L246-L257) and turn ends
|
||||
// (<turn|>, tpl L351).
|
||||
// - vLLM PR #45163: vllm/tool_parsers/gemma4_tool_parser.py (marker
|
||||
// handling, the call:name{...} argument grammar and its decoder, ported
|
||||
// below) and vllm/reasoning/gemma4_reasoning_parser.py (channel markers,
|
||||
// the "thought\n" role label, is_reasoning_end semantics).
|
||||
//
|
||||
// Initial state (derived from the generation prompt, tpl L356-L362, see
|
||||
// RenderGemma4):
|
||||
// - enable_thinking=false: the prompt ends with "<|turn>model\n" +
|
||||
// "<|channel>thought\n<channel|>" - an EMPTY thought channel, pre-opened
|
||||
// AND pre-closed by the template. The model's output therefore starts in
|
||||
// plain content. Use NewGemma4Parser(false).
|
||||
// - enable_thinking=true: the prompt ends at "<|turn>model\n" and the model
|
||||
// opens and closes its own thought channel in the OUTPUT
|
||||
// ("<|channel>thought\n...reasoning...<channel|>final answer", per the
|
||||
// vLLM Gemma4ReasoningParser docstring). The parser still starts in
|
||||
// content state - the channel markers in the output drive the switch.
|
||||
// Use NewGemma4Parser(false) here too.
|
||||
// - NewGemma4Parser(true) is for callers that pre-open the thought channel
|
||||
// in the prompt themselves (appending "<|channel>thought\n" after the
|
||||
// generation prompt to force thinking): the output then begins mid-thought
|
||||
// and everything is reasoning until the first <channel|>.
|
||||
//
|
||||
// State diagram (markers are consumed, never emitted):
|
||||
//
|
||||
// <|channel> \n (channel name dropped: the
|
||||
// [content] --------------> [chan-header] ----> [thought] "thought\n" role
|
||||
// ^ | <channel|> (stray close: swallowed, label, stripped
|
||||
// +-+ strip_thinking semantics, tpl L148-L158) like vLLM does)
|
||||
// ^ <channel|>
|
||||
// +----------------------------------------- [thought]
|
||||
// ^ <tool_call|> | <|tool_call> (implicit
|
||||
// +-------------- [tool-call] <-------------------+ reasoning end, vLLM
|
||||
// | <|tool_call> ^ is_reasoning_end)
|
||||
// +-------------------+
|
||||
// [content]/[thought] --- <turn|> ---> [done] (everything after is dropped)
|
||||
//
|
||||
// Buffering rules:
|
||||
// - content/thought states hold back at most len(longest marker)-1 bytes:
|
||||
// the longest tail that is still a proper prefix of a watched marker.
|
||||
// Content is otherwise emitted immediately (no unbounded buffering).
|
||||
// - the tool-call state buffers the whole payload until <tool_call|>. This
|
||||
// is unbounded in principle but bounded in practice by the model's
|
||||
// diffusion canvas, and is required because the call:name{...} payload
|
||||
// only becomes decodable (and trustworthy) once complete - the same
|
||||
// reason vLLM's parser accumulates before parsing.
|
||||
// - Close() flushes whatever is still held: partial markers come out as
|
||||
// content/reasoning (per the state that held them); an unterminated
|
||||
// channel header or tool-call payload is re-emitted RAW (including its
|
||||
// opening marker) as content - malformed output is never silently
|
||||
// dropped (mirrors vLLM extract_tool_calls returning the raw text as
|
||||
// content when its regex does not match).
|
||||
//
|
||||
// Streaming granularity DIVERGENCE from vLLM: vLLM re-parses the partial
|
||||
// payload on every token and streams argument-JSON diffs (its `partial=True`
|
||||
// decoder mode plus withholding logic exist only for that). Our fragments are
|
||||
// whole committed diffusion blocks, so each completed tool call is emitted
|
||||
// once, as a single ToolCallDelta carrying index + id + name + the full
|
||||
// arguments JSON - exactly the shape backend/python/vllm/backend.py emits
|
||||
// per call and pkg/functions.ToolCallsFromChatDeltas re-accumulates.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// gemma4CallRE is vLLM's tool_call_regex
|
||||
// (`<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>`, DOTALL) anchored to
|
||||
// a single already-extracted payload: name charset [\w\-.], braces mandatory.
|
||||
var gemma4CallRE = regexp.MustCompile(`(?s)^call:([\w\-.]+)\{(.*)\}$`)
|
||||
|
||||
type g4State int
|
||||
|
||||
const (
|
||||
g4Content g4State = iota
|
||||
g4ChanHeader
|
||||
g4Thought
|
||||
g4ToolCall
|
||||
g4Done
|
||||
)
|
||||
|
||||
// Markers watched per emitting state. A stray <tool_call|> outside a tool
|
||||
// call is deliberately NOT watched: it passes through verbatim, consistent
|
||||
// with the malformed-payload fallback re-emitting it as content.
|
||||
var (
|
||||
gemma4ContentMarkers = []string{gemma4ChannelOpen, gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
|
||||
gemma4ThoughtMarkers = []string{gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
|
||||
)
|
||||
|
||||
type Gemma4Parser struct {
|
||||
state g4State
|
||||
// held is the per-state carry-over between Feed calls: a partial marker
|
||||
// (content/thought), a partial channel header (chan-header) or the
|
||||
// payload accumulated so far (tool-call).
|
||||
held string
|
||||
toolIdx int
|
||||
}
|
||||
|
||||
// NewGemma4Parser returns a parser positioned per the initial-state rules in
|
||||
// the header comment: startInThought=true only when the caller pre-opened a
|
||||
// thought channel in the prompt.
|
||||
func NewGemma4Parser(startInThought bool) *Gemma4Parser {
|
||||
state := g4Content
|
||||
if startInThought {
|
||||
state = g4Thought
|
||||
}
|
||||
return &Gemma4Parser{state: state}
|
||||
}
|
||||
|
||||
// Feed consumes the next output fragment and returns the deltas it completes.
|
||||
func (p *Gemma4Parser) Feed(text string) []*pb.ChatDelta {
|
||||
if text == "" || p.state == g4Done {
|
||||
return nil
|
||||
}
|
||||
pending := p.held + text
|
||||
p.held = ""
|
||||
var em g4Emitter
|
||||
for pending != "" {
|
||||
switch p.state {
|
||||
case g4Content, g4Thought:
|
||||
markers := gemma4ContentMarkers
|
||||
if p.state == g4Thought {
|
||||
markers = gemma4ThoughtMarkers
|
||||
}
|
||||
idx, marker := findEarliestGemma4Marker(pending, markers)
|
||||
if idx == -1 {
|
||||
hold := gemma4MarkerHoldback(pending, markers)
|
||||
p.emitText(&em, pending[:len(pending)-hold])
|
||||
p.held = pending[len(pending)-hold:]
|
||||
pending = ""
|
||||
continue
|
||||
}
|
||||
p.emitText(&em, pending[:idx])
|
||||
pending = pending[idx+len(marker):]
|
||||
switch marker {
|
||||
case gemma4ChannelOpen:
|
||||
p.state = g4ChanHeader
|
||||
case gemma4ChannelClose:
|
||||
// In thought: channel ends. In content: stray close,
|
||||
// swallowed (strip_thinking keeps both sides, tpl L148-L158).
|
||||
p.state = g4Content
|
||||
case gemma4ToolCallOpen:
|
||||
p.state = g4ToolCall
|
||||
case gemma4TurnEnd:
|
||||
p.state = g4Done
|
||||
}
|
||||
case g4ChanHeader:
|
||||
// The channel header is "<name>\n"; the template only ever writes
|
||||
// "thought" (tpl L240/L360) and the label is structural, so it is
|
||||
// dropped, not emitted (vLLM strips the same "thought\n" prefix).
|
||||
nl := strings.IndexByte(pending, '\n')
|
||||
if nl == -1 {
|
||||
p.held = pending
|
||||
pending = ""
|
||||
continue
|
||||
}
|
||||
pending = pending[nl+1:]
|
||||
p.state = g4Thought
|
||||
case g4ToolCall:
|
||||
end := strings.Index(pending, gemma4ToolCallClose)
|
||||
if end == -1 {
|
||||
p.held = pending
|
||||
pending = ""
|
||||
continue
|
||||
}
|
||||
p.emitToolCall(&em, pending[:end])
|
||||
pending = pending[end+len(gemma4ToolCallClose):]
|
||||
p.state = g4Content
|
||||
case g4Done:
|
||||
pending = ""
|
||||
}
|
||||
}
|
||||
return em.deltas
|
||||
}
|
||||
|
||||
// Close flushes held-back partials. Incomplete structures (open channel
|
||||
// header, unterminated tool payload) are re-emitted raw as content rather
|
||||
// than dropped. The parser is finished afterwards.
|
||||
func (p *Gemma4Parser) Close() []*pb.ChatDelta {
|
||||
var em g4Emitter
|
||||
switch p.state {
|
||||
case g4Content:
|
||||
em.content(p.held)
|
||||
case g4Thought:
|
||||
em.reasoning(p.held)
|
||||
case g4ChanHeader:
|
||||
em.content(gemma4ChannelOpen + p.held)
|
||||
case g4ToolCall:
|
||||
em.content(gemma4ToolCallOpen + p.held)
|
||||
case g4Done:
|
||||
}
|
||||
p.held = ""
|
||||
p.state = g4Done
|
||||
return em.deltas
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) emitText(em *g4Emitter, s string) {
|
||||
if p.state == g4Thought {
|
||||
em.reasoning(s)
|
||||
return
|
||||
}
|
||||
em.content(s)
|
||||
}
|
||||
|
||||
// emitToolCall decodes one complete <|tool_call>...<tool_call|> payload. On a
|
||||
// payload that does not match call:name{...} the raw text (markers included)
|
||||
// is emitted as content, mirroring vLLM's extract_tool_calls fallback.
|
||||
func (p *Gemma4Parser) emitToolCall(em *g4Emitter, payload string) {
|
||||
m := gemma4CallRE.FindStringSubmatch(payload)
|
||||
if m == nil {
|
||||
em.content(gemma4ToolCallOpen + payload + gemma4ToolCallClose)
|
||||
return
|
||||
}
|
||||
// Index-based ids: deterministic (the split-invariance property relies
|
||||
// on it) and matching the call_<n> convention of pkg/grpc/rich_test.go;
|
||||
// core only needs ids to be non-empty and unique within the response.
|
||||
em.tool(p.toolIdx, "call_"+strconv.Itoa(p.toolIdx), m[1], decodeGemma4Args(m[2], 0))
|
||||
p.toolIdx++
|
||||
}
|
||||
|
||||
// g4Emitter collects ChatDeltas; empty text events are dropped.
|
||||
type g4Emitter struct {
|
||||
deltas []*pb.ChatDelta
|
||||
}
|
||||
|
||||
func (e *g4Emitter) content(s string) {
|
||||
if s != "" {
|
||||
e.deltas = append(e.deltas, &pb.ChatDelta{Content: s})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *g4Emitter) reasoning(s string) {
|
||||
if s != "" {
|
||||
e.deltas = append(e.deltas, &pb.ChatDelta{ReasoningContent: s})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *g4Emitter) tool(index int, id, name, argsJSON string) {
|
||||
e.deltas = append(e.deltas, &pb.ChatDelta{ToolCalls: []*pb.ToolCallDelta{{
|
||||
Index: int32(index),
|
||||
Id: id,
|
||||
Name: name,
|
||||
Arguments: argsJSON,
|
||||
}}})
|
||||
}
|
||||
|
||||
// findEarliestGemma4Marker returns the position and value of the first
|
||||
// complete marker occurrence, or (-1, "").
|
||||
func findEarliestGemma4Marker(s string, markers []string) (int, string) {
|
||||
best, bestMarker := -1, ""
|
||||
for _, m := range markers {
|
||||
if idx := strings.Index(s, m); idx >= 0 && (best == -1 || idx < best) {
|
||||
best, bestMarker = idx, m
|
||||
}
|
||||
}
|
||||
return best, bestMarker
|
||||
}
|
||||
|
||||
// gemma4MarkerHoldback returns the length of the longest suffix of s that is
|
||||
// a proper prefix of a watched marker - the only bytes that may still grow
|
||||
// into a marker and therefore must not be emitted yet (bounded by the
|
||||
// longest marker, so content is never buffered unboundedly).
|
||||
func gemma4MarkerHoldback(s string, markers []string) int {
|
||||
maxHold := 0
|
||||
for _, m := range markers {
|
||||
if len(m)-1 > maxHold {
|
||||
maxHold = len(m) - 1
|
||||
}
|
||||
}
|
||||
if len(s) < maxHold {
|
||||
maxHold = len(s)
|
||||
}
|
||||
for k := maxHold; k >= 1; k-- {
|
||||
tail := s[len(s)-k:]
|
||||
for _, m := range markers {
|
||||
if strings.HasPrefix(m, tail) {
|
||||
return k
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// call:name{...} argument decoder
|
||||
//
|
||||
// Port of vLLM's _parse_gemma4_args / _parse_gemma4_array /
|
||||
// _parse_gemma4_value (gemma4_tool_parser.py) in non-partial mode only: this
|
||||
// parser decodes exclusively COMPLETE payloads (incomplete ones fall back to
|
||||
// raw content at Close), so vLLM's partial-withholding machinery
|
||||
// (trailing-dot floats, withheld bare tails) is intentionally not ported.
|
||||
//
|
||||
// Grammar (inverse of the renderer's formatGemma4Argument, tpl L118-L147):
|
||||
//
|
||||
// args := pair (',' pair)*
|
||||
// pair := key ':' value (keys unquoted, up to the first ':')
|
||||
// value := string | object | array | bare
|
||||
// string := '<|"|>' ... '<|"|>' (no escapes; unterminated -> rest)
|
||||
// object := '{' args '}' (delimited strings skipped when
|
||||
// array := '[' value,* ']' counting braces/brackets)
|
||||
// bare := true | false | null/none/nil | number | bare-string
|
||||
//
|
||||
// Output is a JSON object/array string with keys in payload order (Python
|
||||
// dict insertion order), built with HTML escaping off so payload text
|
||||
// survives byte-for-byte.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func isGemma4Space(c byte) bool { return c == ' ' || c == '\n' || c == '\t' }
|
||||
|
||||
// gemma4MaxArgsDepth caps the mutual recursion between decodeGemma4Args and
|
||||
// decodeGemma4Array. Defense against model-generated deep nesting: a Go stack
|
||||
// overflow is a fatal process kill, not a recoverable error, so past the cap
|
||||
// a nested body gracefully degrades to a JSON string of its raw text.
|
||||
const gemma4MaxArgsDepth = 100
|
||||
|
||||
// decodeGemma4Args decodes one args body (the text between the outer braces
|
||||
// of call:name{...}) into a JSON object string. depth is the current nesting
|
||||
// level (0 at the payload root); see gemma4MaxArgsDepth.
|
||||
func decodeGemma4Args(s string, depth int) string {
|
||||
if depth > gemma4MaxArgsDepth {
|
||||
return gemma4JSONString(s)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("{")
|
||||
first := true
|
||||
pair := func(key, val string) {
|
||||
if !first {
|
||||
b.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
b.WriteString(gemma4JSONString(key))
|
||||
b.WriteString(":")
|
||||
b.WriteString(val)
|
||||
}
|
||||
i, n := 0, len(s)
|
||||
for i < n {
|
||||
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
break
|
||||
}
|
||||
keyStart := i
|
||||
for i < n && s[i] != ':' {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
break // no ':' -> trailing junk, dropped (vLLM does the same)
|
||||
}
|
||||
key := strings.TrimSpace(s[keyStart:i])
|
||||
i++ // skip ':'
|
||||
for i < n && isGemma4Space(s[i]) {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
pair(key, `""`) // "key:" with nothing after -> empty string
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(s[i:], gemma4StringDelim):
|
||||
i += len(gemma4StringDelim)
|
||||
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
|
||||
pair(key, gemma4JSONString(s[i:])) // unterminated -> take rest
|
||||
i = n
|
||||
} else {
|
||||
pair(key, gemma4JSONString(s[i:i+end]))
|
||||
i += end + len(gemma4StringDelim)
|
||||
}
|
||||
case s[i] == '{':
|
||||
inner, next := scanGemma4Balanced(s, i, '{', '}')
|
||||
pair(key, decodeGemma4Args(inner, depth+1))
|
||||
i = next
|
||||
case s[i] == '[':
|
||||
inner, next := scanGemma4Balanced(s, i, '[', ']')
|
||||
pair(key, decodeGemma4Array(inner, depth+1))
|
||||
i = next
|
||||
default:
|
||||
valStart := i
|
||||
for i < n && s[i] != ',' && s[i] != '}' && s[i] != ']' {
|
||||
i++
|
||||
}
|
||||
if i == valStart {
|
||||
// No progress (value starts on a stray '}'/']'): abort on
|
||||
// malformed input rather than loop, like vLLM.
|
||||
i = n
|
||||
continue
|
||||
}
|
||||
pair(key, decodeGemma4Bare(s[valStart:i]))
|
||||
}
|
||||
}
|
||||
b.WriteString("}")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// decodeGemma4Array decodes one array body (the text between '[' and ']')
|
||||
// into a JSON array string. depth is the current nesting level; see
|
||||
// gemma4MaxArgsDepth.
|
||||
func decodeGemma4Array(s string, depth int) string {
|
||||
if depth > gemma4MaxArgsDepth {
|
||||
return gemma4JSONString(s)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("[")
|
||||
first := true
|
||||
item := func(val string) {
|
||||
if !first {
|
||||
b.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
b.WriteString(val)
|
||||
}
|
||||
i, n := 0, len(s)
|
||||
for i < n {
|
||||
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(s[i:], gemma4StringDelim):
|
||||
i += len(gemma4StringDelim)
|
||||
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
|
||||
item(gemma4JSONString(s[i:]))
|
||||
i = n
|
||||
} else {
|
||||
item(gemma4JSONString(s[i : i+end]))
|
||||
i += end + len(gemma4StringDelim)
|
||||
}
|
||||
case s[i] == '{':
|
||||
inner, next := scanGemma4Balanced(s, i, '{', '}')
|
||||
item(decodeGemma4Args(inner, depth+1))
|
||||
i = next
|
||||
case s[i] == '[':
|
||||
inner, next := scanGemma4Balanced(s, i, '[', ']')
|
||||
item(decodeGemma4Array(inner, depth+1))
|
||||
i = next
|
||||
default:
|
||||
valStart := i
|
||||
for i < n && s[i] != ',' && s[i] != ']' {
|
||||
i++
|
||||
}
|
||||
if i == valStart {
|
||||
i = n // no progress: abort on malformed input, like vLLM
|
||||
continue
|
||||
}
|
||||
item(decodeGemma4Bare(s[valStart:i]))
|
||||
}
|
||||
}
|
||||
b.WriteString("]")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// scanGemma4Balanced scans a brace/bracket-balanced span starting at the
|
||||
// opener s[start], skipping over <|"|>-delimited strings so structural
|
||||
// characters inside them do not count (vLLM's depth scan). Returns the inner
|
||||
// text and the index just past the closer; an unterminated span yields the
|
||||
// rest of the string (the inner decoder still extracts what is there - this
|
||||
// path is only reachable from genuinely malformed complete payloads).
|
||||
func scanGemma4Balanced(s string, start int, open, close byte) (string, int) {
|
||||
depth := 1
|
||||
i := start + 1
|
||||
innerStart := i
|
||||
n := len(s)
|
||||
for i < n && depth > 0 {
|
||||
if strings.HasPrefix(s[i:], gemma4StringDelim) {
|
||||
i += len(gemma4StringDelim)
|
||||
if nd := strings.Index(s[i:], gemma4StringDelim); nd == -1 {
|
||||
i = n
|
||||
} else {
|
||||
i += nd + len(gemma4StringDelim)
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch s[i] {
|
||||
case open:
|
||||
depth++
|
||||
case close:
|
||||
depth--
|
||||
}
|
||||
i++
|
||||
}
|
||||
if depth > 0 {
|
||||
return s[innerStart:], n
|
||||
}
|
||||
return s[innerStart : i-1], i
|
||||
}
|
||||
|
||||
// decodeGemma4Bare maps an undelimited value to its JSON form: booleans,
|
||||
// null aliases (null/none/nil, case-insensitive - the renderer writes
|
||||
// Python None as "None", tpl L144-L145 via format_argument's else branch),
|
||||
// numbers (vLLM's rule: a '.' tries float, otherwise int; anything that
|
||||
// fails parses as a bare string).
|
||||
func decodeGemma4Bare(raw string) string {
|
||||
v := strings.TrimSpace(raw)
|
||||
if v == "" {
|
||||
return `""`
|
||||
}
|
||||
if v == "true" || v == "false" {
|
||||
return v
|
||||
}
|
||||
switch strings.ToLower(v) {
|
||||
case "null", "none", "nil":
|
||||
return "null"
|
||||
}
|
||||
if strings.Contains(v, ".") {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return formatGemma4Float(f)
|
||||
}
|
||||
} else if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return strconv.FormatInt(iv, 10)
|
||||
}
|
||||
return gemma4JSONString(v)
|
||||
}
|
||||
|
||||
// formatGemma4Float renders like Python's json.dumps(float): integral floats
|
||||
// keep a ".0" suffix ("108." decodes to 108.0, not 108), so the arguments
|
||||
// JSON matches what vLLM would have produced for the same payload.
|
||||
func formatGemma4Float(f float64) string {
|
||||
s := strconv.FormatFloat(f, 'g', -1, 64)
|
||||
if !strings.ContainsAny(s, ".eE") {
|
||||
s += ".0"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// gemma4JSONString encodes a JSON string WITHOUT HTML escaping (json.Marshal
|
||||
// would escape the angle brackets in "<div>" to \u003c / \u003e sequences;
|
||||
// payload text should survive
|
||||
// byte-for-byte, like Python's json.dumps(ensure_ascii=False)).
|
||||
func gemma4JSONString(s string) string {
|
||||
var sb strings.Builder
|
||||
enc := json.NewEncoder(&sb)
|
||||
enc.SetEscapeHTML(false)
|
||||
if err := enc.Encode(s); err != nil {
|
||||
// Unreachable for plain strings; fall back to default escaping
|
||||
// rather than emitting invalid JSON.
|
||||
b, mErr := json.Marshal(s)
|
||||
if mErr != nil {
|
||||
return `""`
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
// Encode appends a trailing newline.
|
||||
return strings.TrimSuffix(sb.String(), "\n")
|
||||
}
|
||||
592
backend/go/dllm/gemma4_parser_test.go
Normal file
592
backend/go/dllm/gemma4_parser_test.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package main
|
||||
|
||||
// Parser specs for Gemma4Parser (model output text -> pb.ChatDelta events).
|
||||
//
|
||||
// Fixture provenance:
|
||||
// - Entries marked "vLLM: <name>" are direct ports of the named test from
|
||||
// vLLM PR #45163, tests/tool_parsers/test_gemma4_tool_parser.py (the
|
||||
// authoritative test-suite for the gemma4 tool-call wire format). The
|
||||
// streaming tests' chunk lists are reused verbatim as Feed fragments.
|
||||
// - Decoder entries port the TestParseGemma4Args / TestParseGemma4Array
|
||||
// classes from the same file (non-partial mode only; this parser never
|
||||
// decodes partial payloads, see the divergence note in gemma4_parser.go).
|
||||
// - Channel/turn-marker expectations come from the chat template embedded
|
||||
// in gemma4_renderer.go (tpl L356-L362 generation prompt, L148-L158
|
||||
// strip_thinking) and vLLM's Gemma4ReasoningParser
|
||||
// (vllm/reasoning/gemma4_reasoning_parser.py).
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// flatGemma4Tool is one accumulated tool call, mirroring how LocalAI core
|
||||
// folds ToolCallDelta streams (pkg/functions/chat_deltas.go
|
||||
// ToolCallsFromChatDeltas: name/id latch on first non-empty, arguments
|
||||
// concatenate per index). Tests flatten through the same rules so they
|
||||
// assert exactly what core will reconstruct.
|
||||
type flatGemma4Tool struct {
|
||||
id string
|
||||
name string
|
||||
args string
|
||||
}
|
||||
|
||||
func flattenGemma4Deltas(deltas []*pb.ChatDelta) (string, string, []flatGemma4Tool) {
|
||||
var content, reasoning strings.Builder
|
||||
byIndex := map[int32]*flatGemma4Tool{}
|
||||
maxIdx := int32(-1)
|
||||
for _, d := range deltas {
|
||||
content.WriteString(d.GetContent())
|
||||
reasoning.WriteString(d.GetReasoningContent())
|
||||
for _, tc := range d.GetToolCalls() {
|
||||
acc, ok := byIndex[tc.GetIndex()]
|
||||
if !ok {
|
||||
acc = &flatGemma4Tool{}
|
||||
byIndex[tc.GetIndex()] = acc
|
||||
}
|
||||
if tc.GetName() != "" {
|
||||
acc.name = tc.GetName()
|
||||
}
|
||||
if tc.GetId() != "" {
|
||||
acc.id = tc.GetId()
|
||||
}
|
||||
acc.args += tc.GetArguments()
|
||||
if tc.GetIndex() > maxIdx {
|
||||
maxIdx = tc.GetIndex()
|
||||
}
|
||||
}
|
||||
}
|
||||
var tools []flatGemma4Tool
|
||||
for i := int32(0); i <= maxIdx; i++ {
|
||||
if acc, ok := byIndex[i]; ok {
|
||||
tools = append(tools, *acc)
|
||||
}
|
||||
}
|
||||
return content.String(), reasoning.String(), tools
|
||||
}
|
||||
|
||||
type wantGemma4Tool struct {
|
||||
name string
|
||||
argsJSON string // compared with MatchJSON (key order irrelevant)
|
||||
}
|
||||
|
||||
type parseGemma4Case struct {
|
||||
startInThought bool
|
||||
fragments []string
|
||||
wantContent string
|
||||
wantReasoning string
|
||||
wantTools []wantGemma4Tool
|
||||
}
|
||||
|
||||
func parseGemma4Fragments(startInThought bool, fragments []string) []*pb.ChatDelta {
|
||||
p := NewGemma4Parser(startInThought)
|
||||
var all []*pb.ChatDelta
|
||||
for _, f := range fragments {
|
||||
all = append(all, p.Feed(f)...)
|
||||
}
|
||||
return append(all, p.Close()...)
|
||||
}
|
||||
|
||||
var _ = Describe("Gemma4Parser", func() {
|
||||
DescribeTable("parses streamed gemma4 output into ChatDeltas",
|
||||
func(c parseGemma4Case) {
|
||||
content, reasoning, tools := flattenGemma4Deltas(parseGemma4Fragments(c.startInThought, c.fragments))
|
||||
Expect(content).To(Equal(c.wantContent))
|
||||
Expect(reasoning).To(Equal(c.wantReasoning))
|
||||
Expect(tools).To(HaveLen(len(c.wantTools)))
|
||||
seenIDs := map[string]bool{}
|
||||
for i, want := range c.wantTools {
|
||||
Expect(tools[i].name).To(Equal(want.name), "tool %d name", i)
|
||||
Expect(tools[i].args).To(MatchJSON(want.argsJSON), "tool %d arguments", i)
|
||||
Expect(tools[i].id).ToNot(BeEmpty(), "tool %d id", i)
|
||||
Expect(seenIDs).ToNot(HaveKey(tools[i].id), "tool %d id must be unique", i)
|
||||
seenIDs[tools[i].id] = true
|
||||
}
|
||||
},
|
||||
|
||||
// --- (1) pure content -------------------------------------------------
|
||||
// vLLM: test_no_tool_calls
|
||||
Entry("pure content, single fragment", parseGemma4Case{
|
||||
fragments: []string{"Hello, how can I help you today?"},
|
||||
wantContent: "Hello, how can I help you today?",
|
||||
}),
|
||||
|
||||
// --- (2) thought -> final transition ----------------------------------
|
||||
// enable_thinking render: prompt ends at <|turn>model\n and the model
|
||||
// opens/closes its own thought channel in the OUTPUT (vLLM
|
||||
// Gemma4ReasoningParser docstring; tpl L356-L362). The "thought\n"
|
||||
// role label after <|channel> is structural and must be stripped
|
||||
// (vLLM _THOUGHT_PREFIX handling).
|
||||
Entry("thought channel then final content", parseGemma4Case{
|
||||
fragments: []string{"<|channel>thought\nLet me think about this.\n<channel|>The answer is 42."},
|
||||
wantReasoning: "Let me think about this.\n",
|
||||
wantContent: "The answer is 42.",
|
||||
}),
|
||||
|
||||
// --- (3) startInThought both ways -------------------------------------
|
||||
Entry("startInThought=true routes initial text to reasoning until <channel|>", parseGemma4Case{
|
||||
startInThought: true,
|
||||
fragments: []string{"I am thinking hard.<channel|>Done."},
|
||||
wantReasoning: "I am thinking hard.",
|
||||
wantContent: "Done.",
|
||||
}),
|
||||
// A stray <channel|> with no open channel is swallowed, matching the
|
||||
// template's strip_thinking (tpl L148-L158: the marker is dropped,
|
||||
// text on both sides is kept).
|
||||
Entry("startInThought=false keeps the same text as content, stray <channel|> swallowed", parseGemma4Case{
|
||||
startInThought: false,
|
||||
fragments: []string{"I am thinking hard.<channel|>Done."},
|
||||
wantContent: "I am thinking hard.Done.",
|
||||
}),
|
||||
|
||||
// --- (4) one tool call, full payload type zoo --------------------------
|
||||
Entry("single tool call: strings, numbers, bools, null, nested object and array", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:complex_function{text:<|"|>with, comma and {braces}<|"|>,count:42,score:3.14,yes:true,no:false,nothing:null,obj:{inner:<|"|>v<|"|>,k:1},arr:[<|"|>a<|"|>,2,true]}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{
|
||||
name: "complex_function",
|
||||
argsJSON: `{"text":"with, comma and {braces}","count":42,"score":3.14,"yes":true,"no":false,"nothing":null,"obj":{"inner":"v","k":1},"arr":["a",2,true]}`,
|
||||
}},
|
||||
}),
|
||||
|
||||
// --- (5) payload split across 3 fragments ------------------------------
|
||||
Entry("tool-call payload split across three fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>call:get_weather{loc",
|
||||
`ation:<|"|>Paris, Fra`,
|
||||
`nce<|"|>}<tool_call|>`,
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
|
||||
}),
|
||||
|
||||
// --- (6) marker split across fragments ----------------------------------
|
||||
Entry("tool-call open marker split across fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_ca",
|
||||
`ll>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
Entry("channel open marker split across fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|chan",
|
||||
"nel>thought\ndeep thought<channel|>final",
|
||||
},
|
||||
wantReasoning: "deep thought",
|
||||
wantContent: "final",
|
||||
}),
|
||||
|
||||
// --- (7) trailing partial marker held, flushed by Close -----------------
|
||||
Entry("trailing partial marker is held back and flushed by Close", parseGemma4Case{
|
||||
fragments: []string{"Hello <|tool"},
|
||||
wantContent: "Hello <|tool",
|
||||
}),
|
||||
|
||||
// --- (8) malformed/incomplete payload -> content fallback ---------------
|
||||
// vLLM: test_incomplete_tool_call (no end marker: the whole text stays
|
||||
// content, never silently dropped).
|
||||
Entry("incomplete tool payload at Close is emitted as raw content", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London`},
|
||||
wantContent: `<|tool_call>call:get_weather{location:<|"|>London`,
|
||||
}),
|
||||
Entry("malformed complete payload is emitted as raw content, parsing continues", parseGemma4Case{
|
||||
fragments: []string{"<|tool_call>oops no call syntax<tool_call|> done"},
|
||||
wantContent: "<|tool_call>oops no call syntax<tool_call|> done",
|
||||
}),
|
||||
|
||||
// --- (9) <turn|> ends the turn -------------------------------------------
|
||||
Entry("text after <turn|> is ignored, including later fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"before<turn|>after",
|
||||
`more <|tool_call>call:f{}<tool_call|>`,
|
||||
},
|
||||
wantContent: "before",
|
||||
}),
|
||||
Entry("<turn|> inside a thought channel ends the turn", parseGemma4Case{
|
||||
startInThought: true,
|
||||
fragments: []string{"thinking<turn|>ignored"},
|
||||
wantReasoning: "thinking",
|
||||
}),
|
||||
|
||||
// --- (10) ported vLLM non-streaming cases ---------------------------------
|
||||
// vLLM: test_single_tool_call
|
||||
Entry("vLLM: test_single_tool_call", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_multiple_arguments
|
||||
Entry("vLLM: test_multiple_arguments", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"San Francisco","unit":"celsius"}`}},
|
||||
}),
|
||||
// vLLM: test_text_before_tool_call. DIVERGENCE: vLLM's non-streaming
|
||||
// extractor trims the content ("...you."); a streaming parser cannot
|
||||
// retroactively trim already-emitted text, so the trailing space is
|
||||
// kept (vLLM's own streaming path keeps it too, see
|
||||
// test_streaming_text_before_tool_call which only checks a prefix).
|
||||
Entry("vLLM: test_text_before_tool_call (streaming semantics: no trim)", parseGemma4Case{
|
||||
fragments: []string{`Let me check the weather for you. <|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`},
|
||||
wantContent: "Let me check the weather for you. ",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
|
||||
}),
|
||||
// vLLM: test_multiple_tool_calls (also covers case 11: multi-tool sequence)
|
||||
Entry("vLLM: test_multiple_tool_calls", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|><|tool_call>call:get_time{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{
|
||||
{name: "get_weather", argsJSON: `{"location":"London"}`},
|
||||
{name: "get_time", argsJSON: `{"location":"London"}`},
|
||||
},
|
||||
}),
|
||||
// vLLM: test_nested_arguments
|
||||
Entry("vLLM: test_nested_arguments", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:complex_function{nested:{inner:<|"|>value<|"|>},list:[<|"|>a<|"|>,<|"|>b<|"|>]}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "complex_function", argsJSON: `{"nested":{"inner":"value"},"list":["a","b"]}`}},
|
||||
}),
|
||||
// vLLM: test_tool_call_with_number_and_boolean
|
||||
Entry("vLLM: test_tool_call_with_number_and_boolean", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:set_status{is_active:true,count:42,score:3.14}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "set_status", argsJSON: `{"is_active":true,"count":42,"score":3.14}`}},
|
||||
}),
|
||||
// vLLM: test_hyphenated_function_name
|
||||
Entry("vLLM: test_hyphenated_function_name", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "get-weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_dotted_function_name
|
||||
Entry("vLLM: test_dotted_function_name", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "weather.get", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_no_arguments
|
||||
Entry("vLLM: test_no_arguments", parseGemma4Case{
|
||||
fragments: []string{"<|tool_call>call:get_status{}<tool_call|>"},
|
||||
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
|
||||
}),
|
||||
|
||||
// --- ported vLLM streaming cases (chunk lists reused as fragments) --------
|
||||
// vLLM: test_basic_streaming_single_tool
|
||||
Entry("vLLM: test_basic_streaming_single_tool", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>Paris`,
|
||||
", France",
|
||||
`<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_multi_arg
|
||||
Entry("vLLM: test_streaming_multi_arg", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>Tokyo<|"|>,`,
|
||||
`unit:<|"|>celsius<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Tokyo","unit":"celsius"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_text_before_tool_call
|
||||
Entry("vLLM: test_streaming_text_before_tool_call", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"Let me check ",
|
||||
"the weather. ",
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>London<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantContent: "Let me check the weather. ",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_numeric_args
|
||||
Entry("vLLM: test_streaming_numeric_args", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:set_config{",
|
||||
"count:42,",
|
||||
"active:true}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "set_config", argsJSON: `{"count":42,"active":true}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_boolean_split_across_chunks
|
||||
Entry("vLLM: test_streaming_boolean_split_across_chunks", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:search{input:{all:tru",
|
||||
"e}}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "search", argsJSON: `{"input":{"all":true}}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_false_split_across_chunks
|
||||
Entry("vLLM: test_streaming_false_split_across_chunks", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:set{flag:fals",
|
||||
"e}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"flag":false}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_number_split_across_chunks
|
||||
Entry("vLLM: test_streaming_number_split_across_chunks", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:set{count:4",
|
||||
"2}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"count":42}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_empty_args
|
||||
Entry("vLLM: test_streaming_empty_args", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_status{}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_split_delimiter_no_invalid_json (string
|
||||
// delimiter <|"|> split across fragments must not leak fragments).
|
||||
Entry("vLLM: test_streaming_split_delimiter_no_invalid_json", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:todowrite{",
|
||||
`content:<|"|>Buy milk<|`,
|
||||
`"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "todowrite", argsJSON: `{"content":"Buy milk"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call
|
||||
Entry("vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>Paris<|"|>}`,
|
||||
"<tool_call|><",
|
||||
"div>",
|
||||
},
|
||||
wantContent: "<div>",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes
|
||||
Entry("vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:write_file{",
|
||||
`path:<|"|>index.html<|"|>,`,
|
||||
`content:<|"|><!DOCTYPE html>` + "\n<",
|
||||
`html lang="zh-CN">` + "\n<",
|
||||
"head>\n <",
|
||||
`meta charset="UTF-8">` + "\n <",
|
||||
`meta name="viewport" content="width=device-width">` + "\n",
|
||||
`<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{
|
||||
name: "write_file",
|
||||
argsJSON: `{"path":"index.html","content":"<!DOCTYPE html>\n<html lang=\"zh-CN\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width\">\n"}`,
|
||||
}},
|
||||
}),
|
||||
// vLLM: test_streaming_single_chunk_complete_tool_call
|
||||
Entry("vLLM: test_streaming_single_chunk_complete_tool_call", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:name_a_color{color_hex:<|"|>00ff11<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "name_a_color", argsJSON: `{"color_hex":"00ff11"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_multi_chunk_batched_tool_calls (two complete
|
||||
// calls in ONE fragment; both must come out with distinct indices)
|
||||
Entry("vLLM: test_streaming_multi_chunk_batched_tool_calls", parseGemma4Case{
|
||||
fragments: []string{
|
||||
`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>` +
|
||||
`<|tool_call>call:get_time{timezone:<|"|>GMT<|"|>}<tool_call|>`,
|
||||
},
|
||||
wantTools: []wantGemma4Tool{
|
||||
{name: "get_weather", argsJSON: `{"location":"London"}`},
|
||||
{name: "get_time", argsJSON: `{"timezone":"GMT"}`},
|
||||
},
|
||||
}),
|
||||
// vLLM: test_streaming_trailing_bare_bool_not_duplicated
|
||||
Entry("vLLM: test_streaming_trailing_bare_bool_not_duplicated", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:Edit{",
|
||||
`file_path:<|"|>src/env.py<|"|>,`,
|
||||
`old_string:<|"|>old_val<|"|>,`,
|
||||
`new_string:<|"|>new_val<|"|>,`,
|
||||
"replace_all:",
|
||||
"false}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{
|
||||
name: "Edit",
|
||||
argsJSON: `{"file_path":"src/env.py","old_string":"old_val","new_string":"new_val","replace_all":false}`,
|
||||
}},
|
||||
}),
|
||||
|
||||
// --- implicit reasoning end on <|tool_call> (vLLM is_reasoning_end:
|
||||
// a tool_call token means reasoning is over) -----------------------------
|
||||
Entry("tool call inside an open thought channel ends the reasoning", parseGemma4Case{
|
||||
startInThought: true,
|
||||
fragments: []string{`need the weather<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}<tool_call|>`},
|
||||
wantReasoning: "need the weather",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Rome"}`}},
|
||||
}),
|
||||
|
||||
// --- (12) empty fragments are no-ops --------------------------------------
|
||||
Entry("empty fragments are no-ops", parseGemma4Case{
|
||||
fragments: []string{"", "Hello", "", "", " world", ""},
|
||||
wantContent: "Hello world",
|
||||
}),
|
||||
)
|
||||
|
||||
It("returns no deltas for an empty fragment and after Close", func() {
|
||||
p := NewGemma4Parser(false)
|
||||
Expect(p.Feed("")).To(BeEmpty())
|
||||
Expect(p.Feed("hi")).ToNot(BeEmpty())
|
||||
Expect(p.Close()).To(BeEmpty()) // nothing held back
|
||||
// The parser is finished after Close: further input is dropped.
|
||||
Expect(p.Feed("more")).To(BeEmpty())
|
||||
Expect(p.Close()).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("generates index-based tool call ids (call_<index>)", func() {
|
||||
// Mirrors the index-based id convention of pkg/grpc/rich_test.go and
|
||||
// keeps ids deterministic for the split-invariance property below.
|
||||
deltas := parseGemma4Fragments(false, []string{
|
||||
`<|tool_call>call:a{}<tool_call|><|tool_call>call:b{}<tool_call|>`,
|
||||
})
|
||||
_, _, tools := flattenGemma4Deltas(deltas)
|
||||
Expect(tools).To(HaveLen(2))
|
||||
Expect(tools[0].id).To(Equal("call_0"))
|
||||
Expect(tools[1].id).To(Equal("call_1"))
|
||||
})
|
||||
|
||||
// Property: for a fixed full output, EVERY 2-split position must yield
|
||||
// exactly the same flattened result as the unsplit parse. This kills
|
||||
// fragment-boundary bugs (mid-marker, mid-delimiter, mid-payload splits).
|
||||
DescribeTable("2-split fragment invariance",
|
||||
func(startInThought bool, full string) {
|
||||
refContent, refReasoning, refTools := flattenGemma4Deltas(
|
||||
parseGemma4Fragments(startInThought, []string{full}))
|
||||
for i := 0; i <= len(full); i++ {
|
||||
content, reasoning, tools := flattenGemma4Deltas(
|
||||
parseGemma4Fragments(startInThought, []string{full[:i], full[i:]}))
|
||||
Expect(content).To(Equal(refContent), fmt.Sprintf("content diverged at split %d", i))
|
||||
Expect(reasoning).To(Equal(refReasoning), fmt.Sprintf("reasoning diverged at split %d", i))
|
||||
Expect(tools).To(Equal(refTools), fmt.Sprintf("tool calls diverged at split %d", i))
|
||||
}
|
||||
},
|
||||
Entry("thought + content + two tool calls + turn end", false,
|
||||
"<|channel>thought\nPondering the request...\n<channel|>Sure - calling tools now. "+
|
||||
`<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>,days:3,detailed:true}<tool_call|>`+
|
||||
`<|tool_call>call:get_time{timezone:<|"|>Europe/Lisbon<|"|>,nested:{flag:false,vals:[1,2.5,<|"|>x<|"|>]}}<tool_call|>`+
|
||||
"Done.<turn|>ignored tail"),
|
||||
Entry("startInThought + tool call + trailing partial marker", true,
|
||||
`Deep thought<channel|>final answer <|tool_call>call:noop{}<tool_call|> trailing <|tool`),
|
||||
Entry("malformed payload fallback", false,
|
||||
`pre <|tool_call>not a call<tool_call|> post`),
|
||||
)
|
||||
})
|
||||
|
||||
// Decoder-level ports of vLLM's TestParseGemma4Args / TestParseGemma4Array
|
||||
// (non-partial mode; the partial-withholding tests do not apply because this
|
||||
// parser only ever decodes COMPLETE payloads, see gemma4_parser.go).
|
||||
var _ = Describe("decodeGemma4Args", func() {
|
||||
DescribeTable("decodes the gemma4 call syntax into JSON arguments",
|
||||
func(in, wantJSON string) {
|
||||
Expect(decodeGemma4Args(in, 0)).To(MatchJSON(wantJSON))
|
||||
},
|
||||
// vLLM: test_empty_string / test_whitespace_only
|
||||
Entry("empty string", "", `{}`),
|
||||
Entry("whitespace only", " ", `{}`),
|
||||
// vLLM: test_single_string_value
|
||||
Entry("single string value", `location:<|"|>Paris<|"|>`, `{"location":"Paris"}`),
|
||||
// vLLM: test_string_value_with_comma
|
||||
Entry("string value with comma", `location:<|"|>Paris, France<|"|>`, `{"location":"Paris, France"}`),
|
||||
// vLLM: test_multiple_string_values
|
||||
Entry("multiple string values", `location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>`, `{"location":"San Francisco","unit":"celsius"}`),
|
||||
// vLLM: test_integer_value / test_float_value
|
||||
Entry("integer value", "count:42", `{"count":42}`),
|
||||
Entry("float value", "score:3.14", `{"score":3.14}`),
|
||||
// vLLM: test_boolean_true / test_boolean_false
|
||||
Entry("boolean true", "flag:true", `{"flag":true}`),
|
||||
Entry("boolean false", "flag:false", `{"flag":false}`),
|
||||
// vLLM: test_null_value (bare null must become JSON null, not "null")
|
||||
Entry("null value", "param:null", `{"param":null}`),
|
||||
// vLLM: test_mixed_types
|
||||
Entry("mixed types", `name:<|"|>test<|"|>,count:42,active:true,score:3.14`,
|
||||
`{"name":"test","count":42,"active":true,"score":3.14}`),
|
||||
// vLLM: test_nested_object
|
||||
Entry("nested object", `nested:{inner:<|"|>value<|"|>}`, `{"nested":{"inner":"value"}}`),
|
||||
// vLLM: test_array_of_strings
|
||||
Entry("array of strings", `items:[<|"|>a<|"|>,<|"|>b<|"|>]`, `{"items":["a","b"]}`),
|
||||
// vLLM: test_unterminated_string (take everything after the delimiter)
|
||||
Entry("unterminated string", `key:<|"|>unterminated`, `{"key":"unterminated"}`),
|
||||
// vLLM: test_empty_value (key with no value after colon)
|
||||
Entry("empty value", "key:", `{"key":""}`),
|
||||
// vLLM: test_trailing_dot_float_partial_withheld, non-partial branch
|
||||
// (trailing-dot floats parse normally outside streaming).
|
||||
Entry("trailing dot float, complete payload", "left:108.,right:22.8", `{"left":108.0,"right":22.8}`),
|
||||
)
|
||||
|
||||
It("terminates and yields valid JSON on malformed input", func() {
|
||||
// vLLM: test_malformed_partial_array (the assertion there is only
|
||||
// "returns a dict without hanging"; ours is "valid JSON object").
|
||||
out := decodeGemma4Args(":[t:[]", 0)
|
||||
var v map[string]any
|
||||
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
|
||||
})
|
||||
|
||||
It("degrades nesting beyond the recursion cap to a string value", func() {
|
||||
// 200 levels of a:{a:{...a:1...}}. Without the depth cap the mutual
|
||||
// recursion would grow the stack with the model's output; a Go stack
|
||||
// overflow is a fatal process kill, so levels past gemma4MaxArgsDepth
|
||||
// must gracefully fall back to the raw inner text as a JSON string.
|
||||
const depth = 200
|
||||
body := strings.Repeat("a:{", depth-1) + "a:1" + strings.Repeat("}", depth-1)
|
||||
out := decodeGemma4Args(body, 0)
|
||||
var v map[string]any
|
||||
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
|
||||
levels := 0
|
||||
var cur any = v
|
||||
for {
|
||||
m, ok := cur.(map[string]any)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
Expect(m).To(HaveKey("a"))
|
||||
cur = m["a"]
|
||||
levels++
|
||||
}
|
||||
Expect(levels).To(Equal(gemma4MaxArgsDepth + 1))
|
||||
Expect(cur).To(BeAssignableToTypeOf(""))
|
||||
Expect(cur).To(ContainSubstring("a:{"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("decodeGemma4Array", func() {
|
||||
DescribeTable("decodes gemma4 array bodies into JSON arrays",
|
||||
func(in, wantJSON string) {
|
||||
Expect(decodeGemma4Array(in, 0)).To(MatchJSON(wantJSON))
|
||||
},
|
||||
// vLLM: test_string_array / test_empty_array / test_bare_values
|
||||
Entry("string array", `<|"|>a<|"|>,<|"|>b<|"|>`, `["a","b"]`),
|
||||
Entry("empty array", "", `[]`),
|
||||
Entry("bare values", "42,true,3.14", `[42,true,3.14]`),
|
||||
// vLLM: test_string_element_with_closing_bracket (a ']' inside a
|
||||
// delimited string must not close the array)
|
||||
Entry("string element with closing bracket", `[<|"|>a]b<|"|>,<|"|>c<|"|>],<|"|>tail<|"|>`, `[["a]b","c"],"tail"]`),
|
||||
// vLLM: test_stray_closing_bracket (no-progress abort, keep prefix)
|
||||
Entry("stray closing bracket", "42,]trailing", `[42]`),
|
||||
)
|
||||
})
|
||||
1060
backend/go/dllm/gemma4_renderer.go
Normal file
1060
backend/go/dllm/gemma4_renderer.go
Normal file
File diff suppressed because it is too large
Load Diff
406
backend/go/dllm/gemma4_renderer_test.go
Normal file
406
backend/go/dllm/gemma4_renderer_test.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package main
|
||||
|
||||
// Renderer specs for RenderGemma4 against the canonical gemma4 chat template
|
||||
// (see the normative template comment in gemma4_renderer.go).
|
||||
//
|
||||
// Fixture provenance:
|
||||
// - "single user message" and "enable_thinking" are the EXACT expected
|
||||
// decodes from transformers tests/models/diffusion_gemma/
|
||||
// test_modeling_diffusion_gemma.py (test_diffusion_gemma_chat_template
|
||||
// and ..._with_thinking) with ONE difference: the transformers fixtures
|
||||
// start with "<bos>" because apply_chat_template tokenizes the rendered
|
||||
// text with add_bos. Our prompt goes through dllm_capi_generate, whose
|
||||
// run_generate already tokenizes with prepend_bos = vocab.add_bos
|
||||
// (dllm.cpp src/capi.cpp:230-231, true for gemma4), so the renderer must
|
||||
// NOT emit a literal <bos> (it would double) and every expected string
|
||||
// here drops that leading token.
|
||||
// - All other expected strings were produced by rendering the verbatim
|
||||
// GGUF template with jinja2 3.1.2 (bos_token="<bos>") and dropping the
|
||||
// leading "<bos>" for the same reason.
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// Two-function tools array used by the tool fixtures (OpenAI wire shape, as
|
||||
// LocalAI passes it through PredictOptions.Tools).
|
||||
const testToolsJSON = `[{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a location.","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city name."},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}},{"type":"function","function":{"name":"get_time","description":"Get the current time in a timezone.","parameters":{"type":"object","properties":{"timezone":{"type":"string","description":"IANA timezone name."}},"required":["timezone"]}}}]`
|
||||
|
||||
// The <|tool>...<tool|> block the template renders for testToolsJSON inside
|
||||
// the system turn (jinja2-verified).
|
||||
const testToolsBlock = `<|tool>declaration:get_weather{description:<|"|>Get the current weather in a location.<|"|>,parameters:{properties:{location:{description:<|"|>The city name.<|"|>,type:<|"|>STRING<|"|>},unit:{enum:[<|"|>celsius<|"|>,<|"|>fahrenheit<|"|>],type:<|"|>STRING<|"|>}},required:[<|"|>location<|"|>],type:<|"|>OBJECT<|"|>}}<tool|><|tool>declaration:get_time{description:<|"|>Get the current time in a timezone.<|"|>,parameters:{properties:{timezone:{description:<|"|>IANA timezone name.<|"|>,type:<|"|>STRING<|"|>}},required:[<|"|>timezone<|"|>],type:<|"|>OBJECT<|"|>}}<tool|>`
|
||||
|
||||
// A single tool exercising the deep format_parameters branches: array items
|
||||
// (string-typed and nested-array), nullable, enum+nullable, nested object
|
||||
// properties/required, and a response declaration.
|
||||
const complexToolsJSON = `[{"type":"function","function":{"name":"complex_tool","description":"A complex tool.","parameters":{"type":"object","properties":{"tags":{"type":"array","description":"Tags.","items":{"type":"string"}},"matrix":{"type":"array","items":{"type":"array","items":{"type":"number"}}},"opts":{"type":"object","description":"Options.","properties":{"depth":{"type":"integer","nullable":true}},"required":["depth"]},"mode":{"type":"string","enum":["a","b"],"nullable":true}},"required":["tags","opts"]},"response":{"description":"The result.","type":"object"}}}]`
|
||||
|
||||
// jinja2-verified render of complexToolsJSON. Notable template quirks pinned
|
||||
// here: nested array items go through format_argument with ESCAPED keys and
|
||||
// an un-uppercased type (<|"|>type<|"|>:<|"|>number<|"|>), while direct item
|
||||
// types are uppercased; properties dictsort case-insensitively.
|
||||
const complexToolsBlock = `<|tool>declaration:complex_tool{description:<|"|>A complex tool.<|"|>,parameters:{properties:{matrix:{items:{items:{<|"|>type<|"|>:<|"|>number<|"|>},type:<|"|>ARRAY<|"|>},type:<|"|>ARRAY<|"|>},mode:{enum:[<|"|>a<|"|>,<|"|>b<|"|>],nullable:true,type:<|"|>STRING<|"|>},opts:{description:<|"|>Options.<|"|>,properties:{depth:{nullable:true,type:<|"|>INTEGER<|"|>}},required:[<|"|>depth<|"|>],type:<|"|>OBJECT<|"|>},tags:{description:<|"|>Tags.<|"|>,items:{type:<|"|>STRING<|"|>},type:<|"|>ARRAY<|"|>}},required:[<|"|>tags<|"|>,<|"|>opts<|"|>],type:<|"|>OBJECT<|"|>},response:{description:<|"|>The result.<|"|>,type:<|"|>OBJECT<|"|>}}<tool|>`
|
||||
|
||||
type renderGemma4Case struct {
|
||||
msgs []*pb.Message
|
||||
toolsJSON string
|
||||
// nImages mirrors len(PredictOptions.Images): the OpenAI layer strips
|
||||
// image content parts out of the messages, so the renderer re-injects
|
||||
// one engine marker per image on the last user message (see the IMAGE
|
||||
// NOTE on RenderGemma4).
|
||||
nImages int
|
||||
enableThinking bool
|
||||
noGenerationPrompt bool // inverted so the zero value is the common case
|
||||
expected string
|
||||
}
|
||||
|
||||
var _ = Describe("RenderGemma4", func() {
|
||||
DescribeTable("renders the canonical gemma4 prompt",
|
||||
func(c renderGemma4Case) {
|
||||
out, err := RenderGemma4(c.msgs, c.toolsJSON, c.nImages, c.enableThinking, !c.noGenerationPrompt)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal(c.expected))
|
||||
// The C-ABI generate prepends BOS itself: a literal <bos>
|
||||
// anywhere in the rendered prompt would double-encode it.
|
||||
Expect(out).ToNot(ContainSubstring("<bos>"))
|
||||
},
|
||||
|
||||
// transformers fixture (test_diffusion_gemma_chat_template), sans <bos>:
|
||||
// default thinking pre-opens an EMPTY thought channel in the
|
||||
// generation prompt.
|
||||
Entry("single user message, default (no thinking)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Write a long essay about Portugal."},
|
||||
},
|
||||
expected: "<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// transformers fixture (test_diffusion_gemma_chat_template_with_thinking),
|
||||
// sans <bos>: a system turn carrying <|think|> and NO auto-opened
|
||||
// thought channel.
|
||||
Entry("enable_thinking=true", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Write a long essay about Portugal."},
|
||||
},
|
||||
enableThinking: true,
|
||||
expected: "<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n",
|
||||
}),
|
||||
|
||||
Entry("multi-turn user/assistant/user", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Hello, who are you?"},
|
||||
{Role: "assistant", Content: "I am Gemma, a helpful assistant."},
|
||||
{Role: "user", Content: "Tell me a joke."},
|
||||
},
|
||||
expected: "<|turn>user\nHello, who are you?<turn|>\n<|turn>model\nI am Gemma, a helpful assistant.<turn|>\n<|turn>user\nTell me a joke.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L178-L195: a leading system message is folded into the system
|
||||
// turn (trimmed) and consumed from the loop.
|
||||
Entry("system message folds into the system turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "system", Content: "You are a pirate."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|turn>system\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L182-L185: <|think|> goes at the very top of the SAME system
|
||||
// turn, before the system prompt text.
|
||||
Entry("system message with enable_thinking shares the turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "system", Content: "You are a pirate."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
enableThinking: true,
|
||||
expected: "<|turn>system\n<|think|>\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n",
|
||||
}),
|
||||
|
||||
// tpl L196-L203: tool declarations render in the system turn, one
|
||||
// <|tool>declaration:...<tool|> block per tool, no separators.
|
||||
Entry("tools array (two functions)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
||||
},
|
||||
toolsJSON: testToolsJSON,
|
||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// format_parameters deep branches (tpl L1-L85) + response declaration
|
||||
// (tpl L106-L116).
|
||||
Entry("complex tool schema (array items, nullable, nested object, response)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
},
|
||||
toolsJSON: complexToolsJSON,
|
||||
expected: "<|turn>system\n" + complexToolsBlock + "<turn|>\n<|turn>user\ngo<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L243-L313: assistant tool_calls render as
|
||||
// <|tool_call>call:name{args}<tool_call|>; the following role=tool
|
||||
// message renders inline as <|tool_response>response:name{value:..}
|
||||
// <tool_response|>; the model turn stays OPEN (no <turn|>, no new
|
||||
// generation prompt) so the model continues after the response.
|
||||
Entry("assistant tool_calls + role=tool result", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
|
||||
{Role: "tool", ToolCallId: "call_1", Content: "Sunny, 22 degrees celsius."},
|
||||
},
|
||||
toolsJSON: testToolsJSON,
|
||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny, 22 degrees celsius.<|"|>}<tool_response|>`,
|
||||
}),
|
||||
|
||||
// tpl L348-L349: a tool_calls turn with no rendered responses ends
|
||||
// on an OPEN <|tool_response> marker for the runtime to fill, and
|
||||
// add_generation_prompt adds nothing (tpl L357).
|
||||
Entry("assistant tool_calls without a result leaves <|tool_response> open", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
|
||||
},
|
||||
toolsJSON: testToolsJSON,
|
||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>`,
|
||||
}),
|
||||
|
||||
// tpl L237-L241: reasoning_content renders as a thought channel only
|
||||
// on a tool-calling turn after the last user message.
|
||||
Entry("reasoning_content with tool_calls renders the thought channel", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "weather?"},
|
||||
{Role: "assistant", Content: "", ReasoningContent: "I should call the tool", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
|
||||
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
|
||||
},
|
||||
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n<|channel>thought\nI should call the tool\n<channel|>" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>`,
|
||||
}),
|
||||
|
||||
// tpl L220-L235: the assistant answer following its own tool round
|
||||
// continues the SAME model turn (no second <|turn>model).
|
||||
Entry("tool round then final assistant answer then user", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "weather?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
|
||||
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
|
||||
{Role: "assistant", Content: "It is sunny."},
|
||||
{Role: "user", Content: "thanks"},
|
||||
},
|
||||
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>` + "It is sunny.<turn|>\n<|turn>user\nthanks<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// format_argument (tpl L118-L147): numbers keep their JSON literal,
|
||||
// booleans lower-case, nested maps have unquoted dictsorted keys,
|
||||
// arrays bracketed; top-level args are dictsorted case-insensitively.
|
||||
Entry("tool_call argument types (number/bool/nested/array)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"count\":42,\"ratio\":3.5,\"flag\":true,\"off\":false,\"nested\":{\"x\":\"y\",\"n\":7},\"list\":[\"a\",1,true]}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{count:42,flag:true,list:[<|"|>a<|"|>,1,true],nested:{n:7,x:<|"|>y<|"|>},off:false,ratio:3.5}<tool_call|><|tool_response>`,
|
||||
}),
|
||||
|
||||
// jinja dictsort is case-insensitive: alpha sorts before Beta.
|
||||
Entry("tool_call argument dictsort is case-insensitive", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"Beta\":1,\"alpha\":2}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{alpha:2,Beta:1}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
// jinja renders Python None as "None" (round-trips through vLLM's
|
||||
// parser, which lowers "none" back to null).
|
||||
Entry("tool_call null argument renders as None", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"maybe\":null}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{maybe:None}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
Entry("tool_call empty arguments render empty braces", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
// tpl L253-L254: a non-object arguments string renders verbatim.
|
||||
Entry("tool_call non-object string arguments render verbatim", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"just text"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{just text}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
// tpl L278-L285: unmatched tool_call_id falls back to the tool
|
||||
// message's own name.
|
||||
Entry("tool result name falls back when tool_call_id does not match", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
|
||||
{Role: "tool", ToolCallId: "OTHER", Name: "named_tool", Content: "out"},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{}<tool_call|><|tool_response>response:named_tool{value:<|"|>out<|"|>}<tool_response|>`,
|
||||
}),
|
||||
|
||||
// strip_thinking (tpl L148-L158): historical assistant content loses
|
||||
// its <|channel>...<channel|> spans.
|
||||
Entry("assistant content thinking channels are stripped", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "<|channel>thought\nsecret\n<channel|>visible answer"},
|
||||
{Role: "user", Content: "more"},
|
||||
},
|
||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\nvisible answer<turn|>\n<|turn>user\nmore<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L220-L235: consecutive assistant messages suppress the second
|
||||
// <|turn>model (continuation), but each still closes with <turn|>.
|
||||
Entry("consecutive assistant messages continue the model turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "part one"},
|
||||
{Role: "assistant", Content: "part two"},
|
||||
{Role: "user", Content: "ok"},
|
||||
},
|
||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\npart one<turn|>\npart two<turn|>\n<|turn>user\nok<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
Entry("add_generation_prompt=false renders no model turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
},
|
||||
noGenerationPrompt: true,
|
||||
expected: "<|turn>user\nhi<turn|>\n",
|
||||
}),
|
||||
|
||||
// One engine marker per image, appended directly after the user
|
||||
// text with no separator (tpl L323-L341 emits parts back-to-back;
|
||||
// "<image>" is dllm_capi.h's splice marker, not the template's
|
||||
// <|image|> text token - see the IMAGE NOTE on RenderGemma4).
|
||||
Entry("one image appends one engine marker to the user message", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is in this picture?"},
|
||||
},
|
||||
nImages: 1,
|
||||
expected: "<|turn>user\nWhat is in this picture?<image><turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
Entry("multiple images append markers in image order", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Compare these."},
|
||||
},
|
||||
nImages: 3,
|
||||
expected: "<|turn>user\nCompare these.<image><image><image><turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// Flattened delivery loses per-message attribution, so all images
|
||||
// attach to the LAST user message (llama.cpp grpc-server convention).
|
||||
Entry("images attach to the last user message in multi-turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "hello"},
|
||||
{Role: "user", Content: "and this?"},
|
||||
},
|
||||
nImages: 1,
|
||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\nhello<turn|>\n<|turn>user\nand this?<image><turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L346: the markers count as captured_content, so an image-only
|
||||
// user message still has content and closes its turn normally.
|
||||
Entry("image with empty user text still closes the turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: ""},
|
||||
},
|
||||
nImages: 1,
|
||||
expected: "<|turn>user\n<image><turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
)
|
||||
|
||||
Describe("error handling", func() {
|
||||
It("fails loud on an unknown role", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "narrator", Content: "Meanwhile..."},
|
||||
}, "", 0, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(`unknown role "narrator"`))
|
||||
})
|
||||
|
||||
It("fails on invalid tools JSON", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, "{not json", 0, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools JSON"))
|
||||
})
|
||||
|
||||
It("fails on invalid tool_calls JSON", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "", ToolCalls: "{not json"},
|
||||
}, "", 0, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tool_calls JSON"))
|
||||
})
|
||||
|
||||
It("fails on an orphan tool message, naming its index", func() {
|
||||
// A role:tool message with no preceding assistant tool_calls turn
|
||||
// would be silently dropped by the jinja; we fail loud instead.
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "tool", Content: `{"temp": 20}`, ToolCallId: "call_1"},
|
||||
}, "", 0, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("orphan tool message 1"))
|
||||
})
|
||||
|
||||
It("fails on trailing garbage after the tools JSON array", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, "[] junk", 0, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools JSON"))
|
||||
})
|
||||
|
||||
It("fails when the tools JSON is not an array", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, `{"type":"function"}`, 0, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools JSON is not an array"))
|
||||
})
|
||||
|
||||
It("fails when a tools array element is not an object", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, `[42]`, 0, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools[0] is not an object"))
|
||||
})
|
||||
|
||||
It("rejects a nil message via the unknown-role check", func() {
|
||||
// Pins current behavior: pb getters are nil-safe, so a nil message
|
||||
// reads as role "" and trips the fail-loud unknown-role guard.
|
||||
_, err := RenderGemma4([]*pb.Message{nil}, "", 0, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(`unknown role "" in message 0`))
|
||||
})
|
||||
|
||||
It("fails loud on images with no user message to attach them to", func() {
|
||||
// The engine would reject the markerless prompt anyway
|
||||
// (marker/image count mismatch); the renderer surfaces the bad
|
||||
// request with a usable message instead.
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "assistant", Content: "hi"},
|
||||
}, "", 1, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no user message"))
|
||||
})
|
||||
})
|
||||
})
|
||||
98
backend/go/dllm/main.go
Normal file
98
backend/go/dllm/main.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
//
|
||||
// Loads libdllm.so via purego and registers the flat C-ABI declared in
|
||||
// dllm.cpp's include/dllm_capi.h (ABI v1): 9 mandatory symbols plus the
|
||||
// Dlsym-probed optional multimodal pair. The library name can
|
||||
// be overridden with DLLM_LIBRARY (mirrors the PARAKEET_LIBRARY /
|
||||
// WHISPER_LIBRARY convention in the sibling backends); the default looks
|
||||
// for the .so next to this binary (run.sh puts the package dir on
|
||||
// LD_LIBRARY_PATH).
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
// loadCAPI dlopens libName and binds the 9 dllm_capi_* entry points 1:1 to
|
||||
// dllm_capi.h, so an `nm libdllm.so | grep dllm_capi` is enough to spot
|
||||
// drift. Shared with the test suite (ensureLibLoaded), which drives the
|
||||
// bridge without the gRPC server.
|
||||
//
|
||||
// The C-ABI returns malloc'd char* buffers from tokenize_json/generate; we
|
||||
// register those as uintptr so we get the raw pointer back and can call
|
||||
// dllm_capi_free_string on it (purego's string return would copy and forget
|
||||
// the original pointer, leaking it on every call). last_error returns a
|
||||
// BORROWED pointer instead, so it is registered as a plain string: purego
|
||||
// copies it and nothing must be freed.
|
||||
func loadCAPI(libName string) error {
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dllm: dlopen %q: %w", libName, err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&cppAbiVersion, "dllm_capi_abi_version"},
|
||||
{&cppLoad, "dllm_capi_load"},
|
||||
{&cppFree, "dllm_capi_free"},
|
||||
{&cppLastError, "dllm_capi_last_error"},
|
||||
{&cppFreeString, "dllm_capi_free_string"},
|
||||
{&cppTokenizeJSON, "dllm_capi_tokenize_json"},
|
||||
{&cppGenerate, "dllm_capi_generate"},
|
||||
{&cppGenerateStream, "dllm_capi_generate_stream"},
|
||||
{&cppCancel, "dllm_capi_cancel"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
// Multimodal entry points (dllm_capi.h's P4 surface). Additive: the ABI
|
||||
// version stays 1 and consumers detect the surface by probing the symbols
|
||||
// (the parakeet-cpp optional-symbol pattern), so the backend still loads
|
||||
// against an older text-only libdllm.so - image requests then fail with
|
||||
// errMMUnsupported instead of a boot failure.
|
||||
if sym, err := purego.Dlsym(lib, "dllm_capi_generate_mm"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&cppGenerateMM, lib, "dllm_capi_generate_mm")
|
||||
}
|
||||
if sym, err := purego.Dlsym(lib, "dllm_capi_generate_stream_mm"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&cppGenerateStreamMM, lib, "dllm_capi_generate_stream_mm")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("DLLM_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libdllm.so"
|
||||
}
|
||||
|
||||
if err := loadCAPI(libName); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Hard-fail on an ABI mismatch: the flat-pointer bindings above would
|
||||
// otherwise misbehave silently against a future libdllm.so.
|
||||
if v := cAbiVersion(); v != dllmABIVersion {
|
||||
panic(fmt.Errorf("dllm: libdllm.so ABI=%d, this backend speaks ABI=%d", v, dllmABIVersion))
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "[dllm] ABI=%d multimodal=%t\n", cAbiVersion(), cMMSupported())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &Dllm{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
24
backend/go/dllm/package.sh
Executable file
24
backend/go/dllm/package.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# T1 packaging stub: copy the binary, run.sh and libdllm.so into package/.
|
||||
# The full ldd walk (libc, libstdc++, libgomp, GPU runtimes, arch
|
||||
# detection) lands with the registration task, mirroring
|
||||
# backend/go/whisper/package.sh.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/dllm-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libdllm.so + any soname symlinks, should upstream ever add them.
|
||||
cp -avf "$CURDIR"/libdllm.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libdllm.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo "T1 package layout (full ldd walk lands with registration):"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
16
backend/go/dllm/run.sh
Executable file
16
backend/go/dllm/run.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# If a self-contained ld.so was packaged, route through it so the
|
||||
# packaged libc / libstdc++ are used instead of the host's (matches the
|
||||
# whisper / parakeet-cpp backends' runtime layout).
|
||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec "$CURDIR/lib/ld.so" "$CURDIR/dllm-grpc" "$@"
|
||||
fi
|
||||
|
||||
exec "$CURDIR/dllm-grpc" "$@"
|
||||
@@ -1,6 +1,6 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
@@ -15,7 +15,7 @@
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
|
||||
@@ -7,8 +7,12 @@ import "time"
|
||||
type batchRequest struct {
|
||||
pcm []float32
|
||||
decoder int32
|
||||
tag string
|
||||
reply chan batchReply
|
||||
// language is the per-request target locale ("" means the model default).
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang for the whole batch,
|
||||
// so the dispatcher only coalesces requests that share a language.
|
||||
language string
|
||||
tag string
|
||||
reply chan batchReply
|
||||
}
|
||||
|
||||
// batchReply carries one per-item JSON object string (an element of the C-API's
|
||||
@@ -43,13 +47,25 @@ func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchReques
|
||||
// run is the dispatcher loop: accumulate submitted requests until either maxSize
|
||||
// is reached or maxWait elapses since the first queued request, then dispatch.
|
||||
// Exits when stop is closed (draining any partially-filled batch first).
|
||||
//
|
||||
// A batch carries ONE language (parakeet.cpp's batched C-API takes a single
|
||||
// target_lang), so a request whose language differs from the batch leader is
|
||||
// not coalesced: it is held in carry and becomes the leader of the next batch.
|
||||
// carry is therefore never dropped and its caller never deadlocks: every batch
|
||||
// (including a lone carry on stop) is dispatched, and runBatch replies to all.
|
||||
func (b *batcher) run(stop <-chan struct{}) {
|
||||
var carry *batchRequest
|
||||
for {
|
||||
var first *batchRequest
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
if carry != nil {
|
||||
// A mismatched request from the previous fill leads this batch.
|
||||
first, carry = carry, nil
|
||||
} else {
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
batch := []*batchRequest{first}
|
||||
|
||||
@@ -64,12 +80,22 @@ func (b *batcher) run(stop <-chan struct{}) {
|
||||
for len(batch) < b.maxSize {
|
||||
select {
|
||||
case r := <-b.submit:
|
||||
if r.language != first.language {
|
||||
// Different language: carry it to the next batch so this
|
||||
// batch stays single-language, then dispatch what we have.
|
||||
carry = r
|
||||
break fill
|
||||
}
|
||||
batch = append(batch, r)
|
||||
case <-timer.C:
|
||||
break fill
|
||||
case <-stop:
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
// Don't strand a carried request's caller on shutdown.
|
||||
if carry != nil {
|
||||
b.runBatch([]*batchRequest{carry})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,4 +105,60 @@ var _ = Describe("batcher", func() {
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("never coalesces requests with different languages into one batch", func() {
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang per batch, so the
|
||||
// dispatcher must keep every dispatched batch single-language. Submit a
|
||||
// mix of languages and assert (a) no batch ever carries more than one
|
||||
// distinct language and (b) every submitted request still gets a reply
|
||||
// (the mismatched carry-over is never dropped).
|
||||
var mu sync.Mutex
|
||||
var langsPerBatch [][]string
|
||||
run := func(reqs []*batchRequest) {
|
||||
seen := map[string]struct{}{}
|
||||
var distinct []string
|
||||
for _, r := range reqs {
|
||||
if _, ok := seen[r.language]; !ok {
|
||||
seen[r.language] = struct{}{}
|
||||
distinct = append(distinct, r.language)
|
||||
}
|
||||
}
|
||||
mu.Lock()
|
||||
langsPerBatch = append(langsPerBatch, distinct)
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
// Large window + size so the fill loop stays open across submits and the
|
||||
// language constraint (not the timer) is what splits the batches.
|
||||
b := newBatcher(16, 200*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
langs := []string{"en", "en", "de", "de", "en", "fr", "fr"}
|
||||
const N = 7
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), language: langs[i], reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// Invariant: every dispatched batch is single-language.
|
||||
for _, distinct := range langsPerBatch {
|
||||
Expect(len(distinct)).To(Equal(1), "a batch coalesced more than one language: %v", distinct)
|
||||
}
|
||||
// Liveness: every request got a reply (carry-over never stranded).
|
||||
for i := 0; i < N; i++ {
|
||||
Expect(got[i]).To(Equal(string(rune('a' + i))))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -47,6 +48,13 @@ var (
|
||||
// side reads them as const float*/const int*.
|
||||
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
|
||||
|
||||
// CppTranscribePcmBatchJSONLang is the multilingual variant of the batched
|
||||
// JSON entry point: identical, plus a trailing target_lang. "" (the model
|
||||
// default, "auto") is passed for non-prompt models, which ignore it; an
|
||||
// unknown locale on a prompt model returns 0 and sets last_error. Present
|
||||
// only in newer libparakeet.so; nil falls back to CppTranscribePcmBatchJSON.
|
||||
CppTranscribePcmBatchJSONLang func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32, targetLang string) uintptr
|
||||
|
||||
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
|
||||
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
|
||||
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
|
||||
@@ -54,6 +62,18 @@ var (
|
||||
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
|
||||
CppStreamFinalize func(s uintptr) uintptr
|
||||
CppStreamFree func(s uintptr)
|
||||
|
||||
// CppStreamBeginLang is the multilingual variant of stream_begin: identical,
|
||||
// plus a trailing target_lang ("" means the model default). Present only in
|
||||
// newer libparakeet.so; nil falls back to CppStreamBegin.
|
||||
CppStreamBeginLang func(ctx uintptr, targetLang string) uintptr
|
||||
|
||||
// Streaming JSON variants (ABI v4): feed/finalize returning a malloc'd char*
|
||||
// JSON document {text,eou,frame_sec,words} (uintptr, freed via CppFreeString)
|
||||
// so streaming segments can carry per-word timestamps. Present only in newer
|
||||
// libparakeet.so; nil falls back to the text-only CppStreamFeed/Finalize path.
|
||||
CppStreamFeedJSON func(s uintptr, pcm []float32, nSamples int32) uintptr
|
||||
CppStreamFinalizeJSON func(s uintptr) uintptr
|
||||
)
|
||||
|
||||
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
|
||||
@@ -71,9 +91,26 @@ const streamChunkSamples = 16000
|
||||
//
|
||||
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
|
||||
type transcriptJSON struct {
|
||||
Text string `json:"text"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
Text string `json:"text"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
}
|
||||
|
||||
// streamFeedJSON mirrors the document returned by
|
||||
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v4):
|
||||
//
|
||||
// {"text":"...","eou":0,"frame_sec":0.080000,
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
|
||||
// <EOU>/<EOB> fired this feed; "words" are the words finalized this call with
|
||||
// absolute (stream-relative) start/end seconds.
|
||||
type streamFeedJSON struct {
|
||||
Text string `json:"text"`
|
||||
Eou int `json:"eou"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
}
|
||||
|
||||
type transcriptWord struct {
|
||||
@@ -102,6 +139,10 @@ type ParakeetCpp struct {
|
||||
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
|
||||
bat *batcher
|
||||
batStop chan struct{}
|
||||
// segmentGapFrames is NeMo's segment_gap_threshold in ENCODER FRAMES (model
|
||||
// YAML option, default 0=off). When >0 it adds NeMo's silence-gap split on
|
||||
// top of the punctuation split; converted to seconds via the JSON frame_sec.
|
||||
segmentGapFrames int
|
||||
}
|
||||
|
||||
// Load is the LocalAI gRPC entry point for LoadModel: it calls
|
||||
@@ -131,6 +172,11 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
if maxWaitMs < 0 {
|
||||
maxWaitMs = 0
|
||||
}
|
||||
|
||||
// NeMo's segment_gap_threshold (encoder frames, default 0=off). Off by
|
||||
// default matches NeMo's default (punctuation-only segments); when set it
|
||||
// additionally splits segments on inter-word silence (see transcriptResultFromDoc).
|
||||
p.segmentGapFrames = optInt(opts, "segment_gap_threshold", 0)
|
||||
if CppTranscribePcmBatchJSON != nil {
|
||||
p.batStop = make(chan struct{})
|
||||
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
|
||||
@@ -186,8 +232,19 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
if len(reqs) > 0 {
|
||||
dec = reqs[0].decoder
|
||||
}
|
||||
// All requests in a batch share one language (the batcher coalesces only
|
||||
// same-language requests), so any element's language describes the batch.
|
||||
lang := ""
|
||||
if len(reqs) > 0 {
|
||||
lang = reqs[0].language
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
var cstr uintptr
|
||||
if CppTranscribePcmBatchJSONLang != nil {
|
||||
cstr = CppTranscribePcmBatchJSONLang(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec, lang)
|
||||
} else {
|
||||
cstr = CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
}
|
||||
p.engineMu.Unlock()
|
||||
if cstr == 0 {
|
||||
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
|
||||
@@ -225,21 +282,31 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// OpenAI API, whose default is segment-level); token ids always populate
|
||||
// Segment.Tokens.
|
||||
//
|
||||
// translate/diarize/prompt/temperature/language/threads are not applicable to
|
||||
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
|
||||
// translate/diarize/prompt/temperature/threads are not applicable to parakeet
|
||||
// and are ignored; language is honored on the batched + streaming paths (see
|
||||
// opts.GetLanguage() below); streaming is handled by AudioTranscriptionStream
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
|
||||
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
// Fallback when the batched C-API is unavailable: transcribe directly from
|
||||
// the file path (original behavior, no batching).
|
||||
// Fallback when the batched C-API is unavailable: transcribe from a file
|
||||
// path (original behavior, no batching). The C library's audio loader only
|
||||
// understands 16 kHz mono WAV/PCM, so convert the input first - otherwise
|
||||
// any non-WAV upload (MP3, etc.) fails with "failed to load audio". This
|
||||
// mirrors what every other audio backend (whisper, crispasr) does via
|
||||
// utils.AudioToWav before handing the file to the engine.
|
||||
if p.bat == nil {
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
|
||||
converted, cleanup, err := convertToWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, converted, 0)
|
||||
if cstr == 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
|
||||
}
|
||||
@@ -249,7 +316,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
}
|
||||
|
||||
// Batched path: decode to PCM, submit to the batcher, wait for this request's
|
||||
@@ -261,7 +328,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
|
||||
}
|
||||
rep := make(chan batchReply, 1)
|
||||
select {
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, language: opts.GetLanguage(), reply: rep}:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
@@ -278,34 +345,169 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
|
||||
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
}
|
||||
|
||||
// segmentSeparators is NeMo's default segment_seperators (sentence-ending
|
||||
// punctuation). Splitting on these matches NeMo's default segment timestamps.
|
||||
var segmentSeparators = []rune{'.', '?', '!'}
|
||||
|
||||
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
|
||||
// synthesising a single whole-clip segment and attaching word timings only when
|
||||
// the caller requested word granularity. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
|
||||
// grouping words into NeMo-faithful segments (see splitWordsIntoSegments). The
|
||||
// optional gapFrames (NeMo's segment_gap_threshold, in encoder FRAMES; 0=off)
|
||||
// additionally splits on inter-word silence; it is converted to a seconds gap
|
||||
// with the document's frame_sec. Per-segment word timings are attached only when
|
||||
// the caller requested word granularity; token ids populate each segment's
|
||||
// Tokens by time-window membership. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
|
||||
for _, w := range doc.Words {
|
||||
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
|
||||
|
||||
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
|
||||
gapSeconds := 0.0
|
||||
if gapFrames > 0 {
|
||||
if doc.FrameSec > 0 {
|
||||
gapSeconds = float64(gapFrames) * doc.FrameSec
|
||||
} else {
|
||||
xlog.Warn("parakeet-cpp: segment_gap_threshold set but libparakeet.so " +
|
||||
"did not report frame_sec; falling back to punctuation-only segments")
|
||||
}
|
||||
}
|
||||
tokens := make([]int32, 0, len(doc.Tokens))
|
||||
for _, t := range doc.Tokens {
|
||||
tokens = append(tokens, t.ID)
|
||||
|
||||
groups := splitWordsIntoSegments(doc.Words, segmentSeparators, gapSeconds)
|
||||
if len(groups) == 0 {
|
||||
// No words (edge case): single whole-clip text segment.
|
||||
return pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
|
||||
}
|
||||
}
|
||||
var segStart, segEnd int64
|
||||
if len(words) > 0 {
|
||||
segStart = words[0].Start
|
||||
segEnd = words[len(words)-1].End
|
||||
|
||||
wantWords := wordsRequested(opts.TimestampGranularities)
|
||||
segments := make([]*pb.TranscriptSegment, 0, len(groups))
|
||||
for id, group := range groups {
|
||||
parts := make([]string, len(group))
|
||||
for i, gw := range group {
|
||||
parts[i] = gw.W
|
||||
}
|
||||
seg := &pb.TranscriptSegment{
|
||||
Id: int32(id),
|
||||
Start: secondsToNanos(group[0].Start),
|
||||
End: secondsToNanos(group[len(group)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
Tokens: tokensInWindow(doc.Tokens, group[0].Start, group[len(group)-1].End),
|
||||
}
|
||||
if wantWords {
|
||||
ws := make([]*pb.TranscriptWord, len(group))
|
||||
for i, gw := range group {
|
||||
ws[i] = &pb.TranscriptWord{Start: secondsToNanos(gw.Start), End: secondsToNanos(gw.End), Text: gw.W}
|
||||
}
|
||||
seg.Words = ws
|
||||
}
|
||||
segments = append(segments, seg)
|
||||
}
|
||||
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
|
||||
if wordsRequested(opts.TimestampGranularities) {
|
||||
seg.Words = words
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
|
||||
return pb.TranscriptResult{Text: text, Segments: segments}
|
||||
}
|
||||
|
||||
// splitWordsIntoSegments groups words into segments exactly as NeMo's
|
||||
// get_segment_offsets does (nemo/collections/asr/parts/utils/timestamp_utils.py).
|
||||
// Walking the words, it closes a segment when (1) the gap rule is enabled
|
||||
// (gapSeconds > 0) and the segment already has words and the gap from the
|
||||
// previous word's end to this word's start is >= gapSeconds - the current word
|
||||
// then STARTS a new segment - or, checked only when the gap rule did not apply
|
||||
// (NeMo's elif), (2) the word ends with (or is) a separator, which closes the
|
||||
// segment INCLUDING that word. Trailing words flush into a final segment.
|
||||
// gapSeconds <= 0 disables the gap rule, matching NeMo's default
|
||||
// segment_gap_threshold=None (punctuation-only segments).
|
||||
func splitWordsIntoSegments(words []transcriptWord, separators []rune, gapSeconds float64) [][]transcriptWord {
|
||||
var segments [][]transcriptWord
|
||||
var cur []transcriptWord
|
||||
for i, word := range words {
|
||||
gapActive := gapSeconds > 0 && len(cur) > 0
|
||||
if gapActive && (word.Start-words[i-1].End) >= gapSeconds {
|
||||
segments = append(segments, cur)
|
||||
cur = []transcriptWord{word}
|
||||
continue
|
||||
}
|
||||
if !gapActive && endsWithSeparator(word.W, separators) {
|
||||
cur = append(cur, word)
|
||||
segments = append(segments, cur)
|
||||
cur = nil
|
||||
continue
|
||||
}
|
||||
cur = append(cur, word)
|
||||
}
|
||||
if len(cur) > 0 {
|
||||
segments = append(segments, cur)
|
||||
}
|
||||
return segments
|
||||
}
|
||||
|
||||
// endsWithSeparator reports whether w's last rune is in separators (matching
|
||||
// NeMo's `word[-1] in delims or word in delims`).
|
||||
func endsWithSeparator(w string, separators []rune) bool {
|
||||
r := []rune(strings.TrimSpace(w))
|
||||
if len(r) == 0 {
|
||||
return false
|
||||
}
|
||||
last := r[len(r)-1]
|
||||
for _, s := range separators {
|
||||
if last == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// tokensInWindow returns the ids of tokens whose timestamp t falls in
|
||||
// [start, end] (inclusive), assigning each token to the segment that spans its
|
||||
// time. The last segment's end is the last word end, so the final token is
|
||||
// included.
|
||||
func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
|
||||
var ids []int32
|
||||
for _, t := range tokens {
|
||||
if t.T >= start && t.T <= end {
|
||||
ids = append(ids, t.ID)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
|
||||
// is the model's own utterance boundary; each closed segment takes its start/end
|
||||
// from its first/last accumulated word.
|
||||
type streamSegmenter struct {
|
||||
segs []*pb.TranscriptSegment
|
||||
cur []transcriptWord
|
||||
nextID int32
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) add(doc streamFeedJSON) {
|
||||
s.cur = append(s.cur, doc.Words...)
|
||||
if doc.Eou != 0 {
|
||||
s.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) flush() {
|
||||
if len(s.cur) == 0 {
|
||||
return
|
||||
}
|
||||
parts := make([]string, len(s.cur))
|
||||
for i, w := range s.cur {
|
||||
parts[i] = w.W
|
||||
}
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{
|
||||
Id: s.nextID,
|
||||
Start: secondsToNanos(s.cur[0].Start),
|
||||
End: secondsToNanos(s.cur[len(s.cur)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
})
|
||||
s.nextID++
|
||||
s.cur = nil
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
|
||||
|
||||
// wordsRequested reports whether the caller asked for word-level timestamps.
|
||||
// The OpenAI transcription API gates word timings behind
|
||||
// timestamp_granularities[] containing "word" and defaults to segment-level
|
||||
@@ -342,7 +544,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
defer close(results)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return errors.New("parakeet-cpp: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
@@ -351,7 +553,12 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
stream := CppStreamBegin(p.ctxPtr)
|
||||
var stream uintptr
|
||||
if CppStreamBeginLang != nil {
|
||||
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
|
||||
} else {
|
||||
stream = CppStreamBegin(p.ctxPtr)
|
||||
}
|
||||
if stream == 0 {
|
||||
// Not a cache-aware streaming model: run a normal offline
|
||||
// transcription and emit it as one delta + a closing final result.
|
||||
@@ -380,6 +587,14 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return err
|
||||
}
|
||||
|
||||
// ABI v4: when the streaming JSON entry points are present, drive them so the
|
||||
// per-utterance segments carry per-word start/end timestamps. Falls through to
|
||||
// the text-only loop below against an older libparakeet.so. Runs under the
|
||||
// engineMu already held above.
|
||||
if CppStreamFeedJSON != nil {
|
||||
return p.streamJSON(ctx, stream, data, duration, results)
|
||||
}
|
||||
|
||||
var (
|
||||
full strings.Builder
|
||||
segText strings.Builder
|
||||
@@ -456,21 +671,102 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamJSON drives the ABI v4 streaming JSON entry points: each feed/finalize
|
||||
// returns a {text,eou,frame_sec,words} document. The newly-finalized text is
|
||||
// emitted as a delta (unchanged streaming contract) while words are accumulated
|
||||
// into per-utterance segments (closed on EOU) so the closing FinalResult carries
|
||||
// timestamped segments. Runs under engineMu (already held by the caller).
|
||||
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
|
||||
duration float32, results chan *pb.TranscriptStreamResponse) error {
|
||||
var (
|
||||
full strings.Builder
|
||||
seg streamSegmenter
|
||||
)
|
||||
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
|
||||
// emits the delta, and routes words through the segmenter.
|
||||
consume := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
raw := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
var doc streamFeedJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
|
||||
}
|
||||
if doc.Text != "" {
|
||||
full.WriteString(doc.Text)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
|
||||
}
|
||||
seg.add(doc)
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
|
||||
return err
|
||||
}
|
||||
seg.flush() // close any trailing utterance that never saw an EOU
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
segments := seg.segments()
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
|
||||
// float samples plus the clip duration in seconds. Mirrors the whisper
|
||||
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
|
||||
// decodes the PCM.
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
// convertToWavMono16k converts an arbitrary audio file to a 16 kHz mono WAV in
|
||||
// a fresh temp dir and returns the path together with a cleanup func the caller
|
||||
// must defer. WAV inputs already at 16 kHz/mono/16-bit are passed through by
|
||||
// utils.AudioToWav (hardlink/copy), everything else is transcoded via ffmpeg.
|
||||
// Used by the direct (non-batched) transcription path, which hands a file path
|
||||
// to the C library's WAV-only audio loader.
|
||||
func convertToWavMono16k(path string) (string, func(), error) {
|
||||
dir, err := os.MkdirTemp("", "parakeet")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return "", func() {}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
cleanup := func() { _ = os.RemoveAll(dir) }
|
||||
|
||||
converted := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(path, converted); err != nil {
|
||||
cleanup()
|
||||
return "", func() {}, err
|
||||
}
|
||||
return converted, cleanup, nil
|
||||
}
|
||||
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
converted, cleanup, err := convertToWavMono16k(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
fh, err := os.Open(converted)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,11 +3,14 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -50,6 +53,10 @@ func ensureLibLoaded() {
|
||||
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
|
||||
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
|
||||
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
|
||||
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
|
||||
})
|
||||
@@ -70,6 +77,24 @@ func fixturesOrSkip() (string, string) {
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// writeMono16kWav writes `samples` frames of 16 kHz mono 16-bit silence to
|
||||
// path. The result is already in AudioToWav's target format, so the conversion
|
||||
// helper copies it through without invoking ffmpeg.
|
||||
func writeMono16kWav(path string, samples int) {
|
||||
GinkgoHelper()
|
||||
f, err := os.Create(path)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
enc := wav.NewEncoder(f, 16000, 16, 1, 1)
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 16000},
|
||||
SourceBitDepth: 16,
|
||||
Data: make([]int, samples),
|
||||
}
|
||||
Expect(enc.Write(buf)).To(Succeed())
|
||||
Expect(enc.Close()).To(Succeed())
|
||||
Expect(f.Close()).To(Succeed())
|
||||
}
|
||||
|
||||
var _ = Describe("ParakeetCpp", func() {
|
||||
Context("AudioTranscription", func() {
|
||||
It("transcribes a WAV via the parakeet C-API", func() {
|
||||
@@ -86,13 +111,22 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
|
||||
"expected non-empty transcript for %s", audioPath)
|
||||
Expect(res.Segments).To(HaveLen(1),
|
||||
"synthesises a single whole-clip segment")
|
||||
Expect(res.Segments[0].Text).To(Equal(res.Text),
|
||||
"single segment text must equal the top-level text")
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(res.Segments[0].Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
// NeMo-faithful segmentation: one or more punctuation-delimited
|
||||
// segments, each with text and a monotonically-advancing time span.
|
||||
Expect(res.Segments).ToNot(BeEmpty(), "expected at least one segment")
|
||||
var prevEnd int64
|
||||
for i, seg := range res.Segments {
|
||||
Expect(strings.TrimSpace(seg.Text)).ToNot(BeEmpty(),
|
||||
"segment %d must have text", i)
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start),
|
||||
"segment %d end must not precede its start", i)
|
||||
Expect(seg.Start).To(BeNumerically(">=", prevEnd),
|
||||
"segments must be in time order")
|
||||
prevEnd = seg.End
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(seg.Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
}
|
||||
})
|
||||
|
||||
It("emits word-level timestamps when granularity=word", func() {
|
||||
@@ -108,15 +142,61 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
TimestampGranularities: []string{"word"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
seg := res.Segments[0]
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"expected per-word timestamps with granularity=word")
|
||||
// Monotonic, non-negative timings spanning the segment.
|
||||
Expect(seg.Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start))
|
||||
Expect(seg.Words[len(seg.Words)-1].End).To(Equal(seg.End),
|
||||
"segment end tracks the last word")
|
||||
Expect(res.Segments).ToNot(BeEmpty())
|
||||
// With word granularity every segment carries its own words, and each
|
||||
// segment's span tracks its first/last word; word starts advance
|
||||
// monotonically across the whole transcript.
|
||||
totalWords := 0
|
||||
var prevStart int64 = -1
|
||||
for i, seg := range res.Segments {
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"segment %d must carry per-word timestamps with granularity=word", i)
|
||||
Expect(seg.Start).To(Equal(seg.Words[0].Start),
|
||||
"segment %d start tracks its first word", i)
|
||||
Expect(seg.End).To(Equal(seg.Words[len(seg.Words)-1].End),
|
||||
"segment %d end tracks its last word", i)
|
||||
for _, w := range seg.Words {
|
||||
Expect(w.End).To(BeNumerically(">=", w.Start))
|
||||
Expect(w.Start).To(BeNumerically(">=", prevStart))
|
||||
prevStart = w.Start
|
||||
totalWords++
|
||||
}
|
||||
}
|
||||
Expect(totalWords).To(BeNumerically(">", 0))
|
||||
Expect(res.Segments[0].Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("convertToWavMono16k", func() {
|
||||
// The non-batched transcription path hands a file path to the C
|
||||
// library's WAV-only audio loader, so it must convert first.
|
||||
// utils.AudioToWav passes an already-16kHz/mono/16-bit WAV through
|
||||
// without ffmpeg, which lets us exercise the helper (and the
|
||||
// regression: the direct path used to skip conversion entirely)
|
||||
// without a model, the C library, or ffmpeg.
|
||||
It("returns a decodable 16kHz mono WAV copy and cleans it up", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
src := filepath.Join(dir, "input.wav")
|
||||
writeMono16kWav(src, 16000) // 1s of silence at 16 kHz
|
||||
|
||||
converted, cleanup, err := convertToWavMono16k(src)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// It must produce a fresh temp file, not return the original path.
|
||||
Expect(converted).ToNot(Equal(src))
|
||||
Expect(converted).To(BeAnExistingFile())
|
||||
|
||||
pcm, _, err := decodeWavMono16k(converted)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pcm).To(HaveLen(16000), "round-trips the sample count")
|
||||
|
||||
cleanup()
|
||||
Expect(converted).ToNot(BeAnExistingFile(), "cleanup removes the temp dir")
|
||||
})
|
||||
|
||||
It("errors on a non-existent input rather than passing the path through", func() {
|
||||
_, _, err := convertToWavMono16k(filepath.Join(GinkgoT().TempDir(), "missing.mp3"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -65,6 +65,25 @@ func main() {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
|
||||
// Per-request language variants (multilingual nemotron). Same probe pattern:
|
||||
// present only in libparakeet.so built with multilingual support, so the
|
||||
// backend still loads against an older library and falls back to the
|
||||
// non-lang batched + streaming entry points (model default / "auto").
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSONLang, lib, "parakeet_capi_transcribe_pcm_batch_json_lang")
|
||||
}
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_begin_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamBeginLang, lib, "parakeet_capi_stream_begin_lang")
|
||||
}
|
||||
|
||||
// Streaming JSON entry points (ABI v4): surface per-word timestamps on the
|
||||
// streaming path. Same probe pattern; absent in older libparakeet.so, where
|
||||
// the backend falls back to the text-only streaming feed.
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
127
backend/go/parakeet-cpp/segments_test.go
Normal file
127
backend/go/parakeet-cpp/segments_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func tw(text string, start, end float64) transcriptWord {
|
||||
return transcriptWord{W: text, Start: start, End: end}
|
||||
}
|
||||
|
||||
var _ = Describe("splitWordsIntoSegments (NeMo get_segment_offsets parity)", func() {
|
||||
seps := []rune{'.', '?', '!'}
|
||||
|
||||
It("splits on sentence-ending punctuation, including the delimiter word", func() {
|
||||
words := []transcriptWord{tw("hello", 0, 0.4), tw("world.", 0.4, 0.8), tw("bye", 1.0, 1.3)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[0][1].W).To(Equal("world."))
|
||||
Expect(segs[1]).To(HaveLen(1))
|
||||
Expect(segs[1][0].W).To(Equal("bye"))
|
||||
})
|
||||
|
||||
It("keeps a single segment with no terminal punctuation and gap off", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("splits on the gap rule when enabled, the gapped word starting the next segment", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0) // c is 4.6s after b
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2)) // a b
|
||||
Expect(segs[1]).To(HaveLen(1)) // c
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("checks the gap rule before punctuation (NeMo elif order)", func() {
|
||||
// "b." would terminate, but c is far after it -> gap closes [a b.] at b.
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b.", 0.2, 0.4), tw("c", 9.0, 9.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("still splits on punctuation when the gap rule is enabled but does not fire", func() {
|
||||
words := []transcriptWord{tw("hi.", 0, 0.4), tw("bye", 0.4, 0.8)}
|
||||
segs := splitWordsIntoSegments(words, seps, 5.0) // gap never reached
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0][0].W).To(Equal("hi."))
|
||||
})
|
||||
|
||||
It("returns nothing for empty input", func() {
|
||||
Expect(splitWordsIntoSegments(nil, seps, 0)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
|
||||
doc := transcriptJSON{
|
||||
Text: "hello world. bye now",
|
||||
FrameSec: 0.08,
|
||||
Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4},
|
||||
{W: "world.", Start: 0.4, End: 0.8},
|
||||
{W: "bye", Start: 1.0, End: 1.3},
|
||||
{W: "now", Start: 1.3, End: 1.6},
|
||||
},
|
||||
Tokens: []transcriptToken{{ID: 1, T: 0.1}, {ID: 2, T: 0.5}, {ID: 3, T: 1.1}, {ID: 4, T: 1.4}},
|
||||
}
|
||||
|
||||
It("emits one segment per punctuation-delimited group with start/end", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(2))
|
||||
Expect(res.Segments[0].Text).To(Equal("hello world."))
|
||||
Expect(res.Segments[0].Start).To(Equal(int64(0)))
|
||||
Expect(res.Segments[0].End).To(Equal(secondsToNanos(0.8)))
|
||||
Expect(res.Segments[1].Text).To(Equal("bye now"))
|
||||
Expect(res.Segments[1].Start).To(Equal(secondsToNanos(1.0)))
|
||||
Expect(res.Segments[1].Id).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("assigns tokens to the segment whose time window contains them", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments[0].Tokens).To(Equal([]int32{1, 2}))
|
||||
Expect(res.Segments[1].Tokens).To(Equal([]int32{3, 4}))
|
||||
})
|
||||
|
||||
It("attaches per-segment words only when word granularity requested", func() {
|
||||
plain := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(plain.Segments[0].Words).To(BeEmpty())
|
||||
withWords := transcriptResultFromDoc(doc, &pb.TranscriptRequest{TimestampGranularities: []string{"word"}}, 0)
|
||||
Expect(withWords.Segments[0].Words).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("falls back to a single text segment when there are no words", func() {
|
||||
res := transcriptResultFromDoc(transcriptJSON{Text: "hi"}, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
Expect(res.Segments[0].Text).To(Equal("hi"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("streaming segment assembly", func() {
|
||||
It("closes a segment with start/end from its words on EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
Expect(segs).To(HaveLen(1))
|
||||
Expect(segs[0].Text).To(Equal("hello world"))
|
||||
Expect(segs[0].Start).To(Equal(int64(0)))
|
||||
Expect(segs[0].End).To(Equal(secondsToNanos(0.9)))
|
||||
})
|
||||
|
||||
It("buffers words across feeds until EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
|
||||
Expect(acc.segments()).To(BeEmpty())
|
||||
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
|
||||
Expect(acc.segments()).To(HaveLen(1))
|
||||
Expect(acc.segments()[0].Text).To(Equal("hi there"))
|
||||
})
|
||||
})
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# qwen3-tts.cpp version
|
||||
QWEN3TTS_REPO?=https://github.com/predict-woo/qwen3-tts.cpp
|
||||
QWEN3TTS_CPP_VERSION?=7a762e2ad4bacc6fdda81d81bf10a09ffb546f29
|
||||
QWEN3TTS_CPP_VERSION?=136e5d36c17083da0321fd96512dc7b263f94a44
|
||||
SO_TARGET?=libgoqwen3ttscpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -21,6 +22,43 @@ type Qwen3TtsCpp struct {
|
||||
threads int
|
||||
}
|
||||
|
||||
// languageNameAliases maps common full language names to the canonical
|
||||
// two-letter code understood by the C++ language_to_id table.
|
||||
var languageNameAliases = map[string]string{
|
||||
"english": "en",
|
||||
"russian": "ru",
|
||||
"chinese": "zh",
|
||||
"japanese": "ja",
|
||||
"korean": "ko",
|
||||
"german": "de",
|
||||
"french": "fr",
|
||||
"spanish": "es",
|
||||
"italian": "it",
|
||||
"portuguese": "pt",
|
||||
}
|
||||
|
||||
// normalizeLanguage coerces a caller-supplied language into the canonical code
|
||||
// the model expects. It lowercases, trims, strips any region/locale suffix
|
||||
// (en-US, en_US, ja.JP -> en/ja), and resolves common full names (english -> en).
|
||||
// An empty input stays empty so the C++ side applies its English default; an
|
||||
// unrecognized value is returned normalized so C++ can log it and default.
|
||||
func normalizeLanguage(lang string) string {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
if lang == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip region/locale suffix: keep the segment before the first separator.
|
||||
if i := strings.IndexAny(lang, "-_."); i >= 0 {
|
||||
lang = lang[:i]
|
||||
}
|
||||
|
||||
if code, ok := languageNameAliases[lang]; ok {
|
||||
return code
|
||||
}
|
||||
return lang
|
||||
}
|
||||
|
||||
func (q *Qwen3TtsCpp) Load(opts *pb.ModelOptions) error {
|
||||
// ModelFile is the model directory path (containing GGUF files)
|
||||
modelDir := opts.ModelFile
|
||||
@@ -54,7 +92,7 @@ func (q *Qwen3TtsCpp) TTS(req *pb.TTSRequest) error {
|
||||
dst := req.Dst
|
||||
language := ""
|
||||
if req.Language != nil {
|
||||
language = *req.Language
|
||||
language = normalizeLanguage(*req.Language)
|
||||
}
|
||||
|
||||
// Synthesis parameters with sensible defaults
|
||||
|
||||
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLanguageNormalization(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "qwen3-tts-cpp language normalization")
|
||||
}
|
||||
|
||||
var _ = Describe("normalizeLanguage", func() {
|
||||
DescribeTable("maps caller input to the canonical model language code",
|
||||
func(input, expected string) {
|
||||
Expect(normalizeLanguage(input)).To(Equal(expected))
|
||||
},
|
||||
// Canonical codes pass through unchanged
|
||||
Entry("canonical en", "en", "en"),
|
||||
Entry("canonical zh", "zh", "zh"),
|
||||
Entry("canonical pt", "pt", "pt"),
|
||||
|
||||
// Case-insensitive
|
||||
Entry("uppercase", "EN", "en"),
|
||||
Entry("mixed case", "Ja", "ja"),
|
||||
|
||||
// Surrounding whitespace
|
||||
Entry("trims whitespace", " en ", "en"),
|
||||
|
||||
// Region/locale stripping
|
||||
Entry("BCP-47 region", "en-US", "en"),
|
||||
Entry("underscore region", "en_US", "en"),
|
||||
Entry("dotted locale", "ja.JP", "ja"),
|
||||
Entry("region + case", "ZH-CN", "zh"),
|
||||
|
||||
// Full-name aliases
|
||||
Entry("english name", "english", "en"),
|
||||
Entry("chinese name cased", "Chinese", "zh"),
|
||||
Entry("japanese name", "japanese", "ja"),
|
||||
Entry("russian name", "russian", "ru"),
|
||||
Entry("portuguese name", "portuguese", "pt"),
|
||||
|
||||
// Empty stays empty (C++ applies the English default)
|
||||
Entry("empty", "", ""),
|
||||
Entry("whitespace only", " ", ""),
|
||||
|
||||
// Unknown values pass through normalized so C++ can log + default
|
||||
Entry("unknown code", "klingon", "klingon"),
|
||||
Entry("unknown with region", "xx-YY", "xx"),
|
||||
)
|
||||
})
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=7948df8ac1070f5f6881b8d34675821893eb97d6
|
||||
STABLEDIFFUSION_GGML_VERSION?=19bdfe22d255d5b4dff39d449318b9bc5ea2317f
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -386,6 +386,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *llm_vision_path = "";
|
||||
const char *diffusion_model_path = stableDiffusionModel;
|
||||
const char *high_noise_diffusion_model_path = "";
|
||||
const char *uncond_diffusion_model_path = "";
|
||||
const char *taesd_path = "";
|
||||
const char *control_net_path = "";
|
||||
const char *embedding_dir = "";
|
||||
@@ -472,6 +473,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval);
|
||||
if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "uncond_diffusion_model_path")) uncond_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
|
||||
if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
|
||||
if (!strcmp(optname, "embedding_dir")) {
|
||||
@@ -571,6 +573,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.llm_vision_path = llm_vision_path;
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.uncond_diffusion_model_path = uncond_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=23ee03506a91ac3d3f0071b40e66a430eebdfa1d
|
||||
WHISPER_CPP_VERSION?=df7638d8229a243af8a4b5a8ae557e0d74e0a0ae
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -95,6 +95,29 @@
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4"
|
||||
metal: "metal-ds4"
|
||||
metal-darwin-arm64: "metal-ds4"
|
||||
- &dllm
|
||||
name: "dllm"
|
||||
alias: "dllm"
|
||||
license: mit
|
||||
description: |
|
||||
mudler/dllm.cpp - DiffusionGemma block-diffusion LLM inference engine
|
||||
(C++/ggml, GGUF weights). Decodes whole token canvases per diffusion
|
||||
round instead of autoregressive sampling. Runs on CPU and NVIDIA CUDA 13
|
||||
(including Jetson/GB10 L4T targets).
|
||||
urls:
|
||||
- https://github.com/mudler/dllm.cpp
|
||||
tags:
|
||||
- text-to-text
|
||||
- LLM
|
||||
- gguf
|
||||
- diffusion
|
||||
- CPU
|
||||
- CUDA
|
||||
capabilities:
|
||||
default: "cpu-dllm"
|
||||
nvidia: "cuda13-dllm"
|
||||
nvidia-cuda-13: "cuda13-dllm"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm"
|
||||
- &whispercpp
|
||||
name: "whisper"
|
||||
alias: "whisper"
|
||||
@@ -1272,6 +1295,13 @@
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4-development"
|
||||
metal: "metal-ds4-development"
|
||||
metal-darwin-arm64: "metal-ds4-development"
|
||||
- !!merge <<: *dllm
|
||||
name: "dllm-development"
|
||||
capabilities:
|
||||
default: "cpu-dllm-development"
|
||||
nvidia: "cuda13-dllm-development"
|
||||
nvidia-cuda-13: "cuda13-dllm-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm-development"
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "stablediffusion-ggml-development"
|
||||
capabilities:
|
||||
@@ -1859,6 +1889,37 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-ds4"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-ds4
|
||||
## dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cpu-dllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cpu-dllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-dllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-dllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-nvidia-l4t-arm64-dllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-nvidia-l4t-arm64-dllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-dllm
|
||||
## whisper
|
||||
- !!merge <<: *whispercpp
|
||||
name: "whisper-development"
|
||||
|
||||
@@ -37,6 +37,20 @@ def is_int(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a TTSRequest.params value (string on the wire) to the type the
|
||||
Chatterbox generate() kwargs expect (float/int/bool), matching how static
|
||||
YAML options are coerced at load time. Non-string values pass through."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if is_float(value):
|
||||
return float(value)
|
||||
if is_int(value):
|
||||
return int(value)
|
||||
if value.lower() in ["true", "false"]:
|
||||
return value.lower() == "true"
|
||||
return value
|
||||
|
||||
def split_text_at_word_boundary(text, max_length=250):
|
||||
"""
|
||||
Split text at word boundaries without truncating words.
|
||||
@@ -191,6 +205,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Merge per-request params (TTSRequest.params), overriding the static
|
||||
# YAML options. This exposes Chatterbox generation knobs (e.g.
|
||||
# exaggeration, cfg_weight, temperature) per request. Values arrive as
|
||||
# strings on the wire and are coerced to float/int/bool.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Check if text exceeds 250 characters
|
||||
# (chatterbox does not support long text)
|
||||
# https://github.com/resemble-ai/chatterbox/issues/60
|
||||
|
||||
@@ -47,6 +47,26 @@ def is_int(s):
|
||||
return False
|
||||
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a string param value (from the TTSRequest.params map, which is
|
||||
string-typed on the wire) into the most specific Python type the model
|
||||
generation kwargs expect: bool, int, float, else the original string."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
lowered = value.strip().lower()
|
||||
if lowered in ("true", "false"):
|
||||
return lowered == "true"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
@@ -322,6 +342,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def _effective_instruct(self, request):
|
||||
"""Resolve the instruction/style string for this request, preferring the
|
||||
per-request TTSRequest.instructions value and falling back to the static
|
||||
YAML `instruct` option. Empty string means "no instruction"."""
|
||||
req_instruct = (
|
||||
request.instructions
|
||||
if hasattr(request, "instructions") and request.instructions
|
||||
else ""
|
||||
)
|
||||
if req_instruct:
|
||||
return req_instruct
|
||||
return self.options.get("instruct", "") or ""
|
||||
|
||||
def _detect_mode(self, request):
|
||||
"""Detect which mode to use based on request parameters."""
|
||||
# Priority: VoiceClone > VoiceDesign > CustomVoice
|
||||
@@ -338,8 +371,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if self.audio_path or self.voices:
|
||||
return "VoiceClone"
|
||||
|
||||
# VoiceDesign: instruct option is provided
|
||||
if "instruct" in self.options and self.options["instruct"]:
|
||||
# VoiceDesign: instruct provided per-request or via YAML option
|
||||
if self._effective_instruct(request):
|
||||
return "VoiceDesign"
|
||||
|
||||
# Default to CustomVoice
|
||||
@@ -690,10 +723,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if do_sample is not None:
|
||||
generation_kwargs["do_sample"] = do_sample
|
||||
|
||||
instruct = self.options.get("instruct", "")
|
||||
# Prefer the per-request instruction (TTSRequest.instructions) over the
|
||||
# static YAML `instruct` option. This lets clients set a different style
|
||||
# (CustomVoice emotion) or designed voice (VoiceDesign) per request.
|
||||
instruct = self._effective_instruct(request)
|
||||
if instruct is not None and instruct != "":
|
||||
generation_kwargs["instruct"] = instruct
|
||||
|
||||
# Merge any per-request backend-specific params (TTSRequest.params).
|
||||
# Values arrive as strings on the wire; coerce to int/float/bool so the
|
||||
# model receives the types it expects. These override YAML-derived kwargs.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
generation_kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Generate audio based on mode
|
||||
if mode == "VoiceClone":
|
||||
# VoiceClone mode
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf==7.35.0
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -26,7 +26,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
try:
|
||||
from vllm.tokenizers import get_tokenizer # vLLM >= 0.22
|
||||
except ImportError:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer # vLLM < 0.22
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.assets.video import VideoAsset
|
||||
import base64
|
||||
|
||||
@@ -3,5 +3,5 @@
|
||||
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
||||
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
||||
# so uv consults this index alongside PyPI.
|
||||
--extra-index-url https://wheels.vllm.ai/0.22.0/cu130
|
||||
vllm==0.22.0
|
||||
--extra-index-url https://wheels.vllm.ai/0.22.1/cu130
|
||||
vllm==0.22.1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
grpcio==1.81.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -102,7 +102,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
|
||||
natsAuth := cfg.Distributed.NatsAuthConfig()
|
||||
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
|
||||
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
|
||||
}
|
||||
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
|
||||
@@ -23,9 +23,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -308,10 +308,31 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
application.galleryService.SetNATSClient(distSvc.Nats)
|
||||
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
|
||||
// Clean up stale in-progress operations from previous crashed instances
|
||||
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
if _, err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to clean stale gallery operations", "error", err)
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
|
||||
// Reap stale ops periodically, not just at boot: an op orphaned by
|
||||
// a replica that died mid-install (its foreground handler goroutine
|
||||
// gone) would otherwise linger "processing" in the UI until the next
|
||||
// restart. 30m matches the install/upgrade ceiling so a genuinely
|
||||
// slow op is never reaped out from under itself.
|
||||
gsvc := application.galleryService
|
||||
go func() {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-options.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := gsvc.ReapStaleOperations(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to reap stale gallery operations", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
// Hydrate from the store first so the wildcard subscriber finds an
|
||||
// already-populated statuses map for any operations still in flight
|
||||
|
||||
@@ -214,7 +214,9 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
"from", info.InstalledVersion, "to", info.AvailableVersion)
|
||||
var err error
|
||||
if bm != nil {
|
||||
err = bm.UpgradeBackend(ctx, name, nil)
|
||||
// Background auto-upgrade: no live admin watching a progress bar,
|
||||
// so opID is empty and the distributed path skips progress streaming.
|
||||
err = bm.UpgradeBackend(ctx, "", name, nil)
|
||||
} else {
|
||||
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)
|
||||
|
||||
@@ -123,14 +123,14 @@ var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
|
||||
})
|
||||
|
||||
It("ModelTTS forwards the request context to the SmartRouter", func() {
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
|
||||
@@ -239,13 +239,13 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
|
||||
if c.Backend == "cloud-proxy" {
|
||||
opts.Proxy = &pb.ProxyOptions{
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,6 +323,12 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
metadata["enable_thinking"] = "true"
|
||||
}
|
||||
}
|
||||
// Forward the effective reasoning effort so the backend can pass it to the
|
||||
// jinja chat template (chat_template_kwargs.reasoning_effort) — the lever
|
||||
// models like gpt-oss / LFM2.5 actually read, distinct from enable_thinking.
|
||||
if c.ReasoningEffort != "" {
|
||||
metadata["reasoning_effort"] = c.ReasoningEffort
|
||||
}
|
||||
pbOpts.Metadata = metadata
|
||||
|
||||
// Logprobs and TopLogprobs are set by the caller if provided
|
||||
|
||||
@@ -75,3 +75,25 @@ var _ = Describe("gRPCPredictOpts enable_thinking metadata", func() {
|
||||
Expect(opts.Metadata).ToNot(HaveKey("enable_thinking"))
|
||||
})
|
||||
})
|
||||
|
||||
// Guards forwarding the effective reasoning_effort into PredictOptions.Metadata,
|
||||
// where the backend passes it to the jinja chat template (chat_template_kwargs)
|
||||
// so models like gpt-oss / LFM2.5 honor it.
|
||||
var _ = Describe("gRPCPredictOpts reasoning_effort metadata", func() {
|
||||
withEffort := func(effort string) config.ModelConfig {
|
||||
cfg := config.ModelConfig{}
|
||||
cfg.SetDefaults()
|
||||
cfg.ReasoningEffort = effort
|
||||
return cfg
|
||||
}
|
||||
|
||||
It("forwards reasoning_effort when set", func() {
|
||||
opts := gRPCPredictOpts(withEffort("none"), "/tmp/models")
|
||||
Expect(opts.Metadata).To(HaveKeyWithValue("reasoning_effort", "none"))
|
||||
})
|
||||
|
||||
It("omits reasoning_effort when empty", func() {
|
||||
opts := gRPCPredictOpts(withEffort(""), "/tmp/models")
|
||||
Expect(opts.Metadata).ToNot(HaveKey("reasoning_effort"))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -20,11 +20,32 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
// newTTSRequest assembles the gRPC TTSRequest from the per-request inputs. The
|
||||
// optional instructions string is only attached when non-empty so backends can
|
||||
// distinguish "no per-request instruction" (fall back to YAML) from an explicit
|
||||
// empty one. params is forwarded as-is (nil when unset).
|
||||
func newTTSRequest(text, modelPath, voice, dst, language, instructions string, params map[string]string) *proto.TTSRequest {
|
||||
req := &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: dst,
|
||||
Language: &language,
|
||||
Params: params,
|
||||
}
|
||||
if instructions != "" {
|
||||
req.Instructions = &instructions
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func ModelTTS(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -74,13 +95,9 @@ func ModelTTS(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: filePath,
|
||||
Language: &language,
|
||||
})
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, filePath, language, instructions, params)
|
||||
|
||||
res, err := ttsModel.TTS(ctx, ttsRequest)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
@@ -128,7 +145,9 @@ func ModelTTSStream(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -177,12 +196,10 @@ func ModelTTSStream(
|
||||
var totalPCMBytes int
|
||||
snippetCapped := false
|
||||
|
||||
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Language: &language,
|
||||
}, func(reply *proto.Reply) {
|
||||
// Streaming TTS writes to the HTTP response, not a file, so dst is empty.
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, "", language, instructions, params)
|
||||
|
||||
err = ttsModel.TTSStream(ctx, ttsRequest, func(reply *proto.Reply) {
|
||||
// First message contains sample rate info
|
||||
if !headerSent && len(reply.Message) > 0 {
|
||||
var info map[string]any
|
||||
|
||||
42
core/backend/tts_test.go
Normal file
42
core/backend/tts_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
// Specs for the TTSRequest assembly that carries the per-request
|
||||
// instructions/params from the OpenAI `instructions` field (and the LocalAI
|
||||
// `params` extension) through to the gRPC boundary. Before this plumbing the
|
||||
// instruction value was dropped before reaching the backend; these specs pin
|
||||
// that it now survives, and that the empty case stays backward compatible.
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("newTTSRequest", func() {
|
||||
It("attaches the instructions when a per-request value is set", func() {
|
||||
req := newTTSRequest("hi", "/m", "alloy", "/out.wav", "en", "cheerful narrator", nil)
|
||||
Expect(req.Instructions).ToNot(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal("cheerful narrator"))
|
||||
Expect(req.GetText()).To(Equal("hi"))
|
||||
Expect(req.GetVoice()).To(Equal("alloy"))
|
||||
Expect(req.GetDst()).To(Equal("/out.wav"))
|
||||
Expect(req.GetLanguage()).To(Equal("en"))
|
||||
})
|
||||
|
||||
It("leaves instructions unset when empty so backends fall back to YAML", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.Instructions).To(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal(""))
|
||||
})
|
||||
|
||||
It("forwards per-request params through to the backend", func() {
|
||||
params := map[string]string{"exaggeration": "0.7", "cfg_weight": "0.3"}
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", params)
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("exaggeration", "0.7"))
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("cfg_weight", "0.3"))
|
||||
})
|
||||
|
||||
It("leaves params nil when none are supplied", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.GetParams()).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -52,10 +52,28 @@ type AgentWorkerCMD struct {
|
||||
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
|
||||
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
|
||||
|
||||
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
|
||||
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
|
||||
// DistributedRequireAuth is the umbrella switch; for the agent worker (which
|
||||
// has no file-transfer server) it implies NATS auth is required.
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch implying --nats-require-auth (agent workers have no file-transfer server)" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
|
||||
// Timeouts
|
||||
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
|
||||
}
|
||||
|
||||
// natsAuthRequired reports whether NATS JWT credentials must be present — the
|
||||
// granular flag or the umbrella (LOCALAI_DISTRIBUTED_REQUIRE_AUTH).
|
||||
func (cmd *AgentWorkerCMD) natsAuthRequired() bool {
|
||||
return cmd.NatsRequireAuth || cmd.DistributedRequireAuth
|
||||
}
|
||||
|
||||
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
|
||||
|
||||
@@ -81,15 +99,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
registrationBody["token"] = cmd.RegistrationToken
|
||||
}
|
||||
|
||||
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
// Context cancelled on shutdown — used by registration waits, heartbeat, and
|
||||
// other background goroutines.
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Acquire credentials via (re)registration. When the bus requires auth and no
|
||||
// static fallback is configured, wait through admin approval until the
|
||||
// frontend mints credentials rather than starting unauthenticated.
|
||||
credMgr := workerregistry.NewNATSCredentialManager(
|
||||
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
|
||||
return regClient.RegisterFull(ctx, registrationBody)
|
||||
},
|
||||
cmd.natsAuthRequired() && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
|
||||
)
|
||||
res, err := credMgr.Acquire(shutdownCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
nodeID := res.ID
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||
|
||||
// Use provisioned API token if none was set
|
||||
if cmd.APIToken == "" {
|
||||
cmd.APIToken = apiToken
|
||||
cmd.APIToken = res.APIToken
|
||||
}
|
||||
|
||||
// Start heartbeat
|
||||
@@ -98,14 +131,40 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cmd.NatsURL)
|
||||
// Resolve NATS credentials with precedence: explicit env override, then
|
||||
// frontend-minted (auto-refreshed before expiry), then service fallback.
|
||||
// Each static source must supply JWT and seed together.
|
||||
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
|
||||
var natsOpts []messaging.Option
|
||||
switch {
|
||||
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
|
||||
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
|
||||
case credMgr.HasCredentials():
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
|
||||
go func() {
|
||||
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
|
||||
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
|
||||
shutdownCancel()
|
||||
}
|
||||
}()
|
||||
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
|
||||
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
|
||||
case cmd.natsAuthRequired():
|
||||
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
|
||||
}
|
||||
if natsTLS.Enabled() {
|
||||
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
|
||||
}
|
||||
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
@@ -183,17 +242,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
|
||||
|
||||
// Wait for shutdown
|
||||
// Wait for an OS signal or an internal fatal condition (e.g. NATS
|
||||
// credentials became unrenewable), so the worker restarts and re-acquires
|
||||
// rather than lingering unable to serve.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
var runErr error
|
||||
select {
|
||||
case <-sigCh:
|
||||
case <-shutdownCtx.Done():
|
||||
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
|
||||
xlog.Error("Internal shutdown requested", "error", runErr)
|
||||
}
|
||||
|
||||
xlog.Info("Shutting down agent worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
dispatcher.Stop()
|
||||
mcpTools.CloseAllMCPSessions()
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
return nil
|
||||
return runErr
|
||||
}
|
||||
|
||||
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
|
||||
|
||||
30
core/cli/chat/chat.go
Normal file
30
core/cli/chat/chat.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Model string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
In io.Reader
|
||||
Out io.Writer
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, opts Options) error {
|
||||
if opts.In == nil {
|
||||
opts.In = strings.NewReader("")
|
||||
}
|
||||
if opts.Out == nil {
|
||||
opts.Out = io.Discard
|
||||
}
|
||||
|
||||
session, err := newChatSession(ctx, newLocalAIChatClient(opts.BaseURL, opts.APIKey), opts.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return runTerminalChat(ctx, session, opts.In, opts.Out)
|
||||
}
|
||||
13
core/cli/chat/chat_suite_test.go
Normal file
13
core/cli/chat/chat_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestChat(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Chat Suite")
|
||||
}
|
||||
172
core/cli/chat/chat_test.go
Normal file
172
core/cli/chat/chat_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Run chat", func() {
|
||||
It("streams a single chat response", func() {
|
||||
var capturedModel string
|
||||
var capturedAuth string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v1/models" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
writeResponse(w, `{"object":"list","data":[{"id":"test-model","object":"model"}]}`)
|
||||
return
|
||||
}
|
||||
|
||||
Expect(r.URL.Path).To(Equal("/v1/chat/completions"))
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
|
||||
var body struct {
|
||||
Model string `json:"model"`
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
|
||||
capturedModel = body.Model
|
||||
Expect(body.Messages).To(HaveLen(1))
|
||||
Expect(body.Messages[0].Role).To(Equal("user"))
|
||||
Expect(body.Messages[0].Content).To(Equal("hello"))
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n")
|
||||
writeResponse(w, "data: [DONE]\n\n")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var out bytes.Buffer
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
Model: "test-model",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "secret",
|
||||
In: strings.NewReader("hello\n/exit\n"),
|
||||
Out: &out,
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(capturedModel).To(Equal("test-model"))
|
||||
Expect(capturedAuth).To(Equal("Bearer secret"))
|
||||
Expect(out.String()).To(ContainSubstring("assistant: hi!"))
|
||||
Expect(out.String()).To(ContainSubstring("bye"))
|
||||
})
|
||||
|
||||
It("auto-selects the only available model", func() {
|
||||
server := chatTestServer([]string{"solo"}, nil)
|
||||
defer server.Close()
|
||||
|
||||
var out bytes.Buffer
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader("/exit\n"),
|
||||
Out: &out,
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out.String()).To(ContainSubstring("LocalAI chat (solo)"))
|
||||
})
|
||||
|
||||
It("returns an actionable error when no models are installed", func() {
|
||||
server := chatTestServer(nil, nil)
|
||||
defer server.Close()
|
||||
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader(""),
|
||||
})
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no chat models are installed"))
|
||||
Expect(err.Error()).To(ContainSubstring("local-ai models install <model>"))
|
||||
})
|
||||
|
||||
It("returns an actionable error when multiple models are available without a selection", func() {
|
||||
server := chatTestServer([]string{"alpha", "beta"}, nil)
|
||||
defer server.Close()
|
||||
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader(""),
|
||||
})
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("multiple models are available"))
|
||||
Expect(err.Error()).To(ContainSubstring("--model"))
|
||||
Expect(err.Error()).To(ContainSubstring("alpha"))
|
||||
Expect(err.Error()).To(ContainSubstring("beta"))
|
||||
})
|
||||
|
||||
It("lists and switches models inside the chat", func() {
|
||||
requestedModels := []string{}
|
||||
server := chatTestServer([]string{"alpha", "beta"}, func(model string) {
|
||||
requestedModels = append(requestedModels, model)
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
var out bytes.Buffer
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
Model: "alpha",
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader("/models\n/model beta\nhello\n/exit\n"),
|
||||
Out: &out,
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out.String()).To(ContainSubstring("* alpha"))
|
||||
Expect(out.String()).To(ContainSubstring(" beta"))
|
||||
Expect(out.String()).To(ContainSubstring("switched to beta; conversation cleared"))
|
||||
Expect(requestedModels).To(Equal([]string{"beta"}))
|
||||
})
|
||||
})
|
||||
|
||||
func chatTestServer(models []string, onChat func(model string)) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v1/models":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
writeResponse(w, `{"object":"list","data":[`)
|
||||
for i, model := range models {
|
||||
if i > 0 {
|
||||
writeResponse(w, ",")
|
||||
}
|
||||
writeResponsef(w, `{"id":%q,"object":"model"}`, model)
|
||||
}
|
||||
writeResponse(w, `]}`)
|
||||
case "/v1/chat/completions":
|
||||
var body struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
|
||||
if onChat != nil {
|
||||
onChat(body.Model)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n\n")
|
||||
writeResponse(w, "data: [DONE]\n\n")
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func writeResponse(w io.Writer, text string) {
|
||||
_, err := fmt.Fprint(w, text)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
func writeResponsef(w io.Writer, format string, args ...any) {
|
||||
_, err := fmt.Fprintf(w, format, args...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
114
core/cli/chat/client.go
Normal file
114
core/cli/chat/client.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type chatClient interface {
|
||||
ListModels(ctx context.Context) ([]string, error)
|
||||
StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error)
|
||||
}
|
||||
|
||||
type localAIChatClient struct {
|
||||
client *openai.Client
|
||||
}
|
||||
|
||||
func newLocalAIChatClient(baseURL string, apiKey string) *localAIChatClient {
|
||||
cfg := openai.DefaultConfig(apiKey)
|
||||
cfg.BaseURL = baseURL
|
||||
return &localAIChatClient{client: openai.NewClientWithConfig(cfg)}
|
||||
}
|
||||
|
||||
func (c *localAIChatClient) ListModels(ctx context.Context) ([]string, error) {
|
||||
resp, err := c.client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(resp.Models))
|
||||
for _, model := range resp.Models {
|
||||
if model.ID != "" {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
}
|
||||
sort.Strings(models)
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (c *localAIChatClient) StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
|
||||
stream, err := c.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: openAIChatMessages(messages),
|
||||
})
|
||||
if err != nil {
|
||||
return "", friendlyChatError(err, model)
|
||||
}
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
var answer strings.Builder
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return answer.String(), friendlyChatError(err, model)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
token := resp.Choices[0].Delta.Content
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
answer.WriteString(token)
|
||||
if _, err := fmt.Fprint(out, token); err != nil {
|
||||
return answer.String(), err
|
||||
}
|
||||
}
|
||||
|
||||
return answer.String(), nil
|
||||
}
|
||||
|
||||
func openAIChatMessages(messages []chatMessage) []openai.ChatCompletionMessage {
|
||||
converted := make([]openai.ChatCompletionMessage, len(messages))
|
||||
for i, message := range messages {
|
||||
converted[i] = openai.ChatCompletionMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func friendlyChatError(err error, model string) error {
|
||||
var apiErr *openai.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
switch apiErr.HTTPStatusCode {
|
||||
case 404:
|
||||
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
|
||||
case 403:
|
||||
return fmt.Errorf("model %q is disabled. Enable it from LocalAI settings or choose another model with `/model <name>`", model)
|
||||
}
|
||||
if apiErr.Message != "" {
|
||||
return errors.New(apiErr.Message)
|
||||
}
|
||||
}
|
||||
|
||||
msg := err.Error()
|
||||
if strings.Contains(msg, "model") && strings.Contains(msg, "not found") {
|
||||
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
17
core/cli/chat/models.go
Normal file
17
core/cli/chat/models.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package chat
|
||||
|
||||
import "strings"
|
||||
|
||||
func formatChatModelList(models []string, current string) string {
|
||||
var b strings.Builder
|
||||
for _, model := range models {
|
||||
prefix := " "
|
||||
if model == current {
|
||||
prefix = "* "
|
||||
}
|
||||
b.WriteString(prefix)
|
||||
b.WriteString(model)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
120
core/cli/chat/session.go
Normal file
120
core/cli/chat/session.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
chatRoleUser = "user"
|
||||
chatRoleAssistant = "assistant"
|
||||
)
|
||||
|
||||
type chatMessage struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
type chatSession struct {
|
||||
client chatClient
|
||||
model string
|
||||
models []string
|
||||
messages []chatMessage
|
||||
}
|
||||
|
||||
func newChatSession(ctx context.Context, client chatClient, requestedModel string) (*chatSession, error) {
|
||||
models, err := client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list models: %w", err)
|
||||
}
|
||||
|
||||
model, err := resolveChatModel(requestedModel, models)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &chatSession{
|
||||
client: client,
|
||||
model: model,
|
||||
models: models,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *chatSession) CurrentModel() string {
|
||||
return s.model
|
||||
}
|
||||
|
||||
func (s *chatSession) Models() []string {
|
||||
models := make([]string, len(s.models))
|
||||
copy(models, s.models)
|
||||
return models
|
||||
}
|
||||
|
||||
func (s *chatSession) Clear() {
|
||||
s.messages = nil
|
||||
}
|
||||
|
||||
func (s *chatSession) SwitchModel(model string) error {
|
||||
if !modelExists(s.models, model) {
|
||||
return fmt.Errorf("model %q is not available. Use /models to see installed models", model)
|
||||
}
|
||||
s.model = model
|
||||
s.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *chatSession) Send(ctx context.Context, prompt string, out io.Writer) error {
|
||||
s.messages = append(s.messages, chatMessage{
|
||||
Role: chatRoleUser,
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
answer, err := s.client.StreamChat(ctx, s.model, s.messages, out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.messages = append(s.messages, chatMessage{
|
||||
Role: chatRoleAssistant,
|
||||
Content: answer,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveChatModel(requested string, models []string) (string, error) {
|
||||
switch {
|
||||
case requested == "" && len(models) == 0:
|
||||
return "", errors.New(`no chat models are installed.
|
||||
|
||||
Install a model first, for example:
|
||||
local-ai models list
|
||||
local-ai models install <model>
|
||||
local-ai run
|
||||
|
||||
Then start a chat session:
|
||||
local-ai chat --model <model>`)
|
||||
case requested == "" && len(models) == 1:
|
||||
return models[0], nil
|
||||
case requested == "" && len(models) > 1:
|
||||
var b strings.Builder
|
||||
b.WriteString("multiple models are available; choose one with --model:\n")
|
||||
b.WriteString(formatChatModelList(models, ""))
|
||||
return "", errors.New(b.String())
|
||||
case !modelExists(models, requested):
|
||||
return "", fmt.Errorf("model %q is not available. Use `local-ai models list` and `local-ai models install <model>`, or pass an installed model with --model", requested)
|
||||
default:
|
||||
return requested, nil
|
||||
}
|
||||
}
|
||||
|
||||
func modelExists(models []string, name string) bool {
|
||||
for _, model := range models {
|
||||
if model == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
56
core/cli/chat/session_test.go
Normal file
56
core/cli/chat/session_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Chat session", func() {
|
||||
It("keeps model switching and message history out of the terminal adapter", func() {
|
||||
client := &fakeChatClient{
|
||||
models: []string{"alpha", "beta"},
|
||||
answer: "pong",
|
||||
}
|
||||
|
||||
session, err := newChatSession(context.Background(), client, "alpha")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(session.CurrentModel()).To(Equal("alpha"))
|
||||
|
||||
Expect(session.SwitchModel("beta")).To(Succeed())
|
||||
Expect(session.CurrentModel()).To(Equal("beta"))
|
||||
Expect(session.Send(context.Background(), "ping", io.Discard)).To(Succeed())
|
||||
|
||||
Expect(client.requests).To(HaveLen(1))
|
||||
Expect(client.requests[0].model).To(Equal("beta"))
|
||||
Expect(client.requests[0].messages).To(HaveLen(1))
|
||||
Expect(client.requests[0].messages[0].Content).To(Equal("ping"))
|
||||
})
|
||||
})
|
||||
|
||||
type fakeChatClient struct {
|
||||
models []string
|
||||
answer string
|
||||
requests []fakeChatRequest
|
||||
}
|
||||
|
||||
type fakeChatRequest struct {
|
||||
model string
|
||||
messages []chatMessage
|
||||
}
|
||||
|
||||
func (c *fakeChatClient) ListModels(context.Context) ([]string, error) {
|
||||
return c.models, nil
|
||||
}
|
||||
|
||||
func (c *fakeChatClient) StreamChat(_ context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
|
||||
copied := make([]chatMessage, len(messages))
|
||||
copy(copied, messages)
|
||||
c.requests = append(c.requests, fakeChatRequest{model: model, messages: copied})
|
||||
if _, err := io.WriteString(out, c.answer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return c.answer, nil
|
||||
}
|
||||
93
core/cli/chat/terminal.go
Normal file
93
core/cli/chat/terminal.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func runTerminalChat(ctx context.Context, session *chatSession, in io.Reader, out io.Writer) error {
|
||||
scanner := bufio.NewScanner(in)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
|
||||
|
||||
if err := writeChat(out, "LocalAI chat (%s)\n", session.CurrentModel()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeChat(out, "Type /exit to quit, /clear to reset the conversation, /models to list models.\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
if err := writeChat(out, "\n> "); err != nil {
|
||||
return err
|
||||
}
|
||||
if !scanner.Scan() {
|
||||
break
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(scanner.Text())
|
||||
switch prompt {
|
||||
case "":
|
||||
continue
|
||||
case "/bye", "/exit", "/quit":
|
||||
return writeChat(out, "bye\n")
|
||||
case "/clear":
|
||||
session.Clear()
|
||||
if err := writeChat(out, "conversation cleared\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
case "/models":
|
||||
if err := printChatModels(out, session.Models(), session.CurrentModel()); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if nextModel, ok := strings.CutPrefix(prompt, "/model "); ok {
|
||||
nextModel = strings.TrimSpace(nextModel)
|
||||
if nextModel == "" {
|
||||
if err := writeChat(out, "usage: /model <name>\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := session.SwitchModel(nextModel); err != nil {
|
||||
if writeErr := writeChat(out, "%s\n", err); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := writeChat(out, "switched to %s; conversation cleared\n", session.CurrentModel()); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := writeChat(out, "assistant: "); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := session.Send(ctx, prompt, out); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeChat(out, "\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func printChatModels(out io.Writer, models []string, current string) error {
|
||||
if len(models) == 0 {
|
||||
return writeChat(out, "no models installed\n")
|
||||
}
|
||||
return writeChat(out, "%s", formatChatModelList(models, current))
|
||||
}
|
||||
|
||||
func writeChat(out io.Writer, format string, args ...any) error {
|
||||
_, err := fmt.Fprintf(out, format, args...)
|
||||
return err
|
||||
}
|
||||
25
core/cli/chat_cmd.go
Normal file
25
core/cli/chat_cmd.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
chatcli "github.com/mudler/LocalAI/core/cli/chat"
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
)
|
||||
|
||||
type ChatCMD struct {
|
||||
Model string `short:"m" help:"Model name to use. Defaults to the only model returned by the server when exactly one is available"`
|
||||
Endpoint string `env:"LOCALAI_CHAT_ENDPOINT" default:"http://127.0.0.1:8080" help:"LocalAI server endpoint. The /v1 path is added automatically when omitted"`
|
||||
APIKey string `env:"LOCALAI_API_KEY,API_KEY" help:"API key to use when the LocalAI server requires authentication"`
|
||||
}
|
||||
|
||||
func (c *ChatCMD) Run(ctx *cliContext.Context) error {
|
||||
return chatcli.Run(context.Background(), chatcli.Options{
|
||||
Model: c.Model,
|
||||
BaseURL: chatAPIBaseURL(c.Endpoint),
|
||||
APIKey: c.APIKey,
|
||||
In: os.Stdin,
|
||||
Out: os.Stdout,
|
||||
})
|
||||
}
|
||||
27
core/cli/chat_cmd_test.go
Normal file
27
core/cli/chat_cmd_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Chat command wiring", func() {
|
||||
Describe("chatAPIBaseURL", func() {
|
||||
It("adds /v1 to a root endpoint", func() {
|
||||
Expect(chatAPIBaseURL("http://127.0.0.1:8080")).To(Equal("http://127.0.0.1:8080/v1"))
|
||||
})
|
||||
|
||||
It("keeps endpoints that already include /v1", func() {
|
||||
Expect(chatAPIBaseURL("http://127.0.0.1:8080/v1")).To(Equal("http://127.0.0.1:8080/v1"))
|
||||
Expect(chatAPIBaseURL("http://127.0.0.1:8080/v1/")).To(Equal("http://127.0.0.1:8080/v1"))
|
||||
})
|
||||
|
||||
It("adds a default http scheme", func() {
|
||||
Expect(chatAPIBaseURL("127.0.0.1:8080")).To(Equal("http://127.0.0.1:8080/v1"))
|
||||
})
|
||||
|
||||
It("preserves non-root paths before /v1", func() {
|
||||
Expect(chatAPIBaseURL("http://127.0.0.1:8080/localai")).To(Equal("http://127.0.0.1:8080/localai/v1"))
|
||||
})
|
||||
})
|
||||
})
|
||||
29
core/cli/chat_endpoint.go
Normal file
29
core/cli/chat_endpoint.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func chatAPIBaseURL(endpoint string) string {
|
||||
if !strings.Contains(endpoint, "://") {
|
||||
endpoint = "http://" + endpoint
|
||||
}
|
||||
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return strings.TrimRight(endpoint, "/") + "/v1"
|
||||
}
|
||||
|
||||
path := strings.TrimRight(u.Path, "/")
|
||||
if path == "" {
|
||||
u.Path = "/v1"
|
||||
} else if path != "/v1" && !strings.HasSuffix(path, "/v1") {
|
||||
u.Path = path + "/v1"
|
||||
} else {
|
||||
u.Path = path
|
||||
}
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
return u.String()
|
||||
}
|
||||
@@ -9,6 +9,7 @@ var CLI struct {
|
||||
cliContext.Context `embed:""`
|
||||
|
||||
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
|
||||
Chat ChatCMD `cmd:"" help:"Open an interactive chat session against a running LocalAI server"`
|
||||
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
|
||||
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
|
||||
Backends BackendsCMD `cmd:"" help:"Manage LocalAI backends and definitions"`
|
||||
|
||||
@@ -30,6 +30,8 @@ type RunCMD struct {
|
||||
ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
|
||||
|
||||
ExternalBackends []string `env:"LOCALAI_EXTERNAL_BACKENDS,EXTERNAL_BACKENDS" help:"A list of external backends to load from gallery on boot" group:"backends"`
|
||||
WebRTCNAT1To1IPs []string `env:"LOCALAI_WEBRTC_NAT_1TO1_IPS,WEBRTC_NAT_1TO1_IPS" help:"IPs advertised as the host ICE candidates for /v1/realtime WebRTC instead of every local interface. Set to the reachable host/LAN IP when running under Docker host networking or NAT, where pion otherwise offers unreachable bridge addresses and the connection drops after ICE consent checks fail." group:"api"`
|
||||
WebRTCICEInterfaces []string `env:"LOCALAI_WEBRTC_ICE_INTERFACES,WEBRTC_ICE_INTERFACES" help:"Restrict /v1/realtime WebRTC ICE candidate gathering to these network interfaces (e.g. eth0), filtering out docker0/veth noise." group:"api"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||
@@ -154,11 +156,21 @@ type RunCMD struct {
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
RegistrationRequireAuth bool `env:"LOCALAI_REGISTRATION_REQUIRE_AUTH" default:"false" help:"Fail startup when distributed mode is enabled but LOCALAI_REGISTRATION_TOKEN is empty (node endpoints and worker file-transfer server would otherwise be unauthenticated)" group:"distributed"`
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch: require BOTH NATS JWT credentials and a registration token when distributed mode is enabled (implies --nats-require-auth and --registration-require-auth)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
NatsAccountSeed string `env:"LOCALAI_NATS_ACCOUNT_SEED" help:"NATS account signing seed (SU...) used to mint per-node worker JWTs at registration" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"NATS user JWT for the frontend (and agent workers) to publish control-plane messages" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"NATS user signing seed (SU...) paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsWorkerJWTTTL string `env:"LOCALAI_NATS_WORKER_JWT_TTL" help:"Lifetime of minted per-node NATS JWTs (e.g. 24h, default 24h)" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT credentials (service JWT + account seed) when distributed mode is enabled" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI); use with tls:// in --nats-url" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
@@ -215,6 +227,8 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithApiKeys(r.APIKeys),
|
||||
config.WithModelsURL(append(r.Models, r.ModelArgs...)...),
|
||||
config.WithExternalBackends(r.ExternalBackends...),
|
||||
config.WithWebRTCNAT1To1IPs(r.WebRTCNAT1To1IPs...),
|
||||
config.WithWebRTCICEInterfaces(r.WebRTCICEInterfaces...),
|
||||
config.WithOpaqueErrors(r.OpaqueErrors),
|
||||
config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan),
|
||||
config.WithSubtleKeyComparison(r.UseSubtleKeyComparison),
|
||||
@@ -283,6 +297,40 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
if r.RegistrationRequireAuth {
|
||||
opts = append(opts, config.EnableRegistrationRequireAuth)
|
||||
}
|
||||
if r.DistributedRequireAuth {
|
||||
opts = append(opts, config.EnableDistributedRequireAuth)
|
||||
}
|
||||
if r.NatsAccountSeed != "" {
|
||||
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
|
||||
}
|
||||
if r.NatsServiceJWT != "" {
|
||||
opts = append(opts, config.WithNatsServiceJWT(r.NatsServiceJWT))
|
||||
}
|
||||
if r.NatsServiceSeed != "" {
|
||||
opts = append(opts, config.WithNatsServiceSeed(r.NatsServiceSeed))
|
||||
}
|
||||
if r.NatsWorkerJWTTTL != "" {
|
||||
d, err := time.ParseDuration(r.NatsWorkerJWTTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_WORKER_JWT_TTL %q: %w", r.NatsWorkerJWTTTL, err)
|
||||
}
|
||||
opts = append(opts, config.WithNatsWorkerJWTTTL(d))
|
||||
}
|
||||
if r.NatsRequireAuth {
|
||||
opts = append(opts, config.EnableNatsRequireAuth)
|
||||
}
|
||||
if r.NatsTLSCA != "" {
|
||||
opts = append(opts, config.WithNatsTLSCA(r.NatsTLSCA))
|
||||
}
|
||||
if r.NatsTLSCert != "" {
|
||||
opts = append(opts, config.WithNatsTLSCert(r.NatsTLSCert))
|
||||
}
|
||||
if r.NatsTLSKey != "" {
|
||||
opts = append(opts, config.WithNatsTLSKey(r.NatsTLSKey))
|
||||
}
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
@@ -608,12 +656,12 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
// waitForServerReady polls the given address until the HTTP server is
|
||||
// accepting connections or the context is cancelled.
|
||||
func waitForServerReady(address string, ctx context.Context) {
|
||||
// Ensure the address has a host component for dialing.
|
||||
// Echo accepts ":8080" but net.Dial needs a resolvable host.
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err == nil && host == "" {
|
||||
address = "127.0.0.1:" + port
|
||||
}
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -621,11 +669,17 @@ func waitForServerReady(address string, ctx context.Context) {
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", address, 500*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||
options.Backend = t.Backend
|
||||
options.Model = t.Model
|
||||
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, "", nil, ml, opts, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
|
||||
FrontendURL: r.RegisterTo,
|
||||
RegistrationToken: r.RegistrationToken,
|
||||
}
|
||||
nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
nodeID, _, _, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("registering with frontend: %w", regErr)
|
||||
}
|
||||
|
||||
@@ -58,65 +58,77 @@ func (c *RegistrationClient) setAuth(req *http.Request) {
|
||||
|
||||
// RegisterResponse is the JSON body returned by /api/node/register.
|
||||
type RegisterResponse struct {
|
||||
ID string `json:"id"`
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
NatsJWT string `json:"nats_jwt,omitempty"`
|
||||
NatsUserSeed string `json:"nats_user_seed,omitempty"`
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// (optionally) an auto-provisioned API token.
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
|
||||
// RegisterFull sends a single registration request and returns the full
|
||||
// response (node ID, approval status, and optional API token / NATS creds).
|
||||
// Re-registration is idempotent: the frontend preserves the node row and mints
|
||||
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
|
||||
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
url := c.baseURL() + "/api/node/register"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("creating request: %w", err)
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("posting to %s: %w", url, err)
|
||||
return nil, fmt.Errorf("posting to %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", "", fmt.Errorf("decoding response: %w", err)
|
||||
return nil, fmt.Errorf("decoding response: %w", err)
|
||||
}
|
||||
return result.ID, result.APIToken, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// optional credentials (API token for agent workers, NATS JWT when configured).
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
res, err := c.RegisterFull(ctx, body)
|
||||
if err != nil {
|
||||
return "", "", "", "", err
|
||||
}
|
||||
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
|
||||
}
|
||||
|
||||
// RegisterWithRetry retries registration with exponential backoff.
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
backoff := 2 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
|
||||
var nodeID, apiToken string
|
||||
var err error
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
nodeID, apiToken, err = c.Register(ctx, body)
|
||||
nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body)
|
||||
if err == nil {
|
||||
return nodeID, apiToken, nil
|
||||
return nodeID, apiToken, natsJWT, natsSeed, nil
|
||||
}
|
||||
if attempt == maxRetries {
|
||||
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", "", ctx.Err()
|
||||
return "", "", "", "", ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, maxBackoff)
|
||||
}
|
||||
return nodeID, apiToken, err
|
||||
return nodeID, apiToken, natsJWT, natsSeed, err
|
||||
}
|
||||
|
||||
// Heartbeat sends a single heartbeat POST with the given body.
|
||||
|
||||
200
core/cli/workerregistry/credentials.go
Normal file
200
core/cli/workerregistry/credentials.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// statusPending mirrors nodes.StatusPending. It is duplicated rather than
|
||||
// imported so the lightweight registration client does not pull in the nodes
|
||||
// package (and its gorm/DB dependencies).
|
||||
const statusPending = "pending"
|
||||
|
||||
// defaultMaxAttempts bounds how many times Acquire registers (and how many
|
||||
// consecutive times RefreshLoop may fail) before giving up. It is high enough
|
||||
// to ride out a slow admin approval or a transient frontend outage, but finite
|
||||
// so an unauthorized/unapprovable worker exits and surfaces the problem (via a
|
||||
// non-zero exit and the resulting restart) rather than waiting forever.
|
||||
const defaultMaxAttempts = 100
|
||||
|
||||
// RegisterFunc performs one idempotent registration round-trip.
|
||||
type RegisterFunc func(ctx context.Context) (*RegisterResponse, error)
|
||||
|
||||
// NATSCredentialManager acquires NATS credentials at startup — waiting through
|
||||
// admin approval when required — and refreshes them before the minted JWT
|
||||
// expires, by re-registering (which mints a fresh JWT). The live NATS
|
||||
// connection adopts a refreshed JWT on its next reconnect via Provider. Safe
|
||||
// for concurrent use.
|
||||
//
|
||||
// It addresses two failure modes: a worker that needs credentials but registers
|
||||
// while still pending approval (it would otherwise give up and never connect),
|
||||
// and a long-running worker whose 24h JWT expires with no way to renew it.
|
||||
type NATSCredentialManager struct {
|
||||
register RegisterFunc
|
||||
requireCreds bool // block until credentials are present (frontend minting in use)
|
||||
|
||||
// Tunables; defaults set by NewNATSCredentialManager, overridable in tests.
|
||||
initialBackoff time.Duration
|
||||
maxBackoff time.Duration
|
||||
maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited)
|
||||
refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed
|
||||
refreshRetry time.Duration
|
||||
expiryOf func(jwt string) (time.Time, bool)
|
||||
|
||||
mu sync.RWMutex
|
||||
jwt string
|
||||
seed string
|
||||
nodeID string
|
||||
}
|
||||
|
||||
// NewNATSCredentialManager builds a manager over register. When requireCreds is
|
||||
// true, Acquire blocks until the node is approved and credentials are minted.
|
||||
func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager {
|
||||
return &NATSCredentialManager{
|
||||
register: register,
|
||||
requireCreds: requireCreds,
|
||||
initialBackoff: 2 * time.Second,
|
||||
maxBackoff: 30 * time.Second,
|
||||
maxAttempts: defaultMaxAttempts,
|
||||
refreshLead: 0.75,
|
||||
refreshRetry: 30 * time.Second,
|
||||
expiryOf: jwtExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token
|
||||
// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT).
|
||||
func jwtExpiry(token string) (time.Time, bool) {
|
||||
if token == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
uc, err := natsauth.DecodeUserClaims(token)
|
||||
if err != nil || uc.Expires == 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Unix(uc.Expires, 0), true
|
||||
}
|
||||
|
||||
func (m *NATSCredentialManager) store(res *RegisterResponse) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.nodeID = res.ID
|
||||
if res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed
|
||||
}
|
||||
}
|
||||
|
||||
// Current returns the latest NATS credentials (both empty until acquired).
|
||||
func (m *NATSCredentialManager) Current() (jwt, seed string) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.jwt, m.seed
|
||||
}
|
||||
|
||||
// NodeID returns the node ID from the most recent registration.
|
||||
func (m *NATSCredentialManager) NodeID() string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.nodeID
|
||||
}
|
||||
|
||||
// Provider returns a callback compatible with messaging.WithUserJWTProvider,
|
||||
// supplying the current credentials on each (re)connect.
|
||||
func (m *NATSCredentialManager) Provider() func() (string, string) {
|
||||
return m.Current
|
||||
}
|
||||
|
||||
// HasCredentials reports whether complete NATS credentials have been obtained.
|
||||
func (m *NATSCredentialManager) HasCredentials() bool {
|
||||
jwt, seed := m.Current()
|
||||
return jwt != "" && seed != ""
|
||||
}
|
||||
|
||||
// Acquire registers and, when requireCreds is set, keeps re-registering with
|
||||
// exponential backoff until the node is approved (status != pending) and
|
||||
// credentials are minted. Without requireCreds it returns the first successful
|
||||
// response (the historical one-shot behavior, preserved for anonymous NATS).
|
||||
func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) {
|
||||
backoff := m.initialBackoff
|
||||
var lastReason error
|
||||
for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ {
|
||||
res, err := m.register(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
lastReason = err
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
case !m.requireCreds:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
case res.Status == statusPending:
|
||||
lastReason = fmt.Errorf("node %s still pending admin approval", res.ID)
|
||||
xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
case res.NatsJWT == "" || res.NatsUserSeed == "":
|
||||
lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID)
|
||||
xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
default:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, m.maxBackoff)
|
||||
}
|
||||
return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason)
|
||||
}
|
||||
|
||||
// RefreshLoop re-registers to mint a fresh JWT before the current one expires,
|
||||
// updating the credentials returned by Current/Provider so the NATS connection
|
||||
// adopts them on its next reconnect. It returns nil when ctx is cancelled or
|
||||
// when the current credential has no expiry (nothing to refresh), and a non-nil
|
||||
// error after maxAttempts consecutive refresh failures — letting the caller
|
||||
// exit the worker so it restarts and re-acquires (or surfaces the outage)
|
||||
// rather than silently drifting toward an expired, unrenewable JWT.
|
||||
func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error {
|
||||
failures := 0
|
||||
for {
|
||||
jwt, _ := m.Current()
|
||||
exp, ok := m.expiryOf(jwt)
|
||||
if !ok {
|
||||
xlog.Debug("NATS credential has no expiry; refresh loop exiting")
|
||||
return nil
|
||||
}
|
||||
wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(wait):
|
||||
}
|
||||
|
||||
res, err := m.register(ctx)
|
||||
if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.store(res)
|
||||
failures = 0
|
||||
xlog.Info("Refreshed NATS credentials", "node", res.ID)
|
||||
continue
|
||||
}
|
||||
failures++
|
||||
if err != nil {
|
||||
xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err)
|
||||
} else {
|
||||
xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures)
|
||||
}
|
||||
if m.maxAttempts > 0 && failures >= m.maxAttempts {
|
||||
return fmt.Errorf("NATS credential refresh failed %d times in a row", failures)
|
||||
}
|
||||
// Back off before retrying so a persistent failure near expiry does not spin.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(m.refreshRetry):
|
||||
}
|
||||
}
|
||||
}
|
||||
198
core/cli/workerregistry/credentials_test.go
Normal file
198
core/cli/workerregistry/credentials_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestWorkerRegistry(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "WorkerRegistry")
|
||||
}
|
||||
|
||||
// fakeRegister returns a sequence of canned responses/errors, one per call, and
|
||||
// records how many times it was invoked. The last entry repeats once exhausted.
|
||||
type fakeRegister struct {
|
||||
mu sync.Mutex
|
||||
steps []step
|
||||
calls int
|
||||
}
|
||||
|
||||
type step struct {
|
||||
res *RegisterResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeRegister) fn() RegisterFunc {
|
||||
return func(context.Context) (*RegisterResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
i := f.calls
|
||||
f.calls++
|
||||
if i >= len(f.steps) {
|
||||
i = len(f.steps) - 1
|
||||
}
|
||||
return f.steps[i].res, f.steps[i].err
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeRegister) count() int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls
|
||||
}
|
||||
|
||||
var _ = Describe("NATSCredentialManager", func() {
|
||||
approved := func(jwt, seed string) *RegisterResponse {
|
||||
return &RegisterResponse{ID: "node-1", Status: "healthy", NatsJWT: jwt, NatsUserSeed: seed}
|
||||
}
|
||||
pending := &RegisterResponse{ID: "node-1", Status: "pending"}
|
||||
|
||||
Describe("Acquire (#4 — wait through admin approval)", func() {
|
||||
It("keeps re-registering until the node is approved and credentials are minted", func() {
|
||||
f := &fakeRegister{steps: []step{
|
||||
{res: pending}, // not approved yet
|
||||
{res: approved("", "")}, // approved but JWT not minted yet
|
||||
{res: approved("jwt-1", "seed-1")}, // finally minted
|
||||
}}
|
||||
m := NewNATSCredentialManager(f.fn(), true /* requireCreds */)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.ID).To(Equal("node-1"))
|
||||
Expect(f.count()).To(Equal(3))
|
||||
|
||||
jwt, seed := m.Current()
|
||||
Expect(jwt).To(Equal("jwt-1"))
|
||||
Expect(seed).To(Equal("seed-1"))
|
||||
Expect(m.HasCredentials()).To(BeTrue())
|
||||
Expect(m.NodeID()).To(Equal("node-1"))
|
||||
})
|
||||
|
||||
It("returns immediately on the first success when credentials are not required (anonymous NATS)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), false /* requireCreds */)
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Status).To(Equal("pending"))
|
||||
Expect(f.count()).To(Equal(1))
|
||||
Expect(m.HasCredentials()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("aborts when the context is cancelled while waiting for approval", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = 10 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := m.Acquire(ctx)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
})
|
||||
|
||||
It("gives up after a bounded number of attempts so the worker exits and alerts", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}} // never approved
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
m.maxAttempts = 5
|
||||
|
||||
_, err := m.Acquire(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("after 5 attempts"))
|
||||
Expect(err.Error()).To(ContainSubstring("pending admin approval"))
|
||||
Expect(f.count()).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("RefreshLoop (#5 — renew before the JWT expires)", func() {
|
||||
It("re-registers before expiry and updates the credentials served to new connections", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("jwt-2", "seed-2")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
// jwt-1 expires soon; jwt-2 is long-lived so the loop then idles.
|
||||
m.expiryOf = func(jwt string) (time.Time, bool) {
|
||||
switch jwt {
|
||||
case "jwt-1":
|
||||
return time.Now().Add(40 * time.Millisecond), true
|
||||
case "jwt-2":
|
||||
return time.Now().Add(time.Hour), true
|
||||
default:
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() { _ = m.RefreshLoop(ctx) }()
|
||||
|
||||
Eventually(func() string {
|
||||
jwt, _ := m.Current()
|
||||
return jwt
|
||||
}, "2s", "10ms").Should(Equal("jwt-2"))
|
||||
})
|
||||
|
||||
It("returns an error after the bounded number of consecutive failures so the caller can exit", func() {
|
||||
f := &fakeRegister{steps: []step{{err: context.DeadlineExceeded}}} // refresh always fails
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
m.maxAttempts = 3
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Now().Add(time.Millisecond), true }
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- m.RefreshLoop(context.Background()) }()
|
||||
Eventually(errCh, "2s").Should(Receive(MatchError(ContainSubstring("3 times in a row"))))
|
||||
})
|
||||
|
||||
It("exits promptly when the current credential has no expiry (nothing to refresh)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("x", "y")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Time{}, false }
|
||||
m.store(approved("static", "seed"))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { _ = m.RefreshLoop(context.Background()); close(done) }()
|
||||
Eventually(done, "1s").Should(BeClosed())
|
||||
Expect(f.count()).To(Equal(0)) // never tried to re-register
|
||||
})
|
||||
})
|
||||
|
||||
Describe("jwtExpiry default", func() {
|
||||
It("decodes the expiry of a real minted worker JWT", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
cfg := natsauth.Config{AccountSeed: string(seed), WorkerJWTTTL: time.Hour}
|
||||
token, _, err := cfg.MintWorkerJWT("node-1", "backend")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
exp, ok := jwtExpiry(token)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(exp).To(BeTemporally("~", time.Now().Add(time.Hour), 2*time.Minute))
|
||||
})
|
||||
|
||||
It("reports no expiry for an empty or undecodable token", func() {
|
||||
_, ok := jwtExpiry("")
|
||||
Expect(ok).To(BeFalse())
|
||||
_, ok = jwtExpiry("not-a-jwt")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -12,10 +12,19 @@ import (
|
||||
)
|
||||
|
||||
type ApplicationConfig struct {
|
||||
Context context.Context
|
||||
ConfigFile string
|
||||
SystemState *system.SystemState
|
||||
ExternalBackends []string
|
||||
Context context.Context
|
||||
ConfigFile string
|
||||
SystemState *system.SystemState
|
||||
ExternalBackends []string
|
||||
|
||||
// WebRTCNAT1To1IPs, when set, are advertised as the host ICE candidates for
|
||||
// /v1/realtime WebRTC instead of every local interface address. Needed when
|
||||
// the routable address differs from what pion gathers — e.g. Docker host
|
||||
// networking (where pion also offers unreachable bridge IPs) or NAT.
|
||||
WebRTCNAT1To1IPs []string
|
||||
// WebRTCICEInterfaces, when set, restricts ICE candidate gathering to these
|
||||
// network interfaces (e.g. eth0), filtering out docker0/veth noise.
|
||||
WebRTCICEInterfaces []string
|
||||
UploadLimitMB, Threads, ContextSize int
|
||||
F16 bool
|
||||
Debug bool
|
||||
@@ -81,7 +90,6 @@ type ApplicationConfig struct {
|
||||
// file is mode 0600.
|
||||
MITMCADir string
|
||||
|
||||
|
||||
// PIIPatternOverrides applies persisted per-id deltas (action,
|
||||
// disabled) to the live redactor at startup. Loaded from
|
||||
// runtime_settings.json and applied right after pii.NewRedactor.
|
||||
@@ -116,11 +124,11 @@ type ApplicationConfig struct {
|
||||
// --require-backend-integrity / LOCALAI_REQUIRE_BACKEND_INTEGRITY.
|
||||
RequireBackendIntegrity bool
|
||||
|
||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
WatchDogIdle bool
|
||||
WatchDogBusy bool
|
||||
WatchDog bool
|
||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
WatchDogIdle bool
|
||||
WatchDogBusy bool
|
||||
WatchDog bool
|
||||
|
||||
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
|
||||
MemoryReclaimerEnabled bool // Enable memory threshold monitoring
|
||||
@@ -311,6 +319,18 @@ func WithExternalBackends(backends ...string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithWebRTCNAT1To1IPs(ips ...string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.WebRTCNAT1To1IPs = ips
|
||||
}
|
||||
}
|
||||
|
||||
func WithWebRTCICEInterfaces(interfaces ...string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.WebRTCICEInterfaces = interfaces
|
||||
}
|
||||
}
|
||||
|
||||
func WithMachineTag(tag string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MachineTag = tag
|
||||
@@ -702,7 +722,6 @@ func WithMITMCADir(dir string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DynamicConfigsDir = dynamicConfigsDir
|
||||
|
||||
@@ -22,9 +22,11 @@ const (
|
||||
UsecaseRerank = "rerank"
|
||||
UsecaseDetection = "detection"
|
||||
UsecaseVAD = "vad"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseFaceRecognition = "face_recognition"
|
||||
UsecaseSpeakerRecognition = "speaker_recognition"
|
||||
)
|
||||
|
||||
// GRPCMethod identifies a Backend service RPC from backend.proto.
|
||||
@@ -47,6 +49,11 @@ const (
|
||||
MethodAudioTransform GRPCMethod = "AudioTransform"
|
||||
MethodDiarize GRPCMethod = "Diarize"
|
||||
MethodAudioToAudioStream GRPCMethod = "AudioToAudioStream"
|
||||
MethodFaceVerify GRPCMethod = "FaceVerify"
|
||||
MethodFaceAnalyze GRPCMethod = "FaceAnalyze"
|
||||
MethodVoiceVerify GRPCMethod = "VoiceVerify"
|
||||
MethodVoiceEmbed GRPCMethod = "VoiceEmbed"
|
||||
MethodVoiceAnalyze GRPCMethod = "VoiceAnalyze"
|
||||
)
|
||||
|
||||
// UsecaseInfo describes a single known_usecase value and how it maps
|
||||
@@ -154,6 +161,16 @@ var UsecaseInfoMap = map[string]UsecaseInfo{
|
||||
GRPCMethod: MethodAudioToAudioStream,
|
||||
Description: "Self-contained any-to-any audio model for the Realtime API — accepts microphone audio and emits speech + transcript (+ optional function calls) from a single backend via the AudioToAudioStream RPC.",
|
||||
},
|
||||
UsecaseFaceRecognition: {
|
||||
Flag: FLAG_FACE_RECOGNITION,
|
||||
GRPCMethod: MethodFaceVerify,
|
||||
Description: "Face recognition — verify identity, analyze attributes (age/gender/emotion) via FaceVerify and FaceAnalyze RPCs.",
|
||||
},
|
||||
UsecaseSpeakerRecognition: {
|
||||
Flag: FLAG_SPEAKER_RECOGNITION,
|
||||
GRPCMethod: MethodVoiceVerify,
|
||||
Description: "Speaker recognition — verify identity, embed and analyze voice via VoiceVerify, VoiceEmbed and VoiceAnalyze RPCs.",
|
||||
},
|
||||
}
|
||||
|
||||
// BackendCapability describes which gRPC methods and usecases a backend supports.
|
||||
@@ -471,6 +488,21 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR C++ object detection",
|
||||
},
|
||||
|
||||
// --- Face and speaker recognition backends ---
|
||||
"insightface": {
|
||||
GRPCMethods: []GRPCMethod{MethodEmbedding, MethodDetect, MethodFaceVerify, MethodFaceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseEmbeddings, UsecaseDetection, UsecaseFaceRecognition},
|
||||
DefaultUsecases: []string{UsecaseFaceRecognition},
|
||||
AcceptsImages: true,
|
||||
Description: "InsightFace — face detection, embedding, verification and attribute analysis",
|
||||
},
|
||||
"speaker-recognition": {
|
||||
GRPCMethods: []GRPCMethod{MethodVoiceVerify, MethodVoiceEmbed, MethodVoiceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseSpeakerRecognition},
|
||||
DefaultUsecases: []string{UsecaseSpeakerRecognition},
|
||||
Description: "Speaker recognition — voice identity verification and analysis",
|
||||
},
|
||||
"silero-vad": {
|
||||
GRPCMethods: []GRPCMethod{MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseVAD},
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -16,7 +18,29 @@ type DistributedConfig struct {
|
||||
NatsURL string // --nats-url / LOCALAI_NATS_URL
|
||||
StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint)
|
||||
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
// RegistrationRequireAuth fails startup when distributed mode is enabled but
|
||||
// RegistrationToken is empty. The default (false) keeps the historical
|
||||
// fail-open behavior with a loud warning; production should set it so the
|
||||
// node-register endpoints and the worker file-transfer server cannot run
|
||||
// unauthenticated. Mirrors NatsRequireAuth for the NATS bus.
|
||||
RegistrationRequireAuth bool // LOCALAI_REGISTRATION_REQUIRE_AUTH
|
||||
// RequireAuth is the umbrella switch (LOCALAI_DISTRIBUTED_REQUIRE_AUTH) for
|
||||
// distributed-mode auth: when true it implies BOTH NatsRequireAuth and
|
||||
// RegistrationRequireAuth, so a single knob locks down the bus and the
|
||||
// registration/file-transfer layer together. The granular flags remain
|
||||
// available to enforce just one layer.
|
||||
RequireAuth bool // LOCALAI_DISTRIBUTED_REQUIRE_AUTH
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
|
||||
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
|
||||
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
|
||||
NatsServiceJWT string // LOCALAI_NATS_SERVICE_JWT — user JWT for frontends / agent workers
|
||||
NatsServiceSeed string // LOCALAI_NATS_SERVICE_SEED — signing seed paired with service JWT
|
||||
NatsWorkerJWTTTL time.Duration // LOCALAI_NATS_WORKER_JWT_TTL — minted worker JWT lifetime (default 24h)
|
||||
NatsRequireAuth bool // LOCALAI_NATS_REQUIRE_AUTH — fail startup if NATS credentials are missing
|
||||
NatsTLSCA string // LOCALAI_NATS_TLS_CA — PEM file for private CA (server verify)
|
||||
NatsTLSCert string // LOCALAI_NATS_TLS_CERT — client cert for NATS mTLS
|
||||
NatsTLSKey string // LOCALAI_NATS_TLS_KEY — client key paired with NatsTLSCert
|
||||
|
||||
// S3 configuration (used when StorageURL is set)
|
||||
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
|
||||
@@ -76,10 +100,23 @@ func (c DistributedConfig) Validate() error {
|
||||
(c.StorageAccessKey == "" && c.StorageSecretKey != "") {
|
||||
return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty")
|
||||
}
|
||||
// Warn about missing registration token (not an error)
|
||||
// The registration token guards both the node HTTP register/heartbeat
|
||||
// endpoints and the worker file-transfer server (which fails open on an
|
||||
// empty token). Enforce it when registration auth is required (the granular
|
||||
// flag or the umbrella); otherwise warn.
|
||||
if c.RegistrationToken == "" {
|
||||
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
|
||||
if c.RegistrationAuthRequired() {
|
||||
return fmt.Errorf("registration auth is required (LOCALAI_REGISTRATION_REQUIRE_AUTH or LOCALAI_DISTRIBUTED_REQUIRE_AUTH) but LOCALAI_REGISTRATION_TOKEN is empty")
|
||||
}
|
||||
xlog.Warn("distributed mode running without registration token — node endpoints and the worker file-transfer server are unprotected; set LOCALAI_REGISTRATION_TOKEN, or LOCALAI_DISTRIBUTED_REQUIRE_AUTH=true to fail closed")
|
||||
}
|
||||
if err := c.NatsAuthConfig().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.NatsTLSFiles().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.NatsAuthConfig().WarnIfInsecure(true)
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
FlagMCPToolTimeout: c.MCPToolTimeout,
|
||||
@@ -123,6 +160,76 @@ func WithRegistrationToken(token string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsAccountSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsAccountSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceJWT(jwt string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceJWT = jwt
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsWorkerJWTTTL(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsWorkerJWTTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
var EnableNatsRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsRequireAuth = true
|
||||
}
|
||||
|
||||
// EnableRegistrationRequireAuth makes an empty registration token a hard error
|
||||
// in distributed mode (see DistributedConfig.RegistrationRequireAuth).
|
||||
var EnableRegistrationRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.RegistrationRequireAuth = true
|
||||
}
|
||||
|
||||
// EnableDistributedRequireAuth is the umbrella switch implying both
|
||||
// NatsRequireAuth and RegistrationRequireAuth (see DistributedConfig.RequireAuth).
|
||||
var EnableDistributedRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.RequireAuth = true
|
||||
}
|
||||
|
||||
// RegistrationAuthRequired reports whether an empty registration token must be
|
||||
// treated as a fatal misconfiguration — the granular flag or the umbrella.
|
||||
func (c DistributedConfig) RegistrationAuthRequired() bool {
|
||||
return c.RegistrationRequireAuth || c.RequireAuth
|
||||
}
|
||||
|
||||
// NatsAuthRequired reports whether NATS JWT credentials must be present — the
|
||||
// granular flag or the umbrella.
|
||||
func (c DistributedConfig) NatsAuthRequired() bool {
|
||||
return c.NatsRequireAuth || c.RequireAuth
|
||||
}
|
||||
|
||||
func WithNatsTLSCA(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCA = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSCert(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCert = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSKey(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSKey = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageURL(url string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageURL = url
|
||||
@@ -217,6 +324,44 @@ const (
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// NatsTLSFiles returns NATS TLS/mTLS PEM paths for the messaging client.
|
||||
func (c DistributedConfig) NatsTLSFiles() messaging.TLSFiles {
|
||||
return messaging.TLSFiles{
|
||||
CA: c.NatsTLSCA,
|
||||
Cert: c.NatsTLSCert,
|
||||
Key: c.NatsTLSKey,
|
||||
}
|
||||
}
|
||||
|
||||
// NatsMessagingOptions builds messaging client options (JWT + TLS) for distributed components.
|
||||
// Pass explicit userJWT/userSeed when set (e.g. worker overrides); empty uses service JWT from config.
|
||||
func (c DistributedConfig) NatsMessagingOptions(userJWT, userSeed string) []messaging.Option {
|
||||
var opts []messaging.Option
|
||||
jwt, seed := userJWT, userSeed
|
||||
if jwt == "" && seed == "" {
|
||||
auth := c.NatsAuthConfig()
|
||||
jwt, seed = auth.ServiceUserJWT, auth.ServiceUserSeed
|
||||
}
|
||||
if jwt != "" && seed != "" {
|
||||
opts = append(opts, messaging.WithUserJWT(jwt, seed))
|
||||
}
|
||||
if tls := c.NatsTLSFiles(); tls.Enabled() {
|
||||
opts = append(opts, messaging.WithTLS(tls))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// NatsAuthConfig builds pkg/natsauth settings from distributed configuration.
|
||||
func (c DistributedConfig) NatsAuthConfig() natsauth.Config {
|
||||
return natsauth.Config{
|
||||
AccountSeed: c.NatsAccountSeed,
|
||||
ServiceUserJWT: c.NatsServiceJWT,
|
||||
ServiceUserSeed: c.NatsServiceSeed,
|
||||
WorkerJWTTTL: c.NatsWorkerJWTTTL,
|
||||
RequireAuth: c.NatsAuthRequired(),
|
||||
}
|
||||
}
|
||||
|
||||
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||
|
||||
@@ -88,3 +88,66 @@ var _ = Describe("DistributedConfig.Validate negative-duration errors", func() {
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("DistributedConfig.Validate registration auth", func() {
|
||||
It("rejects an empty registration token when RequireAuth is set", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RegistrationRequireAuth: true,
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_REQUIRE_AUTH"))
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
|
||||
})
|
||||
|
||||
It("accepts a set registration token when RequireAuth is set", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RegistrationToken: "s3cret",
|
||||
RegistrationRequireAuth: true,
|
||||
}
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
|
||||
It("warns but succeeds with an empty token when RequireAuth is unset", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
}
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects an empty token when the umbrella RequireAuth is set", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RequireAuth: true,
|
||||
// Provide NATS creds so only the registration-token gap remains.
|
||||
NatsServiceJWT: "jwt",
|
||||
NatsServiceSeed: "seed",
|
||||
NatsAccountSeed: "acct",
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_DISTRIBUTED_REQUIRE_AUTH"))
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
|
||||
})
|
||||
|
||||
It("the umbrella implies NATS auth is required", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
RegistrationToken: "tok", // registration layer satisfied
|
||||
RequireAuth: true, // umbrella → NATS creds now required
|
||||
}
|
||||
Expect(c.NatsAuthRequired()).To(BeTrue())
|
||||
Expect(c.RegistrationAuthRequired()).To(BeTrue())
|
||||
// Missing NATS service JWT/seed must now be fatal.
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("LOCALAI_NATS_REQUIRE_AUTH"))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -39,7 +39,21 @@ func llamaCppDefaults(cfg *ModelConfig, modelPath string) {
|
||||
}
|
||||
}()
|
||||
|
||||
f, err := gguf.ParseGGUFFile(guessPath)
|
||||
// Startup parses every model's GGUF header to guess defaults. We only need
|
||||
// scalar metadata (architecture, head/ff counts, chat_template, token IDs,
|
||||
// MTP head) plus array *lengths* — never the array *contents*. Two options
|
||||
// keep this cheap, which matters when many models live on slow storage such
|
||||
// as a Docker volume (see https://github.com/mudler/LocalAI/issues/9790):
|
||||
//
|
||||
// - SkipLargeMetadata: seek past large array-valued metadata (the tokenizer
|
||||
// vocab: tokenizer.ggml.tokens/scores/merges, often >100k entries) instead
|
||||
// of reading and allocating every element. Lengths stay populated.
|
||||
// - UseMMap: read the header via a memory map so faulting in a few pages
|
||||
// replaces hundreds of thousands of tiny read() syscalls (measured ~524k
|
||||
// -> 8 for a 256k-token vocab), the dominant cost on slow filesystems.
|
||||
//
|
||||
// The mapping is released when ParseGGUFFile returns.
|
||||
f, err := gguf.ParseGGUFFile(guessPath, gguf.UseMMap(), gguf.SkipLargeMetadata())
|
||||
if err == nil {
|
||||
guessGGUFFromFile(cfg, f, 0)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,76 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// GGUF metadata value type tags (see github.com/gpustack/gguf-parser-go).
|
||||
const (
|
||||
ggufTypeUint32 uint32 = 4
|
||||
ggufTypeString uint32 = 8
|
||||
ggufTypeArray uint32 = 9
|
||||
)
|
||||
|
||||
// writeTestGGUF emits a minimal but valid little-endian GGUF v3 header carrying
|
||||
// the scalar metadata the llama-cpp hook guesses from plus a large string vocab
|
||||
// array (tokenizer.ggml.tokens). The big array is exactly what SkipLargeMetadata
|
||||
// + UseMMap are expected to avoid reading element-by-element, so it must survive a
|
||||
// round-trip through the real hook without corrupting the guessed defaults.
|
||||
func writeTestGGUF(path, chatTemplate string, vocab int) error {
|
||||
wStr := func(b *bytes.Buffer, s string) {
|
||||
binary.Write(b, binary.LittleEndian, uint64(len(s)))
|
||||
b.WriteString(s)
|
||||
}
|
||||
kvStr := func(b *bytes.Buffer, k, v string) {
|
||||
wStr(b, k)
|
||||
binary.Write(b, binary.LittleEndian, ggufTypeString)
|
||||
wStr(b, v)
|
||||
}
|
||||
kvU32 := func(b *bytes.Buffer, k string, v uint32) {
|
||||
wStr(b, k)
|
||||
binary.Write(b, binary.LittleEndian, ggufTypeUint32)
|
||||
binary.Write(b, binary.LittleEndian, v)
|
||||
}
|
||||
|
||||
var meta bytes.Buffer
|
||||
kvStr(&meta, "general.architecture", "llama")
|
||||
kvStr(&meta, "general.name", "ReproModel")
|
||||
kvU32(&meta, "llama.context_length", 4096)
|
||||
kvU32(&meta, "llama.attention.head_count", 32)
|
||||
kvU32(&meta, "llama.feed_forward_length", 11008)
|
||||
kvU32(&meta, "llama.block_count", 32)
|
||||
kvU32(&meta, "tokenizer.ggml.bos_token_id", 1)
|
||||
kvStr(&meta, "tokenizer.chat_template", chatTemplate)
|
||||
|
||||
// large array value — the one the optimization skips reading
|
||||
wStr(&meta, "tokenizer.ggml.tokens")
|
||||
binary.Write(&meta, binary.LittleEndian, ggufTypeArray)
|
||||
binary.Write(&meta, binary.LittleEndian, ggufTypeString)
|
||||
binary.Write(&meta, binary.LittleEndian, uint64(vocab))
|
||||
for i := 0; i < vocab; i++ {
|
||||
wStr(&meta, "token")
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
binary.Write(&out, binary.LittleEndian, gguf.GGUFMagicGGUFLe)
|
||||
binary.Write(&out, binary.LittleEndian, uint32(3)) // version
|
||||
binary.Write(&out, binary.LittleEndian, uint64(0)) // tensor count
|
||||
binary.Write(&out, binary.LittleEndian, uint64(9)) // metadata kv count
|
||||
out.Write(meta.Bytes())
|
||||
|
||||
return os.WriteFile(path, out.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
var _ = Describe("Backend hooks and parser defaults", func() {
|
||||
Context("MatchParserDefaults", func() {
|
||||
It("matches Qwen3 family", func() {
|
||||
@@ -137,6 +200,58 @@ var _ = Describe("Backend hooks and parser defaults", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("llamaCppDefaults GGUF guessing", func() {
|
||||
// Regression coverage for https://github.com/mudler/LocalAI/issues/9790:
|
||||
// the hook reads GGUF headers with SkipLargeMetadata + UseMMap to avoid
|
||||
// pulling the whole tokenizer vocab off (slow) disk on every startup. This
|
||||
// verifies that skipping the vocab array still yields the correct guessed
|
||||
// defaults from the remaining scalar metadata.
|
||||
const chatTemplate = "{{ bos_token }}{% for m in messages %}{{ m.content }}{% endfor %}"
|
||||
|
||||
It("guesses defaults from a GGUF whose large vocab is skipped", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
modelFile := "repro.gguf"
|
||||
Expect(writeTestGGUF(filepath.Join(dir, modelFile), chatTemplate, 50000)).To(Succeed())
|
||||
|
||||
// A pre-set context size short-circuits the GGUF run-estimate, which
|
||||
// needs full tensor info this header-only fixture deliberately omits;
|
||||
// the metadata-reading path the optimization touches is unaffected.
|
||||
ctxSize := 4096
|
||||
cfg := &ModelConfig{
|
||||
Backend: "llama-cpp",
|
||||
LLMConfig: LLMConfig{ContextSize: &ctxSize},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: modelFile},
|
||||
},
|
||||
}
|
||||
cfg.SetDefaults(ModelPath(dir))
|
||||
|
||||
// chat_template is a scalar string, not part of the skipped array,
|
||||
// so it must be captured verbatim.
|
||||
Expect(cfg.GetModelTemplate()).To(Equal(chatTemplate))
|
||||
// scalar-derived defaults are still applied
|
||||
Expect(cfg.ContextSize).NotTo(BeNil())
|
||||
Expect(cfg.NGPULayers).NotTo(BeNil())
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
Expect(cfg.KnownUsecaseStrings).To(ContainElement("FLAG_CHAT"))
|
||||
})
|
||||
|
||||
It("falls back to the default context size when the GGUF is unreadable", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
Expect(os.WriteFile(filepath.Join(dir, "bad.gguf"), []byte("not a gguf"), 0o644)).To(Succeed())
|
||||
|
||||
cfg := &ModelConfig{
|
||||
Backend: "llama-cpp",
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: "bad.gguf"},
|
||||
},
|
||||
}
|
||||
cfg.SetDefaults(ModelPath(dir))
|
||||
|
||||
Expect(cfg.ContextSize).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("PromptCacheAll default", func() {
|
||||
It("defaults to true when omitted from YAML", func() {
|
||||
cfg := &ModelConfig{}
|
||||
|
||||
@@ -128,6 +128,22 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Advanced: true,
|
||||
Order: 21,
|
||||
},
|
||||
"reasoning_effort": {
|
||||
Section: "llm",
|
||||
Label: "Reasoning Effort",
|
||||
Description: "Default reasoning effort, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). A per-request reasoning_effort overrides it. 'none' also turns thinking off.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Unset (model default)"},
|
||||
{Value: "none", Label: "none (disable thinking)"},
|
||||
{Value: "minimal", Label: "minimal"},
|
||||
{Value: "low", Label: "low"},
|
||||
{Value: "medium", Label: "medium"},
|
||||
{Value: "high", Label: "high"},
|
||||
},
|
||||
Advanced: true,
|
||||
Order: 22,
|
||||
},
|
||||
"cache_type_k": {
|
||||
Section: "llm",
|
||||
Label: "KV Cache Type (K)",
|
||||
@@ -277,6 +293,56 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
AutocompleteProvider: ProviderModelsVAD,
|
||||
Order: 63,
|
||||
},
|
||||
"pipeline.reasoning_effort": {
|
||||
Section: "pipeline",
|
||||
Label: "Reasoning Effort",
|
||||
Description: "Reasoning effort for the pipeline's LLM, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). Overrides the LLM model's own reasoning_effort. 'none' also turns thinking off.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Default (model config)"},
|
||||
{Value: "none", Label: "none (disable thinking)"},
|
||||
{Value: "minimal", Label: "minimal"},
|
||||
{Value: "low", Label: "low"},
|
||||
{Value: "medium", Label: "medium"},
|
||||
{Value: "high", Label: "high"},
|
||||
},
|
||||
Order: 64,
|
||||
},
|
||||
"pipeline.disable_thinking": {
|
||||
Section: "pipeline",
|
||||
Label: "Disable Thinking",
|
||||
Description: "Suppress reasoning/thinking output from the pipeline LLM (sets enable_thinking=false on the underlying model). Use for models that emit <think> blocks you don't want spoken or streamed back to the realtime client.",
|
||||
Component: "toggle",
|
||||
Order: 65,
|
||||
},
|
||||
"pipeline.streaming.llm": {
|
||||
Section: "pipeline",
|
||||
Label: "Stream LLM",
|
||||
Description: "Stream LLM tokens to the realtime client as they are generated instead of waiting for the full response. Emits incremental response.output_audio_transcript.delta / text deltas.",
|
||||
Component: "toggle",
|
||||
Order: 66,
|
||||
},
|
||||
"pipeline.streaming.tts": {
|
||||
Section: "pipeline",
|
||||
Label: "Stream TTS",
|
||||
Description: "Stream synthesized audio chunks to the realtime client as they are produced (requires a TTS backend that implements TTSStream). Falls back to unary synthesis otherwise.",
|
||||
Component: "toggle",
|
||||
Order: 67,
|
||||
},
|
||||
"pipeline.streaming.transcription": {
|
||||
Section: "pipeline",
|
||||
Label: "Stream Transcription",
|
||||
Description: "Stream partial transcription text to the realtime client as the STT backend produces it (requires a transcription backend that implements AudioTranscriptionStream). Falls back to unary transcription otherwise.",
|
||||
Component: "toggle",
|
||||
Order: 68,
|
||||
},
|
||||
"pipeline.streaming.clause_chunking": {
|
||||
Section: "pipeline",
|
||||
Label: "Clause Chunking",
|
||||
Description: "Split the streamed reply into speakable clauses and synthesize each as soon as it completes, instead of buffering the whole message before TTS — lower time-to-first-audio. Script-aware (handles CJK 。!? and Thai/Lao spaces), so it does not whitespace-split. Requires Stream LLM; off buffers the whole message.",
|
||||
Component: "toggle",
|
||||
Order: 69,
|
||||
},
|
||||
|
||||
// --- Functions ---
|
||||
"function.grammar.parallel_calls": {
|
||||
|
||||
@@ -63,6 +63,13 @@ type ModelConfig struct {
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
|
||||
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||
|
||||
// ReasoningEffort is the default reasoning effort (none|minimal|low|medium|high)
|
||||
// for this model. A per-request reasoning_effort overrides it. It is forwarded
|
||||
// to the backend as the reasoning_effort chat_template_kwarg (see
|
||||
// gRPCPredictOpts), so jinja-templated models that key on it — e.g. gpt-oss
|
||||
// (Harmony) or LFM2.5 — honor it; "none" also toggles enable_thinking off.
|
||||
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
|
||||
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
LLMConfig `yaml:",inline" json:",inline"`
|
||||
@@ -487,6 +494,85 @@ type Pipeline struct {
|
||||
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
|
||||
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
|
||||
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
|
||||
|
||||
// ReasoningEffort sets the reasoning effort (none|minimal|low|medium|high) for
|
||||
// the pipeline's LLM without editing the LLM model config. Overrides the LLM's
|
||||
// own reasoning_effort. Unset leaves the LLM model config in charge.
|
||||
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
|
||||
|
||||
// Streaming opts each pipeline stage into incremental delivery (LLM tokens,
|
||||
// TTS audio chunks, transcription text). Unset stages keep the blocking
|
||||
// unary path, so existing configs are unaffected.
|
||||
Streaming PipelineStreaming `yaml:"streaming,omitempty" json:"streaming,omitempty"`
|
||||
|
||||
// DisableThinking suppresses reasoning/thinking for the pipeline LLM (maps
|
||||
// to enable_thinking=false backend metadata) without editing the underlying
|
||||
// LLM model config. Unset leaves the LLM model config in charge.
|
||||
DisableThinking *bool `yaml:"disable_thinking,omitempty" json:"disable_thinking,omitempty"`
|
||||
}
|
||||
|
||||
// ApplyReasoningEffort resolves the effective reasoning effort — a per-request
|
||||
// value (requestEffort) overrides the config's own ReasoningEffort default —
|
||||
// stores it on the config so gRPCPredictOpts forwards it to the backend as the
|
||||
// reasoning_effort chat_template_kwarg, and maps it onto the enable_thinking
|
||||
// toggle the backend also reads:
|
||||
// - "none" always disables thinking.
|
||||
// - any explicit level enables it, UNLESS the config already disabled reasoning
|
||||
// (an operator's explicit disable wins over a request asking to think).
|
||||
//
|
||||
// An empty requestEffort keeps the config's own default. With no effort set
|
||||
// anywhere it is a no-op, leaving the model's reasoning settings untouched.
|
||||
func (c *ModelConfig) ApplyReasoningEffort(requestEffort string) {
|
||||
effort := requestEffort
|
||||
if effort == "" {
|
||||
effort = c.ReasoningEffort
|
||||
}
|
||||
c.ReasoningEffort = effort
|
||||
switch strings.ToLower(effort) {
|
||||
case "none":
|
||||
disable := true
|
||||
c.ReasoningConfig.DisableReasoning = &disable
|
||||
case "minimal", "low", "medium", "high":
|
||||
if c.ReasoningConfig.DisableReasoning == nil || !*c.ReasoningConfig.DisableReasoning {
|
||||
enable := false
|
||||
c.ReasoningConfig.DisableReasoning = &enable
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// @Description PipelineStreaming toggles incremental delivery per realtime stage.
|
||||
type PipelineStreaming struct {
|
||||
LLM *bool `yaml:"llm,omitempty" json:"llm,omitempty"`
|
||||
TTS *bool `yaml:"tts,omitempty" json:"tts,omitempty"`
|
||||
Transcription *bool `yaml:"transcription,omitempty" json:"transcription,omitempty"`
|
||||
// ClauseChunking splits the streamed LLM reply into speakable clauses and
|
||||
// synthesizes each as soon as it completes, instead of buffering the whole
|
||||
// message before TTS. Script-aware (CJK/Thai), so it does not rely on
|
||||
// whitespace sentence boundaries. Requires LLM streaming; unset buffers the
|
||||
// whole message (today's default).
|
||||
ClauseChunking *bool `yaml:"clause_chunking,omitempty" json:"clause_chunking,omitempty"`
|
||||
}
|
||||
|
||||
// StreamLLM reports whether LLM tokens should be streamed for this pipeline.
|
||||
func (p Pipeline) StreamLLM() bool { return p.Streaming.LLM != nil && *p.Streaming.LLM }
|
||||
|
||||
// StreamTTS reports whether TTS audio should be streamed for this pipeline.
|
||||
func (p Pipeline) StreamTTS() bool { return p.Streaming.TTS != nil && *p.Streaming.TTS }
|
||||
|
||||
// StreamTranscription reports whether transcription text should be streamed.
|
||||
func (p Pipeline) StreamTranscription() bool {
|
||||
return p.Streaming.Transcription != nil && *p.Streaming.Transcription
|
||||
}
|
||||
|
||||
// ChunkClauses reports whether the streamed reply should be split into
|
||||
// script-aware clauses and synthesized incrementally rather than buffered whole.
|
||||
func (p Pipeline) ChunkClauses() bool {
|
||||
return p.Streaming.ClauseChunking != nil && *p.Streaming.ClauseChunking
|
||||
}
|
||||
|
||||
// ThinkingDisabled reports whether the pipeline forces the LLM's thinking off.
|
||||
func (p Pipeline) ThinkingDisabled() bool {
|
||||
return p.DisableThinking != nil && *p.DisableThinking
|
||||
}
|
||||
|
||||
// @Description File configuration for model downloads
|
||||
|
||||
@@ -30,11 +30,26 @@ func MTPSpecOptions() []string {
|
||||
return out
|
||||
}
|
||||
|
||||
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a Multi-Token
|
||||
// Prediction head. Detection reads `<arch>.nextn_predict_layers`, which is
|
||||
// what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
|
||||
// isDraftOnlyAssistantArch reports whether an architecture names a standalone
|
||||
// MTP *draft* model rather than a self-speculating trunk. Upstream's Gemma4 MTP
|
||||
// (ggml-org/llama.cpp#23398) registers the head as a separate `gemma4-assistant`
|
||||
// architecture whose GGUF still carries `nextn_predict_layers`, but which cannot
|
||||
// run alone: it requires a paired target context (`ctx_other`). Such archs must
|
||||
// not trigger the embedded-head self-speculation defaults. The `-assistant`
|
||||
// suffix is upstream's naming convention for these draft-only checkpoints.
|
||||
func isDraftOnlyAssistantArch(arch string) bool {
|
||||
return strings.HasSuffix(arch, "-assistant")
|
||||
}
|
||||
|
||||
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a self-speculating
|
||||
// Multi-Token Prediction head. Detection reads `<arch>.nextn_predict_layers`,
|
||||
// which is what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
|
||||
// `conversion/qwen.py` MTP mixin. A positive layer count means the head is
|
||||
// present in the same GGUF as the trunk.
|
||||
//
|
||||
// Draft-only assistant architectures (e.g. Gemma4's `gemma4-assistant`) carry
|
||||
// the same key but are separate draft checkpoints meant to be paired with a
|
||||
// target model, so they are deliberately excluded here.
|
||||
func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
|
||||
if f == nil {
|
||||
return 0, false
|
||||
@@ -43,6 +58,9 @@ func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
|
||||
if arch == "" {
|
||||
return 0, false
|
||||
}
|
||||
if isDraftOnlyAssistantArch(arch) {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := f.Header.MetadataKV.Get(arch + ".nextn_predict_layers")
|
||||
if !ok {
|
||||
return 0, false
|
||||
|
||||
@@ -3,10 +3,33 @@ package config_test
|
||||
import (
|
||||
. "github.com/mudler/LocalAI/core/config"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// ggufWithArch fabricates a minimal in-memory GGUF carrying the given
|
||||
// `general.architecture` and a positive `<arch>.nextn_predict_layers` count,
|
||||
// so HasEmbeddedMTPHead can be exercised without a real model file.
|
||||
func ggufWithArch(arch string, nextn uint32) *gguf.GGUFFile {
|
||||
return &gguf.GGUFFile{
|
||||
Header: gguf.GGUFHeader{
|
||||
MetadataKV: gguf.GGUFMetadataKVs{
|
||||
{
|
||||
Key: "general.architecture",
|
||||
ValueType: gguf.GGUFMetadataValueTypeString,
|
||||
Value: arch,
|
||||
},
|
||||
{
|
||||
Key: arch + ".nextn_predict_layers",
|
||||
ValueType: gguf.GGUFMetadataValueTypeUint32,
|
||||
Value: nextn,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("MTP auto-defaults", func() {
|
||||
Context("MTPSpecOptions", func() {
|
||||
It("returns the upstream-recommended speculative tuple", func() {
|
||||
@@ -82,5 +105,20 @@ var _ = Describe("MTP auto-defaults", func() {
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
|
||||
It("detects a same-GGUF embedded head (DeepSeek/Qwen style)", func() {
|
||||
n, ok := HasEmbeddedMTPHead(ggufWithArch("qwen3moe", 1))
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(n).To(Equal(uint32(1)))
|
||||
})
|
||||
|
||||
It("ignores a gemma4-assistant draft-only model", func() {
|
||||
// The assistant GGUF carries nextn_predict_layers but is a separate
|
||||
// draft model that requires a paired target (ctx_other); it cannot
|
||||
// self-speculate, so it must not trigger the embedded-head defaults.
|
||||
n, ok := HasEmbeddedMTPHead(ggufWithArch("gemma4-assistant", 48))
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
57
core/config/pipeline_streaming_test.go
Normal file
57
core/config/pipeline_streaming_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// The realtime pipeline can stream each stage (LLM tokens, TTS audio,
|
||||
// transcription text) and can disable model "thinking" for the LLM. These are
|
||||
// opt-in per pipeline; everything defaults to off so existing configs keep the
|
||||
// unary behaviour.
|
||||
var _ = Describe("Pipeline streaming config", func() {
|
||||
It("defaults every streaming + thinking helper to false when unset", func() {
|
||||
var p Pipeline
|
||||
Expect(p.StreamLLM()).To(BeFalse())
|
||||
Expect(p.StreamTTS()).To(BeFalse())
|
||||
Expect(p.StreamTranscription()).To(BeFalse())
|
||||
Expect(p.ChunkClauses()).To(BeFalse())
|
||||
Expect(p.ThinkingDisabled()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("parses the nested streaming block and disable_thinking from YAML", func() {
|
||||
var c ModelConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
name: gpt-realtime
|
||||
pipeline:
|
||||
llm: my-llm
|
||||
tts: my-tts
|
||||
transcription: my-stt
|
||||
streaming:
|
||||
llm: true
|
||||
tts: true
|
||||
transcription: true
|
||||
clause_chunking: true
|
||||
disable_thinking: true
|
||||
`), &c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Pipeline.StreamLLM()).To(BeTrue())
|
||||
Expect(c.Pipeline.StreamTTS()).To(BeTrue())
|
||||
Expect(c.Pipeline.StreamTranscription()).To(BeTrue())
|
||||
Expect(c.Pipeline.ChunkClauses()).To(BeTrue())
|
||||
Expect(c.Pipeline.ThinkingDisabled()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("treats an explicit false in the streaming block as disabled", func() {
|
||||
var c ModelConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
name: gpt-realtime
|
||||
pipeline:
|
||||
streaming:
|
||||
tts: false
|
||||
`), &c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Pipeline.StreamTTS()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
52
core/config/reasoning_effort_test.go
Normal file
52
core/config/reasoning_effort_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// ApplyReasoningEffort resolves the effective reasoning effort (request value
|
||||
// overrides the model config default), stores it on the config so it reaches the
|
||||
// backend, and maps it onto the enable_thinking toggle.
|
||||
var _ = Describe("ModelConfig.ApplyReasoningEffort", func() {
|
||||
It("uses the request value over the config default", func() {
|
||||
c := &config.ModelConfig{ReasoningEffort: "high"}
|
||||
c.ApplyReasoningEffort("none")
|
||||
Expect(c.ReasoningEffort).To(Equal("none"))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("falls back to the config default when the request omits it", func() {
|
||||
c := &config.ModelConfig{ReasoningEffort: "none"}
|
||||
c.ApplyReasoningEffort("")
|
||||
Expect(c.ReasoningEffort).To(Equal("none"))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("enables thinking for an explicit effort level", func() {
|
||||
c := &config.ModelConfig{}
|
||||
c.ApplyReasoningEffort("medium")
|
||||
Expect(c.ReasoningEffort).To(Equal("medium"))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not let a level override an operator's config-level disable", func() {
|
||||
disabled := true
|
||||
c := &config.ModelConfig{}
|
||||
c.ReasoningConfig.DisableReasoning = &disabled
|
||||
c.ApplyReasoningEffort("high")
|
||||
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is a no-op on the toggle when no effort is set anywhere", func() {
|
||||
c := &config.ModelConfig{}
|
||||
c.ApplyReasoningEffort("")
|
||||
Expect(c.ReasoningEffort).To(Equal(""))
|
||||
Expect(c.ReasoningConfig.DisableReasoning).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -420,8 +420,9 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
remoteUnloader = d.Router.Unloader()
|
||||
}
|
||||
}
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
natsCfg := distCfg.NatsAuthConfig()
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, natsCfg)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken, natsCfg)
|
||||
|
||||
// Distributed SSE routes (job progress + agent events via NATS)
|
||||
if d := application.Distributed(); d != nil {
|
||||
|
||||
@@ -383,13 +383,13 @@ var _ = Describe("API test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
|
||||
xlog.Error("server error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig(apiKey)
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
defaultConfig.BaseURL = testHTTPBase + "/v1"
|
||||
|
||||
client2 = openaigo.NewClient("")
|
||||
client2.BaseURL = defaultConfig.BaseURL
|
||||
@@ -418,7 +418,7 @@ var _ = Describe("API test", func() {
|
||||
|
||||
Context("Auth Tests", func() {
|
||||
It("Should fail if the api key is missing", func() {
|
||||
err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available")
|
||||
err, sc := postInvalidRequest(testHTTPBase + "/models/available")
|
||||
Expect(err).ToNot(BeNil())
|
||||
Expect(sc).To(Equal(401))
|
||||
})
|
||||
@@ -427,7 +427,7 @@ var _ = Describe("API test", func() {
|
||||
Context("URL routing Tests", func() {
|
||||
It("Should support reverse-proxy when unauthenticated", func() {
|
||||
|
||||
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
|
||||
err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{
|
||||
"X-Forwarded-Proto": {"https"},
|
||||
"X-Forwarded-Host": {"example.org"},
|
||||
"X-Forwarded-Prefix": {"/myprefix/"},
|
||||
@@ -441,7 +441,7 @@ var _ = Describe("API test", func() {
|
||||
|
||||
It("Should support reverse-proxy when authenticated", func() {
|
||||
|
||||
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
|
||||
err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{
|
||||
"Authorization": {bearerKey},
|
||||
"X-Forwarded-Proto": {"https"},
|
||||
"X-Forwarded-Host": {"example.org"},
|
||||
@@ -459,7 +459,7 @@ var _ = Describe("API test", func() {
|
||||
// requests them through the proxy.
|
||||
It("Should support reverse-proxy when prefix is stripped by the proxy", func() {
|
||||
|
||||
err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{
|
||||
err, sc, body := getRequest(testHTTPBase+"/app", http.Header{
|
||||
"X-Forwarded-Proto": {"https"},
|
||||
"X-Forwarded-Host": {"example.org"},
|
||||
"X-Forwarded-Prefix": {"/myprefix"},
|
||||
@@ -477,7 +477,7 @@ var _ = Describe("API test", func() {
|
||||
// from a foreign origin. BasePathPrefix must reject these via
|
||||
// SafeForwardedPrefix and fall back to "/".
|
||||
It("Should ignore an unsafe X-Forwarded-Prefix and not poison asset URLs", func() {
|
||||
err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{
|
||||
err, sc, body := getRequest(testHTTPBase+"/app", http.Header{
|
||||
"X-Forwarded-Proto": {"https"},
|
||||
"X-Forwarded-Host": {"example.org"},
|
||||
"X-Forwarded-Prefix": {"//evil.com"},
|
||||
@@ -492,13 +492,13 @@ var _ = Describe("API test", func() {
|
||||
Context("Applying models", func() {
|
||||
|
||||
It("applies models from a gallery", func() {
|
||||
models, err := getModels("http://127.0.0.1:9090/models/available")
|
||||
models, err := getModels(testHTTPBase + "/models/available")
|
||||
Expect(err).To(BeNil())
|
||||
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
|
||||
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
|
||||
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
|
||||
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
|
||||
ID: "test@bert2",
|
||||
})
|
||||
|
||||
@@ -507,7 +507,7 @@ var _ = Describe("API test", func() {
|
||||
uuid := response["uuid"].(string)
|
||||
resp := map[string]any{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
@@ -526,7 +526,7 @@ var _ = Describe("API test", func() {
|
||||
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
||||
Expect(content["foo"]).To(Equal("bar"))
|
||||
|
||||
models, err = getModels("http://127.0.0.1:9090/models/available")
|
||||
models, err = getModels(testHTTPBase + "/models/available")
|
||||
Expect(err).To(BeNil())
|
||||
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
|
||||
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
|
||||
@@ -541,7 +541,7 @@ var _ = Describe("API test", func() {
|
||||
})
|
||||
It("overrides models", func() {
|
||||
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
|
||||
URL: bertEmbeddingsURL,
|
||||
Name: "bert",
|
||||
Overrides: map[string]any{
|
||||
@@ -554,7 +554,7 @@ var _ = Describe("API test", func() {
|
||||
uuid := response["uuid"].(string)
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
@@ -567,7 +567,7 @@ var _ = Describe("API test", func() {
|
||||
Expect(content["backend"]).To(Equal("llama"))
|
||||
})
|
||||
It("apply models without overrides", func() {
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
|
||||
URL: bertEmbeddingsURL,
|
||||
Name: "bert",
|
||||
Overrides: map[string]any{},
|
||||
@@ -578,7 +578,7 @@ var _ = Describe("API test", func() {
|
||||
uuid := response["uuid"].(string)
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
@@ -622,14 +622,14 @@ parameters:
|
||||
}
|
||||
|
||||
var response schema.GalleryResponse
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
|
||||
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response.ID).ToNot(BeEmpty())
|
||||
|
||||
uuid := response.ID
|
||||
resp := map[string]any{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
@@ -657,7 +657,7 @@ parameters:
|
||||
}
|
||||
|
||||
var response schema.GalleryResponse
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
|
||||
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
|
||||
// The endpoint should return an error immediately
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("failed to discover model config"))
|
||||
@@ -693,14 +693,14 @@ parameters:
|
||||
}
|
||||
|
||||
var response schema.GalleryResponse
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
|
||||
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response.ID).ToNot(BeEmpty())
|
||||
|
||||
uuid := response.ID
|
||||
resp := map[string]any{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
@@ -751,13 +751,13 @@ parameters:
|
||||
app, err = API(localAIApp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
|
||||
xlog.Error("server error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
defaultConfig.BaseURL = testHTTPBase + "/v1"
|
||||
|
||||
client2 = openaigo.NewClient("")
|
||||
client2.BaseURL = defaultConfig.BaseURL
|
||||
@@ -801,7 +801,7 @@ parameters:
|
||||
// Mock-backend is registered via SetExternalBackend so it appears
|
||||
// alongside any built-in entries; verifying that string proves the
|
||||
// endpoint is wired up regardless of which real backends exist.
|
||||
resp, err := http.Get("http://127.0.0.1:9090/system")
|
||||
resp, err := http.Get(testHTTPBase + "/system")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
dat, err := io.ReadAll(resp.Body)
|
||||
@@ -824,14 +824,14 @@ parameters:
|
||||
}
|
||||
|
||||
var createResp map[string]any
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(createResp["id"]).ToNot(BeEmpty())
|
||||
taskID := createResp["id"].(string)
|
||||
|
||||
// Get the task
|
||||
var task schema.Task
|
||||
resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
|
||||
resp, err := http.Get(testHTTPBase + "/api/agent/tasks/" + taskID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -839,7 +839,7 @@ parameters:
|
||||
Expect(task.Name).To(Equal("Test Task"))
|
||||
|
||||
// List tasks
|
||||
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks")
|
||||
resp, err = http.Get(testHTTPBase + "/api/agent/tasks")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
var tasks []schema.Task
|
||||
@@ -849,18 +849,18 @@ parameters:
|
||||
|
||||
// Update task
|
||||
taskBody["name"] = "Updated Task"
|
||||
err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody)
|
||||
err = putRequestJSON(testHTTPBase+"/api/agent/tasks/"+taskID, &taskBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify update
|
||||
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
|
||||
resp, err = http.Get(testHTTPBase + "/api/agent/tasks/" + taskID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
body, _ = io.ReadAll(resp.Body)
|
||||
json.Unmarshal(body, &task)
|
||||
Expect(task.Name).To(Equal("Updated Task"))
|
||||
|
||||
// Delete task
|
||||
req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil)
|
||||
req, _ := http.NewRequest("DELETE", testHTTPBase+"/api/agent/tasks/"+taskID, nil)
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -877,7 +877,7 @@ parameters:
|
||||
}
|
||||
|
||||
var createResp map[string]any
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
taskID := createResp["id"].(string)
|
||||
|
||||
@@ -888,14 +888,14 @@ parameters:
|
||||
}
|
||||
|
||||
var jobResp schema.JobExecutionResponse
|
||||
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp)
|
||||
err = postRequestResponseJSON(testHTTPBase+"/api/agent/jobs/execute", &jobBody, &jobResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(jobResp.JobID).ToNot(BeEmpty())
|
||||
jobID := jobResp.JobID
|
||||
|
||||
// Get job status
|
||||
var job schema.Job
|
||||
resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID)
|
||||
resp, err := http.Get(testHTTPBase + "/api/agent/jobs/" + jobID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -904,7 +904,7 @@ parameters:
|
||||
Expect(job.TaskID).To(Equal(taskID))
|
||||
|
||||
// List jobs
|
||||
resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs")
|
||||
resp, err = http.Get(testHTTPBase + "/api/agent/jobs")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
var jobs []schema.Job
|
||||
@@ -914,7 +914,7 @@ parameters:
|
||||
|
||||
// Cancel job (if still pending/running)
|
||||
if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning {
|
||||
req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil)
|
||||
req, _ := http.NewRequest("POST", testHTTPBase+"/api/agent/jobs/"+jobID+"/cancel", nil)
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -932,13 +932,13 @@ parameters:
|
||||
}
|
||||
|
||||
var createResp map[string]any
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Execute by name
|
||||
paramsBody := map[string]string{"param1": "value1"}
|
||||
var jobResp schema.JobExecutionResponse
|
||||
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp)
|
||||
err = postRequestResponseJSON(testHTTPBase+"/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(jobResp.JobID).ToNot(BeEmpty())
|
||||
})
|
||||
@@ -998,13 +998,13 @@ parameters:
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
|
||||
xlog.Error("server error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
defaultConfig.BaseURL = testHTTPBase + "/v1"
|
||||
client2 = openaigo.NewClient("")
|
||||
client2.BaseURL = defaultConfig.BaseURL
|
||||
// Wait for API to be ready
|
||||
|
||||
@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
|
||||
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
|
||||
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, "", nil, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ var knownPrefOnlyBackends = []schema.KnownBackend{
|
||||
// Text LLM
|
||||
// ds4: antirez/ds4 - single-model DeepSeek V4 Flash engine; auto-detected via DS4Importer
|
||||
{Name: "ds4", Modality: "text", AutoDetect: false, Description: "antirez/ds4 DeepSeek V4 Flash engine (auto-detected; pref-only fallback)"},
|
||||
// dllm consumes GGUF weights like llama-cpp does, but only for the
|
||||
// DiffusionGemma architecture - auto-detecting on .gguf would shadow
|
||||
// llama-cpp, so it stays preference-only.
|
||||
{Name: "dllm", Modality: "text", AutoDetect: false, Description: "dllm.cpp DiffusionGemma block-diffusion engine (preference-only)"},
|
||||
{Name: "sglang", Modality: "text", AutoDetect: false, Description: "SGLang runtime (preference-only)"},
|
||||
{Name: "tinygrad", Modality: "text", AutoDetect: false, Description: "tinygrad runtime (preference-only)"},
|
||||
{Name: "trl", Modality: "text", AutoDetect: false, Description: "Transformers Reinforcement Learning (preference-only)"},
|
||||
|
||||
@@ -135,6 +135,7 @@ var _ = Describe("Backend Endpoints", func() {
|
||||
Expect(entry.Modality).To(Equal(modality))
|
||||
}
|
||||
|
||||
expectPrefOnly("dllm", "text")
|
||||
expectPrefOnly("sglang", "text")
|
||||
expectPrefOnly("tinygrad", "text")
|
||||
expectPrefOnly("trl", "text")
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
)
|
||||
|
||||
// nodeError builds a schema.ErrorResponse for node endpoints.
|
||||
@@ -89,7 +90,7 @@ type RegisterNodeRequest struct {
|
||||
// RegisterNodeEndpoint registers a new backend node.
|
||||
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
|
||||
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req RegisterNodeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
@@ -217,13 +218,15 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
}
|
||||
|
||||
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
|
||||
// For agent workers, it also provisions an API key so they can call the inference API.
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
id := c.Param("id")
|
||||
@@ -253,10 +256,26 @@ func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecr
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
|
||||
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
|
||||
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
|
||||
return
|
||||
}
|
||||
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
|
||||
return
|
||||
}
|
||||
response["nats_jwt"] = jwt
|
||||
response["nats_user_seed"] = seed
|
||||
}
|
||||
|
||||
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
|
||||
// Returns the plaintext API key on success.
|
||||
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -63,7 +65,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -74,6 +76,29 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
Expect(resp["status"]).To(Equal(nodes.StatusHealthy))
|
||||
})
|
||||
|
||||
It("returns nats_jwt when account seed is configured", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
e := echo.New()
|
||||
body := `{"name":"worker-nats","address":"10.0.0.2:50051"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
natsCfg := natsauth.Config{AccountSeed: string(seed)}
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsCfg)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["nats_jwt"]).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns 400 when name is missing", func() {
|
||||
e := echo.New()
|
||||
body := `{"address":"10.0.0.1:50051"}`
|
||||
@@ -82,7 +107,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -102,7 +127,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -121,7 +146,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -140,7 +165,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -159,7 +184,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
@@ -172,7 +197,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -195,7 +220,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body1))
|
||||
req1.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(e.NewContext(req1, rec1))).To(Succeed())
|
||||
Expect(rec1.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Stream audio chunks as they're generated
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
_, writeErr := c.Response().Write(audioChunk)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
}
|
||||
|
||||
// Non-streaming TTS (existing behavior)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -103,7 +103,12 @@ func applyAutoparserOverride(
|
||||
// blocks like "<think></think>" that some models emit when reasoning
|
||||
// is disabled.
|
||||
if deltaReasoning == "" && deltaContent != "" {
|
||||
deltaReasoning, deltaContent = reason.ExtractReasoningWithConfig(deltaContent, thinkingStartToken, reasoningConfig)
|
||||
// Complete-response extraction: only honor a prefilled <think> start
|
||||
// token when deltaContent actually closes the reasoning block. Without
|
||||
// it the model answered directly and the whole answer must stay in
|
||||
// content rather than be swallowed as unclosed reasoning. See
|
||||
// reason.ExtractReasoningComplete.
|
||||
deltaReasoning, deltaContent = reason.ExtractReasoningComplete(deltaContent, thinkingStartToken, reasoningConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] non-SSE no-tools: overriding result with C++ autoparser deltas",
|
||||
"content_len", len(deltaContent), "reasoning_len", len(deltaReasoning))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user