Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
8cf0448304 chore(deps): bump actions/upload-artifact from 4 to 7
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4 to 7.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v4...v7)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: '7'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-06-03 03:23:39 +00:00
208 changed files with 925 additions and 14850 deletions

View File

@@ -1,138 +0,0 @@
# Working on the dllm Backend
`mudler/dllm.cpp` is a standalone C++/ggml engine for DiffusionGemma
block-diffusion models. LocalAI wraps it with a **pure-Go** backend at
`backend/go/dllm/` that dlopens `libdllm.so` via purego (ebitengine/purego) -
NOT cgo, and NOT a C++ grpc-server fork. The Go side owns chat templating
(gemma4 renderer) and output parsing (gemma4 streaming parser) and implements
the rich gRPC interface (`PredictRich`/`PredictStreamRich`, ChatDelta replies).
> NOTE: github.com/mudler/dllm.cpp is still **private** (publishing is
> planned). Until then the Makefile's anonymous clone fails; use the local-dev
> symlink shortcut documented at the top of `backend/go/dllm/Makefile`
> (symlink an out-of-tree `build/libdllm.so` into the backend dir and skip the
> clone), or a git credential helper with repo access.
## Pin
`backend/go/dllm/Makefile` pins `DLLM_VERSION?=<sha>` at the top
(whisper / parakeet-cpp / ds4 convention). The bump-deps bot
(`.github/workflows/bump_deps.yaml`) tracks `mudler/dllm.cpp` `main` and
rewrites that variable. After a manual bump: `make -C backend/go/dllm purge &&
make -C backend/go/dllm` (the clone is keyed on the directory existing, not
the sha).
## C-ABI and the serialization contract
The binding covers the 9-symbol flat C-ABI from dllm.cpp's
`include/dllm_capi.h` (ABI v1; `main.go` hard-fails on a version mismatch):
`abi_version, load, free, last_error, free_string, tokenize_json, generate,
generate_stream, cancel`. Contract points the Go wiring encodes (`capi.go`
header comment has the full list):
- **One ctx = one concurrent generate/tokenize.** A per-model worker
goroutine (`Dllm.jobs` in `dllm.go`) owns ALL C calls, making the
serialization structural instead of lock discipline.
- **`dllm_capi_cancel` is the ONE exception**: it only flips an atomic and may
be called from any goroutine mid-generate, so `Dllm.Cancel` bypasses the
worker queue. The flag resets at the start of each generate, so a watchdog
racing a new generate must re-issue cancel.
- **`last_error` is a borrowed pointer** and must only be read AFTER the
failing call returned (never while a generate is in flight on the same ctx).
- **Free vs in-flight requests**: requests hold `genMu.RLock` for their full
duration; `Free` takes the write lock, so it only runs when nothing is in
flight, then drains and closes the worker. Post-Free requests get a clean
"model not loaded" error.
- `tokenize_json`/`generate` return malloc'd `char*` (bound as `uintptr`,
copied, then `dllm_capi_free_string`d); opts/params JSON must be a FLAT
object of scalars (`buildOptsJSON` rejects anything else).
## Wire shape
| RPC | Implementation |
|---|---|
| LoadModel | `dllm_capi_load` (params: `n_gpu_layers`, `n_threads`, `ctx_len`); `Options[]` parsed into per-request gen opts (`eb_*`, `blocks`, `kv_cache`) by `parseModelGenOpts` |
| PredictRich | render (if templated) → `dllm_capi_generate` → parse → ONE Reply with aggregated ChatDeltas + legacy `Message` bytes |
| PredictStreamRich | `dllm_capi_generate_stream`; per committed diffusion block → UTF-8 holdback → parser.Feed → one Reply per non-empty delta batch (channel closed by the CALLER, per `pkg/grpc/interface.go`) |
| Predict / PredictStream | Legacy paths, delegate to the rich pair (legacy stream INVERTS channel ownership: the impl closes) |
| TokenizeString | `dllm_capi_tokenize_json` (C side prepends BOS per `vocab.add_bos`) |
| Cancel | `dllm_capi_cancel`, exposed as the `grpc.Cancellable` capability (`pkg/grpc/interface.go`): the gRPC server arms it via `context.AfterFunc` on the Predict/PredictStream context, so client disconnects/timeouts abort the in-flight generate - llama.cpp `IsCancelled()` parity for Go backends |
`n_threads` and `ctx_len` are accepted-but-ignored by the engine at the
current pin (the context bound comes from GGUF `n_ctx_train`); they are sent
for forward compatibility.
## Renderer / parser (the templated chat path)
With `use_tokenizer_template` + raw Messages, the backend owns templating and
parsing (the ds4 precedent, but in Go):
- `gemma4_renderer.go` - `RenderGemma4(msgs, toolsJSON, enableThinking,
addGenerationPrompt)`. The file embeds the FULL `tokenizer.chat_template`
jinja (17466 bytes, md5 `8c34cf93c7a7815b3fdb300a009c4c17`) extracted
verbatim from `diffusiongemma-26B-A4B-it-BF16.gguf` via gguf-py - e.g.
`python scripts/dump_gguf.py model.gguf | grep -A400 chat_template` in the
dllm.cpp checkout - as a numbered comment block; every Go rule cites its
"tpl L<n>" line. Re-verify the md5 before blaming the renderer for a
mismatch with a new GGUF. **BOS exception**: the template emits
`{{- bos_token -}}` but the renderer deliberately does NOT - dllm.cpp's
`run_generate` tokenizes with `prepend_bos = vocab.add_bos` (true for
gemma4), so a literal `<bos>` would double it.
- `gemma4_parser.go` - streaming state machine turning raw model text
(fragments can split anywhere, including mid-marker) into ChatDeltas:
thought channels → `reasoning_content`, `<|tool_call>call:name{...}` →
ToolCallDelta, `<turn|>` → done. Marker grammar cross-checked against vLLM
PR #45163's gemma4 tool/reasoning parsers. Malformed payloads are re-emitted
raw as content, never dropped.
- Thinking is **opt-in** for this family (`Metadata["enable_thinking"]`,
default OFF - the inverse of ds4): the template gates every thinking branch
on `enable_thinking`, and the no-thinking render pre-closes an empty thought
channel, so the parser always starts in content state.
- **UTF-8 boundary holdback** (`splitValidUTF8` in `dllm.go`): per-block
detokenization can split a multi-byte character across block boundaries, and
grpc-go refuses to marshal invalid UTF-8 in proto3 strings. An incomplete
trailing sequence (at most 3 bytes) is carried into the next block; genuinely
undecodable bytes become U+FFFD.
Without `use_tokenizer_template`, the prompt passes through verbatim and the
output is NOT gemma4-parsed (plain content, like any non-autoparsing backend).
## Tests
| Layer | Gate | What |
|---|---|---|
| `backend/go/dllm/*_test.go` (renderer/parser/wiring) | none - run in plain `go test ./backend/go/dllm/...` | Ginkgo specs over a fake `generator` seam; canonical renderer fixtures from transformers' `test_modeling_diffusion_gemma.py`, parser tables from the vLLM gemma4 parsers |
| `backend/go/dllm/dllm_test.go` C-ABI smoke | `DLLM_TEST_LIBRARY` + `DLLM_TEST_TINY_MODEL` (dllm.cpp's `tests/fixtures/tiny_with_vocab.gguf`); Skips when unset | Drives the real `libdllm.so`: ABI check, load, tokenize `[2,18]`, deterministic generate, cancel (incl. mid-stream `Dllm.Cancel` aborting a deliberately slow `eb_max_steps:256` run in ~10ms) |
| `tests/e2e-backends/dllm_test.go` | `BACKEND_TEST_DLLM=1` + `BACKEND_BINARY` (packaged run.sh) + `BACKEND_TEST_MODEL_FILE` (tiny fixture) | Templated chat round trip (Messages + UseTokenizerTemplate) over the real gRPC binary, non-streaming + streaming; plus client-context cancellation mid-stream (proves the `Cancellable` server plumbing end to end) |
| Real-model e2e | `BACKEND_TEST_DLLM_REAL_MODEL_FILE` (26B BF16, ~50 GB) + `BACKEND_TEST_DLLM_REAL_GPU_LAYERS` | CUDA-13-class hardware only |
Tool-call e2e is deliberately absent from the tiny-model spec: the fixture has
random weights and cannot be coaxed into emitting tool markup; the unit tables
carry that coverage.
## Build matrix
`cpu-dllm` (amd64 + arm64), `cuda13-dllm` (amd64), and
`cuda13-nvidia-l4t-arm64-dllm` (arm64 CUDA: Jetson / DGX Spark GB10), via
`.github/backend-matrix.yml`. No darwin/Metal. CUDA builds forward
`-DDLLM_CUDA=ON` (dllm.cpp gates ggml's CUDA behind its own flag - a bare
`-DGGML_CUDA=ON` is overridden by the cache FORCE). `libdllm.so` is
self-contained (ggml statically absorbed, PIC), so `package.sh` only ships
the binary, `run.sh` and that one .so (the parakeet-cpp-style stub layout;
no ldd walk yet).
## Known limitations
- **Cancel granularity**: the C-ABI cancel flag is per-ctx and resets on
every generate entry, so a Cancel racing a NEW generate can be lost, and
with requests queued on the worker it aborts whichever generate is
currently running (acceptable: the server de-registers the hook on normal
completion, one process serves one model).
- **Throughput**: ~0.15 tok/s on the 26B at default settings (GB10) - every
denoise step recomputes the full prompt+canvas. The upstream prefix-KV
cache (dllm.cpp P3) is the fix; `kv_cache:on` errors until it lands
(`auto`/`off` are accepted no-ops).
- **Repo privacy**: see the note at the top - CI clone of dllm.cpp needs the
repo published (or credentials) before the backend images can build.
- Engine spec/validation references: dllm.cpp `docs/validation.md` and
LocalAI `docs/superpowers/specs/2026-06-10-dllm-cpp-design.md`.

View File

@@ -1608,19 +1608,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-13-dllm'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "dllm"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -1660,19 +1647,6 @@ include:
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'false'
tag-latest: 'auto'
tag-suffix: '-nvidia-l4t-cuda-13-arm64-dllm'
base-image: "ubuntu:24.04"
ubuntu-version: '2404'
runs-on: 'ubuntu-24.04-arm'
backend: "dllm"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -1792,6 +1766,20 @@ include:
dockerfile: "./backend/Dockerfile.llama-cpp"
context: "./"
ubuntu-version: '2404'
- build-type: 'hipblas'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-rocm-hipblas-turboquant'
builder-base-image: 'quay.io/go-skynet/ci-cache:base-grpc-rocm-amd64'
runs-on: 'ubuntu-latest'
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
skip-drivers: 'false'
backend: "turboquant"
dockerfile: "./backend/Dockerfile.turboquant"
context: "./"
ubuntu-version: '2404'
- build-type: 'hipblas'
cuda-major-version: ""
cuda-minor-version: ""
@@ -3171,35 +3159,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
# dllm
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
platform-tag: 'amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-dllm'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "dllm"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/arm64'
platform-tag: 'arm64'
tag-latest: 'auto'
tag-suffix: '-cpu-dllm'
runs-on: 'ubuntu-24.04-arm'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "dllm"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'sycl_f32'
cuda-major-version: ""
cuda-minor-version: ""

View File

@@ -3,7 +3,6 @@ package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
@@ -114,17 +113,6 @@ func main() {
fmt.Println("Searching for trending models on HuggingFace...")
rawModels, err := client.GetTrending(searchTerm, limit)
if err != nil {
if errors.Is(err, hfapi.ErrRateLimited) {
fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err)
writeSummary(AddedModelSummary{
SearchTerm: searchTerm,
TotalFound: 0,
ModelsAdded: 0,
Quantization: quantization,
ProcessingTime: time.Since(startTime).String(),
})
return
}
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
os.Exit(1)
}
@@ -289,3 +277,4 @@ func truncateString(s string, maxLen int) string {
}
return s[:maxLen] + "..."
}

View File

@@ -38,10 +38,6 @@ jobs:
variable: "PARAKEET_VERSION"
branch: "master"
file: "backend/go/parakeet-cpp/Makefile"
- repository: "mudler/dllm.cpp"
variable: "DLLM_VERSION"
branch: "main"
file: "backend/go/dllm/Makefile"
- repository: "leejet/stable-diffusion.cpp"
variable: "STABLEDIFFUSION_GGML_VERSION"
branch: "master"

View File

@@ -18,7 +18,7 @@ jobs:
if: ${{ github.actor != 'dependabot[bot]' }}
- name: Run Gosec Security Scanner
if: ${{ github.actor != 'dependabot[bot]' }}
uses: securego/gosec@v2.27.1
uses: securego/gosec@v2.22.9
with:
# we let the report trigger content trigger a failure using the GitHub Security features.
args: '-no-fail -fmt sarif -out results.sarif ./...'

View File

@@ -62,7 +62,7 @@ jobs:
PATH="$PATH:/root/go/bin" make --jobs 5 --output-sync=target test-coverage-check
- name: Upload coverage report
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v7
with:
name: coverage-linux
path: |

View File

@@ -26,7 +26,6 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
| [.agents/sglang-backend.md](.agents/sglang-backend.md) | Working on the SGLang backend — `engine_args` validation against ServerArgs, speculative-decoding (EAGLE/EAGLE3/DFLASH/MTP) recipes, parser handling |
| [.agents/ds4-backend.md](.agents/ds4-backend.md) | Working on the ds4 backend - DSML state machine, thinking modes, KV cache, Metal+CUDA matrix |
| [.agents/dllm-backend.md](.agents/dllm-backend.md) | Working on the dllm backend (DiffusionGemma block-diffusion) - purego C-ABI binding, per-ctx serialization contract, gemma4 renderer/parser, gated test layers |
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |

View File

@@ -1,5 +1,5 @@
# Disable parallel execution for backend builds
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/dllm backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
GOCMD=go
GOTEST=$(GOCMD) test
@@ -180,7 +180,7 @@ osx-signed: build
## Run
run: ## run local-ai
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./cmd/local-ai
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
prepare-test: protogen-go build-mock-backend
@@ -309,20 +309,13 @@ run-e2e-aio: protogen-go
@echo 'Running e2e AIO tests'
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
test-e2e-distributed: protogen-go
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
# cpu-vllm backend from the current working tree, then drives a
# head + headless follower via testcontainers-go and asserts a chat
# completion. BuildKit caches both images, so re-runs only rebuild
# what changed. The test lives under tests/e2e/distributed and is
# selected by the VLLMMultinode label so it doesn't run alongside
# test-e2e-distributed.
# the other distributed-suite tests by default.
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
@echo 'Running e2e vLLM multi-node DP test'
LOCALAI_IMAGE=local-ai \
@@ -1171,9 +1164,6 @@ BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|tr
BACKEND_WHISPER = whisper|golang|.|false|true
BACKEND_CRISPASR = crispasr|golang|.|false|true
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
# dllm is mudler/dllm.cpp, the DiffusionGemma block-diffusion engine,
# wrapped by the purego backend at backend/go/dllm.
BACKEND_DLLM = dllm|golang|.|false|true
BACKEND_VOXTRAL = voxtral|golang|.|false|true
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
@@ -1263,7 +1253,6 @@ $(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_DLLM)))
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))

View File

@@ -149,16 +149,6 @@ local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
local-ai run oci://localai/phi-2:latest
```
To test a running LocalAI server from the terminal, open an interactive chat session from another shell. Inside the prompt, `/models` lists installed models and `/model <name>` switches between them.
```bash
# Terminal 1
local-ai run llama-3.2-1b-instruct:q4_k_m
# Terminal 2
local-ai chat --model llama-3.2-1b-instruct:q4_k_m
```
> **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).

View File

@@ -537,15 +537,6 @@ message TTSRequest {
string dst = 3;
string voice = 4;
optional string language = 5;
// instructions is a free-form, per-request style/voice description (maps to
// the OpenAI `instructions` field). Backends that support expressive synthesis
// (e.g. Qwen3-TTS CustomVoice/VoiceDesign) prefer this over the static YAML
// option when set; backends that don't simply ignore it.
optional string instructions = 6;
// params carries optional, backend-specific per-request generation parameters
// (e.g. Chatterbox exaggeration/cfg_weight/temperature). Values are strings and
// coerced by the backend; unset leaves the backend's configured defaults.
map<string, string> params = 7;
}
message VADRequest {

View File

@@ -60,12 +60,10 @@ elseif(DS4_GPU STREQUAL "cpu")
set(DS4_OBJS "${DS4_DIR}/ds4_cpu.o")
endif()
# ds4.c now references ds4_distributed.c (distributed inference) and ds4_ssd.c
# (SSD expert-cache), each split into its own translation unit upstream. Both
# are GPU-agnostic objects shared by every GPU mode, so link them in regardless
# of DS4_GPU.
# ds4.c now references ds4_distributed.c (distributed inference was split into
# its own translation unit upstream). It is a single GPU-agnostic object shared
# by every GPU mode, so link it in regardless of DS4_GPU.
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_distributed.o")
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_ssd.o")
add_executable(${TARGET}
grpc-server.cpp

View File

@@ -1,10 +1,10 @@
# ds4 backend Makefile.
#
# Upstream pin lives below as DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
# Upstream pin lives below as DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
# (.github/bump_deps.sh) can find and update it - matches the
# llama-cpp / ik-llama-cpp / turboquant convention.
DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
DS4_REPO?=https://github.com/antirez/ds4
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
@@ -18,20 +18,19 @@ UNAME_S := $(shell uname -s)
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
# ds4_distributed.o and ds4_ssd.o are GPU-agnostic translation units that
# ds4.c/ds4_cpu.o now reference (upstream split distributed inference and the
# SSD expert-cache into their own .c files). Both objects are shared by every
# GPU mode, so they are appended unconditionally below.
# ds4_distributed.o is a GPU-agnostic translation unit that ds4.c/ds4_cpu.o now
# reference (upstream split distributed inference into its own .c). The same
# object is shared by every GPU mode, so it is appended unconditionally below.
ifeq ($(BUILD_TYPE),cublas)
CMAKE_ARGS += -DDS4_GPU=cuda
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o
else ifeq ($(UNAME_S),Darwin)
CMAKE_ARGS += -DDS4_GPU=metal
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o
else
# CPU reference path (Linux only - macOS CPU path is broken by VM bug per ds4 README).
CMAKE_ARGS += -DDS4_GPU=cpu
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o ds4_ssd.o
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o
endif
ifneq ($(NATIVE),true)
@@ -56,11 +55,11 @@ ds4:
# the right per-platform compile flags (Objective-C/Metal on Darwin, nvcc on Linux+CUDA).
ds4/ds4.o: ds4
ifeq ($(BUILD_TYPE),cublas)
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o
else ifeq ($(UNAME_S),Darwin)
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o
else
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o ds4_ssd.o
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o
endif
grpc-server: ds4/ds4.o

View File

@@ -1,5 +1,5 @@
IK_LLAMA_VERSION?=e6f8112f3ba126eed3ff5b30cdd08085414a7516
IK_LLAMA_VERSION?=3f40e73c367ad9f0c1b1819f28c7348c26aa340d
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
CMAKE_ARGS?=

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=039e20a2db9e87b2477c76cc04905f3e1acad77f
LLAMA_VERSION?=5dcb71166686799f0d873eab7386234302d05ecf
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?=

View File

@@ -381,15 +381,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
});
}
// for each video in the request, add the video data
for (int i = 0; i < predict->videos_size(); i++) {
data["video_data"].push_back(json
{
{"id", i},
{"data", predict->videos(i)},
});
}
data["stop"] = predict->stopprompts();
// data["n_probs"] = predict->nprobs();
//TODO: images,
@@ -491,13 +482,23 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
if (!request->draftmodel().empty()) {
params.speculative.draft.mparams.path = request->draftmodel();
// Default to draft type if a draft model is set but no explicit type.
// Upstream made the speculative type a vector (ggml-org/llama.cpp#22838)
// and renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE (#22964).
// Upstream (post ggml-org/llama.cpp#22838) made the speculative type a
// vector; the turboquant fork still uses the legacy scalar. The
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
}
#else
const bool no_spec_type = params.speculative.types.empty() ||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
if (no_spec_type) {
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
}
#endif
}
// params.model_alias ??
@@ -573,10 +574,9 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
// tokens (0 disables the minimum). Match upstream's default (256). This
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
// also shifted from a fixed cadence to a minimum spacing. The turboquant
// fork still lacks common_params::checkpoint_min_step, so skip it there
// (LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP is injected by
// backend/cpp/turboquant/patch-grpc-server.sh).
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
// fork branched before the field existed, so skip it on the legacy path
// (LOCALAI_LEGACY_LLAMA_CPP_SPEC is injected by patch-grpc-server.sh).
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
params.checkpoint_min_step = 256;
#endif
@@ -752,7 +752,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
params.cache_idle_slots = false;
}
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
@@ -906,6 +906,17 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
// Speculative decoding options
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
// Fork only knows a single scalar `type`. Take the first comma-
// separated value and assign it via the singular helper.
std::string first = optval_str;
const auto comma = first.find(',');
if (comma != std::string::npos) first = first.substr(0, comma);
auto type = common_speculative_type_from_name(first);
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
params.speculative.type = type;
}
#else
// Upstream switched to a vector of types (comma-separated for multi-type
// chaining via common_speculative_types_from_names). We keep accepting a
// single value here, but also tolerate comma-separated lists.
@@ -934,6 +945,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
if (!parsed.empty()) {
params.speculative.types = parsed;
}
#endif
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
if (optval != NULL) {
try { params.speculative.draft.n_max = std::stoi(optval_str); } catch (...) {}
@@ -971,6 +983,21 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
// shares the target context size. Accept the option for backward
// compatibility but silently ignore it.
// Everything below relies on struct shape introduced in ggml-org/llama.cpp#22838
// (parallel drafting): `ngram_mod`, `ngram_map_k`, `ngram_map_k4v`,
// `ngram_cache`, and the `draft.{cache_type_*, cpuparams*, tensor_buft_overrides}`
// fields. The turboquant fork branched before that, so its build defines
// LOCALAI_LEGACY_LLAMA_CPP_SPEC via patch-grpc-server.sh and these option
// keys become unrecognized (silently dropped, like any unknown opt) for it.
//
// The `#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC` / `#else` split below sits at the
// closing-brace position of the `draft_ctx_size` branch on purpose: in the
// legacy build the chain ends here (the brace closes draft_ctx_size), and in
// the modern build the chain continues with `} else if (...)` instead, so the
// brace count stays balanced under both branches of the preprocessor.
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
}
#else
// --- ngram_mod family (upstream --spec-ngram-mod-*) ---
} else if (!strcmp(optname, "spec_ngram_mod_n_min")) {
if (optval != NULL) {
@@ -1100,6 +1127,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
}
if (!cur.empty()) flush(cur);
}
#endif // LOCALAI_LEGACY_LLAMA_CPP_SPEC — closes the `else`/`#ifdef` opened at draft_ctx_size
}
// Set params.n_parallel from environment variable if not set via options (fallback)
@@ -1149,11 +1177,15 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
}
// Terminate the draft tensor_buft_overrides list with a sentinel, mirroring
// the main-model handling above.
// The draft tensor_buft_overrides are only populated under the modern
// (post-#22838) layout, whose population code is itself gated by
// LOCALAI_LEGACY_LLAMA_CPP_SPEC above. The turboquant fork lacks
// common_params_speculative::draft entirely, so skip the sentinel there too.
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
}
#endif
// TODO: Add yarn
@@ -1512,7 +1544,7 @@ public:
msg_json["role"] = msg.role();
bool is_last_user_msg = (i == last_user_msg_idx);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
// Handle content - can be string, null, or array
// For multimodal content, we'll embed images/audio from separate fields
@@ -1563,16 +1595,6 @@ public:
content_array.push_back(audio_chunk);
}
}
if (request->videos_size() > 0) {
for (int j = 0; j < request->videos_size(); j++) {
json video_chunk;
video_chunk["type"] = "input_video";
json input_video;
input_video["data"] = request->videos(j);
video_chunk["input_video"] = input_video;
content_array.push_back(video_chunk);
}
}
msg_json["content"] = content_array;
} else {
// Use content as-is (already array or not last user message)
@@ -1607,16 +1629,6 @@ public:
content_array.push_back(audio_chunk);
}
}
if (request->videos_size() > 0) {
for (int j = 0; j < request->videos_size(); j++) {
json video_chunk;
video_chunk["type"] = "input_video";
json input_video;
input_video["data"] = request->videos(j);
video_chunk["input_video"] = input_video;
content_array.push_back(video_chunk);
}
}
msg_json["content"] = content_array;
} else if (msg.role() == "tool") {
// Tool role messages must have content field set, even if empty
@@ -1932,17 +1944,6 @@ public:
body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true");
}
// Pass reasoning_effort via chat_template_kwargs too: the lever
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
// from enable_thinking which those templates ignore.
auto re_it = metadata.find("reasoning_effort");
if (re_it != metadata.end() && !re_it->second.empty()) {
if (!body_json.contains("chat_template_kwargs")) {
body_json["chat_template_kwargs"] = json::object();
}
body_json["chat_template_kwargs"]["reasoning_effort"] = re_it->second;
}
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
@@ -2068,16 +2069,6 @@ public:
files.push_back(decoded_data);
}
}
const auto &video_data = data.find("video_data");
if (video_data != data.end() && video_data->is_array())
{
for (const auto &video : *video_data)
{
auto decoded_data = base64_decode(video["data"].get<std::string>());
files.push_back(decoded_data);
}
}
}
const bool has_mtmd = ctx_server.impl->mctx != nullptr;
@@ -2330,7 +2321,7 @@ public:
}
bool is_last_user_msg = (i == last_user_msg_idx);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
// Handle content - can be string, null, or array
// For multimodal content, we'll embed images/audio from separate fields
@@ -2383,16 +2374,6 @@ public:
content_array.push_back(audio_chunk);
}
}
if (request->videos_size() > 0) {
for (int j = 0; j < request->videos_size(); j++) {
json video_chunk;
video_chunk["type"] = "input_video";
json input_video;
input_video["data"] = request->videos(j);
video_chunk["input_video"] = input_video;
content_array.push_back(video_chunk);
}
}
msg_json["content"] = content_array;
} else {
// Use content as-is (already array or not last user message)
@@ -2432,16 +2413,6 @@ public:
content_array.push_back(audio_chunk);
}
}
if (request->videos_size() > 0) {
for (int j = 0; j < request->videos_size(); j++) {
json video_chunk;
video_chunk["type"] = "input_video";
json input_video;
input_video["data"] = request->videos(j);
video_chunk["input_video"] = input_video;
content_array.push_back(video_chunk);
}
}
msg_json["content"] = content_array;
SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i);
} else if (!msg.tool_calls().empty()) {
@@ -2766,17 +2737,6 @@ public:
body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true");
}
// Pass reasoning_effort via chat_template_kwargs too: the lever
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
// from enable_thinking which those templates ignore.
auto predict_re_it = predict_metadata.find("reasoning_effort");
if (predict_re_it != predict_metadata.end() && !predict_re_it->second.empty()) {
if (!body_json.contains("chat_template_kwargs")) {
body_json["chat_template_kwargs"] = json::object();
}
body_json["chat_template_kwargs"]["reasoning_effort"] = predict_re_it->second;
}
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
@@ -2904,16 +2864,6 @@ public:
files.push_back(decoded_data);
}
}
const auto &video_data = data.find("video_data");
if (video_data != data.end() && video_data->is_array())
{
for (const auto &video : *video_data)
{
auto decoded_data = base64_decode(video["data"].get<std::string>());
files.push_back(decoded_data);
}
}
}
// process files

View File

@@ -1,7 +1,7 @@
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
TURBOQUANT_VERSION?=7d9715f1f071fa07c7b2ad3dbfd320b314139e65
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
CMAKE_ARGS?=

View File

@@ -4,19 +4,21 @@
#
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file
# so the grpc-server option parser skips the two references to
# common_params::checkpoint_min_step (the default and the option handler).
# That field does not exist in the fork yet; drop this once it does.
#
# The fork used to lag upstream on the whole common_params_speculative refactor
# (ggml-org/llama.cpp#22397/#22838/#22964), the model_tgt rename (#22838) and
# get_media_marker (#21962), which required a much larger compat shim here
# (flat-field sed renames + a coarse LOCALAI_LEGACY_LLAMA_CPP_SPEC define). The
# fork has since rebased past all of those, so the only remaining gap is
# checkpoint_min_step. If a future bump reintroduces a divergence, add a narrow
# guard in grpc-server.cpp keyed on a fork-specific macro and inject it here
# rather than resurrecting the coarse one.
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
# server-side random per-instance marker) with the legacy "<__media__>"
# literal. The fork branched before that PR, so server-common.cpp has no
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
# "<__media__>", and Go-side tooling falls back to that sentinel when the
# backend does not expose media_marker, so substituting the literal keeps
# behavior identical on the turboquant path.
# 3. Revert the `common_params_speculative` field references to the
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
# the turboquant fork branched before that PR and still exposes the flat
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
# below map the new nested paths back to the legacy flat names so the
# shared grpc-server.cpp keeps compiling against the fork's common.h.
# Drop this block once the fork rebases past #22397.
#
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
@@ -70,20 +72,72 @@ else
echo "==> KV allow-list patch OK"
fi
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file so
# the grpc-server option parser skips the two references to
# common_params::checkpoint_min_step (the default assignment and the option
# handler). That field does not exist in the fork yet. Drop this block once
# the fork rebases past the bump that added checkpoint_min_step.
if grep -q '^#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP' "$SRC"; then
echo "==> $SRC already defines LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP, skipping"
if grep -q 'get_media_marker()' "$SRC"; then
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
# Only one call site today (ModelMetadata), but replace all occurrences to
# stay robust if upstream adds more. Use a temp file to avoid relying on
# sed -i portability (the builder image uses GNU sed, but keeping this
# consistent with the awk block above).
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> get_media_marker() substitution OK"
else
echo "==> patching $SRC to define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top"
# Insert the define before the very first `#include` so it precedes the
# checkpoint_min_step references.
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
fi
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
# Each substitution is the exact post-refactor path → legacy flat field.
# Order doesn't matter because the source paths are disjoint, but we keep
# the most-specific (mparams.path) first for readability.
sed -E \
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
"$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> speculative field rename OK"
else
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
fi
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
# exposes the field as `model` on `server_context_impl`. The two call sites
# are in the Rerank and ModelMetadata RPC handlers.
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> model_tgt rename OK"
else
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
fi
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
# draft.tensor_buft_overrides) introduced for the post-#22838 layout, the
# draft.tensor_buft_overrides sentinel termination, and the
# common_params::checkpoint_min_step default/option (added with the
# 35c9b1f3 bump). Those blocks reference struct fields that simply do not
# exist in the fork.
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
else
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
# Insert the define before the very first `#include` so it precedes all the
# speculative-decoding code paths.
awk '
!done && /^#include/ {
print "#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP 1"
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
print ""
done = 1
@@ -91,13 +145,13 @@ else
{ print }
END {
if (!done) {
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP" > "/dev/stderr"
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
exit 1
}
}
' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP define OK"
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
fi
echo "==> all patches applied"

View File

