Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
b15627c864 chore(deps): bump the pip group across 1 directory with 2 updates
Bumps the pip group with 2 updates in the /backend/python/coqui directory: [transformers](https://github.com/huggingface/transformers) and torch.


Updates `transformers` from 4.48.3 to 5.0.0rc3
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](https://github.com/huggingface/transformers/compare/v4.48.3...v5.0.0rc3)

Updates `torch` from 2.4.1 to 2.7.1+cpu

---
updated-dependencies:
- dependency-name: transformers
  dependency-version: 5.0.0rc3
  dependency-type: direct:production
  dependency-group: pip
- dependency-name: torch
  dependency-version: 2.7.1+cpu
  dependency-type: direct:production
  dependency-group: pip
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-06-05 23:31:21 +00:00
142 changed files with 610 additions and 11320 deletions

View File

@@ -1,138 +0,0 @@
# Working on the dllm Backend
`mudler/dllm.cpp` is a standalone C++/ggml engine for DiffusionGemma
block-diffusion models. LocalAI wraps it with a **pure-Go** backend at
`backend/go/dllm/` that dlopens `libdllm.so` via purego (ebitengine/purego) -
NOT cgo, and NOT a C++ grpc-server fork. The Go side owns chat templating
(gemma4 renderer) and output parsing (gemma4 streaming parser) and implements
the rich gRPC interface (`PredictRich`/`PredictStreamRich`, ChatDelta replies).
> NOTE: github.com/mudler/dllm.cpp is still **private** (publishing is
> planned). Until then the Makefile's anonymous clone fails; use the local-dev
> symlink shortcut documented at the top of `backend/go/dllm/Makefile`
> (symlink an out-of-tree `build/libdllm.so` into the backend dir and skip the
> clone), or a git credential helper with repo access.
## Pin
`backend/go/dllm/Makefile` pins `DLLM_VERSION?=<sha>` at the top
(whisper / parakeet-cpp / ds4 convention). The bump-deps bot
(`.github/workflows/bump_deps.yaml`) tracks `mudler/dllm.cpp` `main` and
rewrites that variable. After a manual bump: `make -C backend/go/dllm purge &&
make -C backend/go/dllm` (the clone is keyed on the directory existing, not
the sha).
## C-ABI and the serialization contract
The binding covers the 9-symbol flat C-ABI from dllm.cpp's
`include/dllm_capi.h` (ABI v1; `main.go` hard-fails on a version mismatch):
`abi_version, load, free, last_error, free_string, tokenize_json, generate,
generate_stream, cancel`. Contract points the Go wiring encodes (`capi.go`
header comment has the full list):
- **One ctx = one concurrent generate/tokenize.** A per-model worker
goroutine (`Dllm.jobs` in `dllm.go`) owns ALL C calls, making the
serialization structural instead of lock discipline.
- **`dllm_capi_cancel` is the ONE exception**: it only flips an atomic and may
be called from any goroutine mid-generate, so `Dllm.Cancel` bypasses the
worker queue. The flag resets at the start of each generate, so a watchdog
racing a new generate must re-issue cancel.
- **`last_error` is a borrowed pointer** and must only be read AFTER the
failing call returned (never while a generate is in flight on the same ctx).
- **Free vs in-flight requests**: requests hold `genMu.RLock` for their full
duration; `Free` takes the write lock, so it only runs when nothing is in
flight, then drains and closes the worker. Post-Free requests get a clean
"model not loaded" error.
- `tokenize_json`/`generate` return malloc'd `char*` (bound as `uintptr`,
copied, then `dllm_capi_free_string`d); opts/params JSON must be a FLAT
object of scalars (`buildOptsJSON` rejects anything else).
## Wire shape
| RPC | Implementation |
|---|---|
| LoadModel | `dllm_capi_load` (params: `n_gpu_layers`, `n_threads`, `ctx_len`); `Options[]` parsed into per-request gen opts (`eb_*`, `blocks`, `kv_cache`) by `parseModelGenOpts` |
| PredictRich | render (if templated) → `dllm_capi_generate` → parse → ONE Reply with aggregated ChatDeltas + legacy `Message` bytes |
| PredictStreamRich | `dllm_capi_generate_stream`; per committed diffusion block → UTF-8 holdback → parser.Feed → one Reply per non-empty delta batch (channel closed by the CALLER, per `pkg/grpc/interface.go`) |
| Predict / PredictStream | Legacy paths, delegate to the rich pair (legacy stream INVERTS channel ownership: the impl closes) |
| TokenizeString | `dllm_capi_tokenize_json` (C side prepends BOS per `vocab.add_bos`) |
| Cancel | `dllm_capi_cancel`, exposed as the `grpc.Cancellable` capability (`pkg/grpc/interface.go`): the gRPC server arms it via `context.AfterFunc` on the Predict/PredictStream context, so client disconnects/timeouts abort the in-flight generate - llama.cpp `IsCancelled()` parity for Go backends |
`n_threads` and `ctx_len` are accepted-but-ignored by the engine at the
current pin (the context bound comes from GGUF `n_ctx_train`); they are sent
for forward compatibility.
## Renderer / parser (the templated chat path)
With `use_tokenizer_template` + raw Messages, the backend owns templating and
parsing (the ds4 precedent, but in Go):
- `gemma4_renderer.go` - `RenderGemma4(msgs, toolsJSON, enableThinking,
addGenerationPrompt)`. The file embeds the FULL `tokenizer.chat_template`
jinja (17466 bytes, md5 `8c34cf93c7a7815b3fdb300a009c4c17`) extracted
verbatim from `diffusiongemma-26B-A4B-it-BF16.gguf` via gguf-py - e.g.
`python scripts/dump_gguf.py model.gguf | grep -A400 chat_template` in the
dllm.cpp checkout - as a numbered comment block; every Go rule cites its
"tpl L<n>" line. Re-verify the md5 before blaming the renderer for a
mismatch with a new GGUF. **BOS exception**: the template emits
`{{- bos_token -}}` but the renderer deliberately does NOT - dllm.cpp's
`run_generate` tokenizes with `prepend_bos = vocab.add_bos` (true for
gemma4), so a literal `<bos>` would double it.
- `gemma4_parser.go` - streaming state machine turning raw model text
(fragments can split anywhere, including mid-marker) into ChatDeltas:
thought channels → `reasoning_content`, `<|tool_call>call:name{...}` →
ToolCallDelta, `<turn|>` → done. Marker grammar cross-checked against vLLM
PR #45163's gemma4 tool/reasoning parsers. Malformed payloads are re-emitted
raw as content, never dropped.
- Thinking is **opt-in** for this family (`Metadata["enable_thinking"]`,
default OFF - the inverse of ds4): the template gates every thinking branch
on `enable_thinking`, and the no-thinking render pre-closes an empty thought
channel, so the parser always starts in content state.
- **UTF-8 boundary holdback** (`splitValidUTF8` in `dllm.go`): per-block
detokenization can split a multi-byte character across block boundaries, and
grpc-go refuses to marshal invalid UTF-8 in proto3 strings. An incomplete
trailing sequence (at most 3 bytes) is carried into the next block; genuinely
undecodable bytes become U+FFFD.
Without `use_tokenizer_template`, the prompt passes through verbatim and the
output is NOT gemma4-parsed (plain content, like any non-autoparsing backend).
## Tests
| Layer | Gate | What |
|---|---|---|
| `backend/go/dllm/*_test.go` (renderer/parser/wiring) | none - run in plain `go test ./backend/go/dllm/...` | Ginkgo specs over a fake `generator` seam; canonical renderer fixtures from transformers' `test_modeling_diffusion_gemma.py`, parser tables from the vLLM gemma4 parsers |
| `backend/go/dllm/dllm_test.go` C-ABI smoke | `DLLM_TEST_LIBRARY` + `DLLM_TEST_TINY_MODEL` (dllm.cpp's `tests/fixtures/tiny_with_vocab.gguf`); Skips when unset | Drives the real `libdllm.so`: ABI check, load, tokenize `[2,18]`, deterministic generate, cancel (incl. mid-stream `Dllm.Cancel` aborting a deliberately slow `eb_max_steps:256` run in ~10ms) |
| `tests/e2e-backends/dllm_test.go` | `BACKEND_TEST_DLLM=1` + `BACKEND_BINARY` (packaged run.sh) + `BACKEND_TEST_MODEL_FILE` (tiny fixture) | Templated chat round trip (Messages + UseTokenizerTemplate) over the real gRPC binary, non-streaming + streaming; plus client-context cancellation mid-stream (proves the `Cancellable` server plumbing end to end) |
| Real-model e2e | `BACKEND_TEST_DLLM_REAL_MODEL_FILE` (26B BF16, ~50 GB) + `BACKEND_TEST_DLLM_REAL_GPU_LAYERS` | CUDA-13-class hardware only |
Tool-call e2e is deliberately absent from the tiny-model spec: the fixture has
random weights and cannot be coaxed into emitting tool markup; the unit tables
carry that coverage.
## Build matrix
`cpu-dllm` (amd64 + arm64), `cuda13-dllm` (amd64), and
`cuda13-nvidia-l4t-arm64-dllm` (arm64 CUDA: Jetson / DGX Spark GB10), via
`.github/backend-matrix.yml`. No darwin/Metal. CUDA builds forward
`-DDLLM_CUDA=ON` (dllm.cpp gates ggml's CUDA behind its own flag - a bare
`-DGGML_CUDA=ON` is overridden by the cache FORCE). `libdllm.so` is
self-contained (ggml statically absorbed, PIC), so `package.sh` only ships
the binary, `run.sh` and that one .so (the parakeet-cpp-style stub layout;
no ldd walk yet).
## Known limitations
- **Cancel granularity**: the C-ABI cancel flag is per-ctx and resets on
every generate entry, so a Cancel racing a NEW generate can be lost, and
with requests queued on the worker it aborts whichever generate is
currently running (acceptable: the server de-registers the hook on normal
completion, one process serves one model).
- **Throughput**: ~0.15 tok/s on the 26B at default settings (GB10) - every
denoise step recomputes the full prompt+canvas. The upstream prefix-KV
cache (dllm.cpp P3) is the fix; `kv_cache:on` errors until it lands
(`auto`/`off` are accepted no-ops).
- **Repo privacy**: see the note at the top - CI clone of dllm.cpp needs the
repo published (or credentials) before the backend images can build.
- Engine spec/validation references: dllm.cpp `docs/validation.md` and
LocalAI `docs/superpowers/specs/2026-06-10-dllm-cpp-design.md`.

View File

@@ -1608,19 +1608,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-13-dllm'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "dllm"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -1660,19 +1647,6 @@ include:
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'false'
tag-latest: 'auto'
tag-suffix: '-nvidia-l4t-cuda-13-arm64-dllm'
base-image: "ubuntu:24.04"
ubuntu-version: '2404'
runs-on: 'ubuntu-24.04-arm'
backend: "dllm"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -1792,6 +1766,20 @@ include:
dockerfile: "./backend/Dockerfile.llama-cpp"
context: "./"
ubuntu-version: '2404'
- build-type: 'hipblas'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-rocm-hipblas-turboquant'
builder-base-image: 'quay.io/go-skynet/ci-cache:base-grpc-rocm-amd64'
runs-on: 'ubuntu-latest'
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
skip-drivers: 'false'
backend: "turboquant"
dockerfile: "./backend/Dockerfile.turboquant"
context: "./"
ubuntu-version: '2404'
- build-type: 'hipblas'
cuda-major-version: ""
cuda-minor-version: ""
@@ -3171,35 +3159,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
# dllm
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
platform-tag: 'amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-dllm'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "dllm"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/arm64'
platform-tag: 'arm64'
tag-latest: 'auto'
tag-suffix: '-cpu-dllm'
runs-on: 'ubuntu-24.04-arm'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "dllm"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'sycl_f32'
cuda-major-version: ""
cuda-minor-version: ""

View File

@@ -38,10 +38,6 @@ jobs:
variable: "PARAKEET_VERSION"
branch: "master"
file: "backend/go/parakeet-cpp/Makefile"
- repository: "mudler/dllm.cpp"
variable: "DLLM_VERSION"
branch: "main"
file: "backend/go/dllm/Makefile"
- repository: "leejet/stable-diffusion.cpp"
variable: "STABLEDIFFUSION_GGML_VERSION"
branch: "master"

View File

@@ -26,7 +26,6 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
| [.agents/sglang-backend.md](.agents/sglang-backend.md) | Working on the SGLang backend — `engine_args` validation against ServerArgs, speculative-decoding (EAGLE/EAGLE3/DFLASH/MTP) recipes, parser handling |
| [.agents/ds4-backend.md](.agents/ds4-backend.md) | Working on the ds4 backend - DSML state machine, thinking modes, KV cache, Metal+CUDA matrix |
| [.agents/dllm-backend.md](.agents/dllm-backend.md) | Working on the dllm backend (DiffusionGemma block-diffusion) - purego C-ABI binding, per-ctx serialization contract, gemma4 renderer/parser, gated test layers |
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |

View File

@@ -1,5 +1,5 @@
# Disable parallel execution for backend builds
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/dllm backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
GOCMD=go
GOTEST=$(GOCMD) test
@@ -180,7 +180,7 @@ osx-signed: build
## Run
run: ## run local-ai
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./cmd/local-ai
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
prepare-test: protogen-go build-mock-backend
@@ -1171,9 +1171,6 @@ BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|tr
BACKEND_WHISPER = whisper|golang|.|false|true
BACKEND_CRISPASR = crispasr|golang|.|false|true
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
# dllm is mudler/dllm.cpp, the DiffusionGemma block-diffusion engine,
# wrapped by the purego backend at backend/go/dllm.
BACKEND_DLLM = dllm|golang|.|false|true
BACKEND_VOXTRAL = voxtral|golang|.|false|true
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
@@ -1263,7 +1260,6 @@ $(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_DLLM)))
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))

View File

@@ -149,16 +149,6 @@ local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
local-ai run oci://localai/phi-2:latest
```
To test a running LocalAI server from the terminal, open an interactive chat session from another shell. Inside the prompt, `/models` lists installed models and `/model <name>` switches between them.
```bash
# Terminal 1
local-ai run llama-3.2-1b-instruct:q4_k_m
# Terminal 2
local-ai chat --model llama-3.2-1b-instruct:q4_k_m
```
> **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).

View File

@@ -60,12 +60,10 @@ elseif(DS4_GPU STREQUAL "cpu")
set(DS4_OBJS "${DS4_DIR}/ds4_cpu.o")
endif()
# ds4.c now references ds4_distributed.c (distributed inference) and ds4_ssd.c
# (SSD expert-cache), each split into its own translation unit upstream. Both
# are GPU-agnostic objects shared by every GPU mode, so link them in regardless
# of DS4_GPU.
# ds4.c now references ds4_distributed.c (distributed inference was split into
# its own translation unit upstream). It is a single GPU-agnostic object shared
# by every GPU mode, so link it in regardless of DS4_GPU.
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_distributed.o")
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_ssd.o")
add_executable(${TARGET}
grpc-server.cpp

View File

@@ -1,10 +1,10 @@
# ds4 backend Makefile.
#
# Upstream pin lives below as DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
# Upstream pin lives below as DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
# (.github/bump_deps.sh) can find and update it - matches the
# llama-cpp / ik-llama-cpp / turboquant convention.
DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
DS4_REPO?=https://github.com/antirez/ds4
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
@@ -18,20 +18,19 @@ UNAME_S := $(shell uname -s)
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
# ds4_distributed.o and ds4_ssd.o are GPU-agnostic translation units that
# ds4.c/ds4_cpu.o now reference (upstream split distributed inference and the
# SSD expert-cache into their own .c files). Both objects are shared by every
# GPU mode, so they are appended unconditionally below.
# ds4_distributed.o is a GPU-agnostic translation unit that ds4.c/ds4_cpu.o now
# reference (upstream split distributed inference into its own .c). The same
# object is shared by every GPU mode, so it is appended unconditionally below.
ifeq ($(BUILD_TYPE),cublas)
CMAKE_ARGS += -DDS4_GPU=cuda
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o
else ifeq ($(UNAME_S),Darwin)
CMAKE_ARGS += -DDS4_GPU=metal
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o
else
# CPU reference path (Linux only - macOS CPU path is broken by VM bug per ds4 README).
CMAKE_ARGS += -DDS4_GPU=cpu
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o ds4_ssd.o
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o
endif
ifneq ($(NATIVE),true)
@@ -56,11 +55,11 @@ ds4:
# the right per-platform compile flags (Objective-C/Metal on Darwin, nvcc on Linux+CUDA).
ds4/ds4.o: ds4
ifeq ($(BUILD_TYPE),cublas)
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o
else ifeq ($(UNAME_S),Darwin)
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o
else
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o ds4_ssd.o
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o
endif
grpc-server: ds4/ds4.o

View File

@@ -1,5 +1,5 @@
IK_LLAMA_VERSION?=e6f8112f3ba126eed3ff5b30cdd08085414a7516
IK_LLAMA_VERSION?=1520eda980564241434b791ce2bbbd128c4be9ea
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
CMAKE_ARGS?=

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=039e20a2db9e87b2477c76cc04905f3e1acad77f
LLAMA_VERSION?=7c158fbb4aec1bdc9c81d6ca0e785139f4826fae
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?=

View File

@@ -381,15 +381,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
});
}
// for each video in the request, add the video data
for (int i = 0; i < predict->videos_size(); i++) {
data["video_data"].push_back(json
{
{"id", i},
{"data", predict->videos(i)},
});
}
data["stop"] = predict->stopprompts();
// data["n_probs"] = predict->nprobs();
//TODO: images,
@@ -491,13 +482,23 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
if (!request->draftmodel().empty()) {
params.speculative.draft.mparams.path = request->draftmodel();
// Default to draft type if a draft model is set but no explicit type.
// Upstream made the speculative type a vector (ggml-org/llama.cpp#22838)
// and renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE (#22964).
// Upstream (post ggml-org/llama.cpp#22838) made the speculative type a
// vector; the turboquant fork still uses the legacy scalar. The
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
}
#else
const bool no_spec_type = params.speculative.types.empty() ||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
if (no_spec_type) {
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
}
#endif
}
// params.model_alias ??
@@ -573,10 +574,9 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
// tokens (0 disables the minimum). Match upstream's default (256). This
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
// also shifted from a fixed cadence to a minimum spacing. The turboquant
// fork still lacks common_params::checkpoint_min_step, so skip it there
// (LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP is injected by
// backend/cpp/turboquant/patch-grpc-server.sh).
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
// fork branched before the field existed, so skip it on the legacy path
// (LOCALAI_LEGACY_LLAMA_CPP_SPEC is injected by patch-grpc-server.sh).
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
params.checkpoint_min_step = 256;
#endif
@@ -752,7 +752,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
params.cache_idle_slots = false;
}
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
@@ -906,6 +906,17 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
// Speculative decoding options
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
// Fork only knows a single scalar `type`. Take the first comma-
// separated value and assign it via the singular helper.
std::string first = optval_str;
const auto comma = first.find(',');
if (comma != std::string::npos) first = first.substr(0, comma);
auto type = common_speculative_type_from_name(first);
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
params.speculative.type = type;
}
#else
// Upstream switched to a vector of types (comma-separated for multi-type
// chaining via common_speculative_types_from_names). We keep accepting a
// single value here, but also tolerate comma-separated lists.
@@ -934,6 +945,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
if (!parsed.empty()) {
params.speculative.types = parsed;
}
#endif
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
if (optval != NULL) {
try { params.speculative.draft.n_max = std::stoi(optval_str); } catch (...) {}
@@ -971,6 +983,21 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
// shares the target context size. Accept the option for backward
// compatibility but silently ignore it.
// Everything below relies on struct shape introduced in ggml-org/llama.cpp#22838
// (parallel drafting): `ngram_mod`, `ngram_map_k`, `ngram_map_k4v`,
// `ngram_cache`, and the `draft.{cache_type_*, cpuparams*, tensor_buft_overrides}`
// fields. The turboquant fork branched before that, so its build defines
// LOCALAI_LEGACY_LLAMA_CPP_SPEC via patch-grpc-server.sh and these option
// keys become unrecognized (silently dropped, like any unknown opt) for it.
//
// The `#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC` / `#else` split below sits at the
// closing-brace position of the `draft_ctx_size` branch on purpose: in the
// legacy build the chain ends here (the brace closes draft_ctx_size), and in
// the modern build the chain continues with `} else if (...)` instead, so the
// brace count stays balanced under both branches of the preprocessor.
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
}
#else
// --- ngram_mod family (upstream --spec-ngram-mod-*) ---
} else if (!strcmp(optname, "spec_ngram_mod_n_min")) {
if (optval != NULL) {
@@ -1100,6 +1127,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
}
if (!cur.empty()) flush(cur);
}
#endif // LOCALAI_LEGACY_LLAMA_CPP_SPEC — closes the `else`/`#ifdef` opened at draft_ctx_size
}
// Set params.n_parallel from environment variable if not set via options (fallback)
@@ -1149,11 +1177,15 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
}
// Terminate the draft tensor_buft_overrides list with a sentinel, mirroring
// the main-model handling above.
// The draft tensor_buft_overrides are only populated under the modern
// (post-#22838) layout, whose population code is itself gated by
// LOCALAI_LEGACY_LLAMA_CPP_SPEC above. The turboquant fork lacks
// common_params_speculative::draft entirely, so skip the sentinel there too.
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
}
#endif
// TODO: Add yarn
@@ -1512,7 +1544,7 @@ public:
msg_json["role"] = msg.role();
bool is_last_user_msg = (i == last_user_msg_idx);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
// Handle content - can be string, null, or array
// For multimodal content, we'll embed images/audio from separate fields
@@ -1563,16 +1595,6 @@ public:
content_array.push_back(audio_chunk);
}
}
if (request->videos_size() > 0) {
for (int j = 0; j < request->videos_size(); j++) {
json video_chunk;
video_chunk["type"] = "input_video";
json input_video;
input_video["data"] = request->videos(j);
video_chunk["input_video"] = input_video;
content_array.push_back(video_chunk);
}
}
msg_json["content"] = content_array;
} else {
// Use content as-is (already array or not last user message)
@@ -1607,16 +1629,6 @@ public:
content_array.push_back(audio_chunk);
}
}
if (request->videos_size() > 0) {
for (int j = 0; j < request->videos_size(); j++) {
json video_chunk;
video_chunk["type"] = "input_video";
json input_video;
input_video["data"] = request->videos(j);
video_chunk["input_video"] = input_video;
content_array.push_back(video_chunk);
}
}
msg_json["content"] = content_array;
} else if (msg.role() == "tool") {
// Tool role messages must have content field set, even if empty
@@ -2068,16 +2080,6 @@ public:
files.push_back(decoded_data);
}
}
const auto &video_data = data.find("video_data");
if (video_data != data.end() && video_data->is_array())
{
for (const auto &video : *video_data)
{
auto decoded_data = base64_decode(video["data"].get<std::string>());
files.push_back(decoded_data);
}
}
}
const bool has_mtmd = ctx_server.impl->mctx != nullptr;
@@ -2330,7 +2332,7 @@ public:
}
bool is_last_user_msg = (i == last_user_msg_idx);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
// Handle content - can be string, null, or array
// For multimodal content, we'll embed images/audio from separate fields
@@ -2383,16 +2385,6 @@ public:
content_array.push_back(audio_chunk);
}
}
if (request->videos_size() > 0) {
for (int j = 0; j < request->videos_size(); j++) {
json video_chunk;
video_chunk["type"] = "input_video";
json input_video;
input_video["data"] = request->videos(j);
video_chunk["input_video"] = input_video;
content_array.push_back(video_chunk);
}
}
msg_json["content"] = content_array;
} else {
// Use content as-is (already array or not last user message)
@@ -2432,16 +2424,6 @@ public:
content_array.push_back(audio_chunk);
}
}
if (request->videos_size() > 0) {
for (int j = 0; j < request->videos_size(); j++) {
json video_chunk;
video_chunk["type"] = "input_video";
json input_video;
input_video["data"] = request->videos(j);
video_chunk["input_video"] = input_video;
content_array.push_back(video_chunk);
}
}
msg_json["content"] = content_array;
SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i);
} else if (!msg.tool_calls().empty()) {
@@ -2904,16 +2886,6 @@ public:
files.push_back(decoded_data);
}
}
const auto &video_data = data.find("video_data");
if (video_data != data.end() && video_data->is_array())
{
for (const auto &video : *video_data)
{
auto decoded_data = base64_decode(video["data"].get<std::string>());
files.push_back(decoded_data);
}
}
}
// process files

View File

@@ -1,7 +1,7 @@
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
TURBOQUANT_VERSION?=7d9715f1f071fa07c7b2ad3dbfd320b314139e65
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
CMAKE_ARGS?=

View File

@@ -4,19 +4,21 @@
#
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file
# so the grpc-server option parser skips the two references to
# common_params::checkpoint_min_step (the default and the option handler).
# That field does not exist in the fork yet; drop this once it does.
#
# The fork used to lag upstream on the whole common_params_speculative refactor
# (ggml-org/llama.cpp#22397/#22838/#22964), the model_tgt rename (#22838) and
# get_media_marker (#21962), which required a much larger compat shim here
# (flat-field sed renames + a coarse LOCALAI_LEGACY_LLAMA_CPP_SPEC define). The
# fork has since rebased past all of those, so the only remaining gap is
# checkpoint_min_step. If a future bump reintroduces a divergence, add a narrow
# guard in grpc-server.cpp keyed on a fork-specific macro and inject it here
# rather than resurrecting the coarse one.
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
# server-side random per-instance marker) with the legacy "<__media__>"
# literal. The fork branched before that PR, so server-common.cpp has no
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
# "<__media__>", and Go-side tooling falls back to that sentinel when the
# backend does not expose media_marker, so substituting the literal keeps
# behavior identical on the turboquant path.
# 3. Revert the `common_params_speculative` field references to the
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
# the turboquant fork branched before that PR and still exposes the flat
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
# below map the new nested paths back to the legacy flat names so the
# shared grpc-server.cpp keeps compiling against the fork's common.h.
# Drop this block once the fork rebases past #22397.
#
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
@@ -70,20 +72,72 @@ else
echo "==> KV allow-list patch OK"
fi
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file so
# the grpc-server option parser skips the two references to
# common_params::checkpoint_min_step (the default assignment and the option
# handler). That field does not exist in the fork yet. Drop this block once
# the fork rebases past the bump that added checkpoint_min_step.
if grep -q '^#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP' "$SRC"; then
echo "==> $SRC already defines LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP, skipping"
if grep -q 'get_media_marker()' "$SRC"; then
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
# Only one call site today (ModelMetadata), but replace all occurrences to
# stay robust if upstream adds more. Use a temp file to avoid relying on
# sed -i portability (the builder image uses GNU sed, but keeping this
# consistent with the awk block above).
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> get_media_marker() substitution OK"
else
echo "==> patching $SRC to define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top"
# Insert the define before the very first `#include` so it precedes the
# checkpoint_min_step references.
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
fi
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
# Each substitution is the exact post-refactor path → legacy flat field.
# Order doesn't matter because the source paths are disjoint, but we keep
# the most-specific (mparams.path) first for readability.
sed -E \
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
"$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> speculative field rename OK"
else
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
fi
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
# exposes the field as `model` on `server_context_impl`. The two call sites
# are in the Rerank and ModelMetadata RPC handlers.
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> model_tgt rename OK"
else
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
fi
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
# draft.tensor_buft_overrides) introduced for the post-#22838 layout, the
# draft.tensor_buft_overrides sentinel termination, and the
# common_params::checkpoint_min_step default/option (added with the
# 35c9b1f3 bump). Those blocks reference struct fields that simply do not
# exist in the fork.
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
else
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
# Insert the define before the very first `#include` so it precedes all the
# speculative-decoding code paths.
awk '
!done && /^#include/ {
print "#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP 1"
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
print ""
done = 1
@@ -91,13 +145,13 @@ else
{ print }
END {
if (!done) {
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP" > "/dev/stderr"
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
exit 1
}
}
' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP define OK"
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
fi
echo "==> all patches applied"

View File