@@ -1,55 +0,0 @@
hip: port the turboquant CUDA additions that ggml's HIP shim doesn't cover
The turboquant fork adds/modifies a few ggml-cuda.cu spots with CUDA APIs
that ggml's HIP (and MUSA) compatibility layer does not provide, breaking
the -gpu-rocm-hipblas-turboquant build:
1. ggml_cuda_copy2d_across_devices() (host-staged cross-device copy for
split mul_mat output) uses the CUDA 3D-peer copy APIs
cudaMemcpy3DPeerParms / make_cudaPitchedPtr / make_cudaExtent /
cudaMemcpy3DPeerAsync. HIP genuinely does not support these (see the
fork's own comment "HIP does not support cudaMemcpy3DPeerAsync"), so
guard the peer fast path with #if !defined(GGML_USE_HIP) &&
!defined(GGML_USE_MUSA) -- matching how the fork already guards the
same API for the sibling 2D copy -- and fall through to the existing
cudaMemcpyAsync staging fallback below (functionally identical,
slightly slower on multi-GPU ROCm).
2. ggml_backend_cuda_device_event_new() creates its event with plain
cudaEventCreate, which ggml's HIP shim does not alias (it only aliases
cudaEventCreateWithFlags). Use cudaEventCreateWithFlags(...,
cudaEventDisableTiming) -- exactly what the rest of this file already
does (cf. lines ~1034, ~3461) and HIP-safe.
CUDA builds are unaffected. Drop the relevant hunk once the fork HIP-ports
these; apply-patches.sh fails fast if an anchor goes stale.
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 0427e6b..6352e6a 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -1933,6 +1933,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
size_t width, size_t height, cudaStream_t dst_stream, cudaStream_t src_stream) {
const auto & info = ggml_cuda_info();
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // 3D-peer copy types unmapped by ggml's HIP/MUSA shim; use staging fallback below
if (info.peer_access[src_device][dst_device]) {
cudaMemcpy3DPeerParms p = {};
p.dstDevice = dst_device;
@@ -1942,6 +1943,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
p.extent = make_cudaExtent(width, height, 1);
return cudaMemcpy3DPeerAsync(&p, dst_stream);
}
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
// Fallback: stage all rows through a single contiguous pinned buffer
int prev_device = ggml_cuda_get_device();
@@ -5714,7 +5716,7 @@ static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_
ggml_cuda_set_device(dev_ctx->device);
cudaEvent_t event;
- CUDA_CHECK(cudaEventCreate(&event));
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
return new ggml_backend_event {
/* .device = */ dev,

View File

@@ -14,7 +14,6 @@ import (
"github.com/mudler/xlog"
"github.com/mudler/LocalAI/pkg/grpc/base"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/httpclient"
)
@@ -146,7 +145,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
cfg := c.cfg.Load()
if cfg == nil {
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
return nil, errors.New("cloud-proxy: model not loaded")
}
if cfg.mode != modeTranslate {
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
@@ -176,7 +175,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
cfg := c.cfg.Load()
if cfg == nil {
return grpcerrors.ModelNotLoaded("cloud-proxy")
return errors.New("cloud-proxy: model not loaded")
}
if cfg.mode != modeTranslate {
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
@@ -270,7 +269,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,
cfg := c.cfg.Load()
if cfg == nil {
return grpcerrors.ModelNotLoaded("cloud-proxy")
return errors.New("cloud-proxy: model not loaded")
}
if cfg.mode != modePassthrough {
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)

View File

@@ -14,7 +14,7 @@ target_include_directories(gocrispasr PRIVATE
# whisper. crispasr is the referencer; the backend static libs supply the
# per-architecture symbols; ggml is the math/runtime base.
target_link_libraries(gocrispasr PRIVATE
crispasr-lib
crispasr
parakeet canary canary_ctc cohere granite_speech granite_nle
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# CrispASR version (release tag)
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
CRISPASR_VERSION?=c29f6653a516a3001d923944dad8892072cc7334
CRISPASR_VERSION?=05e60432bcb5bc2113f8c395a41e86497c11504a
SO_TARGET?=libgocrispasr.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -1,10 +0,0 @@
.cache/
sources/
build/
package/
dllm-grpc
# build artifacts staged in-tree by the Makefile (cp from sources/) or
# symlinked for local dev; the real sources live in dllm.cpp upstream.
*.so
*.so.*
compile_commands.json

View File

@@ -1,93 +0,0 @@
# dllm backend Makefile.
#
# Upstream pin lives below as DLLM_VERSION?=<sha> so .github/bump_deps.sh
# can find and update it - matches the whisper.cpp / parakeet-cpp / ds4
# convention.
#
# Local dev shortcut: if you already have an out-of-tree dllm.cpp build,
# you can symlink the .so into this directory and skip the clone/cmake
# steps entirely, e.g.:
#
# ln -sf /path/to/dllm.cpp/build/libdllm.so .
# go build -o dllm-grpc .
#
# That's what the gated C-ABI binding smoke uses (DLLM_TEST_LIBRARY). The
# default target below does the proper clone-at-pin + cmake build so CI
# doesn't need a side-checkout.
#
# NOTE: github.com/mudler/dllm.cpp is still private (publishing is planned);
# until then the anonymous clone below fails. Use the symlink shortcut above
# with a local checkout, or a git credential helper with access to the repo.
DLLM_VERSION?=b22fcebebfb225131113188599a9ae542b2935d7
DLLM_REPO?=https://github.com/mudler/dllm.cpp
GOCMD?=go
GO_TAGS?=
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
BUILD_TYPE?=
NATIVE?=false
# libdllm.so is self-contained: dllm.cpp's CMakeLists statically absorbs ggml
# (BUILD_SHARED_LIBS=OFF + PIC) into the shared lib, so dlopen needs no
# libggml*.so alongside it, only system libs (libstdc++/libgomp/libc) the
# runtime image already provides. Tests/CLI are upstream-only concerns.
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DDLLM_BUILD_TESTS=OFF
ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif
# Same arch set the sibling ggml backends (acestep/vibevoice/qwen3-tts) bake
# for their cublas images; override for a native build.
CUDA_ARCHITECTURES?=75-virtual;80-virtual;86-real;89-real
# dllm.cpp gates CUDA behind DLLM_CUDA (set(GGML_CUDA ... CACHE FORCE)), so
# forward that instead of a bare -DGGML_CUDA=ON.
ifeq ($(BUILD_TYPE),cublas)
CMAKE_ARGS+=-DDLLM_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="$(CUDA_ARCHITECTURES)"
endif
.PHONY: dllm-grpc package build clean purge test all
all: dllm-grpc
# Clone the upstream dllm.cpp source at the pinned commit (ggml comes in as
# a submodule). Directory acts as the target so make only re-clones when
# missing. After a DLLM_VERSION bump, run 'make purge && make' to refetch.
sources/dllm.cpp:
mkdir -p sources/dllm.cpp
cd sources/dllm.cpp && \
git init -q && \
git remote add origin $(DLLM_REPO) && \
git fetch --depth 1 origin $(DLLM_VERSION) && \
git checkout FETCH_HEAD && \
git submodule update --init --recursive --depth 1 --single-branch
# Build the shared lib out-of-tree, then stage it next to the Go sources so
# purego.Dlopen("libdllm.so") and the packaging step both pick it up.
libdllm.so: sources/dllm.cpp
cmake -B sources/dllm.cpp/build -S sources/dllm.cpp $(CMAKE_ARGS)
cmake --build sources/dllm.cpp/build --config Release -j$(JOBS)
cp -fv sources/dllm.cpp/build/libdllm.so ./
dllm-grpc: libdllm.so main.go capi.go
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o dllm-grpc .
package: dllm-grpc
bash package.sh
build: package
# Test target. The C-ABI binding smoke is gated on DLLM_TEST_LIBRARY +
# DLLM_TEST_TINY_MODEL; without them the gated specs auto-skip and only the
# pure-Go helper specs run.
test:
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
clean: purge
rm -rf libdllm.so* package dllm-grpc
purge:
rm -rf sources/dllm.cpp

View File

@@ -1,256 +0,0 @@
package main
// Typed Go wrappers over dllm.cpp's flat C-ABI (include/dllm_capi.h, ABI v1).
//
// Contract highlights the wrappers encode (see the header + src/capi.cpp):
// - tokenize_json/generate return malloc'd char* the CALLER owns: bound as
// uintptr, copied with goStringFromCPtr, released via dllm_capi_free_string.
// - last_error returns a BORROWED pointer (valid until the next call on the
// same ctx): bound as a plain string (purego copies), never freed, and only
// read AFTER the failing call has returned - reading it while a generate is
// in flight on the same ctx violates the per-ctx serialization contract.
// - All entry points except dllm_capi_cancel must be externally serialized
// per ctx (one ctx = one concurrent generate/tokenize). Cancel only flips
// an atomic and may be called from any goroutine mid-generate.
// - No C++ exception crosses the boundary; failures land in last_error.
import (
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"unsafe"
"github.com/ebitengine/purego"
)
// dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written
// against; main.go refuses to start against a libdllm.so reporting another.
const dllmABIVersion = 1
// purego-bound entry points from libdllm.so. Names match dllm_capi.h
// exactly; loadCAPI (main.go) fills these in at boot.
var (
cppAbiVersion func() int32
cppLoad func(ggufPath, paramsJSON string) uintptr
cppFree func(ctx uintptr)
cppLastError func(ctx uintptr) string // borrowed pointer: purego copies, do NOT free
cppFreeString func(s uintptr)
// malloc'd char* returns, hence uintptr (see loadCAPI's doc comment).
cppTokenizeJSON func(ctx uintptr, text string) uintptr
cppGenerate func(ctx uintptr, prompt, optsJSON string) uintptr
// on_block/on_step are C function pointers produced by purego.NewCallback;
// userData carries the streamCallStates registry key.
cppGenerateStream func(ctx uintptr, prompt, optsJSON string, onBlock, onStep, userData uintptr) int32
cppCancel func(ctx uintptr)
)
// cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION.
func cAbiVersion() int32 {
return cppAbiVersion()
}
// cLoad opens the GGUF at path with the flat params JSON (e.g.
// {"n_gpu_layers":99}). Returns 0 on failure; per the header contract there
// is no ctx to carry the reason, the C side logs it to stderr (and
// cLastError(0) only yields the static NULL-ctx message).
func cLoad(path, paramsJSON string) uintptr {
return cppLoad(path, paramsJSON)
}
// cFree releases a ctx; safe on 0 (delete nullptr).
func cFree(h uintptr) {
cppFree(h)
}
// cLastError returns the ctx's last error message (or the static NULL-ctx
// message for h==0). The C pointer is borrowed and only valid until the next
// call on the same ctx; purego's string return copies it immediately, so the
// returned Go string is safe to keep. Must not be called while another call
// on the same ctx is in flight.
func cLastError(h uintptr) string {
return cppLastError(h)
}
// lastErrorOr is cLastError with a fallback for the empty-message case, so
// wrapped errors never end in ": ".
func lastErrorOr(h uintptr, fallback string) string {
if msg := cLastError(h); msg != "" {
return msg
}
return fallback
}
// cTokenizeJSON tokenizes text (the C side prepends bos per vocab.add_bos)
// and returns the token ids as a JSON array string, e.g. "[2,18]".
func cTokenizeJSON(h uintptr, text string) (string, error) {
ret := cppTokenizeJSON(h, text)
if ret == 0 {
return "", fmt.Errorf("dllm: tokenize failed: %s", lastErrorOr(h, "unknown error"))
}
out := goStringFromCPtr(ret)
cppFreeString(ret)
return out, nil
}
// cGenerate runs a blocking generation and returns the detokenized text.
// optsJSON must be a FLAT JSON object of scalars (use buildOptsJSON); the C
// parser rejects nested objects/arrays. NULL return -> last_error (read only
// after the call returned, per the serialization contract); a cancelled call
// surfaces as the "cancelled" message.
func cGenerate(h uintptr, prompt, optsJSON string) (string, error) {
ret := cppGenerate(h, prompt, optsJSON)
if ret == 0 {
return "", fmt.Errorf("dllm: generate failed: %s", lastErrorOr(h, "unknown error"))
}
out := goStringFromCPtr(ret)
cppFreeString(ret)
return out, nil
}
// streamCallState carries the Go callbacks for one in-flight
// cGenerateStream call; the registry key travels through C as user_data.
// The map shape mirrors the whisper backend's streamCallStates: only one
// entry per ctx is ever live (the C-ABI is serialized per ctx), but keying
// by call survives multiple models/processes sharing the package.
type streamCallState struct {
onBlock func(text string)
onStep func(step, total int, preview string)
}
var (
streamCallStates sync.Map // uint64 -> *streamCallState
streamCallSeq atomic.Uint64
// purego.NewCallback allocates a finite, never-released callback slot, so
// the two trampolines are created exactly once and reused across calls.
streamCbOnce sync.Once
blockCbPtr uintptr
stepCbPtr uintptr
)
// onBlockTrampoline is the Go side of dllm_block_cb. It runs on the C
// calling thread, mid-generate: keep it tiny and non-blocking (callers that
// bridge to goroutines must hand off via buffered channels). The text
// pointer is only valid for the duration of the invocation, so it is copied
// to a Go string immediately.
func onBlockTrampoline(text uintptr, userData uintptr) {
v, ok := streamCallStates.Load(uint64(userData))
if !ok {
return // call already torn down
}
state := v.(*streamCallState)
if state.onBlock != nil {
state.onBlock(goStringFromCPtr(text))
}
}
// onStepTrampoline is the Go side of dllm_step_cb; same threading and
// lifetime caveats as onBlockTrampoline.
func onStepTrampoline(step int32, totalSteps int32, canvasPreview uintptr, userData uintptr) {
v, ok := streamCallStates.Load(uint64(userData))
if !ok {
return
}
state := v.(*streamCallState)
if state.onStep != nil {
state.onStep(int(step), int(totalSteps), goStringFromCPtr(canvasPreview))
}
}
// cGenerateStream runs a generation with per-committed-block (onBlock) and
// per-denoising-step (onStep) callbacks; either may be nil. The callbacks
// run on the C thread (see the trampoline docs). Returns an error carrying
// last_error on failure; cancellation surfaces as the "cancelled" message.
func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error {
streamCbOnce.Do(func() {
blockCbPtr = purego.NewCallback(onBlockTrampoline)
stepCbPtr = purego.NewCallback(onStepTrampoline)
})
id := streamCallSeq.Add(1)
streamCallStates.Store(id, &streamCallState{onBlock: onBlock, onStep: onStep})
defer streamCallStates.Delete(id)
// Pass NULL for absent callbacks so the C side skips the per-block /
// per-step detokenize work entirely.
var blockPtr, stepPtr uintptr
if onBlock != nil {
blockPtr = blockCbPtr
}
if onStep != nil {
stepPtr = stepCbPtr
}
if rc := cppGenerateStream(h, prompt, optsJSON, blockPtr, stepPtr, uintptr(id)); rc != 0 {
return fmt.Errorf("dllm: generate_stream failed: %s", lastErrorOr(h, "unknown error"))
}
return nil
}
// cCancel requests cancellation of the in-flight generate on h. This is the
// ONE entry point safe to call from any goroutine while a generate runs (it
// only flips an atomic). Note the cancel-reset race from the header: each
// generate resets the flag on entry, so a watchdog should re-issue cancel if
// the call has not returned.
func cCancel(h uintptr) {
cppCancel(h)
}
// buildOptsJSON renders generation options as the flat JSON object the
// C-ABI expects (known keys: n_predict, blocks, seed, eb_*, kv_cache). The
// C-side scanner only understands scalar number/string values and rejects
// nested objects/arrays loudly; bools are rejected here too because the
// scanner has no concept of them. Fail loud rather than let an option be
// silently misread.
//
// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g.
// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys
// (kv_cache: auto|on|off) can contain those bytes today; if one ever does,
// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString.
func buildOptsJSON(opts map[string]any) (string, error) {
if len(opts) == 0 {
return "{}", nil
}
for k, v := range opts {
switch v.(type) {
case string,
int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
float32, float64,
json.Number:
// scalar: fine
default:
return "", fmt.Errorf("dllm: opts key %q has non-scalar value %T (the C-ABI only accepts flat number/string scalars)", k, v)
}
}
b, err := json.Marshal(opts)
if err != nil {
return "", fmt.Errorf("dllm: marshal opts: %w", err)
}
return string(b), nil
}
// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is
// the raw pointer returned by purego from the C-ABI (a malloc'd buffer the
// caller owns, or a callback argument only valid during the invocation);
// owning callers must free it via cppFreeString after the copy lands.
//
// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check,
// which can't distinguish a C-owned heap pointer from Go-managed memory (the
// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting
// through &cptr below is equivalent at runtime and keeps plain `go vet`
// clean. It is safe either way: the pointer addresses C memory the Go GC
// neither tracks nor moves, and we dereference it immediately to copy the
// bytes out.
func goStringFromCPtr(cptr uintptr) string {
if cptr == 0 {
return ""
}
p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above)
n := 0
for *(*byte)(unsafe.Add(p, n)) != 0 {
n++
}
return string(unsafe.Slice((*byte)(p), n))
}

View File

@@ -1,553 +0,0 @@
package main
// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models).
//
// Wiring overview:
// - Load opens the GGUF via dllm_capi_load and starts the per-model worker
// goroutine that serializes every C call (see submit).
// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the
// request carries raw messages (use_tokenizer_template), the backend owns
// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies
// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend.
// - The legacy Predict / PredictStream methods delegate to the rich pair
// (cloud-proxy precedent); the gRPC server prefers the rich path anyway.
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"unicode/utf8"
grpc "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/base"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
)
// The gRPC server cancels in-flight generations on client disconnect only
// for backends advertising the Cancellable capability; keep Dllm pinned to
// it so a signature drift fails the build, not the disconnect path.
var _ grpc.Cancellable = (*Dllm)(nil)
// generator is the seam between the backend wiring and the dllm.cpp C-ABI:
// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON
// family, while tests substitute a fake to exercise prompt construction,
// parsing and serialization without libdllm.so.
type generator interface {
generate(prompt, optsJSON string) (string, error)
// generateStream invokes onBlock once per committed diffusion block, on
// the thread running the C call, before returning.
generateStream(prompt, optsJSON string, onBlock func(text string)) error
tokenizeJSON(text string) (string, error)
// cancel is the ONE entry point safe to call concurrently with an
// in-flight generate on the same ctx (dllm_capi.h: it only flips an
// atomic; everything else must be externally serialized per ctx).
cancel()
free()
}
// capiGenerator is the production generator over one dllm_ctx handle.
type capiGenerator struct {
h uintptr
}
func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) {
return cGenerate(g.h, prompt, optsJSON)
}
func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
// on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is
// passed as nil for now: a future progress hook for the React UI can
// plumb it through without touching the C binding.
return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil)
}
func (g *capiGenerator) tokenizeJSON(text string) (string, error) {
return cTokenizeJSON(g.h, text)
}
func (g *capiGenerator) cancel() {
cCancel(g.h)
}
func (g *capiGenerator) free() {
cFree(g.h)
}
// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts
// one backend process per model).
type Dllm struct {
base.Base
gen generator
// genOpts holds the model-level generation overrides parsed from
// ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes
// them per-generate, not per-load, so they are merged into every
// request's opts JSON (requestOptsJSON).
genOpts map[string]any
// jobs is the per-model worker queue. dllm_capi.h requires every entry
// point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one
// ctx = one concurrent generate/tokenize; last_error is unsafe to read
// while a call is in flight). A single goroutine owning all C calls makes
// that contract structural instead of relying on lock discipline.
jobs chan func()
workerWG sync.WaitGroup
// genMu guards gen against Free racing in-flight requests: requests hold
// the read lock for their full duration (they stay concurrent with each
// other - the worker still serializes the C calls), Free takes the write
// lock so it can only run when no request is in flight.
genMu sync.RWMutex
}
func (d *Dllm) startWorker() {
d.jobs = make(chan func())
d.workerWG.Add(1)
go func() {
defer d.workerWG.Done()
for job := range d.jobs {
job()
}
}()
}
// submit runs job on the worker goroutine and waits for it to finish.
// Concurrent gRPC requests therefore queue up and execute one at a time
// against the single dllm_ctx.
func (d *Dllm) submit(job func()) {
done := make(chan struct{})
d.jobs <- func() {
defer close(done)
job()
}
<-done
}
// Load opens the GGUF and prepares the worker. Load-time engine parameters
// travel as the flat params JSON of dllm_capi_load; generation overrides
// from Options are stored for per-request opts JSON instead (the C-ABI has
// no per-load sampler state).
func (d *Dllm) Load(opts *pb.ModelOptions) error {
if d.gen != nil {
return errors.New("dllm: model already loaded")
}
params := map[string]any{
"n_gpu_layers": opts.GetNGPULayers(),
}
if opts.GetThreads() > 0 {
params["n_threads"] = opts.GetThreads()
}
if opts.GetContextSize() > 0 {
params["ctx_len"] = opts.GetContextSize()
}
paramsJSON, err := buildOptsJSON(params)
if err != nil {
return err
}
d.genOpts = parseModelGenOpts(opts.GetOptions())
h := cLoad(opts.GetModelFile(), paramsJSON)
if h == 0 {
// No ctx exists on load failure, so last_error(NULL) only carries the
// static NULL-ctx message; the real reason is on the backend's stderr.
return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)",
opts.GetModelFile(), lastErrorOr(0, "unknown error"))
}
d.gen = &capiGenerator{h: h}
d.startWorker()
xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts)
return nil
}
// Free releases the dllm ctx and stops the worker. Safe when never loaded.
//
// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the
// model-unload path around line 764) calls Free with no locking of its own,
// and base.Base provides none either. Without it a request racing Free would
// panic sending on the closed jobs channel - or worse, generate on a freed C
// ctx. Holding genMu until gen is nil also turns post-Free requests into a
// clean "model not loaded" error instead of a crash.
func (d *Dllm) Free() error {
d.genMu.Lock()
defer d.genMu.Unlock()
if d.gen == nil {
return nil
}
d.submit(d.gen.free)
close(d.jobs)
d.workerWG.Wait()
d.gen = nil
return nil
}
// Cancel requests cancellation of the in-flight generate (the
// grpc.Cancellable capability). The gRPC server arms it via
// context.AfterFunc on the request/stream context, so a client
// disconnect or timeout aborts the generation server-side - the same
// semantics the llama.cpp C++ backend gets from polling IsCancelled().
// It deliberately bypasses the worker queue: dllm_capi_cancel is the one
// call the C-ABI allows from any goroutine mid-generate (it only flips
// an atomic).
//
// Note dllm_capi.h's cancel-reset race: each generate resets the flag on
// entry, so a Cancel racing a NEW generate on the same ctx can be lost
// (and, with requests queued on the worker, it aborts whichever generate
// is currently running). The single-flag granularity is acceptable here
// because the server de-registers the hook on normal completion and one
// backend process serves one model.
func (d *Dllm) Cancel() {
// RLock so a server-side AfterFunc firing in the window between a
// request finishing and a model unload cannot touch a freed C ctx
// (Free holds the write lock while tearing gen down). cancel() is the
// one C call that is safe concurrently with an in-flight generate, so
// taking a read lock here cannot deadlock against request holders.
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen != nil {
d.gen.cancel()
}
}
// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to
// the engine. Options is a shared free-form bag (other layers put their own
// entries there), so unknown keys are skipped with a warning, not an error.
var dllmGenOptKeys = map[string]bool{
"blocks": true,
"kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3
}
// parseModelGenOpts parses "key:value" Options entries into the flat scalar
// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler
// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by
// first successful parse (int, then float, else string) to match the C
// scanner's number/string scalars.
func parseModelGenOpts(options []string) map[string]any {
out := map[string]any{}
for _, o := range options {
key, val, found := strings.Cut(o, ":")
if !found {
xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o)
continue
}
if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] {
xlog.Debug("dllm: ignoring unrecognized option", "key", key)
continue
}
out[key] = parseScalarOpt(val)
}
return out
}
func parseScalarOpt(v string) any {
if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
return iv
}
if fv, err := strconv.ParseFloat(v, 64); err == nil {
return fv
}
return v
}
// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default
// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat
// template guards every thinking branch with `enable_thinking is defined and
// enable_thinking`, i.e. thinking is opt-in for this model family, and the
// no-thinking render pre-closes an empty thought channel that the OFF
// default must produce.
func metadataEnableThinking(opts *pb.PredictOptions) bool {
v := opts.GetMetadata()["enable_thinking"]
return v == "true" || v == "1"
}
// buildPrompt resolves the prompt for a request. With use_tokenizer_template
// and raw messages the backend owns templating (RenderGemma4) and the output
// is in the known gemma4 format, so parse=true. Without it the caller
// templated the prompt themselves (LocalAI's Go templates + PEG fallback, or
// a bare completion): the prompt passes through verbatim and the output is
// NOT gemma4-parsed - it is emitted as plain content and the Go side's
// extraction applies, as for any non-autoparsing backend.
func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) {
if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 {
prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), metadataEnableThinking(opts), true)
return prompt, true, err
}
return opts.GetPrompt(), false, nil
}
// requestOptsJSON merges the model-level overrides with the request's
// sampling fields into the flat opts JSON for one generate call.
func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) {
m := make(map[string]any, len(d.genOpts)+2)
for k, v := range d.genOpts {
m[k] = v
}
if n := opts.GetTokens(); n > 0 {
// The engine rounds n_predict UP to a whole number of diffusion
// blocks (the canvas is denoised block-wise), so the completion may
// run slightly past the requested budget. Tokens==0 omits the key so
// the C-ABI default of 256 applies (hardcoded in capi.cpp's
// parse_gen_opts, independent of canvas_length).
m["n_predict"] = n
}
if s := opts.GetSeed(); s > 0 {
// The engine seeds mt19937 with explicit non-negative seeds. Seed<=0
// is omitted: proto3 cannot distinguish 0 from unset, and negative
// values conventionally mean "random" across LocalAI backends.
m["seed"] = s
}
return buildOptsJSON(m)
}
// prepareRequest is the shared prologue of the rich methods: resolve the
// prompt (and whether the output gets gemma4-parsed) and build the per-call
// opts JSON.
func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON string, err error) {
prompt, parse, err = buildPrompt(opts)
if err != nil {
return "", false, "", err
}
optsJSON, err = d.requestOptsJSON(opts)
if err != nil {
return "", false, "", err
}
return prompt, parse, optsJSON, nil
}
// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary
// detokenization and byte-fallback tokens can produce invalid UTF-8, and
// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so
// every string destined for a Reply/ChatDelta must pass through here (or
// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely
// undecodable: replace with U+FFFD rather than crash the stream.
func sanitizeUTF8(s string) string {
if utf8.ValidString(s) {
return s
}
return strings.ToValidUTF8(s, "<22>")
}
// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte
// (1 for bytes that can never lead a multi-byte sequence, so they are never
// held back and fall through to sanitizeUTF8's replacement).
func utf8SeqLen(b byte) int {
switch {
case b&0xE0 == 0xC0:
return 2
case b&0xF0 == 0xE0:
return 3
case b&0xF8 == 0xF0:
return 4
default:
return 1
}
}
// splitValidUTF8 prepends the previous block's carry to the new block and
// splits the result into text safe to emit now and a trailing INCOMPLETE
// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block:
// the per-block detokenize can split a multi-byte character across block
// boundaries (llama.cpp's grpc-server holds back the same way). Only a
// suffix that can still become a valid rune is withheld; bytes that are
// already undecodable are replaced immediately so the carry stays bounded.
func splitValidUTF8(carry, block string) (emit, newCarry string) {
s := carry + block
cut := len(s)
for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- {
b := s[i]
if b < utf8.RuneSelf {
break // ASCII: everything before the tail scan is complete
}
if !utf8.RuneStart(b) {
continue // continuation byte: keep looking for its leading byte
}
// Leading byte: hold the sequence back iff it declares more bytes
// than the stream has produced so far (it may complete next block).
if utf8SeqLen(b) > len(s)-i {
cut = i
}
break
}
return sanitizeUTF8(s[:cut]), s[cut:]
}
// PredictRich is the non-streaming inference path (grpc.AIModelRich).
// Returns one Reply whose Message is the aggregated assistant content and
// whose ChatDeltas carry the parsed content/reasoning/tool-call events.
func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) {
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen == nil {
return nil, grpcerrors.ModelNotLoaded("dllm")
}
prompt, parse, optsJSON, err := d.prepareRequest(opts)
if err != nil {
return nil, err
}
var out string
var genErr error
d.submit(func() {
out, genErr = d.gen.generate(prompt, optsJSON)
})
if genErr != nil {
return nil, genErr
}
// Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings
// must be valid or grpc-go fails the whole reply at marshal time.
out = sanitizeUTF8(out)
if !parse {
// Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt).
return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil
}
// The prompt renders with add_generation_prompt; both thinking modes
// leave the model starting in content state (see the Gemma4Parser header
// comment), hence NewGemma4Parser(false).
parser := NewGemma4Parser(false)
if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil {
return reply, nil
}
// Everything was markers (or out was empty): an empty but non-nil Reply.
return &pb.Reply{}, nil
}
// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one
// Reply per committed diffusion block that produced deltas. Per the
// interface contract the channel is only sent into here - the gRPC server
// closes it after this returns (opposite to legacy PredictStream).
func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error {
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen == nil {
return grpcerrors.ModelNotLoaded("dllm")
}
prompt, parse, optsJSON, err := d.prepareRequest(opts)
if err != nil {
return err
}
var parser *Gemma4Parser
if parse {
parser = NewGemma4Parser(false)
}
// emit runs inside onBlock, i.e. on the thread driving the C generate.
// Sending on results can block on a slow consumer, but the server-side
// pump (pkg/grpc/server.go PredictStream) drains continuously and drops
// undeliverable sends, so this backpressure is brief and bounded - and
// pausing the diffusion loop under it is the desired behavior anyway.
emit := func(text string) {
if !parse {
if text != "" {
results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}}
}
return
}
deltas := parser.Feed(text)
if reply := replyFromDeltas(deltas); reply != nil {
results <- reply
}
}
// onBlock guards emit (and through it the parser) against invalid UTF-8:
// a multi-byte character split across block boundaries is held back until
// it completes (see splitValidUTF8), so proto3 marshaling never fails.
var carry string
onBlock := func(block string) {
var text string
text, carry = splitValidUTF8(carry, block)
emit(text)
}
var genErr error
d.submit(func() {
genErr = d.gen.generateStream(prompt, optsJSON, onBlock)
})
if genErr != nil {
return genErr
}
if carry != "" {
// The stream ended mid-sequence: the held-back bytes can no longer
// complete, so flush them through the U+FFFD last resort.
emit(sanitizeUTF8(carry))
}
if parse {
if reply := replyFromDeltas(parser.Close()); reply != nil {
results <- reply
}
}
return nil
}
// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply,
// or nil when the batch is empty (markers consumed, nothing emitted yet).
// Message mirrors the batch's content text so legacy chan-string consumers
// see exactly the displayed tokens.
func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply {
if len(deltas) == 0 {
return nil
}
var content strings.Builder
for _, delta := range deltas {
content.WriteString(delta.GetContent())
}
return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas}
}
// Predict is the legacy (string, error) signature; the gRPC server prefers
// PredictRich, this exists for non-rich callers (cloud-proxy precedent).
func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) {
reply, err := d.PredictRich(opts)
if err != nil {
return "", err
}
return string(reply.GetMessage()), nil
}
// PredictStream is the legacy chan-string path: rich replies reduced to
// their content text. Note the inverted channel ownership - the LEGACY
// contract requires the impl to close the channel.
func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error {
defer close(results)
richCh := make(chan *pb.Reply)
errCh := make(chan error, 1)
go func() {
errCh <- d.PredictStreamRich(opts, richCh)
close(richCh)
}()
for reply := range richCh {
if msg := reply.GetMessage(); len(msg) > 0 {
results <- string(msg)
}
}
return <-errCh
}
// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C
// side prepends bos per the vocab) and decodes the returned id array.
func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen == nil {
return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm")
}
var out string
var tokErr error
d.submit(func() {
out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt())
})
if tokErr != nil {
return pb.TokenizationResponse{}, tokErr
}
var tokens []int32
if err := json.Unmarshal([]byte(out), &tokens); err != nil {
return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err)
}
return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil
}

View File

@@ -1,807 +0,0 @@
package main
import (
"errors"
"os"
"runtime"
"sync"
"testing"
"time"
"unicode/utf8"
"unsafe"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
func TestDllm(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "dllm Backend Suite")
}
var (
libLoadOnce sync.Once
libLoadErr error
)
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
// C-ABI bridge without spinning up the gRPC server. The library path comes
// from DLLM_TEST_LIBRARY (gated specs Skip when it is unset).
func ensureLibLoaded() {
libLoadOnce.Do(func() {
libLoadErr = loadCAPI(os.Getenv("DLLM_TEST_LIBRARY"))
})
}
// C-ABI binding smoke: drives the real libdllm.so against the tiny GGUF
// fixture from dllm.cpp (tests/fixtures/tiny_with_vocab.gguf). Gated on:
//
// DLLM_TEST_LIBRARY absolute path to libdllm.so
// DLLM_TEST_TINY_MODEL absolute path to tiny_with_vocab.gguf
var _ = Describe("C-ABI binding", func() {
BeforeEach(func() {
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the C-ABI binding smoke")
}
ensureLibLoaded()
Expect(libLoadErr).ToNot(HaveOccurred())
})
It("binds the 9 symbols and round-trips the tiny model", func() {
Expect(cAbiVersion()).To(Equal(int32(1)))
h := cLoad(os.Getenv("DLLM_TEST_TINY_MODEL"), "{}")
Expect(h).ToNot(BeZero(), "dllm_capi_load of the tiny fixture")
// Tiny fixture vocab: "hello" tokenizes to ids [2,18] (bos prepended
// by the C side: vocab.add_bos).
toks, err := cTokenizeJSON(h, "hello")
Expect(err).ToNot(HaveOccurred())
Expect(toks).To(Equal("[2,18]"))
// Deterministic generation: an explicit non-negative seed seeds
// mt19937, so two identical calls must produce identical text.
out1, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`)
Expect(err).ToNot(HaveOccurred())
Expect(out1).ToNot(BeEmpty())
// Cancel with no call in flight is dropped: each generate resets the
// cancel flag on entry (header contract), so this must not affect
// the next call. Also binds the 9th symbol; safe on NULL too.
cCancel(h)
cCancel(0)
out2, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`)
Expect(err).ToNot(HaveOccurred())
Expect(out2).To(Equal(out1))
// Streaming variant: same opts, blocks arrive via the purego
// callback trampoline. The per-block detokenize can differ from the
// seamless full-text decode at block boundaries, so only assert that
// blocks arrived and were non-trivial, not byte equality with out1.
var blocks []string
var steps int
err = cGenerateStream(h, "hello", `{"n_predict":16,"seed":7}`,
func(text string) { blocks = append(blocks, text) },
func(step, total int, preview string) { steps++ },
)
Expect(err).ToNot(HaveOccurred())
Expect(blocks).ToNot(BeEmpty())
Expect(steps).To(BeNumerically(">", 0))
// Load failure path: NULL ctx back, and last_error(NULL) returns the
// static NULL-ctx message (there is no ctx to carry the real reason).
bad := cLoad("/nonexistent/dllm-model.gguf", "{}")
Expect(bad).To(BeZero())
Expect(cLastError(0)).ToNot(BeEmpty())
// Free is safe on a live handle and a NULL one (delete nullptr).
cFree(h)
cFree(0)
})
})
// Ungated specs for the pure-Go helpers (no libdllm.so required).
var _ = Describe("buildOptsJSON", func() {
It("renders flat scalars as a JSON object", func() {
out, err := buildOptsJSON(map[string]any{
"n_predict": 16,
"seed": int64(7),
"eb_t_min": 0.5,
"kv_cache": "auto",
})
Expect(err).ToNot(HaveOccurred())
Expect(out).To(MatchJSON(`{"n_predict":16,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
})
It("renders an empty object for no options", func() {
out, err := buildOptsJSON(nil)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal("{}"))
})
It("rejects nested objects (the C-side scanner only reads flat scalars)", func() {
_, err := buildOptsJSON(map[string]any{"sampler": map[string]any{"seed": 1}})
Expect(err).To(HaveOccurred())
})
It("rejects arrays", func() {
_, err := buildOptsJSON(map[string]any{"stop": []string{"a"}})
Expect(err).To(HaveOccurred())
})
It("rejects booleans (the C-side scanner only understands numbers and strings)", func() {
_, err := buildOptsJSON(map[string]any{"flag": true})
Expect(err).To(HaveOccurred())
})
})
var _ = Describe("splitValidUTF8", func() {
It("holds back a trailing incomplete sequence and completes it next block", func() {
emit, carry := splitValidUTF8("", "caf\xe2")
Expect(emit).To(Equal("caf"))
Expect(carry).To(Equal("\xe2"))
emit, carry = splitValidUTF8(carry, "\x82")
Expect(emit).To(BeEmpty())
Expect(carry).To(Equal("\xe2\x82"))
emit, carry = splitValidUTF8(carry, "\xac!")
Expect(emit).To(Equal("€!"))
Expect(carry).To(BeEmpty())
})
It("holds back up to 3 bytes of a 4-byte sequence", func() {
emit, carry := splitValidUTF8("", "x\xf0\x9f\x98") // 😀 missing its last byte
Expect(emit).To(Equal("x"))
Expect(carry).To(Equal("\xf0\x9f\x98"))
emit, carry = splitValidUTF8(carry, "\x80")
Expect(emit).To(Equal("😀"))
Expect(carry).To(BeEmpty())
})
It("replaces undecodable bytes immediately instead of carrying them", func() {
// A mid-string invalid byte can never complete: carrying it would let
// the carry grow unboundedly, so it is substituted on the spot.
emit, carry := splitValidUTF8("", "a\xe2bc")
Expect(emit).To(Equal("a<>bc"))
Expect(carry).To(BeEmpty())
// Orphan continuation bytes at the end have no leading byte to wait
// for either.
emit, carry = splitValidUTF8("", "a\x82")
Expect(emit).To(Equal("a<>"))
Expect(carry).To(BeEmpty())
})
It("passes pure ASCII and complete UTF-8 through untouched", func() {
emit, carry := splitValidUTF8("", "héllo €")
Expect(emit).To(Equal("héllo €"))
Expect(carry).To(BeEmpty())
})
})
var _ = Describe("goStringFromCPtr", func() {
It("copies a NUL-terminated buffer", func() {
buf := []byte("dllm\x00")
s := goStringFromCPtr(uintptr(unsafe.Pointer(&buf[0])))
// The uintptr round-trip hides buf from the GC's liveness analysis;
// keep it reachable until after the copy.
runtime.KeepAlive(buf)
Expect(s).To(Equal("dllm"))
})
It("returns the empty string for NULL", func() {
Expect(goStringFromCPtr(0)).To(Equal(""))
})
})
// ---------------------------------------------------------------------------
// Backend wiring (T4): fake-generator specs, no libdllm.so required.
// ---------------------------------------------------------------------------
type fakeGenCall struct {
prompt string
optsJSON string
}
// fakeGen implements generator in-process. It records every call (prompt +
// opts JSON), tracks concurrent in-flight calls to prove worker
// serialization, and replays canned output (out for generate/tokenize,
// blocks for generateStream).
type fakeGen struct {
mu sync.Mutex
calls []fakeGenCall
inFlight int
maxInFlight int
out string
blocks []string
err error
delay time.Duration
}
func (f *fakeGen) begin(prompt, optsJSON string) {
f.mu.Lock()
defer f.mu.Unlock()
f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON})
f.inFlight++
if f.inFlight > f.maxInFlight {
f.maxInFlight = f.inFlight
}
}
func (f *fakeGen) end() {
f.mu.Lock()
defer f.mu.Unlock()
f.inFlight--
}
func (f *fakeGen) snapshot() (calls []fakeGenCall, maxInFlight int) {
f.mu.Lock()
defer f.mu.Unlock()
return append([]fakeGenCall(nil), f.calls...), f.maxInFlight
}
func (f *fakeGen) generate(prompt, optsJSON string) (string, error) {
f.begin(prompt, optsJSON)
defer f.end()
if f.delay > 0 {
time.Sleep(f.delay)
}
return f.out, f.err
}
func (f *fakeGen) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
f.begin(prompt, optsJSON)
defer f.end()
if f.err != nil {
return f.err
}
for _, b := range f.blocks {
onBlock(b)
}
return nil
}
func (f *fakeGen) tokenizeJSON(text string) (string, error) {
f.begin(text, "")
defer f.end()
return f.out, f.err
}
func (f *fakeGen) cancel() {}
func (f *fakeGen) free() {}
// newTestDllm assembles a backend around a fake generator (bypassing Load,
// which needs libdllm.so) and registers cleanup of the worker goroutine.
func newTestDllm(g generator, genOpts map[string]any) *Dllm {
d := &Dllm{gen: g, genOpts: genOpts}
d.startWorker()
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
return d
}
// drainReplies empties ch without blocking, failing the spec if the channel
// was closed (PredictStreamRich must NOT close it - interface.go contract).
// Size ch above the expected reply count: an overflow deadlocks the spec on
// the producer's send instead of failing it.
func drainReplies(ch chan *pb.Reply) []*pb.Reply {
var out []*pb.Reply
for {
select {
case r, ok := <-ch:
if !ok {
Fail("PredictStreamRich closed the results channel (the gRPC server owns the close)")
}
expectValidUTF8Reply(r)
out = append(out, r)
default:
return out
}
}
}
// expectValidUTF8Reply is the blanket guard for the proto3 marshaling
// contract: grpc-go rejects any string field carrying invalid UTF-8, so every
// reply field that ends up in a proto string must validate.
func expectValidUTF8Reply(r *pb.Reply) {
GinkgoHelper()
Expect(utf8.ValidString(string(r.GetMessage()))).To(BeTrue(), "Reply.Message carries invalid UTF-8")
for _, delta := range r.GetChatDeltas() {
Expect(utf8.ValidString(delta.GetContent())).To(BeTrue(), "ChatDelta.Content carries invalid UTF-8")
Expect(utf8.ValidString(delta.GetReasoningContent())).To(BeTrue(), "ChatDelta.ReasoningContent carries invalid UTF-8")
for _, tc := range delta.GetToolCalls() {
Expect(utf8.ValidString(tc.GetName())).To(BeTrue(), "ToolCallDelta.Name carries invalid UTF-8")
Expect(utf8.ValidString(tc.GetArguments())).To(BeTrue(), "ToolCallDelta.Arguments carries invalid UTF-8")
}
}
}
var _ = Describe("Dllm backend wiring", func() {
Describe("PredictRich", func() {
It("renders gemma4 from raw messages and parses the output when use_tokenizer_template is set", func() {
fake := &fakeGen{out: "<|channel>thought\npondering<channel|>The answer.<turn|>"}
d := newTestDllm(fake, nil)
reply, err := d.PredictRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
Metadata: map[string]string{"enable_thinking": "true"},
})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls).To(HaveLen(1))
// The enable_thinking=true render from the transformers fixture.
Expect(calls[0].prompt).To(Equal(
"<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n"))
Expect(string(reply.GetMessage())).To(Equal("The answer."))
Expect(reply.GetChatDeltas()).To(HaveLen(2))
Expect(reply.GetChatDeltas()[0].GetReasoningContent()).To(Equal("pondering"))
Expect(reply.GetChatDeltas()[1].GetContent()).To(Equal("The answer."))
})
It("defaults enable_thinking OFF (the gemma4 template treats thinking as opt-in)", func() {
fake := &fakeGen{out: "hi"}
d := newTestDllm(fake, nil)
_, err := d.PredictRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
// No-thinking render: the template pre-opens AND pre-closes an
// empty thought channel in the generation prompt.
Expect(calls[0].prompt).To(Equal(
"<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>"))
})
It("passes the raw prompt verbatim and skips gemma4 parsing without use_tokenizer_template", func() {
// Marker-looking text must survive untouched: in raw-prompt mode
// the caller templates themselves and the Go-side extraction
// applies, so the backend must not interpret the output.
fake := &fakeGen{out: "<|channel>thought\nnot parsed<channel|>tail"}
d := newTestDllm(fake, nil)
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "my raw prompt"})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls[0].prompt).To(Equal("my raw prompt"))
Expect(string(reply.GetMessage())).To(Equal(fake.out))
Expect(reply.GetChatDeltas()).To(HaveLen(1))
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal(fake.out))
})
It("sanitizes invalid UTF-8 in the non-streaming output", func() {
// Byte-fallback tokens can decode to lone malformed bytes; the
// whole-output sanitize must replace them so proto3 marshaling of
// Message/ChatDeltas cannot fail.
fake := &fakeGen{out: "a\xe2b"}
d := newTestDllm(fake, nil)
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
expectValidUTF8Reply(reply)
Expect(string(reply.GetMessage())).To(Equal("a<>b"))
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal("a<>b"))
})
It("maps Tokens and Seed into the opts JSON on top of the model-level overrides", func() {
fake := &fakeGen{out: "x"}
d := newTestDllm(fake, map[string]any{"eb_t_min": 0.5, "kv_cache": "auto"})
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p", Tokens: 32, Seed: 7})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls[0].optsJSON).To(MatchJSON(`{"n_predict":32,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
})
It("omits n_predict and seed when unset so the engine defaults apply", func() {
fake := &fakeGen{out: "x"}
d := newTestDllm(fake, nil)
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls[0].optsJSON).To(MatchJSON(`{}`))
})
It("surfaces generator errors", func() {
fake := &fakeGen{err: errors.New("boom")}
d := newTestDllm(fake, nil)
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).To(MatchError("boom"))
})
It("errors before generating when no model is loaded", func() {
d := &Dllm{} // no Load, no worker: must fail fast, not hang
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).To(HaveOccurred())
})
It("makes a concurrent Free wait for the in-flight request (both finish cleanly)", func() {
// server.go's Free has no locking of its own: the backend's genMu
// must hold Free back until the racing generate drains, instead of
// closing the jobs channel (panic) or freeing the C ctx under it.
fake := &fakeGen{out: "x", delay: 50 * time.Millisecond}
d := newTestDllm(fake, nil)
predictDone := make(chan error, 1)
go func() {
defer GinkgoRecover()
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
predictDone <- err
}()
// Wait until the fake generate is actually in flight (the read
// lock is held from before submit until PredictRich returns).
Eventually(func() int {
_, maxInFlight := fake.snapshot()
return maxInFlight
}).Should(Equal(1))
Expect(d.Free()).To(Succeed())
// Free's write lock means the request finished before Free did.
var predictErr error
Eventually(predictDone).Should(Receive(&predictErr))
Expect(predictErr).ToNot(HaveOccurred())
})
It("returns model-not-loaded for requests after Free", func() {
fake := &fakeGen{out: "x"}
d := newTestDllm(fake, nil)
Expect(d.Free()).To(Succeed())
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
})
It("serializes concurrent requests through the worker goroutine", func() {
// dllm_capi.h: one ctx = one concurrent generate. Two overlapping
// PredictRich calls must execute the C calls one at a time.
fake := &fakeGen{out: "x", delay: 30 * time.Millisecond}
d := newTestDllm(fake, nil)
var wg sync.WaitGroup
for range 2 {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
}()
}
wg.Wait()
calls, maxInFlight := fake.snapshot()
Expect(calls).To(HaveLen(2))
Expect(maxInFlight).To(Equal(1), "generate calls overlapped despite the worker queue")
})
})
Describe("PredictStreamRich", func() {
It("emits one reply per delta-producing block and leaves the channel open", func() {
// Blocks split mid-marker and mid-payload: the parser's holdback
// must keep marker fragments out of the emitted deltas.
fake := &fakeGen{blocks: []string{
"<|channel>thou", // partial channel open: no deltas yet
"ght\nponder", // header completes, reasoning starts
"ing<channel|>Hi ", // reasoning ends, content starts
"there<turn|>discarded", // turn ends: trailing text dropped
}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, ch)
Expect(err).ToNot(HaveOccurred())
replies := drainReplies(ch)
Expect(replies).To(HaveLen(3), "block 1 completes no delta and must not produce a reply")
var content, reasoning string
for _, r := range replies {
for _, delta := range r.GetChatDeltas() {
content += delta.GetContent()
reasoning += delta.GetReasoningContent()
}
}
Expect(reasoning).To(Equal("pondering"))
Expect(content).To(Equal("Hi there"))
// Message mirrors each reply's content so legacy consumers see
// exactly the displayed tokens.
Expect(string(replies[1].GetMessage())).To(Equal("Hi "))
Expect(string(replies[2].GetMessage())).To(Equal("there"))
})
It("streams raw blocks verbatim without use_tokenizer_template", func() {
fake := &fakeGen{blocks: []string{"abc", "", "<|channel>def"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
Expect(err).ToNot(HaveOccurred())
replies := drainReplies(ch)
Expect(replies).To(HaveLen(2), "empty blocks produce no reply")
Expect(string(replies[0].GetMessage())).To(Equal("abc"))
Expect(string(replies[1].GetMessage())).To(Equal("<|channel>def"))
Expect(replies[1].GetChatDeltas()).To(HaveLen(1))
})
It("flushes parser holdback after the stream ends", func() {
// The unterminated partial marker "<chan" is held back during the
// stream and must come out as content on the final flush.
fake := &fakeGen{blocks: []string{"tail<chan"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) {
content += string(r.GetMessage())
}
Expect(content).To(Equal("tail<chan"))
})
It("reassembles a multi-byte character split across block boundaries", func() {
// Per-block detokenize can split "€" (E2 82 AC) as E2 | 82 AC.
// Emitting the lone E2 would make grpc-go fail the marshal of the
// whole reply; the trailing incomplete sequence must be held back
// and completed by the next block.
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac ok"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) { // drain asserts ValidString per reply
content += string(r.GetMessage())
}
Expect(content).To(Equal("caf€ ok"))
})
It("reassembles a split multi-byte character in parsed (gemma4) mode too", func() {
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac<turn|>"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) {
for _, delta := range r.GetChatDeltas() {
content += delta.GetContent()
}
}
Expect(content).To(Equal("caf€"))
})
It("replaces an incomplete sequence left at stream end with U+FFFD", func() {
// A byte-fallback token can leave a lone leading byte (0xE2) that
// no later block completes: the final flush must substitute it,
// never emit it raw and never drop into a marshal error.
fake := &fakeGen{blocks: []string{"ok\xe2"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) {
content += string(r.GetMessage())
}
Expect(content).To(Equal("ok<6F>"))
})
It("surfaces generator errors without sending replies", func() {
fake := &fakeGen{err: errors.New("stream boom")}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
Expect(err).To(MatchError("stream boom"))
Expect(drainReplies(ch)).To(BeEmpty())
})
It("errors before generating when no model is loaded", func() {
d := &Dllm{} // no Load, no worker: must fail fast, not hang
ch := make(chan *pb.Reply, 1)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
Expect(drainReplies(ch)).To(BeEmpty())
})
})
Describe("legacy Predict/PredictStream adapters", func() {
It("Predict returns the aggregated content string", func() {
fake := &fakeGen{out: "plain text"}
d := newTestDllm(fake, nil)
out, err := d.Predict(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal("plain text"))
})
It("PredictStream forwards content strings and closes the channel (legacy ownership)", func() {
fake := &fakeGen{blocks: []string{"a", "b"}}
d := newTestDllm(fake, nil)
ch := make(chan string, 16)
Expect(d.PredictStream(&pb.PredictOptions{Prompt: "p"}, ch)).To(Succeed())
var got []string
for s := range ch { // terminates only if the impl closed ch
got = append(got, s)
}
Expect(got).To(Equal([]string{"a", "b"}))
})
})
Describe("TokenizeString", func() {
It("decodes the C-side JSON id array", func() {
fake := &fakeGen{out: "[2,18]"}
d := newTestDllm(fake, nil)
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).ToNot(HaveOccurred())
Expect(resp.Length).To(Equal(int32(2)))
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
calls, _ := fake.snapshot()
Expect(calls[0].prompt).To(Equal("hello"))
})
It("fails loud on a malformed id array", func() {
fake := &fakeGen{out: "not json"}
d := newTestDllm(fake, nil)
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).To(HaveOccurred())
})
It("errors before tokenizing when no model is loaded", func() {
d := &Dllm{} // no Load, no worker: must fail fast, not hang
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
})
})
Describe("parseModelGenOpts", func() {
It("parses eb_*/blocks/kv_cache entries and types values by first successful parse", func() {
got := parseModelGenOpts([]string{
"eb_max_steps:16",
"eb_t_min:0.25",
"kv_cache:auto",
"blocks:4",
"unrelated_key:1", // other layers' options: skipped
"malformed", // no colon: skipped
})
Expect(got).To(Equal(map[string]any{
"eb_max_steps": int64(16),
"eb_t_min": 0.25,
"kv_cache": "auto",
"blocks": int64(4),
}))
})
It("round-trips through buildOptsJSON (only flat scalars are produced)", func() {
got := parseModelGenOpts([]string{"eb_entropy_bound:0.8", "kv_cache:off"})
out, err := buildOptsJSON(got)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(MatchJSON(`{"eb_entropy_bound":0.8,"kv_cache":"off"}`))
})
})
})
// ---------------------------------------------------------------------------
// Gated backend round-trip against the real libdllm.so + tiny GGUF fixture.
// ---------------------------------------------------------------------------
var _ = Describe("Dllm backend (real tiny model)", func() {
BeforeEach(func() {
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the backend round-trip")
}
ensureLibLoaded()
Expect(libLoadErr).ToNot(HaveOccurred())
})
It("round-trips Load, PredictRich, PredictStreamRich and TokenizeString", func() {
d := &Dllm{}
Expect(d.Load(&pb.ModelOptions{ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL")})).To(Succeed())
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
// TokenizeString: tiny fixture vocab tokenizes "hello" to [2,18].
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).ToNot(HaveOccurred())
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
Expect(resp.Length).To(Equal(int32(2)))
req := &pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
Tokens: 16,
Seed: 7,
}
// Non-streaming: the tiny random-weight model emits arbitrary vocab
// words; with no gemma4 markers in them everything is content.
reply, err := d.PredictRich(req)
Expect(err).ToNot(HaveOccurred())
Expect(string(reply.GetMessage())).ToNot(BeEmpty())
Expect(reply.GetChatDeltas()).ToNot(BeEmpty())
// Streaming: at least one reply, and the channel-ownership rule is
// honored (drainReplies fails the spec on a closed channel).
ch := make(chan *pb.Reply, 64)
Expect(d.PredictStreamRich(req, ch)).To(Succeed())
replies := drainReplies(ch)
Expect(replies).ToNot(BeEmpty())
var streamed string
for _, r := range replies {
streamed += string(r.GetMessage())
}
Expect(streamed).ToNot(BeEmpty())
})
It("aborts an in-flight generation promptly on Cancel", func() {
d := &Dllm{}
// eb_max_steps inflates the per-block denoise loop so the full run
// takes ~10s on the tiny fixture (vs ~40ms at engine defaults; 16
// blocks, first block after ~0.7s) - long enough that a prompt
// post-cancel return is distinguishable from the generation simply
// finishing.
Expect(d.Load(&pb.ModelOptions{
ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL"),
Options: []string{"eb_max_steps:256"},
})).To(Succeed())
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
ch := make(chan *pb.Reply, 64)
errCh := make(chan error, 1)
go func() {
defer GinkgoRecover()
errCh <- d.PredictStreamRich(&pb.PredictOptions{Prompt: "hello", Tokens: 256, Seed: 7}, ch)
}()
// Cancel only once the first block proves the generate is in
// flight: the C side resets the cancel flag on generate entry, so
// an earlier Cancel would be swallowed (dllm_capi.h race note).
Eventually(ch, "60s").Should(Receive())
cancelAt := time.Now()
d.Cancel()
// Uncancelled, ~10s of generation remain; the cancelled call must
// come back in milliseconds (the flag is checked per denoise step).
var genErr error
Eventually(errCh, "5s").Should(Receive(&genErr))
latency := time.Since(cancelAt)
Expect(genErr).To(MatchError(ContainSubstring("cancelled")))
GinkgoWriter.Printf("dllm cancel: PredictStreamRich returned %v after Cancel\n", latency)
})
})

View File

@@ -1,562 +0,0 @@
// Gemma4 (DiffusionGemma) streaming output parser: raw model text, fed in
// arbitrary fragments (per committed diffusion block; a fragment can split
// anywhere, including mid-marker and mid-payload), is turned into
// pb.ChatDelta events (content / reasoning_content / tool_calls).
//
// Normative sources:
// - The chat template embedded at the top of gemma4_renderer.go ("tpl L<n>"
// citations below refer to its numbered lines). The OUTPUT format mirrors
// what the template renders for assistant history: thought channels
// (<|channel>thought\n ... <channel|>, tpl L240), tool calls
// (<|tool_call>call:name{...}<tool_call|>, tpl L246-L257) and turn ends
// (<turn|>, tpl L351).
// - vLLM PR #45163: vllm/tool_parsers/gemma4_tool_parser.py (marker
// handling, the call:name{...} argument grammar and its decoder, ported
// below) and vllm/reasoning/gemma4_reasoning_parser.py (channel markers,
// the "thought\n" role label, is_reasoning_end semantics).
//
// Initial state (derived from the generation prompt, tpl L356-L362, see
// RenderGemma4):
// - enable_thinking=false: the prompt ends with "<|turn>model\n" +
// "<|channel>thought\n<channel|>" - an EMPTY thought channel, pre-opened
// AND pre-closed by the template. The model's output therefore starts in
// plain content. Use NewGemma4Parser(false).
// - enable_thinking=true: the prompt ends at "<|turn>model\n" and the model
// opens and closes its own thought channel in the OUTPUT
// ("<|channel>thought\n...reasoning...<channel|>final answer", per the
// vLLM Gemma4ReasoningParser docstring). The parser still starts in
// content state - the channel markers in the output drive the switch.
// Use NewGemma4Parser(false) here too.
// - NewGemma4Parser(true) is for callers that pre-open the thought channel
// in the prompt themselves (appending "<|channel>thought\n" after the
// generation prompt to force thinking): the output then begins mid-thought
// and everything is reasoning until the first <channel|>.
//
// State diagram (markers are consumed, never emitted):
//
// <|channel> \n (channel name dropped: the
// [content] --------------> [chan-header] ----> [thought] "thought\n" role
// ^ | <channel|> (stray close: swallowed, label, stripped
// +-+ strip_thinking semantics, tpl L148-L158) like vLLM does)
// ^ <channel|>
// +----------------------------------------- [thought]
// ^ <tool_call|> | <|tool_call> (implicit
// +-------------- [tool-call] <-------------------+ reasoning end, vLLM
// | <|tool_call> ^ is_reasoning_end)
// +-------------------+
// [content]/[thought] --- <turn|> ---> [done] (everything after is dropped)
//
// Buffering rules:
// - content/thought states hold back at most len(longest marker)-1 bytes:
// the longest tail that is still a proper prefix of a watched marker.
// Content is otherwise emitted immediately (no unbounded buffering).
// - the tool-call state buffers the whole payload until <tool_call|>. This
// is unbounded in principle but bounded in practice by the model's
// diffusion canvas, and is required because the call:name{...} payload
// only becomes decodable (and trustworthy) once complete - the same
// reason vLLM's parser accumulates before parsing.
// - Close() flushes whatever is still held: partial markers come out as
// content/reasoning (per the state that held them); an unterminated
// channel header or tool-call payload is re-emitted RAW (including its
// opening marker) as content - malformed output is never silently
// dropped (mirrors vLLM extract_tool_calls returning the raw text as
// content when its regex does not match).
//
// Streaming granularity DIVERGENCE from vLLM: vLLM re-parses the partial
// payload on every token and streams argument-JSON diffs (its `partial=True`
// decoder mode plus withholding logic exist only for that). Our fragments are
// whole committed diffusion blocks, so each completed tool call is emitted
// once, as a single ToolCallDelta carrying index + id + name + the full
// arguments JSON - exactly the shape backend/python/vllm/backend.py emits
// per call and pkg/functions.ToolCallsFromChatDeltas re-accumulates.
package main
import (
"encoding/json"
"regexp"
"strconv"
"strings"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// gemma4CallRE is vLLM's tool_call_regex
// (`<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>`, DOTALL) anchored to
// a single already-extracted payload: name charset [\w\-.], braces mandatory.
var gemma4CallRE = regexp.MustCompile(`(?s)^call:([\w\-.]+)\{(.*)\}$`)
type g4State int
const (
g4Content g4State = iota
g4ChanHeader
g4Thought
g4ToolCall
g4Done
)
// Markers watched per emitting state. A stray <tool_call|> outside a tool
// call is deliberately NOT watched: it passes through verbatim, consistent
// with the malformed-payload fallback re-emitting it as content.
var (
gemma4ContentMarkers = []string{gemma4ChannelOpen, gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
gemma4ThoughtMarkers = []string{gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
)
type Gemma4Parser struct {
state g4State
// held is the per-state carry-over between Feed calls: a partial marker
// (content/thought), a partial channel header (chan-header) or the
// payload accumulated so far (tool-call).
held string
toolIdx int
}
// NewGemma4Parser returns a parser positioned per the initial-state rules in
// the header comment: startInThought=true only when the caller pre-opened a
// thought channel in the prompt.
func NewGemma4Parser(startInThought bool) *Gemma4Parser {
state := g4Content
if startInThought {
state = g4Thought
}
return &Gemma4Parser{state: state}
}
// Feed consumes the next output fragment and returns the deltas it completes.
func (p *Gemma4Parser) Feed(text string) []*pb.ChatDelta {
if text == "" || p.state == g4Done {
return nil
}
pending := p.held + text
p.held = ""
var em g4Emitter
for pending != "" {
switch p.state {
case g4Content, g4Thought:
markers := gemma4ContentMarkers
if p.state == g4Thought {
markers = gemma4ThoughtMarkers
}
idx, marker := findEarliestGemma4Marker(pending, markers)
if idx == -1 {
hold := gemma4MarkerHoldback(pending, markers)
p.emitText(&em, pending[:len(pending)-hold])
p.held = pending[len(pending)-hold:]
pending = ""
continue
}
p.emitText(&em, pending[:idx])
pending = pending[idx+len(marker):]
switch marker {
case gemma4ChannelOpen:
p.state = g4ChanHeader
case gemma4ChannelClose:
// In thought: channel ends. In content: stray close,
// swallowed (strip_thinking keeps both sides, tpl L148-L158).
p.state = g4Content
case gemma4ToolCallOpen:
p.state = g4ToolCall
case gemma4TurnEnd:
p.state = g4Done
}
case g4ChanHeader:
// The channel header is "<name>\n"; the template only ever writes
// "thought" (tpl L240/L360) and the label is structural, so it is
// dropped, not emitted (vLLM strips the same "thought\n" prefix).
nl := strings.IndexByte(pending, '\n')
if nl == -1 {
p.held = pending
pending = ""
continue
}
pending = pending[nl+1:]
p.state = g4Thought
case g4ToolCall:
end := strings.Index(pending, gemma4ToolCallClose)
if end == -1 {
p.held = pending
pending = ""
continue
}
p.emitToolCall(&em, pending[:end])
pending = pending[end+len(gemma4ToolCallClose):]
p.state = g4Content
case g4Done:
pending = ""
}
}
return em.deltas
}
// Close flushes held-back partials. Incomplete structures (open channel
// header, unterminated tool payload) are re-emitted raw as content rather
// than dropped. The parser is finished afterwards.
func (p *Gemma4Parser) Close() []*pb.ChatDelta {
var em g4Emitter
switch p.state {
case g4Content:
em.content(p.held)
case g4Thought:
em.reasoning(p.held)
case g4ChanHeader:
em.content(gemma4ChannelOpen + p.held)
case g4ToolCall:
em.content(gemma4ToolCallOpen + p.held)
case g4Done:
}
p.held = ""
p.state = g4Done
return em.deltas
}
func (p *Gemma4Parser) emitText(em *g4Emitter, s string) {
if p.state == g4Thought {
em.reasoning(s)
return
}
em.content(s)
}
// emitToolCall decodes one complete <|tool_call>...<tool_call|> payload. On a
// payload that does not match call:name{...} the raw text (markers included)
// is emitted as content, mirroring vLLM's extract_tool_calls fallback.
func (p *Gemma4Parser) emitToolCall(em *g4Emitter, payload string) {
m := gemma4CallRE.FindStringSubmatch(payload)
if m == nil {
em.content(gemma4ToolCallOpen + payload + gemma4ToolCallClose)
return
}
// Index-based ids: deterministic (the split-invariance property relies
// on it) and matching the call_<n> convention of pkg/grpc/rich_test.go;
// core only needs ids to be non-empty and unique within the response.
em.tool(p.toolIdx, "call_"+strconv.Itoa(p.toolIdx), m[1], decodeGemma4Args(m[2], 0))
p.toolIdx++
}
// g4Emitter collects ChatDeltas; empty text events are dropped.
type g4Emitter struct {
deltas []*pb.ChatDelta
}
func (e *g4Emitter) content(s string) {
if s != "" {
e.deltas = append(e.deltas, &pb.ChatDelta{Content: s})
}
}
func (e *g4Emitter) reasoning(s string) {
if s != "" {
e.deltas = append(e.deltas, &pb.ChatDelta{ReasoningContent: s})
}
}
func (e *g4Emitter) tool(index int, id, name, argsJSON string) {
e.deltas = append(e.deltas, &pb.ChatDelta{ToolCalls: []*pb.ToolCallDelta{{
Index: int32(index),
Id: id,
Name: name,
Arguments: argsJSON,
}}})
}
// findEarliestGemma4Marker returns the position and value of the first
// complete marker occurrence, or (-1, "").
func findEarliestGemma4Marker(s string, markers []string) (int, string) {
best, bestMarker := -1, ""
for _, m := range markers {
if idx := strings.Index(s, m); idx >= 0 && (best == -1 || idx < best) {
best, bestMarker = idx, m
}
}
return best, bestMarker
}
// gemma4MarkerHoldback returns the length of the longest suffix of s that is
// a proper prefix of a watched marker - the only bytes that may still grow
// into a marker and therefore must not be emitted yet (bounded by the
// longest marker, so content is never buffered unboundedly).
func gemma4MarkerHoldback(s string, markers []string) int {
maxHold := 0
for _, m := range markers {
if len(m)-1 > maxHold {
maxHold = len(m) - 1
}
}
if len(s) < maxHold {
maxHold = len(s)
}
for k := maxHold; k >= 1; k-- {
tail := s[len(s)-k:]
for _, m := range markers {
if strings.HasPrefix(m, tail) {
return k
}
}
}
return 0
}
// ---------------------------------------------------------------------------
// call:name{...} argument decoder
//
// Port of vLLM's _parse_gemma4_args / _parse_gemma4_array /
// _parse_gemma4_value (gemma4_tool_parser.py) in non-partial mode only: this
// parser decodes exclusively COMPLETE payloads (incomplete ones fall back to
// raw content at Close), so vLLM's partial-withholding machinery
// (trailing-dot floats, withheld bare tails) is intentionally not ported.
//
// Grammar (inverse of the renderer's formatGemma4Argument, tpl L118-L147):
//
// args := pair (',' pair)*
// pair := key ':' value (keys unquoted, up to the first ':')
// value := string | object | array | bare
// string := '<|"|>' ... '<|"|>' (no escapes; unterminated -> rest)
// object := '{' args '}' (delimited strings skipped when
// array := '[' value,* ']' counting braces/brackets)
// bare := true | false | null/none/nil | number | bare-string
//
// Output is a JSON object/array string with keys in payload order (Python
// dict insertion order), built with HTML escaping off so payload text
// survives byte-for-byte.
// ---------------------------------------------------------------------------
func isGemma4Space(c byte) bool { return c == ' ' || c == '\n' || c == '\t' }
// gemma4MaxArgsDepth caps the mutual recursion between decodeGemma4Args and
// decodeGemma4Array. Defense against model-generated deep nesting: a Go stack
// overflow is a fatal process kill, not a recoverable error, so past the cap
// a nested body gracefully degrades to a JSON string of its raw text.
const gemma4MaxArgsDepth = 100
// decodeGemma4Args decodes one args body (the text between the outer braces
// of call:name{...}) into a JSON object string. depth is the current nesting
// level (0 at the payload root); see gemma4MaxArgsDepth.
func decodeGemma4Args(s string, depth int) string {
if depth > gemma4MaxArgsDepth {
return gemma4JSONString(s)
}
var b strings.Builder
b.WriteString("{")
first := true
pair := func(key, val string) {
if !first {
b.WriteString(",")
}
first = false
b.WriteString(gemma4JSONString(key))
b.WriteString(":")
b.WriteString(val)
}
i, n := 0, len(s)
for i < n {
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
i++
}
if i >= n {
break
}
keyStart := i
for i < n && s[i] != ':' {
i++
}
if i >= n {
break // no ':' -> trailing junk, dropped (vLLM does the same)
}
key := strings.TrimSpace(s[keyStart:i])
i++ // skip ':'
for i < n && isGemma4Space(s[i]) {
i++
}
if i >= n {
pair(key, `""`) // "key:" with nothing after -> empty string
break
}
switch {
case strings.HasPrefix(s[i:], gemma4StringDelim):
i += len(gemma4StringDelim)
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
pair(key, gemma4JSONString(s[i:])) // unterminated -> take rest
i = n
} else {
pair(key, gemma4JSONString(s[i:i+end]))
i += end + len(gemma4StringDelim)
}
case s[i] == '{':
inner, next := scanGemma4Balanced(s, i, '{', '}')
pair(key, decodeGemma4Args(inner, depth+1))
i = next
case s[i] == '[':
inner, next := scanGemma4Balanced(s, i, '[', ']')
pair(key, decodeGemma4Array(inner, depth+1))
i = next
default:
valStart := i
for i < n && s[i] != ',' && s[i] != '}' && s[i] != ']' {
i++
}
if i == valStart {
// No progress (value starts on a stray '}'/']'): abort on
// malformed input rather than loop, like vLLM.
i = n
continue
}
pair(key, decodeGemma4Bare(s[valStart:i]))
}
}
b.WriteString("}")
return b.String()
}
// decodeGemma4Array decodes one array body (the text between '[' and ']')
// into a JSON array string. depth is the current nesting level; see
// gemma4MaxArgsDepth.
func decodeGemma4Array(s string, depth int) string {
if depth > gemma4MaxArgsDepth {
return gemma4JSONString(s)
}
var b strings.Builder
b.WriteString("[")
first := true
item := func(val string) {
if !first {
b.WriteString(",")
}
first = false
b.WriteString(val)
}
i, n := 0, len(s)
for i < n {
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
i++
}
if i >= n {
break
}
switch {
case strings.HasPrefix(s[i:], gemma4StringDelim):
i += len(gemma4StringDelim)
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
item(gemma4JSONString(s[i:]))
i = n
} else {
item(gemma4JSONString(s[i : i+end]))
i += end + len(gemma4StringDelim)
}
case s[i] == '{':
inner, next := scanGemma4Balanced(s, i, '{', '}')
item(decodeGemma4Args(inner, depth+1))
i = next
case s[i] == '[':
inner, next := scanGemma4Balanced(s, i, '[', ']')
item(decodeGemma4Array(inner, depth+1))
i = next
default:
valStart := i
for i < n && s[i] != ',' && s[i] != ']' {
i++
}
if i == valStart {
i = n // no progress: abort on malformed input, like vLLM
continue
}
item(decodeGemma4Bare(s[valStart:i]))
}
}
b.WriteString("]")
return b.String()
}
// scanGemma4Balanced scans a brace/bracket-balanced span starting at the
// opener s[start], skipping over <|"|>-delimited strings so structural
// characters inside them do not count (vLLM's depth scan). Returns the inner
// text and the index just past the closer; an unterminated span yields the
// rest of the string (the inner decoder still extracts what is there - this
// path is only reachable from genuinely malformed complete payloads).
func scanGemma4Balanced(s string, start int, open, close byte) (string, int) {
depth := 1
i := start + 1
innerStart := i
n := len(s)
for i < n && depth > 0 {
if strings.HasPrefix(s[i:], gemma4StringDelim) {
i += len(gemma4StringDelim)
if nd := strings.Index(s[i:], gemma4StringDelim); nd == -1 {
i = n
} else {
i += nd + len(gemma4StringDelim)
}
continue
}
switch s[i] {
case open:
depth++
case close:
depth--
}
i++
}
if depth > 0 {
return s[innerStart:], n
}
return s[innerStart : i-1], i
}
// decodeGemma4Bare maps an undelimited value to its JSON form: booleans,
// null aliases (null/none/nil, case-insensitive - the renderer writes
// Python None as "None", tpl L144-L145 via format_argument's else branch),
// numbers (vLLM's rule: a '.' tries float, otherwise int; anything that
// fails parses as a bare string).
func decodeGemma4Bare(raw string) string {
v := strings.TrimSpace(raw)
if v == "" {
return `""`
}
if v == "true" || v == "false" {
return v
}
switch strings.ToLower(v) {
case "null", "none", "nil":
return "null"
}
if strings.Contains(v, ".") {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return formatGemma4Float(f)
}
} else if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
return strconv.FormatInt(iv, 10)
}
return gemma4JSONString(v)
}
// formatGemma4Float renders like Python's json.dumps(float): integral floats
// keep a ".0" suffix ("108." decodes to 108.0, not 108), so the arguments
// JSON matches what vLLM would have produced for the same payload.
func formatGemma4Float(f float64) string {
s := strconv.FormatFloat(f, 'g', -1, 64)
if !strings.ContainsAny(s, ".eE") {
s += ".0"
}
return s
}
// gemma4JSONString encodes a JSON string WITHOUT HTML escaping (json.Marshal
// would escape the angle brackets in "<div>" to \u003c / \u003e sequences;
// payload text should survive
// byte-for-byte, like Python's json.dumps(ensure_ascii=False)).
func gemma4JSONString(s string) string {
var sb strings.Builder
enc := json.NewEncoder(&sb)
enc.SetEscapeHTML(false)
if err := enc.Encode(s); err != nil {
// Unreachable for plain strings; fall back to default escaping
// rather than emitting invalid JSON.
b, mErr := json.Marshal(s)
if mErr != nil {
return `""`
}
return string(b)
}
// Encode appends a trailing newline.
return strings.TrimSuffix(sb.String(), "\n")
}

View File