@@ -1,55 +0,0 @@
hip: port the turboquant CUDA additions that ggml's HIP shim doesn't cover
The turboquant fork adds/modifies a few ggml-cuda.cu spots with CUDA APIs
that ggml's HIP (and MUSA) compatibility layer does not provide, breaking
the -gpu-rocm-hipblas-turboquant build:
1. ggml_cuda_copy2d_across_devices() (host-staged cross-device copy for
split mul_mat output) uses the CUDA 3D-peer copy APIs
cudaMemcpy3DPeerParms / make_cudaPitchedPtr / make_cudaExtent /
cudaMemcpy3DPeerAsync. HIP genuinely does not support these (see the
fork's own comment "HIP does not support cudaMemcpy3DPeerAsync"), so
guard the peer fast path with #if !defined(GGML_USE_HIP) &&
!defined(GGML_USE_MUSA) -- matching how the fork already guards the
same API for the sibling 2D copy -- and fall through to the existing
cudaMemcpyAsync staging fallback below (functionally identical,
slightly slower on multi-GPU ROCm).
2. ggml_backend_cuda_device_event_new() creates its event with plain
cudaEventCreate, which ggml's HIP shim does not alias (it only aliases
cudaEventCreateWithFlags). Use cudaEventCreateWithFlags(...,
cudaEventDisableTiming) -- exactly what the rest of this file already
does (cf. lines ~1034, ~3461) and HIP-safe.
CUDA builds are unaffected. Drop the relevant hunk once the fork HIP-ports
these; apply-patches.sh fails fast if an anchor goes stale.
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 0427e6b..6352e6a 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -1933,6 +1933,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
size_t width, size_t height, cudaStream_t dst_stream, cudaStream_t src_stream) {
const auto & info = ggml_cuda_info();
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // 3D-peer copy types unmapped by ggml's HIP/MUSA shim; use staging fallback below
if (info.peer_access[src_device][dst_device]) {
cudaMemcpy3DPeerParms p = {};
p.dstDevice = dst_device;
@@ -1942,6 +1943,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
p.extent = make_cudaExtent(width, height, 1);
return cudaMemcpy3DPeerAsync(&p, dst_stream);
}
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
// Fallback: stage all rows through a single contiguous pinned buffer
int prev_device = ggml_cuda_get_device();
@@ -5714,7 +5716,7 @@ static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_
ggml_cuda_set_device(dev_ctx->device);
cudaEvent_t event;
- CUDA_CHECK(cudaEventCreate(&event));
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
return new ggml_backend_event {
/* .device = */ dev,

View File

@@ -14,7 +14,7 @@ target_include_directories(gocrispasr PRIVATE
# whisper. crispasr is the referencer; the backend static libs supply the
# per-architecture symbols; ggml is the math/runtime base.
target_link_libraries(gocrispasr PRIVATE
crispasr-lib
crispasr
parakeet canary canary_ctc cohere granite_speech granite_nle
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# CrispASR version (release tag)
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
CRISPASR_VERSION?=c29f6653a516a3001d923944dad8892072cc7334
CRISPASR_VERSION?=13d54e110e1538e0f0bc3af0680b9ab246cfb48d
SO_TARGET?=libgocrispasr.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -1,10 +0,0 @@
.cache/
sources/
build/
package/
dllm-grpc
# build artifacts staged in-tree by the Makefile (cp from sources/) or
# symlinked for local dev; the real sources live in dllm.cpp upstream.
*.so
*.so.*
compile_commands.json

View File

@@ -1,93 +0,0 @@
# dllm backend Makefile.
#
# Upstream pin lives below as DLLM_VERSION?=<sha> so .github/bump_deps.sh
# can find and update it - matches the whisper.cpp / parakeet-cpp / ds4
# convention.
#
# Local dev shortcut: if you already have an out-of-tree dllm.cpp build,
# you can symlink the .so into this directory and skip the clone/cmake
# steps entirely, e.g.:
#
# ln -sf /path/to/dllm.cpp/build/libdllm.so .
# go build -o dllm-grpc .
#
# That's what the gated C-ABI binding smoke uses (DLLM_TEST_LIBRARY). The
# default target below does the proper clone-at-pin + cmake build so CI
# doesn't need a side-checkout.
#
# NOTE: github.com/mudler/dllm.cpp is still private (publishing is planned);
# until then the anonymous clone below fails. Use the symlink shortcut above
# with a local checkout, or a git credential helper with access to the repo.
DLLM_VERSION?=b22fcebebfb225131113188599a9ae542b2935d7
DLLM_REPO?=https://github.com/mudler/dllm.cpp
GOCMD?=go
GO_TAGS?=
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
BUILD_TYPE?=
NATIVE?=false
# libdllm.so is self-contained: dllm.cpp's CMakeLists statically absorbs ggml
# (BUILD_SHARED_LIBS=OFF + PIC) into the shared lib, so dlopen needs no
# libggml*.so alongside it, only system libs (libstdc++/libgomp/libc) the
# runtime image already provides. Tests/CLI are upstream-only concerns.
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DDLLM_BUILD_TESTS=OFF
ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif
# Same arch set the sibling ggml backends (acestep/vibevoice/qwen3-tts) bake
# for their cublas images; override for a native build.
CUDA_ARCHITECTURES?=75-virtual;80-virtual;86-real;89-real
# dllm.cpp gates CUDA behind DLLM_CUDA (set(GGML_CUDA ... CACHE FORCE)), so
# forward that instead of a bare -DGGML_CUDA=ON.
ifeq ($(BUILD_TYPE),cublas)
CMAKE_ARGS+=-DDLLM_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="$(CUDA_ARCHITECTURES)"
endif
.PHONY: dllm-grpc package build clean purge test all
all: dllm-grpc
# Clone the upstream dllm.cpp source at the pinned commit (ggml comes in as
# a submodule). Directory acts as the target so make only re-clones when
# missing. After a DLLM_VERSION bump, run 'make purge && make' to refetch.
sources/dllm.cpp:
mkdir -p sources/dllm.cpp
cd sources/dllm.cpp && \
git init -q && \
git remote add origin $(DLLM_REPO) && \
git fetch --depth 1 origin $(DLLM_VERSION) && \
git checkout FETCH_HEAD && \
git submodule update --init --recursive --depth 1 --single-branch
# Build the shared lib out-of-tree, then stage it next to the Go sources so
# purego.Dlopen("libdllm.so") and the packaging step both pick it up.
libdllm.so: sources/dllm.cpp
cmake -B sources/dllm.cpp/build -S sources/dllm.cpp $(CMAKE_ARGS)
cmake --build sources/dllm.cpp/build --config Release -j$(JOBS)
cp -fv sources/dllm.cpp/build/libdllm.so ./
dllm-grpc: libdllm.so main.go capi.go
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o dllm-grpc .
package: dllm-grpc
bash package.sh
build: package
# Test target. The C-ABI binding smoke is gated on DLLM_TEST_LIBRARY +
# DLLM_TEST_TINY_MODEL; without them the gated specs auto-skip and only the
# pure-Go helper specs run.
test:
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
clean: purge
rm -rf libdllm.so* package dllm-grpc
purge:
rm -rf sources/dllm.cpp

View File

@@ -1,256 +0,0 @@
package main
// Typed Go wrappers over dllm.cpp's flat C-ABI (include/dllm_capi.h, ABI v1).
//
// Contract highlights the wrappers encode (see the header + src/capi.cpp):
// - tokenize_json/generate return malloc'd char* the CALLER owns: bound as
// uintptr, copied with goStringFromCPtr, released via dllm_capi_free_string.
// - last_error returns a BORROWED pointer (valid until the next call on the
// same ctx): bound as a plain string (purego copies), never freed, and only
// read AFTER the failing call has returned - reading it while a generate is
// in flight on the same ctx violates the per-ctx serialization contract.
// - All entry points except dllm_capi_cancel must be externally serialized
// per ctx (one ctx = one concurrent generate/tokenize). Cancel only flips
// an atomic and may be called from any goroutine mid-generate.
// - No C++ exception crosses the boundary; failures land in last_error.
import (
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"unsafe"
"github.com/ebitengine/purego"
)
// dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written
// against; main.go refuses to start against a libdllm.so reporting another.
const dllmABIVersion = 1
// purego-bound entry points from libdllm.so. Names match dllm_capi.h
// exactly; loadCAPI (main.go) fills these in at boot.
var (
cppAbiVersion func() int32
cppLoad func(ggufPath, paramsJSON string) uintptr
cppFree func(ctx uintptr)
cppLastError func(ctx uintptr) string // borrowed pointer: purego copies, do NOT free
cppFreeString func(s uintptr)
// malloc'd char* returns, hence uintptr (see loadCAPI's doc comment).
cppTokenizeJSON func(ctx uintptr, text string) uintptr
cppGenerate func(ctx uintptr, prompt, optsJSON string) uintptr
// on_block/on_step are C function pointers produced by purego.NewCallback;
// userData carries the streamCallStates registry key.
cppGenerateStream func(ctx uintptr, prompt, optsJSON string, onBlock, onStep, userData uintptr) int32
cppCancel func(ctx uintptr)
)
// cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION.
func cAbiVersion() int32 {
return cppAbiVersion()
}
// cLoad opens the GGUF at path with the flat params JSON (e.g.
// {"n_gpu_layers":99}). Returns 0 on failure; per the header contract there
// is no ctx to carry the reason, the C side logs it to stderr (and
// cLastError(0) only yields the static NULL-ctx message).
func cLoad(path, paramsJSON string) uintptr {
return cppLoad(path, paramsJSON)
}
// cFree releases a ctx; safe on 0 (delete nullptr).
func cFree(h uintptr) {
cppFree(h)
}
// cLastError returns the ctx's last error message (or the static NULL-ctx
// message for h==0). The C pointer is borrowed and only valid until the next
// call on the same ctx; purego's string return copies it immediately, so the
// returned Go string is safe to keep. Must not be called while another call
// on the same ctx is in flight.
func cLastError(h uintptr) string {
return cppLastError(h)
}
// lastErrorOr is cLastError with a fallback for the empty-message case, so
// wrapped errors never end in ": ".
func lastErrorOr(h uintptr, fallback string) string {
if msg := cLastError(h); msg != "" {
return msg
}
return fallback
}
// cTokenizeJSON tokenizes text (the C side prepends bos per vocab.add_bos)
// and returns the token ids as a JSON array string, e.g. "[2,18]".
func cTokenizeJSON(h uintptr, text string) (string, error) {
ret := cppTokenizeJSON(h, text)
if ret == 0 {
return "", fmt.Errorf("dllm: tokenize failed: %s", lastErrorOr(h, "unknown error"))
}
out := goStringFromCPtr(ret)
cppFreeString(ret)
return out, nil
}
// cGenerate runs a blocking generation and returns the detokenized text.
// optsJSON must be a FLAT JSON object of scalars (use buildOptsJSON); the C
// parser rejects nested objects/arrays. NULL return -> last_error (read only
// after the call returned, per the serialization contract); a cancelled call
// surfaces as the "cancelled" message.
func cGenerate(h uintptr, prompt, optsJSON string) (string, error) {
ret := cppGenerate(h, prompt, optsJSON)
if ret == 0 {
return "", fmt.Errorf("dllm: generate failed: %s", lastErrorOr(h, "unknown error"))
}
out := goStringFromCPtr(ret)
cppFreeString(ret)
return out, nil
}
// streamCallState carries the Go callbacks for one in-flight
// cGenerateStream call; the registry key travels through C as user_data.
// The map shape mirrors the whisper backend's streamCallStates: only one
// entry per ctx is ever live (the C-ABI is serialized per ctx), but keying
// by call survives multiple models/processes sharing the package.
type streamCallState struct {
onBlock func(text string)
onStep func(step, total int, preview string)
}
var (
streamCallStates sync.Map // uint64 -> *streamCallState
streamCallSeq atomic.Uint64
// purego.NewCallback allocates a finite, never-released callback slot, so
// the two trampolines are created exactly once and reused across calls.
streamCbOnce sync.Once
blockCbPtr uintptr
stepCbPtr uintptr
)
// onBlockTrampoline is the Go side of dllm_block_cb. It runs on the C
// calling thread, mid-generate: keep it tiny and non-blocking (callers that
// bridge to goroutines must hand off via buffered channels). The text
// pointer is only valid for the duration of the invocation, so it is copied
// to a Go string immediately.
func onBlockTrampoline(text uintptr, userData uintptr) {
v, ok := streamCallStates.Load(uint64(userData))
if !ok {
return // call already torn down
}
state := v.(*streamCallState)
if state.onBlock != nil {
state.onBlock(goStringFromCPtr(text))
}
}
// onStepTrampoline is the Go side of dllm_step_cb; same threading and
// lifetime caveats as onBlockTrampoline.
func onStepTrampoline(step int32, totalSteps int32, canvasPreview uintptr, userData uintptr) {
v, ok := streamCallStates.Load(uint64(userData))
if !ok {
return
}
state := v.(*streamCallState)
if state.onStep != nil {
state.onStep(int(step), int(totalSteps), goStringFromCPtr(canvasPreview))
}
}
// cGenerateStream runs a generation with per-committed-block (onBlock) and
// per-denoising-step (onStep) callbacks; either may be nil. The callbacks
// run on the C thread (see the trampoline docs). Returns an error carrying
// last_error on failure; cancellation surfaces as the "cancelled" message.
func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error {
streamCbOnce.Do(func() {
blockCbPtr = purego.NewCallback(onBlockTrampoline)
stepCbPtr = purego.NewCallback(onStepTrampoline)
})
id := streamCallSeq.Add(1)
streamCallStates.Store(id, &streamCallState{onBlock: onBlock, onStep: onStep})
defer streamCallStates.Delete(id)
// Pass NULL for absent callbacks so the C side skips the per-block /
// per-step detokenize work entirely.
var blockPtr, stepPtr uintptr
if onBlock != nil {
blockPtr = blockCbPtr
}
if onStep != nil {
stepPtr = stepCbPtr
}
if rc := cppGenerateStream(h, prompt, optsJSON, blockPtr, stepPtr, uintptr(id)); rc != 0 {
return fmt.Errorf("dllm: generate_stream failed: %s", lastErrorOr(h, "unknown error"))
}
return nil
}
// cCancel requests cancellation of the in-flight generate on h. This is the
// ONE entry point safe to call from any goroutine while a generate runs (it
// only flips an atomic). Note the cancel-reset race from the header: each
// generate resets the flag on entry, so a watchdog should re-issue cancel if
// the call has not returned.
func cCancel(h uintptr) {
cppCancel(h)
}
// buildOptsJSON renders generation options as the flat JSON object the
// C-ABI expects (known keys: n_predict, blocks, seed, eb_*, kv_cache). The
// C-side scanner only understands scalar number/string values and rejects
// nested objects/arrays loudly; bools are rejected here too because the
// scanner has no concept of them. Fail loud rather than let an option be
// silently misread.
//
// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g.
// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys
// (kv_cache: auto|on|off) can contain those bytes today; if one ever does,
// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString.
func buildOptsJSON(opts map[string]any) (string, error) {
if len(opts) == 0 {
return "{}", nil
}
for k, v := range opts {
switch v.(type) {
case string,
int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
float32, float64,
json.Number:
// scalar: fine
default:
return "", fmt.Errorf("dllm: opts key %q has non-scalar value %T (the C-ABI only accepts flat number/string scalars)", k, v)
}
}
b, err := json.Marshal(opts)
if err != nil {
return "", fmt.Errorf("dllm: marshal opts: %w", err)
}
return string(b), nil
}
// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is
// the raw pointer returned by purego from the C-ABI (a malloc'd buffer the
// caller owns, or a callback argument only valid during the invocation);
// owning callers must free it via cppFreeString after the copy lands.
//
// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check,
// which can't distinguish a C-owned heap pointer from Go-managed memory (the
// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting
// through &cptr below is equivalent at runtime and keeps plain `go vet`
// clean. It is safe either way: the pointer addresses C memory the Go GC
// neither tracks nor moves, and we dereference it immediately to copy the
// bytes out.
func goStringFromCPtr(cptr uintptr) string {
if cptr == 0 {
return ""
}
p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above)
n := 0
for *(*byte)(unsafe.Add(p, n)) != 0 {
n++
}
return string(unsafe.Slice((*byte)(p), n))
}

View File

@@ -1,553 +0,0 @@
package main
// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models).
//
// Wiring overview:
// - Load opens the GGUF via dllm_capi_load and starts the per-model worker
// goroutine that serializes every C call (see submit).
// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the
// request carries raw messages (use_tokenizer_template), the backend owns
// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies
// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend.
// - The legacy Predict / PredictStream methods delegate to the rich pair
// (cloud-proxy precedent); the gRPC server prefers the rich path anyway.
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"unicode/utf8"
grpc "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/base"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
)
// The gRPC server cancels in-flight generations on client disconnect only
// for backends advertising the Cancellable capability; keep Dllm pinned to
// it so a signature drift fails the build, not the disconnect path.
var _ grpc.Cancellable = (*Dllm)(nil)
// generator is the seam between the backend wiring and the dllm.cpp C-ABI:
// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON
// family, while tests substitute a fake to exercise prompt construction,
// parsing and serialization without libdllm.so.
type generator interface {
generate(prompt, optsJSON string) (string, error)
// generateStream invokes onBlock once per committed diffusion block, on
// the thread running the C call, before returning.
generateStream(prompt, optsJSON string, onBlock func(text string)) error
tokenizeJSON(text string) (string, error)
// cancel is the ONE entry point safe to call concurrently with an
// in-flight generate on the same ctx (dllm_capi.h: it only flips an
// atomic; everything else must be externally serialized per ctx).
cancel()
free()
}
// capiGenerator is the production generator over one dllm_ctx handle.
type capiGenerator struct {
h uintptr
}
func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) {
return cGenerate(g.h, prompt, optsJSON)
}
func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
// on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is
// passed as nil for now: a future progress hook for the React UI can
// plumb it through without touching the C binding.
return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil)
}
func (g *capiGenerator) tokenizeJSON(text string) (string, error) {
return cTokenizeJSON(g.h, text)
}
func (g *capiGenerator) cancel() {
cCancel(g.h)
}
func (g *capiGenerator) free() {
cFree(g.h)
}
// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts
// one backend process per model).
type Dllm struct {
base.Base
gen generator
// genOpts holds the model-level generation overrides parsed from
// ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes
// them per-generate, not per-load, so they are merged into every
// request's opts JSON (requestOptsJSON).
genOpts map[string]any
// jobs is the per-model worker queue. dllm_capi.h requires every entry
// point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one
// ctx = one concurrent generate/tokenize; last_error is unsafe to read
// while a call is in flight). A single goroutine owning all C calls makes
// that contract structural instead of relying on lock discipline.
jobs chan func()
workerWG sync.WaitGroup
// genMu guards gen against Free racing in-flight requests: requests hold
// the read lock for their full duration (they stay concurrent with each
// other - the worker still serializes the C calls), Free takes the write
// lock so it can only run when no request is in flight.
genMu sync.RWMutex
}
func (d *Dllm) startWorker() {
d.jobs = make(chan func())
d.workerWG.Add(1)
go func() {
defer d.workerWG.Done()
for job := range d.jobs {
job()
}
}()
}
// submit runs job on the worker goroutine and waits for it to finish.
// Concurrent gRPC requests therefore queue up and execute one at a time
// against the single dllm_ctx.
func (d *Dllm) submit(job func()) {
done := make(chan struct{})
d.jobs <- func() {
defer close(done)
job()
}
<-done
}
// Load opens the GGUF and prepares the worker. Load-time engine parameters
// travel as the flat params JSON of dllm_capi_load; generation overrides
// from Options are stored for per-request opts JSON instead (the C-ABI has
// no per-load sampler state).
func (d *Dllm) Load(opts *pb.ModelOptions) error {
if d.gen != nil {
return errors.New("dllm: model already loaded")
}
params := map[string]any{
"n_gpu_layers": opts.GetNGPULayers(),
}
if opts.GetThreads() > 0 {
params["n_threads"] = opts.GetThreads()
}
if opts.GetContextSize() > 0 {
params["ctx_len"] = opts.GetContextSize()
}
paramsJSON, err := buildOptsJSON(params)
if err != nil {
return err
}
d.genOpts = parseModelGenOpts(opts.GetOptions())
h := cLoad(opts.GetModelFile(), paramsJSON)
if h == 0 {
// No ctx exists on load failure, so last_error(NULL) only carries the
// static NULL-ctx message; the real reason is on the backend's stderr.
return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)",
opts.GetModelFile(), lastErrorOr(0, "unknown error"))
}
d.gen = &capiGenerator{h: h}
d.startWorker()
xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts)
return nil
}
// Free releases the dllm ctx and stops the worker. Safe when never loaded.
//
// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the
// model-unload path around line 764) calls Free with no locking of its own,
// and base.Base provides none either. Without it a request racing Free would
// panic sending on the closed jobs channel - or worse, generate on a freed C
// ctx. Holding genMu until gen is nil also turns post-Free requests into a
// clean "model not loaded" error instead of a crash.
func (d *Dllm) Free() error {
d.genMu.Lock()
defer d.genMu.Unlock()
if d.gen == nil {
return nil
}
d.submit(d.gen.free)
close(d.jobs)
d.workerWG.Wait()
d.gen = nil
return nil
}
// Cancel requests cancellation of the in-flight generate (the
// grpc.Cancellable capability). The gRPC server arms it via
// context.AfterFunc on the request/stream context, so a client
// disconnect or timeout aborts the generation server-side - the same
// semantics the llama.cpp C++ backend gets from polling IsCancelled().
// It deliberately bypasses the worker queue: dllm_capi_cancel is the one
// call the C-ABI allows from any goroutine mid-generate (it only flips
// an atomic).
//
// Note dllm_capi.h's cancel-reset race: each generate resets the flag on
// entry, so a Cancel racing a NEW generate on the same ctx can be lost
// (and, with requests queued on the worker, it aborts whichever generate
// is currently running). The single-flag granularity is acceptable here
// because the server de-registers the hook on normal completion and one
// backend process serves one model.
func (d *Dllm) Cancel() {
// RLock so a server-side AfterFunc firing in the window between a
// request finishing and a model unload cannot touch a freed C ctx
// (Free holds the write lock while tearing gen down). cancel() is the
// one C call that is safe concurrently with an in-flight generate, so
// taking a read lock here cannot deadlock against request holders.
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen != nil {
d.gen.cancel()
}
}
// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to
// the engine. Options is a shared free-form bag (other layers put their own
// entries there), so unknown keys are skipped with a warning, not an error.
var dllmGenOptKeys = map[string]bool{
"blocks": true,
"kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3
}
// parseModelGenOpts parses "key:value" Options entries into the flat scalar
// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler
// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by
// first successful parse (int, then float, else string) to match the C
// scanner's number/string scalars.
func parseModelGenOpts(options []string) map[string]any {
out := map[string]any{}
for _, o := range options {
key, val, found := strings.Cut(o, ":")
if !found {
xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o)
continue
}
if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] {
xlog.Debug("dllm: ignoring unrecognized option", "key", key)
continue
}
out[key] = parseScalarOpt(val)
}
return out
}
func parseScalarOpt(v string) any {
if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
return iv
}
if fv, err := strconv.ParseFloat(v, 64); err == nil {
return fv
}
return v
}
// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default
// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat
// template guards every thinking branch with `enable_thinking is defined and
// enable_thinking`, i.e. thinking is opt-in for this model family, and the
// no-thinking render pre-closes an empty thought channel that the OFF
// default must produce.
func metadataEnableThinking(opts *pb.PredictOptions) bool {
v := opts.GetMetadata()["enable_thinking"]
return v == "true" || v == "1"
}
// buildPrompt resolves the prompt for a request. With use_tokenizer_template
// and raw messages the backend owns templating (RenderGemma4) and the output
// is in the known gemma4 format, so parse=true. Without it the caller
// templated the prompt themselves (LocalAI's Go templates + PEG fallback, or
// a bare completion): the prompt passes through verbatim and the output is
// NOT gemma4-parsed - it is emitted as plain content and the Go side's
// extraction applies, as for any non-autoparsing backend.
func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) {
if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 {
prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), metadataEnableThinking(opts), true)
return prompt, true, err
}
return opts.GetPrompt(), false, nil
}
// requestOptsJSON merges the model-level overrides with the request's
// sampling fields into the flat opts JSON for one generate call.
func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) {
m := make(map[string]any, len(d.genOpts)+2)
for k, v := range d.genOpts {
m[k] = v
}
if n := opts.GetTokens(); n > 0 {
// The engine rounds n_predict UP to a whole number of diffusion
// blocks (the canvas is denoised block-wise), so the completion may
// run slightly past the requested budget. Tokens==0 omits the key so
// the C-ABI default of 256 applies (hardcoded in capi.cpp's
// parse_gen_opts, independent of canvas_length).
m["n_predict"] = n
}
if s := opts.GetSeed(); s > 0 {
// The engine seeds mt19937 with explicit non-negative seeds. Seed<=0
// is omitted: proto3 cannot distinguish 0 from unset, and negative
// values conventionally mean "random" across LocalAI backends.
m["seed"] = s
}
return buildOptsJSON(m)
}
// prepareRequest is the shared prologue of the rich methods: resolve the
// prompt (and whether the output gets gemma4-parsed) and build the per-call
// opts JSON.
func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON string, err error) {
prompt, parse, err = buildPrompt(opts)
if err != nil {
return "", false, "", err
}
optsJSON, err = d.requestOptsJSON(opts)
if err != nil {
return "", false, "", err
}
return prompt, parse, optsJSON, nil
}
// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary
// detokenization and byte-fallback tokens can produce invalid UTF-8, and
// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so
// every string destined for a Reply/ChatDelta must pass through here (or
// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely
// undecodable: replace with U+FFFD rather than crash the stream.
func sanitizeUTF8(s string) string {
if utf8.ValidString(s) {
return s
}
return strings.ToValidUTF8(s, "<22>")
}
// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte
// (1 for bytes that can never lead a multi-byte sequence, so they are never
// held back and fall through to sanitizeUTF8's replacement).
func utf8SeqLen(b byte) int {
switch {
case b&0xE0 == 0xC0:
return 2
case b&0xF0 == 0xE0:
return 3
case b&0xF8 == 0xF0:
return 4
default:
return 1
}
}
// splitValidUTF8 prepends the previous block's carry to the new block and
// splits the result into text safe to emit now and a trailing INCOMPLETE
// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block:
// the per-block detokenize can split a multi-byte character across block
// boundaries (llama.cpp's grpc-server holds back the same way). Only a
// suffix that can still become a valid rune is withheld; bytes that are
// already undecodable are replaced immediately so the carry stays bounded.
func splitValidUTF8(carry, block string) (emit, newCarry string) {
s := carry + block
cut := len(s)
for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- {
b := s[i]
if b < utf8.RuneSelf {
break // ASCII: everything before the tail scan is complete
}
if !utf8.RuneStart(b) {
continue // continuation byte: keep looking for its leading byte
}
// Leading byte: hold the sequence back iff it declares more bytes
// than the stream has produced so far (it may complete next block).
if utf8SeqLen(b) > len(s)-i {
cut = i
}
break
}
return sanitizeUTF8(s[:cut]), s[cut:]
}
// PredictRich is the non-streaming inference path (grpc.AIModelRich).
// Returns one Reply whose Message is the aggregated assistant content and
// whose ChatDeltas carry the parsed content/reasoning/tool-call events.
func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) {
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen == nil {
return nil, grpcerrors.ModelNotLoaded("dllm")
}
prompt, parse, optsJSON, err := d.prepareRequest(opts)
if err != nil {
return nil, err
}
var out string
var genErr error
d.submit(func() {
out, genErr = d.gen.generate(prompt, optsJSON)
})
if genErr != nil {
return nil, genErr
}
// Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings
// must be valid or grpc-go fails the whole reply at marshal time.
out = sanitizeUTF8(out)
if !parse {
// Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt).
return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil
}
// The prompt renders with add_generation_prompt; both thinking modes
// leave the model starting in content state (see the Gemma4Parser header
// comment), hence NewGemma4Parser(false).
parser := NewGemma4Parser(false)
if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil {
return reply, nil
}
// Everything was markers (or out was empty): an empty but non-nil Reply.
return &pb.Reply{}, nil
}
// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one
// Reply per committed diffusion block that produced deltas. Per the
// interface contract the channel is only sent into here - the gRPC server
// closes it after this returns (opposite to legacy PredictStream).
func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error {
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen == nil {
return grpcerrors.ModelNotLoaded("dllm")
}
prompt, parse, optsJSON, err := d.prepareRequest(opts)
if err != nil {
return err
}
var parser *Gemma4Parser
if parse {
parser = NewGemma4Parser(false)
}
// emit runs inside onBlock, i.e. on the thread driving the C generate.
// Sending on results can block on a slow consumer, but the server-side
// pump (pkg/grpc/server.go PredictStream) drains continuously and drops
// undeliverable sends, so this backpressure is brief and bounded - and
// pausing the diffusion loop under it is the desired behavior anyway.
emit := func(text string) {
if !parse {
if text != "" {
results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}}
}
return
}
deltas := parser.Feed(text)
if reply := replyFromDeltas(deltas); reply != nil {
results <- reply
}
}
// onBlock guards emit (and through it the parser) against invalid UTF-8:
// a multi-byte character split across block boundaries is held back until
// it completes (see splitValidUTF8), so proto3 marshaling never fails.
var carry string
onBlock := func(block string) {
var text string
text, carry = splitValidUTF8(carry, block)
emit(text)
}
var genErr error
d.submit(func() {
genErr = d.gen.generateStream(prompt, optsJSON, onBlock)
})
if genErr != nil {
return genErr
}
if carry != "" {
// The stream ended mid-sequence: the held-back bytes can no longer
// complete, so flush them through the U+FFFD last resort.
emit(sanitizeUTF8(carry))
}
if parse {
if reply := replyFromDeltas(parser.Close()); reply != nil {
results <- reply
}
}
return nil
}
// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply,
// or nil when the batch is empty (markers consumed, nothing emitted yet).
// Message mirrors the batch's content text so legacy chan-string consumers
// see exactly the displayed tokens.
func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply {
if len(deltas) == 0 {
return nil
}
var content strings.Builder
for _, delta := range deltas {
content.WriteString(delta.GetContent())
}
return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas}
}
// Predict is the legacy (string, error) signature; the gRPC server prefers
// PredictRich, this exists for non-rich callers (cloud-proxy precedent).
func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) {
reply, err := d.PredictRich(opts)
if err != nil {
return "", err
}
return string(reply.GetMessage()), nil
}
// PredictStream is the legacy chan-string path: rich replies reduced to
// their content text. Note the inverted channel ownership - the LEGACY
// contract requires the impl to close the channel.
func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error {
defer close(results)
richCh := make(chan *pb.Reply)
errCh := make(chan error, 1)
go func() {
errCh <- d.PredictStreamRich(opts, richCh)
close(richCh)
}()
for reply := range richCh {
if msg := reply.GetMessage(); len(msg) > 0 {
results <- string(msg)
}
}
return <-errCh
}
// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C
// side prepends bos per the vocab) and decodes the returned id array.
func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
d.genMu.RLock()
defer d.genMu.RUnlock()
if d.gen == nil {
return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm")
}
var out string
var tokErr error
d.submit(func() {
out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt())
})
if tokErr != nil {
return pb.TokenizationResponse{}, tokErr
}
var tokens []int32
if err := json.Unmarshal([]byte(out), &tokens); err != nil {
return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err)
}
return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil
}

View File