@@ -1,592 +0,0 @@
package main
// Parser specs for Gemma4Parser (model output text -> pb.ChatDelta events).
//
// Fixture provenance:
// - Entries marked "vLLM: <name>" are direct ports of the named test from
// vLLM PR #45163, tests/tool_parsers/test_gemma4_tool_parser.py (the
// authoritative test-suite for the gemma4 tool-call wire format). The
// streaming tests' chunk lists are reused verbatim as Feed fragments.
// - Decoder entries port the TestParseGemma4Args / TestParseGemma4Array
// classes from the same file (non-partial mode only; this parser never
// decodes partial payloads, see the divergence note in gemma4_parser.go).
// - Channel/turn-marker expectations come from the chat template embedded
// in gemma4_renderer.go (tpl L356-L362 generation prompt, L148-L158
// strip_thinking) and vLLM's Gemma4ReasoningParser
// (vllm/reasoning/gemma4_reasoning_parser.py).
import (
"encoding/json"
"fmt"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// flatGemma4Tool is one accumulated tool call, mirroring how LocalAI core
// folds ToolCallDelta streams (pkg/functions/chat_deltas.go
// ToolCallsFromChatDeltas: name/id latch on first non-empty, arguments
// concatenate per index). Tests flatten through the same rules so they
// assert exactly what core will reconstruct.
type flatGemma4Tool struct {
id string
name string
args string
}
func flattenGemma4Deltas(deltas []*pb.ChatDelta) (string, string, []flatGemma4Tool) {
var content, reasoning strings.Builder
byIndex := map[int32]*flatGemma4Tool{}
maxIdx := int32(-1)
for _, d := range deltas {
content.WriteString(d.GetContent())
reasoning.WriteString(d.GetReasoningContent())
for _, tc := range d.GetToolCalls() {
acc, ok := byIndex[tc.GetIndex()]
if !ok {
acc = &flatGemma4Tool{}
byIndex[tc.GetIndex()] = acc
}
if tc.GetName() != "" {
acc.name = tc.GetName()
}
if tc.GetId() != "" {
acc.id = tc.GetId()
}
acc.args += tc.GetArguments()
if tc.GetIndex() > maxIdx {
maxIdx = tc.GetIndex()
}
}
}
var tools []flatGemma4Tool
for i := int32(0); i <= maxIdx; i++ {
if acc, ok := byIndex[i]; ok {
tools = append(tools, *acc)
}
}
return content.String(), reasoning.String(), tools
}
type wantGemma4Tool struct {
name string
argsJSON string // compared with MatchJSON (key order irrelevant)
}
type parseGemma4Case struct {
startInThought bool
fragments []string
wantContent string
wantReasoning string
wantTools []wantGemma4Tool
}
func parseGemma4Fragments(startInThought bool, fragments []string) []*pb.ChatDelta {
p := NewGemma4Parser(startInThought)
var all []*pb.ChatDelta
for _, f := range fragments {
all = append(all, p.Feed(f)...)
}
return append(all, p.Close()...)
}
var _ = Describe("Gemma4Parser", func() {
DescribeTable("parses streamed gemma4 output into ChatDeltas",
func(c parseGemma4Case) {
content, reasoning, tools := flattenGemma4Deltas(parseGemma4Fragments(c.startInThought, c.fragments))
Expect(content).To(Equal(c.wantContent))
Expect(reasoning).To(Equal(c.wantReasoning))
Expect(tools).To(HaveLen(len(c.wantTools)))
seenIDs := map[string]bool{}
for i, want := range c.wantTools {
Expect(tools[i].name).To(Equal(want.name), "tool %d name", i)
Expect(tools[i].args).To(MatchJSON(want.argsJSON), "tool %d arguments", i)
Expect(tools[i].id).ToNot(BeEmpty(), "tool %d id", i)
Expect(seenIDs).ToNot(HaveKey(tools[i].id), "tool %d id must be unique", i)
seenIDs[tools[i].id] = true
}
},
// --- (1) pure content -------------------------------------------------
// vLLM: test_no_tool_calls
Entry("pure content, single fragment", parseGemma4Case{
fragments: []string{"Hello, how can I help you today?"},
wantContent: "Hello, how can I help you today?",
}),
// --- (2) thought -> final transition ----------------------------------
// enable_thinking render: prompt ends at <|turn>model\n and the model
// opens/closes its own thought channel in the OUTPUT (vLLM
// Gemma4ReasoningParser docstring; tpl L356-L362). The "thought\n"
// role label after <|channel> is structural and must be stripped
// (vLLM _THOUGHT_PREFIX handling).
Entry("thought channel then final content", parseGemma4Case{
fragments: []string{"<|channel>thought\nLet me think about this.\n<channel|>The answer is 42."},
wantReasoning: "Let me think about this.\n",
wantContent: "The answer is 42.",
}),
// --- (3) startInThought both ways -------------------------------------
Entry("startInThought=true routes initial text to reasoning until <channel|>", parseGemma4Case{
startInThought: true,
fragments: []string{"I am thinking hard.<channel|>Done."},
wantReasoning: "I am thinking hard.",
wantContent: "Done.",
}),
// A stray <channel|> with no open channel is swallowed, matching the
// template's strip_thinking (tpl L148-L158: the marker is dropped,
// text on both sides is kept).
Entry("startInThought=false keeps the same text as content, stray <channel|> swallowed", parseGemma4Case{
startInThought: false,
fragments: []string{"I am thinking hard.<channel|>Done."},
wantContent: "I am thinking hard.Done.",
}),
// --- (4) one tool call, full payload type zoo --------------------------
Entry("single tool call: strings, numbers, bools, null, nested object and array", parseGemma4Case{
fragments: []string{`<|tool_call>call:complex_function{text:<|"|>with, comma and {braces}<|"|>,count:42,score:3.14,yes:true,no:false,nothing:null,obj:{inner:<|"|>v<|"|>,k:1},arr:[<|"|>a<|"|>,2,true]}<tool_call|>`},
wantTools: []wantGemma4Tool{{
name: "complex_function",
argsJSON: `{"text":"with, comma and {braces}","count":42,"score":3.14,"yes":true,"no":false,"nothing":null,"obj":{"inner":"v","k":1},"arr":["a",2,true]}`,
}},
}),
// --- (5) payload split across 3 fragments ------------------------------
Entry("tool-call payload split across three fragments", parseGemma4Case{
fragments: []string{
"<|tool_call>call:get_weather{loc",
`ation:<|"|>Paris, Fra`,
`nce<|"|>}<tool_call|>`,
},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
}),
// --- (6) marker split across fragments ----------------------------------
Entry("tool-call open marker split across fragments", parseGemma4Case{
fragments: []string{
"<|tool_ca",
`ll>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
}),
Entry("channel open marker split across fragments", parseGemma4Case{
fragments: []string{
"<|chan",
"nel>thought\ndeep thought<channel|>final",
},
wantReasoning: "deep thought",
wantContent: "final",
}),
// --- (7) trailing partial marker held, flushed by Close -----------------
Entry("trailing partial marker is held back and flushed by Close", parseGemma4Case{
fragments: []string{"Hello <|tool"},
wantContent: "Hello <|tool",
}),
// --- (8) malformed/incomplete payload -> content fallback ---------------
// vLLM: test_incomplete_tool_call (no end marker: the whole text stays
// content, never silently dropped).
Entry("incomplete tool payload at Close is emitted as raw content", parseGemma4Case{
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London`},
wantContent: `<|tool_call>call:get_weather{location:<|"|>London`,
}),
Entry("malformed complete payload is emitted as raw content, parsing continues", parseGemma4Case{
fragments: []string{"<|tool_call>oops no call syntax<tool_call|> done"},
wantContent: "<|tool_call>oops no call syntax<tool_call|> done",
}),
// --- (9) <turn|> ends the turn -------------------------------------------
Entry("text after <turn|> is ignored, including later fragments", parseGemma4Case{
fragments: []string{
"before<turn|>after",
`more <|tool_call>call:f{}<tool_call|>`,
},
wantContent: "before",
}),
Entry("<turn|> inside a thought channel ends the turn", parseGemma4Case{
startInThought: true,
fragments: []string{"thinking<turn|>ignored"},
wantReasoning: "thinking",
}),
// --- (10) ported vLLM non-streaming cases ---------------------------------
// vLLM: test_single_tool_call
Entry("vLLM: test_single_tool_call", parseGemma4Case{
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
}),
// vLLM: test_multiple_arguments
Entry("vLLM: test_multiple_arguments", parseGemma4Case{
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"San Francisco","unit":"celsius"}`}},
}),
// vLLM: test_text_before_tool_call. DIVERGENCE: vLLM's non-streaming
// extractor trims the content ("...you."); a streaming parser cannot
// retroactively trim already-emitted text, so the trailing space is
// kept (vLLM's own streaming path keeps it too, see
// test_streaming_text_before_tool_call which only checks a prefix).
Entry("vLLM: test_text_before_tool_call (streaming semantics: no trim)", parseGemma4Case{
fragments: []string{`Let me check the weather for you. <|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`},
wantContent: "Let me check the weather for you. ",
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
}),
// vLLM: test_multiple_tool_calls (also covers case 11: multi-tool sequence)
Entry("vLLM: test_multiple_tool_calls", parseGemma4Case{
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|><|tool_call>call:get_time{location:<|"|>London<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{
{name: "get_weather", argsJSON: `{"location":"London"}`},
{name: "get_time", argsJSON: `{"location":"London"}`},
},
}),
// vLLM: test_nested_arguments
Entry("vLLM: test_nested_arguments", parseGemma4Case{
fragments: []string{`<|tool_call>call:complex_function{nested:{inner:<|"|>value<|"|>},list:[<|"|>a<|"|>,<|"|>b<|"|>]}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "complex_function", argsJSON: `{"nested":{"inner":"value"},"list":["a","b"]}`}},
}),
// vLLM: test_tool_call_with_number_and_boolean
Entry("vLLM: test_tool_call_with_number_and_boolean", parseGemma4Case{
fragments: []string{`<|tool_call>call:set_status{is_active:true,count:42,score:3.14}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "set_status", argsJSON: `{"is_active":true,"count":42,"score":3.14}`}},
}),
// vLLM: test_hyphenated_function_name
Entry("vLLM: test_hyphenated_function_name", parseGemma4Case{
fragments: []string{`<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "get-weather", argsJSON: `{"location":"London"}`}},
}),
// vLLM: test_dotted_function_name
Entry("vLLM: test_dotted_function_name", parseGemma4Case{
fragments: []string{`<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "weather.get", argsJSON: `{"location":"London"}`}},
}),
// vLLM: test_no_arguments
Entry("vLLM: test_no_arguments", parseGemma4Case{
fragments: []string{"<|tool_call>call:get_status{}<tool_call|>"},
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
}),
// --- ported vLLM streaming cases (chunk lists reused as fragments) --------
// vLLM: test_basic_streaming_single_tool
Entry("vLLM: test_basic_streaming_single_tool", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:get_weather{",
`location:<|"|>Paris`,
", France",
`<|"|>}`,
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
}),
// vLLM: test_streaming_multi_arg
Entry("vLLM: test_streaming_multi_arg", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:get_weather{",
`location:<|"|>Tokyo<|"|>,`,
`unit:<|"|>celsius<|"|>}`,
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Tokyo","unit":"celsius"}`}},
}),
// vLLM: test_streaming_text_before_tool_call
Entry("vLLM: test_streaming_text_before_tool_call", parseGemma4Case{
fragments: []string{
"Let me check ",
"the weather. ",
"<|tool_call>",
"call:get_weather{",
`location:<|"|>London<|"|>}`,
"<tool_call|>",
},
wantContent: "Let me check the weather. ",
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
}),
// vLLM: test_streaming_numeric_args
Entry("vLLM: test_streaming_numeric_args", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:set_config{",
"count:42,",
"active:true}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "set_config", argsJSON: `{"count":42,"active":true}`}},
}),
// vLLM: test_streaming_boolean_split_across_chunks
Entry("vLLM: test_streaming_boolean_split_across_chunks", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:search{input:{all:tru",
"e}}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "search", argsJSON: `{"input":{"all":true}}`}},
}),
// vLLM: test_streaming_false_split_across_chunks
Entry("vLLM: test_streaming_false_split_across_chunks", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:set{flag:fals",
"e}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"flag":false}`}},
}),
// vLLM: test_streaming_number_split_across_chunks
Entry("vLLM: test_streaming_number_split_across_chunks", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:set{count:4",
"2}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"count":42}`}},
}),
// vLLM: test_streaming_empty_args
Entry("vLLM: test_streaming_empty_args", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:get_status{}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
}),
// vLLM: test_streaming_split_delimiter_no_invalid_json (string
// delimiter <|"|> split across fragments must not leak fragments).
Entry("vLLM: test_streaming_split_delimiter_no_invalid_json", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:todowrite{",
`content:<|"|>Buy milk<|`,
`"|>}`,
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "todowrite", argsJSON: `{"content":"Buy milk"}`}},
}),
// vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call
Entry("vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:get_weather{",
`location:<|"|>Paris<|"|>}`,
"<tool_call|><",
"div>",
},
wantContent: "<div>",
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
}),
// vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes
Entry("vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:write_file{",
`path:<|"|>index.html<|"|>,`,
`content:<|"|><!DOCTYPE html>` + "\n<",
`html lang="zh-CN">` + "\n<",
"head>\n <",
`meta charset="UTF-8">` + "\n <",
`meta name="viewport" content="width=device-width">` + "\n",
`<|"|>}`,
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{
name: "write_file",
argsJSON: `{"path":"index.html","content":"<!DOCTYPE html>\n<html lang=\"zh-CN\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width\">\n"}`,
}},
}),
// vLLM: test_streaming_single_chunk_complete_tool_call
Entry("vLLM: test_streaming_single_chunk_complete_tool_call", parseGemma4Case{
fragments: []string{`<|tool_call>call:name_a_color{color_hex:<|"|>00ff11<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "name_a_color", argsJSON: `{"color_hex":"00ff11"}`}},
}),
// vLLM: test_streaming_multi_chunk_batched_tool_calls (two complete
// calls in ONE fragment; both must come out with distinct indices)
Entry("vLLM: test_streaming_multi_chunk_batched_tool_calls", parseGemma4Case{
fragments: []string{
`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>` +
`<|tool_call>call:get_time{timezone:<|"|>GMT<|"|>}<tool_call|>`,
},
wantTools: []wantGemma4Tool{
{name: "get_weather", argsJSON: `{"location":"London"}`},
{name: "get_time", argsJSON: `{"timezone":"GMT"}`},
},
}),
// vLLM: test_streaming_trailing_bare_bool_not_duplicated
Entry("vLLM: test_streaming_trailing_bare_bool_not_duplicated", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:Edit{",
`file_path:<|"|>src/env.py<|"|>,`,
`old_string:<|"|>old_val<|"|>,`,
`new_string:<|"|>new_val<|"|>,`,
"replace_all:",
"false}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{
name: "Edit",
argsJSON: `{"file_path":"src/env.py","old_string":"old_val","new_string":"new_val","replace_all":false}`,
}},
}),
// --- implicit reasoning end on <|tool_call> (vLLM is_reasoning_end:
// a tool_call token means reasoning is over) -----------------------------
Entry("tool call inside an open thought channel ends the reasoning", parseGemma4Case{
startInThought: true,
fragments: []string{`need the weather<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}<tool_call|>`},
wantReasoning: "need the weather",
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Rome"}`}},
}),
// --- (12) empty fragments are no-ops --------------------------------------
Entry("empty fragments are no-ops", parseGemma4Case{
fragments: []string{"", "Hello", "", "", " world", ""},
wantContent: "Hello world",
}),
)
It("returns no deltas for an empty fragment and after Close", func() {
p := NewGemma4Parser(false)
Expect(p.Feed("")).To(BeEmpty())
Expect(p.Feed("hi")).ToNot(BeEmpty())
Expect(p.Close()).To(BeEmpty()) // nothing held back
// The parser is finished after Close: further input is dropped.
Expect(p.Feed("more")).To(BeEmpty())
Expect(p.Close()).To(BeEmpty())
})
It("generates index-based tool call ids (call_<index>)", func() {
// Mirrors the index-based id convention of pkg/grpc/rich_test.go and
// keeps ids deterministic for the split-invariance property below.
deltas := parseGemma4Fragments(false, []string{
`<|tool_call>call:a{}<tool_call|><|tool_call>call:b{}<tool_call|>`,
})
_, _, tools := flattenGemma4Deltas(deltas)
Expect(tools).To(HaveLen(2))
Expect(tools[0].id).To(Equal("call_0"))
Expect(tools[1].id).To(Equal("call_1"))
})
// Property: for a fixed full output, EVERY 2-split position must yield
// exactly the same flattened result as the unsplit parse. This kills
// fragment-boundary bugs (mid-marker, mid-delimiter, mid-payload splits).
DescribeTable("2-split fragment invariance",
func(startInThought bool, full string) {
refContent, refReasoning, refTools := flattenGemma4Deltas(
parseGemma4Fragments(startInThought, []string{full}))
for i := 0; i <= len(full); i++ {
content, reasoning, tools := flattenGemma4Deltas(
parseGemma4Fragments(startInThought, []string{full[:i], full[i:]}))
Expect(content).To(Equal(refContent), fmt.Sprintf("content diverged at split %d", i))
Expect(reasoning).To(Equal(refReasoning), fmt.Sprintf("reasoning diverged at split %d", i))
Expect(tools).To(Equal(refTools), fmt.Sprintf("tool calls diverged at split %d", i))
}
},
Entry("thought + content + two tool calls + turn end", false,
"<|channel>thought\nPondering the request...\n<channel|>Sure - calling tools now. "+
`<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>,days:3,detailed:true}<tool_call|>`+
`<|tool_call>call:get_time{timezone:<|"|>Europe/Lisbon<|"|>,nested:{flag:false,vals:[1,2.5,<|"|>x<|"|>]}}<tool_call|>`+
"Done.<turn|>ignored tail"),
Entry("startInThought + tool call + trailing partial marker", true,
`Deep thought<channel|>final answer <|tool_call>call:noop{}<tool_call|> trailing <|tool`),
Entry("malformed payload fallback", false,
`pre <|tool_call>not a call<tool_call|> post`),
)
})
// Decoder-level ports of vLLM's TestParseGemma4Args / TestParseGemma4Array
// (non-partial mode; the partial-withholding tests do not apply because this
// parser only ever decodes COMPLETE payloads, see gemma4_parser.go).
var _ = Describe("decodeGemma4Args", func() {
DescribeTable("decodes the gemma4 call syntax into JSON arguments",
func(in, wantJSON string) {
Expect(decodeGemma4Args(in, 0)).To(MatchJSON(wantJSON))
},
// vLLM: test_empty_string / test_whitespace_only
Entry("empty string", "", `{}`),
Entry("whitespace only", " ", `{}`),
// vLLM: test_single_string_value
Entry("single string value", `location:<|"|>Paris<|"|>`, `{"location":"Paris"}`),
// vLLM: test_string_value_with_comma
Entry("string value with comma", `location:<|"|>Paris, France<|"|>`, `{"location":"Paris, France"}`),
// vLLM: test_multiple_string_values
Entry("multiple string values", `location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>`, `{"location":"San Francisco","unit":"celsius"}`),
// vLLM: test_integer_value / test_float_value
Entry("integer value", "count:42", `{"count":42}`),
Entry("float value", "score:3.14", `{"score":3.14}`),
// vLLM: test_boolean_true / test_boolean_false
Entry("boolean true", "flag:true", `{"flag":true}`),
Entry("boolean false", "flag:false", `{"flag":false}`),
// vLLM: test_null_value (bare null must become JSON null, not "null")
Entry("null value", "param:null", `{"param":null}`),
// vLLM: test_mixed_types
Entry("mixed types", `name:<|"|>test<|"|>,count:42,active:true,score:3.14`,
`{"name":"test","count":42,"active":true,"score":3.14}`),
// vLLM: test_nested_object
Entry("nested object", `nested:{inner:<|"|>value<|"|>}`, `{"nested":{"inner":"value"}}`),
// vLLM: test_array_of_strings
Entry("array of strings", `items:[<|"|>a<|"|>,<|"|>b<|"|>]`, `{"items":["a","b"]}`),
// vLLM: test_unterminated_string (take everything after the delimiter)
Entry("unterminated string", `key:<|"|>unterminated`, `{"key":"unterminated"}`),
// vLLM: test_empty_value (key with no value after colon)
Entry("empty value", "key:", `{"key":""}`),
// vLLM: test_trailing_dot_float_partial_withheld, non-partial branch
// (trailing-dot floats parse normally outside streaming).
Entry("trailing dot float, complete payload", "left:108.,right:22.8", `{"left":108.0,"right":22.8}`),
)
It("terminates and yields valid JSON on malformed input", func() {
// vLLM: test_malformed_partial_array (the assertion there is only
// "returns a dict without hanging"; ours is "valid JSON object").
out := decodeGemma4Args(":[t:[]", 0)
var v map[string]any
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
})
It("degrades nesting beyond the recursion cap to a string value", func() {
// 200 levels of a:{a:{...a:1...}}. Without the depth cap the mutual
// recursion would grow the stack with the model's output; a Go stack
// overflow is a fatal process kill, so levels past gemma4MaxArgsDepth
// must gracefully fall back to the raw inner text as a JSON string.
const depth = 200
body := strings.Repeat("a:{", depth-1) + "a:1" + strings.Repeat("}", depth-1)
out := decodeGemma4Args(body, 0)
var v map[string]any
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
levels := 0
var cur any = v
for {
m, ok := cur.(map[string]any)
if !ok {
break
}
Expect(m).To(HaveKey("a"))
cur = m["a"]
levels++
}
Expect(levels).To(Equal(gemma4MaxArgsDepth + 1))
Expect(cur).To(BeAssignableToTypeOf(""))
Expect(cur).To(ContainSubstring("a:{"))
})
})
var _ = Describe("decodeGemma4Array", func() {
DescribeTable("decodes gemma4 array bodies into JSON arrays",
func(in, wantJSON string) {
Expect(decodeGemma4Array(in, 0)).To(MatchJSON(wantJSON))
},
// vLLM: test_string_array / test_empty_array / test_bare_values
Entry("string array", `<|"|>a<|"|>,<|"|>b<|"|>`, `["a","b"]`),
Entry("empty array", "", `[]`),
Entry("bare values", "42,true,3.14", `[42,true,3.14]`),
// vLLM: test_string_element_with_closing_bracket (a ']' inside a
// delimited string must not close the array)
Entry("string element with closing bracket", `[<|"|>a]b<|"|>,<|"|>c<|"|>],<|"|>tail<|"|>`, `[["a]b","c"],"tail"]`),
// vLLM: test_stray_closing_bracket (no-progress abort, keep prefix)
Entry("stray closing bracket", "42,]trailing", `[42]`),
)
})

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,347 +0,0 @@
package main
// Renderer specs for RenderGemma4 against the canonical gemma4 chat template
// (see the normative template comment in gemma4_renderer.go).
//
// Fixture provenance:
// - "single user message" and "enable_thinking" are the EXACT expected
// decodes from transformers tests/models/diffusion_gemma/
// test_modeling_diffusion_gemma.py (test_diffusion_gemma_chat_template
// and ..._with_thinking) with ONE difference: the transformers fixtures
// start with "<bos>" because apply_chat_template tokenizes the rendered
// text with add_bos. Our prompt goes through dllm_capi_generate, whose
// run_generate already tokenizes with prepend_bos = vocab.add_bos
// (dllm.cpp src/capi.cpp:230-231, true for gemma4), so the renderer must
// NOT emit a literal <bos> (it would double) and every expected string
// here drops that leading token.
// - All other expected strings were produced by rendering the verbatim
// GGUF template with jinja2 3.1.2 (bos_token="<bos>") and dropping the
// leading "<bos>" for the same reason.
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// Two-function tools array used by the tool fixtures (OpenAI wire shape, as
// LocalAI passes it through PredictOptions.Tools).
const testToolsJSON = `[{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a location.","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city name."},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}},{"type":"function","function":{"name":"get_time","description":"Get the current time in a timezone.","parameters":{"type":"object","properties":{"timezone":{"type":"string","description":"IANA timezone name."}},"required":["timezone"]}}}]`
// The <|tool>...<tool|> block the template renders for testToolsJSON inside
// the system turn (jinja2-verified).
const testToolsBlock = `<|tool>declaration:get_weather{description:<|"|>Get the current weather in a location.<|"|>,parameters:{properties:{location:{description:<|"|>The city name.<|"|>,type:<|"|>STRING<|"|>},unit:{enum:[<|"|>celsius<|"|>,<|"|>fahrenheit<|"|>],type:<|"|>STRING<|"|>}},required:[<|"|>location<|"|>],type:<|"|>OBJECT<|"|>}}<tool|><|tool>declaration:get_time{description:<|"|>Get the current time in a timezone.<|"|>,parameters:{properties:{timezone:{description:<|"|>IANA timezone name.<|"|>,type:<|"|>STRING<|"|>}},required:[<|"|>timezone<|"|>],type:<|"|>OBJECT<|"|>}}<tool|>`
// A single tool exercising the deep format_parameters branches: array items
// (string-typed and nested-array), nullable, enum+nullable, nested object
// properties/required, and a response declaration.
const complexToolsJSON = `[{"type":"function","function":{"name":"complex_tool","description":"A complex tool.","parameters":{"type":"object","properties":{"tags":{"type":"array","description":"Tags.","items":{"type":"string"}},"matrix":{"type":"array","items":{"type":"array","items":{"type":"number"}}},"opts":{"type":"object","description":"Options.","properties":{"depth":{"type":"integer","nullable":true}},"required":["depth"]},"mode":{"type":"string","enum":["a","b"],"nullable":true}},"required":["tags","opts"]},"response":{"description":"The result.","type":"object"}}}]`
// jinja2-verified render of complexToolsJSON. Notable template quirks pinned
// here: nested array items go through format_argument with ESCAPED keys and
// an un-uppercased type (<|"|>type<|"|>:<|"|>number<|"|>), while direct item
// types are uppercased; properties dictsort case-insensitively.
const complexToolsBlock = `<|tool>declaration:complex_tool{description:<|"|>A complex tool.<|"|>,parameters:{properties:{matrix:{items:{items:{<|"|>type<|"|>:<|"|>number<|"|>},type:<|"|>ARRAY<|"|>},type:<|"|>ARRAY<|"|>},mode:{enum:[<|"|>a<|"|>,<|"|>b<|"|>],nullable:true,type:<|"|>STRING<|"|>},opts:{description:<|"|>Options.<|"|>,properties:{depth:{nullable:true,type:<|"|>INTEGER<|"|>}},required:[<|"|>depth<|"|>],type:<|"|>OBJECT<|"|>},tags:{description:<|"|>Tags.<|"|>,items:{type:<|"|>STRING<|"|>},type:<|"|>ARRAY<|"|>}},required:[<|"|>tags<|"|>,<|"|>opts<|"|>],type:<|"|>OBJECT<|"|>},response:{description:<|"|>The result.<|"|>,type:<|"|>OBJECT<|"|>}}<tool|>`
type renderGemma4Case struct {
msgs []*pb.Message
toolsJSON string
enableThinking bool
noGenerationPrompt bool // inverted so the zero value is the common case
expected string
}
var _ = Describe("RenderGemma4", func() {
DescribeTable("renders the canonical gemma4 prompt",
func(c renderGemma4Case) {
out, err := RenderGemma4(c.msgs, c.toolsJSON, c.enableThinking, !c.noGenerationPrompt)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal(c.expected))
// The C-ABI generate prepends BOS itself: a literal <bos>
// anywhere in the rendered prompt would double-encode it.
Expect(out).ToNot(ContainSubstring("<bos>"))
},
// transformers fixture (test_diffusion_gemma_chat_template), sans <bos>:
// default thinking pre-opens an EMPTY thought channel in the
// generation prompt.
Entry("single user message, default (no thinking)", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "Write a long essay about Portugal."},
},
expected: "<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// transformers fixture (test_diffusion_gemma_chat_template_with_thinking),
// sans <bos>: a system turn carrying <|think|> and NO auto-opened
// thought channel.
Entry("enable_thinking=true", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "Write a long essay about Portugal."},
},
enableThinking: true,
expected: "<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n",
}),
Entry("multi-turn user/assistant/user", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "Hello, who are you?"},
{Role: "assistant", Content: "I am Gemma, a helpful assistant."},
{Role: "user", Content: "Tell me a joke."},
},
expected: "<|turn>user\nHello, who are you?<turn|>\n<|turn>model\nI am Gemma, a helpful assistant.<turn|>\n<|turn>user\nTell me a joke.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// tpl L178-L195: a leading system message is folded into the system
// turn (trimmed) and consumed from the loop.
Entry("system message folds into the system turn", renderGemma4Case{
msgs: []*pb.Message{
{Role: "system", Content: "You are a pirate."},
{Role: "user", Content: "Hello!"},
},
expected: "<|turn>system\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// tpl L182-L185: <|think|> goes at the very top of the SAME system
// turn, before the system prompt text.
Entry("system message with enable_thinking shares the turn", renderGemma4Case{
msgs: []*pb.Message{
{Role: "system", Content: "You are a pirate."},
{Role: "user", Content: "Hello!"},
},
enableThinking: true,
expected: "<|turn>system\n<|think|>\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n",
}),
// tpl L196-L203: tool declarations render in the system turn, one
// <|tool>declaration:...<tool|> block per tool, no separators.
Entry("tools array (two functions)", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "What is the weather in Tokyo?"},
},
toolsJSON: testToolsJSON,
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// format_parameters deep branches (tpl L1-L85) + response declaration
// (tpl L106-L116).
Entry("complex tool schema (array items, nullable, nested object, response)", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
},
toolsJSON: complexToolsJSON,
expected: "<|turn>system\n" + complexToolsBlock + "<turn|>\n<|turn>user\ngo<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// tpl L243-L313: assistant tool_calls render as
// <|tool_call>call:name{args}<tool_call|>; the following role=tool
// message renders inline as <|tool_response>response:name{value:..}
// <tool_response|>; the model turn stays OPEN (no <turn|>, no new
// generation prompt) so the model continues after the response.
Entry("assistant tool_calls + role=tool result", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "What is the weather in Tokyo?"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
{Role: "tool", ToolCallId: "call_1", Content: "Sunny, 22 degrees celsius."},
},
toolsJSON: testToolsJSON,
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny, 22 degrees celsius.<|"|>}<tool_response|>`,
}),
// tpl L348-L349: a tool_calls turn with no rendered responses ends
// on an OPEN <|tool_response> marker for the runtime to fill, and
// add_generation_prompt adds nothing (tpl L357).
Entry("assistant tool_calls without a result leaves <|tool_response> open", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "What is the weather in Tokyo?"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
},
toolsJSON: testToolsJSON,
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>`,
}),
// tpl L237-L241: reasoning_content renders as a thought channel only
// on a tool-calling turn after the last user message.
Entry("reasoning_content with tool_calls renders the thought channel", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "weather?"},
{Role: "assistant", Content: "", ReasoningContent: "I should call the tool", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
},
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n<|channel>thought\nI should call the tool\n<channel|>" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>`,
}),
// tpl L220-L235: the assistant answer following its own tool round
// continues the SAME model turn (no second <|turn>model).
Entry("tool round then final assistant answer then user", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "weather?"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
{Role: "assistant", Content: "It is sunny."},
{Role: "user", Content: "thanks"},
},
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>` + "It is sunny.<turn|>\n<|turn>user\nthanks<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// format_argument (tpl L118-L147): numbers keep their JSON literal,
// booleans lower-case, nested maps have unquoted dictsorted keys,
// arrays bracketed; top-level args are dictsorted case-insensitively.
Entry("tool_call argument types (number/bool/nested/array)", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"count\":42,\"ratio\":3.5,\"flag\":true,\"off\":false,\"nested\":{\"x\":\"y\",\"n\":7},\"list\":[\"a\",1,true]}"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{count:42,flag:true,list:[<|"|>a<|"|>,1,true],nested:{n:7,x:<|"|>y<|"|>},off:false,ratio:3.5}<tool_call|><|tool_response>`,
}),
// jinja dictsort is case-insensitive: alpha sorts before Beta.
Entry("tool_call argument dictsort is case-insensitive", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"Beta\":1,\"alpha\":2}"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{alpha:2,Beta:1}<tool_call|><|tool_response>",
}),
// jinja renders Python None as "None" (round-trips through vLLM's
// parser, which lowers "none" back to null).
Entry("tool_call null argument renders as None", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"maybe\":null}"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{maybe:None}<tool_call|><|tool_response>",
}),
Entry("tool_call empty arguments render empty braces", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{}<tool_call|><|tool_response>",
}),
// tpl L253-L254: a non-object arguments string renders verbatim.
Entry("tool_call non-object string arguments render verbatim", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"just text"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{just text}<tool_call|><|tool_response>",
}),
// tpl L278-L285: unmatched tool_call_id falls back to the tool
// message's own name.
Entry("tool result name falls back when tool_call_id does not match", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
{Role: "tool", ToolCallId: "OTHER", Name: "named_tool", Content: "out"},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{}<tool_call|><|tool_response>response:named_tool{value:<|"|>out<|"|>}<tool_response|>`,
}),
// strip_thinking (tpl L148-L158): historical assistant content loses
// its <|channel>...<channel|> spans.
Entry("assistant content thinking channels are stripped", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "hi"},
{Role: "assistant", Content: "<|channel>thought\nsecret\n<channel|>visible answer"},
{Role: "user", Content: "more"},
},
expected: "<|turn>user\nhi<turn|>\n<|turn>model\nvisible answer<turn|>\n<|turn>user\nmore<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// tpl L220-L235: consecutive assistant messages suppress the second
// <|turn>model (continuation), but each still closes with <turn|>.
Entry("consecutive assistant messages continue the model turn", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "hi"},
{Role: "assistant", Content: "part one"},
{Role: "assistant", Content: "part two"},
{Role: "user", Content: "ok"},
},
expected: "<|turn>user\nhi<turn|>\n<|turn>model\npart one<turn|>\npart two<turn|>\n<|turn>user\nok<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
Entry("add_generation_prompt=false renders no model turn", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "hi"},
},
noGenerationPrompt: true,
expected: "<|turn>user\nhi<turn|>\n",
}),
)
Describe("error handling", func() {
It("fails loud on an unknown role", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "narrator", Content: "Meanwhile..."},
}, "", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(`unknown role "narrator"`))
})
It("fails on invalid tools JSON", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
}, "{not json", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tools JSON"))
})
It("fails on invalid tool_calls JSON", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
{Role: "assistant", Content: "", ToolCalls: "{not json"},
}, "", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tool_calls JSON"))
})
It("fails on an orphan tool message, naming its index", func() {
// A role:tool message with no preceding assistant tool_calls turn
// would be silently dropped by the jinja; we fail loud instead.
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
{Role: "tool", Content: `{"temp": 20}`, ToolCallId: "call_1"},
}, "", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("orphan tool message 1"))
})
It("fails on trailing garbage after the tools JSON array", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
}, "[] junk", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tools JSON"))
})
It("fails when the tools JSON is not an array", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
}, `{"type":"function"}`, false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tools JSON is not an array"))
})
It("fails when a tools array element is not an object", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
}, `[42]`, false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tools[0] is not an object"))
})
It("rejects a nil message via the unknown-role check", func() {
// Pins current behavior: pb getters are nil-safe, so a nil message
// reads as role "" and trips the fail-loud unknown-role guard.
_, err := RenderGemma4([]*pb.Message{nil}, "", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(`unknown role "" in message 0`))
})
})
})

View File

@@ -1,85 +0,0 @@
package main
// Started internally by LocalAI - one gRPC server per loaded model.
//
// Loads libdllm.so via purego and registers the 9-symbol flat C-ABI
// declared in dllm.cpp's include/dllm_capi.h (ABI v1). The library name can
// be overridden with DLLM_LIBRARY (mirrors the PARAKEET_LIBRARY /
// WHISPER_LIBRARY convention in the sibling backends); the default looks
// for the .so next to this binary (run.sh puts the package dir on
// LD_LIBRARY_PATH).
import (
"flag"
"fmt"
"os"
"github.com/ebitengine/purego"
grpc "github.com/mudler/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
type LibFuncs struct {
FuncPtr any
Name string
}
// loadCAPI dlopens libName and binds the 9 dllm_capi_* entry points 1:1 to
// dllm_capi.h, so an `nm libdllm.so | grep dllm_capi` is enough to spot
// drift. Shared with the test suite (ensureLibLoaded), which drives the
// bridge without the gRPC server.
//
// The C-ABI returns malloc'd char* buffers from tokenize_json/generate; we
// register those as uintptr so we get the raw pointer back and can call
// dllm_capi_free_string on it (purego's string return would copy and forget
// the original pointer, leaking it on every call). last_error returns a
// BORROWED pointer instead, so it is registered as a plain string: purego
// copies it and nothing must be freed.
func loadCAPI(libName string) error {
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
return fmt.Errorf("dllm: dlopen %q: %w", libName, err)
}
libFuncs := []LibFuncs{
{&cppAbiVersion, "dllm_capi_abi_version"},
{&cppLoad, "dllm_capi_load"},
{&cppFree, "dllm_capi_free"},
{&cppLastError, "dllm_capi_last_error"},
{&cppFreeString, "dllm_capi_free_string"},
{&cppTokenizeJSON, "dllm_capi_tokenize_json"},
{&cppGenerate, "dllm_capi_generate"},
{&cppGenerateStream, "dllm_capi_generate_stream"},
{&cppCancel, "dllm_capi_cancel"},
}
for _, lf := range libFuncs {
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
}
return nil
}
func main() {
libName := os.Getenv("DLLM_LIBRARY")
if libName == "" {
libName = "libdllm.so"
}
if err := loadCAPI(libName); err != nil {
panic(err)
}
// Hard-fail on an ABI mismatch: the flat-pointer bindings above would
// otherwise misbehave silently against a future libdllm.so.
if v := cAbiVersion(); v != dllmABIVersion {
panic(fmt.Errorf("dllm: libdllm.so ABI=%d, this backend speaks ABI=%d", v, dllmABIVersion))
}
fmt.Fprintf(os.Stderr, "[dllm] ABI=%d\n", cAbiVersion())
flag.Parse()
if err := grpc.StartServer(*addr, &Dllm{}); err != nil {
panic(err)
}
}

View File

@@ -1,24 +0,0 @@
#!/bin/bash
#
# T1 packaging stub: copy the binary, run.sh and libdllm.so into package/.
# The full ldd walk (libc, libstdc++, libgomp, GPU runtimes, arch
# detection) lands with the registration task, mirroring
# backend/go/whisper/package.sh.
set -e
CURDIR=$(dirname "$(realpath "$0")")
mkdir -p "$CURDIR/package/lib"
cp -avf "$CURDIR/dllm-grpc" "$CURDIR/package/"
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
# libdllm.so + any soname symlinks, should upstream ever add them.
cp -avf "$CURDIR"/libdllm.so* "$CURDIR/package/lib/" 2>/dev/null || {
echo "ERROR: libdllm.so not found in $CURDIR, run 'make' first" >&2
exit 1
}
echo "T1 package layout (full ldd walk lands with registration):"
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"

View File

@@ -1,16 +0,0 @@
#!/bin/bash
set -e
CURDIR=$(dirname "$(realpath "$0")")
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
# If a self-contained ld.so was packaged, route through it so the
# packaged libc / libstdc++ are used instead of the host's (matches the
# whisper / parakeet-cpp backends' runtime layout).
if [ -f "$CURDIR/lib/ld.so" ]; then
echo "Using lib/ld.so"
exec "$CURDIR/lib/ld.so" "$CURDIR/dllm-grpc" "$@"
fi
exec "$CURDIR/dllm-grpc" "$@"

View File

@@ -1,6 +1,6 @@
# parakeet-cpp backend Makefile.
#
# Upstream pin lives below as PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
# Upstream pin lives below as PARAKEET_VERSION?=8a7c48209d7882a7ce79a6b306270e4703194543
# (.github/bump_deps.sh) can find and update it - matches the
# whisper.cpp / ds4 / vibevoice-cpp convention.
#
@@ -15,7 +15,7 @@
# That's what the L0 smoke test uses. The default target below does the
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
PARAKEET_VERSION?=8a7c48209d7882a7ce79a6b306270e4703194543
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
GOCMD?=go

View File

@@ -7,12 +7,8 @@ import "time"
type batchRequest struct {
pcm []float32
decoder int32
// language is the per-request target locale ("" means the model default).
// parakeet.cpp's batched C-API takes ONE target_lang for the whole batch,
// so the dispatcher only coalesces requests that share a language.
language string
tag string
reply chan batchReply
tag string
reply chan batchReply
}
// batchReply carries one per-item JSON object string (an element of the C-API's
@@ -47,25 +43,13 @@ func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchReques
// run is the dispatcher loop: accumulate submitted requests until either maxSize
// is reached or maxWait elapses since the first queued request, then dispatch.
// Exits when stop is closed (draining any partially-filled batch first).
//
// A batch carries ONE language (parakeet.cpp's batched C-API takes a single
// target_lang), so a request whose language differs from the batch leader is
// not coalesced: it is held in carry and becomes the leader of the next batch.
// carry is therefore never dropped and its caller never deadlocks: every batch
// (including a lone carry on stop) is dispatched, and runBatch replies to all.
func (b *batcher) run(stop <-chan struct{}) {
var carry *batchRequest
for {
var first *batchRequest
if carry != nil {
// A mismatched request from the previous fill leads this batch.
first, carry = carry, nil
} else {
select {
case first = <-b.submit:
case <-stop:
return
}
select {
case first = <-b.submit:
case <-stop:
return
}
batch := []*batchRequest{first}
@@ -80,22 +64,12 @@ func (b *batcher) run(stop <-chan struct{}) {
for len(batch) < b.maxSize {
select {
case r := <-b.submit:
if r.language != first.language {
// Different language: carry it to the next batch so this
// batch stays single-language, then dispatch what we have.
carry = r
break fill
}
batch = append(batch, r)
case <-timer.C:
break fill
case <-stop:
timer.Stop()
b.runBatch(batch)
// Don't strand a carried request's caller on shutdown.
if carry != nil {
b.runBatch([]*batchRequest{carry})
}
return
}
}

View File

@@ -105,60 +105,4 @@ var _ = Describe("batcher", func() {
go func() { <-rep }()
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
})
It("never coalesces requests with different languages into one batch", func() {
// parakeet.cpp's batched C-API takes ONE target_lang per batch, so the
// dispatcher must keep every dispatched batch single-language. Submit a
// mix of languages and assert (a) no batch ever carries more than one
// distinct language and (b) every submitted request still gets a reply
// (the mismatched carry-over is never dropped).
var mu sync.Mutex
var langsPerBatch [][]string
run := func(reqs []*batchRequest) {
seen := map[string]struct{}{}
var distinct []string
for _, r := range reqs {
if _, ok := seen[r.language]; !ok {
seen[r.language] = struct{}{}
distinct = append(distinct, r.language)
}
}
mu.Lock()
langsPerBatch = append(langsPerBatch, distinct)
mu.Unlock()
echoReply(reqs)
}
// Large window + size so the fill loop stays open across submits and the
// language constraint (not the timer) is what splits the batches.
b := newBatcher(16, 200*time.Millisecond, run)
stop := make(chan struct{})
go b.run(stop)
defer close(stop)
langs := []string{"en", "en", "de", "de", "en", "fr", "fr"}
const N = 7
var wg sync.WaitGroup
got := make([]string, N)
for i := 0; i < N; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
rep := make(chan batchReply, 1)
b.submit <- &batchRequest{tag: string(rune('a' + i)), language: langs[i], reply: rep}
got[i] = (<-rep).json
}(i)
}
wg.Wait()
mu.Lock()
defer mu.Unlock()
// Invariant: every dispatched batch is single-language.
for _, distinct := range langsPerBatch {
Expect(len(distinct)).To(Equal(1), "a batch coalesced more than one language: %v", distinct)
}
// Liveness: every request got a reply (carry-over never stranded).
for i := 0; i < N; i++ {
Expect(got[i]).To(Equal(string(rune('a' + i))))
}
})
})

View File

@@ -15,7 +15,6 @@ import (
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
@@ -48,13 +47,6 @@ var (
// side reads them as const float*/const int*.
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
// CppTranscribePcmBatchJSONLang is the multilingual variant of the batched
// JSON entry point: identical, plus a trailing target_lang. "" (the model
// default, "auto") is passed for non-prompt models, which ignore it; an
// unknown locale on a prompt model returns 0 and sets last_error. Present
// only in newer libparakeet.so; nil falls back to CppTranscribePcmBatchJSON.
CppTranscribePcmBatchJSONLang func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32, targetLang string) uintptr
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
@@ -62,18 +54,6 @@ var (
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
CppStreamFinalize func(s uintptr) uintptr
CppStreamFree func(s uintptr)
// CppStreamBeginLang is the multilingual variant of stream_begin: identical,
// plus a trailing target_lang ("" means the model default). Present only in
// newer libparakeet.so; nil falls back to CppStreamBegin.
CppStreamBeginLang func(ctx uintptr, targetLang string) uintptr
// Streaming JSON variants (ABI v4): feed/finalize returning a malloc'd char*
// JSON document {text,eou,frame_sec,words} (uintptr, freed via CppFreeString)
// so streaming segments can carry per-word timestamps. Present only in newer
// libparakeet.so; nil falls back to the text-only CppStreamFeed/Finalize path.
CppStreamFeedJSON func(s uintptr, pcm []float32, nSamples int32) uintptr
CppStreamFinalizeJSON func(s uintptr) uintptr
)
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
@@ -91,26 +71,9 @@ const streamChunkSamples = 16000
//
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
type transcriptJSON struct {
Text string `json:"text"`
FrameSec float64 `json:"frame_sec"`
Words []transcriptWord `json:"words"`
Tokens []transcriptToken `json:"tokens"`
}
// streamFeedJSON mirrors the document returned by
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v4):
//
// {"text":"...","eou":0,"frame_sec":0.080000,
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
//
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
// <EOU>/<EOB> fired this feed; "words" are the words finalized this call with
// absolute (stream-relative) start/end seconds.
type streamFeedJSON struct {
Text string `json:"text"`
Eou int `json:"eou"`
FrameSec float64 `json:"frame_sec"`
Words []transcriptWord `json:"words"`
Text string `json:"text"`
Words []transcriptWord `json:"words"`
Tokens []transcriptToken `json:"tokens"`
}
type transcriptWord struct {
@@ -139,10 +102,6 @@ type ParakeetCpp struct {
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
bat *batcher
batStop chan struct{}
// segmentGapFrames is NeMo's segment_gap_threshold in ENCODER FRAMES (model
// YAML option, default 0=off). When >0 it adds NeMo's silence-gap split on
// top of the punctuation split; converted to seconds via the JSON frame_sec.
segmentGapFrames int
}
// Load is the LocalAI gRPC entry point for LoadModel: it calls
@@ -172,11 +131,6 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
if maxWaitMs < 0 {
maxWaitMs = 0
}
// NeMo's segment_gap_threshold (encoder frames, default 0=off). Off by
// default matches NeMo's default (punctuation-only segments); when set it
// additionally splits segments on inter-word silence (see transcriptResultFromDoc).
p.segmentGapFrames = optInt(opts, "segment_gap_threshold", 0)
if CppTranscribePcmBatchJSON != nil {
p.batStop = make(chan struct{})
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
@@ -232,19 +186,8 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
if len(reqs) > 0 {
dec = reqs[0].decoder
}
// All requests in a batch share one language (the batcher coalesces only
// same-language requests), so any element's language describes the batch.
lang := ""
if len(reqs) > 0 {
lang = reqs[0].language
}
p.engineMu.Lock()
var cstr uintptr
if CppTranscribePcmBatchJSONLang != nil {
cstr = CppTranscribePcmBatchJSONLang(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec, lang)
} else {
cstr = CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
}
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
p.engineMu.Unlock()
if cstr == 0 {
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
@@ -282,31 +225,21 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
// OpenAI API, whose default is segment-level); token ids always populate
// Segment.Tokens.
//
// translate/diarize/prompt/temperature/threads are not applicable to parakeet
// and are ignored; language is honored on the batched + streaming paths (see
// opts.GetLanguage() below); streaming is handled by AudioTranscriptionStream
// translate/diarize/prompt/temperature/language/threads are not applicable to
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
// (L2).
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if p.ctxPtr == 0 {
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
}
if opts.Dst == "" {
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
}
// Fallback when the batched C-API is unavailable: transcribe from a file
// path (original behavior, no batching). The C library's audio loader only
// understands 16 kHz mono WAV/PCM, so convert the input first - otherwise
// any non-WAV upload (MP3, etc.) fails with "failed to load audio". This
// mirrors what every other audio backend (whisper, crispasr) does via
// utils.AudioToWav before handing the file to the engine.
// Fallback when the batched C-API is unavailable: transcribe directly from
// the file path (original behavior, no batching).
if p.bat == nil {
converted, cleanup, err := convertToWavMono16k(opts.Dst)
if err != nil {
return pb.TranscriptResult{}, err
}
defer cleanup()
cstr := CppTranscribePathJSON(p.ctxPtr, converted, 0)
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
if cstr == 0 {
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
}
@@ -316,7 +249,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
}
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
return transcriptResultFromDoc(doc, opts), nil
}
// Batched path: decode to PCM, submit to the batcher, wait for this request's
@@ -328,7 +261,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
}
rep := make(chan batchReply, 1)
select {
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, language: opts.GetLanguage(), reply: rep}:
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
case <-ctx.Done():
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
}
@@ -345,169 +278,34 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
}
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
return transcriptResultFromDoc(doc, opts), nil
}
// segmentSeparators is NeMo's default segment_seperators (sentence-ending
// punctuation). Splitting on these matches NeMo's default segment timestamps.
var segmentSeparators = []rune{'.', '?', '!'}
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
// grouping words into NeMo-faithful segments (see splitWordsIntoSegments). The
// optional gapFrames (NeMo's segment_gap_threshold, in encoder FRAMES; 0=off)
// additionally splits on inter-word silence; it is converted to a seconds gap
// with the document's frame_sec. Per-segment word timings are attached only when
// the caller requested word granularity; token ids populate each segment's
// Tokens by time-window membership. Shared by the batched and direct paths.
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
// synthesising a single whole-clip segment and attaching word timings only when
// the caller requested word granularity. Shared by the batched and direct paths.
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
text := strings.TrimSpace(doc.Text)
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
gapSeconds := 0.0
if gapFrames > 0 {
if doc.FrameSec > 0 {
gapSeconds = float64(gapFrames) * doc.FrameSec
} else {
xlog.Warn("parakeet-cpp: segment_gap_threshold set but libparakeet.so " +
"did not report frame_sec; falling back to punctuation-only segments")
}
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
for _, w := range doc.Words {
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
}
groups := splitWordsIntoSegments(doc.Words, segmentSeparators, gapSeconds)
if len(groups) == 0 {
// No words (edge case): single whole-clip text segment.
return pb.TranscriptResult{
Text: text,
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
}
tokens := make([]int32, 0, len(doc.Tokens))
for _, t := range doc.Tokens {
tokens = append(tokens, t.ID)
}
wantWords := wordsRequested(opts.TimestampGranularities)
segments := make([]*pb.TranscriptSegment, 0, len(groups))
for id, group := range groups {
parts := make([]string, len(group))
for i, gw := range group {
parts[i] = gw.W
}
seg := &pb.TranscriptSegment{
Id: int32(id),
Start: secondsToNanos(group[0].Start),
End: secondsToNanos(group[len(group)-1].End),
Text: strings.TrimSpace(strings.Join(parts, " ")),
Tokens: tokensInWindow(doc.Tokens, group[0].Start, group[len(group)-1].End),
}
if wantWords {
ws := make([]*pb.TranscriptWord, len(group))
for i, gw := range group {
ws[i] = &pb.TranscriptWord{Start: secondsToNanos(gw.Start), End: secondsToNanos(gw.End), Text: gw.W}
}
seg.Words = ws
}
segments = append(segments, seg)
var segStart, segEnd int64
if len(words) > 0 {
segStart = words[0].Start
segEnd = words[len(words)-1].End
}
return pb.TranscriptResult{Text: text, Segments: segments}
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
if wordsRequested(opts.TimestampGranularities) {
seg.Words = words
}
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
}
// splitWordsIntoSegments groups words into segments exactly as NeMo's
// get_segment_offsets does (nemo/collections/asr/parts/utils/timestamp_utils.py).
// Walking the words, it closes a segment when (1) the gap rule is enabled
// (gapSeconds > 0) and the segment already has words and the gap from the
// previous word's end to this word's start is >= gapSeconds - the current word
// then STARTS a new segment - or, checked only when the gap rule did not apply
// (NeMo's elif), (2) the word ends with (or is) a separator, which closes the
// segment INCLUDING that word. Trailing words flush into a final segment.
// gapSeconds <= 0 disables the gap rule, matching NeMo's default
// segment_gap_threshold=None (punctuation-only segments).
func splitWordsIntoSegments(words []transcriptWord, separators []rune, gapSeconds float64) [][]transcriptWord {
var segments [][]transcriptWord
var cur []transcriptWord
for i, word := range words {
gapActive := gapSeconds > 0 && len(cur) > 0
if gapActive && (word.Start-words[i-1].End) >= gapSeconds {
segments = append(segments, cur)
cur = []transcriptWord{word}
continue
}
if !gapActive && endsWithSeparator(word.W, separators) {
cur = append(cur, word)
segments = append(segments, cur)
cur = nil
continue
}
cur = append(cur, word)
}
if len(cur) > 0 {
segments = append(segments, cur)
}
return segments
}
// endsWithSeparator reports whether w's last rune is in separators (matching
// NeMo's `word[-1] in delims or word in delims`).
func endsWithSeparator(w string, separators []rune) bool {
r := []rune(strings.TrimSpace(w))
if len(r) == 0 {
return false
}
last := r[len(r)-1]
for _, s := range separators {
if last == s {
return true
}
}
return false
}
// tokensInWindow returns the ids of tokens whose timestamp t falls in
// [start, end] (inclusive), assigning each token to the segment that spans its
// time. The last segment's end is the last word end, so the final token is
// included.
func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
var ids []int32
for _, t := range tokens {
if t.T >= start && t.T <= end {
ids = append(ids, t.ID)
}
}
return ids
}
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
// is the model's own utterance boundary; each closed segment takes its start/end
// from its first/last accumulated word.
type streamSegmenter struct {
segs []*pb.TranscriptSegment
cur []transcriptWord
nextID int32
}
func (s *streamSegmenter) add(doc streamFeedJSON) {
s.cur = append(s.cur, doc.Words...)
if doc.Eou != 0 {
s.flush()
}
}
func (s *streamSegmenter) flush() {
if len(s.cur) == 0 {
return
}
parts := make([]string, len(s.cur))
for i, w := range s.cur {
parts[i] = w.W
}
s.segs = append(s.segs, &pb.TranscriptSegment{
Id: s.nextID,
Start: secondsToNanos(s.cur[0].Start),
End: secondsToNanos(s.cur[len(s.cur)-1].End),
Text: strings.TrimSpace(strings.Join(parts, " ")),
})
s.nextID++
s.cur = nil
}
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
// wordsRequested reports whether the caller asked for word-level timestamps.
// The OpenAI transcription API gates word timings behind
// timestamp_granularities[] containing "word" and defaults to segment-level
@@ -544,7 +342,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
defer close(results)
if p.ctxPtr == 0 {
return grpcerrors.ModelNotLoaded("parakeet-cpp")
return errors.New("parakeet-cpp: model not loaded")
}
if opts.Dst == "" {
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
@@ -553,12 +351,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
return status.Error(codes.Canceled, "transcription cancelled")
}
var stream uintptr
if CppStreamBeginLang != nil {
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
} else {
stream = CppStreamBegin(p.ctxPtr)
}
stream := CppStreamBegin(p.ctxPtr)
if stream == 0 {
// Not a cache-aware streaming model: run a normal offline
// transcription and emit it as one delta + a closing final result.
@@ -587,14 +380,6 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
return err
}
// ABI v4: when the streaming JSON entry points are present, drive them so the
// per-utterance segments carry per-word start/end timestamps. Falls through to
// the text-only loop below against an older libparakeet.so. Runs under the
// engineMu already held above.
if CppStreamFeedJSON != nil {
return p.streamJSON(ctx, stream, data, duration, results)
}
var (
full strings.Builder
segText strings.Builder
@@ -671,102 +456,21 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
return nil
}
// streamJSON drives the ABI v4 streaming JSON entry points: each feed/finalize
// returns a {text,eou,frame_sec,words} document. The newly-finalized text is
// emitted as a delta (unchanged streaming contract) while words are accumulated
// into per-utterance segments (closed on EOU) so the closing FinalResult carries
// timestamped segments. Runs under engineMu (already held by the caller).
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
duration float32, results chan *pb.TranscriptStreamResponse) error {
var (
full strings.Builder
seg streamSegmenter
)
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
// emits the delta, and routes words through the segmenter.
consume := func(ret uintptr) error {
if ret == 0 {
msg := CppLastError(p.ctxPtr)
if msg == "" {
msg = "unknown error"
}
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
}
raw := goStringFromCPtr(ret)
CppFreeString(ret)
var doc streamFeedJSON
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
}
if doc.Text != "" {
full.WriteString(doc.Text)
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
}
seg.add(doc)
return nil
}
for off := 0; off < len(data); off += streamChunkSamples {
if err := ctx.Err(); err != nil {
return status.Error(codes.Canceled, "transcription cancelled")
}
end := min(off+streamChunkSamples, len(data))
chunk := data[off:end]
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
return err
}
}
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
return err
}
seg.flush() // close any trailing utterance that never saw an EOU
text := strings.TrimSpace(full.String())
segments := seg.segments()
if len(segments) == 0 && text != "" {
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
}
results <- &pb.TranscriptStreamResponse{
FinalResult: &pb.TranscriptResult{
Text: text,
Segments: segments,
Duration: duration,
},
}
return nil
}
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
// float samples plus the clip duration in seconds. Mirrors the whisper
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
// decodes the PCM.
// convertToWavMono16k converts an arbitrary audio file to a 16 kHz mono WAV in
// a fresh temp dir and returns the path together with a cleanup func the caller
// must defer. WAV inputs already at 16 kHz/mono/16-bit are passed through by
// utils.AudioToWav (hardlink/copy), everything else is transcoded via ffmpeg.
// Used by the direct (non-batched) transcription path, which hands a file path
// to the C library's WAV-only audio loader.
func convertToWavMono16k(path string) (string, func(), error) {
dir, err := os.MkdirTemp("", "parakeet")
if err != nil {
return "", func() {}, err
}
cleanup := func() { _ = os.RemoveAll(dir) }
converted := filepath.Join(dir, "converted.wav")
if err := utils.AudioToWav(path, converted); err != nil {
cleanup()
return "", func() {}, err
}
return converted, cleanup, nil
}
func decodeWavMono16k(path string) ([]float32, float32, error) {
converted, cleanup, err := convertToWavMono16k(path)
dir, err := os.MkdirTemp("", "parakeet")
if err != nil {
return nil, 0, err
}
defer cleanup()
defer func() { _ = os.RemoveAll(dir) }()
converted := filepath.Join(dir, "converted.wav")
if err := utils.AudioToWav(path, converted); err != nil {
return nil, 0, err
}
fh, err := os.Open(converted)
if err != nil {

View File

@@ -3,14 +3,11 @@ package main
import (
"context"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"github.com/ebitengine/purego"
"github.com/go-audio/audio"
"github.com/go-audio/wav"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -53,10 +50,6 @@ func ensureLibLoaded() {
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
}
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
})
@@ -77,24 +70,6 @@ func fixturesOrSkip() (string, string) {
return modelPath, audioPath
}
// writeMono16kWav writes `samples` frames of 16 kHz mono 16-bit silence to
// path. The result is already in AudioToWav's target format, so the conversion
// helper copies it through without invoking ffmpeg.
func writeMono16kWav(path string, samples int) {
GinkgoHelper()
f, err := os.Create(path)
Expect(err).ToNot(HaveOccurred())
enc := wav.NewEncoder(f, 16000, 16, 1, 1)
buf := &audio.IntBuffer{
Format: &audio.Format{NumChannels: 1, SampleRate: 16000},
SourceBitDepth: 16,
Data: make([]int, samples),
}
Expect(enc.Write(buf)).To(Succeed())
Expect(enc.Close()).To(Succeed())
Expect(f.Close()).To(Succeed())
}
var _ = Describe("ParakeetCpp", func() {
Context("AudioTranscription", func() {
It("transcribes a WAV via the parakeet C-API", func() {
@@ -111,22 +86,13 @@ var _ = Describe("ParakeetCpp", func() {
Expect(err).ToNot(HaveOccurred())
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
"expected non-empty transcript for %s", audioPath)
// NeMo-faithful segmentation: one or more punctuation-delimited
// segments, each with text and a monotonically-advancing time span.
Expect(res.Segments).ToNot(BeEmpty(), "expected at least one segment")
var prevEnd int64
for i, seg := range res.Segments {
Expect(strings.TrimSpace(seg.Text)).ToNot(BeEmpty(),
"segment %d must have text", i)
Expect(seg.End).To(BeNumerically(">=", seg.Start),
"segment %d end must not precede its start", i)
Expect(seg.Start).To(BeNumerically(">=", prevEnd),
"segments must be in time order")
prevEnd = seg.End
// Default (no granularities) is segment-level: no per-word timings.
Expect(seg.Words).To(BeEmpty(),
"word timings are opt-in via timestamp_granularities")
}
Expect(res.Segments).To(HaveLen(1),
"synthesises a single whole-clip segment")
Expect(res.Segments[0].Text).To(Equal(res.Text),
"single segment text must equal the top-level text")
// Default (no granularities) is segment-level: no per-word timings.
Expect(res.Segments[0].Words).To(BeEmpty(),
"word timings are opt-in via timestamp_granularities")
})
It("emits word-level timestamps when granularity=word", func() {
@@ -142,61 +108,15 @@ var _ = Describe("ParakeetCpp", func() {
TimestampGranularities: []string{"word"},
})
Expect(err).ToNot(HaveOccurred())
Expect(res.Segments).ToNot(BeEmpty())
// With word granularity every segment carries its own words, and each
// segment's span tracks its first/last word; word starts advance
// monotonically across the whole transcript.
totalWords := 0
var prevStart int64 = -1
for i, seg := range res.Segments {
Expect(seg.Words).ToNot(BeEmpty(),
"segment %d must carry per-word timestamps with granularity=word", i)
Expect(seg.Start).To(Equal(seg.Words[0].Start),
"segment %d start tracks its first word", i)
Expect(seg.End).To(Equal(seg.Words[len(seg.Words)-1].End),
"segment %d end tracks its last word", i)
for _, w := range seg.Words {
Expect(w.End).To(BeNumerically(">=", w.Start))
Expect(w.Start).To(BeNumerically(">=", prevStart))
prevStart = w.Start
totalWords++
}
}
Expect(totalWords).To(BeNumerically(">", 0))
Expect(res.Segments[0].Words[0].Start).To(BeNumerically(">=", int64(0)))
})
})
Context("convertToWavMono16k", func() {
// The non-batched transcription path hands a file path to the C
// library's WAV-only audio loader, so it must convert first.
// utils.AudioToWav passes an already-16kHz/mono/16-bit WAV through
// without ffmpeg, which lets us exercise the helper (and the
// regression: the direct path used to skip conversion entirely)
// without a model, the C library, or ffmpeg.
It("returns a decodable 16kHz mono WAV copy and cleans it up", func() {
dir := GinkgoT().TempDir()
src := filepath.Join(dir, "input.wav")
writeMono16kWav(src, 16000) // 1s of silence at 16 kHz
converted, cleanup, err := convertToWavMono16k(src)
Expect(err).ToNot(HaveOccurred())
// It must produce a fresh temp file, not return the original path.
Expect(converted).ToNot(Equal(src))
Expect(converted).To(BeAnExistingFile())
pcm, _, err := decodeWavMono16k(converted)
Expect(err).ToNot(HaveOccurred())
Expect(pcm).To(HaveLen(16000), "round-trips the sample count")
cleanup()
Expect(converted).ToNot(BeAnExistingFile(), "cleanup removes the temp dir")
})
It("errors on a non-existent input rather than passing the path through", func() {
_, _, err := convertToWavMono16k(filepath.Join(GinkgoT().TempDir(), "missing.mp3"))
Expect(err).To(HaveOccurred())
Expect(res.Segments).To(HaveLen(1))
seg := res.Segments[0]
Expect(seg.Words).ToNot(BeEmpty(),
"expected per-word timestamps with granularity=word")
// Monotonic, non-negative timings spanning the segment.
Expect(seg.Words[0].Start).To(BeNumerically(">=", int64(0)))
Expect(seg.End).To(BeNumerically(">=", seg.Start))
Expect(seg.Words[len(seg.Words)-1].End).To(Equal(seg.End),
"segment end tracks the last word")
})
})

View File

@@ -65,25 +65,6 @@ func main() {
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
}
// Per-request language variants (multilingual nemotron). Same probe pattern:
// present only in libparakeet.so built with multilingual support, so the
// backend still loads against an older library and falls back to the
// non-lang batched + streaming entry points (model default / "auto").
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json_lang"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppTranscribePcmBatchJSONLang, lib, "parakeet_capi_transcribe_pcm_batch_json_lang")
}
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_begin_lang"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppStreamBeginLang, lib, "parakeet_capi_stream_begin_lang")
}
// Streaming JSON entry points (ABI v4): surface per-word timestamps on the
// streaming path. Same probe pattern; absent in older libparakeet.so, where
// the backend falls back to the text-only streaming feed.
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
}
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
flag.Parse()

View File

@@ -1,127 +0,0 @@
package main
import (
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func tw(text string, start, end float64) transcriptWord {
return transcriptWord{W: text, Start: start, End: end}
}
var _ = Describe("splitWordsIntoSegments (NeMo get_segment_offsets parity)", func() {
seps := []rune{'.', '?', '!'}
It("splits on sentence-ending punctuation, including the delimiter word", func() {
words := []transcriptWord{tw("hello", 0, 0.4), tw("world.", 0.4, 0.8), tw("bye", 1.0, 1.3)}
segs := splitWordsIntoSegments(words, seps, 0)
Expect(segs).To(HaveLen(2))
Expect(segs[0]).To(HaveLen(2))
Expect(segs[0][1].W).To(Equal("world."))
Expect(segs[1]).To(HaveLen(1))
Expect(segs[1][0].W).To(Equal("bye"))
})
It("keeps a single segment with no terminal punctuation and gap off", func() {
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
segs := splitWordsIntoSegments(words, seps, 0)
Expect(segs).To(HaveLen(1))
})
It("splits on the gap rule when enabled, the gapped word starting the next segment", func() {
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
segs := splitWordsIntoSegments(words, seps, 1.0) // c is 4.6s after b
Expect(segs).To(HaveLen(2))
Expect(segs[0]).To(HaveLen(2)) // a b
Expect(segs[1]).To(HaveLen(1)) // c
Expect(segs[1][0].W).To(Equal("c"))
})
It("checks the gap rule before punctuation (NeMo elif order)", func() {
// "b." would terminate, but c is far after it -> gap closes [a b.] at b.
words := []transcriptWord{tw("a", 0, 0.2), tw("b.", 0.2, 0.4), tw("c", 9.0, 9.2)}
segs := splitWordsIntoSegments(words, seps, 1.0)
Expect(segs).To(HaveLen(2))
Expect(segs[0]).To(HaveLen(2))
Expect(segs[1][0].W).To(Equal("c"))
})
It("still splits on punctuation when the gap rule is enabled but does not fire", func() {
words := []transcriptWord{tw("hi.", 0, 0.4), tw("bye", 0.4, 0.8)}
segs := splitWordsIntoSegments(words, seps, 5.0) // gap never reached
Expect(segs).To(HaveLen(2))
Expect(segs[0][0].W).To(Equal("hi."))
})
It("returns nothing for empty input", func() {
Expect(splitWordsIntoSegments(nil, seps, 0)).To(BeEmpty())
})
})
var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
doc := transcriptJSON{
Text: "hello world. bye now",
FrameSec: 0.08,
Words: []transcriptWord{
{W: "hello", Start: 0.0, End: 0.4},
{W: "world.", Start: 0.4, End: 0.8},
{W: "bye", Start: 1.0, End: 1.3},
{W: "now", Start: 1.3, End: 1.6},
},
Tokens: []transcriptToken{{ID: 1, T: 0.1}, {ID: 2, T: 0.5}, {ID: 3, T: 1.1}, {ID: 4, T: 1.4}},
}
It("emits one segment per punctuation-delimited group with start/end", func() {
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
Expect(res.Segments).To(HaveLen(2))
Expect(res.Segments[0].Text).To(Equal("hello world."))
Expect(res.Segments[0].Start).To(Equal(int64(0)))
Expect(res.Segments[0].End).To(Equal(secondsToNanos(0.8)))
Expect(res.Segments[1].Text).To(Equal("bye now"))
Expect(res.Segments[1].Start).To(Equal(secondsToNanos(1.0)))
Expect(res.Segments[1].Id).To(Equal(int32(1)))
})
It("assigns tokens to the segment whose time window contains them", func() {
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
Expect(res.Segments[0].Tokens).To(Equal([]int32{1, 2}))
Expect(res.Segments[1].Tokens).To(Equal([]int32{3, 4}))
})
It("attaches per-segment words only when word granularity requested", func() {
plain := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
Expect(plain.Segments[0].Words).To(BeEmpty())
withWords := transcriptResultFromDoc(doc, &pb.TranscriptRequest{TimestampGranularities: []string{"word"}}, 0)
Expect(withWords.Segments[0].Words).To(HaveLen(2))
})
It("falls back to a single text segment when there are no words", func() {
res := transcriptResultFromDoc(transcriptJSON{Text: "hi"}, &pb.TranscriptRequest{}, 0)
Expect(res.Segments).To(HaveLen(1))
Expect(res.Segments[0].Text).To(Equal("hi"))
})
})
var _ = Describe("streaming segment assembly", func() {
It("closes a segment with start/end from its words on EOU", func() {
acc := &streamSegmenter{}
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
}})
segs := acc.segments()
Expect(segs).To(HaveLen(1))
Expect(segs[0].Text).To(Equal("hello world"))
Expect(segs[0].Start).To(Equal(int64(0)))
Expect(segs[0].End).To(Equal(secondsToNanos(0.9)))
})
It("buffers words across feeds until EOU", func() {
acc := &streamSegmenter{}
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
Expect(acc.segments()).To(BeEmpty())
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
Expect(acc.segments()).To(HaveLen(1))
Expect(acc.segments()[0].Text).To(Equal("hi there"))
})
})

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# qwen3-tts.cpp version
QWEN3TTS_REPO?=https://github.com/predict-woo/qwen3-tts.cpp
QWEN3TTS_CPP_VERSION?=136e5d36c17083da0321fd96512dc7b263f94a44
QWEN3TTS_CPP_VERSION?=7a762e2ad4bacc6fdda81d81bf10a09ffb546f29
SO_TARGET?=libgoqwen3ttscpp.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
@@ -22,43 +21,6 @@ type Qwen3TtsCpp struct {
threads int
}
// languageNameAliases maps common full language names to the canonical
// two-letter code understood by the C++ language_to_id table.
var languageNameAliases = map[string]string{
"english": "en",
"russian": "ru",
"chinese": "zh",
"japanese": "ja",
"korean": "ko",
"german": "de",
"french": "fr",
"spanish": "es",
"italian": "it",
"portuguese": "pt",
}
// normalizeLanguage coerces a caller-supplied language into the canonical code
// the model expects. It lowercases, trims, strips any region/locale suffix
// (en-US, en_US, ja.JP -> en/ja), and resolves common full names (english -> en).
// An empty input stays empty so the C++ side applies its English default; an
// unrecognized value is returned normalized so C++ can log it and default.
func normalizeLanguage(lang string) string {
lang = strings.ToLower(strings.TrimSpace(lang))
if lang == "" {
return ""
}
// Strip region/locale suffix: keep the segment before the first separator.
if i := strings.IndexAny(lang, "-_."); i >= 0 {
lang = lang[:i]
}
if code, ok := languageNameAliases[lang]; ok {
return code
}
return lang
}
func (q *Qwen3TtsCpp) Load(opts *pb.ModelOptions) error {
// ModelFile is the model directory path (containing GGUF files)
modelDir := opts.ModelFile
@@ -92,7 +54,7 @@ func (q *Qwen3TtsCpp) TTS(req *pb.TTSRequest) error {
dst := req.Dst
language := ""
if req.Language != nil {
language = normalizeLanguage(*req.Language)
language = *req.Language
}
// Synthesis parameters with sensible defaults

View File

@@ -1,53 +0,0 @@
package main
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLanguageNormalization(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "qwen3-tts-cpp language normalization")
}
var _ = Describe("normalizeLanguage", func() {
DescribeTable("maps caller input to the canonical model language code",
func(input, expected string) {
Expect(normalizeLanguage(input)).To(Equal(expected))
},
// Canonical codes pass through unchanged
Entry("canonical en", "en", "en"),
Entry("canonical zh", "zh", "zh"),
Entry("canonical pt", "pt", "pt"),
// Case-insensitive
Entry("uppercase", "EN", "en"),
Entry("mixed case", "Ja", "ja"),
// Surrounding whitespace
Entry("trims whitespace", " en ", "en"),
// Region/locale stripping
Entry("BCP-47 region", "en-US", "en"),
Entry("underscore region", "en_US", "en"),
Entry("dotted locale", "ja.JP", "ja"),
Entry("region + case", "ZH-CN", "zh"),
// Full-name aliases
Entry("english name", "english", "en"),
Entry("chinese name cased", "Chinese", "zh"),
Entry("japanese name", "japanese", "ja"),
Entry("russian name", "russian", "ru"),
Entry("portuguese name", "portuguese", "pt"),
// Empty stays empty (C++ applies the English default)
Entry("empty", "", ""),
Entry("whitespace only", " ", ""),
// Unknown values pass through normalized so C++ can log + default
Entry("unknown code", "klingon", "klingon"),
Entry("unknown with region", "xx-YY", "xx"),
)
})

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=19bdfe22d255d5b4dff39d449318b9bc5ea2317f
STABLEDIFFUSION_GGML_VERSION?=7948df8ac1070f5f6881b8d34675821893eb97d6
CMAKE_ARGS+=-DGGML_MAX_NAME=128

View File

@@ -386,7 +386,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
const char *llm_vision_path = "";
const char *diffusion_model_path = stableDiffusionModel;
const char *high_noise_diffusion_model_path = "";
const char *uncond_diffusion_model_path = "";
const char *taesd_path = "";
const char *control_net_path = "";
const char *embedding_dir = "";
@@ -473,7 +472,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval);
if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval);
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
if (!strcmp(optname, "uncond_diffusion_model_path")) uncond_diffusion_model_path = strdup(optval);
if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
if (!strcmp(optname, "embedding_dir")) {
@@ -573,7 +571,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
ctx_params.llm_vision_path = llm_vision_path;
ctx_params.diffusion_model_path = diffusion_model_path;
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
ctx_params.uncond_diffusion_model_path = uncond_diffusion_model_path;
ctx_params.vae_path = vae_path;
ctx_params.audio_vae_path = audio_vae_path;
ctx_params.embeddings_connectors_path = embeddings_connectors_path;

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=df7638d8229a243af8a4b5a8ae557e0d74e0a0ae
WHISPER_CPP_VERSION?=23ee03506a91ac3d3f0071b40e66a430eebdfa1d
SO_TARGET?=libgowhisper.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -95,29 +95,6 @@
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4"
metal: "metal-ds4"
metal-darwin-arm64: "metal-ds4"
- &dllm
name: "dllm"
alias: "dllm"
license: mit
description: |
mudler/dllm.cpp - DiffusionGemma block-diffusion LLM inference engine
(C++/ggml, GGUF weights). Decodes whole token canvases per diffusion
round instead of autoregressive sampling. Runs on CPU and NVIDIA CUDA 13
(including Jetson/GB10 L4T targets).
urls:
- https://github.com/mudler/dllm.cpp
tags:
- text-to-text
- LLM
- gguf
- diffusion
- CPU
- CUDA
capabilities:
default: "cpu-dllm"
nvidia: "cuda13-dllm"
nvidia-cuda-13: "cuda13-dllm"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm"
- &whispercpp
name: "whisper"
alias: "whisper"
@@ -1295,13 +1272,6 @@
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4-development"
metal: "metal-ds4-development"
metal-darwin-arm64: "metal-ds4-development"
- !!merge <<: *dllm
name: "dllm-development"
capabilities:
default: "cpu-dllm-development"
nvidia: "cuda13-dllm-development"
nvidia-cuda-13: "cuda13-dllm-development"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm-development"
- !!merge <<: *stablediffusionggml
name: "stablediffusion-ggml-development"
capabilities:
@@ -1889,37 +1859,6 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-ds4"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-ds4
## dllm
- !!merge <<: *dllm
name: "cpu-dllm"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-dllm"
mirrors:
- localai/localai-backends:latest-cpu-dllm
- !!merge <<: *dllm
name: "cpu-dllm-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-dllm"
mirrors:
- localai/localai-backends:master-cpu-dllm
- !!merge <<: *dllm
name: "cuda13-dllm"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-dllm"
mirrors:
- localai/localai-backends:latest-gpu-nvidia-cuda-13-dllm
- !!merge <<: *dllm
name: "cuda13-dllm-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-dllm"
mirrors:
- localai/localai-backends:master-gpu-nvidia-cuda-13-dllm
- !!merge <<: *dllm
name: "cuda13-nvidia-l4t-arm64-dllm"
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm"
mirrors:
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm
- !!merge <<: *dllm
name: "cuda13-nvidia-l4t-arm64-dllm-development"
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-dllm"
mirrors:
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-dllm
## whisper
- !!merge <<: *whispercpp
name: "whisper-development"

View File

@@ -37,20 +37,6 @@ def is_int(s):
except ValueError:
return False
def coerce_param_value(value):
"""Coerce a TTSRequest.params value (string on the wire) to the type the
Chatterbox generate() kwargs expect (float/int/bool), matching how static
YAML options are coerced at load time. Non-string values pass through."""
if not isinstance(value, str):
return value
if is_float(value):
return float(value)
if is_int(value):
return int(value)
if value.lower() in ["true", "false"]:
return value.lower() == "true"
return value
def split_text_at_word_boundary(text, max_length=250):
"""
Split text at word boundaries without truncating words.
@@ -205,14 +191,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# add options to kwargs
kwargs.update(self.options)
# Merge per-request params (TTSRequest.params), overriding the static
# YAML options. This exposes Chatterbox generation knobs (e.g.
# exaggeration, cfg_weight, temperature) per request. Values arrive as
# strings on the wire and are coerced to float/int/bool.
if hasattr(request, "params") and request.params:
for key, value in request.params.items():
kwargs[key] = coerce_param_value(value)
# Check if text exceeds 250 characters
# (chatterbox does not support long text)
# https://github.com/resemble-ai/chatterbox/issues/60

View File

@@ -47,26 +47,6 @@ def is_int(s):
return False
def coerce_param_value(value):
"""Coerce a string param value (from the TTSRequest.params map, which is
string-typed on the wire) into the most specific Python type the model
generation kwargs expect: bool, int, float, else the original string."""
if not isinstance(value, str):
return value
lowered = value.strip().lower()
if lowered in ("true", "false"):
return lowered == "true"
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
@@ -342,19 +322,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Model loaded successfully", success=True)
def _effective_instruct(self, request):
"""Resolve the instruction/style string for this request, preferring the
per-request TTSRequest.instructions value and falling back to the static
YAML `instruct` option. Empty string means "no instruction"."""
req_instruct = (
request.instructions
if hasattr(request, "instructions") and request.instructions
else ""
)
if req_instruct:
return req_instruct
return self.options.get("instruct", "") or ""
def _detect_mode(self, request):
"""Detect which mode to use based on request parameters."""
# Priority: VoiceClone > VoiceDesign > CustomVoice
@@ -371,8 +338,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if self.audio_path or self.voices:
return "VoiceClone"
# VoiceDesign: instruct provided per-request or via YAML option
if self._effective_instruct(request):
# VoiceDesign: instruct option is provided
if "instruct" in self.options and self.options["instruct"]:
return "VoiceDesign"
# Default to CustomVoice
@@ -723,20 +690,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if do_sample is not None:
generation_kwargs["do_sample"] = do_sample
# Prefer the per-request instruction (TTSRequest.instructions) over the
# static YAML `instruct` option. This lets clients set a different style
# (CustomVoice emotion) or designed voice (VoiceDesign) per request.
instruct = self._effective_instruct(request)
instruct = self.options.get("instruct", "")
if instruct is not None and instruct != "":
generation_kwargs["instruct"] = instruct
# Merge any per-request backend-specific params (TTSRequest.params).
# Values arrive as strings on the wire; coerce to int/float/bool so the
# model receives the types it expects. These override YAML-derived kwargs.
if hasattr(request, "params") and request.params:
for key, value in request.params.items():
generation_kwargs[key] = coerce_param_value(value)
# Generate audio based on mode
if mode == "VoiceClone":
# VoiceClone mode

View File

@@ -1,4 +1,4 @@
grpcio==1.81.0
grpcio==1.80.0
protobuf==7.35.0
certifi
setuptools

View File

@@ -26,10 +26,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
try:
from vllm.tokenizers import get_tokenizer # vLLM >= 0.22
except ImportError:
from vllm.transformers_utils.tokenizer import get_tokenizer # vLLM < 0.22
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.multimodal.utils import fetch_image
from vllm.assets.video import VideoAsset
import base64

View File

@@ -3,5 +3,5 @@
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
# so uv consults this index alongside PyPI.
--extra-index-url https://wheels.vllm.ai/0.22.1/cu130
vllm==0.22.1
--extra-index-url https://wheels.vllm.ai/0.22.0/cu130
vllm==0.22.0

View File

@@ -1,4 +1,4 @@
grpcio==1.81.0
grpcio==1.80.0
protobuf
certifi
setuptools

View File

@@ -102,12 +102,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
// Connect to NATS
natsAuth := cfg.Distributed.NatsAuthConfig()
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
}
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
if err != nil {
return nil, fmt.Errorf("connecting to NATS: %w", err)
}

View File

@@ -23,9 +23,9 @@ import (
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/services/routing/router"
"github.com/mudler/LocalAI/core/services/storage"
"github.com/mudler/LocalAI/pkg/signals"
coreStartup "github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/signals"
"github.com/mudler/LocalAI/pkg/vram"
"github.com/mudler/LocalAI/pkg/model"
@@ -308,31 +308,10 @@ func New(opts ...config.AppOption) (*Application, error) {
application.galleryService.SetNATSClient(distSvc.Nats)
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
// Clean up stale in-progress operations from previous crashed instances
if _, err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
xlog.Warn("Failed to clean stale gallery operations", "error", err)
}
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
// Reap stale ops periodically, not just at boot: an op orphaned by
// a replica that died mid-install (its foreground handler goroutine
// gone) would otherwise linger "processing" in the UI until the next
// restart. 30m matches the install/upgrade ceiling so a genuinely
// slow op is never reaped out from under itself.
gsvc := application.galleryService
go func() {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for {
select {
case <-options.Context.Done():
return
case <-ticker.C:
if _, err := gsvc.ReapStaleOperations(30 * time.Minute); err != nil {
xlog.Warn("Failed to reap stale gallery operations", "error", err)
}
}
}
}()
}
// Hydrate from the store first so the wildcard subscriber finds an
// already-populated statuses map for any operations still in flight

View File

@@ -214,9 +214,7 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
"from", info.InstalledVersion, "to", info.AvailableVersion)
var err error
if bm != nil {
// Background auto-upgrade: no live admin watching a progress bar,
// so opID is empty and the distributed path skips progress streaming.
err = bm.UpgradeBackend(ctx, "", name, nil)
err = bm.UpgradeBackend(ctx, name, nil)
} else {
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)

View File

@@ -123,14 +123,14 @@ var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
})
It("ModelTTS forwards the request context to the SmartRouter", func() {
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg)
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
err := backend.ModelTTSStream(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg, func([]byte) error { return nil })
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()

View File

@@ -239,13 +239,13 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
if c.Backend == "cloud-proxy" {
opts.Proxy = &pb.ProxyOptions{
UpstreamUrl: c.Proxy.UpstreamURL,
Mode: c.Proxy.Mode,
Provider: c.Proxy.Provider,
ApiKeyEnv: c.Proxy.APIKeyEnv,
ApiKeyFile: c.Proxy.APIKeyFile,
UpstreamModel: c.Proxy.UpstreamModel,
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
UpstreamUrl: c.Proxy.UpstreamURL,
Mode: c.Proxy.Mode,
Provider: c.Proxy.Provider,
ApiKeyEnv: c.Proxy.APIKeyEnv,
ApiKeyFile: c.Proxy.APIKeyFile,
UpstreamModel: c.Proxy.UpstreamModel,
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
}
}
@@ -323,12 +323,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
metadata["enable_thinking"] = "true"
}
}
// Forward the effective reasoning effort so the backend can pass it to the
// jinja chat template (chat_template_kwargs.reasoning_effort) — the lever
// models like gpt-oss / LFM2.5 actually read, distinct from enable_thinking.
if c.ReasoningEffort != "" {
metadata["reasoning_effort"] = c.ReasoningEffort
}
pbOpts.Metadata = metadata
// Logprobs and TopLogprobs are set by the caller if provided

View File

@@ -75,25 +75,3 @@ var _ = Describe("gRPCPredictOpts enable_thinking metadata", func() {
Expect(opts.Metadata).ToNot(HaveKey("enable_thinking"))
})
})
// Guards forwarding the effective reasoning_effort into PredictOptions.Metadata,
// where the backend passes it to the jinja chat template (chat_template_kwargs)
// so models like gpt-oss / LFM2.5 honor it.
var _ = Describe("gRPCPredictOpts reasoning_effort metadata", func() {
withEffort := func(effort string) config.ModelConfig {
cfg := config.ModelConfig{}
cfg.SetDefaults()
cfg.ReasoningEffort = effort
return cfg
}
It("forwards reasoning_effort when set", func() {
opts := gRPCPredictOpts(withEffort("none"), "/tmp/models")
Expect(opts.Metadata).To(HaveKeyWithValue("reasoning_effort", "none"))
})
It("omits reasoning_effort when empty", func() {
opts := gRPCPredictOpts(withEffort(""), "/tmp/models")
Expect(opts.Metadata).ToNot(HaveKey("reasoning_effort"))
})
})

View File

@@ -20,32 +20,11 @@ import (
"github.com/mudler/LocalAI/pkg/utils"
)
// newTTSRequest assembles the gRPC TTSRequest from the per-request inputs. The
// optional instructions string is only attached when non-empty so backends can
// distinguish "no per-request instruction" (fall back to YAML) from an explicit
// empty one. params is forwarded as-is (nil when unset).
func newTTSRequest(text, modelPath, voice, dst, language, instructions string, params map[string]string) *proto.TTSRequest {
req := &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: dst,
Language: &language,
Params: params,
}
if instructions != "" {
req.Instructions = &instructions
}
return req
}
func ModelTTS(
ctx context.Context,
text,
voice,
language,
instructions string,
params map[string]string,
language string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
modelConfig config.ModelConfig,
@@ -95,9 +74,13 @@ func ModelTTS(
startTime = time.Now()
}
ttsRequest := newTTSRequest(text, modelPath, voice, filePath, language, instructions, params)
res, err := ttsModel.TTS(ctx, ttsRequest)
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
Language: &language,
})
if appConfig.EnableTracing {
errStr := ""
@@ -145,9 +128,7 @@ func ModelTTSStream(
ctx context.Context,
text,
voice,
language,
instructions string,
params map[string]string,
language string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
modelConfig config.ModelConfig,
@@ -196,10 +177,12 @@ func ModelTTSStream(
var totalPCMBytes int
snippetCapped := false
// Streaming TTS writes to the HTTP response, not a file, so dst is empty.
ttsRequest := newTTSRequest(text, modelPath, voice, "", language, instructions, params)
err = ttsModel.TTSStream(ctx, ttsRequest, func(reply *proto.Reply) {
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Language: &language,
}, func(reply *proto.Reply) {
// First message contains sample rate info
if !headerSent && len(reply.Message) > 0 {
var info map[string]any

View File

@@ -1,42 +0,0 @@
package backend
// Specs for the TTSRequest assembly that carries the per-request
// instructions/params from the OpenAI `instructions` field (and the LocalAI
// `params` extension) through to the gRPC boundary. Before this plumbing the
// instruction value was dropped before reaching the backend; these specs pin
// that it now survives, and that the empty case stays backward compatible.
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("newTTSRequest", func() {
It("attaches the instructions when a per-request value is set", func() {
req := newTTSRequest("hi", "/m", "alloy", "/out.wav", "en", "cheerful narrator", nil)
Expect(req.Instructions).ToNot(BeNil())
Expect(req.GetInstructions()).To(Equal("cheerful narrator"))
Expect(req.GetText()).To(Equal("hi"))
Expect(req.GetVoice()).To(Equal("alloy"))
Expect(req.GetDst()).To(Equal("/out.wav"))
Expect(req.GetLanguage()).To(Equal("en"))
})
It("leaves instructions unset when empty so backends fall back to YAML", func() {
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
Expect(req.Instructions).To(BeNil())
Expect(req.GetInstructions()).To(Equal(""))
})
It("forwards per-request params through to the backend", func() {
params := map[string]string{"exaggeration": "0.7", "cfg_weight": "0.3"}
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", params)
Expect(req.GetParams()).To(HaveKeyWithValue("exaggeration", "0.7"))
Expect(req.GetParams()).To(HaveKeyWithValue("cfg_weight", "0.3"))
})
It("leaves params nil when none are supplied", func() {
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
Expect(req.GetParams()).To(BeNil())
})
})

View File

@@ -52,28 +52,10 @@ type AgentWorkerCMD struct {
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
// DistributedRequireAuth is the umbrella switch; for the agent worker (which
// has no file-transfer server) it implies NATS auth is required.
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch implying --nats-require-auth (agent workers have no file-transfer server)" group:"distributed"`
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
// Timeouts
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
}
// natsAuthRequired reports whether NATS JWT credentials must be present — the
// granular flag or the umbrella (LOCALAI_DISTRIBUTED_REQUIRE_AUTH).
func (cmd *AgentWorkerCMD) natsAuthRequired() bool {
return cmd.NatsRequireAuth || cmd.DistributedRequireAuth
}
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
@@ -99,30 +81,15 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
registrationBody["token"] = cmd.RegistrationToken
}
// Context cancelled on shutdown — used by registration waits, heartbeat, and
// other background goroutines.
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
// Acquire credentials via (re)registration. When the bus requires auth and no
// static fallback is configured, wait through admin approval until the
// frontend mints credentials rather than starting unauthenticated.
credMgr := workerregistry.NewNATSCredentialManager(
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
return regClient.RegisterFull(ctx, registrationBody)
},
cmd.natsAuthRequired() && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
)
res, err := credMgr.Acquire(shutdownCtx)
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
if err != nil {
return fmt.Errorf("registration failed: %w", err)
}
nodeID := res.ID
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
// Use provisioned API token if none was set
if cmd.APIToken == "" {
cmd.APIToken = res.APIToken
cmd.APIToken = apiToken
}
// Start heartbeat
@@ -131,40 +98,14 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
}
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
// Context cancelled on shutdown — used by heartbeat and other background goroutines
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
// Resolve NATS credentials with precedence: explicit env override, then
// frontend-minted (auto-refreshed before expiry), then service fallback.
// Each static source must supply JWT and seed together.
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
var natsOpts []messaging.Option
switch {
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
}
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
case credMgr.HasCredentials():
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
go func() {
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
shutdownCancel()
}
}()
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
}
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
case cmd.natsAuthRequired():
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
}
if natsTLS.Enabled() {
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
}
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
// Connect to NATS
natsClient, err := messaging.New(cmd.NatsURL)
if err != nil {
return fmt.Errorf("connecting to NATS: %w", err)
}
@@ -242,25 +183,17 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
// Wait for an OS signal or an internal fatal condition (e.g. NATS
// credentials became unrenewable), so the worker restarts and re-acquires
// rather than lingering unable to serve.
// Wait for shutdown
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
var runErr error
select {
case <-sigCh:
case <-shutdownCtx.Done():
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
xlog.Error("Internal shutdown requested", "error", runErr)
}
<-sigCh
xlog.Info("Shutting down agent worker")
shutdownCancel() // stop heartbeat loop immediately
dispatcher.Stop()
mcpTools.CloseAllMCPSessions()
regClient.GracefulDeregister(nodeID)
return runErr
return nil
}
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.

View File

@@ -1,30 +0,0 @@
package chat
import (
"context"
"io"
"strings"
)
type Options struct {
Model string
BaseURL string
APIKey string
In io.Reader
Out io.Writer
}
func Run(ctx context.Context, opts Options) error {
if opts.In == nil {
opts.In = strings.NewReader("")
}
if opts.Out == nil {
opts.Out = io.Discard
}
session, err := newChatSession(ctx, newLocalAIChatClient(opts.BaseURL, opts.APIKey), opts.Model)
if err != nil {
return err
}
return runTerminalChat(ctx, session, opts.In, opts.Out)
}

View File

@@ -1,13 +0,0 @@
package chat
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestChat(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Chat Suite")
}

View File

@@ -1,172 +0,0 @@
package chat
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Run chat", func() {
It("streams a single chat response", func() {
var capturedModel string
var capturedAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/models" {
w.Header().Set("Content-Type", "application/json")
writeResponse(w, `{"object":"list","data":[{"id":"test-model","object":"model"}]}`)
return
}
Expect(r.URL.Path).To(Equal("/v1/chat/completions"))
capturedAuth = r.Header.Get("Authorization")
var body struct {
Model string `json:"model"`
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
}
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
capturedModel = body.Model
Expect(body.Messages).To(HaveLen(1))
Expect(body.Messages[0].Role).To(Equal("user"))
Expect(body.Messages[0].Content).To(Equal("hello"))
w.Header().Set("Content-Type", "text/event-stream")
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n")
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n")
writeResponse(w, "data: [DONE]\n\n")
}))
defer server.Close()
var out bytes.Buffer
err := Run(GinkgoT().Context(), Options{
Model: "test-model",
BaseURL: server.URL + "/v1",
APIKey: "secret",
In: strings.NewReader("hello\n/exit\n"),
Out: &out,
})
Expect(err).ToNot(HaveOccurred())
Expect(capturedModel).To(Equal("test-model"))
Expect(capturedAuth).To(Equal("Bearer secret"))
Expect(out.String()).To(ContainSubstring("assistant: hi!"))
Expect(out.String()).To(ContainSubstring("bye"))
})
It("auto-selects the only available model", func() {
server := chatTestServer([]string{"solo"}, nil)
defer server.Close()
var out bytes.Buffer
err := Run(GinkgoT().Context(), Options{
BaseURL: server.URL + "/v1",
In: strings.NewReader("/exit\n"),
Out: &out,
})
Expect(err).ToNot(HaveOccurred())
Expect(out.String()).To(ContainSubstring("LocalAI chat (solo)"))
})
It("returns an actionable error when no models are installed", func() {
server := chatTestServer(nil, nil)
defer server.Close()
err := Run(GinkgoT().Context(), Options{
BaseURL: server.URL + "/v1",
In: strings.NewReader(""),
})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no chat models are installed"))
Expect(err.Error()).To(ContainSubstring("local-ai models install <model>"))
})
It("returns an actionable error when multiple models are available without a selection", func() {
server := chatTestServer([]string{"alpha", "beta"}, nil)
defer server.Close()
err := Run(GinkgoT().Context(), Options{
BaseURL: server.URL + "/v1",
In: strings.NewReader(""),
})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("multiple models are available"))
Expect(err.Error()).To(ContainSubstring("--model"))
Expect(err.Error()).To(ContainSubstring("alpha"))
Expect(err.Error()).To(ContainSubstring("beta"))
})
It("lists and switches models inside the chat", func() {
requestedModels := []string{}
server := chatTestServer([]string{"alpha", "beta"}, func(model string) {
requestedModels = append(requestedModels, model)
})
defer server.Close()
var out bytes.Buffer
err := Run(GinkgoT().Context(), Options{
Model: "alpha",
BaseURL: server.URL + "/v1",
In: strings.NewReader("/models\n/model beta\nhello\n/exit\n"),
Out: &out,
})
Expect(err).ToNot(HaveOccurred())
Expect(out.String()).To(ContainSubstring("* alpha"))
Expect(out.String()).To(ContainSubstring(" beta"))
Expect(out.String()).To(ContainSubstring("switched to beta; conversation cleared"))
Expect(requestedModels).To(Equal([]string{"beta"}))
})
})
func chatTestServer(models []string, onChat func(model string)) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/models":
w.Header().Set("Content-Type", "application/json")
writeResponse(w, `{"object":"list","data":[`)
for i, model := range models {
if i > 0 {
writeResponse(w, ",")
}
writeResponsef(w, `{"id":%q,"object":"model"}`, model)
}
writeResponse(w, `]}`)
case "/v1/chat/completions":
var body struct {
Model string `json:"model"`
}
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
if onChat != nil {
onChat(body.Model)
}
w.Header().Set("Content-Type", "text/event-stream")
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n\n")
writeResponse(w, "data: [DONE]\n\n")
default:
w.WriteHeader(http.StatusNotFound)
}
}))
}
func writeResponse(w io.Writer, text string) {
_, err := fmt.Fprint(w, text)
Expect(err).ToNot(HaveOccurred())
}
func writeResponsef(w io.Writer, format string, args ...any) {
_, err := fmt.Fprintf(w, format, args...)
Expect(err).ToNot(HaveOccurred())
}