@@ -1,807 +0,0 @@
package main
import (
"errors"
"os"
"runtime"
"sync"
"testing"
"time"
"unicode/utf8"
"unsafe"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
func TestDllm(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "dllm Backend Suite")
}
var (
libLoadOnce sync.Once
libLoadErr error
)
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
// C-ABI bridge without spinning up the gRPC server. The library path comes
// from DLLM_TEST_LIBRARY (gated specs Skip when it is unset).
func ensureLibLoaded() {
libLoadOnce.Do(func() {
libLoadErr = loadCAPI(os.Getenv("DLLM_TEST_LIBRARY"))
})
}
// C-ABI binding smoke: drives the real libdllm.so against the tiny GGUF
// fixture from dllm.cpp (tests/fixtures/tiny_with_vocab.gguf). Gated on:
//
// DLLM_TEST_LIBRARY absolute path to libdllm.so
// DLLM_TEST_TINY_MODEL absolute path to tiny_with_vocab.gguf
var _ = Describe("C-ABI binding", func() {
BeforeEach(func() {
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the C-ABI binding smoke")
}
ensureLibLoaded()
Expect(libLoadErr).ToNot(HaveOccurred())
})
It("binds the 9 symbols and round-trips the tiny model", func() {
Expect(cAbiVersion()).To(Equal(int32(1)))
h := cLoad(os.Getenv("DLLM_TEST_TINY_MODEL"), "{}")
Expect(h).ToNot(BeZero(), "dllm_capi_load of the tiny fixture")
// Tiny fixture vocab: "hello" tokenizes to ids [2,18] (bos prepended
// by the C side: vocab.add_bos).
toks, err := cTokenizeJSON(h, "hello")
Expect(err).ToNot(HaveOccurred())
Expect(toks).To(Equal("[2,18]"))
// Deterministic generation: an explicit non-negative seed seeds
// mt19937, so two identical calls must produce identical text.
out1, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`)
Expect(err).ToNot(HaveOccurred())
Expect(out1).ToNot(BeEmpty())
// Cancel with no call in flight is dropped: each generate resets the
// cancel flag on entry (header contract), so this must not affect
// the next call. Also binds the 9th symbol; safe on NULL too.
cCancel(h)
cCancel(0)
out2, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`)
Expect(err).ToNot(HaveOccurred())
Expect(out2).To(Equal(out1))
// Streaming variant: same opts, blocks arrive via the purego
// callback trampoline. The per-block detokenize can differ from the
// seamless full-text decode at block boundaries, so only assert that
// blocks arrived and were non-trivial, not byte equality with out1.
var blocks []string
var steps int
err = cGenerateStream(h, "hello", `{"n_predict":16,"seed":7}`,
func(text string) { blocks = append(blocks, text) },
func(step, total int, preview string) { steps++ },
)
Expect(err).ToNot(HaveOccurred())
Expect(blocks).ToNot(BeEmpty())
Expect(steps).To(BeNumerically(">", 0))
// Load failure path: NULL ctx back, and last_error(NULL) returns the
// static NULL-ctx message (there is no ctx to carry the real reason).
bad := cLoad("/nonexistent/dllm-model.gguf", "{}")
Expect(bad).To(BeZero())
Expect(cLastError(0)).ToNot(BeEmpty())
// Free is safe on a live handle and a NULL one (delete nullptr).
cFree(h)
cFree(0)
})
})
// Ungated specs for the pure-Go helpers (no libdllm.so required).
var _ = Describe("buildOptsJSON", func() {
It("renders flat scalars as a JSON object", func() {
out, err := buildOptsJSON(map[string]any{
"n_predict": 16,
"seed": int64(7),
"eb_t_min": 0.5,
"kv_cache": "auto",
})
Expect(err).ToNot(HaveOccurred())
Expect(out).To(MatchJSON(`{"n_predict":16,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
})
It("renders an empty object for no options", func() {
out, err := buildOptsJSON(nil)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal("{}"))
})
It("rejects nested objects (the C-side scanner only reads flat scalars)", func() {
_, err := buildOptsJSON(map[string]any{"sampler": map[string]any{"seed": 1}})
Expect(err).To(HaveOccurred())
})
It("rejects arrays", func() {
_, err := buildOptsJSON(map[string]any{"stop": []string{"a"}})
Expect(err).To(HaveOccurred())
})
It("rejects booleans (the C-side scanner only understands numbers and strings)", func() {
_, err := buildOptsJSON(map[string]any{"flag": true})
Expect(err).To(HaveOccurred())
})
})
var _ = Describe("splitValidUTF8", func() {
It("holds back a trailing incomplete sequence and completes it next block", func() {
emit, carry := splitValidUTF8("", "caf\xe2")
Expect(emit).To(Equal("caf"))
Expect(carry).To(Equal("\xe2"))
emit, carry = splitValidUTF8(carry, "\x82")
Expect(emit).To(BeEmpty())
Expect(carry).To(Equal("\xe2\x82"))
emit, carry = splitValidUTF8(carry, "\xac!")
Expect(emit).To(Equal("€!"))
Expect(carry).To(BeEmpty())
})
It("holds back up to 3 bytes of a 4-byte sequence", func() {
emit, carry := splitValidUTF8("", "x\xf0\x9f\x98") // 😀 missing its last byte
Expect(emit).To(Equal("x"))
Expect(carry).To(Equal("\xf0\x9f\x98"))
emit, carry = splitValidUTF8(carry, "\x80")
Expect(emit).To(Equal("😀"))
Expect(carry).To(BeEmpty())
})
It("replaces undecodable bytes immediately instead of carrying them", func() {
// A mid-string invalid byte can never complete: carrying it would let
// the carry grow unboundedly, so it is substituted on the spot.
emit, carry := splitValidUTF8("", "a\xe2bc")
Expect(emit).To(Equal("a<>bc"))
Expect(carry).To(BeEmpty())
// Orphan continuation bytes at the end have no leading byte to wait
// for either.
emit, carry = splitValidUTF8("", "a\x82")
Expect(emit).To(Equal("a<>"))
Expect(carry).To(BeEmpty())
})
It("passes pure ASCII and complete UTF-8 through untouched", func() {
emit, carry := splitValidUTF8("", "héllo €")
Expect(emit).To(Equal("héllo €"))
Expect(carry).To(BeEmpty())
})
})
var _ = Describe("goStringFromCPtr", func() {
It("copies a NUL-terminated buffer", func() {
buf := []byte("dllm\x00")
s := goStringFromCPtr(uintptr(unsafe.Pointer(&buf[0])))
// The uintptr round-trip hides buf from the GC's liveness analysis;
// keep it reachable until after the copy.
runtime.KeepAlive(buf)
Expect(s).To(Equal("dllm"))
})
It("returns the empty string for NULL", func() {
Expect(goStringFromCPtr(0)).To(Equal(""))
})
})
// ---------------------------------------------------------------------------
// Backend wiring (T4): fake-generator specs, no libdllm.so required.
// ---------------------------------------------------------------------------
type fakeGenCall struct {
prompt string
optsJSON string
}
// fakeGen implements generator in-process. It records every call (prompt +
// opts JSON), tracks concurrent in-flight calls to prove worker
// serialization, and replays canned output (out for generate/tokenize,
// blocks for generateStream).
type fakeGen struct {
mu sync.Mutex
calls []fakeGenCall
inFlight int
maxInFlight int
out string
blocks []string
err error
delay time.Duration
}
func (f *fakeGen) begin(prompt, optsJSON string) {
f.mu.Lock()
defer f.mu.Unlock()
f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON})
f.inFlight++
if f.inFlight > f.maxInFlight {
f.maxInFlight = f.inFlight
}
}
func (f *fakeGen) end() {
f.mu.Lock()
defer f.mu.Unlock()
f.inFlight--
}
func (f *fakeGen) snapshot() (calls []fakeGenCall, maxInFlight int) {
f.mu.Lock()
defer f.mu.Unlock()
return append([]fakeGenCall(nil), f.calls...), f.maxInFlight
}
func (f *fakeGen) generate(prompt, optsJSON string) (string, error) {
f.begin(prompt, optsJSON)
defer f.end()
if f.delay > 0 {
time.Sleep(f.delay)
}
return f.out, f.err
}
func (f *fakeGen) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
f.begin(prompt, optsJSON)
defer f.end()
if f.err != nil {
return f.err
}
for _, b := range f.blocks {
onBlock(b)
}
return nil
}
func (f *fakeGen) tokenizeJSON(text string) (string, error) {
f.begin(text, "")
defer f.end()
return f.out, f.err
}
func (f *fakeGen) cancel() {}
func (f *fakeGen) free() {}
// newTestDllm assembles a backend around a fake generator (bypassing Load,
// which needs libdllm.so) and registers cleanup of the worker goroutine.
func newTestDllm(g generator, genOpts map[string]any) *Dllm {
d := &Dllm{gen: g, genOpts: genOpts}
d.startWorker()
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
return d
}
// drainReplies empties ch without blocking, failing the spec if the channel
// was closed (PredictStreamRich must NOT close it - interface.go contract).
// Size ch above the expected reply count: an overflow deadlocks the spec on
// the producer's send instead of failing it.
func drainReplies(ch chan *pb.Reply) []*pb.Reply {
var out []*pb.Reply
for {
select {
case r, ok := <-ch:
if !ok {
Fail("PredictStreamRich closed the results channel (the gRPC server owns the close)")
}
expectValidUTF8Reply(r)
out = append(out, r)
default:
return out
}
}
}
// expectValidUTF8Reply is the blanket guard for the proto3 marshaling
// contract: grpc-go rejects any string field carrying invalid UTF-8, so every
// reply field that ends up in a proto string must validate.
func expectValidUTF8Reply(r *pb.Reply) {
GinkgoHelper()
Expect(utf8.ValidString(string(r.GetMessage()))).To(BeTrue(), "Reply.Message carries invalid UTF-8")
for _, delta := range r.GetChatDeltas() {
Expect(utf8.ValidString(delta.GetContent())).To(BeTrue(), "ChatDelta.Content carries invalid UTF-8")
Expect(utf8.ValidString(delta.GetReasoningContent())).To(BeTrue(), "ChatDelta.ReasoningContent carries invalid UTF-8")
for _, tc := range delta.GetToolCalls() {
Expect(utf8.ValidString(tc.GetName())).To(BeTrue(), "ToolCallDelta.Name carries invalid UTF-8")
Expect(utf8.ValidString(tc.GetArguments())).To(BeTrue(), "ToolCallDelta.Arguments carries invalid UTF-8")
}
}
}
var _ = Describe("Dllm backend wiring", func() {
Describe("PredictRich", func() {
It("renders gemma4 from raw messages and parses the output when use_tokenizer_template is set", func() {
fake := &fakeGen{out: "<|channel>thought\npondering<channel|>The answer.<turn|>"}
d := newTestDllm(fake, nil)
reply, err := d.PredictRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
Metadata: map[string]string{"enable_thinking": "true"},
})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls).To(HaveLen(1))
// The enable_thinking=true render from the transformers fixture.
Expect(calls[0].prompt).To(Equal(
"<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n"))
Expect(string(reply.GetMessage())).To(Equal("The answer."))
Expect(reply.GetChatDeltas()).To(HaveLen(2))
Expect(reply.GetChatDeltas()[0].GetReasoningContent()).To(Equal("pondering"))
Expect(reply.GetChatDeltas()[1].GetContent()).To(Equal("The answer."))
})
It("defaults enable_thinking OFF (the gemma4 template treats thinking as opt-in)", func() {
fake := &fakeGen{out: "hi"}
d := newTestDllm(fake, nil)
_, err := d.PredictRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
// No-thinking render: the template pre-opens AND pre-closes an
// empty thought channel in the generation prompt.
Expect(calls[0].prompt).To(Equal(
"<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>"))
})
It("passes the raw prompt verbatim and skips gemma4 parsing without use_tokenizer_template", func() {
// Marker-looking text must survive untouched: in raw-prompt mode
// the caller templates themselves and the Go-side extraction
// applies, so the backend must not interpret the output.
fake := &fakeGen{out: "<|channel>thought\nnot parsed<channel|>tail"}
d := newTestDllm(fake, nil)
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "my raw prompt"})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls[0].prompt).To(Equal("my raw prompt"))
Expect(string(reply.GetMessage())).To(Equal(fake.out))
Expect(reply.GetChatDeltas()).To(HaveLen(1))
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal(fake.out))
})
It("sanitizes invalid UTF-8 in the non-streaming output", func() {
// Byte-fallback tokens can decode to lone malformed bytes; the
// whole-output sanitize must replace them so proto3 marshaling of
// Message/ChatDeltas cannot fail.
fake := &fakeGen{out: "a\xe2b"}
d := newTestDllm(fake, nil)
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
expectValidUTF8Reply(reply)
Expect(string(reply.GetMessage())).To(Equal("a<>b"))
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal("a<>b"))
})
It("maps Tokens and Seed into the opts JSON on top of the model-level overrides", func() {
fake := &fakeGen{out: "x"}
d := newTestDllm(fake, map[string]any{"eb_t_min": 0.5, "kv_cache": "auto"})
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p", Tokens: 32, Seed: 7})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls[0].optsJSON).To(MatchJSON(`{"n_predict":32,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
})
It("omits n_predict and seed when unset so the engine defaults apply", func() {
fake := &fakeGen{out: "x"}
d := newTestDllm(fake, nil)
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
calls, _ := fake.snapshot()
Expect(calls[0].optsJSON).To(MatchJSON(`{}`))
})
It("surfaces generator errors", func() {
fake := &fakeGen{err: errors.New("boom")}
d := newTestDllm(fake, nil)
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).To(MatchError("boom"))
})
It("errors before generating when no model is loaded", func() {
d := &Dllm{} // no Load, no worker: must fail fast, not hang
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).To(HaveOccurred())
})
It("makes a concurrent Free wait for the in-flight request (both finish cleanly)", func() {
// server.go's Free has no locking of its own: the backend's genMu
// must hold Free back until the racing generate drains, instead of
// closing the jobs channel (panic) or freeing the C ctx under it.
fake := &fakeGen{out: "x", delay: 50 * time.Millisecond}
d := newTestDllm(fake, nil)
predictDone := make(chan error, 1)
go func() {
defer GinkgoRecover()
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
predictDone <- err
}()
// Wait until the fake generate is actually in flight (the read
// lock is held from before submit until PredictRich returns).
Eventually(func() int {
_, maxInFlight := fake.snapshot()
return maxInFlight
}).Should(Equal(1))
Expect(d.Free()).To(Succeed())
// Free's write lock means the request finished before Free did.
var predictErr error
Eventually(predictDone).Should(Receive(&predictErr))
Expect(predictErr).ToNot(HaveOccurred())
})
It("returns model-not-loaded for requests after Free", func() {
fake := &fakeGen{out: "x"}
d := newTestDllm(fake, nil)
Expect(d.Free()).To(Succeed())
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
})
It("serializes concurrent requests through the worker goroutine", func() {
// dllm_capi.h: one ctx = one concurrent generate. Two overlapping
// PredictRich calls must execute the C calls one at a time.
fake := &fakeGen{out: "x", delay: 30 * time.Millisecond}
d := newTestDllm(fake, nil)
var wg sync.WaitGroup
for range 2 {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
}()
}
wg.Wait()
calls, maxInFlight := fake.snapshot()
Expect(calls).To(HaveLen(2))
Expect(maxInFlight).To(Equal(1), "generate calls overlapped despite the worker queue")
})
})
Describe("PredictStreamRich", func() {
It("emits one reply per delta-producing block and leaves the channel open", func() {
// Blocks split mid-marker and mid-payload: the parser's holdback
// must keep marker fragments out of the emitted deltas.
fake := &fakeGen{blocks: []string{
"<|channel>thou", // partial channel open: no deltas yet
"ght\nponder", // header completes, reasoning starts
"ing<channel|>Hi ", // reasoning ends, content starts
"there<turn|>discarded", // turn ends: trailing text dropped
}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, ch)
Expect(err).ToNot(HaveOccurred())
replies := drainReplies(ch)
Expect(replies).To(HaveLen(3), "block 1 completes no delta and must not produce a reply")
var content, reasoning string
for _, r := range replies {
for _, delta := range r.GetChatDeltas() {
content += delta.GetContent()
reasoning += delta.GetReasoningContent()
}
}
Expect(reasoning).To(Equal("pondering"))
Expect(content).To(Equal("Hi there"))
// Message mirrors each reply's content so legacy consumers see
// exactly the displayed tokens.
Expect(string(replies[1].GetMessage())).To(Equal("Hi "))
Expect(string(replies[2].GetMessage())).To(Equal("there"))
})
It("streams raw blocks verbatim without use_tokenizer_template", func() {
fake := &fakeGen{blocks: []string{"abc", "", "<|channel>def"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
Expect(err).ToNot(HaveOccurred())
replies := drainReplies(ch)
Expect(replies).To(HaveLen(2), "empty blocks produce no reply")
Expect(string(replies[0].GetMessage())).To(Equal("abc"))
Expect(string(replies[1].GetMessage())).To(Equal("<|channel>def"))
Expect(replies[1].GetChatDeltas()).To(HaveLen(1))
})
It("flushes parser holdback after the stream ends", func() {
// The unterminated partial marker "<chan" is held back during the
// stream and must come out as content on the final flush.
fake := &fakeGen{blocks: []string{"tail<chan"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) {
content += string(r.GetMessage())
}
Expect(content).To(Equal("tail<chan"))
})
It("reassembles a multi-byte character split across block boundaries", func() {
// Per-block detokenize can split "€" (E2 82 AC) as E2 | 82 AC.
// Emitting the lone E2 would make grpc-go fail the marshal of the
// whole reply; the trailing incomplete sequence must be held back
// and completed by the next block.
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac ok"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) { // drain asserts ValidString per reply
content += string(r.GetMessage())
}
Expect(content).To(Equal("caf€ ok"))
})
It("reassembles a split multi-byte character in parsed (gemma4) mode too", func() {
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac<turn|>"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) {
for _, delta := range r.GetChatDeltas() {
content += delta.GetContent()
}
}
Expect(content).To(Equal("caf€"))
})
It("replaces an incomplete sequence left at stream end with U+FFFD", func() {
// A byte-fallback token can leave a lone leading byte (0xE2) that
// no later block completes: the final flush must substitute it,
// never emit it raw and never drop into a marshal error.
fake := &fakeGen{blocks: []string{"ok\xe2"}}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
Expect(err).ToNot(HaveOccurred())
var content string
for _, r := range drainReplies(ch) {
content += string(r.GetMessage())
}
Expect(content).To(Equal("ok<6F>"))
})
It("surfaces generator errors without sending replies", func() {
fake := &fakeGen{err: errors.New("stream boom")}
d := newTestDllm(fake, nil)
ch := make(chan *pb.Reply, 16)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
Expect(err).To(MatchError("stream boom"))
Expect(drainReplies(ch)).To(BeEmpty())
})
It("errors before generating when no model is loaded", func() {
d := &Dllm{} // no Load, no worker: must fail fast, not hang
ch := make(chan *pb.Reply, 1)
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
Expect(drainReplies(ch)).To(BeEmpty())
})
})
Describe("legacy Predict/PredictStream adapters", func() {
It("Predict returns the aggregated content string", func() {
fake := &fakeGen{out: "plain text"}
d := newTestDllm(fake, nil)
out, err := d.Predict(&pb.PredictOptions{Prompt: "p"})
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal("plain text"))
})
It("PredictStream forwards content strings and closes the channel (legacy ownership)", func() {
fake := &fakeGen{blocks: []string{"a", "b"}}
d := newTestDllm(fake, nil)
ch := make(chan string, 16)
Expect(d.PredictStream(&pb.PredictOptions{Prompt: "p"}, ch)).To(Succeed())
var got []string
for s := range ch { // terminates only if the impl closed ch
got = append(got, s)
}
Expect(got).To(Equal([]string{"a", "b"}))
})
})
Describe("TokenizeString", func() {
It("decodes the C-side JSON id array", func() {
fake := &fakeGen{out: "[2,18]"}
d := newTestDllm(fake, nil)
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).ToNot(HaveOccurred())
Expect(resp.Length).To(Equal(int32(2)))
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
calls, _ := fake.snapshot()
Expect(calls[0].prompt).To(Equal("hello"))
})
It("fails loud on a malformed id array", func() {
fake := &fakeGen{out: "not json"}
d := newTestDllm(fake, nil)
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).To(HaveOccurred())
})
It("errors before tokenizing when no model is loaded", func() {
d := &Dllm{} // no Load, no worker: must fail fast, not hang
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
})
})
Describe("parseModelGenOpts", func() {
It("parses eb_*/blocks/kv_cache entries and types values by first successful parse", func() {
got := parseModelGenOpts([]string{
"eb_max_steps:16",
"eb_t_min:0.25",
"kv_cache:auto",
"blocks:4",
"unrelated_key:1", // other layers' options: skipped
"malformed", // no colon: skipped
})
Expect(got).To(Equal(map[string]any{
"eb_max_steps": int64(16),
"eb_t_min": 0.25,
"kv_cache": "auto",
"blocks": int64(4),
}))
})
It("round-trips through buildOptsJSON (only flat scalars are produced)", func() {
got := parseModelGenOpts([]string{"eb_entropy_bound:0.8", "kv_cache:off"})
out, err := buildOptsJSON(got)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(MatchJSON(`{"eb_entropy_bound":0.8,"kv_cache":"off"}`))
})
})
})
// ---------------------------------------------------------------------------
// Gated backend round-trip against the real libdllm.so + tiny GGUF fixture.
// ---------------------------------------------------------------------------
var _ = Describe("Dllm backend (real tiny model)", func() {
BeforeEach(func() {
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the backend round-trip")
}
ensureLibLoaded()
Expect(libLoadErr).ToNot(HaveOccurred())
})
It("round-trips Load, PredictRich, PredictStreamRich and TokenizeString", func() {
d := &Dllm{}
Expect(d.Load(&pb.ModelOptions{ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL")})).To(Succeed())
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
// TokenizeString: tiny fixture vocab tokenizes "hello" to [2,18].
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
Expect(err).ToNot(HaveOccurred())
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
Expect(resp.Length).To(Equal(int32(2)))
req := &pb.PredictOptions{
UseTokenizerTemplate: true,
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
Tokens: 16,
Seed: 7,
}
// Non-streaming: the tiny random-weight model emits arbitrary vocab
// words; with no gemma4 markers in them everything is content.
reply, err := d.PredictRich(req)
Expect(err).ToNot(HaveOccurred())
Expect(string(reply.GetMessage())).ToNot(BeEmpty())
Expect(reply.GetChatDeltas()).ToNot(BeEmpty())
// Streaming: at least one reply, and the channel-ownership rule is
// honored (drainReplies fails the spec on a closed channel).
ch := make(chan *pb.Reply, 64)
Expect(d.PredictStreamRich(req, ch)).To(Succeed())
replies := drainReplies(ch)
Expect(replies).ToNot(BeEmpty())
var streamed string
for _, r := range replies {
streamed += string(r.GetMessage())
}
Expect(streamed).ToNot(BeEmpty())
})
It("aborts an in-flight generation promptly on Cancel", func() {
d := &Dllm{}
// eb_max_steps inflates the per-block denoise loop so the full run
// takes ~10s on the tiny fixture (vs ~40ms at engine defaults; 16
// blocks, first block after ~0.7s) - long enough that a prompt
// post-cancel return is distinguishable from the generation simply
// finishing.
Expect(d.Load(&pb.ModelOptions{
ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL"),
Options: []string{"eb_max_steps:256"},
})).To(Succeed())
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
ch := make(chan *pb.Reply, 64)
errCh := make(chan error, 1)
go func() {
defer GinkgoRecover()
errCh <- d.PredictStreamRich(&pb.PredictOptions{Prompt: "hello", Tokens: 256, Seed: 7}, ch)
}()
// Cancel only once the first block proves the generate is in
// flight: the C side resets the cancel flag on generate entry, so
// an earlier Cancel would be swallowed (dllm_capi.h race note).
Eventually(ch, "60s").Should(Receive())
cancelAt := time.Now()
d.Cancel()
// Uncancelled, ~10s of generation remain; the cancelled call must
// come back in milliseconds (the flag is checked per denoise step).
var genErr error
Eventually(errCh, "5s").Should(Receive(&genErr))
latency := time.Since(cancelAt)
Expect(genErr).To(MatchError(ContainSubstring("cancelled")))
GinkgoWriter.Printf("dllm cancel: PredictStreamRich returned %v after Cancel\n", latency)
})
})

View File

@@ -1,562 +0,0 @@
// Gemma4 (DiffusionGemma) streaming output parser: raw model text, fed in
// arbitrary fragments (per committed diffusion block; a fragment can split
// anywhere, including mid-marker and mid-payload), is turned into
// pb.ChatDelta events (content / reasoning_content / tool_calls).
//
// Normative sources:
// - The chat template embedded at the top of gemma4_renderer.go ("tpl L<n>"
// citations below refer to its numbered lines). The OUTPUT format mirrors
// what the template renders for assistant history: thought channels
// (<|channel>thought\n ... <channel|>, tpl L240), tool calls
// (<|tool_call>call:name{...}<tool_call|>, tpl L246-L257) and turn ends
// (<turn|>, tpl L351).
// - vLLM PR #45163: vllm/tool_parsers/gemma4_tool_parser.py (marker
// handling, the call:name{...} argument grammar and its decoder, ported
// below) and vllm/reasoning/gemma4_reasoning_parser.py (channel markers,
// the "thought\n" role label, is_reasoning_end semantics).
//
// Initial state (derived from the generation prompt, tpl L356-L362, see
// RenderGemma4):
// - enable_thinking=false: the prompt ends with "<|turn>model\n" +
// "<|channel>thought\n<channel|>" - an EMPTY thought channel, pre-opened
// AND pre-closed by the template. The model's output therefore starts in
// plain content. Use NewGemma4Parser(false).
// - enable_thinking=true: the prompt ends at "<|turn>model\n" and the model
// opens and closes its own thought channel in the OUTPUT
// ("<|channel>thought\n...reasoning...<channel|>final answer", per the
// vLLM Gemma4ReasoningParser docstring). The parser still starts in
// content state - the channel markers in the output drive the switch.
// Use NewGemma4Parser(false) here too.
// - NewGemma4Parser(true) is for callers that pre-open the thought channel
// in the prompt themselves (appending "<|channel>thought\n" after the
// generation prompt to force thinking): the output then begins mid-thought
// and everything is reasoning until the first <channel|>.
//
// State diagram (markers are consumed, never emitted):
//
// <|channel> \n (channel name dropped: the
// [content] --------------> [chan-header] ----> [thought] "thought\n" role
// ^ | <channel|> (stray close: swallowed, label, stripped
// +-+ strip_thinking semantics, tpl L148-L158) like vLLM does)
// ^ <channel|>
// +----------------------------------------- [thought]
// ^ <tool_call|> | <|tool_call> (implicit
// +-------------- [tool-call] <-------------------+ reasoning end, vLLM
// | <|tool_call> ^ is_reasoning_end)
// +-------------------+
// [content]/[thought] --- <turn|> ---> [done] (everything after is dropped)
//
// Buffering rules:
// - content/thought states hold back at most len(longest marker)-1 bytes:
// the longest tail that is still a proper prefix of a watched marker.
// Content is otherwise emitted immediately (no unbounded buffering).
// - the tool-call state buffers the whole payload until <tool_call|>. This
// is unbounded in principle but bounded in practice by the model's
// diffusion canvas, and is required because the call:name{...} payload
// only becomes decodable (and trustworthy) once complete - the same
// reason vLLM's parser accumulates before parsing.
// - Close() flushes whatever is still held: partial markers come out as
// content/reasoning (per the state that held them); an unterminated
// channel header or tool-call payload is re-emitted RAW (including its
// opening marker) as content - malformed output is never silently
// dropped (mirrors vLLM extract_tool_calls returning the raw text as
// content when its regex does not match).
//
// Streaming granularity DIVERGENCE from vLLM: vLLM re-parses the partial
// payload on every token and streams argument-JSON diffs (its `partial=True`
// decoder mode plus withholding logic exist only for that). Our fragments are
// whole committed diffusion blocks, so each completed tool call is emitted
// once, as a single ToolCallDelta carrying index + id + name + the full
// arguments JSON - exactly the shape backend/python/vllm/backend.py emits
// per call and pkg/functions.ToolCallsFromChatDeltas re-accumulates.
package main
import (
"encoding/json"
"regexp"
"strconv"
"strings"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// gemma4CallRE is vLLM's tool_call_regex
// (`<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>`, DOTALL) anchored to
// a single already-extracted payload: name charset [\w\-.], braces mandatory.
var gemma4CallRE = regexp.MustCompile(`(?s)^call:([\w\-.]+)\{(.*)\}$`)
type g4State int
const (
g4Content g4State = iota
g4ChanHeader
g4Thought
g4ToolCall
g4Done
)
// Markers watched per emitting state. A stray <tool_call|> outside a tool
// call is deliberately NOT watched: it passes through verbatim, consistent
// with the malformed-payload fallback re-emitting it as content.
var (
gemma4ContentMarkers = []string{gemma4ChannelOpen, gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
gemma4ThoughtMarkers = []string{gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
)
type Gemma4Parser struct {
state g4State
// held is the per-state carry-over between Feed calls: a partial marker
// (content/thought), a partial channel header (chan-header) or the
// payload accumulated so far (tool-call).
held string
toolIdx int
}
// NewGemma4Parser returns a parser positioned per the initial-state rules in
// the header comment: startInThought=true only when the caller pre-opened a
// thought channel in the prompt.
func NewGemma4Parser(startInThought bool) *Gemma4Parser {
state := g4Content
if startInThought {
state = g4Thought
}
return &Gemma4Parser{state: state}
}
// Feed consumes the next output fragment and returns the deltas it completes.
func (p *Gemma4Parser) Feed(text string) []*pb.ChatDelta {
if text == "" || p.state == g4Done {
return nil
}
pending := p.held + text
p.held = ""
var em g4Emitter
for pending != "" {
switch p.state {
case g4Content, g4Thought:
markers := gemma4ContentMarkers
if p.state == g4Thought {
markers = gemma4ThoughtMarkers
}
idx, marker := findEarliestGemma4Marker(pending, markers)
if idx == -1 {
hold := gemma4MarkerHoldback(pending, markers)
p.emitText(&em, pending[:len(pending)-hold])
p.held = pending[len(pending)-hold:]
pending = ""
continue
}
p.emitText(&em, pending[:idx])
pending = pending[idx+len(marker):]
switch marker {
case gemma4ChannelOpen:
p.state = g4ChanHeader
case gemma4ChannelClose:
// In thought: channel ends. In content: stray close,
// swallowed (strip_thinking keeps both sides, tpl L148-L158).
p.state = g4Content
case gemma4ToolCallOpen:
p.state = g4ToolCall
case gemma4TurnEnd:
p.state = g4Done
}
case g4ChanHeader:
// The channel header is "<name>\n"; the template only ever writes
// "thought" (tpl L240/L360) and the label is structural, so it is
// dropped, not emitted (vLLM strips the same "thought\n" prefix).
nl := strings.IndexByte(pending, '\n')
if nl == -1 {
p.held = pending
pending = ""
continue
}
pending = pending[nl+1:]
p.state = g4Thought
case g4ToolCall:
end := strings.Index(pending, gemma4ToolCallClose)
if end == -1 {
p.held = pending
pending = ""
continue
}
p.emitToolCall(&em, pending[:end])
pending = pending[end+len(gemma4ToolCallClose):]
p.state = g4Content
case g4Done:
pending = ""
}
}
return em.deltas
}
// Close flushes held-back partials. Incomplete structures (open channel
// header, unterminated tool payload) are re-emitted raw as content rather
// than dropped. The parser is finished afterwards.
func (p *Gemma4Parser) Close() []*pb.ChatDelta {
var em g4Emitter
switch p.state {
case g4Content:
em.content(p.held)
case g4Thought:
em.reasoning(p.held)
case g4ChanHeader:
em.content(gemma4ChannelOpen + p.held)
case g4ToolCall:
em.content(gemma4ToolCallOpen + p.held)
case g4Done:
}
p.held = ""
p.state = g4Done
return em.deltas
}
func (p *Gemma4Parser) emitText(em *g4Emitter, s string) {
if p.state == g4Thought {
em.reasoning(s)
return
}
em.content(s)
}
// emitToolCall decodes one complete <|tool_call>...<tool_call|> payload. On a
// payload that does not match call:name{...} the raw text (markers included)
// is emitted as content, mirroring vLLM's extract_tool_calls fallback.
func (p *Gemma4Parser) emitToolCall(em *g4Emitter, payload string) {
m := gemma4CallRE.FindStringSubmatch(payload)
if m == nil {
em.content(gemma4ToolCallOpen + payload + gemma4ToolCallClose)
return
}
// Index-based ids: deterministic (the split-invariance property relies
// on it) and matching the call_<n> convention of pkg/grpc/rich_test.go;
// core only needs ids to be non-empty and unique within the response.
em.tool(p.toolIdx, "call_"+strconv.Itoa(p.toolIdx), m[1], decodeGemma4Args(m[2], 0))
p.toolIdx++
}
// g4Emitter collects ChatDeltas; empty text events are dropped.
type g4Emitter struct {
deltas []*pb.ChatDelta
}
func (e *g4Emitter) content(s string) {
if s != "" {
e.deltas = append(e.deltas, &pb.ChatDelta{Content: s})
}
}
func (e *g4Emitter) reasoning(s string) {
if s != "" {
e.deltas = append(e.deltas, &pb.ChatDelta{ReasoningContent: s})
}
}
func (e *g4Emitter) tool(index int, id, name, argsJSON string) {
e.deltas = append(e.deltas, &pb.ChatDelta{ToolCalls: []*pb.ToolCallDelta{{
Index: int32(index),
Id: id,
Name: name,
Arguments: argsJSON,
}}})
}
// findEarliestGemma4Marker returns the position and value of the first
// complete marker occurrence, or (-1, "").
func findEarliestGemma4Marker(s string, markers []string) (int, string) {
best, bestMarker := -1, ""
for _, m := range markers {
if idx := strings.Index(s, m); idx >= 0 && (best == -1 || idx < best) {
best, bestMarker = idx, m
}
}
return best, bestMarker
}
// gemma4MarkerHoldback returns the length of the longest suffix of s that is
// a proper prefix of a watched marker - the only bytes that may still grow
// into a marker and therefore must not be emitted yet (bounded by the
// longest marker, so content is never buffered unboundedly).
func gemma4MarkerHoldback(s string, markers []string) int {
maxHold := 0
for _, m := range markers {
if len(m)-1 > maxHold {
maxHold = len(m) - 1
}
}
if len(s) < maxHold {
maxHold = len(s)
}
for k := maxHold; k >= 1; k-- {
tail := s[len(s)-k:]
for _, m := range markers {
if strings.HasPrefix(m, tail) {
return k
}
}
}
return 0
}
// ---------------------------------------------------------------------------
// call:name{...} argument decoder
//
// Port of vLLM's _parse_gemma4_args / _parse_gemma4_array /
// _parse_gemma4_value (gemma4_tool_parser.py) in non-partial mode only: this
// parser decodes exclusively COMPLETE payloads (incomplete ones fall back to
// raw content at Close), so vLLM's partial-withholding machinery
// (trailing-dot floats, withheld bare tails) is intentionally not ported.
//
// Grammar (inverse of the renderer's formatGemma4Argument, tpl L118-L147):
//
// args := pair (',' pair)*
// pair := key ':' value (keys unquoted, up to the first ':')
// value := string | object | array | bare
// string := '<|"|>' ... '<|"|>' (no escapes; unterminated -> rest)
// object := '{' args '}' (delimited strings skipped when
// array := '[' value,* ']' counting braces/brackets)
// bare := true | false | null/none/nil | number | bare-string
//
// Output is a JSON object/array string with keys in payload order (Python
// dict insertion order), built with HTML escaping off so payload text
// survives byte-for-byte.
// ---------------------------------------------------------------------------
func isGemma4Space(c byte) bool { return c == ' ' || c == '\n' || c == '\t' }
// gemma4MaxArgsDepth caps the mutual recursion between decodeGemma4Args and
// decodeGemma4Array. Defense against model-generated deep nesting: a Go stack
// overflow is a fatal process kill, not a recoverable error, so past the cap
// a nested body gracefully degrades to a JSON string of its raw text.
const gemma4MaxArgsDepth = 100
// decodeGemma4Args decodes one args body (the text between the outer braces
// of call:name{...}) into a JSON object string. depth is the current nesting
// level (0 at the payload root); see gemma4MaxArgsDepth.
func decodeGemma4Args(s string, depth int) string {
if depth > gemma4MaxArgsDepth {
return gemma4JSONString(s)
}
var b strings.Builder
b.WriteString("{")
first := true
pair := func(key, val string) {
if !first {
b.WriteString(",")
}
first = false
b.WriteString(gemma4JSONString(key))
b.WriteString(":")
b.WriteString(val)
}
i, n := 0, len(s)
for i < n {
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
i++
}
if i >= n {
break
}
keyStart := i
for i < n && s[i] != ':' {
i++
}
if i >= n {
break // no ':' -> trailing junk, dropped (vLLM does the same)
}
key := strings.TrimSpace(s[keyStart:i])
i++ // skip ':'
for i < n && isGemma4Space(s[i]) {
i++
}
if i >= n {
pair(key, `""`) // "key:" with nothing after -> empty string
break
}
switch {
case strings.HasPrefix(s[i:], gemma4StringDelim):
i += len(gemma4StringDelim)
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
pair(key, gemma4JSONString(s[i:])) // unterminated -> take rest
i = n
} else {
pair(key, gemma4JSONString(s[i:i+end]))
i += end + len(gemma4StringDelim)
}
case s[i] == '{':
inner, next := scanGemma4Balanced(s, i, '{', '}')
pair(key, decodeGemma4Args(inner, depth+1))
i = next
case s[i] == '[':
inner, next := scanGemma4Balanced(s, i, '[', ']')
pair(key, decodeGemma4Array(inner, depth+1))
i = next
default:
valStart := i
for i < n && s[i] != ',' && s[i] != '}' && s[i] != ']' {
i++
}
if i == valStart {
// No progress (value starts on a stray '}'/']'): abort on
// malformed input rather than loop, like vLLM.
i = n
continue
}
pair(key, decodeGemma4Bare(s[valStart:i]))
}
}
b.WriteString("}")
return b.String()
}
// decodeGemma4Array decodes one array body (the text between '[' and ']')
// into a JSON array string. depth is the current nesting level; see
// gemma4MaxArgsDepth.
func decodeGemma4Array(s string, depth int) string {
if depth > gemma4MaxArgsDepth {
return gemma4JSONString(s)
}
var b strings.Builder
b.WriteString("[")
first := true
item := func(val string) {
if !first {
b.WriteString(",")
}
first = false
b.WriteString(val)
}
i, n := 0, len(s)
for i < n {
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
i++
}
if i >= n {
break
}
switch {
case strings.HasPrefix(s[i:], gemma4StringDelim):
i += len(gemma4StringDelim)
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
item(gemma4JSONString(s[i:]))
i = n
} else {
item(gemma4JSONString(s[i : i+end]))
i += end + len(gemma4StringDelim)
}
case s[i] == '{':
inner, next := scanGemma4Balanced(s, i, '{', '}')
item(decodeGemma4Args(inner, depth+1))
i = next
case s[i] == '[':
inner, next := scanGemma4Balanced(s, i, '[', ']')
item(decodeGemma4Array(inner, depth+1))
i = next
default:
valStart := i
for i < n && s[i] != ',' && s[i] != ']' {
i++
}
if i == valStart {
i = n // no progress: abort on malformed input, like vLLM
continue
}
item(decodeGemma4Bare(s[valStart:i]))
}
}
b.WriteString("]")
return b.String()
}
// scanGemma4Balanced scans a brace/bracket-balanced span starting at the
// opener s[start], skipping over <|"|>-delimited strings so structural
// characters inside them do not count (vLLM's depth scan). Returns the inner
// text and the index just past the closer; an unterminated span yields the
// rest of the string (the inner decoder still extracts what is there - this
// path is only reachable from genuinely malformed complete payloads).
func scanGemma4Balanced(s string, start int, open, close byte) (string, int) {
depth := 1
i := start + 1
innerStart := i
n := len(s)
for i < n && depth > 0 {
if strings.HasPrefix(s[i:], gemma4StringDelim) {
i += len(gemma4StringDelim)
if nd := strings.Index(s[i:], gemma4StringDelim); nd == -1 {
i = n
} else {
i += nd + len(gemma4StringDelim)
}
continue
}
switch s[i] {
case open:
depth++
case close:
depth--
}
i++
}
if depth > 0 {
return s[innerStart:], n
}
return s[innerStart : i-1], i
}
// decodeGemma4Bare maps an undelimited value to its JSON form: booleans,
// null aliases (null/none/nil, case-insensitive - the renderer writes
// Python None as "None", tpl L144-L145 via format_argument's else branch),
// numbers (vLLM's rule: a '.' tries float, otherwise int; anything that
// fails parses as a bare string).
func decodeGemma4Bare(raw string) string {
v := strings.TrimSpace(raw)
if v == "" {
return `""`
}
if v == "true" || v == "false" {
return v
}
switch strings.ToLower(v) {
case "null", "none", "nil":
return "null"
}
if strings.Contains(v, ".") {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return formatGemma4Float(f)
}
} else if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
return strconv.FormatInt(iv, 10)
}
return gemma4JSONString(v)
}
// formatGemma4Float renders like Python's json.dumps(float): integral floats
// keep a ".0" suffix ("108." decodes to 108.0, not 108), so the arguments
// JSON matches what vLLM would have produced for the same payload.
func formatGemma4Float(f float64) string {
s := strconv.FormatFloat(f, 'g', -1, 64)
if !strings.ContainsAny(s, ".eE") {
s += ".0"
}
return s
}
// gemma4JSONString encodes a JSON string WITHOUT HTML escaping (json.Marshal
// would escape the angle brackets in "<div>" to \u003c / \u003e sequences;
// payload text should survive
// byte-for-byte, like Python's json.dumps(ensure_ascii=False)).
func gemma4JSONString(s string) string {
var sb strings.Builder
enc := json.NewEncoder(&sb)
enc.SetEscapeHTML(false)
if err := enc.Encode(s); err != nil {
// Unreachable for plain strings; fall back to default escaping
// rather than emitting invalid JSON.
b, mErr := json.Marshal(s)
if mErr != nil {
return `""`
}
return string(b)
}
// Encode appends a trailing newline.
return strings.TrimSuffix(sb.String(), "\n")
}

View File