View File

@@ -1,114 +0,0 @@
package chat
import (
"context"
"errors"
"fmt"
"io"
"sort"
"strings"
openai "github.com/sashabaranov/go-openai"
)
type chatClient interface {
ListModels(ctx context.Context) ([]string, error)
StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error)
}
type localAIChatClient struct {
client *openai.Client
}
func newLocalAIChatClient(baseURL string, apiKey string) *localAIChatClient {
cfg := openai.DefaultConfig(apiKey)
cfg.BaseURL = baseURL
return &localAIChatClient{client: openai.NewClientWithConfig(cfg)}
}
func (c *localAIChatClient) ListModels(ctx context.Context) ([]string, error) {
resp, err := c.client.ListModels(ctx)
if err != nil {
return nil, err
}
models := make([]string, 0, len(resp.Models))
for _, model := range resp.Models {
if model.ID != "" {
models = append(models, model.ID)
}
}
sort.Strings(models)
return models, nil
}
func (c *localAIChatClient) StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
stream, err := c.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
Model: model,
Messages: openAIChatMessages(messages),
})
if err != nil {
return "", friendlyChatError(err, model)
}
defer func() {
_ = stream.Close()
}()
var answer strings.Builder
for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return answer.String(), friendlyChatError(err, model)
}
if len(resp.Choices) == 0 {
continue
}
token := resp.Choices[0].Delta.Content
if token == "" {
continue
}
answer.WriteString(token)
if _, err := fmt.Fprint(out, token); err != nil {
return answer.String(), err
}
}
return answer.String(), nil
}
func openAIChatMessages(messages []chatMessage) []openai.ChatCompletionMessage {
converted := make([]openai.ChatCompletionMessage, len(messages))
for i, message := range messages {
converted[i] = openai.ChatCompletionMessage{
Role: message.Role,
Content: message.Content,
}
}
return converted
}
func friendlyChatError(err error, model string) error {
var apiErr *openai.APIError
if errors.As(err, &apiErr) {
switch apiErr.HTTPStatusCode {
case 404:
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
case 403:
return fmt.Errorf("model %q is disabled. Enable it from LocalAI settings or choose another model with `/model <name>`", model)
}
if apiErr.Message != "" {
return errors.New(apiErr.Message)
}
}
msg := err.Error()
if strings.Contains(msg, "model") && strings.Contains(msg, "not found") {
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
}
return err
}

View File

@@ -1,17 +0,0 @@
package chat
import "strings"
func formatChatModelList(models []string, current string) string {
var b strings.Builder
for _, model := range models {
prefix := " "
if model == current {
prefix = "* "
}
b.WriteString(prefix)
b.WriteString(model)
b.WriteByte('\n')
}
return b.String()
}

View File

@@ -1,120 +0,0 @@
package chat
import (
"context"
"errors"
"fmt"
"io"
"strings"
)
const (
chatRoleUser = "user"
chatRoleAssistant = "assistant"
)
type chatMessage struct {
Role string
Content string
}
type chatSession struct {
client chatClient
model string
models []string
messages []chatMessage
}
func newChatSession(ctx context.Context, client chatClient, requestedModel string) (*chatSession, error) {
models, err := client.ListModels(ctx)
if err != nil {
return nil, fmt.Errorf("list models: %w", err)
}
model, err := resolveChatModel(requestedModel, models)
if err != nil {
return nil, err
}
return &chatSession{
client: client,
model: model,
models: models,
}, nil
}
func (s *chatSession) CurrentModel() string {
return s.model
}
func (s *chatSession) Models() []string {
models := make([]string, len(s.models))
copy(models, s.models)
return models
}
func (s *chatSession) Clear() {
s.messages = nil
}
func (s *chatSession) SwitchModel(model string) error {
if !modelExists(s.models, model) {
return fmt.Errorf("model %q is not available. Use /models to see installed models", model)
}
s.model = model
s.Clear()
return nil
}
func (s *chatSession) Send(ctx context.Context, prompt string, out io.Writer) error {
s.messages = append(s.messages, chatMessage{
Role: chatRoleUser,
Content: prompt,
})
answer, err := s.client.StreamChat(ctx, s.model, s.messages, out)
if err != nil {
return err
}
s.messages = append(s.messages, chatMessage{
Role: chatRoleAssistant,
Content: answer,
})
return nil
}
func resolveChatModel(requested string, models []string) (string, error) {
switch {
case requested == "" && len(models) == 0:
return "", errors.New(`no chat models are installed.
Install a model first, for example:
local-ai models list
local-ai models install <model>
local-ai run
Then start a chat session:
local-ai chat --model <model>`)
case requested == "" && len(models) == 1:
return models[0], nil
case requested == "" && len(models) > 1:
var b strings.Builder
b.WriteString("multiple models are available; choose one with --model:\n")
b.WriteString(formatChatModelList(models, ""))
return "", errors.New(b.String())
case !modelExists(models, requested):
return "", fmt.Errorf("model %q is not available. Use `local-ai models list` and `local-ai models install <model>`, or pass an installed model with --model", requested)
default:
return requested, nil
}
}
func modelExists(models []string, name string) bool {
for _, model := range models {
if model == name {
return true
}
}
return false
}

View File

@@ -1,56 +0,0 @@
package chat
import (
"context"
"io"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Chat session", func() {
It("keeps model switching and message history out of the terminal adapter", func() {
client := &fakeChatClient{
models: []string{"alpha", "beta"},
answer: "pong",
}
session, err := newChatSession(context.Background(), client, "alpha")
Expect(err).ToNot(HaveOccurred())
Expect(session.CurrentModel()).To(Equal("alpha"))
Expect(session.SwitchModel("beta")).To(Succeed())
Expect(session.CurrentModel()).To(Equal("beta"))
Expect(session.Send(context.Background(), "ping", io.Discard)).To(Succeed())
Expect(client.requests).To(HaveLen(1))
Expect(client.requests[0].model).To(Equal("beta"))
Expect(client.requests[0].messages).To(HaveLen(1))
Expect(client.requests[0].messages[0].Content).To(Equal("ping"))
})
})
type fakeChatClient struct {
models []string
answer string
requests []fakeChatRequest
}
type fakeChatRequest struct {
model string
messages []chatMessage
}
func (c *fakeChatClient) ListModels(context.Context) ([]string, error) {
return c.models, nil
}
func (c *fakeChatClient) StreamChat(_ context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
copied := make([]chatMessage, len(messages))
copy(copied, messages)
c.requests = append(c.requests, fakeChatRequest{model: model, messages: copied})
if _, err := io.WriteString(out, c.answer); err != nil {
return "", err
}
return c.answer, nil
}

View File

@@ -1,93 +0,0 @@
package chat
import (
"bufio"
"context"
"fmt"
"io"
"strings"
)
func runTerminalChat(ctx context.Context, session *chatSession, in io.Reader, out io.Writer) error {
scanner := bufio.NewScanner(in)
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
if err := writeChat(out, "LocalAI chat (%s)\n", session.CurrentModel()); err != nil {
return err
}
if err := writeChat(out, "Type /exit to quit, /clear to reset the conversation, /models to list models.\n"); err != nil {
return err
}
for {
if err := writeChat(out, "\n> "); err != nil {
return err
}
if !scanner.Scan() {
break
}
prompt := strings.TrimSpace(scanner.Text())
switch prompt {
case "":
continue
case "/bye", "/exit", "/quit":
return writeChat(out, "bye\n")
case "/clear":
session.Clear()
if err := writeChat(out, "conversation cleared\n"); err != nil {
return err
}
continue
case "/models":
if err := printChatModels(out, session.Models(), session.CurrentModel()); err != nil {
return err
}
continue
}
if nextModel, ok := strings.CutPrefix(prompt, "/model "); ok {
nextModel = strings.TrimSpace(nextModel)
if nextModel == "" {
if err := writeChat(out, "usage: /model <name>\n"); err != nil {
return err
}
continue
}
if err := session.SwitchModel(nextModel); err != nil {
if writeErr := writeChat(out, "%s\n", err); writeErr != nil {
return writeErr
}
continue
}
if err := writeChat(out, "switched to %s; conversation cleared\n", session.CurrentModel()); err != nil {
return err
}
continue
}
if err := writeChat(out, "assistant: "); err != nil {
return err
}
if err := session.Send(ctx, prompt, out); err != nil {
return err
}
if err := writeChat(out, "\n"); err != nil {
return err
}
}
return scanner.Err()
}
func printChatModels(out io.Writer, models []string, current string) error {
if len(models) == 0 {
return writeChat(out, "no models installed\n")
}
return writeChat(out, "%s", formatChatModelList(models, current))
}
func writeChat(out io.Writer, format string, args ...any) error {
_, err := fmt.Fprintf(out, format, args...)
return err
}