@@ -1,592 +0,0 @@
package main
// Parser specs for Gemma4Parser (model output text -> pb.ChatDelta events).
//
// Fixture provenance:
// - Entries marked "vLLM: <name>" are direct ports of the named test from
// vLLM PR #45163, tests/tool_parsers/test_gemma4_tool_parser.py (the
// authoritative test-suite for the gemma4 tool-call wire format). The
// streaming tests' chunk lists are reused verbatim as Feed fragments.
// - Decoder entries port the TestParseGemma4Args / TestParseGemma4Array
// classes from the same file (non-partial mode only; this parser never
// decodes partial payloads, see the divergence note in gemma4_parser.go).
// - Channel/turn-marker expectations come from the chat template embedded
// in gemma4_renderer.go (tpl L356-L362 generation prompt, L148-L158
// strip_thinking) and vLLM's Gemma4ReasoningParser
// (vllm/reasoning/gemma4_reasoning_parser.py).
import (
"encoding/json"
"fmt"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// flatGemma4Tool is one accumulated tool call, mirroring how LocalAI core
// folds ToolCallDelta streams (pkg/functions/chat_deltas.go
// ToolCallsFromChatDeltas: name/id latch on first non-empty, arguments
// concatenate per index). Tests flatten through the same rules so they
// assert exactly what core will reconstruct.
type flatGemma4Tool struct {
id string
name string
args string
}
func flattenGemma4Deltas(deltas []*pb.ChatDelta) (string, string, []flatGemma4Tool) {
var content, reasoning strings.Builder
byIndex := map[int32]*flatGemma4Tool{}
maxIdx := int32(-1)
for _, d := range deltas {
content.WriteString(d.GetContent())
reasoning.WriteString(d.GetReasoningContent())
for _, tc := range d.GetToolCalls() {
acc, ok := byIndex[tc.GetIndex()]
if !ok {
acc = &flatGemma4Tool{}
byIndex[tc.GetIndex()] = acc
}
if tc.GetName() != "" {
acc.name = tc.GetName()
}
if tc.GetId() != "" {
acc.id = tc.GetId()
}
acc.args += tc.GetArguments()
if tc.GetIndex() > maxIdx {
maxIdx = tc.GetIndex()
}
}
}
var tools []flatGemma4Tool
for i := int32(0); i <= maxIdx; i++ {
if acc, ok := byIndex[i]; ok {
tools = append(tools, *acc)
}
}
return content.String(), reasoning.String(), tools
}
type wantGemma4Tool struct {
name string
argsJSON string // compared with MatchJSON (key order irrelevant)
}
type parseGemma4Case struct {
startInThought bool
fragments []string
wantContent string
wantReasoning string
wantTools []wantGemma4Tool
}
func parseGemma4Fragments(startInThought bool, fragments []string) []*pb.ChatDelta {
p := NewGemma4Parser(startInThought)
var all []*pb.ChatDelta
for _, f := range fragments {
all = append(all, p.Feed(f)...)
}
return append(all, p.Close()...)
}
var _ = Describe("Gemma4Parser", func() {
DescribeTable("parses streamed gemma4 output into ChatDeltas",
func(c parseGemma4Case) {
content, reasoning, tools := flattenGemma4Deltas(parseGemma4Fragments(c.startInThought, c.fragments))
Expect(content).To(Equal(c.wantContent))
Expect(reasoning).To(Equal(c.wantReasoning))
Expect(tools).To(HaveLen(len(c.wantTools)))
seenIDs := map[string]bool{}
for i, want := range c.wantTools {
Expect(tools[i].name).To(Equal(want.name), "tool %d name", i)
Expect(tools[i].args).To(MatchJSON(want.argsJSON), "tool %d arguments", i)
Expect(tools[i].id).ToNot(BeEmpty(), "tool %d id", i)
Expect(seenIDs).ToNot(HaveKey(tools[i].id), "tool %d id must be unique", i)
seenIDs[tools[i].id] = true
}
},
// --- (1) pure content -------------------------------------------------
// vLLM: test_no_tool_calls
Entry("pure content, single fragment", parseGemma4Case{
fragments: []string{"Hello, how can I help you today?"},
wantContent: "Hello, how can I help you today?",
}),
// --- (2) thought -> final transition ----------------------------------
// enable_thinking render: prompt ends at <|turn>model\n and the model
// opens/closes its own thought channel in the OUTPUT (vLLM
// Gemma4ReasoningParser docstring; tpl L356-L362). The "thought\n"
// role label after <|channel> is structural and must be stripped
// (vLLM _THOUGHT_PREFIX handling).
Entry("thought channel then final content", parseGemma4Case{
fragments: []string{"<|channel>thought\nLet me think about this.\n<channel|>The answer is 42."},
wantReasoning: "Let me think about this.\n",
wantContent: "The answer is 42.",
}),
// --- (3) startInThought both ways -------------------------------------
Entry("startInThought=true routes initial text to reasoning until <channel|>", parseGemma4Case{
startInThought: true,
fragments: []string{"I am thinking hard.<channel|>Done."},
wantReasoning: "I am thinking hard.",
wantContent: "Done.",
}),
// A stray <channel|> with no open channel is swallowed, matching the
// template's strip_thinking (tpl L148-L158: the marker is dropped,
// text on both sides is kept).
Entry("startInThought=false keeps the same text as content, stray <channel|> swallowed", parseGemma4Case{
startInThought: false,
fragments: []string{"I am thinking hard.<channel|>Done."},
wantContent: "I am thinking hard.Done.",
}),
// --- (4) one tool call, full payload type zoo --------------------------
Entry("single tool call: strings, numbers, bools, null, nested object and array", parseGemma4Case{
fragments: []string{`<|tool_call>call:complex_function{text:<|"|>with, comma and {braces}<|"|>,count:42,score:3.14,yes:true,no:false,nothing:null,obj:{inner:<|"|>v<|"|>,k:1},arr:[<|"|>a<|"|>,2,true]}<tool_call|>`},
wantTools: []wantGemma4Tool{{
name: "complex_function",
argsJSON: `{"text":"with, comma and {braces}","count":42,"score":3.14,"yes":true,"no":false,"nothing":null,"obj":{"inner":"v","k":1},"arr":["a",2,true]}`,
}},
}),
// --- (5) payload split across 3 fragments ------------------------------
Entry("tool-call payload split across three fragments", parseGemma4Case{
fragments: []string{
"<|tool_call>call:get_weather{loc",
`ation:<|"|>Paris, Fra`,
`nce<|"|>}<tool_call|>`,
},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
}),
// --- (6) marker split across fragments ----------------------------------
Entry("tool-call open marker split across fragments", parseGemma4Case{
fragments: []string{
"<|tool_ca",
`ll>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
}),
Entry("channel open marker split across fragments", parseGemma4Case{
fragments: []string{
"<|chan",
"nel>thought\ndeep thought<channel|>final",
},
wantReasoning: "deep thought",
wantContent: "final",
}),
// --- (7) trailing partial marker held, flushed by Close -----------------
Entry("trailing partial marker is held back and flushed by Close", parseGemma4Case{
fragments: []string{"Hello <|tool"},
wantContent: "Hello <|tool",
}),
// --- (8) malformed/incomplete payload -> content fallback ---------------
// vLLM: test_incomplete_tool_call (no end marker: the whole text stays
// content, never silently dropped).
Entry("incomplete tool payload at Close is emitted as raw content", parseGemma4Case{
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London`},
wantContent: `<|tool_call>call:get_weather{location:<|"|>London`,
}),
Entry("malformed complete payload is emitted as raw content, parsing continues", parseGemma4Case{
fragments: []string{"<|tool_call>oops no call syntax<tool_call|> done"},
wantContent: "<|tool_call>oops no call syntax<tool_call|> done",
}),
// --- (9) <turn|> ends the turn -------------------------------------------
Entry("text after <turn|> is ignored, including later fragments", parseGemma4Case{
fragments: []string{
"before<turn|>after",
`more <|tool_call>call:f{}<tool_call|>`,
},
wantContent: "before",
}),
Entry("<turn|> inside a thought channel ends the turn", parseGemma4Case{
startInThought: true,
fragments: []string{"thinking<turn|>ignored"},
wantReasoning: "thinking",
}),
// --- (10) ported vLLM non-streaming cases ---------------------------------
// vLLM: test_single_tool_call
Entry("vLLM: test_single_tool_call", parseGemma4Case{
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
}),
// vLLM: test_multiple_arguments
Entry("vLLM: test_multiple_arguments", parseGemma4Case{
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"San Francisco","unit":"celsius"}`}},
}),
// vLLM: test_text_before_tool_call. DIVERGENCE: vLLM's non-streaming
// extractor trims the content ("...you."); a streaming parser cannot
// retroactively trim already-emitted text, so the trailing space is
// kept (vLLM's own streaming path keeps it too, see
// test_streaming_text_before_tool_call which only checks a prefix).
Entry("vLLM: test_text_before_tool_call (streaming semantics: no trim)", parseGemma4Case{
fragments: []string{`Let me check the weather for you. <|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`},
wantContent: "Let me check the weather for you. ",
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
}),
// vLLM: test_multiple_tool_calls (also covers case 11: multi-tool sequence)
Entry("vLLM: test_multiple_tool_calls", parseGemma4Case{
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|><|tool_call>call:get_time{location:<|"|>London<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{
{name: "get_weather", argsJSON: `{"location":"London"}`},
{name: "get_time", argsJSON: `{"location":"London"}`},
},
}),
// vLLM: test_nested_arguments
Entry("vLLM: test_nested_arguments", parseGemma4Case{
fragments: []string{`<|tool_call>call:complex_function{nested:{inner:<|"|>value<|"|>},list:[<|"|>a<|"|>,<|"|>b<|"|>]}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "complex_function", argsJSON: `{"nested":{"inner":"value"},"list":["a","b"]}`}},
}),
// vLLM: test_tool_call_with_number_and_boolean
Entry("vLLM: test_tool_call_with_number_and_boolean", parseGemma4Case{
fragments: []string{`<|tool_call>call:set_status{is_active:true,count:42,score:3.14}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "set_status", argsJSON: `{"is_active":true,"count":42,"score":3.14}`}},
}),
// vLLM: test_hyphenated_function_name
Entry("vLLM: test_hyphenated_function_name", parseGemma4Case{
fragments: []string{`<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "get-weather", argsJSON: `{"location":"London"}`}},
}),
// vLLM: test_dotted_function_name
Entry("vLLM: test_dotted_function_name", parseGemma4Case{
fragments: []string{`<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "weather.get", argsJSON: `{"location":"London"}`}},
}),
// vLLM: test_no_arguments
Entry("vLLM: test_no_arguments", parseGemma4Case{
fragments: []string{"<|tool_call>call:get_status{}<tool_call|>"},
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
}),
// --- ported vLLM streaming cases (chunk lists reused as fragments) --------
// vLLM: test_basic_streaming_single_tool
Entry("vLLM: test_basic_streaming_single_tool", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:get_weather{",
`location:<|"|>Paris`,
", France",
`<|"|>}`,
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
}),
// vLLM: test_streaming_multi_arg
Entry("vLLM: test_streaming_multi_arg", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:get_weather{",
`location:<|"|>Tokyo<|"|>,`,
`unit:<|"|>celsius<|"|>}`,
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Tokyo","unit":"celsius"}`}},
}),
// vLLM: test_streaming_text_before_tool_call
Entry("vLLM: test_streaming_text_before_tool_call", parseGemma4Case{
fragments: []string{
"Let me check ",
"the weather. ",
"<|tool_call>",
"call:get_weather{",
`location:<|"|>London<|"|>}`,
"<tool_call|>",
},
wantContent: "Let me check the weather. ",
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
}),
// vLLM: test_streaming_numeric_args
Entry("vLLM: test_streaming_numeric_args", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:set_config{",
"count:42,",
"active:true}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "set_config", argsJSON: `{"count":42,"active":true}`}},
}),
// vLLM: test_streaming_boolean_split_across_chunks
Entry("vLLM: test_streaming_boolean_split_across_chunks", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:search{input:{all:tru",
"e}}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "search", argsJSON: `{"input":{"all":true}}`}},
}),
// vLLM: test_streaming_false_split_across_chunks
Entry("vLLM: test_streaming_false_split_across_chunks", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:set{flag:fals",
"e}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"flag":false}`}},
}),
// vLLM: test_streaming_number_split_across_chunks
Entry("vLLM: test_streaming_number_split_across_chunks", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:set{count:4",
"2}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"count":42}`}},
}),
// vLLM: test_streaming_empty_args
Entry("vLLM: test_streaming_empty_args", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:get_status{}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
}),
// vLLM: test_streaming_split_delimiter_no_invalid_json (string
// delimiter <|"|> split across fragments must not leak fragments).
Entry("vLLM: test_streaming_split_delimiter_no_invalid_json", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:todowrite{",
`content:<|"|>Buy milk<|`,
`"|>}`,
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{name: "todowrite", argsJSON: `{"content":"Buy milk"}`}},
}),
// vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call
Entry("vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:get_weather{",
`location:<|"|>Paris<|"|>}`,
"<tool_call|><",
"div>",
},
wantContent: "<div>",
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
}),
// vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes
Entry("vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:write_file{",
`path:<|"|>index.html<|"|>,`,
`content:<|"|><!DOCTYPE html>` + "\n<",
`html lang="zh-CN">` + "\n<",
"head>\n <",
`meta charset="UTF-8">` + "\n <",
`meta name="viewport" content="width=device-width">` + "\n",
`<|"|>}`,
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{
name: "write_file",
argsJSON: `{"path":"index.html","content":"<!DOCTYPE html>\n<html lang=\"zh-CN\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width\">\n"}`,
}},
}),
// vLLM: test_streaming_single_chunk_complete_tool_call
Entry("vLLM: test_streaming_single_chunk_complete_tool_call", parseGemma4Case{
fragments: []string{`<|tool_call>call:name_a_color{color_hex:<|"|>00ff11<|"|>}<tool_call|>`},
wantTools: []wantGemma4Tool{{name: "name_a_color", argsJSON: `{"color_hex":"00ff11"}`}},
}),
// vLLM: test_streaming_multi_chunk_batched_tool_calls (two complete
// calls in ONE fragment; both must come out with distinct indices)
Entry("vLLM: test_streaming_multi_chunk_batched_tool_calls", parseGemma4Case{
fragments: []string{
`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>` +
`<|tool_call>call:get_time{timezone:<|"|>GMT<|"|>}<tool_call|>`,
},
wantTools: []wantGemma4Tool{
{name: "get_weather", argsJSON: `{"location":"London"}`},
{name: "get_time", argsJSON: `{"timezone":"GMT"}`},
},
}),
// vLLM: test_streaming_trailing_bare_bool_not_duplicated
Entry("vLLM: test_streaming_trailing_bare_bool_not_duplicated", parseGemma4Case{
fragments: []string{
"<|tool_call>",
"call:Edit{",
`file_path:<|"|>src/env.py<|"|>,`,
`old_string:<|"|>old_val<|"|>,`,
`new_string:<|"|>new_val<|"|>,`,
"replace_all:",
"false}",
"<tool_call|>",
},
wantTools: []wantGemma4Tool{{
name: "Edit",
argsJSON: `{"file_path":"src/env.py","old_string":"old_val","new_string":"new_val","replace_all":false}`,
}},
}),
// --- implicit reasoning end on <|tool_call> (vLLM is_reasoning_end:
// a tool_call token means reasoning is over) -----------------------------
Entry("tool call inside an open thought channel ends the reasoning", parseGemma4Case{
startInThought: true,
fragments: []string{`need the weather<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}<tool_call|>`},
wantReasoning: "need the weather",
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Rome"}`}},
}),
// --- (12) empty fragments are no-ops --------------------------------------
Entry("empty fragments are no-ops", parseGemma4Case{
fragments: []string{"", "Hello", "", "", " world", ""},
wantContent: "Hello world",
}),
)
It("returns no deltas for an empty fragment and after Close", func() {
p := NewGemma4Parser(false)
Expect(p.Feed("")).To(BeEmpty())
Expect(p.Feed("hi")).ToNot(BeEmpty())
Expect(p.Close()).To(BeEmpty()) // nothing held back
// The parser is finished after Close: further input is dropped.
Expect(p.Feed("more")).To(BeEmpty())
Expect(p.Close()).To(BeEmpty())
})
It("generates index-based tool call ids (call_<index>)", func() {
// Mirrors the index-based id convention of pkg/grpc/rich_test.go and
// keeps ids deterministic for the split-invariance property below.
deltas := parseGemma4Fragments(false, []string{
`<|tool_call>call:a{}<tool_call|><|tool_call>call:b{}<tool_call|>`,
})
_, _, tools := flattenGemma4Deltas(deltas)
Expect(tools).To(HaveLen(2))
Expect(tools[0].id).To(Equal("call_0"))
Expect(tools[1].id).To(Equal("call_1"))
})
// Property: for a fixed full output, EVERY 2-split position must yield
// exactly the same flattened result as the unsplit parse. This kills
// fragment-boundary bugs (mid-marker, mid-delimiter, mid-payload splits).
DescribeTable("2-split fragment invariance",
func(startInThought bool, full string) {
refContent, refReasoning, refTools := flattenGemma4Deltas(
parseGemma4Fragments(startInThought, []string{full}))
for i := 0; i <= len(full); i++ {
content, reasoning, tools := flattenGemma4Deltas(
parseGemma4Fragments(startInThought, []string{full[:i], full[i:]}))
Expect(content).To(Equal(refContent), fmt.Sprintf("content diverged at split %d", i))
Expect(reasoning).To(Equal(refReasoning), fmt.Sprintf("reasoning diverged at split %d", i))
Expect(tools).To(Equal(refTools), fmt.Sprintf("tool calls diverged at split %d", i))
}
},
Entry("thought + content + two tool calls + turn end", false,
"<|channel>thought\nPondering the request...\n<channel|>Sure - calling tools now. "+
`<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>,days:3,detailed:true}<tool_call|>`+
`<|tool_call>call:get_time{timezone:<|"|>Europe/Lisbon<|"|>,nested:{flag:false,vals:[1,2.5,<|"|>x<|"|>]}}<tool_call|>`+
"Done.<turn|>ignored tail"),
Entry("startInThought + tool call + trailing partial marker", true,
`Deep thought<channel|>final answer <|tool_call>call:noop{}<tool_call|> trailing <|tool`),
Entry("malformed payload fallback", false,
`pre <|tool_call>not a call<tool_call|> post`),
)
})
// Decoder-level ports of vLLM's TestParseGemma4Args / TestParseGemma4Array
// (non-partial mode; the partial-withholding tests do not apply because this
// parser only ever decodes COMPLETE payloads, see gemma4_parser.go).
var _ = Describe("decodeGemma4Args", func() {
DescribeTable("decodes the gemma4 call syntax into JSON arguments",
func(in, wantJSON string) {
Expect(decodeGemma4Args(in, 0)).To(MatchJSON(wantJSON))
},
// vLLM: test_empty_string / test_whitespace_only
Entry("empty string", "", `{}`),
Entry("whitespace only", " ", `{}`),
// vLLM: test_single_string_value
Entry("single string value", `location:<|"|>Paris<|"|>`, `{"location":"Paris"}`),
// vLLM: test_string_value_with_comma
Entry("string value with comma", `location:<|"|>Paris, France<|"|>`, `{"location":"Paris, France"}`),
// vLLM: test_multiple_string_values
Entry("multiple string values", `location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>`, `{"location":"San Francisco","unit":"celsius"}`),
// vLLM: test_integer_value / test_float_value
Entry("integer value", "count:42", `{"count":42}`),
Entry("float value", "score:3.14", `{"score":3.14}`),
// vLLM: test_boolean_true / test_boolean_false
Entry("boolean true", "flag:true", `{"flag":true}`),
Entry("boolean false", "flag:false", `{"flag":false}`),
// vLLM: test_null_value (bare null must become JSON null, not "null")
Entry("null value", "param:null", `{"param":null}`),
// vLLM: test_mixed_types
Entry("mixed types", `name:<|"|>test<|"|>,count:42,active:true,score:3.14`,
`{"name":"test","count":42,"active":true,"score":3.14}`),
// vLLM: test_nested_object
Entry("nested object", `nested:{inner:<|"|>value<|"|>}`, `{"nested":{"inner":"value"}}`),
// vLLM: test_array_of_strings
Entry("array of strings", `items:[<|"|>a<|"|>,<|"|>b<|"|>]`, `{"items":["a","b"]}`),
// vLLM: test_unterminated_string (take everything after the delimiter)
Entry("unterminated string", `key:<|"|>unterminated`, `{"key":"unterminated"}`),
// vLLM: test_empty_value (key with no value after colon)
Entry("empty value", "key:", `{"key":""}`),
// vLLM: test_trailing_dot_float_partial_withheld, non-partial branch
// (trailing-dot floats parse normally outside streaming).
Entry("trailing dot float, complete payload", "left:108.,right:22.8", `{"left":108.0,"right":22.8}`),
)
It("terminates and yields valid JSON on malformed input", func() {
// vLLM: test_malformed_partial_array (the assertion there is only
// "returns a dict without hanging"; ours is "valid JSON object").
out := decodeGemma4Args(":[t:[]", 0)
var v map[string]any
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
})
It("degrades nesting beyond the recursion cap to a string value", func() {
// 200 levels of a:{a:{...a:1...}}. Without the depth cap the mutual
// recursion would grow the stack with the model's output; a Go stack
// overflow is a fatal process kill, so levels past gemma4MaxArgsDepth
// must gracefully fall back to the raw inner text as a JSON string.
const depth = 200
body := strings.Repeat("a:{", depth-1) + "a:1" + strings.Repeat("}", depth-1)
out := decodeGemma4Args(body, 0)
var v map[string]any
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
levels := 0
var cur any = v
for {
m, ok := cur.(map[string]any)
if !ok {
break
}
Expect(m).To(HaveKey("a"))
cur = m["a"]
levels++
}
Expect(levels).To(Equal(gemma4MaxArgsDepth + 1))
Expect(cur).To(BeAssignableToTypeOf(""))
Expect(cur).To(ContainSubstring("a:{"))
})
})
var _ = Describe("decodeGemma4Array", func() {
DescribeTable("decodes gemma4 array bodies into JSON arrays",
func(in, wantJSON string) {
Expect(decodeGemma4Array(in, 0)).To(MatchJSON(wantJSON))
},
// vLLM: test_string_array / test_empty_array / test_bare_values
Entry("string array", `<|"|>a<|"|>,<|"|>b<|"|>`, `["a","b"]`),
Entry("empty array", "", `[]`),
Entry("bare values", "42,true,3.14", `[42,true,3.14]`),
// vLLM: test_string_element_with_closing_bracket (a ']' inside a
// delimited string must not close the array)
Entry("string element with closing bracket", `[<|"|>a]b<|"|>,<|"|>c<|"|>],<|"|>tail<|"|>`, `[["a]b","c"],"tail"]`),
// vLLM: test_stray_closing_bracket (no-progress abort, keep prefix)
Entry("stray closing bracket", "42,]trailing", `[42]`),
)
})

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,347 +0,0 @@
package main
// Renderer specs for RenderGemma4 against the canonical gemma4 chat template
// (see the normative template comment in gemma4_renderer.go).
//
// Fixture provenance:
// - "single user message" and "enable_thinking" are the EXACT expected
// decodes from transformers tests/models/diffusion_gemma/
// test_modeling_diffusion_gemma.py (test_diffusion_gemma_chat_template
// and ..._with_thinking) with ONE difference: the transformers fixtures
// start with "<bos>" because apply_chat_template tokenizes the rendered
// text with add_bos. Our prompt goes through dllm_capi_generate, whose
// run_generate already tokenizes with prepend_bos = vocab.add_bos
// (dllm.cpp src/capi.cpp:230-231, true for gemma4), so the renderer must
// NOT emit a literal <bos> (it would double) and every expected string
// here drops that leading token.
// - All other expected strings were produced by rendering the verbatim
// GGUF template with jinja2 3.1.2 (bos_token="<bos>") and dropping the
// leading "<bos>" for the same reason.
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// Two-function tools array used by the tool fixtures (OpenAI wire shape, as
// LocalAI passes it through PredictOptions.Tools).
const testToolsJSON = `[{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a location.","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city name."},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}},{"type":"function","function":{"name":"get_time","description":"Get the current time in a timezone.","parameters":{"type":"object","properties":{"timezone":{"type":"string","description":"IANA timezone name."}},"required":["timezone"]}}}]`
// The <|tool>...<tool|> block the template renders for testToolsJSON inside
// the system turn (jinja2-verified).
const testToolsBlock = `<|tool>declaration:get_weather{description:<|"|>Get the current weather in a location.<|"|>,parameters:{properties:{location:{description:<|"|>The city name.<|"|>,type:<|"|>STRING<|"|>},unit:{enum:[<|"|>celsius<|"|>,<|"|>fahrenheit<|"|>],type:<|"|>STRING<|"|>}},required:[<|"|>location<|"|>],type:<|"|>OBJECT<|"|>}}<tool|><|tool>declaration:get_time{description:<|"|>Get the current time in a timezone.<|"|>,parameters:{properties:{timezone:{description:<|"|>IANA timezone name.<|"|>,type:<|"|>STRING<|"|>}},required:[<|"|>timezone<|"|>],type:<|"|>OBJECT<|"|>}}<tool|>`
// A single tool exercising the deep format_parameters branches: array items
// (string-typed and nested-array), nullable, enum+nullable, nested object
// properties/required, and a response declaration.
const complexToolsJSON = `[{"type":"function","function":{"name":"complex_tool","description":"A complex tool.","parameters":{"type":"object","properties":{"tags":{"type":"array","description":"Tags.","items":{"type":"string"}},"matrix":{"type":"array","items":{"type":"array","items":{"type":"number"}}},"opts":{"type":"object","description":"Options.","properties":{"depth":{"type":"integer","nullable":true}},"required":["depth"]},"mode":{"type":"string","enum":["a","b"],"nullable":true}},"required":["tags","opts"]},"response":{"description":"The result.","type":"object"}}}]`
// jinja2-verified render of complexToolsJSON. Notable template quirks pinned
// here: nested array items go through format_argument with ESCAPED keys and
// an un-uppercased type (<|"|>type<|"|>:<|"|>number<|"|>), while direct item
// types are uppercased; properties dictsort case-insensitively.
const complexToolsBlock = `<|tool>declaration:complex_tool{description:<|"|>A complex tool.<|"|>,parameters:{properties:{matrix:{items:{items:{<|"|>type<|"|>:<|"|>number<|"|>},type:<|"|>ARRAY<|"|>},type:<|"|>ARRAY<|"|>},mode:{enum:[<|"|>a<|"|>,<|"|>b<|"|>],nullable:true,type:<|"|>STRING<|"|>},opts:{description:<|"|>Options.<|"|>,properties:{depth:{nullable:true,type:<|"|>INTEGER<|"|>}},required:[<|"|>depth<|"|>],type:<|"|>OBJECT<|"|>},tags:{description:<|"|>Tags.<|"|>,items:{type:<|"|>STRING<|"|>},type:<|"|>ARRAY<|"|>}},required:[<|"|>tags<|"|>,<|"|>opts<|"|>],type:<|"|>OBJECT<|"|>},response:{description:<|"|>The result.<|"|>,type:<|"|>OBJECT<|"|>}}<tool|>`
type renderGemma4Case struct {
msgs []*pb.Message
toolsJSON string
enableThinking bool
noGenerationPrompt bool // inverted so the zero value is the common case
expected string
}
var _ = Describe("RenderGemma4", func() {
DescribeTable("renders the canonical gemma4 prompt",
func(c renderGemma4Case) {
out, err := RenderGemma4(c.msgs, c.toolsJSON, c.enableThinking, !c.noGenerationPrompt)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal(c.expected))
// The C-ABI generate prepends BOS itself: a literal <bos>
// anywhere in the rendered prompt would double-encode it.
Expect(out).ToNot(ContainSubstring("<bos>"))
},
// transformers fixture (test_diffusion_gemma_chat_template), sans <bos>:
// default thinking pre-opens an EMPTY thought channel in the
// generation prompt.
Entry("single user message, default (no thinking)", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "Write a long essay about Portugal."},
},
expected: "<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// transformers fixture (test_diffusion_gemma_chat_template_with_thinking),
// sans <bos>: a system turn carrying <|think|> and NO auto-opened
// thought channel.
Entry("enable_thinking=true", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "Write a long essay about Portugal."},
},
enableThinking: true,
expected: "<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n",
}),
Entry("multi-turn user/assistant/user", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "Hello, who are you?"},
{Role: "assistant", Content: "I am Gemma, a helpful assistant."},
{Role: "user", Content: "Tell me a joke."},
},
expected: "<|turn>user\nHello, who are you?<turn|>\n<|turn>model\nI am Gemma, a helpful assistant.<turn|>\n<|turn>user\nTell me a joke.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// tpl L178-L195: a leading system message is folded into the system
// turn (trimmed) and consumed from the loop.
Entry("system message folds into the system turn", renderGemma4Case{
msgs: []*pb.Message{
{Role: "system", Content: "You are a pirate."},
{Role: "user", Content: "Hello!"},
},
expected: "<|turn>system\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// tpl L182-L185: <|think|> goes at the very top of the SAME system
// turn, before the system prompt text.
Entry("system message with enable_thinking shares the turn", renderGemma4Case{
msgs: []*pb.Message{
{Role: "system", Content: "You are a pirate."},
{Role: "user", Content: "Hello!"},
},
enableThinking: true,
expected: "<|turn>system\n<|think|>\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n",
}),
// tpl L196-L203: tool declarations render in the system turn, one
// <|tool>declaration:...<tool|> block per tool, no separators.
Entry("tools array (two functions)", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "What is the weather in Tokyo?"},
},
toolsJSON: testToolsJSON,
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// format_parameters deep branches (tpl L1-L85) + response declaration
// (tpl L106-L116).
Entry("complex tool schema (array items, nullable, nested object, response)", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
},
toolsJSON: complexToolsJSON,
expected: "<|turn>system\n" + complexToolsBlock + "<turn|>\n<|turn>user\ngo<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// tpl L243-L313: assistant tool_calls render as
// <|tool_call>call:name{args}<tool_call|>; the following role=tool
// message renders inline as <|tool_response>response:name{value:..}
// <tool_response|>; the model turn stays OPEN (no <turn|>, no new
// generation prompt) so the model continues after the response.
Entry("assistant tool_calls + role=tool result", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "What is the weather in Tokyo?"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
{Role: "tool", ToolCallId: "call_1", Content: "Sunny, 22 degrees celsius."},
},
toolsJSON: testToolsJSON,
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny, 22 degrees celsius.<|"|>}<tool_response|>`,
}),
// tpl L348-L349: a tool_calls turn with no rendered responses ends
// on an OPEN <|tool_response> marker for the runtime to fill, and
// add_generation_prompt adds nothing (tpl L357).
Entry("assistant tool_calls without a result leaves <|tool_response> open", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "What is the weather in Tokyo?"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
},
toolsJSON: testToolsJSON,
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>`,
}),
// tpl L237-L241: reasoning_content renders as a thought channel only
// on a tool-calling turn after the last user message.
Entry("reasoning_content with tool_calls renders the thought channel", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "weather?"},
{Role: "assistant", Content: "", ReasoningContent: "I should call the tool", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
},
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n<|channel>thought\nI should call the tool\n<channel|>" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>`,
}),
// tpl L220-L235: the assistant answer following its own tool round
// continues the SAME model turn (no second <|turn>model).
Entry("tool round then final assistant answer then user", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "weather?"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
{Role: "assistant", Content: "It is sunny."},
{Role: "user", Content: "thanks"},
},
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>` + "It is sunny.<turn|>\n<|turn>user\nthanks<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// format_argument (tpl L118-L147): numbers keep their JSON literal,
// booleans lower-case, nested maps have unquoted dictsorted keys,
// arrays bracketed; top-level args are dictsorted case-insensitively.
Entry("tool_call argument types (number/bool/nested/array)", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"count\":42,\"ratio\":3.5,\"flag\":true,\"off\":false,\"nested\":{\"x\":\"y\",\"n\":7},\"list\":[\"a\",1,true]}"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{count:42,flag:true,list:[<|"|>a<|"|>,1,true],nested:{n:7,x:<|"|>y<|"|>},off:false,ratio:3.5}<tool_call|><|tool_response>`,
}),
// jinja dictsort is case-insensitive: alpha sorts before Beta.
Entry("tool_call argument dictsort is case-insensitive", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"Beta\":1,\"alpha\":2}"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{alpha:2,Beta:1}<tool_call|><|tool_response>",
}),
// jinja renders Python None as "None" (round-trips through vLLM's
// parser, which lowers "none" back to null).
Entry("tool_call null argument renders as None", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"maybe\":null}"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{maybe:None}<tool_call|><|tool_response>",
}),
Entry("tool_call empty arguments render empty braces", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{}<tool_call|><|tool_response>",
}),
// tpl L253-L254: a non-object arguments string renders verbatim.
Entry("tool_call non-object string arguments render verbatim", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"just text"}}]`},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{just text}<tool_call|><|tool_response>",
}),
// tpl L278-L285: unmatched tool_call_id falls back to the tool
// message's own name.
Entry("tool result name falls back when tool_call_id does not match", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "go"},
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
{Role: "tool", ToolCallId: "OTHER", Name: "named_tool", Content: "out"},
},
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{}<tool_call|><|tool_response>response:named_tool{value:<|"|>out<|"|>}<tool_response|>`,
}),
// strip_thinking (tpl L148-L158): historical assistant content loses
// its <|channel>...<channel|> spans.
Entry("assistant content thinking channels are stripped", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "hi"},
{Role: "assistant", Content: "<|channel>thought\nsecret\n<channel|>visible answer"},
{Role: "user", Content: "more"},
},
expected: "<|turn>user\nhi<turn|>\n<|turn>model\nvisible answer<turn|>\n<|turn>user\nmore<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
// tpl L220-L235: consecutive assistant messages suppress the second
// <|turn>model (continuation), but each still closes with <turn|>.
Entry("consecutive assistant messages continue the model turn", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "hi"},
{Role: "assistant", Content: "part one"},
{Role: "assistant", Content: "part two"},
{Role: "user", Content: "ok"},
},
expected: "<|turn>user\nhi<turn|>\n<|turn>model\npart one<turn|>\npart two<turn|>\n<|turn>user\nok<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
}),
Entry("add_generation_prompt=false renders no model turn", renderGemma4Case{
msgs: []*pb.Message{
{Role: "user", Content: "hi"},
},
noGenerationPrompt: true,
expected: "<|turn>user\nhi<turn|>\n",
}),
)
Describe("error handling", func() {
It("fails loud on an unknown role", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "narrator", Content: "Meanwhile..."},
}, "", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(`unknown role "narrator"`))
})
It("fails on invalid tools JSON", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
}, "{not json", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tools JSON"))
})
It("fails on invalid tool_calls JSON", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
{Role: "assistant", Content: "", ToolCalls: "{not json"},
}, "", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tool_calls JSON"))
})
It("fails on an orphan tool message, naming its index", func() {
// A role:tool message with no preceding assistant tool_calls turn
// would be silently dropped by the jinja; we fail loud instead.
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
{Role: "tool", Content: `{"temp": 20}`, ToolCallId: "call_1"},
}, "", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("orphan tool message 1"))
})
It("fails on trailing garbage after the tools JSON array", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
}, "[] junk", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tools JSON"))
})
It("fails when the tools JSON is not an array", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
}, `{"type":"function"}`, false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tools JSON is not an array"))
})
It("fails when a tools array element is not an object", func() {
_, err := RenderGemma4([]*pb.Message{
{Role: "user", Content: "hi"},
}, `[42]`, false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tools[0] is not an object"))
})
It("rejects a nil message via the unknown-role check", func() {
// Pins current behavior: pb getters are nil-safe, so a nil message
// reads as role "" and trips the fail-loud unknown-role guard.
_, err := RenderGemma4([]*pb.Message{nil}, "", false, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(`unknown role "" in message 0`))
})
})
})

View File

@@ -1,85 +0,0 @@
package main
// Started internally by LocalAI - one gRPC server per loaded model.
//
// Loads libdllm.so via purego and registers the 9-symbol flat C-ABI
// declared in dllm.cpp's include/dllm_capi.h (ABI v1). The library name can
// be overridden with DLLM_LIBRARY (mirrors the PARAKEET_LIBRARY /
// WHISPER_LIBRARY convention in the sibling backends); the default looks
// for the .so next to this binary (run.sh puts the package dir on
// LD_LIBRARY_PATH).
import (
"flag"
"fmt"
"os"
"github.com/ebitengine/purego"
grpc "github.com/mudler/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
type LibFuncs struct {
FuncPtr any
Name string
}
// loadCAPI dlopens libName and binds the 9 dllm_capi_* entry points 1:1 to
// dllm_capi.h, so an `nm libdllm.so | grep dllm_capi` is enough to spot
// drift. Shared with the test suite (ensureLibLoaded), which drives the
// bridge without the gRPC server.
//
// The C-ABI returns malloc'd char* buffers from tokenize_json/generate; we
// register those as uintptr so we get the raw pointer back and can call
// dllm_capi_free_string on it (purego's string return would copy and forget
// the original pointer, leaking it on every call). last_error returns a
// BORROWED pointer instead, so it is registered as a plain string: purego
// copies it and nothing must be freed.
func loadCAPI(libName string) error {
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
return fmt.Errorf("dllm: dlopen %q: %w", libName, err)
}
libFuncs := []LibFuncs{
{&cppAbiVersion, "dllm_capi_abi_version"},
{&cppLoad, "dllm_capi_load"},
{&cppFree, "dllm_capi_free"},
{&cppLastError, "dllm_capi_last_error"},
{&cppFreeString, "dllm_capi_free_string"},
{&cppTokenizeJSON, "dllm_capi_tokenize_json"},
{&cppGenerate, "dllm_capi_generate"},
{&cppGenerateStream, "dllm_capi_generate_stream"},
{&cppCancel, "dllm_capi_cancel"},
}
for _, lf := range libFuncs {
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
}
return nil
}
func main() {
libName := os.Getenv("DLLM_LIBRARY")
if libName == "" {
libName = "libdllm.so"
}
if err := loadCAPI(libName); err != nil {
panic(err)
}
// Hard-fail on an ABI mismatch: the flat-pointer bindings above would
// otherwise misbehave silently against a future libdllm.so.
if v := cAbiVersion(); v != dllmABIVersion {
panic(fmt.Errorf("dllm: libdllm.so ABI=%d, this backend speaks ABI=%d", v, dllmABIVersion))
}
fmt.Fprintf(os.Stderr, "[dllm] ABI=%d\n", cAbiVersion())
flag.Parse()
if err := grpc.StartServer(*addr, &Dllm{}); err != nil {
panic(err)
}
}

View File

@@ -1,24 +0,0 @@
#!/bin/bash
#
# T1 packaging stub: copy the binary, run.sh and libdllm.so into package/.
# The full ldd walk (libc, libstdc++, libgomp, GPU runtimes, arch
# detection) lands with the registration task, mirroring
# backend/go/whisper/package.sh.
set -e
CURDIR=$(dirname "$(realpath "$0")")
mkdir -p "$CURDIR/package/lib"
cp -avf "$CURDIR/dllm-grpc" "$CURDIR/package/"
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
# libdllm.so + any soname symlinks, should upstream ever add them.
cp -avf "$CURDIR"/libdllm.so* "$CURDIR/package/lib/" 2>/dev/null || {
echo "ERROR: libdllm.so not found in $CURDIR, run 'make' first" >&2
exit 1
}
echo "T1 package layout (full ldd walk lands with registration):"
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"

View File

@@ -1,16 +0,0 @@
#!/bin/bash
set -e
CURDIR=$(dirname "$(realpath "$0")")
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
# If a self-contained ld.so was packaged, route through it so the
# packaged libc / libstdc++ are used instead of the host's (matches the
# whisper / parakeet-cpp backends' runtime layout).
if [ -f "$CURDIR/lib/ld.so" ]; then
echo "Using lib/ld.so"
exec "$CURDIR/lib/ld.so" "$CURDIR/dllm-grpc" "$@"
fi
exec "$CURDIR/dllm-grpc" "$@"

View File

@@ -1,6 +1,6 @@
# parakeet-cpp backend Makefile.
#
# Upstream pin lives below as PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
# Upstream pin lives below as PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
# (.github/bump_deps.sh) can find and update it - matches the
# whisper.cpp / ds4 / vibevoice-cpp convention.
#
@@ -15,7 +15,7 @@
# That's what the L0 smoke test uses. The default target below does the
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
GOCMD?=go

View File

@@ -7,12 +7,8 @@ import "time"
type batchRequest struct {
pcm []float32
decoder int32
// language is the per-request target locale ("" means the model default).
// parakeet.cpp's batched C-API takes ONE target_lang for the whole batch,
// so the dispatcher only coalesces requests that share a language.
language string
tag string
reply chan batchReply
tag string
reply chan batchReply
}
// batchReply carries one per-item JSON object string (an element of the C-API's
@@ -47,25 +43,13 @@ func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchReques
// run is the dispatcher loop: accumulate submitted requests until either maxSize
// is reached or maxWait elapses since the first queued request, then dispatch.
// Exits when stop is closed (draining any partially-filled batch first).
//
// A batch carries ONE language (parakeet.cpp's batched C-API takes a single
// target_lang), so a request whose language differs from the batch leader is
// not coalesced: it is held in carry and becomes the leader of the next batch.
// carry is therefore never dropped and its caller never deadlocks: every batch
// (including a lone carry on stop) is dispatched, and runBatch replies to all.
func (b *batcher) run(stop <-chan struct{}) {
var carry *batchRequest
for {
var first *batchRequest
if carry != nil {
// A mismatched request from the previous fill leads this batch.
first, carry = carry, nil
} else {
select {
case first = <-b.submit:
case <-stop:
return
}
select {
case first = <-b.submit:
case <-stop:
return
}
batch := []*batchRequest{first}
@@ -80,22 +64,12 @@ func (b *batcher) run(stop <-chan struct{}) {
for len(batch) < b.maxSize {
select {
case r := <-b.submit:
if r.language != first.language {
// Different language: carry it to the next batch so this
// batch stays single-language, then dispatch what we have.
carry = r
break fill
}
batch = append(batch, r)
case <-timer.C:
break fill
case <-stop:
timer.Stop()
b.runBatch(batch)
// Don't strand a carried request's caller on shutdown.
if carry != nil {
b.runBatch([]*batchRequest{carry})
}
return
}
}

View File

@@ -105,60 +105,4 @@ var _ = Describe("batcher", func() {
go func() { <-rep }()
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
})
It("never coalesces requests with different languages into one batch", func() {
// parakeet.cpp's batched C-API takes ONE target_lang per batch, so the
// dispatcher must keep every dispatched batch single-language. Submit a
// mix of languages and assert (a) no batch ever carries more than one
// distinct language and (b) every submitted request still gets a reply
// (the mismatched carry-over is never dropped).
var mu sync.Mutex
var langsPerBatch [][]string
run := func(reqs []*batchRequest) {
seen := map[string]struct{}{}
var distinct []string
for _, r := range reqs {
if _, ok := seen[r.language]; !ok {
seen[r.language] = struct{}{}
distinct = append(distinct, r.language)
}
}
mu.Lock()
langsPerBatch = append(langsPerBatch, distinct)
mu.Unlock()
echoReply(reqs)
}
// Large window + size so the fill loop stays open across submits and the
// language constraint (not the timer) is what splits the batches.
b := newBatcher(16, 200*time.Millisecond, run)
stop := make(chan struct{})
go b.run(stop)
defer close(stop)
langs := []string{"en", "en", "de", "de", "en", "fr", "fr"}
const N = 7
var wg sync.WaitGroup
got := make([]string, N)
for i := 0; i < N; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
rep := make(chan batchReply, 1)
b.submit <- &batchRequest{tag: string(rune('a' + i)), language: langs[i], reply: rep}
got[i] = (<-rep).json
}(i)
}
wg.Wait()
mu.Lock()
defer mu.Unlock()
// Invariant: every dispatched batch is single-language.
for _, distinct := range langsPerBatch {
Expect(len(distinct)).To(Equal(1), "a batch coalesced more than one language: %v", distinct)
}
// Liveness: every request got a reply (carry-over never stranded).
for i := 0; i < N; i++ {
Expect(got[i]).To(Equal(string(rune('a' + i))))
}
})
})

View File

@@ -48,13 +48,6 @@ var (
// side reads them as const float*/const int*.
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
// CppTranscribePcmBatchJSONLang is the multilingual variant of the batched
// JSON entry point: identical, plus a trailing target_lang. "" (the model
// default, "auto") is passed for non-prompt models, which ignore it; an
// unknown locale on a prompt model returns 0 and sets last_error. Present
// only in newer libparakeet.so; nil falls back to CppTranscribePcmBatchJSON.
CppTranscribePcmBatchJSONLang func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32, targetLang string) uintptr
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
@@ -62,18 +55,6 @@ var (
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
CppStreamFinalize func(s uintptr) uintptr
CppStreamFree func(s uintptr)
// CppStreamBeginLang is the multilingual variant of stream_begin: identical,
// plus a trailing target_lang ("" means the model default). Present only in
// newer libparakeet.so; nil falls back to CppStreamBegin.
CppStreamBeginLang func(ctx uintptr, targetLang string) uintptr
// Streaming JSON variants (ABI v4): feed/finalize returning a malloc'd char*
// JSON document {text,eou,frame_sec,words} (uintptr, freed via CppFreeString)
// so streaming segments can carry per-word timestamps. Present only in newer
// libparakeet.so; nil falls back to the text-only CppStreamFeed/Finalize path.
CppStreamFeedJSON func(s uintptr, pcm []float32, nSamples int32) uintptr
CppStreamFinalizeJSON func(s uintptr) uintptr
)
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
@@ -91,26 +72,9 @@ const streamChunkSamples = 16000
//
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
type transcriptJSON struct {
Text string `json:"text"`
FrameSec float64 `json:"frame_sec"`
Words []transcriptWord `json:"words"`
Tokens []transcriptToken `json:"tokens"`
}
// streamFeedJSON mirrors the document returned by
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v4):
//
// {"text":"...","eou":0,"frame_sec":0.080000,
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
//
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
// <EOU>/<EOB> fired this feed; "words" are the words finalized this call with
// absolute (stream-relative) start/end seconds.
type streamFeedJSON struct {
Text string `json:"text"`
Eou int `json:"eou"`
FrameSec float64 `json:"frame_sec"`
Words []transcriptWord `json:"words"`
Text string `json:"text"`
Words []transcriptWord `json:"words"`
Tokens []transcriptToken `json:"tokens"`
}
type transcriptWord struct {
@@ -139,10 +103,6 @@ type ParakeetCpp struct {
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
bat *batcher
batStop chan struct{}
// segmentGapFrames is NeMo's segment_gap_threshold in ENCODER FRAMES (model
// YAML option, default 0=off). When >0 it adds NeMo's silence-gap split on
// top of the punctuation split; converted to seconds via the JSON frame_sec.
segmentGapFrames int
}
// Load is the LocalAI gRPC entry point for LoadModel: it calls
@@ -172,11 +132,6 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
if maxWaitMs < 0 {
maxWaitMs = 0
}
// NeMo's segment_gap_threshold (encoder frames, default 0=off). Off by
// default matches NeMo's default (punctuation-only segments); when set it
// additionally splits segments on inter-word silence (see transcriptResultFromDoc).
p.segmentGapFrames = optInt(opts, "segment_gap_threshold", 0)
if CppTranscribePcmBatchJSON != nil {
p.batStop = make(chan struct{})
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
@@ -232,19 +187,8 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
if len(reqs) > 0 {
dec = reqs[0].decoder
}
// All requests in a batch share one language (the batcher coalesces only
// same-language requests), so any element's language describes the batch.
lang := ""
if len(reqs) > 0 {
lang = reqs[0].language
}
p.engineMu.Lock()
var cstr uintptr
if CppTranscribePcmBatchJSONLang != nil {
cstr = CppTranscribePcmBatchJSONLang(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec, lang)
} else {
cstr = CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
}
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
p.engineMu.Unlock()
if cstr == 0 {
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
@@ -282,9 +226,8 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
// OpenAI API, whose default is segment-level); token ids always populate
// Segment.Tokens.
//
// translate/diarize/prompt/temperature/threads are not applicable to parakeet
// and are ignored; language is honored on the batched + streaming paths (see
// opts.GetLanguage() below); streaming is handled by AudioTranscriptionStream
// translate/diarize/prompt/temperature/language/threads are not applicable to
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
// (L2).
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if p.ctxPtr == 0 {
@@ -316,7 +259,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
}
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
return transcriptResultFromDoc(doc, opts), nil
}
// Batched path: decode to PCM, submit to the batcher, wait for this request's
@@ -328,7 +271,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
}
rep := make(chan batchReply, 1)
select {
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, language: opts.GetLanguage(), reply: rep}:
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
case <-ctx.Done():
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
}
@@ -345,169 +288,34 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
}
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
return transcriptResultFromDoc(doc, opts), nil
}
// segmentSeparators is NeMo's default segment_seperators (sentence-ending
// punctuation). Splitting on these matches NeMo's default segment timestamps.
var segmentSeparators = []rune{'.', '?', '!'}
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
// grouping words into NeMo-faithful segments (see splitWordsIntoSegments). The
// optional gapFrames (NeMo's segment_gap_threshold, in encoder FRAMES; 0=off)
// additionally splits on inter-word silence; it is converted to a seconds gap
// with the document's frame_sec. Per-segment word timings are attached only when
// the caller requested word granularity; token ids populate each segment's
// Tokens by time-window membership. Shared by the batched and direct paths.
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
// synthesising a single whole-clip segment and attaching word timings only when
// the caller requested word granularity. Shared by the batched and direct paths.
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
text := strings.TrimSpace(doc.Text)
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
gapSeconds := 0.0
if gapFrames > 0 {
if doc.FrameSec > 0 {
gapSeconds = float64(gapFrames) * doc.FrameSec
} else {
xlog.Warn("parakeet-cpp: segment_gap_threshold set but libparakeet.so " +
"did not report frame_sec; falling back to punctuation-only segments")
}
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
for _, w := range doc.Words {
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
}
groups := splitWordsIntoSegments(doc.Words, segmentSeparators, gapSeconds)
if len(groups) == 0 {
// No words (edge case): single whole-clip text segment.
return pb.TranscriptResult{
Text: text,
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
}
tokens := make([]int32, 0, len(doc.Tokens))
for _, t := range doc.Tokens {
tokens = append(tokens, t.ID)
}
wantWords := wordsRequested(opts.TimestampGranularities)
segments := make([]*pb.TranscriptSegment, 0, len(groups))
for id, group := range groups {
parts := make([]string, len(group))
for i, gw := range group {
parts[i] = gw.W
}
seg := &pb.TranscriptSegment{
Id: int32(id),
Start: secondsToNanos(group[0].Start),
End: secondsToNanos(group[len(group)-1].End),
Text: strings.TrimSpace(strings.Join(parts, " ")),
Tokens: tokensInWindow(doc.Tokens, group[0].Start, group[len(group)-1].End),
}
if wantWords {
ws := make([]*pb.TranscriptWord, len(group))
for i, gw := range group {
ws[i] = &pb.TranscriptWord{Start: secondsToNanos(gw.Start), End: secondsToNanos(gw.End), Text: gw.W}
}
seg.Words = ws
}
segments = append(segments, seg)
var segStart, segEnd int64
if len(words) > 0 {
segStart = words[0].Start
segEnd = words[len(words)-1].End
}
return pb.TranscriptResult{Text: text, Segments: segments}
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
if wordsRequested(opts.TimestampGranularities) {
seg.Words = words
}
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
}
// splitWordsIntoSegments groups words into segments exactly as NeMo's
// get_segment_offsets does (nemo/collections/asr/parts/utils/timestamp_utils.py).
// Walking the words, it closes a segment when (1) the gap rule is enabled
// (gapSeconds > 0) and the segment already has words and the gap from the
// previous word's end to this word's start is >= gapSeconds - the current word
// then STARTS a new segment - or, checked only when the gap rule did not apply
// (NeMo's elif), (2) the word ends with (or is) a separator, which closes the
// segment INCLUDING that word. Trailing words flush into a final segment.
// gapSeconds <= 0 disables the gap rule, matching NeMo's default
// segment_gap_threshold=None (punctuation-only segments).
func splitWordsIntoSegments(words []transcriptWord, separators []rune, gapSeconds float64) [][]transcriptWord {
var segments [][]transcriptWord
var cur []transcriptWord
for i, word := range words {
gapActive := gapSeconds > 0 && len(cur) > 0
if gapActive && (word.Start-words[i-1].End) >= gapSeconds {
segments = append(segments, cur)
cur = []transcriptWord{word}
continue
}
if !gapActive && endsWithSeparator(word.W, separators) {
cur = append(cur, word)
segments = append(segments, cur)
cur = nil
continue
}
cur = append(cur, word)
}
if len(cur) > 0 {
segments = append(segments, cur)
}
return segments
}
// endsWithSeparator reports whether w's last rune is in separators (matching
// NeMo's `word[-1] in delims or word in delims`).
func endsWithSeparator(w string, separators []rune) bool {
r := []rune(strings.TrimSpace(w))
if len(r) == 0 {
return false
}
last := r[len(r)-1]
for _, s := range separators {
if last == s {
return true
}
}
return false
}
// tokensInWindow returns the ids of tokens whose timestamp t falls in
// [start, end] (inclusive), assigning each token to the segment that spans its
// time. The last segment's end is the last word end, so the final token is
// included.
func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
var ids []int32
for _, t := range tokens {
if t.T >= start && t.T <= end {
ids = append(ids, t.ID)
}
}
return ids
}
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
// is the model's own utterance boundary; each closed segment takes its start/end
// from its first/last accumulated word.
type streamSegmenter struct {
segs []*pb.TranscriptSegment
cur []transcriptWord
nextID int32
}
func (s *streamSegmenter) add(doc streamFeedJSON) {
s.cur = append(s.cur, doc.Words...)
if doc.Eou != 0 {
s.flush()
}
}
func (s *streamSegmenter) flush() {
if len(s.cur) == 0 {
return
}
parts := make([]string, len(s.cur))
for i, w := range s.cur {
parts[i] = w.W
}
s.segs = append(s.segs, &pb.TranscriptSegment{
Id: s.nextID,
Start: secondsToNanos(s.cur[0].Start),
End: secondsToNanos(s.cur[len(s.cur)-1].End),
Text: strings.TrimSpace(strings.Join(parts, " ")),
})
s.nextID++
s.cur = nil
}
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
// wordsRequested reports whether the caller asked for word-level timestamps.
// The OpenAI transcription API gates word timings behind
// timestamp_granularities[] containing "word" and defaults to segment-level
@@ -553,12 +361,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
return status.Error(codes.Canceled, "transcription cancelled")
}
var stream uintptr
if CppStreamBeginLang != nil {
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
} else {
stream = CppStreamBegin(p.ctxPtr)
}
stream := CppStreamBegin(p.ctxPtr)
if stream == 0 {
// Not a cache-aware streaming model: run a normal offline
// transcription and emit it as one delta + a closing final result.
@@ -587,14 +390,6 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
return err
}
// ABI v4: when the streaming JSON entry points are present, drive them so the
// per-utterance segments carry per-word start/end timestamps. Falls through to
// the text-only loop below against an older libparakeet.so. Runs under the
// engineMu already held above.
if CppStreamFeedJSON != nil {
return p.streamJSON(ctx, stream, data, duration, results)
}
var (
full strings.Builder
segText strings.Builder
@@ -671,71 +466,6 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
return nil
}
// streamJSON drives the ABI v4 streaming JSON entry points: each feed/finalize
// returns a {text,eou,frame_sec,words} document. The newly-finalized text is
// emitted as a delta (unchanged streaming contract) while words are accumulated
// into per-utterance segments (closed on EOU) so the closing FinalResult carries
// timestamped segments. Runs under engineMu (already held by the caller).
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
duration float32, results chan *pb.TranscriptStreamResponse) error {
var (
full strings.Builder
seg streamSegmenter
)
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
// emits the delta, and routes words through the segmenter.
consume := func(ret uintptr) error {
if ret == 0 {
msg := CppLastError(p.ctxPtr)
if msg == "" {
msg = "unknown error"
}
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
}
raw := goStringFromCPtr(ret)
CppFreeString(ret)
var doc streamFeedJSON
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
}
if doc.Text != "" {
full.WriteString(doc.Text)
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
}
seg.add(doc)
return nil
}
for off := 0; off < len(data); off += streamChunkSamples {
if err := ctx.Err(); err != nil {
return status.Error(codes.Canceled, "transcription cancelled")
}
end := min(off+streamChunkSamples, len(data))
chunk := data[off:end]
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
return err
}
}
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
return err
}
seg.flush() // close any trailing utterance that never saw an EOU
text := strings.TrimSpace(full.String())
segments := seg.segments()
if len(segments) == 0 && text != "" {
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
}
results <- &pb.TranscriptStreamResponse{
FinalResult: &pb.TranscriptResult{
Text: text,
Segments: segments,
Duration: duration,
},
}
return nil
}
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
// float samples plus the clip duration in seconds. Mirrors the whisper
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio

View File

@@ -53,10 +53,6 @@ func ensureLibLoaded() {
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
}
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
})
@@ -111,22 +107,13 @@ var _ = Describe("ParakeetCpp", func() {
Expect(err).ToNot(HaveOccurred())
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
"expected non-empty transcript for %s", audioPath)
// NeMo-faithful segmentation: one or more punctuation-delimited
// segments, each with text and a monotonically-advancing time span.
Expect(res.Segments).ToNot(BeEmpty(), "expected at least one segment")
var prevEnd int64
for i, seg := range res.Segments {
Expect(strings.TrimSpace(seg.Text)).ToNot(BeEmpty(),
"segment %d must have text", i)
Expect(seg.End).To(BeNumerically(">=", seg.Start),
"segment %d end must not precede its start", i)
Expect(seg.Start).To(BeNumerically(">=", prevEnd),
"segments must be in time order")
prevEnd = seg.End
// Default (no granularities) is segment-level: no per-word timings.
Expect(seg.Words).To(BeEmpty(),
"word timings are opt-in via timestamp_granularities")
}
Expect(res.Segments).To(HaveLen(1),
"synthesises a single whole-clip segment")
Expect(res.Segments[0].Text).To(Equal(res.Text),
"single segment text must equal the top-level text")
// Default (no granularities) is segment-level: no per-word timings.
Expect(res.Segments[0].Words).To(BeEmpty(),
"word timings are opt-in via timestamp_granularities")
})
It("emits word-level timestamps when granularity=word", func() {
@@ -142,28 +129,15 @@ var _ = Describe("ParakeetCpp", func() {
TimestampGranularities: []string{"word"},
})
Expect(err).ToNot(HaveOccurred())
Expect(res.Segments).ToNot(BeEmpty())
// With word granularity every segment carries its own words, and each
// segment's span tracks its first/last word; word starts advance
// monotonically across the whole transcript.
totalWords := 0
var prevStart int64 = -1
for i, seg := range res.Segments {
Expect(seg.Words).ToNot(BeEmpty(),
"segment %d must carry per-word timestamps with granularity=word", i)
Expect(seg.Start).To(Equal(seg.Words[0].Start),
"segment %d start tracks its first word", i)
Expect(seg.End).To(Equal(seg.Words[len(seg.Words)-1].End),
"segment %d end tracks its last word", i)
for _, w := range seg.Words {
Expect(w.End).To(BeNumerically(">=", w.Start))
Expect(w.Start).To(BeNumerically(">=", prevStart))
prevStart = w.Start
totalWords++
}
}
Expect(totalWords).To(BeNumerically(">", 0))
Expect(res.Segments[0].Words[0].Start).To(BeNumerically(">=", int64(0)))
Expect(res.Segments).To(HaveLen(1))
seg := res.Segments[0]
Expect(seg.Words).ToNot(BeEmpty(),
"expected per-word timestamps with granularity=word")
// Monotonic, non-negative timings spanning the segment.
Expect(seg.Words[0].Start).To(BeNumerically(">=", int64(0)))
Expect(seg.End).To(BeNumerically(">=", seg.Start))
Expect(seg.Words[len(seg.Words)-1].End).To(Equal(seg.End),
"segment end tracks the last word")
})
})

View File

@@ -65,25 +65,6 @@ func main() {
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
}
// Per-request language variants (multilingual nemotron). Same probe pattern:
// present only in libparakeet.so built with multilingual support, so the
// backend still loads against an older library and falls back to the
// non-lang batched + streaming entry points (model default / "auto").
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json_lang"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppTranscribePcmBatchJSONLang, lib, "parakeet_capi_transcribe_pcm_batch_json_lang")
}
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_begin_lang"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppStreamBeginLang, lib, "parakeet_capi_stream_begin_lang")
}
// Streaming JSON entry points (ABI v4): surface per-word timestamps on the
// streaming path. Same probe pattern; absent in older libparakeet.so, where
// the backend falls back to the text-only streaming feed.
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
}
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
flag.Parse()

View File

@@ -1,127 +0,0 @@
package main
import (
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func tw(text string, start, end float64) transcriptWord {
return transcriptWord{W: text, Start: start, End: end}
}
var _ = Describe("splitWordsIntoSegments (NeMo get_segment_offsets parity)", func() {
seps := []rune{'.', '?', '!'}
It("splits on sentence-ending punctuation, including the delimiter word", func() {
words := []transcriptWord{tw("hello", 0, 0.4), tw("world.", 0.4, 0.8), tw("bye", 1.0, 1.3)}
segs := splitWordsIntoSegments(words, seps, 0)
Expect(segs).To(HaveLen(2))
Expect(segs[0]).To(HaveLen(2))
Expect(segs[0][1].W).To(Equal("world."))
Expect(segs[1]).To(HaveLen(1))
Expect(segs[1][0].W).To(Equal("bye"))
})
It("keeps a single segment with no terminal punctuation and gap off", func() {
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
segs := splitWordsIntoSegments(words, seps, 0)
Expect(segs).To(HaveLen(1))
})
It("splits on the gap rule when enabled, the gapped word starting the next segment", func() {
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
segs := splitWordsIntoSegments(words, seps, 1.0) // c is 4.6s after b
Expect(segs).To(HaveLen(2))
Expect(segs[0]).To(HaveLen(2)) // a b
Expect(segs[1]).To(HaveLen(1)) // c
Expect(segs[1][0].W).To(Equal("c"))
})
It("checks the gap rule before punctuation (NeMo elif order)", func() {
// "b." would terminate, but c is far after it -> gap closes [a b.] at b.
words := []transcriptWord{tw("a", 0, 0.2), tw("b.", 0.2, 0.4), tw("c", 9.0, 9.2)}
segs := splitWordsIntoSegments(words, seps, 1.0)
Expect(segs).To(HaveLen(2))
Expect(segs[0]).To(HaveLen(2))
Expect(segs[1][0].W).To(Equal("c"))
})
It("still splits on punctuation when the gap rule is enabled but does not fire", func() {
words := []transcriptWord{tw("hi.", 0, 0.4), tw("bye", 0.4, 0.8)}
segs := splitWordsIntoSegments(words, seps, 5.0) // gap never reached
Expect(segs).To(HaveLen(2))
Expect(segs[0][0].W).To(Equal("hi."))
})
It("returns nothing for empty input", func() {
Expect(splitWordsIntoSegments(nil, seps, 0)).To(BeEmpty())
})
})
var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
doc := transcriptJSON{
Text: "hello world. bye now",
FrameSec: 0.08,
Words: []transcriptWord{
{W: "hello", Start: 0.0, End: 0.4},
{W: "world.", Start: 0.4, End: 0.8},
{W: "bye", Start: 1.0, End: 1.3},
{W: "now", Start: 1.3, End: 1.6},
},
Tokens: []transcriptToken{{ID: 1, T: 0.1}, {ID: 2, T: 0.5}, {ID: 3, T: 1.1}, {ID: 4, T: 1.4}},
}
It("emits one segment per punctuation-delimited group with start/end", func() {
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
Expect(res.Segments).To(HaveLen(2))
Expect(res.Segments[0].Text).To(Equal("hello world."))
Expect(res.Segments[0].Start).To(Equal(int64(0)))
Expect(res.Segments[0].End).To(Equal(secondsToNanos(0.8)))
Expect(res.Segments[1].Text).To(Equal("bye now"))
Expect(res.Segments[1].Start).To(Equal(secondsToNanos(1.0)))
Expect(res.Segments[1].Id).To(Equal(int32(1)))
})
It("assigns tokens to the segment whose time window contains them", func() {
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
Expect(res.Segments[0].Tokens).To(Equal([]int32{1, 2}))
Expect(res.Segments[1].Tokens).To(Equal([]int32{3, 4}))
})
It("attaches per-segment words only when word granularity requested", func() {
plain := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
Expect(plain.Segments[0].Words).To(BeEmpty())
withWords := transcriptResultFromDoc(doc, &pb.TranscriptRequest{TimestampGranularities: []string{"word"}}, 0)
Expect(withWords.Segments[0].Words).To(HaveLen(2))
})
It("falls back to a single text segment when there are no words", func() {
res := transcriptResultFromDoc(transcriptJSON{Text: "hi"}, &pb.TranscriptRequest{}, 0)
Expect(res.Segments).To(HaveLen(1))
Expect(res.Segments[0].Text).To(Equal("hi"))
})
})
var _ = Describe("streaming segment assembly", func() {
It("closes a segment with start/end from its words on EOU", func() {
acc := &streamSegmenter{}
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
}})
segs := acc.segments()
Expect(segs).To(HaveLen(1))
Expect(segs[0].Text).To(Equal("hello world"))
Expect(segs[0].Start).To(Equal(int64(0)))
Expect(segs[0].End).To(Equal(secondsToNanos(0.9)))
})
It("buffers words across feeds until EOU", func() {
acc := &streamSegmenter{}
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
Expect(acc.segments()).To(BeEmpty())
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
Expect(acc.segments()).To(HaveLen(1))
Expect(acc.segments()[0].Text).To(Equal("hi there"))
})
})

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=19bdfe22d255d5b4dff39d449318b9bc5ea2317f
STABLEDIFFUSION_GGML_VERSION?=1f9ee88e09c258053fa59d5e05e23dfb10fa0b13
CMAKE_ARGS+=-DGGML_MAX_NAME=128

View File

@@ -386,7 +386,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
const char *llm_vision_path = "";
const char *diffusion_model_path = stableDiffusionModel;
const char *high_noise_diffusion_model_path = "";
const char *uncond_diffusion_model_path = "";
const char *taesd_path = "";
const char *control_net_path = "";
const char *embedding_dir = "";
@@ -473,7 +472,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval);
if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval);
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
if (!strcmp(optname, "uncond_diffusion_model_path")) uncond_diffusion_model_path = strdup(optval);
if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
if (!strcmp(optname, "embedding_dir")) {
@@ -573,7 +571,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
ctx_params.llm_vision_path = llm_vision_path;
ctx_params.diffusion_model_path = diffusion_model_path;
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
ctx_params.uncond_diffusion_model_path = uncond_diffusion_model_path;
ctx_params.vae_path = vae_path;
ctx_params.audio_vae_path = audio_vae_path;
ctx_params.embeddings_connectors_path = embeddings_connectors_path;

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=df7638d8229a243af8a4b5a8ae557e0d74e0a0ae
WHISPER_CPP_VERSION?=99613cb720b65036237d44b52f753b51f75c2797
SO_TARGET?=libgowhisper.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -95,29 +95,6 @@
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4"
metal: "metal-ds4"
metal-darwin-arm64: "metal-ds4"
- &dllm
name: "dllm"
alias: "dllm"
license: mit
description: |
mudler/dllm.cpp - DiffusionGemma block-diffusion LLM inference engine
(C++/ggml, GGUF weights). Decodes whole token canvases per diffusion
round instead of autoregressive sampling. Runs on CPU and NVIDIA CUDA 13
(including Jetson/GB10 L4T targets).
urls:
- https://github.com/mudler/dllm.cpp
tags:
- text-to-text
- LLM
- gguf
- diffusion
- CPU
- CUDA
capabilities:
default: "cpu-dllm"
nvidia: "cuda13-dllm"
nvidia-cuda-13: "cuda13-dllm"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm"
- &whispercpp
name: "whisper"
alias: "whisper"
@@ -1295,13 +1272,6 @@
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4-development"
metal: "metal-ds4-development"
metal-darwin-arm64: "metal-ds4-development"
- !!merge <<: *dllm
name: "dllm-development"
capabilities:
default: "cpu-dllm-development"
nvidia: "cuda13-dllm-development"
nvidia-cuda-13: "cuda13-dllm-development"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm-development"
- !!merge <<: *stablediffusionggml
name: "stablediffusion-ggml-development"
capabilities:
@@ -1889,37 +1859,6 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-ds4"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-ds4
## dllm
- !!merge <<: *dllm
name: "cpu-dllm"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-dllm"
mirrors:
- localai/localai-backends:latest-cpu-dllm
- !!merge <<: *dllm
name: "cpu-dllm-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-dllm"
mirrors:
- localai/localai-backends:master-cpu-dllm
- !!merge <<: *dllm
name: "cuda13-dllm"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-dllm"
mirrors:
- localai/localai-backends:latest-gpu-nvidia-cuda-13-dllm
- !!merge <<: *dllm
name: "cuda13-dllm-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-dllm"
mirrors:
- localai/localai-backends:master-gpu-nvidia-cuda-13-dllm
- !!merge <<: *dllm
name: "cuda13-nvidia-l4t-arm64-dllm"
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm"
mirrors:
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm
- !!merge <<: *dllm
name: "cuda13-nvidia-l4t-arm64-dllm-development"
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-dllm"
mirrors:
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-dllm
## whisper
- !!merge <<: *whispercpp
name: "whisper-development"

View File

@@ -1,6 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cpu
transformers==4.48.3
transformers==5.0.0rc3
accelerate
torch==2.4.1
torch==2.7.1+cpu
torchaudio==2.4.1
coqui-tts

View File

@@ -1,5 +1,5 @@
torch==2.4.1
torch==2.7.1+cpu
torchaudio==2.4.1
transformers==4.48.3
transformers==5.0.0rc3
accelerate
coqui-tts

View File

@@ -1,6 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/rocm7.0
torch==2.10.0+rocm7.0
torch==2.7.1+cpu
torchaudio==2.10.0+rocm7.0
transformers==4.48.3
transformers==5.0.0rc3
accelerate
coqui-tts

View File

@@ -1,8 +1,8 @@
--extra-index-url https://download.pytorch.org/whl/xpu
torch==2.8.0+xpu
torch==2.7.1+cpu
torchaudio==2.8.0+xpu
optimum[openvino]
setuptools
transformers==4.48.3
transformers==5.0.0rc3
accelerate
coqui-tts

View File

@@ -1,4 +1,4 @@
torch==2.7.1
transformers==4.48.3
torch==2.7.1+cpu
transformers==5.0.0rc3
accelerate
coqui-tts

View File

@@ -26,10 +26,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
try:
from vllm.tokenizers import get_tokenizer # vLLM >= 0.22
except ImportError:
from vllm.transformers_utils.tokenizer import get_tokenizer # vLLM < 0.22
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.multimodal.utils import fetch_image
from vllm.assets.video import VideoAsset
import base64

View File

@@ -23,9 +23,9 @@ import (
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/services/routing/router"
"github.com/mudler/LocalAI/core/services/storage"
"github.com/mudler/LocalAI/pkg/signals"
coreStartup "github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/signals"
"github.com/mudler/LocalAI/pkg/vram"
"github.com/mudler/LocalAI/pkg/model"
@@ -308,31 +308,10 @@ func New(opts ...config.AppOption) (*Application, error) {
application.galleryService.SetNATSClient(distSvc.Nats)
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
// Clean up stale in-progress operations from previous crashed instances
if _, err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
xlog.Warn("Failed to clean stale gallery operations", "error", err)
}
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
// Reap stale ops periodically, not just at boot: an op orphaned by
// a replica that died mid-install (its foreground handler goroutine
// gone) would otherwise linger "processing" in the UI until the next
// restart. 30m matches the install/upgrade ceiling so a genuinely
// slow op is never reaped out from under itself.
gsvc := application.galleryService
go func() {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for {
select {
case <-options.Context.Done():
return
case <-ticker.C:
if _, err := gsvc.ReapStaleOperations(30 * time.Minute); err != nil {
xlog.Warn("Failed to reap stale gallery operations", "error", err)
}
}
}
}()
}
// Hydrate from the store first so the wildcard subscriber finds an
// already-populated statuses map for any operations still in flight

View File

@@ -214,9 +214,7 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
"from", info.InstalledVersion, "to", info.AvailableVersion)
var err error
if bm != nil {
// Background auto-upgrade: no live admin watching a progress bar,
// so opID is empty and the distributed path skips progress streaming.
err = bm.UpgradeBackend(ctx, "", name, nil)
err = bm.UpgradeBackend(ctx, name, nil)
} else {
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)

View File

@@ -1,30 +0,0 @@
package chat
import (
"context"
"io"
"strings"
)
type Options struct {
Model string
BaseURL string
APIKey string
In io.Reader
Out io.Writer
}
func Run(ctx context.Context, opts Options) error {
if opts.In == nil {
opts.In = strings.NewReader("")
}
if opts.Out == nil {
opts.Out = io.Discard
}
session, err := newChatSession(ctx, newLocalAIChatClient(opts.BaseURL, opts.APIKey), opts.Model)
if err != nil {
return err
}
return runTerminalChat(ctx, session, opts.In, opts.Out)
}

View File

@@ -1,13 +0,0 @@
package chat
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestChat(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Chat Suite")
}

View File

@@ -1,172 +0,0 @@
package chat
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Run chat", func() {
It("streams a single chat response", func() {
var capturedModel string
var capturedAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/models" {
w.Header().Set("Content-Type", "application/json")
writeResponse(w, `{"object":"list","data":[{"id":"test-model","object":"model"}]}`)
return
}
Expect(r.URL.Path).To(Equal("/v1/chat/completions"))
capturedAuth = r.Header.Get("Authorization")
var body struct {
Model string `json:"model"`
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
}
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
capturedModel = body.Model
Expect(body.Messages).To(HaveLen(1))
Expect(body.Messages[0].Role).To(Equal("user"))
Expect(body.Messages[0].Content).To(Equal("hello"))
w.Header().Set("Content-Type", "text/event-stream")
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n")
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n")
writeResponse(w, "data: [DONE]\n\n")
}))
defer server.Close()
var out bytes.Buffer
err := Run(GinkgoT().Context(), Options{
Model: "test-model",
BaseURL: server.URL + "/v1",
APIKey: "secret",
In: strings.NewReader("hello\n/exit\n"),
Out: &out,
})
Expect(err).ToNot(HaveOccurred())
Expect(capturedModel).To(Equal("test-model"))
Expect(capturedAuth).To(Equal("Bearer secret"))
Expect(out.String()).To(ContainSubstring("assistant: hi!"))
Expect(out.String()).To(ContainSubstring("bye"))
})
It("auto-selects the only available model", func() {
server := chatTestServer([]string{"solo"}, nil)
defer server.Close()
var out bytes.Buffer
err := Run(GinkgoT().Context(), Options{
BaseURL: server.URL + "/v1",
In: strings.NewReader("/exit\n"),
Out: &out,
})
Expect(err).ToNot(HaveOccurred())
Expect(out.String()).To(ContainSubstring("LocalAI chat (solo)"))
})
It("returns an actionable error when no models are installed", func() {
server := chatTestServer(nil, nil)
defer server.Close()
err := Run(GinkgoT().Context(), Options{
BaseURL: server.URL + "/v1",
In: strings.NewReader(""),
})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no chat models are installed"))
Expect(err.Error()).To(ContainSubstring("local-ai models install <model>"))
})
It("returns an actionable error when multiple models are available without a selection", func() {
server := chatTestServer([]string{"alpha", "beta"}, nil)
defer server.Close()
err := Run(GinkgoT().Context(), Options{
BaseURL: server.URL + "/v1",
In: strings.NewReader(""),
})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("multiple models are available"))
Expect(err.Error()).To(ContainSubstring("--model"))
Expect(err.Error()).To(ContainSubstring("alpha"))
Expect(err.Error()).To(ContainSubstring("beta"))
})
It("lists and switches models inside the chat", func() {
requestedModels := []string{}
server := chatTestServer([]string{"alpha", "beta"}, func(model string) {
requestedModels = append(requestedModels, model)
})
defer server.Close()
var out bytes.Buffer
err := Run(GinkgoT().Context(), Options{
Model: "alpha",
BaseURL: server.URL + "/v1",
In: strings.NewReader("/models\n/model beta\nhello\n/exit\n"),
Out: &out,
})
Expect(err).ToNot(HaveOccurred())
Expect(out.String()).To(ContainSubstring("* alpha"))
Expect(out.String()).To(ContainSubstring(" beta"))
Expect(out.String()).To(ContainSubstring("switched to beta; conversation cleared"))
Expect(requestedModels).To(Equal([]string{"beta"}))
})
})
func chatTestServer(models []string, onChat func(model string)) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/models":
w.Header().Set("Content-Type", "application/json")
writeResponse(w, `{"object":"list","data":[`)
for i, model := range models {
if i > 0 {
writeResponse(w, ",")
}
writeResponsef(w, `{"id":%q,"object":"model"}`, model)
}
writeResponse(w, `]}`)
case "/v1/chat/completions":
var body struct {
Model string `json:"model"`
}
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
if onChat != nil {
onChat(body.Model)
}
w.Header().Set("Content-Type", "text/event-stream")
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n\n")
writeResponse(w, "data: [DONE]\n\n")
default:
w.WriteHeader(http.StatusNotFound)
}
}))
}
func writeResponse(w io.Writer, text string) {
_, err := fmt.Fprint(w, text)
Expect(err).ToNot(HaveOccurred())
}
func writeResponsef(w io.Writer, format string, args ...any) {
_, err := fmt.Fprintf(w, format, args...)
Expect(err).ToNot(HaveOccurred())
}

View File

@@ -1,114 +0,0 @@
package chat
import (
"context"
"errors"
"fmt"
"io"
"sort"
"strings"
openai "github.com/sashabaranov/go-openai"
)
type chatClient interface {
ListModels(ctx context.Context) ([]string, error)
StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error)
}
type localAIChatClient struct {
client *openai.Client
}
func newLocalAIChatClient(baseURL string, apiKey string) *localAIChatClient {
cfg := openai.DefaultConfig(apiKey)
cfg.BaseURL = baseURL
return &localAIChatClient{client: openai.NewClientWithConfig(cfg)}
}
func (c *localAIChatClient) ListModels(ctx context.Context) ([]string, error) {
resp, err := c.client.ListModels(ctx)
if err != nil {
return nil, err
}
models := make([]string, 0, len(resp.Models))
for _, model := range resp.Models {
if model.ID != "" {
models = append(models, model.ID)
}
}
sort.Strings(models)
return models, nil
}
func (c *localAIChatClient) StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
stream, err := c.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
Model: model,
Messages: openAIChatMessages(messages),
})
if err != nil {
return "", friendlyChatError(err, model)
}
defer func() {
_ = stream.Close()
}()
var answer strings.Builder
for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return answer.String(), friendlyChatError(err, model)
}
if len(resp.Choices) == 0 {
continue
}
token := resp.Choices[0].Delta.Content
if token == "" {
continue
}
answer.WriteString(token)
if _, err := fmt.Fprint(out, token); err != nil {
return answer.String(), err
}
}
return answer.String(), nil
}
func openAIChatMessages(messages []chatMessage) []openai.ChatCompletionMessage {
converted := make([]openai.ChatCompletionMessage, len(messages))
for i, message := range messages {
converted[i] = openai.ChatCompletionMessage{
Role: message.Role,
Content: message.Content,
}
}
return converted
}
func friendlyChatError(err error, model string) error {
var apiErr *openai.APIError
if errors.As(err, &apiErr) {
switch apiErr.HTTPStatusCode {
case 404:
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
case 403:
return fmt.Errorf("model %q is disabled. Enable it from LocalAI settings or choose another model with `/model <name>`", model)
}
if apiErr.Message != "" {
return errors.New(apiErr.Message)
}
}
msg := err.Error()
if strings.Contains(msg, "model") && strings.Contains(msg, "not found") {
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
}
return err
}

View File

@@ -1,17 +0,0 @@
package chat
import "strings"
func formatChatModelList(models []string, current string) string {
var b strings.Builder
for _, model := range models {
prefix := " "
if model == current {
prefix = "* "
}
b.WriteString(prefix)
b.WriteString(model)
b.WriteByte('\n')
}
return b.String()
}

View File

@@ -1,120 +0,0 @@
package chat
import (
"context"
"errors"
"fmt"
"io"
"strings"
)
const (
chatRoleUser = "user"
chatRoleAssistant = "assistant"
)
type chatMessage struct {
Role string
Content string
}
type chatSession struct {
client chatClient
model string
models []string
messages []chatMessage
}
func newChatSession(ctx context.Context, client chatClient, requestedModel string) (*chatSession, error) {
models, err := client.ListModels(ctx)
if err != nil {
return nil, fmt.Errorf("list models: %w", err)
}
model, err := resolveChatModel(requestedModel, models)
if err != nil {
return nil, err
}
return &chatSession{
client: client,
model: model,
models: models,
}, nil
}
func (s *chatSession) CurrentModel() string {
return s.model
}
func (s *chatSession) Models() []string {
models := make([]string, len(s.models))
copy(models, s.models)
return models
}
func (s *chatSession) Clear() {
s.messages = nil
}
func (s *chatSession) SwitchModel(model string) error {
if !modelExists(s.models, model) {
return fmt.Errorf("model %q is not available. Use /models to see installed models", model)
}
s.model = model
s.Clear()
return nil
}
func (s *chatSession) Send(ctx context.Context, prompt string, out io.Writer) error {
s.messages = append(s.messages, chatMessage{
Role: chatRoleUser,
Content: prompt,
})
answer, err := s.client.StreamChat(ctx, s.model, s.messages, out)
if err != nil {
return err
}
s.messages = append(s.messages, chatMessage{
Role: chatRoleAssistant,
Content: answer,
})
return nil
}
func resolveChatModel(requested string, models []string) (string, error) {
switch {
case requested == "" && len(models) == 0:
return "", errors.New(`no chat models are installed.
Install a model first, for example:
local-ai models list
local-ai models install <model>
local-ai run
Then start a chat session:
local-ai chat --model <model>`)
case requested == "" && len(models) == 1:
return models[0], nil
case requested == "" && len(models) > 1:
var b strings.Builder
b.WriteString("multiple models are available; choose one with --model:\n")
b.WriteString(formatChatModelList(models, ""))
return "", errors.New(b.String())
case !modelExists(models, requested):
return "", fmt.Errorf("model %q is not available. Use `local-ai models list` and `local-ai models install <model>`, or pass an installed model with --model", requested)
default:
return requested, nil
}
}
func modelExists(models []string, name string) bool {
for _, model := range models {
if model == name {
return true
}
}
return false
}

View File

@@ -1,56 +0,0 @@
package chat
import (
"context"
"io"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Chat session", func() {
It("keeps model switching and message history out of the terminal adapter", func() {
client := &fakeChatClient{
models: []string{"alpha", "beta"},
answer: "pong",
}
session, err := newChatSession(context.Background(), client, "alpha")
Expect(err).ToNot(HaveOccurred())
Expect(session.CurrentModel()).To(Equal("alpha"))
Expect(session.SwitchModel("beta")).To(Succeed())
Expect(session.CurrentModel()).To(Equal("beta"))
Expect(session.Send(context.Background(), "ping", io.Discard)).To(Succeed())
Expect(client.requests).To(HaveLen(1))
Expect(client.requests[0].model).To(Equal("beta"))
Expect(client.requests[0].messages).To(HaveLen(1))
Expect(client.requests[0].messages[0].Content).To(Equal("ping"))
})
})
type fakeChatClient struct {
models []string
answer string
requests []fakeChatRequest
}
type fakeChatRequest struct {
model string
messages []chatMessage
}
func (c *fakeChatClient) ListModels(context.Context) ([]string, error) {
return c.models, nil
}
func (c *fakeChatClient) StreamChat(_ context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
copied := make([]chatMessage, len(messages))
copy(copied, messages)
c.requests = append(c.requests, fakeChatRequest{model: model, messages: copied})
if _, err := io.WriteString(out, c.answer); err != nil {
return "", err
}
return c.answer, nil
}

View File

@@ -1,93 +0,0 @@
package chat
import (
"bufio"
"context"
"fmt"
"io"
"strings"
)
func runTerminalChat(ctx context.Context, session *chatSession, in io.Reader, out io.Writer) error {
scanner := bufio.NewScanner(in)
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
if err := writeChat(out, "LocalAI chat (%s)\n", session.CurrentModel()); err != nil {
return err
}
if err := writeChat(out, "Type /exit to quit, /clear to reset the conversation, /models to list models.\n"); err != nil {
return err
}
for {
if err := writeChat(out, "\n> "); err != nil {
return err
}
if !scanner.Scan() {
break
}
prompt := strings.TrimSpace(scanner.Text())
switch prompt {
case "":
continue
case "/bye", "/exit", "/quit":
return writeChat(out, "bye\n")
case "/clear":
session.Clear()
if err := writeChat(out, "conversation cleared\n"); err != nil {
return err
}
continue
case "/models":
if err := printChatModels(out, session.Models(), session.CurrentModel()); err != nil {
return err
}
continue
}
if nextModel, ok := strings.CutPrefix(prompt, "/model "); ok {
nextModel = strings.TrimSpace(nextModel)
if nextModel == "" {
if err := writeChat(out, "usage: /model <name>\n"); err != nil {
return err
}
continue
}
if err := session.SwitchModel(nextModel); err != nil {
if writeErr := writeChat(out, "%s\n", err); writeErr != nil {
return writeErr
}
continue
}
if err := writeChat(out, "switched to %s; conversation cleared\n", session.CurrentModel()); err != nil {
return err
}
continue
}
if err := writeChat(out, "assistant: "); err != nil {
return err
}
if err := session.Send(ctx, prompt, out); err != nil {
return err
}
if err := writeChat(out, "\n"); err != nil {
return err
}
}
return scanner.Err()
}
func printChatModels(out io.Writer, models []string, current string) error {
if len(models) == 0 {
return writeChat(out, "no models installed\n")
}
return writeChat(out, "%s", formatChatModelList(models, current))
}
func writeChat(out io.Writer, format string, args ...any) error {
_, err := fmt.Fprintf(out, format, args...)
return err
}

View File

@@ -1,25 +0,0 @@
package cli
import (
"context"
"os"
chatcli "github.com/mudler/LocalAI/core/cli/chat"
cliContext "github.com/mudler/LocalAI/core/cli/context"
)
type ChatCMD struct {
Model string `short:"m" help:"Model name to use. Defaults to the only model returned by the server when exactly one is available"`
Endpoint string `env:"LOCALAI_CHAT_ENDPOINT" default:"http://127.0.0.1:8080" help:"LocalAI server endpoint. The /v1 path is added automatically when omitted"`
APIKey string `env:"LOCALAI_API_KEY,API_KEY" help:"API key to use when the LocalAI server requires authentication"`
}
func (c *ChatCMD) Run(ctx *cliContext.Context) error {
return chatcli.Run(context.Background(), chatcli.Options{
Model: c.Model,
BaseURL: chatAPIBaseURL(c.Endpoint),
APIKey: c.APIKey,
In: os.Stdin,
Out: os.Stdout,
})
}

View File

@@ -1,27 +0,0 @@
package cli
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Chat command wiring", func() {
Describe("chatAPIBaseURL", func() {
It("adds /v1 to a root endpoint", func() {
Expect(chatAPIBaseURL("http://127.0.0.1:8080")).To(Equal("http://127.0.0.1:8080/v1"))
})
It("keeps endpoints that already include /v1", func() {
Expect(chatAPIBaseURL("http://127.0.0.1:8080/v1")).To(Equal("http://127.0.0.1:8080/v1"))
Expect(chatAPIBaseURL("http://127.0.0.1:8080/v1/")).To(Equal("http://127.0.0.1:8080/v1"))
})
It("adds a default http scheme", func() {
Expect(chatAPIBaseURL("127.0.0.1:8080")).To(Equal("http://127.0.0.1:8080/v1"))
})
It("preserves non-root paths before /v1", func() {
Expect(chatAPIBaseURL("http://127.0.0.1:8080/localai")).To(Equal("http://127.0.0.1:8080/localai/v1"))
})
})
})

View File

@@ -1,29 +0,0 @@
package cli
import (
"net/url"
"strings"
)
func chatAPIBaseURL(endpoint string) string {
if !strings.Contains(endpoint, "://") {
endpoint = "http://" + endpoint
}
u, err := url.Parse(endpoint)
if err != nil {
return strings.TrimRight(endpoint, "/") + "/v1"
}
path := strings.TrimRight(u.Path, "/")
if path == "" {
u.Path = "/v1"
} else if path != "/v1" && !strings.HasSuffix(path, "/v1") {
u.Path = path + "/v1"
} else {
u.Path = path
}
u.RawQuery = ""
u.Fragment = ""
return u.String()
}

View File