View File

@@ -1,25 +0,0 @@
package cli
import (
"context"
"os"
chatcli "github.com/mudler/LocalAI/core/cli/chat"
cliContext "github.com/mudler/LocalAI/core/cli/context"
)
type ChatCMD struct {
Model string `short:"m" help:"Model name to use. Defaults to the only model returned by the server when exactly one is available"`
Endpoint string `env:"LOCALAI_CHAT_ENDPOINT" default:"http://127.0.0.1:8080" help:"LocalAI server endpoint. The /v1 path is added automatically when omitted"`
APIKey string `env:"LOCALAI_API_KEY,API_KEY" help:"API key to use when the LocalAI server requires authentication"`
}
func (c *ChatCMD) Run(ctx *cliContext.Context) error {
return chatcli.Run(context.Background(), chatcli.Options{
Model: c.Model,
BaseURL: chatAPIBaseURL(c.Endpoint),
APIKey: c.APIKey,
In: os.Stdin,
Out: os.Stdout,
})
}

View File

@@ -1,27 +0,0 @@
package cli
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Chat command wiring", func() {
Describe("chatAPIBaseURL", func() {
It("adds /v1 to a root endpoint", func() {
Expect(chatAPIBaseURL("http://127.0.0.1:8080")).To(Equal("http://127.0.0.1:8080/v1"))
})
It("keeps endpoints that already include /v1", func() {
Expect(chatAPIBaseURL("http://127.0.0.1:8080/v1")).To(Equal("http://127.0.0.1:8080/v1"))
Expect(chatAPIBaseURL("http://127.0.0.1:8080/v1/")).To(Equal("http://127.0.0.1:8080/v1"))
})
It("adds a default http scheme", func() {
Expect(chatAPIBaseURL("127.0.0.1:8080")).To(Equal("http://127.0.0.1:8080/v1"))
})
It("preserves non-root paths before /v1", func() {
Expect(chatAPIBaseURL("http://127.0.0.1:8080/localai")).To(Equal("http://127.0.0.1:8080/localai/v1"))
})
})
})

View File

@@ -1,29 +0,0 @@
package cli
import (
"net/url"
"strings"
)
func chatAPIBaseURL(endpoint string) string {
if !strings.Contains(endpoint, "://") {
endpoint = "http://" + endpoint
}
u, err := url.Parse(endpoint)
if err != nil {
return strings.TrimRight(endpoint, "/") + "/v1"
}
path := strings.TrimRight(u.Path, "/")
if path == "" {
u.Path = "/v1"
} else if path != "/v1" && !strings.HasSuffix(path, "/v1") {
u.Path = path + "/v1"
} else {
u.Path = path
}
u.RawQuery = ""
u.Fragment = ""
return u.String()
}

View File

@@ -9,7 +9,6 @@ var CLI struct {
cliContext.Context `embed:""`
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
Chat ChatCMD `cmd:"" help:"Open an interactive chat session against a running LocalAI server"`
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
Backends BackendsCMD `cmd:"" help:"Manage LocalAI backends and definitions"`

View File

@@ -30,8 +30,6 @@ type RunCMD struct {
ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
ExternalBackends []string `env:"LOCALAI_EXTERNAL_BACKENDS,EXTERNAL_BACKENDS" help:"A list of external backends to load from gallery on boot" group:"backends"`
WebRTCNAT1To1IPs []string `env:"LOCALAI_WEBRTC_NAT_1TO1_IPS,WEBRTC_NAT_1TO1_IPS" help:"IPs advertised as the host ICE candidates for /v1/realtime WebRTC instead of every local interface. Set to the reachable host/LAN IP when running under Docker host networking or NAT, where pion otherwise offers unreachable bridge addresses and the connection drops after ICE consent checks fail." group:"api"`
WebRTCICEInterfaces []string `env:"LOCALAI_WEBRTC_ICE_INTERFACES,WEBRTC_ICE_INTERFACES" help:"Restrict /v1/realtime WebRTC ICE candidate gathering to these network interfaces (e.g. eth0), filtering out docker0/veth noise." group:"api"`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
@@ -156,21 +154,11 @@ type RunCMD struct {
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
RegistrationRequireAuth bool `env:"LOCALAI_REGISTRATION_REQUIRE_AUTH" default:"false" help:"Fail startup when distributed mode is enabled but LOCALAI_REGISTRATION_TOKEN is empty (node endpoints and worker file-transfer server would otherwise be unauthenticated)" group:"distributed"`
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch: require BOTH NATS JWT credentials and a registration token when distributed mode is enabled (implies --nats-require-auth and --registration-require-auth)" group:"distributed"`
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
NatsAccountSeed string `env:"LOCALAI_NATS_ACCOUNT_SEED" help:"NATS account signing seed (SU...) used to mint per-node worker JWTs at registration" group:"distributed"`
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"NATS user JWT for the frontend (and agent workers) to publish control-plane messages" group:"distributed"`
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"NATS user signing seed (SU...) paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
NatsWorkerJWTTTL string `env:"LOCALAI_NATS_WORKER_JWT_TTL" help:"Lifetime of minted per-node NATS JWTs (e.g. 24h, default 24h)" group:"distributed"`
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT credentials (service JWT + account seed) when distributed mode is enabled" group:"distributed"`
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI); use with tls:// in --nats-url" group:"distributed"`
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
Version bool
@@ -227,8 +215,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithApiKeys(r.APIKeys),
config.WithModelsURL(append(r.Models, r.ModelArgs...)...),
config.WithExternalBackends(r.ExternalBackends...),
config.WithWebRTCNAT1To1IPs(r.WebRTCNAT1To1IPs...),
config.WithWebRTCICEInterfaces(r.WebRTCICEInterfaces...),
config.WithOpaqueErrors(r.OpaqueErrors),
config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan),
config.WithSubtleKeyComparison(r.UseSubtleKeyComparison),
@@ -297,40 +283,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
if r.RegistrationToken != "" {
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
}
if r.RegistrationRequireAuth {
opts = append(opts, config.EnableRegistrationRequireAuth)
}
if r.DistributedRequireAuth {
opts = append(opts, config.EnableDistributedRequireAuth)
}
if r.NatsAccountSeed != "" {
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
}
if r.NatsServiceJWT != "" {
opts = append(opts, config.WithNatsServiceJWT(r.NatsServiceJWT))
}
if r.NatsServiceSeed != "" {
opts = append(opts, config.WithNatsServiceSeed(r.NatsServiceSeed))
}
if r.NatsWorkerJWTTTL != "" {
d, err := time.ParseDuration(r.NatsWorkerJWTTTL)
if err != nil {
return fmt.Errorf("invalid LOCALAI_NATS_WORKER_JWT_TTL %q: %w", r.NatsWorkerJWTTTL, err)
}
opts = append(opts, config.WithNatsWorkerJWTTTL(d))
}
if r.NatsRequireAuth {
opts = append(opts, config.EnableNatsRequireAuth)
}
if r.NatsTLSCA != "" {
opts = append(opts, config.WithNatsTLSCA(r.NatsTLSCA))
}
if r.NatsTLSCert != "" {
opts = append(opts, config.WithNatsTLSCert(r.NatsTLSCert))
}
if r.NatsTLSKey != "" {
opts = append(opts, config.WithNatsTLSKey(r.NatsTLSKey))
}
if r.AutoApproveNodes {
opts = append(opts, config.EnableAutoApproveNodes)
}
@@ -656,12 +608,12 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
// waitForServerReady polls the given address until the HTTP server is
// accepting connections or the context is cancelled.
func waitForServerReady(address string, ctx context.Context) {
// Ensure the address has a host component for dialing.
// Echo accepts ":8080" but net.Dial needs a resolvable host.
host, port, err := net.SplitHostPort(address)
if err == nil && host == "" {
address = "127.0.0.1:" + port
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for {
select {
@@ -669,17 +621,11 @@ func waitForServerReady(address string, ctx context.Context) {
return
default:
}
conn, err := net.DialTimeout("tcp", address, 500*time.Millisecond)
if err == nil {
conn.Close()
return
}
select {
case <-ctx.Done():
return
case <-ticker.C:
}
time.Sleep(250 * time.Millisecond)
}
}

View File

@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
options.Backend = t.Backend
options.Model = t.Model
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, "", nil, ml, opts, options)
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
if err != nil {
return err
}

View File

@@ -96,7 +96,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
FrontendURL: r.RegisterTo,
RegistrationToken: r.RegistrationToken,
}
nodeID, _, _, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
if regErr != nil {
return fmt.Errorf("registering with frontend: %w", regErr)
}

View File

@@ -58,77 +58,65 @@ func (c *RegistrationClient) setAuth(req *http.Request) {
// RegisterResponse is the JSON body returned by /api/node/register.
type RegisterResponse struct {
ID string `json:"id"`
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
APIToken string `json:"api_token,omitempty"`
NatsJWT string `json:"nats_jwt,omitempty"`
NatsUserSeed string `json:"nats_user_seed,omitempty"`
ID string `json:"id"`
APIToken string `json:"api_token,omitempty"`
}
// RegisterFull sends a single registration request and returns the full
// response (node ID, approval status, and optional API token / NATS creds).
// Re-registration is idempotent: the frontend preserves the node row and mints
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
// Register sends a single registration request and returns the node ID and
// (optionally) an auto-provisioned API token.
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
jsonBody, _ := json.Marshal(body)
url := c.baseURL() + "/api/node/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
return "", "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("posting to %s: %w", url, err)
return "", "", fmt.Errorf("posting to %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
}
var result RegisterResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decoding response: %w", err)
return "", "", fmt.Errorf("decoding response: %w", err)
}
return &result, nil
}
// Register sends a single registration request and returns the node ID and
// optional credentials (API token for agent workers, NATS JWT when configured).
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
res, err := c.RegisterFull(ctx, body)
if err != nil {
return "", "", "", "", err
}
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
return result.ID, result.APIToken, nil
}
// RegisterWithRetry retries registration with exponential backoff.
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
backoff := 2 * time.Second
maxBackoff := 30 * time.Second
var nodeID, apiToken string
var err error
for attempt := 1; attempt <= maxRetries; attempt++ {
nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body)
nodeID, apiToken, err = c.Register(ctx, body)
if err == nil {
return nodeID, apiToken, natsJWT, natsSeed, nil
return nodeID, apiToken, nil
}
if attempt == maxRetries {
return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
}
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
select {
case <-ctx.Done():
return "", "", "", "", ctx.Err()
return "", "", ctx.Err()
case <-time.After(backoff):
}
backoff = min(backoff*2, maxBackoff)
}
return nodeID, apiToken, natsJWT, natsSeed, err
return nodeID, apiToken, err
}
// Heartbeat sends a single heartbeat POST with the given body.

View File

@@ -1,200 +0,0 @@
package workerregistry
import (
"context"
"fmt"
"sync"
"time"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/mudler/xlog"
)
// statusPending mirrors nodes.StatusPending. It is duplicated rather than
// imported so the lightweight registration client does not pull in the nodes
// package (and its gorm/DB dependencies).
const statusPending = "pending"
// defaultMaxAttempts bounds how many times Acquire registers (and how many
// consecutive times RefreshLoop may fail) before giving up. It is high enough
// to ride out a slow admin approval or a transient frontend outage, but finite
// so an unauthorized/unapprovable worker exits and surfaces the problem (via a
// non-zero exit and the resulting restart) rather than waiting forever.
const defaultMaxAttempts = 100
// RegisterFunc performs one idempotent registration round-trip.
type RegisterFunc func(ctx context.Context) (*RegisterResponse, error)
// NATSCredentialManager acquires NATS credentials at startup — waiting through
// admin approval when required — and refreshes them before the minted JWT
// expires, by re-registering (which mints a fresh JWT). The live NATS
// connection adopts a refreshed JWT on its next reconnect via Provider. Safe
// for concurrent use.
//
// It addresses two failure modes: a worker that needs credentials but registers
// while still pending approval (it would otherwise give up and never connect),
// and a long-running worker whose 24h JWT expires with no way to renew it.
type NATSCredentialManager struct {
register RegisterFunc
requireCreds bool // block until credentials are present (frontend minting in use)
// Tunables; defaults set by NewNATSCredentialManager, overridable in tests.
initialBackoff time.Duration
maxBackoff time.Duration
maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited)
refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed
refreshRetry time.Duration
expiryOf func(jwt string) (time.Time, bool)
mu sync.RWMutex
jwt string
seed string
nodeID string
}
// NewNATSCredentialManager builds a manager over register. When requireCreds is
// true, Acquire blocks until the node is approved and credentials are minted.
func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager {
return &NATSCredentialManager{
register: register,
requireCreds: requireCreds,
initialBackoff: 2 * time.Second,
maxBackoff: 30 * time.Second,
maxAttempts: defaultMaxAttempts,
refreshLead: 0.75,
refreshRetry: 30 * time.Second,
expiryOf: jwtExpiry,
}
}
// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token
// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT).
func jwtExpiry(token string) (time.Time, bool) {
if token == "" {
return time.Time{}, false
}
uc, err := natsauth.DecodeUserClaims(token)
if err != nil || uc.Expires == 0 {
return time.Time{}, false
}
return time.Unix(uc.Expires, 0), true
}
func (m *NATSCredentialManager) store(res *RegisterResponse) {
m.mu.Lock()
defer m.mu.Unlock()
m.nodeID = res.ID
if res.NatsJWT != "" && res.NatsUserSeed != "" {
m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed
}
}
// Current returns the latest NATS credentials (both empty until acquired).
func (m *NATSCredentialManager) Current() (jwt, seed string) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.jwt, m.seed
}
// NodeID returns the node ID from the most recent registration.
func (m *NATSCredentialManager) NodeID() string {
m.mu.RLock()
defer m.mu.RUnlock()
return m.nodeID
}
// Provider returns a callback compatible with messaging.WithUserJWTProvider,
// supplying the current credentials on each (re)connect.
func (m *NATSCredentialManager) Provider() func() (string, string) {
return m.Current
}
// HasCredentials reports whether complete NATS credentials have been obtained.
func (m *NATSCredentialManager) HasCredentials() bool {
jwt, seed := m.Current()
return jwt != "" && seed != ""
}
// Acquire registers and, when requireCreds is set, keeps re-registering with
// exponential backoff until the node is approved (status != pending) and
// credentials are minted. Without requireCreds it returns the first successful
// response (the historical one-shot behavior, preserved for anonymous NATS).
func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) {
backoff := m.initialBackoff
var lastReason error
for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ {
res, err := m.register(ctx)
switch {
case err != nil:
lastReason = err
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
case !m.requireCreds:
m.store(res)
return res, nil
case res.Status == statusPending:
lastReason = fmt.Errorf("node %s still pending admin approval", res.ID)
xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
case res.NatsJWT == "" || res.NatsUserSeed == "":
lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID)
xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
default:
m.store(res)
return res, nil
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(backoff):
}
backoff = min(backoff*2, m.maxBackoff)
}
return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason)
}
// RefreshLoop re-registers to mint a fresh JWT before the current one expires,
// updating the credentials returned by Current/Provider so the NATS connection
// adopts them on its next reconnect. It returns nil when ctx is cancelled or
// when the current credential has no expiry (nothing to refresh), and a non-nil
// error after maxAttempts consecutive refresh failures — letting the caller
// exit the worker so it restarts and re-acquires (or surfaces the outage)
// rather than silently drifting toward an expired, unrenewable JWT.
func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error {
failures := 0
for {
jwt, _ := m.Current()
exp, ok := m.expiryOf(jwt)
if !ok {
xlog.Debug("NATS credential has no expiry; refresh loop exiting")
return nil
}
wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0)
select {
case <-ctx.Done():
return nil
case <-time.After(wait):
}
res, err := m.register(ctx)
if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" {
m.store(res)
failures = 0
xlog.Info("Refreshed NATS credentials", "node", res.ID)
continue
}
failures++
if err != nil {
xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err)
} else {
xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures)
}
if m.maxAttempts > 0 && failures >= m.maxAttempts {
return fmt.Errorf("NATS credential refresh failed %d times in a row", failures)
}
// Back off before retrying so a persistent failure near expiry does not spin.
select {
case <-ctx.Done():
return nil
case <-time.After(m.refreshRetry):
}
}
}

View File

@@ -1,198 +0,0 @@
package workerregistry
import (
"context"
"sync"
"testing"
"time"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/nats-io/nkeys"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestWorkerRegistry(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "WorkerRegistry")
}
// fakeRegister returns a sequence of canned responses/errors, one per call, and
// records how many times it was invoked. The last entry repeats once exhausted.
type fakeRegister struct {
mu sync.Mutex
steps []step
calls int
}
type step struct {
res *RegisterResponse
err error
}
func (f *fakeRegister) fn() RegisterFunc {
return func(context.Context) (*RegisterResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
i := f.calls
f.calls++
if i >= len(f.steps) {
i = len(f.steps) - 1
}
return f.steps[i].res, f.steps[i].err
}
}
func (f *fakeRegister) count() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls
}
var _ = Describe("NATSCredentialManager", func() {
approved := func(jwt, seed string) *RegisterResponse {
return &RegisterResponse{ID: "node-1", Status: "healthy", NatsJWT: jwt, NatsUserSeed: seed}
}
pending := &RegisterResponse{ID: "node-1", Status: "pending"}
Describe("Acquire (#4 — wait through admin approval)", func() {
It("keeps re-registering until the node is approved and credentials are minted", func() {
f := &fakeRegister{steps: []step{
{res: pending}, // not approved yet
{res: approved("", "")}, // approved but JWT not minted yet
{res: approved("jwt-1", "seed-1")}, // finally minted
}}
m := NewNATSCredentialManager(f.fn(), true /* requireCreds */)
m.initialBackoff = time.Millisecond
m.maxBackoff = time.Millisecond
res, err := m.Acquire(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(res.ID).To(Equal("node-1"))
Expect(f.count()).To(Equal(3))
jwt, seed := m.Current()
Expect(jwt).To(Equal("jwt-1"))
Expect(seed).To(Equal("seed-1"))
Expect(m.HasCredentials()).To(BeTrue())
Expect(m.NodeID()).To(Equal("node-1"))
})
It("returns immediately on the first success when credentials are not required (anonymous NATS)", func() {
f := &fakeRegister{steps: []step{{res: pending}}}
m := NewNATSCredentialManager(f.fn(), false /* requireCreds */)
res, err := m.Acquire(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(res.Status).To(Equal("pending"))
Expect(f.count()).To(Equal(1))
Expect(m.HasCredentials()).To(BeFalse())
})
It("aborts when the context is cancelled while waiting for approval", func() {
f := &fakeRegister{steps: []step{{res: pending}}}
m := NewNATSCredentialManager(f.fn(), true)
m.initialBackoff = 10 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := m.Acquire(ctx)
Expect(err).To(MatchError(context.Canceled))
})
It("gives up after a bounded number of attempts so the worker exits and alerts", func() {
f := &fakeRegister{steps: []step{{res: pending}}} // never approved
m := NewNATSCredentialManager(f.fn(), true)
m.initialBackoff = time.Millisecond
m.maxBackoff = time.Millisecond
m.maxAttempts = 5
_, err := m.Acquire(context.Background())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("after 5 attempts"))
Expect(err.Error()).To(ContainSubstring("pending admin approval"))
Expect(f.count()).To(Equal(5))
})
})
Describe("RefreshLoop (#5 — renew before the JWT expires)", func() {
It("re-registers before expiry and updates the credentials served to new connections", func() {
f := &fakeRegister{steps: []step{{res: approved("jwt-2", "seed-2")}}}
m := NewNATSCredentialManager(f.fn(), true)
m.refreshLead = 0.5
m.refreshRetry = time.Millisecond
// jwt-1 expires soon; jwt-2 is long-lived so the loop then idles.
m.expiryOf = func(jwt string) (time.Time, bool) {
switch jwt {
case "jwt-1":
return time.Now().Add(40 * time.Millisecond), true
case "jwt-2":
return time.Now().Add(time.Hour), true
default:
return time.Time{}, false
}
}
m.store(approved("jwt-1", "seed-1"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = m.RefreshLoop(ctx) }()
Eventually(func() string {
jwt, _ := m.Current()
return jwt
}, "2s", "10ms").Should(Equal("jwt-2"))
})
It("returns an error after the bounded number of consecutive failures so the caller can exit", func() {
f := &fakeRegister{steps: []step{{err: context.DeadlineExceeded}}} // refresh always fails
m := NewNATSCredentialManager(f.fn(), true)
m.refreshLead = 0.5
m.refreshRetry = time.Millisecond
m.maxAttempts = 3
m.expiryOf = func(string) (time.Time, bool) { return time.Now().Add(time.Millisecond), true }
m.store(approved("jwt-1", "seed-1"))
errCh := make(chan error, 1)
go func() { errCh <- m.RefreshLoop(context.Background()) }()
Eventually(errCh, "2s").Should(Receive(MatchError(ContainSubstring("3 times in a row"))))
})
It("exits promptly when the current credential has no expiry (nothing to refresh)", func() {
f := &fakeRegister{steps: []step{{res: approved("x", "y")}}}
m := NewNATSCredentialManager(f.fn(), true)
m.expiryOf = func(string) (time.Time, bool) { return time.Time{}, false }
m.store(approved("static", "seed"))
done := make(chan struct{})
go func() { _ = m.RefreshLoop(context.Background()); close(done) }()
Eventually(done, "1s").Should(BeClosed())
Expect(f.count()).To(Equal(0)) // never tried to re-register
})
})
Describe("jwtExpiry default", func() {
It("decodes the expiry of a real minted worker JWT", func() {
akp, err := nkeys.CreateAccount()
Expect(err).ToNot(HaveOccurred())
seed, err := akp.Seed()
Expect(err).ToNot(HaveOccurred())
cfg := natsauth.Config{AccountSeed: string(seed), WorkerJWTTTL: time.Hour}
token, _, err := cfg.MintWorkerJWT("node-1", "backend")
Expect(err).ToNot(HaveOccurred())
exp, ok := jwtExpiry(token)
Expect(ok).To(BeTrue())
Expect(exp).To(BeTemporally("~", time.Now().Add(time.Hour), 2*time.Minute))
})
It("reports no expiry for an empty or undecodable token", func() {
_, ok := jwtExpiry("")
Expect(ok).To(BeFalse())
_, ok = jwtExpiry("not-a-jwt")
Expect(ok).To(BeFalse())
})
})
})

View File

@@ -12,19 +12,10 @@ import (
)
type ApplicationConfig struct {
Context context.Context
ConfigFile string
SystemState *system.SystemState
ExternalBackends []string
// WebRTCNAT1To1IPs, when set, are advertised as the host ICE candidates for
// /v1/realtime WebRTC instead of every local interface address. Needed when
// the routable address differs from what pion gathers — e.g. Docker host
// networking (where pion also offers unreachable bridge IPs) or NAT.
WebRTCNAT1To1IPs []string
// WebRTCICEInterfaces, when set, restricts ICE candidate gathering to these
// network interfaces (e.g. eth0), filtering out docker0/veth noise.
WebRTCICEInterfaces []string
Context context.Context
ConfigFile string
SystemState *system.SystemState
ExternalBackends []string
UploadLimitMB, Threads, ContextSize int
F16 bool
Debug bool
@@ -90,6 +81,7 @@ type ApplicationConfig struct {
// file is mode 0600.
MITMCADir string
// PIIPatternOverrides applies persisted per-id deltas (action,
// disabled) to the live redactor at startup. Loaded from
// runtime_settings.json and applied right after pii.NewRedactor.
@@ -124,11 +116,11 @@ type ApplicationConfig struct {
// --require-backend-integrity / LOCALAI_REQUIRE_BACKEND_INTEGRITY.
RequireBackendIntegrity bool
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
MemoryReclaimerEnabled bool // Enable memory threshold monitoring
@@ -319,18 +311,6 @@ func WithExternalBackends(backends ...string) AppOption {
}
}
func WithWebRTCNAT1To1IPs(ips ...string) AppOption {
return func(o *ApplicationConfig) {
o.WebRTCNAT1To1IPs = ips
}
}
func WithWebRTCICEInterfaces(interfaces ...string) AppOption {
return func(o *ApplicationConfig) {
o.WebRTCICEInterfaces = interfaces
}
}
func WithMachineTag(tag string) AppOption {
return func(o *ApplicationConfig) {
o.MachineTag = tag
@@ -722,6 +702,7 @@ func WithMITMCADir(dir string) AppOption {
}
}
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
return func(o *ApplicationConfig) {
o.DynamicConfigsDir = dynamicConfigsDir

View File

@@ -22,11 +22,9 @@ const (
UsecaseRerank = "rerank"
UsecaseDetection = "detection"
UsecaseVAD = "vad"
UsecaseAudioTransform = "audio_transform"
UsecaseDiarization = "diarization"
UsecaseRealtimeAudio = "realtime_audio"
UsecaseFaceRecognition = "face_recognition"
UsecaseSpeakerRecognition = "speaker_recognition"
UsecaseAudioTransform = "audio_transform"
UsecaseDiarization = "diarization"
UsecaseRealtimeAudio = "realtime_audio"
)
// GRPCMethod identifies a Backend service RPC from backend.proto.
@@ -49,11 +47,6 @@ const (
MethodAudioTransform GRPCMethod = "AudioTransform"
MethodDiarize GRPCMethod = "Diarize"
MethodAudioToAudioStream GRPCMethod = "AudioToAudioStream"
MethodFaceVerify GRPCMethod = "FaceVerify"
MethodFaceAnalyze GRPCMethod = "FaceAnalyze"
MethodVoiceVerify GRPCMethod = "VoiceVerify"
MethodVoiceEmbed GRPCMethod = "VoiceEmbed"
MethodVoiceAnalyze GRPCMethod = "VoiceAnalyze"
)
// UsecaseInfo describes a single known_usecase value and how it maps
@@ -161,16 +154,6 @@ var UsecaseInfoMap = map[string]UsecaseInfo{
GRPCMethod: MethodAudioToAudioStream,
Description: "Self-contained any-to-any audio model for the Realtime API — accepts microphone audio and emits speech + transcript (+ optional function calls) from a single backend via the AudioToAudioStream RPC.",
},
UsecaseFaceRecognition: {
Flag: FLAG_FACE_RECOGNITION,
GRPCMethod: MethodFaceVerify,
Description: "Face recognition — verify identity, analyze attributes (age/gender/emotion) via FaceVerify and FaceAnalyze RPCs.",
},
UsecaseSpeakerRecognition: {
Flag: FLAG_SPEAKER_RECOGNITION,
GRPCMethod: MethodVoiceVerify,
Description: "Speaker recognition — verify identity, embed and analyze voice via VoiceVerify, VoiceEmbed and VoiceAnalyze RPCs.",
},
}
// BackendCapability describes which gRPC methods and usecases a backend supports.
@@ -488,21 +471,6 @@ var BackendCapabilities = map[string]BackendCapability{
DefaultUsecases: []string{UsecaseDetection},
Description: "RF-DETR C++ object detection",
},
// --- Face and speaker recognition backends ---
"insightface": {
GRPCMethods: []GRPCMethod{MethodEmbedding, MethodDetect, MethodFaceVerify, MethodFaceAnalyze},
PossibleUsecases: []string{UsecaseEmbeddings, UsecaseDetection, UsecaseFaceRecognition},
DefaultUsecases: []string{UsecaseFaceRecognition},
AcceptsImages: true,
Description: "InsightFace — face detection, embedding, verification and attribute analysis",
},
"speaker-recognition": {
GRPCMethods: []GRPCMethod{MethodVoiceVerify, MethodVoiceEmbed, MethodVoiceAnalyze},
PossibleUsecases: []string{UsecaseSpeakerRecognition},
DefaultUsecases: []string{UsecaseSpeakerRecognition},
Description: "Speaker recognition — voice identity verification and analysis",
},
"silero-vad": {
GRPCMethods: []GRPCMethod{MethodVAD},
PossibleUsecases: []string{UsecaseVAD},

View File

@@ -5,8 +5,6 @@ import (
"fmt"
"time"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/mudler/xlog"
)
@@ -18,29 +16,7 @@ type DistributedConfig struct {
NatsURL string // --nats-url / LOCALAI_NATS_URL
StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint)
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
// RegistrationRequireAuth fails startup when distributed mode is enabled but
// RegistrationToken is empty. The default (false) keeps the historical
// fail-open behavior with a loud warning; production should set it so the
// node-register endpoints and the worker file-transfer server cannot run
// unauthenticated. Mirrors NatsRequireAuth for the NATS bus.
RegistrationRequireAuth bool // LOCALAI_REGISTRATION_REQUIRE_AUTH
// RequireAuth is the umbrella switch (LOCALAI_DISTRIBUTED_REQUIRE_AUTH) for
// distributed-mode auth: when true it implies BOTH NatsRequireAuth and
// RegistrationRequireAuth, so a single knob locks down the bus and the
// registration/file-transfer layer together. The granular flags remain
// available to enforce just one layer.
RequireAuth bool // LOCALAI_DISTRIBUTED_REQUIRE_AUTH
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
NatsServiceJWT string // LOCALAI_NATS_SERVICE_JWT — user JWT for frontends / agent workers
NatsServiceSeed string // LOCALAI_NATS_SERVICE_SEED — signing seed paired with service JWT
NatsWorkerJWTTTL time.Duration // LOCALAI_NATS_WORKER_JWT_TTL — minted worker JWT lifetime (default 24h)
NatsRequireAuth bool // LOCALAI_NATS_REQUIRE_AUTH — fail startup if NATS credentials are missing
NatsTLSCA string // LOCALAI_NATS_TLS_CA — PEM file for private CA (server verify)
NatsTLSCert string // LOCALAI_NATS_TLS_CERT — client cert for NATS mTLS
NatsTLSKey string // LOCALAI_NATS_TLS_KEY — client key paired with NatsTLSCert
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
// S3 configuration (used when StorageURL is set)
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
@@ -100,23 +76,10 @@ func (c DistributedConfig) Validate() error {
(c.StorageAccessKey == "" && c.StorageSecretKey != "") {
return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty")
}
// The registration token guards both the node HTTP register/heartbeat
// endpoints and the worker file-transfer server (which fails open on an
// empty token). Enforce it when registration auth is required (the granular
// flag or the umbrella); otherwise warn.
// Warn about missing registration token (not an error)
if c.RegistrationToken == "" {
if c.RegistrationAuthRequired() {
return fmt.Errorf("registration auth is required (LOCALAI_REGISTRATION_REQUIRE_AUTH or LOCALAI_DISTRIBUTED_REQUIRE_AUTH) but LOCALAI_REGISTRATION_TOKEN is empty")
}
xlog.Warn("distributed mode running without registration token — node endpoints and the worker file-transfer server are unprotected; set LOCALAI_REGISTRATION_TOKEN, or LOCALAI_DISTRIBUTED_REQUIRE_AUTH=true to fail closed")
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
}
if err := c.NatsAuthConfig().Validate(); err != nil {
return err
}
if err := c.NatsTLSFiles().Validate(); err != nil {
return err
}
c.NatsAuthConfig().WarnIfInsecure(true)
// Check for negative durations
for name, d := range map[string]time.Duration{
FlagMCPToolTimeout: c.MCPToolTimeout,
@@ -160,76 +123,6 @@ func WithRegistrationToken(token string) AppOption {
}
}
func WithNatsAccountSeed(seed string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsAccountSeed = seed
}
}
func WithNatsServiceJWT(jwt string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsServiceJWT = jwt
}
}
func WithNatsServiceSeed(seed string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsServiceSeed = seed
}
}
func WithNatsWorkerJWTTTL(d time.Duration) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsWorkerJWTTTL = d
}
}
var EnableNatsRequireAuth = func(o *ApplicationConfig) {
o.Distributed.NatsRequireAuth = true
}
// EnableRegistrationRequireAuth makes an empty registration token a hard error
// in distributed mode (see DistributedConfig.RegistrationRequireAuth).
var EnableRegistrationRequireAuth = func(o *ApplicationConfig) {
o.Distributed.RegistrationRequireAuth = true
}
// EnableDistributedRequireAuth is the umbrella switch implying both
// NatsRequireAuth and RegistrationRequireAuth (see DistributedConfig.RequireAuth).
var EnableDistributedRequireAuth = func(o *ApplicationConfig) {
o.Distributed.RequireAuth = true
}
// RegistrationAuthRequired reports whether an empty registration token must be
// treated as a fatal misconfiguration — the granular flag or the umbrella.
func (c DistributedConfig) RegistrationAuthRequired() bool {
return c.RegistrationRequireAuth || c.RequireAuth
}
// NatsAuthRequired reports whether NATS JWT credentials must be present — the
// granular flag or the umbrella.
func (c DistributedConfig) NatsAuthRequired() bool {
return c.NatsRequireAuth || c.RequireAuth
}
func WithNatsTLSCA(path string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsTLSCA = path
}
}
func WithNatsTLSCert(path string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsTLSCert = path
}
}
func WithNatsTLSKey(path string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsTLSKey = path
}
}
func WithStorageURL(url string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageURL = url
@@ -324,44 +217,6 @@ const (
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
const DefaultMaxUploadSize int64 = 50 << 30
// NatsTLSFiles returns NATS TLS/mTLS PEM paths for the messaging client.
func (c DistributedConfig) NatsTLSFiles() messaging.TLSFiles {
return messaging.TLSFiles{
CA: c.NatsTLSCA,
Cert: c.NatsTLSCert,
Key: c.NatsTLSKey,
}
}
// NatsMessagingOptions builds messaging client options (JWT + TLS) for distributed components.
// Pass explicit userJWT/userSeed when set (e.g. worker overrides); empty uses service JWT from config.
func (c DistributedConfig) NatsMessagingOptions(userJWT, userSeed string) []messaging.Option {
var opts []messaging.Option
jwt, seed := userJWT, userSeed
if jwt == "" && seed == "" {
auth := c.NatsAuthConfig()
jwt, seed = auth.ServiceUserJWT, auth.ServiceUserSeed
}
if jwt != "" && seed != "" {
opts = append(opts, messaging.WithUserJWT(jwt, seed))
}
if tls := c.NatsTLSFiles(); tls.Enabled() {
opts = append(opts, messaging.WithTLS(tls))
}
return opts
}
// NatsAuthConfig builds pkg/natsauth settings from distributed configuration.
func (c DistributedConfig) NatsAuthConfig() natsauth.Config {
return natsauth.Config{
AccountSeed: c.NatsAccountSeed,
ServiceUserJWT: c.NatsServiceJWT,
ServiceUserSeed: c.NatsServiceSeed,
WorkerJWTTTL: c.NatsWorkerJWTTTL,
RequireAuth: c.NatsAuthRequired(),
}
}
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)

View File

@@ -88,66 +88,3 @@ var _ = Describe("DistributedConfig.Validate negative-duration errors", func() {
Expect(c.Validate()).To(Succeed())
})
})
var _ = Describe("DistributedConfig.Validate registration auth", func() {
It("rejects an empty registration token when RequireAuth is set", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
RegistrationRequireAuth: true,
}
err := c.Validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_REQUIRE_AUTH"))
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
})
It("accepts a set registration token when RequireAuth is set", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
RegistrationToken: "s3cret",
RegistrationRequireAuth: true,
}
Expect(c.Validate()).To(Succeed())
})
It("warns but succeeds with an empty token when RequireAuth is unset", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
}
Expect(c.Validate()).To(Succeed())
})
It("rejects an empty token when the umbrella RequireAuth is set", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
RequireAuth: true,
// Provide NATS creds so only the registration-token gap remains.
NatsServiceJWT: "jwt",
NatsServiceSeed: "seed",
NatsAccountSeed: "acct",
}
err := c.Validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("LOCALAI_DISTRIBUTED_REQUIRE_AUTH"))
Expect(err.Error()).To(ContainSubstring("LOCALAI_REGISTRATION_TOKEN"))
})
It("the umbrella implies NATS auth is required", func() {
c := config.DistributedConfig{
Enabled: true,
NatsURL: "nats://localhost:4222",
RegistrationToken: "tok", // registration layer satisfied
RequireAuth: true, // umbrella → NATS creds now required
}
Expect(c.NatsAuthRequired()).To(BeTrue())
Expect(c.RegistrationAuthRequired()).To(BeTrue())
// Missing NATS service JWT/seed must now be fatal.
err := c.Validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("LOCALAI_NATS_REQUIRE_AUTH"))
})
})

View File

@@ -39,21 +39,7 @@ func llamaCppDefaults(cfg *ModelConfig, modelPath string) {
}
}()
// Startup parses every model's GGUF header to guess defaults. We only need
// scalar metadata (architecture, head/ff counts, chat_template, token IDs,
// MTP head) plus array *lengths* — never the array *contents*. Two options
// keep this cheap, which matters when many models live on slow storage such
// as a Docker volume (see https://github.com/mudler/LocalAI/issues/9790):
//
// - SkipLargeMetadata: seek past large array-valued metadata (the tokenizer
// vocab: tokenizer.ggml.tokens/scores/merges, often >100k entries) instead
// of reading and allocating every element. Lengths stay populated.
// - UseMMap: read the header via a memory map so faulting in a few pages
// replaces hundreds of thousands of tiny read() syscalls (measured ~524k
// -> 8 for a 256k-token vocab), the dominant cost on slow filesystems.
//
// The mapping is released when ParseGGUFFile returns.
f, err := gguf.ParseGGUFFile(guessPath, gguf.UseMMap(), gguf.SkipLargeMetadata())
f, err := gguf.ParseGGUFFile(guessPath)
if err == nil {
guessGGUFFromFile(cfg, f, 0)
}

View File

@@ -1,76 +1,13 @@
package config_test
import (
"bytes"
"encoding/binary"
"os"
"path/filepath"
. "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
gguf "github.com/gpustack/gguf-parser-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// GGUF metadata value type tags (see github.com/gpustack/gguf-parser-go).
const (
ggufTypeUint32 uint32 = 4
ggufTypeString uint32 = 8
ggufTypeArray uint32 = 9
)
// writeTestGGUF emits a minimal but valid little-endian GGUF v3 header carrying
// the scalar metadata the llama-cpp hook guesses from plus a large string vocab
// array (tokenizer.ggml.tokens). The big array is exactly what SkipLargeMetadata
// + UseMMap are expected to avoid reading element-by-element, so it must survive a
// round-trip through the real hook without corrupting the guessed defaults.
func writeTestGGUF(path, chatTemplate string, vocab int) error {
wStr := func(b *bytes.Buffer, s string) {
binary.Write(b, binary.LittleEndian, uint64(len(s)))
b.WriteString(s)
}
kvStr := func(b *bytes.Buffer, k, v string) {
wStr(b, k)
binary.Write(b, binary.LittleEndian, ggufTypeString)
wStr(b, v)
}
kvU32 := func(b *bytes.Buffer, k string, v uint32) {
wStr(b, k)
binary.Write(b, binary.LittleEndian, ggufTypeUint32)
binary.Write(b, binary.LittleEndian, v)
}
var meta bytes.Buffer
kvStr(&meta, "general.architecture", "llama")
kvStr(&meta, "general.name", "ReproModel")
kvU32(&meta, "llama.context_length", 4096)
kvU32(&meta, "llama.attention.head_count", 32)
kvU32(&meta, "llama.feed_forward_length", 11008)
kvU32(&meta, "llama.block_count", 32)
kvU32(&meta, "tokenizer.ggml.bos_token_id", 1)
kvStr(&meta, "tokenizer.chat_template", chatTemplate)
// large array value — the one the optimization skips reading
wStr(&meta, "tokenizer.ggml.tokens")
binary.Write(&meta, binary.LittleEndian, ggufTypeArray)
binary.Write(&meta, binary.LittleEndian, ggufTypeString)
binary.Write(&meta, binary.LittleEndian, uint64(vocab))
for i := 0; i < vocab; i++ {
wStr(&meta, "token")
}
var out bytes.Buffer
binary.Write(&out, binary.LittleEndian, gguf.GGUFMagicGGUFLe)
binary.Write(&out, binary.LittleEndian, uint32(3)) // version
binary.Write(&out, binary.LittleEndian, uint64(0)) // tensor count
binary.Write(&out, binary.LittleEndian, uint64(9)) // metadata kv count
out.Write(meta.Bytes())
return os.WriteFile(path, out.Bytes(), 0o644)
}
var _ = Describe("Backend hooks and parser defaults", func() {
Context("MatchParserDefaults", func() {
It("matches Qwen3 family", func() {
@@ -200,58 +137,6 @@ var _ = Describe("Backend hooks and parser defaults", func() {
})
})
Context("llamaCppDefaults GGUF guessing", func() {
// Regression coverage for https://github.com/mudler/LocalAI/issues/9790:
// the hook reads GGUF headers with SkipLargeMetadata + UseMMap to avoid
// pulling the whole tokenizer vocab off (slow) disk on every startup. This
// verifies that skipping the vocab array still yields the correct guessed
// defaults from the remaining scalar metadata.
const chatTemplate = "{{ bos_token }}{% for m in messages %}{{ m.content }}{% endfor %}"
It("guesses defaults from a GGUF whose large vocab is skipped", func() {
dir := GinkgoT().TempDir()
modelFile := "repro.gguf"
Expect(writeTestGGUF(filepath.Join(dir, modelFile), chatTemplate, 50000)).To(Succeed())
// A pre-set context size short-circuits the GGUF run-estimate, which
// needs full tensor info this header-only fixture deliberately omits;
// the metadata-reading path the optimization touches is unaffected.
ctxSize := 4096
cfg := &ModelConfig{
Backend: "llama-cpp",
LLMConfig: LLMConfig{ContextSize: &ctxSize},
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{Model: modelFile},
},
}
cfg.SetDefaults(ModelPath(dir))
// chat_template is a scalar string, not part of the skipped array,
// so it must be captured verbatim.
Expect(cfg.GetModelTemplate()).To(Equal(chatTemplate))
// scalar-derived defaults are still applied
Expect(cfg.ContextSize).NotTo(BeNil())
Expect(cfg.NGPULayers).NotTo(BeNil())
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
Expect(cfg.KnownUsecaseStrings).To(ContainElement("FLAG_CHAT"))
})
It("falls back to the default context size when the GGUF is unreadable", func() {
dir := GinkgoT().TempDir()
Expect(os.WriteFile(filepath.Join(dir, "bad.gguf"), []byte("not a gguf"), 0o644)).To(Succeed())
cfg := &ModelConfig{
Backend: "llama-cpp",
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{Model: "bad.gguf"},
},
}
cfg.SetDefaults(ModelPath(dir))
Expect(cfg.ContextSize).NotTo(BeNil())
})
})
Context("PromptCacheAll default", func() {
It("defaults to true when omitted from YAML", func() {
cfg := &ModelConfig{}

View File

@@ -128,22 +128,6 @@ func DefaultRegistry() map[string]FieldMetaOverride {
Advanced: true,
Order: 21,
},
"reasoning_effort": {
Section: "llm",
Label: "Reasoning Effort",
Description: "Default reasoning effort, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). A per-request reasoning_effort overrides it. 'none' also turns thinking off.",
Component: "select",
Options: []FieldOption{
{Value: "", Label: "Unset (model default)"},
{Value: "none", Label: "none (disable thinking)"},
{Value: "minimal", Label: "minimal"},
{Value: "low", Label: "low"},
{Value: "medium", Label: "medium"},
{Value: "high", Label: "high"},
},
Advanced: true,
Order: 22,
},
"cache_type_k": {
Section: "llm",
Label: "KV Cache Type (K)",
@@ -293,56 +277,6 @@ func DefaultRegistry() map[string]FieldMetaOverride {
AutocompleteProvider: ProviderModelsVAD,
Order: 63,
},
"pipeline.reasoning_effort": {
Section: "pipeline",
Label: "Reasoning Effort",
Description: "Reasoning effort for the pipeline's LLM, forwarded to the backend as the reasoning_effort chat_template_kwarg (jinja models like gpt-oss / LFM2.5 honor it). Overrides the LLM model's own reasoning_effort. 'none' also turns thinking off.",
Component: "select",
Options: []FieldOption{
{Value: "", Label: "Default (model config)"},
{Value: "none", Label: "none (disable thinking)"},
{Value: "minimal", Label: "minimal"},
{Value: "low", Label: "low"},
{Value: "medium", Label: "medium"},
{Value: "high", Label: "high"},
},
Order: 64,
},
"pipeline.disable_thinking": {
Section: "pipeline",
Label: "Disable Thinking",
Description: "Suppress reasoning/thinking output from the pipeline LLM (sets enable_thinking=false on the underlying model). Use for models that emit <think> blocks you don't want spoken or streamed back to the realtime client.",
Component: "toggle",
Order: 65,
},
"pipeline.streaming.llm": {
Section: "pipeline",
Label: "Stream LLM",
Description: "Stream LLM tokens to the realtime client as they are generated instead of waiting for the full response. Emits incremental response.output_audio_transcript.delta / text deltas.",
Component: "toggle",
Order: 66,
},
"pipeline.streaming.tts": {
Section: "pipeline",
Label: "Stream TTS",
Description: "Stream synthesized audio chunks to the realtime client as they are produced (requires a TTS backend that implements TTSStream). Falls back to unary synthesis otherwise.",
Component: "toggle",
Order: 67,
},
"pipeline.streaming.transcription": {
Section: "pipeline",
Label: "Stream Transcription",
Description: "Stream partial transcription text to the realtime client as the STT backend produces it (requires a transcription backend that implements AudioTranscriptionStream). Falls back to unary transcription otherwise.",
Component: "toggle",
Order: 68,
},
"pipeline.streaming.clause_chunking": {
Section: "pipeline",
Label: "Clause Chunking",
Description: "Split the streamed reply into speakable clauses and synthesize each as soon as it completes, instead of buffering the whole message before TTS — lower time-to-first-audio. Script-aware (handles CJK 。!? and Thai/Lao spaces), so it does not whitespace-split. Requires Stream LLM; off buffers the whole message.",
Component: "toggle",
Order: 69,
},
// --- Functions ---
"function.grammar.parallel_calls": {

View File

@@ -63,13 +63,6 @@ type ModelConfig struct {
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
// ReasoningEffort is the default reasoning effort (none|minimal|low|medium|high)
// for this model. A per-request reasoning_effort overrides it. It is forwarded
// to the backend as the reasoning_effort chat_template_kwarg (see
// gRPCPredictOpts), so jinja-templated models that key on it — e.g. gpt-oss
// (Harmony) or LFM2.5 — honor it; "none" also toggles enable_thinking off.
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline" json:",inline"`
@@ -494,85 +487,6 @@ type Pipeline struct {
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
// ReasoningEffort sets the reasoning effort (none|minimal|low|medium|high) for
// the pipeline's LLM without editing the LLM model config. Overrides the LLM's
// own reasoning_effort. Unset leaves the LLM model config in charge.
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
// Streaming opts each pipeline stage into incremental delivery (LLM tokens,
// TTS audio chunks, transcription text). Unset stages keep the blocking
// unary path, so existing configs are unaffected.
Streaming PipelineStreaming `yaml:"streaming,omitempty" json:"streaming,omitempty"`
// DisableThinking suppresses reasoning/thinking for the pipeline LLM (maps
// to enable_thinking=false backend metadata) without editing the underlying
// LLM model config. Unset leaves the LLM model config in charge.
DisableThinking *bool `yaml:"disable_thinking,omitempty" json:"disable_thinking,omitempty"`
}
// ApplyReasoningEffort resolves the effective reasoning effort — a per-request
// value (requestEffort) overrides the config's own ReasoningEffort default —
// stores it on the config so gRPCPredictOpts forwards it to the backend as the
// reasoning_effort chat_template_kwarg, and maps it onto the enable_thinking
// toggle the backend also reads:
// - "none" always disables thinking.
// - any explicit level enables it, UNLESS the config already disabled reasoning
// (an operator's explicit disable wins over a request asking to think).
//
// An empty requestEffort keeps the config's own default. With no effort set
// anywhere it is a no-op, leaving the model's reasoning settings untouched.
func (c *ModelConfig) ApplyReasoningEffort(requestEffort string) {
effort := requestEffort
if effort == "" {
effort = c.ReasoningEffort
}
c.ReasoningEffort = effort
switch strings.ToLower(effort) {
case "none":
disable := true
c.ReasoningConfig.DisableReasoning = &disable
case "minimal", "low", "medium", "high":
if c.ReasoningConfig.DisableReasoning == nil || !*c.ReasoningConfig.DisableReasoning {
enable := false
c.ReasoningConfig.DisableReasoning = &enable
}
}
}
// @Description PipelineStreaming toggles incremental delivery per realtime stage.
type PipelineStreaming struct {
LLM *bool `yaml:"llm,omitempty" json:"llm,omitempty"`
TTS *bool `yaml:"tts,omitempty" json:"tts,omitempty"`
Transcription *bool `yaml:"transcription,omitempty" json:"transcription,omitempty"`
// ClauseChunking splits the streamed LLM reply into speakable clauses and
// synthesizes each as soon as it completes, instead of buffering the whole
// message before TTS. Script-aware (CJK/Thai), so it does not rely on
// whitespace sentence boundaries. Requires LLM streaming; unset buffers the
// whole message (today's default).
ClauseChunking *bool `yaml:"clause_chunking,omitempty" json:"clause_chunking,omitempty"`
}
// StreamLLM reports whether LLM tokens should be streamed for this pipeline.
func (p Pipeline) StreamLLM() bool { return p.Streaming.LLM != nil && *p.Streaming.LLM }
// StreamTTS reports whether TTS audio should be streamed for this pipeline.
func (p Pipeline) StreamTTS() bool { return p.Streaming.TTS != nil && *p.Streaming.TTS }
// StreamTranscription reports whether transcription text should be streamed.
func (p Pipeline) StreamTranscription() bool {
return p.Streaming.Transcription != nil && *p.Streaming.Transcription
}
// ChunkClauses reports whether the streamed reply should be split into
// script-aware clauses and synthesized incrementally rather than buffered whole.
func (p Pipeline) ChunkClauses() bool {
return p.Streaming.ClauseChunking != nil && *p.Streaming.ClauseChunking
}
// ThinkingDisabled reports whether the pipeline forces the LLM's thinking off.
func (p Pipeline) ThinkingDisabled() bool {
return p.DisableThinking != nil && *p.DisableThinking
}
// @Description File configuration for model downloads

View File

@@ -30,26 +30,11 @@ func MTPSpecOptions() []string {
return out
}
// isDraftOnlyAssistantArch reports whether an architecture names a standalone
// MTP *draft* model rather than a self-speculating trunk. Upstream's Gemma4 MTP
// (ggml-org/llama.cpp#23398) registers the head as a separate `gemma4-assistant`
// architecture whose GGUF still carries `nextn_predict_layers`, but which cannot
// run alone: it requires a paired target context (`ctx_other`). Such archs must
// not trigger the embedded-head self-speculation defaults. The `-assistant`
// suffix is upstream's naming convention for these draft-only checkpoints.
func isDraftOnlyAssistantArch(arch string) bool {
return strings.HasSuffix(arch, "-assistant")
}
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a self-speculating
// Multi-Token Prediction head. Detection reads `<arch>.nextn_predict_layers`,
// which is what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a Multi-Token
// Prediction head. Detection reads `<arch>.nextn_predict_layers`, which is
// what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
// `conversion/qwen.py` MTP mixin. A positive layer count means the head is
// present in the same GGUF as the trunk.
//
// Draft-only assistant architectures (e.g. Gemma4's `gemma4-assistant`) carry
// the same key but are separate draft checkpoints meant to be paired with a
// target model, so they are deliberately excluded here.
func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
if f == nil {
return 0, false
@@ -58,9 +43,6 @@ func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
if arch == "" {
return 0, false
}
if isDraftOnlyAssistantArch(arch) {
return 0, false
}
v, ok := f.Header.MetadataKV.Get(arch + ".nextn_predict_layers")
if !ok {
return 0, false

View File

@@ -3,33 +3,10 @@ package config_test
import (
. "github.com/mudler/LocalAI/core/config"
gguf "github.com/gpustack/gguf-parser-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// ggufWithArch fabricates a minimal in-memory GGUF carrying the given
// `general.architecture` and a positive `<arch>.nextn_predict_layers` count,
// so HasEmbeddedMTPHead can be exercised without a real model file.
func ggufWithArch(arch string, nextn uint32) *gguf.GGUFFile {
return &gguf.GGUFFile{
Header: gguf.GGUFHeader{
MetadataKV: gguf.GGUFMetadataKVs{
{
Key: "general.architecture",
ValueType: gguf.GGUFMetadataValueTypeString,
Value: arch,
},
{
Key: arch + ".nextn_predict_layers",
ValueType: gguf.GGUFMetadataValueTypeUint32,
Value: nextn,
},
},
},
}
}
var _ = Describe("MTP auto-defaults", func() {
Context("MTPSpecOptions", func() {
It("returns the upstream-recommended speculative tuple", func() {
@@ -105,20 +82,5 @@ var _ = Describe("MTP auto-defaults", func() {
Expect(ok).To(BeFalse())
Expect(n).To(BeZero())
})
It("detects a same-GGUF embedded head (DeepSeek/Qwen style)", func() {
n, ok := HasEmbeddedMTPHead(ggufWithArch("qwen3moe", 1))
Expect(ok).To(BeTrue())
Expect(n).To(Equal(uint32(1)))
})
It("ignores a gemma4-assistant draft-only model", func() {
// The assistant GGUF carries nextn_predict_layers but is a separate
// draft model that requires a paired target (ctx_other); it cannot
// self-speculate, so it must not trigger the embedded-head defaults.
n, ok := HasEmbeddedMTPHead(ggufWithArch("gemma4-assistant", 48))
Expect(ok).To(BeFalse())
Expect(n).To(BeZero())
})
})
})

View File

@@ -1,57 +0,0 @@
package config
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gopkg.in/yaml.v3"
)
// The realtime pipeline can stream each stage (LLM tokens, TTS audio,
// transcription text) and can disable model "thinking" for the LLM. These are
// opt-in per pipeline; everything defaults to off so existing configs keep the
// unary behaviour.
var _ = Describe("Pipeline streaming config", func() {
It("defaults every streaming + thinking helper to false when unset", func() {
var p Pipeline
Expect(p.StreamLLM()).To(BeFalse())
Expect(p.StreamTTS()).To(BeFalse())
Expect(p.StreamTranscription()).To(BeFalse())
Expect(p.ChunkClauses()).To(BeFalse())
Expect(p.ThinkingDisabled()).To(BeFalse())
})
It("parses the nested streaming block and disable_thinking from YAML", func() {
var c ModelConfig
err := yaml.Unmarshal([]byte(`
name: gpt-realtime
pipeline:
llm: my-llm
tts: my-tts
transcription: my-stt
streaming:
llm: true
tts: true
transcription: true
clause_chunking: true
disable_thinking: true
`), &c)
Expect(err).ToNot(HaveOccurred())
Expect(c.Pipeline.StreamLLM()).To(BeTrue())
Expect(c.Pipeline.StreamTTS()).To(BeTrue())
Expect(c.Pipeline.StreamTranscription()).To(BeTrue())
Expect(c.Pipeline.ChunkClauses()).To(BeTrue())
Expect(c.Pipeline.ThinkingDisabled()).To(BeTrue())
})
It("treats an explicit false in the streaming block as disabled", func() {
var c ModelConfig
err := yaml.Unmarshal([]byte(`
name: gpt-realtime
pipeline:
streaming:
tts: false
`), &c)
Expect(err).ToNot(HaveOccurred())
Expect(c.Pipeline.StreamTTS()).To(BeFalse())
})
})

View File

@@ -1,52 +0,0 @@
package config_test
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/config"
)
// ApplyReasoningEffort resolves the effective reasoning effort (request value
// overrides the model config default), stores it on the config so it reaches the
// backend, and maps it onto the enable_thinking toggle.
var _ = Describe("ModelConfig.ApplyReasoningEffort", func() {
It("uses the request value over the config default", func() {
c := &config.ModelConfig{ReasoningEffort: "high"}
c.ApplyReasoningEffort("none")
Expect(c.ReasoningEffort).To(Equal("none"))
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
})
It("falls back to the config default when the request omits it", func() {
c := &config.ModelConfig{ReasoningEffort: "none"}
c.ApplyReasoningEffort("")
Expect(c.ReasoningEffort).To(Equal("none"))
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
})
It("enables thinking for an explicit effort level", func() {
c := &config.ModelConfig{}
c.ApplyReasoningEffort("medium")
Expect(c.ReasoningEffort).To(Equal("medium"))
Expect(c.ReasoningConfig.DisableReasoning).ToNot(BeNil())
Expect(*c.ReasoningConfig.DisableReasoning).To(BeFalse())
})
It("does not let a level override an operator's config-level disable", func() {
disabled := true
c := &config.ModelConfig{}
c.ReasoningConfig.DisableReasoning = &disabled
c.ApplyReasoningEffort("high")
Expect(*c.ReasoningConfig.DisableReasoning).To(BeTrue())
})
It("is a no-op on the toggle when no effort is set anywhere", func() {
c := &config.ModelConfig{}
c.ApplyReasoningEffort("")
Expect(c.ReasoningEffort).To(Equal(""))
Expect(c.ReasoningConfig.DisableReasoning).To(BeNil())
})
})

View File

@@ -420,9 +420,8 @@ func API(application *application.Application) (*echo.Echo, error) {
remoteUnloader = d.Router.Unloader()
}
}
natsCfg := distCfg.NatsAuthConfig()
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, natsCfg)
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken, natsCfg)
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
// Distributed SSE routes (job progress + agent events via NATS)
if d := application.Distributed(); d != nil {

View File

@@ -383,13 +383,13 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
xlog.Error("server error", "error", err)
}
}()
defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = testHTTPBase + "/v1"
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
@@ -418,7 +418,7 @@ var _ = Describe("API test", func() {
Context("Auth Tests", func() {
It("Should fail if the api key is missing", func() {
err, sc := postInvalidRequest(testHTTPBase + "/models/available")
err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available")
Expect(err).ToNot(BeNil())
Expect(sc).To(Equal(401))
})
@@ -427,7 +427,7 @@ var _ = Describe("API test", func() {
Context("URL routing Tests", func() {
It("Should support reverse-proxy when unauthenticated", func() {
err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
"X-Forwarded-Prefix": {"/myprefix/"},
@@ -441,7 +441,7 @@ var _ = Describe("API test", func() {
It("Should support reverse-proxy when authenticated", func() {
err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
"Authorization": {bearerKey},
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
@@ -459,7 +459,7 @@ var _ = Describe("API test", func() {
// requests them through the proxy.
It("Should support reverse-proxy when prefix is stripped by the proxy", func() {
err, sc, body := getRequest(testHTTPBase+"/app", http.Header{
err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
"X-Forwarded-Prefix": {"/myprefix"},
@@ -477,7 +477,7 @@ var _ = Describe("API test", func() {
// from a foreign origin. BasePathPrefix must reject these via
// SafeForwardedPrefix and fall back to "/".
It("Should ignore an unsafe X-Forwarded-Prefix and not poison asset URLs", func() {
err, sc, body := getRequest(testHTTPBase+"/app", http.Header{
err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
"X-Forwarded-Prefix": {"//evil.com"},
@@ -492,13 +492,13 @@ var _ = Describe("API test", func() {
Context("Applying models", func() {
It("applies models from a gallery", func() {
models, err := getModels(testHTTPBase + "/models/available")
models, err := getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "test@bert2",
})
@@ -507,7 +507,7 @@ var _ = Describe("API test", func() {
uuid := response["uuid"].(string)
resp := map[string]any{}
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
resp = response
return response["processed"].(bool)
@@ -526,7 +526,7 @@ var _ = Describe("API test", func() {
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
Expect(content["foo"]).To(Equal("bar"))
models, err = getModels(testHTTPBase + "/models/available")
models, err = getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
@@ -541,7 +541,7 @@ var _ = Describe("API test", func() {
})
It("overrides models", func() {
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: bertEmbeddingsURL,
Name: "bert",
Overrides: map[string]any{
@@ -554,7 +554,7 @@ var _ = Describe("API test", func() {
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
@@ -567,7 +567,7 @@ var _ = Describe("API test", func() {
Expect(content["backend"]).To(Equal("llama"))
})
It("apply models without overrides", func() {
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: bertEmbeddingsURL,
Name: "bert",
Overrides: map[string]any{},
@@ -578,7 +578,7 @@ var _ = Describe("API test", func() {
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
@@ -622,14 +622,14 @@ parameters:
}
var response schema.GalleryResponse
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
Expect(err).ToNot(HaveOccurred())
Expect(response.ID).ToNot(BeEmpty())
uuid := response.ID
resp := map[string]any{}
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
resp = response
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
@@ -657,7 +657,7 @@ parameters:
}
var response schema.GalleryResponse
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
// The endpoint should return an error immediately
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("failed to discover model config"))
@@ -693,14 +693,14 @@ parameters:
}
var response schema.GalleryResponse
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
Expect(err).ToNot(HaveOccurred())
Expect(response.ID).ToNot(BeEmpty())
uuid := response.ID
resp := map[string]any{}
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
resp = response
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
@@ -751,13 +751,13 @@ parameters:
app, err = API(localAIApp)
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
xlog.Error("server error", "error", err)
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = testHTTPBase + "/v1"
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
@@ -801,7 +801,7 @@ parameters:
// Mock-backend is registered via SetExternalBackend so it appears
// alongside any built-in entries; verifying that string proves the
// endpoint is wired up regardless of which real backends exist.
resp, err := http.Get(testHTTPBase + "/system")
resp, err := http.Get("http://127.0.0.1:9090/system")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
dat, err := io.ReadAll(resp.Body)
@@ -824,14 +824,14 @@ parameters:
}
var createResp map[string]any
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
Expect(createResp["id"]).ToNot(BeEmpty())
taskID := createResp["id"].(string)
// Get the task
var task schema.Task
resp, err := http.Get(testHTTPBase + "/api/agent/tasks/" + taskID)
resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, _ := io.ReadAll(resp.Body)
@@ -839,7 +839,7 @@ parameters:
Expect(task.Name).To(Equal("Test Task"))
// List tasks
resp, err = http.Get(testHTTPBase + "/api/agent/tasks")
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
var tasks []schema.Task
@@ -849,18 +849,18 @@ parameters:
// Update task
taskBody["name"] = "Updated Task"
err = putRequestJSON(testHTTPBase+"/api/agent/tasks/"+taskID, &taskBody)
err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody)
Expect(err).ToNot(HaveOccurred())
// Verify update
resp, err = http.Get(testHTTPBase + "/api/agent/tasks/" + taskID)
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
Expect(err).ToNot(HaveOccurred())
body, _ = io.ReadAll(resp.Body)
json.Unmarshal(body, &task)
Expect(task.Name).To(Equal("Updated Task"))
// Delete task
req, _ := http.NewRequest("DELETE", testHTTPBase+"/api/agent/tasks/"+taskID, nil)
req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil)
req.Header.Set("Authorization", bearerKey)
resp, err = http.DefaultClient.Do(req)
Expect(err).ToNot(HaveOccurred())
@@ -877,7 +877,7 @@ parameters:
}
var createResp map[string]any
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
taskID := createResp["id"].(string)
@@ -888,14 +888,14 @@ parameters:
}
var jobResp schema.JobExecutionResponse
err = postRequestResponseJSON(testHTTPBase+"/api/agent/jobs/execute", &jobBody, &jobResp)
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp)
Expect(err).ToNot(HaveOccurred())
Expect(jobResp.JobID).ToNot(BeEmpty())
jobID := jobResp.JobID
// Get job status
var job schema.Job
resp, err := http.Get(testHTTPBase + "/api/agent/jobs/" + jobID)
resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, _ := io.ReadAll(resp.Body)
@@ -904,7 +904,7 @@ parameters:
Expect(job.TaskID).To(Equal(taskID))
// List jobs
resp, err = http.Get(testHTTPBase + "/api/agent/jobs")
resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
var jobs []schema.Job
@@ -914,7 +914,7 @@ parameters:
// Cancel job (if still pending/running)
if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning {
req, _ := http.NewRequest("POST", testHTTPBase+"/api/agent/jobs/"+jobID+"/cancel", nil)
req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil)
req.Header.Set("Authorization", bearerKey)
resp, err = http.DefaultClient.Do(req)
Expect(err).ToNot(HaveOccurred())
@@ -932,13 +932,13 @@ parameters:
}
var createResp map[string]any
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
// Execute by name
paramsBody := map[string]string{"param1": "value1"}
var jobResp schema.JobExecutionResponse
err = postRequestResponseJSON(testHTTPBase+"/api/agent/tasks/Named Task/execute", &paramsBody, &jobResp)
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", &paramsBody, &jobResp)
Expect(err).ToNot(HaveOccurred())
Expect(jobResp.JobID).ToNot(BeEmpty())
})
@@ -998,13 +998,13 @@ parameters:
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
xlog.Error("server error", "error", err)
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = testHTTPBase + "/v1"
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready

View File

@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, "", nil, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -25,10 +25,6 @@ var knownPrefOnlyBackends = []schema.KnownBackend{
// Text LLM
// ds4: antirez/ds4 - single-model DeepSeek V4 Flash engine; auto-detected via DS4Importer
{Name: "ds4", Modality: "text", AutoDetect: false, Description: "antirez/ds4 DeepSeek V4 Flash engine (auto-detected; pref-only fallback)"},
// dllm consumes GGUF weights like llama-cpp does, but only for the
// DiffusionGemma architecture - auto-detecting on .gguf would shadow
// llama-cpp, so it stays preference-only.
{Name: "dllm", Modality: "text", AutoDetect: false, Description: "dllm.cpp DiffusionGemma block-diffusion engine (preference-only)"},
{Name: "sglang", Modality: "text", AutoDetect: false, Description: "SGLang runtime (preference-only)"},
{Name: "tinygrad", Modality: "text", AutoDetect: false, Description: "tinygrad runtime (preference-only)"},
{Name: "trl", Modality: "text", AutoDetect: false, Description: "Transformers Reinforcement Learning (preference-only)"},

View File

@@ -135,7 +135,6 @@ var _ = Describe("Backend Endpoints", func() {
Expect(entry.Modality).To(Equal(modality))
}
expectPrefOnly("dllm", "text")
expectPrefOnly("sglang", "text")
expectPrefOnly("tinygrad", "text")
expectPrefOnly("trl", "text")

View File

@@ -28,7 +28,6 @@ import (
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
"github.com/mudler/LocalAI/pkg/httpclient"
"github.com/mudler/LocalAI/pkg/natsauth"
)
// nodeError builds a schema.ErrorResponse for node endpoints.
@@ -90,7 +89,7 @@ type RegisterNodeRequest struct {
// RegisterNodeEndpoint registers a new backend node.
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
return func(c echo.Context) error {
var req RegisterNodeRequest
if err := c.Bind(&req); err != nil {
@@ -218,15 +217,13 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
}
}
attachNatsJWT(response, node, natsCfg)
return c.JSON(http.StatusCreated, response)
}
}
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
// For agent workers, it also provisions an API key so they can call the inference API.
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
id := c.Param("id")
@@ -256,26 +253,10 @@ func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecr
}
}
attachNatsJWT(response, node, natsCfg)
return c.JSON(http.StatusOK, response)
}
}
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
return
}
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
if err != nil {
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
return
}
response["nats_jwt"] = jwt
response["nats_user_seed"] = seed
}
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
// Returns the plaintext API key on success.
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {

View File

@@ -12,8 +12,6 @@ import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/testutil"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/nats-io/nkeys"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -65,7 +63,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusCreated))
@@ -76,29 +74,6 @@ var _ = Describe("Node HTTP handlers", func() {
Expect(resp["status"]).To(Equal(nodes.StatusHealthy))
})
It("returns nats_jwt when account seed is configured", func() {
akp, err := nkeys.CreateAccount()
Expect(err).ToNot(HaveOccurred())
seed, err := akp.Seed()
Expect(err).ToNot(HaveOccurred())
e := echo.New()
body := `{"name":"worker-nats","address":"10.0.0.2:50051"}`
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
natsCfg := natsauth.Config{AccountSeed: string(seed)}
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsCfg)
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusCreated))
var resp map[string]any
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
Expect(resp["nats_jwt"]).ToNot(BeEmpty())
})
It("returns 400 when name is missing", func() {
e := echo.New()
body := `{"address":"10.0.0.1:50051"}`
@@ -107,7 +82,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -127,7 +102,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -146,7 +121,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -165,7 +140,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -184,7 +159,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "", natsauth.Config{})
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "")
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
@@ -197,7 +172,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", false, nil, "", natsauth.Config{})
handler := RegisterNodeEndpoint(registry, "", false, nil, "")
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusCreated))
@@ -220,7 +195,7 @@ var _ = Describe("Node HTTP handlers", func() {
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body1))
req1.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec1 := httptest.NewRecorder()
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
Expect(handler(e.NewContext(req1, rec1))).To(Succeed())
Expect(rec1.Code).To(Equal(http.StatusCreated))

View File

@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
c.Response().Header().Set("Connection", "keep-alive")
// Stream audio chunks as they're generated
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg, func(audioChunk []byte) error {
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
_, writeErr := c.Response().Write(audioChunk)
if writeErr != nil {
return writeErr
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
}
// Non-streaming TTS (existing behavior)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
if err != nil {
return err
}

Some files were not shown because too many files have changed in this diff Show More