@@ -9,7 +9,6 @@ var CLI struct {
cliContext.Context `embed:""`
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
Chat ChatCMD `cmd:"" help:"Open an interactive chat session against a running LocalAI server"`
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
Backends BackendsCMD `cmd:"" help:"Manage LocalAI backends and definitions"`

View File

@@ -30,8 +30,6 @@ type RunCMD struct {
ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
ExternalBackends []string `env:"LOCALAI_EXTERNAL_BACKENDS,EXTERNAL_BACKENDS" help:"A list of external backends to load from gallery on boot" group:"backends"`
WebRTCNAT1To1IPs []string `env:"LOCALAI_WEBRTC_NAT_1TO1_IPS,WEBRTC_NAT_1TO1_IPS" help:"IPs advertised as the host ICE candidates for /v1/realtime WebRTC instead of every local interface. Set to the reachable host/LAN IP when running under Docker host networking or NAT, where pion otherwise offers unreachable bridge addresses and the connection drops after ICE consent checks fail." group:"api"`
WebRTCICEInterfaces []string `env:"LOCALAI_WEBRTC_ICE_INTERFACES,WEBRTC_ICE_INTERFACES" help:"Restrict /v1/realtime WebRTC ICE candidate gathering to these network interfaces (e.g. eth0), filtering out docker0/veth noise." group:"api"`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
@@ -227,8 +225,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithApiKeys(r.APIKeys),
config.WithModelsURL(append(r.Models, r.ModelArgs...)...),
config.WithExternalBackends(r.ExternalBackends...),
config.WithWebRTCNAT1To1IPs(r.WebRTCNAT1To1IPs...),
config.WithWebRTCICEInterfaces(r.WebRTCICEInterfaces...),
config.WithOpaqueErrors(r.OpaqueErrors),
config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan),
config.WithSubtleKeyComparison(r.UseSubtleKeyComparison),
@@ -656,12 +652,12 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
// waitForServerReady polls the given address until the HTTP server is
// accepting connections or the context is cancelled.
func waitForServerReady(address string, ctx context.Context) {
// Ensure the address has a host component for dialing.
// Echo accepts ":8080" but net.Dial needs a resolvable host.
host, port, err := net.SplitHostPort(address)
if err == nil && host == "" {
address = "127.0.0.1:" + port
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for {
select {
@@ -669,17 +665,11 @@ func waitForServerReady(address string, ctx context.Context) {
return
default:
}
conn, err := net.DialTimeout("tcp", address, 500*time.Millisecond)
if err == nil {
conn.Close()
return
}
select {
case <-ctx.Done():
return
case <-ticker.C:
}
time.Sleep(250 * time.Millisecond)
}
}

View File

@@ -12,19 +12,10 @@ import (
)
type ApplicationConfig struct {
Context context.Context
ConfigFile string
SystemState *system.SystemState
ExternalBackends []string
// WebRTCNAT1To1IPs, when set, are advertised as the host ICE candidates for
// /v1/realtime WebRTC instead of every local interface address. Needed when
// the routable address differs from what pion gathers — e.g. Docker host
// networking (where pion also offers unreachable bridge IPs) or NAT.
WebRTCNAT1To1IPs []string
// WebRTCICEInterfaces, when set, restricts ICE candidate gathering to these
// network interfaces (e.g. eth0), filtering out docker0/veth noise.
WebRTCICEInterfaces []string
Context context.Context
ConfigFile string
SystemState *system.SystemState
ExternalBackends []string
UploadLimitMB, Threads, ContextSize int
F16 bool
Debug bool
@@ -90,6 +81,7 @@ type ApplicationConfig struct {
// file is mode 0600.
MITMCADir string
// PIIPatternOverrides applies persisted per-id deltas (action,
// disabled) to the live redactor at startup. Loaded from
// runtime_settings.json and applied right after pii.NewRedactor.
@@ -124,11 +116,11 @@ type ApplicationConfig struct {
// --require-backend-integrity / LOCALAI_REQUIRE_BACKEND_INTEGRITY.
RequireBackendIntegrity bool
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
MemoryReclaimerEnabled bool // Enable memory threshold monitoring
@@ -319,18 +311,6 @@ func WithExternalBackends(backends ...string) AppOption {
}
}
func WithWebRTCNAT1To1IPs(ips ...string) AppOption {
return func(o *ApplicationConfig) {
o.WebRTCNAT1To1IPs = ips
}
}
func WithWebRTCICEInterfaces(interfaces ...string) AppOption {
return func(o *ApplicationConfig) {
o.WebRTCICEInterfaces = interfaces
}
}
func WithMachineTag(tag string) AppOption {
return func(o *ApplicationConfig) {
o.MachineTag = tag
@@ -722,6 +702,7 @@ func WithMITMCADir(dir string) AppOption {
}
}
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
return func(o *ApplicationConfig) {
o.DynamicConfigsDir = dynamicConfigsDir

View File

@@ -39,21 +39,7 @@ func llamaCppDefaults(cfg *ModelConfig, modelPath string) {
}
}()
// Startup parses every model's GGUF header to guess defaults. We only need
// scalar metadata (architecture, head/ff counts, chat_template, token IDs,
// MTP head) plus array *lengths* — never the array *contents*. Two options
// keep this cheap, which matters when many models live on slow storage such
// as a Docker volume (see https://github.com/mudler/LocalAI/issues/9790):
//
// - SkipLargeMetadata: seek past large array-valued metadata (the tokenizer
// vocab: tokenizer.ggml.tokens/scores/merges, often >100k entries) instead
// of reading and allocating every element. Lengths stay populated.
// - UseMMap: read the header via a memory map so faulting in a few pages
// replaces hundreds of thousands of tiny read() syscalls (measured ~524k
// -> 8 for a 256k-token vocab), the dominant cost on slow filesystems.
//
// The mapping is released when ParseGGUFFile returns.
f, err := gguf.ParseGGUFFile(guessPath, gguf.UseMMap(), gguf.SkipLargeMetadata())
f, err := gguf.ParseGGUFFile(guessPath)
if err == nil {
guessGGUFFromFile(cfg, f, 0)
}

View File

@@ -1,76 +1,13 @@
package config_test
import (
"bytes"
"encoding/binary"
"os"
"path/filepath"
. "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
gguf "github.com/gpustack/gguf-parser-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// GGUF metadata value type tags (see github.com/gpustack/gguf-parser-go).
const (
ggufTypeUint32 uint32 = 4
ggufTypeString uint32 = 8
ggufTypeArray uint32 = 9
)
// writeTestGGUF emits a minimal but valid little-endian GGUF v3 header carrying
// the scalar metadata the llama-cpp hook guesses from plus a large string vocab
// array (tokenizer.ggml.tokens). The big array is exactly what SkipLargeMetadata
// + UseMMap are expected to avoid reading element-by-element, so it must survive a
// round-trip through the real hook without corrupting the guessed defaults.
func writeTestGGUF(path, chatTemplate string, vocab int) error {
wStr := func(b *bytes.Buffer, s string) {
binary.Write(b, binary.LittleEndian, uint64(len(s)))
b.WriteString(s)
}
kvStr := func(b *bytes.Buffer, k, v string) {
wStr(b, k)
binary.Write(b, binary.LittleEndian, ggufTypeString)
wStr(b, v)
}
kvU32 := func(b *bytes.Buffer, k string, v uint32) {
wStr(b, k)
binary.Write(b, binary.LittleEndian, ggufTypeUint32)
binary.Write(b, binary.LittleEndian, v)
}
var meta bytes.Buffer
kvStr(&meta, "general.architecture", "llama")
kvStr(&meta, "general.name", "ReproModel")
kvU32(&meta, "llama.context_length", 4096)
kvU32(&meta, "llama.attention.head_count", 32)
kvU32(&meta, "llama.feed_forward_length", 11008)
kvU32(&meta, "llama.block_count", 32)
kvU32(&meta, "tokenizer.ggml.bos_token_id", 1)
kvStr(&meta, "tokenizer.chat_template", chatTemplate)
// large array value — the one the optimization skips reading
wStr(&meta, "tokenizer.ggml.tokens")
binary.Write(&meta, binary.LittleEndian, ggufTypeArray)
binary.Write(&meta, binary.LittleEndian, ggufTypeString)
binary.Write(&meta, binary.LittleEndian, uint64(vocab))
for i := 0; i < vocab; i++ {
wStr(&meta, "token")
}
var out bytes.Buffer
binary.Write(&out, binary.LittleEndian, gguf.GGUFMagicGGUFLe)
binary.Write(&out, binary.LittleEndian, uint32(3)) // version
binary.Write(&out, binary.LittleEndian, uint64(0)) // tensor count
binary.Write(&out, binary.LittleEndian, uint64(9)) // metadata kv count
out.Write(meta.Bytes())
return os.WriteFile(path, out.Bytes(), 0o644)
}
var _ = Describe("Backend hooks and parser defaults", func() {
Context("MatchParserDefaults", func() {
It("matches Qwen3 family", func() {
@@ -200,58 +137,6 @@ var _ = Describe("Backend hooks and parser defaults", func() {
})
})
Context("llamaCppDefaults GGUF guessing", func() {
// Regression coverage for https://github.com/mudler/LocalAI/issues/9790:
// the hook reads GGUF headers with SkipLargeMetadata + UseMMap to avoid
// pulling the whole tokenizer vocab off (slow) disk on every startup. This
// verifies that skipping the vocab array still yields the correct guessed
// defaults from the remaining scalar metadata.
const chatTemplate = "{{ bos_token }}{% for m in messages %}{{ m.content }}{% endfor %}"
It("guesses defaults from a GGUF whose large vocab is skipped", func() {
dir := GinkgoT().TempDir()
modelFile := "repro.gguf"
Expect(writeTestGGUF(filepath.Join(dir, modelFile), chatTemplate, 50000)).To(Succeed())
// A pre-set context size short-circuits the GGUF run-estimate, which
// needs full tensor info this header-only fixture deliberately omits;
// the metadata-reading path the optimization touches is unaffected.
ctxSize := 4096
cfg := &ModelConfig{
Backend: "llama-cpp",
LLMConfig: LLMConfig{ContextSize: &ctxSize},
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{Model: modelFile},
},
}
cfg.SetDefaults(ModelPath(dir))
// chat_template is a scalar string, not part of the skipped array,
// so it must be captured verbatim.
Expect(cfg.GetModelTemplate()).To(Equal(chatTemplate))
// scalar-derived defaults are still applied
Expect(cfg.ContextSize).NotTo(BeNil())
Expect(cfg.NGPULayers).NotTo(BeNil())
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
Expect(cfg.KnownUsecaseStrings).To(ContainElement("FLAG_CHAT"))
})
It("falls back to the default context size when the GGUF is unreadable", func() {
dir := GinkgoT().TempDir()
Expect(os.WriteFile(filepath.Join(dir, "bad.gguf"), []byte("not a gguf"), 0o644)).To(Succeed())
cfg := &ModelConfig{
Backend: "llama-cpp",
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{Model: "bad.gguf"},
},
}
cfg.SetDefaults(ModelPath(dir))
Expect(cfg.ContextSize).NotTo(BeNil())
})
})
Context("PromptCacheAll default", func() {
It("defaults to true when omitted from YAML", func() {
cfg := &ModelConfig{}

View File

@@ -308,41 +308,6 @@ func DefaultRegistry() map[string]FieldMetaOverride {
},
Order: 64,
},
"pipeline.disable_thinking": {
Section: "pipeline",
Label: "Disable Thinking",
Description: "Suppress reasoning/thinking output from the pipeline LLM (sets enable_thinking=false on the underlying model). Use for models that emit <think> blocks you don't want spoken or streamed back to the realtime client.",
Component: "toggle",
Order: 65,
},
"pipeline.streaming.llm": {
Section: "pipeline",
Label: "Stream LLM",
Description: "Stream LLM tokens to the realtime client as they are generated instead of waiting for the full response. Emits incremental response.output_audio_transcript.delta / text deltas.",
Component: "toggle",
Order: 66,
},
"pipeline.streaming.tts": {
Section: "pipeline",
Label: "Stream TTS",
Description: "Stream synthesized audio chunks to the realtime client as they are produced (requires a TTS backend that implements TTSStream). Falls back to unary synthesis otherwise.",
Component: "toggle",
Order: 67,
},
"pipeline.streaming.transcription": {
Section: "pipeline",
Label: "Stream Transcription",
Description: "Stream partial transcription text to the realtime client as the STT backend produces it (requires a transcription backend that implements AudioTranscriptionStream). Falls back to unary transcription otherwise.",
Component: "toggle",
Order: 68,
},
"pipeline.streaming.clause_chunking": {
Section: "pipeline",
Label: "Clause Chunking",
Description: "Split the streamed reply into speakable clauses and synthesize each as soon as it completes, instead of buffering the whole message before TTS — lower time-to-first-audio. Script-aware (handles CJK 。!? and Thai/Lao spaces), so it does not whitespace-split. Requires Stream LLM; off buffers the whole message.",
Component: "toggle",
Order: 69,
},
// --- Functions ---
"function.grammar.parallel_calls": {

View File

@@ -499,16 +499,6 @@ type Pipeline struct {
// the pipeline's LLM without editing the LLM model config. Overrides the LLM's
// own reasoning_effort. Unset leaves the LLM model config in charge.
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
// Streaming opts each pipeline stage into incremental delivery (LLM tokens,
// TTS audio chunks, transcription text). Unset stages keep the blocking
// unary path, so existing configs are unaffected.
Streaming PipelineStreaming `yaml:"streaming,omitempty" json:"streaming,omitempty"`
// DisableThinking suppresses reasoning/thinking for the pipeline LLM (maps
// to enable_thinking=false backend metadata) without editing the underlying
// LLM model config. Unset leaves the LLM model config in charge.
DisableThinking *bool `yaml:"disable_thinking,omitempty" json:"disable_thinking,omitempty"`
}
// ApplyReasoningEffort resolves the effective reasoning effort — a per-request
@@ -540,41 +530,6 @@ func (c *ModelConfig) ApplyReasoningEffort(requestEffort string) {
}
}
// @Description PipelineStreaming toggles incremental delivery per realtime stage.
type PipelineStreaming struct {
LLM *bool `yaml:"llm,omitempty" json:"llm,omitempty"`
TTS *bool `yaml:"tts,omitempty" json:"tts,omitempty"`
Transcription *bool `yaml:"transcription,omitempty" json:"transcription,omitempty"`
// ClauseChunking splits the streamed LLM reply into speakable clauses and
// synthesizes each as soon as it completes, instead of buffering the whole
// message before TTS. Script-aware (CJK/Thai), so it does not rely on
// whitespace sentence boundaries. Requires LLM streaming; unset buffers the
// whole message (today's default).
ClauseChunking *bool `yaml:"clause_chunking,omitempty" json:"clause_chunking,omitempty"`
}
// StreamLLM reports whether LLM tokens should be streamed for this pipeline.
func (p Pipeline) StreamLLM() bool { return p.Streaming.LLM != nil && *p.Streaming.LLM }
// StreamTTS reports whether TTS audio should be streamed for this pipeline.
func (p Pipeline) StreamTTS() bool { return p.Streaming.TTS != nil && *p.Streaming.TTS }
// StreamTranscription reports whether transcription text should be streamed.
func (p Pipeline) StreamTranscription() bool {
return p.Streaming.Transcription != nil && *p.Streaming.Transcription
}
// ChunkClauses reports whether the streamed reply should be split into
// script-aware clauses and synthesized incrementally rather than buffered whole.
func (p Pipeline) ChunkClauses() bool {
return p.Streaming.ClauseChunking != nil && *p.Streaming.ClauseChunking
}
// ThinkingDisabled reports whether the pipeline forces the LLM's thinking off.
func (p Pipeline) ThinkingDisabled() bool {
return p.DisableThinking != nil && *p.DisableThinking
}
// @Description File configuration for model downloads
type File struct {
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`

View File

@@ -30,26 +30,11 @@ func MTPSpecOptions() []string {
return out
}
// isDraftOnlyAssistantArch reports whether an architecture names a standalone
// MTP *draft* model rather than a self-speculating trunk. Upstream's Gemma4 MTP
// (ggml-org/llama.cpp#23398) registers the head as a separate `gemma4-assistant`
// architecture whose GGUF still carries `nextn_predict_layers`, but which cannot
// run alone: it requires a paired target context (`ctx_other`). Such archs must
// not trigger the embedded-head self-speculation defaults. The `-assistant`
// suffix is upstream's naming convention for these draft-only checkpoints.
func isDraftOnlyAssistantArch(arch string) bool {
return strings.HasSuffix(arch, "-assistant")
}
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a self-speculating
// Multi-Token Prediction head. Detection reads `<arch>.nextn_predict_layers`,
// which is what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a Multi-Token
// Prediction head. Detection reads `<arch>.nextn_predict_layers`, which is
// what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
// `conversion/qwen.py` MTP mixin. A positive layer count means the head is
// present in the same GGUF as the trunk.
//
// Draft-only assistant architectures (e.g. Gemma4's `gemma4-assistant`) carry
// the same key but are separate draft checkpoints meant to be paired with a
// target model, so they are deliberately excluded here.
func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
if f == nil {
return 0, false
@@ -58,9 +43,6 @@ func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
if arch == "" {
return 0, false
}
if isDraftOnlyAssistantArch(arch) {
return 0, false
}
v, ok := f.Header.MetadataKV.Get(arch + ".nextn_predict_layers")
if !ok {
return 0, false

View File

@@ -3,33 +3,10 @@ package config_test
import (
. "github.com/mudler/LocalAI/core/config"
gguf "github.com/gpustack/gguf-parser-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// ggufWithArch fabricates a minimal in-memory GGUF carrying the given
// `general.architecture` and a positive `<arch>.nextn_predict_layers` count,
// so HasEmbeddedMTPHead can be exercised without a real model file.
func ggufWithArch(arch string, nextn uint32) *gguf.GGUFFile {
return &gguf.GGUFFile{
Header: gguf.GGUFHeader{
MetadataKV: gguf.GGUFMetadataKVs{
{
Key: "general.architecture",
ValueType: gguf.GGUFMetadataValueTypeString,
Value: arch,
},
{
Key: arch + ".nextn_predict_layers",
ValueType: gguf.GGUFMetadataValueTypeUint32,
Value: nextn,
},
},
},
}
}
var _ = Describe("MTP auto-defaults", func() {
Context("MTPSpecOptions", func() {
It("returns the upstream-recommended speculative tuple", func() {
@@ -105,20 +82,5 @@ var _ = Describe("MTP auto-defaults", func() {
Expect(ok).To(BeFalse())
Expect(n).To(BeZero())
})
It("detects a same-GGUF embedded head (DeepSeek/Qwen style)", func() {
n, ok := HasEmbeddedMTPHead(ggufWithArch("qwen3moe", 1))
Expect(ok).To(BeTrue())
Expect(n).To(Equal(uint32(1)))
})
It("ignores a gemma4-assistant draft-only model", func() {
// The assistant GGUF carries nextn_predict_layers but is a separate
// draft model that requires a paired target (ctx_other); it cannot
// self-speculate, so it must not trigger the embedded-head defaults.
n, ok := HasEmbeddedMTPHead(ggufWithArch("gemma4-assistant", 48))
Expect(ok).To(BeFalse())
Expect(n).To(BeZero())
})
})
})

View File

@@ -1,57 +0,0 @@
package config
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gopkg.in/yaml.v3"
)
// The realtime pipeline can stream each stage (LLM tokens, TTS audio,
// transcription text) and can disable model "thinking" for the LLM. These are
// opt-in per pipeline; everything defaults to off so existing configs keep the
// unary behaviour.
var _ = Describe("Pipeline streaming config", func() {
It("defaults every streaming + thinking helper to false when unset", func() {
var p Pipeline
Expect(p.StreamLLM()).To(BeFalse())
Expect(p.StreamTTS()).To(BeFalse())
Expect(p.StreamTranscription()).To(BeFalse())
Expect(p.ChunkClauses()).To(BeFalse())
Expect(p.ThinkingDisabled()).To(BeFalse())
})
It("parses the nested streaming block and disable_thinking from YAML", func() {
var c ModelConfig
err := yaml.Unmarshal([]byte(`
name: gpt-realtime
pipeline:
llm: my-llm
tts: my-tts
transcription: my-stt
streaming:
llm: true
tts: true
transcription: true
clause_chunking: true
disable_thinking: true
`), &c)
Expect(err).ToNot(HaveOccurred())
Expect(c.Pipeline.StreamLLM()).To(BeTrue())
Expect(c.Pipeline.StreamTTS()).To(BeTrue())
Expect(c.Pipeline.StreamTranscription()).To(BeTrue())
Expect(c.Pipeline.ChunkClauses()).To(BeTrue())
Expect(c.Pipeline.ThinkingDisabled()).To(BeTrue())
})
It("treats an explicit false in the streaming block as disabled", func() {
var c ModelConfig
err := yaml.Unmarshal([]byte(`
name: gpt-realtime
pipeline:
streaming:
tts: false
`), &c)
Expect(err).ToNot(HaveOccurred())
Expect(c.Pipeline.StreamTTS()).To(BeFalse())
})
})

View File

@@ -383,13 +383,13 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
xlog.Error("server error", "error", err)
}
}()
defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = testHTTPBase + "/v1"
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
@@ -418,7 +418,7 @@ var _ = Describe("API test", func() {
Context("Auth Tests", func() {
It("Should fail if the api key is missing", func() {
err, sc := postInvalidRequest(testHTTPBase + "/models/available")
err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available")
Expect(err).ToNot(BeNil())
Expect(sc).To(Equal(401))
})
@@ -427,7 +427,7 @@ var _ = Describe("API test", func() {
Context("URL routing Tests", func() {
It("Should support reverse-proxy when unauthenticated", func() {
err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
"X-Forwarded-Prefix": {"/myprefix/"},
@@ -441,7 +441,7 @@ var _ = Describe("API test", func() {
It("Should support reverse-proxy when authenticated", func() {
err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
"Authorization": {bearerKey},
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
@@ -459,7 +459,7 @@ var _ = Describe("API test", func() {
// requests them through the proxy.
It("Should support reverse-proxy when prefix is stripped by the proxy", func() {
err, sc, body := getRequest(testHTTPBase+"/app", http.Header{
err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
"X-Forwarded-Prefix": {"/myprefix"},
@@ -477,7 +477,7 @@ var _ = Describe("API test", func() {
// from a foreign origin. BasePathPrefix must reject these via
// SafeForwardedPrefix and fall back to "/".
It("Should ignore an unsafe X-Forwarded-Prefix and not poison asset URLs", func() {
err, sc, body := getRequest(testHTTPBase+"/app", http.Header{
err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
"X-Forwarded-Prefix": {"//evil.com"},
@@ -492,13 +492,13 @@ var _ = Describe("API test", func() {
Context("Applying models", func() {
It("applies models from a gallery", func() {
models, err := getModels(testHTTPBase + "/models/available")
models, err := getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "test@bert2",
})
@@ -507,7 +507,7 @@ var _ = Describe("API test", func() {
uuid := response["uuid"].(string)
resp := map[string]any{}
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
resp = response
return response["processed"].(bool)
@@ -526,7 +526,7 @@ var _ = Describe("API test", func() {
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
Expect(content["foo"]).To(Equal("bar"))
models, err = getModels(testHTTPBase + "/models/available")
models, err = getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
@@ -541,7 +541,7 @@ var _ = Describe("API test", func() {
})
It("overrides models", func() {
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: bertEmbeddingsURL,
Name: "bert",
Overrides: map[string]any{
@@ -554,7 +554,7 @@ var _ = Describe("API test", func() {
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
@@ -567,7 +567,7 @@ var _ = Describe("API test", func() {
Expect(content["backend"]).To(Equal("llama"))
})
It("apply models without overrides", func() {
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: bertEmbeddingsURL,
Name: "bert",
Overrides: map[string]any{},
@@ -578,7 +578,7 @@ var _ = Describe("API test", func() {
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
@@ -622,14 +622,14 @@ parameters:
}
var response schema.GalleryResponse
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
Expect(err).ToNot(HaveOccurred())
Expect(response.ID).ToNot(BeEmpty())
uuid := response.ID
resp := map[string]any{}
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
resp = response
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
@@ -657,7 +657,7 @@ parameters:
}
var response schema.GalleryResponse
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
// The endpoint should return an error immediately
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("failed to discover model config"))
@@ -693,14 +693,14 @@ parameters:
}
var response schema.GalleryResponse
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
Expect(err).ToNot(HaveOccurred())
Expect(response.ID).ToNot(BeEmpty())
uuid := response.ID
resp := map[string]any{}
Eventually(func() bool {
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
resp = response
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
@@ -751,13 +751,13 @@ parameters:
app, err = API(localAIApp)
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
xlog.Error("server error", "error", err)
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = testHTTPBase + "/v1"
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
@@ -801,7 +801,7 @@ parameters:
// Mock-backend is registered via SetExternalBackend so it appears
// alongside any built-in entries; verifying that string proves the
// endpoint is wired up regardless of which real backends exist.
resp, err := http.Get(testHTTPBase + "/system")
resp, err := http.Get("http://127.0.0.1:9090/system")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
dat, err := io.ReadAll(resp.Body)
@@ -824,14 +824,14 @@ parameters:
}
var createResp map[string]any
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
Expect(createResp["id"]).ToNot(BeEmpty())
taskID := createResp["id"].(string)
// Get the task
var task schema.Task
resp, err := http.Get(testHTTPBase + "/api/agent/tasks/" + taskID)
resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, _ := io.ReadAll(resp.Body)
@@ -839,7 +839,7 @@ parameters:
Expect(task.Name).To(Equal("Test Task"))
// List tasks
resp, err = http.Get(testHTTPBase + "/api/agent/tasks")
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
var tasks []schema.Task
@@ -849,18 +849,18 @@ parameters:
// Update task
taskBody["name"] = "Updated Task"
err = putRequestJSON(testHTTPBase+"/api/agent/tasks/"+taskID, &taskBody)
err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody)
Expect(err).ToNot(HaveOccurred())
// Verify update
resp, err = http.Get(testHTTPBase + "/api/agent/tasks/" + taskID)
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
Expect(err).ToNot(HaveOccurred())
body, _ = io.ReadAll(resp.Body)
json.Unmarshal(body, &task)
Expect(task.Name).To(Equal("Updated Task"))
// Delete task
req, _ := http.NewRequest("DELETE", testHTTPBase+"/api/agent/tasks/"+taskID, nil)
req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil)
req.Header.Set("Authorization", bearerKey)
resp, err = http.DefaultClient.Do(req)
Expect(err).ToNot(HaveOccurred())
@@ -877,7 +877,7 @@ parameters:
}
var createResp map[string]any
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
taskID := createResp["id"].(string)
@@ -888,14 +888,14 @@ parameters:
}
var jobResp schema.JobExecutionResponse
err = postRequestResponseJSON(testHTTPBase+"/api/agent/jobs/execute", &jobBody, &jobResp)
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp)
Expect(err).ToNot(HaveOccurred())
Expect(jobResp.JobID).ToNot(BeEmpty())
jobID := jobResp.JobID
// Get job status
var job schema.Job
resp, err := http.Get(testHTTPBase + "/api/agent/jobs/" + jobID)
resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, _ := io.ReadAll(resp.Body)
@@ -904,7 +904,7 @@ parameters:
Expect(job.TaskID).To(Equal(taskID))
// List jobs
resp, err = http.Get(testHTTPBase + "/api/agent/jobs")
resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
var jobs []schema.Job
@@ -914,7 +914,7 @@ parameters:
// Cancel job (if still pending/running)
if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning {
req, _ := http.NewRequest("POST", testHTTPBase+"/api/agent/jobs/"+jobID+"/cancel", nil)
req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil)
req.Header.Set("Authorization", bearerKey)
resp, err = http.DefaultClient.Do(req)
Expect(err).ToNot(HaveOccurred())
@@ -932,13 +932,13 @@ parameters:
}
var createResp map[string]any
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred())
// Execute by name
paramsBody := map[string]string{"param1": "value1"}
var jobResp schema.JobExecutionResponse
err = postRequestResponseJSON(testHTTPBase+"/api/agent/tasks/Named Task/execute", &paramsBody, &jobResp)
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", &paramsBody, &jobResp)
Expect(err).ToNot(HaveOccurred())
Expect(jobResp.JobID).ToNot(BeEmpty())
})
@@ -998,13 +998,13 @@ parameters:
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
xlog.Error("server error", "error", err)
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = testHTTPBase + "/v1"
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready

View File

@@ -25,10 +25,6 @@ var knownPrefOnlyBackends = []schema.KnownBackend{
// Text LLM
// ds4: antirez/ds4 - single-model DeepSeek V4 Flash engine; auto-detected via DS4Importer
{Name: "ds4", Modality: "text", AutoDetect: false, Description: "antirez/ds4 DeepSeek V4 Flash engine (auto-detected; pref-only fallback)"},
// dllm consumes GGUF weights like llama-cpp does, but only for the
// DiffusionGemma architecture - auto-detecting on .gguf would shadow
// llama-cpp, so it stays preference-only.
{Name: "dllm", Modality: "text", AutoDetect: false, Description: "dllm.cpp DiffusionGemma block-diffusion engine (preference-only)"},
{Name: "sglang", Modality: "text", AutoDetect: false, Description: "SGLang runtime (preference-only)"},
{Name: "tinygrad", Modality: "text", AutoDetect: false, Description: "tinygrad runtime (preference-only)"},
{Name: "trl", Modality: "text", AutoDetect: false, Description: "Transformers Reinforcement Learning (preference-only)"},

View File

@@ -135,7 +135,6 @@ var _ = Describe("Backend Endpoints", func() {
Expect(entry.Modality).To(Equal(modality))
}
expectPrefOnly("dllm", "text")
expectPrefOnly("sglang", "text")
expectPrefOnly("tinygrad", "text")
expectPrefOnly("trl", "text")

View File

@@ -103,12 +103,7 @@ func applyAutoparserOverride(
// blocks like "<think></think>" that some models emit when reasoning
// is disabled.
if deltaReasoning == "" && deltaContent != "" {
// Complete-response extraction: only honor a prefilled <think> start
// token when deltaContent actually closes the reasoning block. Without
// it the model answered directly and the whole answer must stay in
// content rather than be swallowed as unclosed reasoning. See
// reason.ExtractReasoningComplete.
deltaReasoning, deltaContent = reason.ExtractReasoningComplete(deltaContent, thinkingStartToken, reasoningConfig)
deltaReasoning, deltaContent = reason.ExtractReasoningWithConfig(deltaContent, thinkingStartToken, reasoningConfig)
}
xlog.Debug("[ChatDeltas] non-SSE no-tools: overriding result with C++ autoparser deltas",
"content_len", len(deltaContent), "reasoning_len", len(deltaReasoning))

View File

@@ -186,114 +186,6 @@ var _ = Describe("applyAutoparserOverride", func() {
Expect(result).To(Equal(existing))
})
})
// Regression tests for the prefilled-thinking-token path (thinkingStartToken
// != ""). This is the configuration the gallery qwen3 family runs in: the
// chat template injects <think> into the prompt, so DetectThinkingStartToken
// returns "<think>" and the model's output begins *inside* a reasoning block
// — it emits a closing </think> but no opening tag.
//
// The defensive Go-side fallback prepends the start token so the standard
// extractor can pair it with the model's </think>. But on a *complete*
// response that contains NO closing tag (the model answered directly with no
// reasoning at all), prepending <think> manufactures an unclosed block that
// swallows the entire answer into reasoning, leaving content empty. That is
// the bug: short/direct answers (session names, JSON summaries) come back
// with an empty content field.
Context("autoparser delivered content with empty reasoning and a prefilled thinking token", func() {
const startToken = "<think>"
It("keeps a tag-less direct answer as content instead of swallowing it as reasoning", func() {
// Model answered directly: no <think>, no </think> anywhere.
chatDeltas := []*pb.ChatDelta{
{Content: "hello", ReasoningContent: ""},
}
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
Expect(result).To(HaveLen(1))
Expect(result[0].Message.Content).ToNot(BeNil())
Expect(*(result[0].Message.Content.(*string))).To(Equal("hello"),
"a complete answer with no closing reasoning tag must stay in content")
Expect(result[0].Message.Reasoning).To(BeNil(),
"no reasoning block was emitted, so Reasoning must not be set")
})
It("keeps a tag-less JSON answer as content (the summary case)", func() {
raw := `{"short":"Tests pass","long":"go test ./... succeeded."}`
chatDeltas := []*pb.ChatDelta{
{Content: raw, ReasoningContent: ""},
}
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
Expect(result).To(HaveLen(1))
Expect(*(result[0].Message.Content.(*string))).To(Equal(raw))
Expect(result[0].Message.Reasoning).To(BeNil())
})
It("still splits reasoning when the model emits the closing tag (prefill paired with </think>)", func() {
// The legitimate prefill case: <think> was in the prompt, so the
// output carries only the closing tag. The closing tag is the proof
// that a reasoning block exists, so extraction must run.
raw := "The user wants a greeting.\n</think>\n\nHello there!"
chatDeltas := []*pb.ChatDelta{
{Content: raw, ReasoningContent: ""},
}
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
Expect(result).To(HaveLen(1))
content := *(result[0].Message.Content.(*string))
Expect(content).To(ContainSubstring("Hello there!"))
Expect(content).ToNot(ContainSubstring("</think>"))
Expect(content).ToNot(ContainSubstring("The user wants a greeting"))
Expect(result[0].Message.Reasoning).ToNot(BeNil())
Expect(*result[0].Message.Reasoning).To(ContainSubstring("The user wants a greeting"))
})
It("still splits a fully-tagged <think>…</think> block with a prefill token set", func() {
raw := "<think>Reasoning here.</think>Final answer."
chatDeltas := []*pb.ChatDelta{
{Content: raw, ReasoningContent: ""},
}
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
Expect(result).To(HaveLen(1))
Expect(*(result[0].Message.Content.(*string))).To(Equal("Final answer."))
Expect(result[0].Message.Reasoning).ToNot(BeNil())
Expect(*result[0].Message.Reasoning).To(ContainSubstring("Reasoning here"))
})
// End-to-end regression for the real production failure: a request with
// enable_thinking=false against a <think>-capable model (qwen3 family).
//
// In non-thinking mode the model emits no reasoning block, so llama.cpp's
// autoparser correctly returns ChatDeltas with Content set and
// ReasoningContent EMPTY (verified against stock llama-server: the same
// model with chat_template_kwargs.enable_thinking=false returns
// reasoning_content=null and content="hello"). But thinkingStartToken is
// detected per-model from the enable_thinking=TRUE render
// (grpc-server renders with enable_thinking=true; DetectThinkingStartToken
// does not evaluate the jinja {% if enable_thinking %} conditional), so it
// is "<think>" even for this non-thinking request. The old code prepended
// it and swallowed the answer. This is the case that broke session
// summaries and auto-titles and was NOT covered before.
It("preserves content for a non-thinking-mode request (enable_thinking=false, empty reasoning_content)", func() {
// What llama.cpp's autoparser actually returns in non-thinking mode.
chatDeltas := []*pb.ChatDelta{
{Content: `{"short":"Go tests passed for internal/session"}`, ReasoningContent: ""},
}
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
Expect(result).To(HaveLen(1))
Expect(*(result[0].Message.Content.(*string))).To(Equal(`{"short":"Go tests passed for internal/session"}`),
"non-thinking-mode answers must reach the client intact, not be swallowed as reasoning")
Expect(result[0].Message.Reasoning).To(BeNil())
})
})
})
var _ = Describe("mergeToolCallDeltas", func() {

View File

@@ -2,10 +2,8 @@ package openai
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math"
@@ -237,12 +235,6 @@ type Model interface {
Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error)
Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error)
TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error)
// TTSStream synthesizes speech incrementally, invoking onAudio with raw PCM
// chunks (and the backend sample rate) as they are produced.
TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error
// TranscribeStream transcribes audio incrementally, invoking onDelta for each
// transcript text fragment and returning the final aggregated result.
TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error)
PredictConfig() *config.ModelConfig
}
@@ -1262,15 +1254,27 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co
// TODO: If we have a real any-to-any model then transcription is optional
var transcript string
if session.InputAudioTranscription != nil {
// emitTranscription streams transcript deltas when
// pipeline.streaming.transcription is set, otherwise emits a single
// completed event; either way it returns the final transcript text.
var err error
transcript, err = emitTranscription(ctx, t, session, generateItemID(), f.Name())
tr, err := session.ModelInterface.Transcribe(ctx, f.Name(), session.InputAudioTranscription.Language, false, false, session.InputAudioTranscription.Prompt)
if err != nil {
sendError(t, "transcription_failed", err.Error(), "", "event_TODO")
return
} else if tr == nil {
sendError(t, "transcription_failed", "trancribe result is nil", "", "event_TODO")
return
}
transcript = tr.Text
sendEvent(t, types.ConversationItemInputAudioTranscriptionCompletedEvent{
ServerEventBase: types.ServerEventBase{
EventID: "event_TODO",
},
ItemID: generateItemID(),
// ResponseID: "resp_TODO", // Not needed for transcription completed event
// OutputIndex: 0,
ContentIndex: 0,
Transcript: transcript,
})
} else {
sendNotImplemented(t, "any-to-any models")
return
@@ -1498,26 +1502,6 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
},
})
// Streamed LLM path: when the pipeline opts into LLM streaming, stream the
// transcript to the client as it is generated and synthesize the buffered
// message once. Tool turns are supported only when the model uses its
// tokenizer template: the C++ autoparser then delivers content and tool
// calls via ChatDeltas (clearing the text stream), so the spoken transcript
// never leaks tool-call tokens. Grammar-based function calling emits the
// call as JSON in the token stream, so those turns keep the buffered path.
if config != nil && session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamLLM() {
canStream := len(tools) == 0 || config.TemplateConfig.UseTokenizerTemplate
var respMods []types.Modality
if overrides != nil {
respMods = overrides.OutputModalities
}
if canStream && modalitiesContainAudio(resolveOutputModalities(session.OutputModalities, respMods)) {
if streamLLMResponse(ctx, session, conv, t, responseID, conversationHistory, images, config, tools, toolChoice, toolTurn) {
return
}
}
}
predFunc, err := session.ModelInterface.Predict(ctx, conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil)
if err != nil {
sendError(t, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here
@@ -1595,7 +1579,7 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
// ExtractReasoningWithConfig is a no-op when no tag pair matches,
// so it's safe to apply unconditionally in the no-reasoning branch.
if deltaReasoning == "" && deltaContent != "" {
deltaReasoning, deltaContent = reasoning.ExtractReasoningComplete(deltaContent, thinkingStartToken, spokenReasoningConfig(config.ReasoningConfig))
deltaReasoning, deltaContent = reasoning.ExtractReasoningWithConfig(deltaContent, thinkingStartToken, config.ReasoningConfig)
}
reasoningText = deltaReasoning
responseWithoutReasoning = deltaContent
@@ -1603,7 +1587,7 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
cleanedResponse = deltaContent
toolCalls = deltaToolCalls
} else {
reasoningText, responseWithoutReasoning = reasoning.ExtractReasoningComplete(rawResponse, thinkingStartToken, spokenReasoningConfig(config.ReasoningConfig))
reasoningText, responseWithoutReasoning = reasoning.ExtractReasoningWithConfig(rawResponse, thinkingStartToken, config.ReasoningConfig)
textContent = functions.ParseTextContent(responseWithoutReasoning, config.FunctionsConfig)
cleanedResponse = functions.CleanupLLMResult(responseWithoutReasoning, config.FunctionsConfig)
toolCalls = functions.ParseFunctionCall(cleanedResponse, config.FunctionsConfig)
@@ -1729,7 +1713,64 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
return
}
// Transcript of the spoken reply (the audio's text).
audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language)
if err != nil {
if ctx.Err() != nil {
xlog.Debug("TTS cancelled (barge-in)")
sendCancelledResponse()
return
}
xlog.Error("TTS failed", "error", err)
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
return
}
if !res.Success {
xlog.Error("TTS failed", "message", res.Message)
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID)
return
}
defer func() { _ = os.Remove(audioFilePath) }()
audioBytes, err := os.ReadFile(audioFilePath)
if err != nil {
xlog.Error("failed to read TTS file", "error", err)
sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
return
}
// Parse WAV header to get raw PCM and the actual sample rate from the TTS backend.
pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes)
if ttsSampleRate == 0 {
ttsSampleRate = localSampleRate
}
xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate)
// SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the
// Opus encoder, which resamples to 48kHz internally. This avoids a
// lossy intermediate resample through 16kHz.
// XXX: This is a noop in websocket mode; it's included in the JSON instead
if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil {
if ctx.Err() != nil {
xlog.Debug("Audio playback cancelled (barge-in)")
sendCancelledResponse()
return
}
xlog.Error("failed to send audio via transport", "error", err)
}
// For WebSocket clients, resample to the session's output rate and
// deliver audio as base64 in JSON events. WebRTC clients already
// received audio over the RTP track, so skip the base64 payload.
if !isWebRTC {
wsPCM := pcmData
if ttsSampleRate != session.OutputSampleRate {
samples := sound.BytesToInt16sLE(pcmData)
resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate)
wsPCM = sound.Int16toBytesLE(resampled)
}
audioString = base64.StdEncoding.EncodeToString(wsPCM)
}
sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
@@ -1747,26 +1788,15 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
Transcript: finalSpeech,
})
// Synthesize and send the audio. With pipeline.streaming.tts enabled
// emitSpeech forwards a response.output_audio.delta per backend PCM
// chunk as it's produced; otherwise it sends the whole utterance as a
// single delta. The returned PCM is stored (base64) on the item below.
pcmAudio, err := emitSpeech(ctx, t, session, responseID, item.Assistant.ID, finalSpeech)
if err != nil {
if ctx.Err() != nil {
xlog.Debug("TTS cancelled (barge-in)")
sendCancelledResponse()
return
}
xlog.Error("TTS failed", "error", err)
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
return
}
if !isWebRTC {
audioString = base64.StdEncoding.EncodeToString(pcmAudio)
}
if !isWebRTC {
sendEvent(t, types.ResponseOutputAudioDeltaEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: item.Assistant.ID,
OutputIndex: 0,
ContentIndex: 0,
Delta: audioString,
})
sendEvent(t, types.ResponseOutputAudioDoneEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
@@ -1819,27 +1849,17 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
})
}
// Emit the parsed tool calls, the terminal response.done, and (for
// server-side assistant tools) the follow-up response. Shared with the
// streamed path so both finalize tool calls identically.
emitToolCallItems(ctx, session, conv, t, responseID, finalToolCalls, finalSpeech != "", toolTurn)
}
// emitToolCallItems emits the realtime function_call items for the parsed tool
// calls, the terminal response.done, and — for server-side LocalAI Assistant
// tools — re-triggers a follow-up response so the model can speak the result.
// hasContent shifts the tool-call output index past the assistant content item
// when the same turn also produced spoken/text content. Two tool paths:
// - LocalAI Assistant tools (session.AssistantExecutor.IsTool) run server-side;
// we append both the call and its output to conv.Items and re-trigger. The
// client only sees observability events.
// - All other tools follow the standard OpenAI flow: emit
// function_call_arguments.done and wait for the client to send
// conversation.item.create back.
func emitToolCallItems(ctx context.Context, session *Session, conv *Conversation, t Transport, responseID string, toolCalls []functions.FuncCallResults, hasContent bool, toolTurn int) {
xlog.Debug("About to handle tool calls", "finalToolCallsCount", len(toolCalls))
// Handle Tool Calls. Two paths:
// - LocalAI Assistant tools (session.AssistantExecutor.IsTool) run
// server-side; we append both the call and its output to conv.Items
// and re-trigger a follow-up response so the model can speak the
// result. The client only sees observability events.
// - All other tools follow the standard OpenAI flow: emit
// function_call_arguments.done and wait for the client to send
// conversation.item.create back.
xlog.Debug("About to handle tool calls", "finalToolCallsCount", len(finalToolCalls))
executedAssistantTool := false
for i, tc := range toolCalls {
for i, tc := range finalToolCalls {
toolCallID := generateItemID()
callID := "call_" + generateUniqueID() // OpenAI uses call_xyz
@@ -1859,7 +1879,7 @@ func emitToolCallItems(ctx context.Context, session *Session, conv *Conversation
conv.Lock.Unlock()
outputIndex := i
if hasContent {
if finalSpeech != "" {
outputIndex++
}
@@ -1985,11 +2005,8 @@ func generateItemID() string {
}
func generateUniqueID() string {
// 16 random bytes, hex-encoded. Must be collision-free: session, item,
// response and call IDs build on this, and the conversation tracks/removes
// items by ID (e.g. cancel() in realtime_stream.go, conversation.item.retrieve).
// A constant would make every ID alias and corrupt that bookkeeping.
var b [16]byte
_, _ = rand.Read(b[:])
return hex.EncodeToString(b[:])
// Generate a unique ID string
// For simplicity, use a counter or UUID
// Implement as needed
return "unique_id"
}

View File

@@ -1,200 +0,0 @@
package openai
import (
"strings"
"unicode"
"unicode/utf8"
"github.com/rivo/uniseg"
)
// Default clause-chunker bounds (in runes). minRunes gates only sub-sentence
// (clause-mark / Thai-space) cuts so we don't synthesize tiny choppy fragments;
// full sentences always flush regardless of length. maxRunes caps an
// unterminated run so a long punctuation-less span doesn't buffer unbounded.
const (
defaultClauseMinRunes = 12
defaultClauseMaxRunes = 200
)
// clauseChunker splits streamed LLM content into speakable clauses for
// incremental TTS, in a SCRIPT-AWARE way so it works for languages without
// whitespace word boundaries. It leans on UAX #29 sentence segmentation (which
// natively terminates on CJK 。!? as well as Latin .!?), adds CJK clause
// punctuation (,、;:) and Thai/Lao spaces as finer boundaries, and caps an
// over-long unterminated run via UAX #14 line-break opportunities.
//
// Unlike the old ASCII .!?/newline segmenter (dropped in 076dcdbe), it does not
// degrade to whole-message buffering for CJK (handled natively) or Thai/Lao
// (handled via spaces, which Thai uses at clause/sentence boundaries). Scripts
// that genuinely need a dictionary (Khmer/Myanmar) simply stay buffered until a
// space or end-of-message — no worse than the buffered default.
//
// It is not safe for concurrent use; callers feed it from a single goroutine
// (the LLM token callback).
type clauseChunker struct {
buf strings.Builder
minRunes int
maxRunes int
}
func newClauseChunker(minRunes, maxRunes int) *clauseChunker {
return &clauseChunker{minRunes: minRunes, maxRunes: maxRunes}
}
// push appends streamed content and returns any clauses that are now complete —
// "complete" meaning confirmed by following content, so we never speak a clause
// that the next token might extend. Incomplete trailing text stays buffered.
func (c *clauseChunker) push(text string) []string {
c.buf.WriteString(text)
return c.drain(false)
}
// flush returns the remaining buffered clauses, treating end-of-input as a hard
// boundary, and clears the buffer.
func (c *clauseChunker) flush() []string {
return c.drain(true)
}
func (c *clauseChunker) drain(final bool) []string {
s := c.buf.String()
rest := s
var out []string
for rest != "" {
end, ok := c.nextBoundary(rest, final)
if !ok {
break
}
if seg := strings.TrimSpace(rest[:end]); seg != "" {
out = append(out, seg)
}
rest = rest[end:]
}
// Rewriting the builder reallocates and copies the whole buffer; skip it on
// the common per-token call where no boundary was confirmed.
if len(rest) != len(s) {
c.buf.Reset()
c.buf.WriteString(rest)
}
return out
}
// nextBoundary returns the byte offset just past the first emittable clause in
// s, or ok=false when more input is needed (final=false) and no boundary is
// confirmed yet.
func (c *clauseChunker) nextBoundary(s string, final bool) (int, bool) {
if s == "" {
return 0, false
}
// 1) UAX #29 sentence boundary. When the first sentence is followed by more
// text it is a confirmed complete sentence (handles Latin .!? with
// abbreviation/decimal guards, and CJK 。!? with no whitespace).
sentence, rest, _ := uniseg.FirstSentenceInString(s, -1)
if rest != "" {
// Optionally cut finer inside the sentence at a clause boundary.
if cut, ok := c.firstClauseCut(sentence); ok {
return cut, true
}
return len(sentence), true
}
// 2) Unterminated tail: look for a sub-sentence clause boundary (CJK
// punctuation or a Thai/Lao space) confirmed by following content.
if cut, ok := c.firstClauseCut(s); ok {
return cut, true
}
// 3) Over-long punctuation-less run: force a typographically legal break so
// we don't buffer unbounded (e.g. a long CJK run with no punctuation).
if !final && c.maxRunes > 0 && utf8.RuneCountInString(s) > c.maxRunes {
if cut, ok := lineBreakCut(s, c.maxRunes); ok {
return cut, true
}
}
// 4) End of input: emit whatever remains as the final clause.
if final {
return len(s), true
}
return 0, false
}
// firstClauseCut returns the byte offset just past the first sub-sentence clause
// boundary in s — a CJK clause punctuation mark, or a space following a Thai/Lao
// letter — provided the prefix is at least minRunes long and non-space content
// follows. The boundary mark (and any trailing spaces) stay with the left clause.
func (c *clauseChunker) firstClauseCut(s string) (int, bool) {
var prev rune
runes := 0
for i, r := range s {
boundary := isCJKClausePunct(r) || (unicode.IsSpace(r) && isThaiLao(prev))
if boundary && runes+1 >= c.minRunes {
end := i + utf8.RuneLen(r)
for end < len(s) {
nr, sz := utf8.DecodeRuneInString(s[end:])
if !unicode.IsSpace(nr) {
break
}
end += sz
}
if end < len(s) { // confirmed: real content follows the boundary
return end, true
}
// Boundary sits at the end of the buffer with nothing after it yet —
// wait for the next token to confirm it rather than emit early.
return 0, false
}
prev = r
runes++
}
return 0, false
}
// lineBreakCut walks UAX #14 line segments and returns the byte offset of the
// last legal break opportunity at or before maxRunes. Returns ok=false when the
// run has no internal break opportunity (e.g. a space-less Thai run), leaving it
// buffered.
func lineBreakCut(s string, maxRunes int) (int, bool) {
state := -1
rest := s
consumed := 0
runes := 0
for rest != "" {
seg, rem, _, st := uniseg.FirstLineSegmentInString(rest, state)
state = st
runes += utf8.RuneCountInString(seg)
consumed += len(seg)
rest = rem
if runes >= maxRunes {
if consumed < len(s) {
return consumed, true
}
return 0, false
}
}
return 0, false
}
// isCJKClausePunct reports whether r is a CJK clause-level separator worth a
// soft TTS break. Sentence terminators (。!?) are intentionally excluded — UAX
// #29 sentence segmentation already handles those.
func isCJKClausePunct(r rune) bool {
switch r {
case '', // fullwidth comma
'、', // 、 ideographic comma
'', // fullwidth semicolon
'', // fullwidth colon
'・', // ・ katakana middle dot
'・': // ・ halfwidth katakana middle dot
return true
}
return false
}
// isThaiLao reports whether r is a Thai or Lao letter. Those scripts have no
// inter-word spaces; an ASCII space inside such a run marks a clause/sentence
// boundary, which is the only no-dictionary segmentation signal available.
func isThaiLao(r rune) bool {
return unicode.Is(unicode.Thai, r) || unicode.Is(unicode.Lao, r)
}

View File

@@ -1,103 +0,0 @@
package openai
import (
"strings"
"unicode/utf8"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// clauseChunker splits streamed LLM content into speakable clauses in a
// script-aware way: UAX#29 sentences (Latin .!? and CJK 。!?), CJK clause
// punctuation, and Thai/Lao spaces — never whitespace-splitting CJK.
var _ = Describe("clauseChunker", func() {
Context("Latin sentences", func() {
It("emits a sentence only once following content confirms it is complete", func() {
c := newClauseChunker(12, 200)
Expect(c.push("Hello world. How are you?")).To(Equal([]string{"Hello world."}))
// The trailing sentence is held until flush (the next token might extend it).
Expect(c.flush()).To(Equal([]string{"How are you?"}))
})
It("assembles a sentence across many small tokens", func() {
c := newClauseChunker(12, 200)
var got []string
for _, tok := range []string{"Hello", " world.", " How", " are", " you?"} {
got = append(got, c.push(tok)...)
}
got = append(got, c.flush()...)
Expect(got).To(Equal([]string{"Hello world.", "How are you?"}))
})
It("does not split decimals or abbreviations (UAX#29 SB6)", func() {
c := newClauseChunker(12, 200)
got := c.push("Pi is 3.14 and e is 2.72. Done")
Expect(got).To(Equal([]string{"Pi is 3.14 and e is 2.72."}))
Expect(c.flush()).To(Equal([]string{"Done"}))
})
})
Context("CJK (no whitespace)", func() {
It("splits Chinese on the ideographic full stop", func() {
c := newClauseChunker(12, 200)
Expect(c.push("你好世界。今天天气很好。")).To(Equal([]string{"你好世界。"}))
Expect(c.flush()).To(Equal([]string{"今天天气很好。"}))
})
It("splits Japanese on the ideographic full stop", func() {
c := newClauseChunker(12, 200)
Expect(c.push("こんにちは。元気ですか。")).To(Equal([]string{"こんにちは。"}))
Expect(c.flush()).To(Equal([]string{"元気ですか。"}))
})
It("splits on CJK clause punctuation for lower latency", func() {
c := newClauseChunker(2, 200) // small min so short test clauses cut
Expect(c.push("你好,世界。再见")).To(Equal([]string{"你好,", "世界。"}))
Expect(c.flush()).To(Equal([]string{"再见"}))
})
})
Context("Thai (spaces mark clauses, not words)", func() {
It("splits a Thai run on the inter-clause space", func() {
c := newClauseChunker(2, 200)
Expect(c.push("สวัสดีครับ กินข้าวไหม")).To(Equal([]string{"สวัสดีครับ"}))
Expect(c.flush()).To(Equal([]string{"กินข้าวไหม"}))
})
It("never shatters a space-less Thai run into characters", func() {
c := newClauseChunker(2, 200)
Expect(c.push("สวัสดีครับ")).To(BeEmpty()) // held, no boundary
Expect(c.flush()).To(Equal([]string{"สวัสดีครับ"}))
})
})
Context("length cap (UAX#14 fallback)", func() {
It("force-breaks an over-long punctuation-less CJK run at legal points", func() {
c := newClauseChunker(4, 10) // maxRunes = 10
run := strings.Repeat("字", 25)
got := c.push(run)
got = append(got, c.flush()...)
total := 0
for _, seg := range got {
n := utf8.RuneCountInString(seg)
Expect(n).To(BeNumerically("<=", 10)) // never exceeds the cap
total += n
}
Expect(total).To(Equal(25)) // nothing dropped
Expect(len(got)).To(BeNumerically(">=", 3)) // 10 + 10 + 5
})
})
Context("buffer lifecycle", func() {
It("flush clears the buffer so the chunker is reusable", func() {
c := newClauseChunker(12, 200)
// "First one." is confirmed by the following "Second", so push drains it;
// only the unterminated tail remains for flush.
Expect(c.push("First one. Second")).To(Equal([]string{"First one."}))
Expect(c.flush()).To(Equal([]string{"Second"}))
Expect(c.flush()).To(BeEmpty())
Expect(c.push("Again. More")).To(Equal([]string{"Again."}))
})
})
})

View File

@@ -1,138 +0,0 @@
package openai
import (
"context"
"strings"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
)
// fakeTransport records the server events and audio sent to a realtime client
// so streaming behaviour can be asserted without a real WebSocket/WebRTC peer.
// It is not a *WebRTCTransport, so handler code takes the WebSocket path.
type fakeTransport struct {
events []types.ServerEvent
audio []fakeAudioChunk
}
type fakeAudioChunk struct {
pcm []byte
sampleRate int
}
func (f *fakeTransport) SendEvent(e types.ServerEvent) error {
f.events = append(f.events, e)
return nil
}
func (f *fakeTransport) ReadEvent() ([]byte, error) { return nil, nil }
func (f *fakeTransport) SendAudio(_ context.Context, pcm []byte, sampleRate int) error {
f.audio = append(f.audio, fakeAudioChunk{pcm: pcm, sampleRate: sampleRate})
return nil
}
func (f *fakeTransport) Close() error { return nil }
// countEvents returns how many recorded events have the given type.
func (f *fakeTransport) countEvents(et types.ServerEventType) int {
n := 0
for _, e := range f.events {
if e.ServerEventType() == et {
n++
}
}
return n
}
// transcriptDeltaText concatenates the Delta of every recorded transcript
// delta event — i.e. the text streamed to the client as it is generated.
func (f *fakeTransport) transcriptDeltaText() string {
var b strings.Builder
for _, e := range f.events {
if d, ok := e.(types.ResponseOutputAudioTranscriptDeltaEvent); ok {
b.WriteString(d.Delta)
}
}
return b.String()
}
// fakeModel is a configurable Model double. TTSStream replays ttsStreamChunks
// and TranscribeStream replays transcribeDeltas, so the handler's streaming
// paths can be driven deterministically.
type fakeModel struct {
cfg *config.ModelConfig
ttsFile string
ttsStreamChunks [][]byte
ttsStreamRate int
ttsStreamErr error
transcribeDeltas []string
transcribeFinal *schema.TranscriptionResult
// Predict streaming: predictTokens are replayed through the token callback
// (simulating streamed LLM output); predictResp/predictErr are returned by
// the deferred predict function. predictChunkDeltas, when set, are delivered
// per-token via TokenUsage.ChatDeltas to exercise the autoparser path.
predictTokens []string
predictChunkDeltas [][]*proto.ChatDelta
predictResp backend.LLMResponse
predictErr error
}
func (m *fakeModel) VAD(context.Context, *schema.VADRequest) (*schema.VADResponse, error) {
return nil, nil
}
func (m *fakeModel) Transcribe(context.Context, string, string, bool, bool, string) (*schema.TranscriptionResult, error) {
return m.transcribeFinal, nil
}
func (m *fakeModel) Predict(_ context.Context, _ schema.Messages, _, _, _ []string, cb func(string, backend.TokenUsage) bool, _ []types.ToolUnion, _ *types.ToolChoiceUnion, _, _ *int, _ map[string]float64) (func() (backend.LLMResponse, error), error) {
if m.predictErr != nil {
return nil, m.predictErr
}
return func() (backend.LLMResponse, error) {
for i, tok := range m.predictTokens {
if cb == nil {
continue
}
usage := backend.TokenUsage{}
if i < len(m.predictChunkDeltas) {
usage.ChatDeltas = m.predictChunkDeltas[i]
}
cb(tok, usage)
}
return m.predictResp, nil
}, nil
}
func (m *fakeModel) TTS(context.Context, string, string, string) (string, *proto.Result, error) {
return m.ttsFile, &proto.Result{Success: true}, nil
}
func (m *fakeModel) TTSStream(_ context.Context, _, _, _ string, onAudio func(pcm []byte, sampleRate int) error) error {
if m.ttsStreamErr != nil {
return m.ttsStreamErr
}
for _, c := range m.ttsStreamChunks {
if err := onAudio(c, m.ttsStreamRate); err != nil {
return err
}
}
return nil
}
func (m *fakeModel) TranscribeStream(_ context.Context, _, _ string, _, _ bool, _ string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
for _, d := range m.transcribeDeltas {
onDelta(d)
}
return m.transcribeFinal, nil
}
func (m *fakeModel) PredictConfig() *config.ModelConfig { return m.cfg }

View File

@@ -3,7 +3,6 @@ package openai
import (
"context"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
@@ -88,14 +87,6 @@ func (m *transcriptOnlyModel) TTS(ctx context.Context, text, voice, language str
return "", nil, fmt.Errorf("TTS not supported in transcript-only mode")
}
func (m *transcriptOnlyModel) TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error {
return fmt.Errorf("TTS not supported in transcript-only mode")
}
func (m *transcriptOnlyModel) TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta)
}
func (m *transcriptOnlyModel) PredictConfig() *config.ModelConfig {
return nil
}
@@ -330,75 +321,10 @@ func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (s
return backend.ModelTTS(ctx, text, voice, language, "", nil, m.modelLoader, m.appConfig, *m.TTSConfig)
}
func (m *wrappedModel) TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error {
return ttsStream(ctx, m.modelLoader, m.appConfig, *m.TTSConfig, text, voice, language, onAudio)
}
func (m *wrappedModel) TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta)
}
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
return m.LLMConfig
}
// wavStreamHeaderBytes is the size of the WAV header that backend.ModelTTSStream
// emits as its first audio callback; the sample rate lives at byte offset 24.
const wavStreamHeaderBytes = 44
// ttsStream adapts backend.ModelTTSStream (which emits a WAV stream: a 44-byte
// header carrying the sample rate, then raw PCM) to the realtime onAudio
// callback, which wants raw PCM plus the sample rate. The header is buffered
// until complete, the sample rate is read from it, and subsequent bytes are
// forwarded as PCM.
func ttsStream(ctx context.Context, ml *model.ModelLoader, appConfig *config.ApplicationConfig, ttsConfig config.ModelConfig, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error {
var header []byte
headerDone := false
sampleRate := 0
return backend.ModelTTSStream(ctx, text, voice, language, "", nil, ml, appConfig, ttsConfig, func(b []byte) error {
if headerDone {
if len(b) == 0 {
return nil
}
return onAudio(b, sampleRate)
}
header = append(header, b...)
if len(header) < wavStreamHeaderBytes {
return nil
}
sampleRate = int(binary.LittleEndian.Uint32(header[24:28]))
headerDone = true
if len(header) > wavStreamHeaderBytes {
return onAudio(header[wavStreamHeaderBytes:], sampleRate)
}
return nil
})
}
// transcribeStream adapts backend.ModelTranscriptionStream to the realtime
// onDelta callback, returning the final aggregated transcription result.
func transcribeStream(ctx context.Context, ml *model.ModelLoader, transcriptionConfig config.ModelConfig, appConfig *config.ApplicationConfig, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
var final *schema.TranscriptionResult
err := backend.ModelTranscriptionStream(ctx, backend.TranscriptionRequest{
Audio: audio,
Language: language,
Translate: translate,
Diarize: diarize,
Prompt: prompt,
}, ml, transcriptionConfig, appConfig, func(chunk backend.TranscriptionStreamChunk) {
if chunk.Delta != "" {
onDelta(chunk.Delta)
}
if chunk.Final != nil {
final = chunk.Final
}
})
if err != nil {
return nil, err
}
return final, nil
}
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
if err != nil {
@@ -528,10 +454,8 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
return nil, fmt.Errorf("failed to validate config: %w", err)
}
// Let the pipeline set the LLM's reasoning effort and force thinking off
// (cfgLLM is a per-session copy). disable_thinking applies after the effort.
// Let the pipeline set the LLM's reasoning effort (cfgLLM is a per-session copy).
applyPipelineReasoning(cfgLLM, *pipeline)
applyPipelineThinking(cfgLLM, *pipeline)
cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
if err != nil {

View File

@@ -1,102 +0,0 @@
package openai
import (
"context"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
laudio "github.com/mudler/LocalAI/pkg/audio"
"github.com/mudler/LocalAI/pkg/sound"
)
// emitSpeech synthesizes text and sends the audio to the client. When the
// pipeline opts into TTS streaming it forwards each PCM chunk as its own
// response.output_audio.delta as soon as the backend produces it; otherwise it
// synthesizes the whole utterance and sends it as a single delta.
//
// It deliberately does NOT emit transcript or audio-done events: the caller owns
// those so a streamed reply can be split into several spoken segments that share
// one response/item.
//
// It returns the PCM audio (at the session output rate) accumulated across all
// chunks, which the caller base64-encodes onto the conversation item. For WebRTC
// the audio goes over the RTP track instead, so the returned slice is empty.
func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) ([]byte, error) {
if text == "" {
return nil, nil
}
_, isWebRTC := t.(*WebRTCTransport)
var wsAudio []byte // PCM at the session output rate, accumulated for the item record
// sendChunk hands one PCM buffer to the transport: WebRTC consumes the raw
// PCM directly (it resamples internally); WebSocket gets base64 PCM at the
// session output rate via a JSON delta event.
sendChunk := func(pcm []byte, sampleRate int) error {
if len(pcm) == 0 {
return nil
}
if err := t.SendAudio(ctx, pcm, sampleRate); err != nil {
return err
}
if isWebRTC {
return nil
}
wsPCM := pcm
if sampleRate != 0 && sampleRate != session.OutputSampleRate {
samples := sound.BytesToInt16sLE(pcm)
resampled := sound.ResampleInt16(samples, sampleRate, session.OutputSampleRate)
wsPCM = sound.Int16toBytesLE(resampled)
}
wsAudio = append(wsAudio, wsPCM...)
return t.SendEvent(types.ResponseOutputAudioDeltaEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: itemID,
OutputIndex: 0,
ContentIndex: 0,
Delta: base64.StdEncoding.EncodeToString(wsPCM),
})
}
language := ""
if session.InputAudioTranscription != nil {
language = session.InputAudioTranscription.Language
}
if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTTS() {
if err := session.ModelInterface.TTSStream(ctx, text, session.Voice, language, sendChunk); err != nil {
return nil, err
}
return wsAudio, nil
}
// Unary fallback: synthesize the whole utterance to a file, then emit once.
audioFilePath, res, err := session.ModelInterface.TTS(ctx, text, session.Voice, language)
if err != nil {
return nil, err
}
if res != nil && !res.Success {
return nil, fmt.Errorf("tts generation failed: %s", res.Message)
}
defer func() { _ = os.Remove(audioFilePath) }()
// filepath.Clean normalizes the backend-produced temp path before reading
// (also keeps gosec G304 quiet — the path is backend-controlled, not user input).
audioBytes, err := os.ReadFile(filepath.Clean(audioFilePath))
if err != nil {
return nil, fmt.Errorf("read tts audio: %w", err)
}
pcm, sampleRate := laudio.ParseWAV(audioBytes)
if sampleRate == 0 {
sampleRate = session.OutputSampleRate
}
if err := sendChunk(pcm, sampleRate); err != nil {
return nil, err
}
return wsAudio, nil
}

View File

@@ -1,70 +0,0 @@
package openai
import (
"context"
"os"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
laudio "github.com/mudler/LocalAI/pkg/audio"
)
// emitSpeech synthesizes a piece of text and forwards the audio to the client,
// streaming a delta per TTS chunk when the pipeline opts in, or sending the
// whole utterance as one delta otherwise.
var _ = Describe("emitSpeech", func() {
ttsOn := true
streamingSession := func(m Model) *Session {
return &Session{
OutputSampleRate: 24000,
ModelInterface: m,
ModelConfig: &config.ModelConfig{
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{TTS: &ttsOn}},
},
}
}
It("streams one output_audio.delta per TTS chunk when streaming is enabled", func() {
m := &fakeModel{
ttsStreamChunks: [][]byte{{1, 2}, {3, 4}, {5, 6}},
ttsStreamRate: 24000,
}
t := &fakeTransport{}
audio, err := emitSpeech(context.Background(), t, streamingSession(m), "resp1", "item1", "Hello there.")
Expect(err).ToNot(HaveOccurred())
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(3))
// The returned audio is all chunks concatenated (session output rate).
Expect(audio).To(Equal([]byte{1, 2, 3, 4, 5, 6}))
})
It("sends a single output_audio.delta in unary mode", func() {
// A minimal real WAV file for the unary TTS path to read + parse.
f, err := os.CreateTemp("", "emit-*.wav")
Expect(err).ToNot(HaveOccurred())
defer func() { _ = os.Remove(f.Name()) }()
pcm := make([]byte, 320) // 160 samples of silence
hdr := laudio.NewWAVHeader(uint32(len(pcm)))
Expect(hdr.Write(f)).To(Succeed())
_, err = f.Write(pcm)
Expect(err).ToNot(HaveOccurred())
Expect(f.Close()).To(Succeed())
session := &Session{
OutputSampleRate: 24000,
ModelInterface: &fakeModel{ttsFile: f.Name()},
ModelConfig: &config.ModelConfig{}, // streaming off
}
t := &fakeTransport{}
_, err = emitSpeech(context.Background(), t, session, "resp1", "item1", "Hello there.")
Expect(err).ToNot(HaveOccurred())
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(1))
})
})

View File

@@ -1,315 +0,0 @@
package openai
import (
"context"
"encoding/base64"
"fmt"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/reasoning"
)
// transcriptStreamer turns streamed LLM tokens into the assistant's spoken
// transcript: it strips reasoning incrementally and sends one
// response.output_audio_transcript.delta per content fragment. It does NOT
// synthesize audio — the caller buffers the full message and synthesizes it
// once (streaming the audio chunks when the TTS backend supports TTSStream),
// which works uniformly for streaming and non-streaming TTS and for languages
// without sentence or word boundaries.
type transcriptStreamer struct {
ctx context.Context
t Transport
responseID string
itemID string
extractor *reasoning.ReasoningExtractor
// announce, if set, is invoked once just before the first transcript delta.
// It lets the caller create the assistant item lazily, so a content-less
// tool-call turn never emits a spurious empty assistant item.
announce func()
announced bool
}
func newTranscriptStreamer(ctx context.Context, t Transport, responseID, itemID, thinkingStartToken string, reasoningCfg reasoning.Config) *transcriptStreamer {
return &transcriptStreamer{
ctx: ctx,
t: t,
responseID: responseID,
itemID: itemID,
extractor: reasoning.NewReasoningExtractor(thinkingStartToken, spokenReasoningConfig(reasoningCfg)),
}
}
// onToken handles one streamed unit of model output, sending a transcript delta
// for the new content (reasoning stripped) and returning that content delta so
// the caller can also feed it to the clause chunker. For plain-content models
// the unit is the raw text token; for autoparser tool turns the backend clears
// the text and delivers content via ChatDeltas, so the caller passes that
// content here. Returns "" when the token produced no new spoken content.
func (s *transcriptStreamer) onToken(token string) string {
_, content := s.extractor.ProcessToken(token)
if content == "" {
return ""
}
if !s.announced {
s.announced = true
if s.announce != nil {
s.announce()
}
}
_ = s.t.SendEvent(types.ResponseOutputAudioTranscriptDeltaEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: s.responseID,
ItemID: s.itemID,
OutputIndex: 0,
ContentIndex: 0,
Delta: content,
})
return content
}
// content returns the full transcript so far with reasoning stripped.
func (s *transcriptStreamer) content() string {
return s.extractor.CleanedContent()
}
// streamLLMResponse drives a streamed realtime reply. It streams the assistant
// transcript as the LLM generates, then synthesizes the whole buffered message
// once (streaming the audio chunks when the TTS backend supports it, otherwise a
// single unary delta). Tool calls parsed from the autoparser ChatDeltas are
// emitted after the spoken content. The assistant content item is created lazily
// on the first content delta, so a content-less tool-call turn emits only the
// tool calls. It returns true when it has fully handled the response so the
// caller can return; callers must only invoke it for an audio modality, and with
// tools only when the model uses its tokenizer template (see triggerResponseAtTurn).
func streamLLMResponse(ctx context.Context, session *Session, conv *Conversation, t Transport, responseID string, history schema.Messages, images []string, llmCfg *config.ModelConfig, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, toolTurn int) bool {
itemID := generateItemID()
item := types.MessageItemUnion{
Assistant: &types.MessageItemAssistant{
ID: itemID,
Status: types.ItemStatusInProgress,
Content: []types.MessageContentOutput{{Type: types.MessageContentTypeOutputAudio}},
},
}
// announce creates the assistant content item lazily, just before the first
// transcript delta — a tool-only turn never produces content, so it stays out
// of the conversation and the client sees only the tool calls.
announced := false
announce := func() {
announced = true
conv.Lock.Lock()
conv.Items = append(conv.Items, &item)
conv.Lock.Unlock()
sendEvent(t, types.ResponseOutputItemAddedEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
OutputIndex: 0,
Item: item,
})
sendEvent(t, types.ResponseContentPartAddedEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: itemID,
OutputIndex: 0,
ContentIndex: 0,
Part: item.Assistant.Content[0],
})
}
cancel := func() {
if announced {
conv.Lock.Lock()
for i := len(conv.Items) - 1; i >= 0; i-- {
if conv.Items[i].Assistant != nil && conv.Items[i].Assistant.ID == itemID {
conv.Items = append(conv.Items[:i], conv.Items[i+1:]...)
break
}
}
conv.Lock.Unlock()
}
sendEvent(t, types.ResponseDoneEvent{
ServerEventBase: types.ServerEventBase{},
Response: types.Response{ID: responseID, Object: "realtime.response", Status: types.ResponseStatusCancelled},
})
}
var template string
if llmCfg.TemplateConfig.UseTokenizerTemplate {
template = llmCfg.GetModelTemplate()
} else {
template = llmCfg.TemplateConfig.Chat
}
thinkingStartToken := reasoning.DetectThinkingStartToken(template, &llmCfg.ReasoningConfig)
// The autoparser (tokenizer-template path) already delivers reasoning-free
// content. Prefilling the thinking start token here would re-tag that clean
// content as an unclosed reasoning block, leaving CleanedContent() empty —
// no spoken reply, no TTS. Disable the prefill; closed tag pairs are still
// stripped (PEG-fallback case, #9985).
reasoningCfg := llmCfg.ReasoningConfig
if llmCfg.TemplateConfig.UseTokenizerTemplate {
disablePrefill := true
reasoningCfg.DisableReasoningTagPrefill = &disablePrefill
}
streamer := newTranscriptStreamer(ctx, t, responseID, itemID, thinkingStartToken, reasoningCfg)
streamer.announce = announce
// Clause chunking (opt-in): synthesize each clause as soon as it completes
// instead of buffering the whole reply. streamedAudio accumulates the PCM
// across clauses for the conversation item record; ttsErr captures the first
// synthesis failure so the token callback can stop the prediction. emitSpeech
// runs synchronously here — the LLM keeps generating into the gRPC stream
// while a clause is synthesized, so audio still starts mid-generation.
var chunker *clauseChunker
if session.ModelConfig != nil && session.ModelConfig.Pipeline.ChunkClauses() {
chunker = newClauseChunker(defaultClauseMinRunes, defaultClauseMaxRunes)
}
var streamedAudio []byte
var ttsErr error
speakClause := func(clause string) error {
a, err := emitSpeech(ctx, t, session, responseID, itemID, clause)
if err != nil {
return err
}
streamedAudio = append(streamedAudio, a...)
return nil
}
// fail reports a mid-stream failure. A cancelled context means the client
// interrupted (barge-in), so roll the turn back instead of erroring.
fail := func(code, msg string, err error) bool {
if ctx.Err() != nil {
cancel()
} else {
sendError(t, code, fmt.Sprintf("%s: %v", msg, err), "", itemID)
}
return true
}
cb := func(token string, usage backend.TokenUsage) bool {
if ctx.Err() != nil {
return false
}
// Plain-content models stream text via the token; autoparser tool turns
// clear the text and deliver content via ChatDeltas, so prefer the latter
// when present. Either way only content reaches the transcript — tool-call
// deltas are parsed from the final response below.
text := token
if len(usage.ChatDeltas) > 0 {
text = functions.ContentFromChatDeltas(usage.ChatDeltas)
}
delta := streamer.onToken(text)
if chunker != nil && delta != "" {
for _, clause := range chunker.push(delta) {
if ttsErr = speakClause(clause); ttsErr != nil {
return false // stop the prediction; reported after predFunc returns
}
}
}
return true
}
predFunc, err := session.ModelInterface.Predict(ctx, history, images, nil, nil, cb, tools, toolChoice, nil, nil, nil)
if err != nil {
sendError(t, "inference_failed", fmt.Sprintf("backend error: %v", err), "", itemID)
return true
}
pred, err := predFunc()
// A clause synthesis failed mid-stream (the callback stopped the prediction);
// report it as a TTS error rather than a prediction error.
if ttsErr != nil {
return fail("tts_error", "TTS generation failed", ttsErr)
}
if err != nil {
return fail("prediction_failed", "backend error", err)
}
if ctx.Err() != nil {
cancel()
return true
}
content := streamer.content()
toolCalls := functions.ToolCallsFromChatDeltas(pred.ChatDeltas)
// Finalize the spoken content item only when the turn produced content. A
// tool-only turn skips this entirely (no empty assistant item).
if content != "" {
if !announced {
announce()
}
// Synthesize the audio. With clause chunking the completed clauses were
// already spoken inside the token callback; flush the trailing clause(s)
// the segmenter was still holding. Otherwise buffer the whole message and
// synthesize it once. emitSpeech streams the audio chunks when the TTS
// backend supports TTSStream, otherwise it sends a single unary delta.
var audio []byte
if chunker != nil {
for _, clause := range chunker.flush() {
if ttsErr = speakClause(clause); ttsErr != nil {
break
}
}
audio = streamedAudio
} else {
audio, ttsErr = emitSpeech(ctx, t, session, responseID, itemID, content)
}
if ttsErr != nil {
return fail("tts_error", "TTS generation failed", ttsErr)
}
_, isWebRTC := t.(*WebRTCTransport)
sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: itemID,
OutputIndex: 0,
ContentIndex: 0,
Transcript: content,
})
if !isWebRTC {
sendEvent(t, types.ResponseOutputAudioDoneEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: itemID,
OutputIndex: 0,
ContentIndex: 0,
})
}
conv.Lock.Lock()
item.Assistant.Status = types.ItemStatusCompleted
item.Assistant.Content[0].Transcript = content
if !isWebRTC {
item.Assistant.Content[0].Audio = base64.StdEncoding.EncodeToString(audio)
}
conv.Lock.Unlock()
sendEvent(t, types.ResponseContentPartDoneEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: itemID,
OutputIndex: 0,
ContentIndex: 0,
Part: item.Assistant.Content[0],
})
sendEvent(t, types.ResponseOutputItemDoneEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
OutputIndex: 0,
Item: item,
})
}
// Emit any tool calls, the terminal response.done, and (for server-side
// assistant tools) the follow-up turn — shared with the buffered path.
emitToolCallItems(ctx, session, conv, t, responseID, toolCalls, content != "", toolTurn)
return true
}

View File

@@ -1,213 +0,0 @@
package openai
import (
"context"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/reasoning"
)
// transcriptStreamer turns streamed LLM tokens into incremental transcript
// deltas, stripping reasoning. Audio is synthesized once from the full message
// by the caller, so there is no per-sentence segmentation.
var _ = Describe("transcriptStreamer", func() {
It("emits one transcript delta per content token", func() {
t := &fakeTransport{}
s := newTranscriptStreamer(context.Background(), t, "resp1", "item1", "", reasoning.Config{})
for _, tok := range []string{"Hello", " world.", " Bye"} {
s.onToken(tok)
}
Expect(s.content()).To(Equal("Hello world. Bye"))
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(3))
Expect(t.transcriptDeltaText()).To(Equal("Hello world. Bye"))
})
It("strips leaked reasoning even when reasoning is disabled (disable_thinking safety net)", func() {
// disable_thinking maps to DisableReasoning=true (enable_thinking=false to
// the backend). If the model emits thinking anyway, the transcript must
// still not leak it: stripping always runs for spoken output.
disable := true
t := &fakeTransport{}
s := newTranscriptStreamer(context.Background(), t, "resp1", "item1", "",
reasoning.Config{DisableReasoning: &disable})
s.onToken("<think>secret plan</think>")
s.onToken("The answer is 42.")
Expect(s.content()).To(Equal("The answer is 42."))
Expect(s.content()).ToNot(ContainSubstring("secret plan"))
Expect(t.transcriptDeltaText()).ToNot(ContainSubstring("secret plan"))
})
It("does not swallow autoparser content when the template has a thinking start token (tokenizer-template path)", func() {
// Regression: with tag prefill on, the detected <think> token is
// prepended to the autoparser's already-clean content, swallowing the
// whole reply (empty transcript → no TTS). streamLLMResponse disables
// the prefill for the tokenizer-template path.
disablePrefill := true
t := &fakeTransport{}
s := newTranscriptStreamer(context.Background(), t, "resp1", "item1", "<think>",
reasoning.Config{DisableReasoningTagPrefill: &disablePrefill})
s.onToken("Hello")
s.onToken(" there.")
Expect(s.content()).To(Equal("Hello there."))
Expect(t.transcriptDeltaText()).To(Equal("Hello there."))
})
It("still strips embedded closed reasoning tags with prefill disabled (PEG-fallback safety, #9985)", func() {
// Disabling prefill must not stop stripping closed <think>…</think>
// pairs the PEG fallback can leave in autoparser content.
disablePrefill := true
t := &fakeTransport{}
s := newTranscriptStreamer(context.Background(), t, "resp1", "item1", "<think>",
reasoning.Config{DisableReasoningTagPrefill: &disablePrefill})
s.onToken("<think>secret</think>")
s.onToken("The answer is 42.")
Expect(s.content()).To(Equal("The answer is 42."))
Expect(t.transcriptDeltaText()).ToNot(ContainSubstring("secret"))
})
})
// streamLLMResponse drives a full streamed realtime turn: live transcript
// deltas while the LLM generates, then the whole message is synthesized once.
var _ = Describe("streamLLMResponse", func() {
It("streams transcript deltas then synthesizes the whole message once", func() {
on := true
m := &fakeModel{
predictTokens: []string{"Hello", " world.", " How are you?"},
predictResp: backend.LLMResponse{Response: "Hello world. How are you?"},
ttsStreamChunks: [][]byte{{9}},
ttsStreamRate: 24000,
}
session := &Session{
OutputSampleRate: 24000,
ModelInterface: m,
ModelConfig: &config.ModelConfig{
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{LLM: &on, TTS: &on}},
},
}
conv := &Conversation{}
t := &fakeTransport{}
llmCfg := &config.ModelConfig{}
handled := streamLLMResponse(context.Background(), session, conv, t, "resp1", nil, nil, llmCfg, nil, nil, 0)
Expect(handled).To(BeTrue())
// One live transcript delta per streamed token.
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(3))
// The whole message is synthesized ONCE (not per sentence): a single
// emitSpeech replays the one TTS stream chunk.
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(1))
Expect(t.transcriptDeltaText()).To(Equal("Hello world. How are you?"))
})
It("synthesizes each clause as it completes when clause chunking is enabled", func() {
on := true
m := &fakeModel{
predictTokens: []string{"Hello world.", " How are you?"},
predictResp: backend.LLMResponse{Response: "Hello world. How are you?"},
ttsStreamChunks: [][]byte{{9}},
ttsStreamRate: 24000,
}
session := &Session{
OutputSampleRate: 24000,
ModelInterface: m,
ModelConfig: &config.ModelConfig{
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{LLM: &on, TTS: &on, ClauseChunking: &on}},
},
}
conv := &Conversation{}
t := &fakeTransport{}
llmCfg := &config.ModelConfig{}
handled := streamLLMResponse(context.Background(), session, conv, t, "resp1", nil, nil, llmCfg, nil, nil, 0)
Expect(handled).To(BeTrue())
// Two clauses ("Hello world." mid-stream, "How are you?" on flush) → two
// emitSpeech calls → two audio deltas, vs one for whole-message buffering.
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(2))
// The full transcript still streams verbatim.
Expect(t.transcriptDeltaText()).To(Equal("Hello world. How are you?"))
// Exactly one terminal response.done.
Expect(t.countEvents(types.ServerEventTypeResponseDone)).To(Equal(1))
})
It("streams content deltas and emits tool-call items (autoparser tool turn)", func() {
on := true
// Autoparser path: reply.Message is empty; content + tool calls arrive via
// ChatDeltas. Chunk 1 carries content, chunk 2 carries the tool call.
contentDelta := []*proto.ChatDelta{{Content: "Let me check."}}
toolDelta := []*proto.ChatDelta{{ToolCalls: []*proto.ToolCallDelta{{Index: 0, Name: "get_weather", Arguments: `{"city":"Paris"}`}}}}
m := &fakeModel{
predictTokens: []string{"", ""},
predictChunkDeltas: [][]*proto.ChatDelta{contentDelta, toolDelta},
predictResp: backend.LLMResponse{ChatDeltas: append(append([]*proto.ChatDelta{}, contentDelta...), toolDelta...)},
ttsStreamChunks: [][]byte{{9}},
ttsStreamRate: 24000,
}
session := &Session{
OutputSampleRate: 24000,
ModelInterface: m,
ModelConfig: &config.ModelConfig{
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{LLM: &on, TTS: &on}},
},
}
conv := &Conversation{}
t := &fakeTransport{}
llmCfg := &config.ModelConfig{}
llmCfg.TemplateConfig.UseTokenizerTemplate = true
handled := streamLLMResponse(context.Background(), session, conv, t, "resp1", nil, nil, llmCfg, nil, nil, 0)
Expect(handled).To(BeTrue())
// The spoken content was streamed live.
Expect(t.transcriptDeltaText()).To(Equal("Let me check."))
// The tool call is emitted as a function_call item.
Expect(t.countEvents(types.ServerEventTypeResponseFunctionCallArgumentsDone)).To(Equal(1))
// Exactly one terminal response.done.
Expect(t.countEvents(types.ServerEventTypeResponseDone)).To(Equal(1))
})
It("emits only tool-call items for a content-less tool turn (no empty assistant item)", func() {
on := true
toolDelta := []*proto.ChatDelta{{ToolCalls: []*proto.ToolCallDelta{{Index: 0, Name: "get_weather", Arguments: `{"city":"Rome"}`}}}}
m := &fakeModel{
predictTokens: []string{""},
predictChunkDeltas: [][]*proto.ChatDelta{toolDelta},
predictResp: backend.LLMResponse{ChatDeltas: toolDelta},
}
session := &Session{
OutputSampleRate: 24000,
ModelInterface: m,
ModelConfig: &config.ModelConfig{
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{LLM: &on, TTS: &on}},
},
}
conv := &Conversation{}
t := &fakeTransport{}
llmCfg := &config.ModelConfig{}
llmCfg.TemplateConfig.UseTokenizerTemplate = true
handled := streamLLMResponse(context.Background(), session, conv, t, "resp1", nil, nil, llmCfg, nil, nil, 0)
Expect(handled).To(BeTrue())
// No content → no transcript deltas and no spurious assistant content item.
Expect(t.transcriptDeltaText()).To(Equal(""))
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(0))
// The tool call is still emitted.
Expect(t.countEvents(types.ServerEventTypeResponseFunctionCallArgumentsDone)).To(Equal(1))
Expect(t.countEvents(types.ServerEventTypeResponseDone)).To(Equal(1))
})
})

View File

@@ -1,33 +0,0 @@
package openai
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/reasoning"
)
// applyPipelineThinking forces the LLM's reasoning/thinking off when the realtime
// pipeline sets disable_thinking, mapping to the enable_thinking=false backend
// metadata via ReasoningConfig.DisableReasoning. The LLM config passed in is the
// per-session copy returned by the config loader, so this does not affect other
// users of the same model. When the pipeline does not set disable_thinking the
// LLM config is left untouched.
func applyPipelineThinking(llm *config.ModelConfig, pipeline config.Pipeline) {
if llm == nil || !pipeline.ThinkingDisabled() {
return
}
disable := true
llm.ReasoningConfig.DisableReasoning = &disable
}
// spokenReasoningConfig adapts a model's reasoning config for stripping reasoning
// OUT of realtime spoken output. ReasoningConfig.DisableReasoning is overloaded:
// the backend reads it as the "enable_thinking=false" hint (which pipeline
// disable_thinking sets via applyPipelineThinking), but the reasoning extractor
// reads it as "skip stripping, assume there is no reasoning". Honouring the latter
// when extracting for speech would leak raw <think>…</think> whenever the model
// ignores the suppression hint. Spoken output must never contain reasoning, so we
// always strip: clear DisableReasoning while keeping custom tokens/tag pairs.
func spokenReasoningConfig(cfg reasoning.Config) reasoning.Config {
cfg.DisableReasoning = nil
return cfg
}

View File

@@ -1,50 +0,0 @@
package openai
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/reasoning"
)
// applyPipelineThinking lets a realtime pipeline force the LLM's thinking off
// (enable_thinking=false metadata) without editing the LLM model config.
var _ = Describe("applyPipelineThinking", func() {
It("disables reasoning on the LLM config when the pipeline disables thinking", func() {
disable := true
llm := &config.ModelConfig{}
applyPipelineThinking(llm, config.Pipeline{DisableThinking: &disable})
Expect(llm.ReasoningConfig.DisableReasoning).ToNot(BeNil())
Expect(*llm.ReasoningConfig.DisableReasoning).To(BeTrue())
})
It("leaves the LLM config untouched when the pipeline does not set disable_thinking", func() {
llm := &config.ModelConfig{}
applyPipelineThinking(llm, config.Pipeline{})
Expect(llm.ReasoningConfig.DisableReasoning).To(BeNil())
})
})
// spokenReasoningConfig clears DisableReasoning so realtime spoken output always
// strips reasoning, even though disable_thinking sets DisableReasoning=true on the
// LLM config (which the backend reads as enable_thinking=false).
var _ = Describe("spokenReasoningConfig", func() {
It("clears DisableReasoning so the extractor still strips leaked reasoning", func() {
disable := true
out := spokenReasoningConfig(reasoning.Config{DisableReasoning: &disable})
Expect(out.DisableReasoning).To(BeNil())
})
It("preserves the other reasoning settings", func() {
disable := true
out := spokenReasoningConfig(reasoning.Config{
DisableReasoning: &disable,
ThinkingStartTokens: []string{"<reason>"},
TagPairs: []reasoning.TagPair{{Start: "<reason>", End: "</reason>"}},
})
Expect(out.ThinkingStartTokens).To(Equal([]string{"<reason>"}))
Expect(out.TagPairs).To(HaveLen(1))
Expect(out.TagPairs[0].Start).To(Equal("<reason>"))
})
})

View File

@@ -1,63 +0,0 @@
package openai
import (
"context"
"fmt"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
)
// emitTranscription transcribes a committed utterance and emits the transcription
// events for it, returning the final transcript text. With
// pipeline.streaming.transcription enabled it streams each transcript fragment as
// a conversation.item.input_audio_transcription.delta as the backend produces it,
// then a completed event; otherwise it transcribes the whole utterance and emits
// a single completed event. delta and completed events share itemID.
func emitTranscription(ctx context.Context, t Transport, session *Session, itemID, audioPath string) (string, error) {
cfg := session.InputAudioTranscription
if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTranscription() {
final, err := session.ModelInterface.TranscribeStream(ctx, audioPath, cfg.Language, false, false, cfg.Prompt, func(delta string) {
_ = t.SendEvent(types.ConversationItemInputAudioTranscriptionDeltaEvent{
ServerEventBase: types.ServerEventBase{EventID: "event_TODO"},
ItemID: itemID,
ContentIndex: 0,
Delta: delta,
})
})
if err != nil {
return "", err
}
transcript := ""
if final != nil {
transcript = final.Text
}
if err := t.SendEvent(types.ConversationItemInputAudioTranscriptionCompletedEvent{
ServerEventBase: types.ServerEventBase{EventID: "event_TODO"},
ItemID: itemID,
ContentIndex: 0,
Transcript: transcript,
}); err != nil {
return "", err
}
return transcript, nil
}
// Unary fallback: transcribe the whole utterance, emit one completed event.
tr, err := session.ModelInterface.Transcribe(ctx, audioPath, cfg.Language, false, false, cfg.Prompt)
if err != nil {
return "", err
}
if tr == nil {
return "", fmt.Errorf("transcribe result is nil")
}
if err := t.SendEvent(types.ConversationItemInputAudioTranscriptionCompletedEvent{
ServerEventBase: types.ServerEventBase{EventID: "event_TODO"},
ItemID: itemID,
ContentIndex: 0,
Transcript: tr.Text,
}); err != nil {
return "", err
}
return tr.Text, nil
}

View File

@@ -1,54 +0,0 @@
package openai
import (
"context"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
"github.com/mudler/LocalAI/core/schema"
)
// emitTranscription transcribes a committed utterance, streaming transcript text
// deltas when the pipeline opts in, and returns the final transcript text.
var _ = Describe("emitTranscription", func() {
It("streams transcription deltas then a completed event when streaming is enabled", func() {
on := true
session := &Session{
InputAudioTranscription: &types.AudioTranscription{},
ModelConfig: &config.ModelConfig{
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{Transcription: &on}},
},
ModelInterface: &fakeModel{
transcribeDeltas: []string{"Hel", "lo", " world"},
transcribeFinal: &schema.TranscriptionResult{Text: "Hello world"},
},
}
t := &fakeTransport{}
transcript, err := emitTranscription(context.Background(), t, session, "item1", "/tmp/x.wav")
Expect(err).ToNot(HaveOccurred())
Expect(transcript).To(Equal("Hello world"))
Expect(t.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionDelta)).To(Equal(3))
Expect(t.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionCompleted)).To(Equal(1))
})
It("emits a single completed event with no deltas in unary mode", func() {
session := &Session{
InputAudioTranscription: &types.AudioTranscription{},
ModelConfig: &config.ModelConfig{}, // streaming off
ModelInterface: &fakeModel{transcribeFinal: &schema.TranscriptionResult{Text: "Hi"}},
}
t := &fakeTransport{}
transcript, err := emitTranscription(context.Background(), t, session, "item1", "/tmp/x.wav")
Expect(err).ToNot(HaveOccurred())
Expect(transcript).To(Equal("Hi"))
Expect(t.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionDelta)).To(Equal(0))
Expect(t.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionCompleted)).To(Equal(1))
})
})

View File

@@ -48,8 +48,7 @@ func RealtimeCalls(application *application.Application) echo.HandlerFunc {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "codec registration failed"})
}
se := webRTCSettingEngine(application.ApplicationConfig())
api := webrtc.NewAPI(webrtc.WithMediaEngine(m), webrtc.WithSettingEngine(se))
api := webrtc.NewAPI(webrtc.WithMediaEngine(m))
pc, err := api.NewPeerConnection(webrtc.Configuration{})
if err != nil {

View File

@@ -1,47 +0,0 @@
package openai
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/xlog"
"github.com/pion/webrtc/v4"
)
// webRTCSettingEngine builds the pion SettingEngine for /v1/realtime WebRTC.
//
// With a default (empty) SettingEngine, pion gathers a host ICE candidate for
// every local interface. Under Docker host networking that includes bridge
// addresses (docker0/veth, 172.x) that a remote browser cannot route to; the
// connection often establishes on a good pair and then drops once ICE consent
// checks fail on the unreachable ones. The two opt-in knobs below let an
// operator advertise only the reachable address.
func webRTCSettingEngine(cfg *config.ApplicationConfig) webrtc.SettingEngine {
s := webrtc.SettingEngine{}
if cfg == nil {
return s
}
if len(cfg.WebRTCNAT1To1IPs) > 0 {
s.SetNAT1To1IPs(cfg.WebRTCNAT1To1IPs, webrtc.ICECandidateTypeHost)
xlog.Debug("realtime webrtc: advertising NAT 1:1 host IPs", "ips", cfg.WebRTCNAT1To1IPs)
}
if filter := iceInterfaceFilter(cfg.WebRTCICEInterfaces); filter != nil {
s.SetInterfaceFilter(filter)
xlog.Debug("realtime webrtc: restricting ICE interfaces", "interfaces", cfg.WebRTCICEInterfaces)
}
return s
}
// iceInterfaceFilter returns an interface allow-list predicate for pion, or nil
// when no interfaces are configured (pion's default: gather from all).
func iceInterfaceFilter(allowed []string) func(string) bool {
if len(allowed) == 0 {
return nil
}
set := make(map[string]struct{}, len(allowed))
for _, name := range allowed {
set[name] = struct{}{}
}
return func(iface string) bool {
_, ok := set[iface]
return ok
}
}

View File

@@ -1,39 +0,0 @@
package openai
import (
"github.com/mudler/LocalAI/core/config"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("webRTC ICE settings", func() {
Describe("iceInterfaceFilter", func() {
It("returns nil when no interfaces are configured", func() {
Expect(iceInterfaceFilter(nil)).To(BeNil())
Expect(iceInterfaceFilter([]string{})).To(BeNil())
})
It("admits only the configured interfaces", func() {
f := iceInterfaceFilter([]string{"eth0", "wlan0"})
Expect(f).NotTo(BeNil())
Expect(f("eth0")).To(BeTrue())
Expect(f("wlan0")).To(BeTrue())
Expect(f("docker0")).To(BeFalse())
Expect(f("veth123")).To(BeFalse())
})
})
Describe("webRTCSettingEngine", func() {
It("does not panic on a nil config", func() {
Expect(func() { webRTCSettingEngine(nil) }).NotTo(Panic())
})
It("builds an engine with NAT 1:1 IPs and an interface filter configured", func() {
cfg := &config.ApplicationConfig{
WebRTCNAT1To1IPs: []string{"192.168.1.10"},
WebRTCICEInterfaces: []string{"eth0"},
}
Expect(func() { webRTCSettingEngine(cfg) }).NotTo(Panic())
})
})
})

View File

@@ -1356,7 +1356,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig)
// Extract reasoning from result before cleaning
reasoningContent, cleanedResult := reason.ExtractReasoningComplete(result, thinkingStartToken, cfg.ReasoningConfig)
reasoningContent, cleanedResult := reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
// Parse tool calls if using functions
var outputItems []schema.ORItemField
@@ -1996,7 +1996,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
finalCleanedResult = extractor.CleanedContent()
}
if finalReasoning == "" && finalCleanedResult == "" {
finalReasoning, finalCleanedResult = reason.ExtractReasoningComplete(result, thinkingStartToken, cfg.ReasoningConfig)
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
}
// Close reasoning item if it exists and wasn't closed yet
@@ -2493,7 +2493,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
finalCleanedResult = extractor.CleanedContent()
}
if finalReasoning == "" && finalCleanedResult == "" {
finalReasoning, finalCleanedResult = reason.ExtractReasoningComplete(result, thinkingStartToken, cfg.ReasoningConfig)
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
}
// Close reasoning item if it exists and wasn't closed yet

View File

@@ -21,20 +21,6 @@ var (
mockBackendPath string
)
// testHTTPAddr is the listen address used by specs that start a full HTTP
// server. Configurable so the suite can run on machines where the default
// port is taken by an unrelated service (override: LOCALAI_TEST_HTTP_PORT).
var testHTTPAddr = func() string {
port := os.Getenv("LOCALAI_TEST_HTTP_PORT")
if port == "" {
port = "9090"
}
return "127.0.0.1:" + port
}()
// testHTTPBase is the matching http://host:port prefix for client requests.
var testHTTPBase = "http://" + testHTTPAddr
// findMockBackendBinary locates the mock-backend binary built by
// `make build-mock-backend`. Mirrors the lookup used by
// tests/e2e/e2e_suite_test.go so both suites consume the same artifact.

View File

@@ -59,14 +59,14 @@ var _ = Describe("Open Responses API", func() {
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
xlog.Error("server error", "error", err)
}
}()
// Wait for API to be ready
Eventually(func() error {
resp, err := http.Get(testHTTPBase + "/healthz")
resp, err := http.Get("http://127.0.0.1:9090/healthz")
if err != nil {
return err
}
@@ -95,7 +95,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -118,7 +118,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -143,7 +143,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -168,7 +168,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -196,7 +196,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -241,7 +241,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -269,7 +269,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -297,7 +297,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -328,7 +328,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -358,7 +358,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -386,7 +386,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -418,7 +418,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -454,7 +454,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -490,7 +490,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -539,7 +539,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -590,7 +590,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -624,7 +624,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -658,7 +658,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -680,7 +680,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -727,7 +727,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -756,7 +756,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -799,7 +799,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -835,7 +835,7 @@ var _ = Describe("Open Responses API", func() {
payload1, err := json.Marshal(reqBody1)
Expect(err).ToNot(HaveOccurred())
req1, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload1))
req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1))
Expect(err).ToNot(HaveOccurred())
req1.Header.Set("Content-Type", "application/json")
req1.Header.Set("Authorization", bearerKey)
@@ -869,7 +869,7 @@ var _ = Describe("Open Responses API", func() {
payload2, err := json.Marshal(reqBody2)
Expect(err).ToNot(HaveOccurred())
req2, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload2))
req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2))
Expect(err).ToNot(HaveOccurred())
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", bearerKey)
@@ -897,7 +897,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
@@ -933,7 +933,7 @@ var _ = Describe("Open Responses API", func() {
payload1, err := json.Marshal(reqBody1)
Expect(err).ToNot(HaveOccurred())
req1, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload1))
req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1))
Expect(err).ToNot(HaveOccurred())
req1.Header.Set("Content-Type", "application/json")
req1.Header.Set("Authorization", bearerKey)
@@ -983,7 +983,7 @@ var _ = Describe("Open Responses API", func() {
payload2, err := json.Marshal(reqBody2)
Expect(err).ToNot(HaveOccurred())
req2, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload2))
req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2))
Expect(err).ToNot(HaveOccurred())
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", bearerKey)
@@ -1009,7 +1009,7 @@ var _ = Describe("Open Responses API", func() {
payload, err := json.Marshal(reqBody)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)

View File

@@ -216,12 +216,6 @@ export function useChat(initialModel = '') {
audio_url: { url: `data:${file.type};base64,${file.base64}` },
})
userFiles.push({ name: file.name, type: 'audio' })
} else if (file.type?.startsWith('video/')) {
messageContent.push({
type: 'video_url',
video_url: { url: `data:${file.type};base64,${file.base64}` },
})
userFiles.push({ name: file.name, type: 'video' })
} else {
// Text/PDF files - append to content
if (file.textContent) {

View File

@@ -506,10 +506,7 @@ export default function Backends() {
<tbody>
{backends.map((b, idx) => {
const op = getBackendOp(b)
// A failed op is intentionally kept in the operations list so the
// OperationsBar can surface the error + Dismiss; it must NOT render
// as a perpetual "Installing..." spinner here (mirrors Models.jsx).
const isProcessing = !!op && !op.error
const isProcessing = !!op
const isExpanded = expandedRow === idx
return (

View File

@@ -265,7 +265,7 @@ function UserMessageContent({ content, files }) {
<div className="chat-message-files">
{files.map((f, i) => (
<span key={i} className="chat-file-inline">
<i className={`fas ${f.type === 'image' ? 'fa-image' : f.type === 'audio' ? 'fa-headphones' : f.type === 'video' ? 'fa-film' : 'fa-file'}`} />
<i className={`fas ${f.type === 'image' ? 'fa-image' : f.type === 'audio' ? 'fa-headphones' : 'fa-file'}`} />
{f.name}
</span>
))}
@@ -274,9 +274,6 @@ function UserMessageContent({ content, files }) {
{Array.isArray(content) && content.filter(c => c.type === 'image_url').map((img, i) => (
<img key={i} src={img.image_url.url} alt="attached" className="chat-inline-image" />
))}
{Array.isArray(content) && content.filter(c => c.type === 'video_url').map((vid, i) => (
<video key={i} src={vid.video_url.url} controls className="chat-inline-video" />
))}
</>
)
}
@@ -714,7 +711,7 @@ export default function Chat() {
for (const file of e.target.files) {
const base64 = await fileToBase64(file)
const entry = { name: file.name, type: file.type, base64 }
if (!file.type.startsWith('image/') && !file.type.startsWith('audio/') && !file.type.startsWith('video/')) {
if (!file.type.startsWith('image/') && !file.type.startsWith('audio/')) {
entry.textContent = await file.text().catch(() => '')
}
newFiles.push(entry)
@@ -1247,7 +1244,7 @@ export default function Chat() {
<div className="chat-files">
{files.map((f, i) => (
<span key={i} className="chat-file-badge">
<i className={`fas ${f.type?.startsWith('image/') ? 'fa-image' : f.type?.startsWith('audio/') ? 'fa-headphones' : f.type?.startsWith('video/') ? 'fa-film' : 'fa-file'}`} />
<i className={`fas ${f.type?.startsWith('image/') ? 'fa-image' : f.type?.startsWith('audio/') ? 'fa-headphones' : 'fa-file'}`} />
{f.name}
<button onClick={() => setFiles(prev => prev.filter((_, idx) => idx !== i))}>
<i className="fas fa-xmark" />
@@ -1346,7 +1343,7 @@ export default function Chat() {
ref={fileInputRef}
type="file"
multiple
accept="image/*,audio/*,video/*,application/pdf,.txt,.md,.csv,.json"
accept="image/*,audio/*,application/pdf,.txt,.md,.csv,.json"
style={{ display: 'none' }}
onChange={handleFileChange}
/>

View File

@@ -12,7 +12,6 @@ import ActionMenu from '../components/ActionMenu'
import ResourceRow, { ChevronCell, IconCell, StopPropagationCell } from '../components/ResourceRow'
import { useModels } from '../hooks/useModels'
import { useGalleryEnrichment } from '../hooks/useGalleryEnrichment'
import { useOperations } from '../hooks/useOperations'
import { backendControlApi, modelsApi, backendsApi, systemApi, nodesApi } from '../utils/api'
import { renderMarkdown } from '../utils/markdown'
import { safeHref } from '../utils/url'
@@ -127,7 +126,6 @@ export default function Manage() {
const [activeTab, setActiveTab] = useState(TABS.some(tab => tab.key === initialTab) ? initialTab : 'models')
const { models, loading: modelsLoading, refetch: refetchModels } = useModels()
const { enrichModel, enrichBackend } = useGalleryEnrichment()
const { operations } = useOperations()
const [loadedModelIds, setLoadedModelIds] = useState(new Set())
const [backends, setBackends] = useState([])
const [backendsLoading, setBackendsLoading] = useState(true)
@@ -260,19 +258,14 @@ export default function Manage() {
return `${m}m ago`
})()
// Refresh installed backends + available upgrades when the Backends tab opens
// AND whenever a backend operation settles (operations.length changes as a
// reinstall/upgrade completes and drops off the list). Without the op-settle
// refresh the installed-version cell and the "update available" badge stay
// stale after an upgrade until the user switches tabs - the op looks like it
// "did nothing". Mirrors the operations.length watch Backends.jsx uses.
// Fetch available backend upgrades
useEffect(() => {
if (activeTab !== 'backends') return
fetchBackends()
backendsApi.checkUpgrades()
.then(data => setUpgrades(data || {}))
.catch(() => {})
}, [operations.length, activeTab, fetchBackends])
if (activeTab === 'backends') {
backendsApi.checkUpgrades()
.then(data => setUpgrades(data || {}))
.catch(() => {})
}
}, [activeTab])
const handleStopModel = (modelName) => {
setConfirmDialog({

View File

@@ -17,24 +17,6 @@ const STATUS_STYLES = {
error: { icon: 'fa-solid fa-circle', color: 'var(--color-error)', bg: 'var(--color-error-light)' },
}
// upsertAssistant merges a streamed transcript fragment into the assistant entry
// identified by the server's item_id, or appends a new entry if none exists yet.
// Keying by item_id (not a mutable index tracked across handler/updater
// boundaries) makes streamed deltas idempotent and order-independent, so React's
// batching of non-React data-channel events cannot produce a duplicate bubble.
// mode 'append' adds to the running text; 'replace' sets the final transcript.
function upsertAssistant(prev, itemId, text, mode) {
// Only assistant entries carry an id, and the streaming entry is almost
// always the newest — search from the tail so per-delta cost stays constant.
const i = prev.findLastIndex(e => e.id === itemId)
if (i === -1) {
return [...prev, { role: 'assistant', id: itemId, text }]
}
const next = [...prev]
next[i] = { ...next[i], text: mode === 'append' ? next[i].text + text : text }
return next
}
export default function Talk() {
const { addToast } = useOutletContext()
const navigate = useNavigate()
@@ -52,10 +34,7 @@ export default function Talk() {
// Transcript
const [transcript, setTranscript] = useState([])
// item_id of the assistant message currently streaming — used only to remove
// its partial bubble when a response is cancelled (barge-in). The transcript
// itself is keyed by item_id via upsertAssistant, not by this ref.
const inProgressIdRef = useRef(null)
const streamingRef = useRef(null) // tracks the index of the in-progress assistant message
// Session settings
const [instructions, setInstructions] = useState(
@@ -248,21 +227,39 @@ export default function Talk() {
break
case 'conversation.item.input_audio_transcription.completed':
if (event.transcript) {
streamingRef.current = null
setTranscript(prev => [...prev, { role: 'user', text: event.transcript }])
}
updateStatus('thinking', 'Generating response...')
break
case 'response.output_audio_transcript.delta':
if (event.delta) {
inProgressIdRef.current = event.item_id
setTranscript(prev => upsertAssistant(prev, event.item_id, event.delta, 'append'))
setTranscript(prev => {
if (streamingRef.current !== null) {
const updated = [...prev]
updated[streamingRef.current] = {
...updated[streamingRef.current],
text: updated[streamingRef.current].text + event.delta,
}
return updated
}
streamingRef.current = prev.length
return [...prev, { role: 'assistant', text: event.delta }]
})
}
break
case 'response.output_audio_transcript.done':
if (event.transcript) {
setTranscript(prev => upsertAssistant(prev, event.item_id, event.transcript, 'replace'))
setTranscript(prev => {
if (streamingRef.current !== null) {
const updated = [...prev]
updated[streamingRef.current] = { ...updated[streamingRef.current], text: event.transcript }
return updated
}
return [...prev, { role: 'assistant', text: event.transcript }]
})
}
inProgressIdRef.current = null
streamingRef.current = null
break
case 'response.output_audio.delta':
updateStatus('speaking', 'Speaking...')
@@ -284,7 +281,7 @@ export default function Talk() {
// Pretty-print JSON for readability; fall back to raw string.
try { preview = JSON.stringify(JSON.parse(preview), null, 2) } catch (_) { /* keep raw */ }
setTranscript(prev => [...prev, { role: 'tool_result', text: preview }])
inProgressIdRef.current = null // tool result ends the current assistant text run
streamingRef.current = null // tool result ends the current assistant text run
}
break
}
@@ -293,20 +290,9 @@ export default function Talk() {
// conversation.item.create + response.create when it's done.
handleFunctionCall(event)
break
case 'response.done': {
// A cancelled response (barge-in / interruption) leaves a partial,
// incrementally-streamed assistant bubble behind. The server discards
// the interrupted item from history; mirror that here (remove the
// in-progress assistant entry by item_id) so the regenerated reply
// doesn't show up as a second assistant message.
if (event.response?.status === 'cancelled' && inProgressIdRef.current) {
const id = inProgressIdRef.current
inProgressIdRef.current = null
setTranscript(prev => prev.filter(e => e.id !== id))
}
case 'response.done':
updateStatus('listening', 'Listening...')
break
}
case 'error':
hasErrorRef.current = true
updateStatus('error', 'Error: ' + (event.error?.message || 'Unknown error'))
@@ -803,7 +789,7 @@ export default function Talk() {
const iconColor = isToolCall || isToolResult ? 'var(--color-text-secondary)'
: isUser ? 'var(--color-primary)' : 'var(--color-accent)'
return (
<div key={entry.id || i} style={{ display: 'flex', alignItems: 'flex-start', gap: 'var(--spacing-xs)' }}>
<div key={i} style={{ display: 'flex', alignItems: 'flex-start', gap: 'var(--spacing-xs)' }}>
<i className={iconClass} style={{ color: iconColor, marginTop: 3, flexShrink: 0, fontSize: '0.75rem' }} />
<p style={{
margin: 0,

View File

@@ -466,11 +466,10 @@ func (s *AgentPoolService) Chat(name, message string) (string, error) {
s.collectAndCopyMetadata(metadata, chatUserID)
}
content := s.appendLocalAGIKBCitations(response.Response, name, message, response.State)
msg := map[string]any{
"id": messageID + "-agent",
"sender": "agent",
"content": content,
"content": response.Response,
"timestamp": time.Now().Format(time.RFC3339),
}
if len(metadata) > 0 {
@@ -490,79 +489,6 @@ func (s *AgentPoolService) Chat(name, message string) (string, error) {
return messageID, nil
}
func (s *AgentPoolService) appendLocalAGIKBCitations(response, agentKey, message string, states []coreTypes.ActionState) string {
if strings.TrimSpace(response) == "" {
return response
}
userID, collection := splitAgentKey(agentKey)
cfg := s.localAGI.pool.GetConfig(agentKey)
if cfg == nil || !cfg.EnableKnowledgeBase {
return response
}
citations := kbCitationsFromActionStates(states)
if len(citations) == 0 && cfg.KBAutoSearch {
maxResults := cfg.KnowledgeBaseResults
if maxResults <= 0 {
maxResults = 5
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
kbResult := agents.KBAutoSearchPrompt(ctx, s.apiURL, s.apiKey, collection, message, maxResults, userID)
citations = kbResult.Citations
}
return agents.AppendKBCitations(response, collection, userID, citations)
}
func splitAgentKey(agentKey string) (userID, name string) {
if uid, n, ok := strings.Cut(agentKey, ":"); ok {
return uid, n
}
return "", agentKey
}
func kbCitationsFromActionStates(states []coreTypes.ActionState) []agents.KBCitation {
var citations []agents.KBCitation
for _, state := range states {
citations = append(citations, kbCitationsFromMetadata(state.Metadata)...)
}
return citations
}
func kbCitationsFromMetadata(metadata map[string]any) []agents.KBCitation {
if len(metadata) == 0 {
return nil
}
fileName := metadata["file_name"]
source := metadata["source"]
if fileName == nil && source == nil {
return nil
}
citation := agents.KBCitation{
FileName: metadataString(fileName),
EntryKey: metadataString(source),
}
if citation.FileName == "" && citation.EntryKey == "" {
return nil
}
return []agents.KBCitation{citation}
}
func metadataString(value any) string {
switch v := value.(type) {
case string:
return v
case fmt.Stringer:
return v.String()
default:
return ""
}
}
// userOutputsDir returns the per-user outputs directory, creating it if needed.
// If userID is empty, falls back to the shared outputs directory.
func (s *AgentPoolService) userOutputsDir(userID string) string {

View File

@@ -1,127 +0,0 @@
package agents
import (
"fmt"
"net/url"
"strings"
"sync"
)
type kbCitationList struct {
mu sync.Mutex
citations []KBCitation
}
func (l *kbCitationList) AddKBCitations(citations []KBCitation) {
if len(citations) == 0 {
return
}
l.mu.Lock()
defer l.mu.Unlock()
l.citations = append(l.citations, citations...)
}
func (l *kbCitationList) Citations() []KBCitation {
l.mu.Lock()
defer l.mu.Unlock()
out := make([]KBCitation, len(l.citations))
copy(out, l.citations)
return out
}
// AppendKBCitations appends a markdown Sources block for KB citations.
func AppendKBCitations(response, collection, userID string, citations []KBCitation) string {
if strings.TrimSpace(response) == "" || len(citations) == 0 {
return response
}
var lines []string
seen := make(map[string]struct{})
for _, citation := range citations {
key := strings.TrimSpace(citation.EntryKey)
if key == "" {
key = strings.TrimSpace(citation.FileName)
}
if key == "" {
continue
}
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
displayName := kbCitationDisplayName(citation)
if displayName == "" {
continue
}
sourceURL := kbCitationRawFileURL(collection, citation.EntryKey, userID)
number := len(lines) + 1
if sourceURL == "" {
lines = append(lines, fmt.Sprintf("[%d] %s", number, displayName))
continue
}
lines = append(lines, fmt.Sprintf("[%d] [%s](%s)", number, escapeMarkdownLinkText(displayName), sourceURL))
}
if len(lines) == 0 {
return response
}
var sb strings.Builder
sb.WriteString(strings.TrimRight(response, "\n"))
sb.WriteString("\n\nSources:\n")
for _, line := range lines {
sb.WriteString(line)
sb.WriteString("\n")
}
return strings.TrimRight(sb.String(), "\n")
}
func kbCitationDisplayName(citation KBCitation) string {
if fileName := strings.TrimSpace(citation.FileName); fileName != "" {
return fileName
}
segments := strings.Split(strings.Trim(strings.TrimSpace(citation.EntryKey), "/"), "/")
for i := len(segments) - 1; i >= 0; i-- {
if segment := strings.TrimSpace(segments[i]); segment != "" {
return segment
}
}
return ""
}
func kbCitationRawFileURL(collection, entryKey, userID string) string {
collection = strings.TrimSpace(collection)
entryKey = strings.Trim(strings.TrimSpace(entryKey), "/")
if collection == "" || entryKey == "" {
return ""
}
var escapedEntrySegments []string
for _, segment := range strings.Split(entryKey, "/") {
if segment == "" {
continue
}
escapedEntrySegments = append(escapedEntrySegments, url.PathEscape(segment))
}
if len(escapedEntrySegments) == 0 {
return ""
}
sourceURL := "/api/agents/collections/" + url.PathEscape(collection) + "/entries-raw/" + strings.Join(escapedEntrySegments, "/")
if userID != "" {
query := url.Values{}
query.Set("user_id", userID)
sourceURL += "?" + query.Encode()
}
return sourceURL
}
func escapeMarkdownLinkText(text string) string {
text = strings.ReplaceAll(text, `\`, `\\`)
text = strings.ReplaceAll(text, "[", `\[`)
text = strings.ReplaceAll(text, "]", `\]`)
return text
}

View File

@@ -167,12 +167,10 @@ func ExecuteChatWithLLM(ctx context.Context, llm cogito.LLM, cfg *AgentConfig, m
}
}
kbCitations := &kbCitationList{}
if cfg.EnableKnowledgeBase && (kbMode == KBModeAutoSearch || kbMode == KBModeBoth) {
kbResult := KBAutoSearchPrompt(ctx, effectiveURL, effectiveKey, cfg.Name, message, cfg.KnowledgeBaseResults, userID)
if kbResult.Prompt != "" {
fragment = fragment.AddMessage(cogito.SystemMessageRole, kbResult.Prompt)
kbCitations.AddKBCitations(kbResult.Citations)
kbResults := KBAutoSearchPrompt(ctx, effectiveURL, effectiveKey, cfg.Name, message, cfg.KnowledgeBaseResults, userID)
if kbResults != "" {
fragment = fragment.AddMessage(cogito.SystemMessageRole, kbResults)
}
}
@@ -199,7 +197,7 @@ func ExecuteChatWithLLM(ctx context.Context, llm cogito.LLM, cfg *AgentConfig, m
}
cogitoOpts = append(cogitoOpts, cogito.WithTools(
cogito.NewToolDefinition(
KBSearchMemoryTool{APIURL: effectiveURL, APIKey: effectiveKey, Collection: cfg.Name, MaxResults: kbResults, UserID: userID, CitationCollector: kbCitations},
KBSearchMemoryTool{APIURL: effectiveURL, APIKey: effectiveKey, Collection: cfg.Name, MaxResults: kbResults, UserID: userID},
KBSearchMemoryArgs{},
"search_memory",
"Search the knowledge base for relevant information",
@@ -338,8 +336,6 @@ func ExecuteChatWithLLM(ctx context.Context, llm cogito.LLM, cfg *AgentConfig, m
if cfg.StripThinkingTags && response != "" {
response = stripThinkingTags(response)
}
responseForMemory := response
response = AppendKBCitations(response, cfg.Name, userID, kbCitations.Citations())
// Save conversation to KB when long-term memory is enabled.
// Use a detached context: the parent ctx may be cancelled (e.g. in distributed
@@ -348,7 +344,7 @@ func ExecuteChatWithLLM(ctx context.Context, llm cogito.LLM, cfg *AgentConfig, m
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
saveConversationToKB(ctx, llm, effectiveURL, effectiveKey, cfg, message, responseForMemory, userID)
saveConversationToKB(ctx, llm, effectiveURL, effectiveKey, cfg, message, response, userID)
}()
}

Some files were not shown because too many files have changed in this diff Show More