mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-11 18:27:32 -04:00
Compare commits
1 Commits
feat/dllm-
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b15627c864 |
@@ -1,138 +0,0 @@
|
||||
# Working on the dllm Backend
|
||||
|
||||
`mudler/dllm.cpp` is a standalone C++/ggml engine for DiffusionGemma
|
||||
block-diffusion models. LocalAI wraps it with a **pure-Go** backend at
|
||||
`backend/go/dllm/` that dlopens `libdllm.so` via purego (ebitengine/purego) -
|
||||
NOT cgo, and NOT a C++ grpc-server fork. The Go side owns chat templating
|
||||
(gemma4 renderer) and output parsing (gemma4 streaming parser) and implements
|
||||
the rich gRPC interface (`PredictRich`/`PredictStreamRich`, ChatDelta replies).
|
||||
|
||||
> NOTE: github.com/mudler/dllm.cpp is still **private** (publishing is
|
||||
> planned). Until then the Makefile's anonymous clone fails; use the local-dev
|
||||
> symlink shortcut documented at the top of `backend/go/dllm/Makefile`
|
||||
> (symlink an out-of-tree `build/libdllm.so` into the backend dir and skip the
|
||||
> clone), or a git credential helper with repo access.
|
||||
|
||||
## Pin
|
||||
|
||||
`backend/go/dllm/Makefile` pins `DLLM_VERSION?=<sha>` at the top
|
||||
(whisper / parakeet-cpp / ds4 convention). The bump-deps bot
|
||||
(`.github/workflows/bump_deps.yaml`) tracks `mudler/dllm.cpp` `main` and
|
||||
rewrites that variable. After a manual bump: `make -C backend/go/dllm purge &&
|
||||
make -C backend/go/dllm` (the clone is keyed on the directory existing, not
|
||||
the sha).
|
||||
|
||||
## C-ABI and the serialization contract
|
||||
|
||||
The binding covers the 9-symbol flat C-ABI from dllm.cpp's
|
||||
`include/dllm_capi.h` (ABI v1; `main.go` hard-fails on a version mismatch):
|
||||
`abi_version, load, free, last_error, free_string, tokenize_json, generate,
|
||||
generate_stream, cancel`. Contract points the Go wiring encodes (`capi.go`
|
||||
header comment has the full list):
|
||||
|
||||
- **One ctx = one concurrent generate/tokenize.** A per-model worker
|
||||
goroutine (`Dllm.jobs` in `dllm.go`) owns ALL C calls, making the
|
||||
serialization structural instead of lock discipline.
|
||||
- **`dllm_capi_cancel` is the ONE exception**: it only flips an atomic and may
|
||||
be called from any goroutine mid-generate, so `Dllm.Cancel` bypasses the
|
||||
worker queue. The flag resets at the start of each generate, so a watchdog
|
||||
racing a new generate must re-issue cancel.
|
||||
- **`last_error` is a borrowed pointer** and must only be read AFTER the
|
||||
failing call returned (never while a generate is in flight on the same ctx).
|
||||
- **Free vs in-flight requests**: requests hold `genMu.RLock` for their full
|
||||
duration; `Free` takes the write lock, so it only runs when nothing is in
|
||||
flight, then drains and closes the worker. Post-Free requests get a clean
|
||||
"model not loaded" error.
|
||||
- `tokenize_json`/`generate` return malloc'd `char*` (bound as `uintptr`,
|
||||
copied, then `dllm_capi_free_string`d); opts/params JSON must be a FLAT
|
||||
object of scalars (`buildOptsJSON` rejects anything else).
|
||||
|
||||
## Wire shape
|
||||
|
||||
| RPC | Implementation |
|
||||
|---|---|
|
||||
| LoadModel | `dllm_capi_load` (params: `n_gpu_layers`, `n_threads`, `ctx_len`); `Options[]` parsed into per-request gen opts (`eb_*`, `blocks`, `kv_cache`) by `parseModelGenOpts` |
|
||||
| PredictRich | render (if templated) → `dllm_capi_generate` → parse → ONE Reply with aggregated ChatDeltas + legacy `Message` bytes |
|
||||
| PredictStreamRich | `dllm_capi_generate_stream`; per committed diffusion block → UTF-8 holdback → parser.Feed → one Reply per non-empty delta batch (channel closed by the CALLER, per `pkg/grpc/interface.go`) |
|
||||
| Predict / PredictStream | Legacy paths, delegate to the rich pair (legacy stream INVERTS channel ownership: the impl closes) |
|
||||
| TokenizeString | `dllm_capi_tokenize_json` (C side prepends BOS per `vocab.add_bos`) |
|
||||
| Cancel | `dllm_capi_cancel`, exposed as the `grpc.Cancellable` capability (`pkg/grpc/interface.go`): the gRPC server arms it via `context.AfterFunc` on the Predict/PredictStream context, so client disconnects/timeouts abort the in-flight generate - llama.cpp `IsCancelled()` parity for Go backends |
|
||||
|
||||
`n_threads` and `ctx_len` are accepted-but-ignored by the engine at the
|
||||
current pin (the context bound comes from GGUF `n_ctx_train`); they are sent
|
||||
for forward compatibility.
|
||||
|
||||
## Renderer / parser (the templated chat path)
|
||||
|
||||
With `use_tokenizer_template` + raw Messages, the backend owns templating and
|
||||
parsing (the ds4 precedent, but in Go):
|
||||
|
||||
- `gemma4_renderer.go` - `RenderGemma4(msgs, toolsJSON, enableThinking,
|
||||
addGenerationPrompt)`. The file embeds the FULL `tokenizer.chat_template`
|
||||
jinja (17466 bytes, md5 `8c34cf93c7a7815b3fdb300a009c4c17`) extracted
|
||||
verbatim from `diffusiongemma-26B-A4B-it-BF16.gguf` via gguf-py - e.g.
|
||||
`python scripts/dump_gguf.py model.gguf | grep -A400 chat_template` in the
|
||||
dllm.cpp checkout - as a numbered comment block; every Go rule cites its
|
||||
"tpl L<n>" line. Re-verify the md5 before blaming the renderer for a
|
||||
mismatch with a new GGUF. **BOS exception**: the template emits
|
||||
`{{- bos_token -}}` but the renderer deliberately does NOT - dllm.cpp's
|
||||
`run_generate` tokenizes with `prepend_bos = vocab.add_bos` (true for
|
||||
gemma4), so a literal `<bos>` would double it.
|
||||
- `gemma4_parser.go` - streaming state machine turning raw model text
|
||||
(fragments can split anywhere, including mid-marker) into ChatDeltas:
|
||||
thought channels → `reasoning_content`, `<|tool_call>call:name{...}` →
|
||||
ToolCallDelta, `<turn|>` → done. Marker grammar cross-checked against vLLM
|
||||
PR #45163's gemma4 tool/reasoning parsers. Malformed payloads are re-emitted
|
||||
raw as content, never dropped.
|
||||
- Thinking is **opt-in** for this family (`Metadata["enable_thinking"]`,
|
||||
default OFF - the inverse of ds4): the template gates every thinking branch
|
||||
on `enable_thinking`, and the no-thinking render pre-closes an empty thought
|
||||
channel, so the parser always starts in content state.
|
||||
- **UTF-8 boundary holdback** (`splitValidUTF8` in `dllm.go`): per-block
|
||||
detokenization can split a multi-byte character across block boundaries, and
|
||||
grpc-go refuses to marshal invalid UTF-8 in proto3 strings. An incomplete
|
||||
trailing sequence (at most 3 bytes) is carried into the next block; genuinely
|
||||
undecodable bytes become U+FFFD.
|
||||
|
||||
Without `use_tokenizer_template`, the prompt passes through verbatim and the
|
||||
output is NOT gemma4-parsed (plain content, like any non-autoparsing backend).
|
||||
|
||||
## Tests
|
||||
|
||||
| Layer | Gate | What |
|
||||
|---|---|---|
|
||||
| `backend/go/dllm/*_test.go` (renderer/parser/wiring) | none - run in plain `go test ./backend/go/dllm/...` | Ginkgo specs over a fake `generator` seam; canonical renderer fixtures from transformers' `test_modeling_diffusion_gemma.py`, parser tables from the vLLM gemma4 parsers |
|
||||
| `backend/go/dllm/dllm_test.go` C-ABI smoke | `DLLM_TEST_LIBRARY` + `DLLM_TEST_TINY_MODEL` (dllm.cpp's `tests/fixtures/tiny_with_vocab.gguf`); Skips when unset | Drives the real `libdllm.so`: ABI check, load, tokenize `[2,18]`, deterministic generate, cancel (incl. mid-stream `Dllm.Cancel` aborting a deliberately slow `eb_max_steps:256` run in ~10ms) |
|
||||
| `tests/e2e-backends/dllm_test.go` | `BACKEND_TEST_DLLM=1` + `BACKEND_BINARY` (packaged run.sh) + `BACKEND_TEST_MODEL_FILE` (tiny fixture) | Templated chat round trip (Messages + UseTokenizerTemplate) over the real gRPC binary, non-streaming + streaming; plus client-context cancellation mid-stream (proves the `Cancellable` server plumbing end to end) |
|
||||
| Real-model e2e | `BACKEND_TEST_DLLM_REAL_MODEL_FILE` (26B BF16, ~50 GB) + `BACKEND_TEST_DLLM_REAL_GPU_LAYERS` | CUDA-13-class hardware only |
|
||||
|
||||
Tool-call e2e is deliberately absent from the tiny-model spec: the fixture has
|
||||
random weights and cannot be coaxed into emitting tool markup; the unit tables
|
||||
carry that coverage.
|
||||
|
||||
## Build matrix
|
||||
|
||||
`cpu-dllm` (amd64 + arm64), `cuda13-dllm` (amd64), and
|
||||
`cuda13-nvidia-l4t-arm64-dllm` (arm64 CUDA: Jetson / DGX Spark GB10), via
|
||||
`.github/backend-matrix.yml`. No darwin/Metal. CUDA builds forward
|
||||
`-DDLLM_CUDA=ON` (dllm.cpp gates ggml's CUDA behind its own flag - a bare
|
||||
`-DGGML_CUDA=ON` is overridden by the cache FORCE). `libdllm.so` is
|
||||
self-contained (ggml statically absorbed, PIC), so `package.sh` only ships
|
||||
the binary, `run.sh` and that one .so (the parakeet-cpp-style stub layout;
|
||||
no ldd walk yet).
|
||||
|
||||
## Known limitations
|
||||
|
||||
- **Cancel granularity**: the C-ABI cancel flag is per-ctx and resets on
|
||||
every generate entry, so a Cancel racing a NEW generate can be lost, and
|
||||
with requests queued on the worker it aborts whichever generate is
|
||||
currently running (acceptable: the server de-registers the hook on normal
|
||||
completion, one process serves one model).
|
||||
- **Throughput**: ~0.15 tok/s on the 26B at default settings (GB10) - every
|
||||
denoise step recomputes the full prompt+canvas. The upstream prefix-KV
|
||||
cache (dllm.cpp P3) is the fix; `kv_cache:on` errors until it lands
|
||||
(`auto`/`off` are accepted no-ops).
|
||||
- **Repo privacy**: see the note at the top - CI clone of dllm.cpp needs the
|
||||
repo published (or credentials) before the backend images can build.
|
||||
- Engine spec/validation references: dllm.cpp `docs/validation.md` and
|
||||
LocalAI `docs/superpowers/specs/2026-06-10-dllm-cpp-design.md`.
|
||||
69
.github/backend-matrix.yml
vendored
69
.github/backend-matrix.yml
vendored
@@ -1608,19 +1608,6 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-dllm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1660,19 +1647,6 @@ include:
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-dllm'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1792,6 +1766,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-turboquant'
|
||||
builder-base-image: 'quay.io/go-skynet/ci-cache:base-grpc-rocm-amd64'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -3171,35 +3159,6 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# dllm
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-dllm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-dllm'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "dllm"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -38,10 +38,6 @@ jobs:
|
||||
variable: "PARAKEET_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/parakeet-cpp/Makefile"
|
||||
- repository: "mudler/dllm.cpp"
|
||||
variable: "DLLM_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/dllm/Makefile"
|
||||
- repository: "leejet/stable-diffusion.cpp"
|
||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||
branch: "master"
|
||||
|
||||
@@ -26,7 +26,6 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
|
||||
| [.agents/sglang-backend.md](.agents/sglang-backend.md) | Working on the SGLang backend — `engine_args` validation against ServerArgs, speculative-decoding (EAGLE/EAGLE3/DFLASH/MTP) recipes, parser handling |
|
||||
| [.agents/ds4-backend.md](.agents/ds4-backend.md) | Working on the ds4 backend - DSML state machine, thinking modes, KV cache, Metal+CUDA matrix |
|
||||
| [.agents/dllm-backend.md](.agents/dllm-backend.md) | Working on the dllm backend (DiffusionGemma block-diffusion) - purego C-ABI binding, per-ctx serialization contract, gemma4 renderer/parser, gated test layers |
|
||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
|
||||
8
Makefile
8
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/dllm backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -180,7 +180,7 @@ osx-signed: build
|
||||
|
||||
## Run
|
||||
run: ## run local-ai
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./cmd/local-ai
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
||||
|
||||
prepare-test: protogen-go build-mock-backend
|
||||
|
||||
@@ -1171,9 +1171,6 @@ BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|tr
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
||||
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
||||
# dllm is mudler/dllm.cpp, the DiffusionGemma block-diffusion engine,
|
||||
# wrapped by the purego backend at backend/go/dllm.
|
||||
BACKEND_DLLM = dllm|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
||||
@@ -1263,7 +1260,6 @@ $(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DLLM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
||||
|
||||
10
README.md
10
README.md
@@ -149,16 +149,6 @@ local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
||||
local-ai run oci://localai/phi-2:latest
|
||||
```
|
||||
|
||||
To test a running LocalAI server from the terminal, open an interactive chat session from another shell. Inside the prompt, `/models` lists installed models and `/model <name>` switches between them.
|
||||
|
||||
```bash
|
||||
# Terminal 1
|
||||
local-ai run llama-3.2-1b-instruct:q4_k_m
|
||||
|
||||
# Terminal 2
|
||||
local-ai chat --model llama-3.2-1b-instruct:q4_k_m
|
||||
```
|
||||
|
||||
> **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
|
||||
|
||||
For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).
|
||||
|
||||
@@ -60,12 +60,10 @@ elseif(DS4_GPU STREQUAL "cpu")
|
||||
set(DS4_OBJS "${DS4_DIR}/ds4_cpu.o")
|
||||
endif()
|
||||
|
||||
# ds4.c now references ds4_distributed.c (distributed inference) and ds4_ssd.c
|
||||
# (SSD expert-cache), each split into its own translation unit upstream. Both
|
||||
# are GPU-agnostic objects shared by every GPU mode, so link them in regardless
|
||||
# of DS4_GPU.
|
||||
# ds4.c now references ds4_distributed.c (distributed inference was split into
|
||||
# its own translation unit upstream). It is a single GPU-agnostic object shared
|
||||
# by every GPU mode, so link it in regardless of DS4_GPU.
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_distributed.o")
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_ssd.o")
|
||||
|
||||
add_executable(${TARGET}
|
||||
grpc-server.cpp
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
|
||||
# Upstream pin lives below as DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=8384adf0f9fa0f3bb342dd925372de778b95b263
|
||||
DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
@@ -18,20 +18,19 @@ UNAME_S := $(shell uname -s)
|
||||
|
||||
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
|
||||
|
||||
# ds4_distributed.o and ds4_ssd.o are GPU-agnostic translation units that
|
||||
# ds4.c/ds4_cpu.o now reference (upstream split distributed inference and the
|
||||
# SSD expert-cache into their own .c files). Both objects are shared by every
|
||||
# GPU mode, so they are appended unconditionally below.
|
||||
# ds4_distributed.o is a GPU-agnostic translation unit that ds4.c/ds4_cpu.o now
|
||||
# reference (upstream split distributed inference into its own .c). The same
|
||||
# object is shared by every GPU mode, so it is appended unconditionally below.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS += -DDS4_GPU=cuda
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
CMAKE_ARGS += -DDS4_GPU=metal
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o
|
||||
else
|
||||
# CPU reference path (Linux only - macOS CPU path is broken by VM bug per ds4 README).
|
||||
CMAKE_ARGS += -DDS4_GPU=cpu
|
||||
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o ds4_ssd.o
|
||||
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o
|
||||
endif
|
||||
|
||||
ifneq ($(NATIVE),true)
|
||||
@@ -56,11 +55,11 @@ ds4:
|
||||
# the right per-platform compile flags (Objective-C/Metal on Darwin, nvcc on Linux+CUDA).
|
||||
ds4/ds4.o: ds4
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o ds4_ssd.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o ds4_ssd.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o
|
||||
else
|
||||
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o ds4_ssd.o
|
||||
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o
|
||||
endif
|
||||
|
||||
grpc-server: ds4/ds4.o
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=e6f8112f3ba126eed3ff5b30cdd08085414a7516
|
||||
IK_LLAMA_VERSION?=1520eda980564241434b791ce2bbbd128c4be9ea
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=039e20a2db9e87b2477c76cc04905f3e1acad77f
|
||||
LLAMA_VERSION?=7c158fbb4aec1bdc9c81d6ca0e785139f4826fae
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -381,15 +381,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
});
|
||||
}
|
||||
|
||||
// for each video in the request, add the video data
|
||||
for (int i = 0; i < predict->videos_size(); i++) {
|
||||
data["video_data"].push_back(json
|
||||
{
|
||||
{"id", i},
|
||||
{"data", predict->videos(i)},
|
||||
});
|
||||
}
|
||||
|
||||
data["stop"] = predict->stopprompts();
|
||||
// data["n_probs"] = predict->nprobs();
|
||||
//TODO: images,
|
||||
@@ -491,13 +482,23 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!request->draftmodel().empty()) {
|
||||
params.speculative.draft.mparams.path = request->draftmodel();
|
||||
// Default to draft type if a draft model is set but no explicit type.
|
||||
// Upstream made the speculative type a vector (ggml-org/llama.cpp#22838)
|
||||
// and renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE (#22964).
|
||||
// Upstream (post ggml-org/llama.cpp#22838) made the speculative type a
|
||||
// vector; the turboquant fork still uses the legacy scalar. The
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
||||
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
|
||||
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
}
|
||||
#else
|
||||
const bool no_spec_type = params.speculative.types.empty() ||
|
||||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
||||
if (no_spec_type) {
|
||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// params.model_alias ??
|
||||
@@ -573,10 +574,9 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// tokens (0 disables the minimum). Match upstream's default (256). This
|
||||
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
|
||||
// also shifted from a fixed cadence to a minimum spacing. The turboquant
|
||||
// fork still lacks common_params::checkpoint_min_step, so skip it there
|
||||
// (LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh).
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
// fork branched before the field existed, so skip it on the legacy path
|
||||
// (LOCALAI_LEGACY_LLAMA_CPP_SPEC is injected by patch-grpc-server.sh).
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
params.checkpoint_min_step = 256;
|
||||
#endif
|
||||
|
||||
@@ -752,7 +752,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
#ifndef LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
|
||||
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
|
||||
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
|
||||
@@ -906,6 +906,17 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// Fork only knows a single scalar `type`. Take the first comma-
|
||||
// separated value and assign it via the singular helper.
|
||||
std::string first = optval_str;
|
||||
const auto comma = first.find(',');
|
||||
if (comma != std::string::npos) first = first.substr(0, comma);
|
||||
auto type = common_speculative_type_from_name(first);
|
||||
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||
params.speculative.type = type;
|
||||
}
|
||||
#else
|
||||
// Upstream switched to a vector of types (comma-separated for multi-type
|
||||
// chaining via common_speculative_types_from_names). We keep accepting a
|
||||
// single value here, but also tolerate comma-separated lists.
|
||||
@@ -934,6 +945,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!parsed.empty()) {
|
||||
params.speculative.types = parsed;
|
||||
}
|
||||
#endif
|
||||
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.draft.n_max = std::stoi(optval_str); } catch (...) {}
|
||||
@@ -971,6 +983,21 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// shares the target context size. Accept the option for backward
|
||||
// compatibility but silently ignore it.
|
||||
|
||||
// Everything below relies on struct shape introduced in ggml-org/llama.cpp#22838
|
||||
// (parallel drafting): `ngram_mod`, `ngram_map_k`, `ngram_map_k4v`,
|
||||
// `ngram_cache`, and the `draft.{cache_type_*, cpuparams*, tensor_buft_overrides}`
|
||||
// fields. The turboquant fork branched before that, so its build defines
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC via patch-grpc-server.sh and these option
|
||||
// keys become unrecognized (silently dropped, like any unknown opt) for it.
|
||||
//
|
||||
// The `#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC` / `#else` split below sits at the
|
||||
// closing-brace position of the `draft_ctx_size` branch on purpose: in the
|
||||
// legacy build the chain ends here (the brace closes draft_ctx_size), and in
|
||||
// the modern build the chain continues with `} else if (...)` instead, so the
|
||||
// brace count stays balanced under both branches of the preprocessor.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
}
|
||||
#else
|
||||
// --- ngram_mod family (upstream --spec-ngram-mod-*) ---
|
||||
} else if (!strcmp(optname, "spec_ngram_mod_n_min")) {
|
||||
if (optval != NULL) {
|
||||
@@ -1100,6 +1127,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
}
|
||||
if (!cur.empty()) flush(cur);
|
||||
}
|
||||
#endif // LOCALAI_LEGACY_LLAMA_CPP_SPEC — closes the `else`/`#ifdef` opened at draft_ctx_size
|
||||
}
|
||||
|
||||
// Set params.n_parallel from environment variable if not set via options (fallback)
|
||||
@@ -1149,11 +1177,15 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
}
|
||||
// Terminate the draft tensor_buft_overrides list with a sentinel, mirroring
|
||||
// the main-model handling above.
|
||||
// The draft tensor_buft_overrides are only populated under the modern
|
||||
// (post-#22838) layout, whose population code is itself gated by
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC above. The turboquant fork lacks
|
||||
// common_params_speculative::draft entirely, so skip the sentinel there too.
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
|
||||
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: Add yarn
|
||||
|
||||
@@ -1512,7 +1544,7 @@ public:
|
||||
msg_json["role"] = msg.role();
|
||||
|
||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
|
||||
|
||||
// Handle content - can be string, null, or array
|
||||
// For multimodal content, we'll embed images/audio from separate fields
|
||||
@@ -1563,16 +1595,6 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else {
|
||||
// Use content as-is (already array or not last user message)
|
||||
@@ -1607,16 +1629,6 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else if (msg.role() == "tool") {
|
||||
// Tool role messages must have content field set, even if empty
|
||||
@@ -2068,16 +2080,6 @@ public:
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &video_data = data.find("video_data");
|
||||
if (video_data != data.end() && video_data->is_array())
|
||||
{
|
||||
for (const auto &video : *video_data)
|
||||
{
|
||||
auto decoded_data = base64_decode(video["data"].get<std::string>());
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const bool has_mtmd = ctx_server.impl->mctx != nullptr;
|
||||
@@ -2330,7 +2332,7 @@ public:
|
||||
}
|
||||
|
||||
bool is_last_user_msg = (i == last_user_msg_idx);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0 || request->videos_size() > 0);
|
||||
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
|
||||
|
||||
// Handle content - can be string, null, or array
|
||||
// For multimodal content, we'll embed images/audio from separate fields
|
||||
@@ -2383,16 +2385,6 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
} else {
|
||||
// Use content as-is (already array or not last user message)
|
||||
@@ -2432,16 +2424,6 @@ public:
|
||||
content_array.push_back(audio_chunk);
|
||||
}
|
||||
}
|
||||
if (request->videos_size() > 0) {
|
||||
for (int j = 0; j < request->videos_size(); j++) {
|
||||
json video_chunk;
|
||||
video_chunk["type"] = "input_video";
|
||||
json input_video;
|
||||
input_video["data"] = request->videos(j);
|
||||
video_chunk["input_video"] = input_video;
|
||||
content_array.push_back(video_chunk);
|
||||
}
|
||||
}
|
||||
msg_json["content"] = content_array;
|
||||
SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i);
|
||||
} else if (!msg.tool_calls().empty()) {
|
||||
@@ -2904,16 +2886,6 @@ public:
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &video_data = data.find("video_data");
|
||||
if (video_data != data.end() && video_data->is_array())
|
||||
{
|
||||
for (const auto &video : *video_data)
|
||||
{
|
||||
auto decoded_data = base64_decode(video["data"].get<std::string>());
|
||||
files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process files
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=7d9715f1f071fa07c7b2ad3dbfd320b314139e65
|
||||
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -4,19 +4,21 @@
|
||||
#
|
||||
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
|
||||
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file
|
||||
# so the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default and the option handler).
|
||||
# That field does not exist in the fork yet; drop this once it does.
|
||||
#
|
||||
# The fork used to lag upstream on the whole common_params_speculative refactor
|
||||
# (ggml-org/llama.cpp#22397/#22838/#22964), the model_tgt rename (#22838) and
|
||||
# get_media_marker (#21962), which required a much larger compat shim here
|
||||
# (flat-field sed renames + a coarse LOCALAI_LEGACY_LLAMA_CPP_SPEC define). The
|
||||
# fork has since rebased past all of those, so the only remaining gap is
|
||||
# checkpoint_min_step. If a future bump reintroduces a divergence, add a narrow
|
||||
# guard in grpc-server.cpp keyed on a fork-specific macro and inject it here
|
||||
# rather than resurrecting the coarse one.
|
||||
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
|
||||
# server-side random per-instance marker) with the legacy "<__media__>"
|
||||
# literal. The fork branched before that PR, so server-common.cpp has no
|
||||
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
|
||||
# "<__media__>", and Go-side tooling falls back to that sentinel when the
|
||||
# backend does not expose media_marker, so substituting the literal keeps
|
||||
# behavior identical on the turboquant path.
|
||||
# 3. Revert the `common_params_speculative` field references to the
|
||||
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
|
||||
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
|
||||
# the turboquant fork branched before that PR and still exposes the flat
|
||||
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
|
||||
# below map the new nested paths back to the legacy flat names so the
|
||||
# shared grpc-server.cpp keeps compiling against the fork's common.h.
|
||||
# Drop this block once the fork rebases past #22397.
|
||||
#
|
||||
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
|
||||
@@ -70,20 +72,72 @@ else
|
||||
echo "==> KV allow-list patch OK"
|
||||
fi
|
||||
|
||||
# 2. Define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top of the file so
|
||||
# the grpc-server option parser skips the two references to
|
||||
# common_params::checkpoint_min_step (the default assignment and the option
|
||||
# handler). That field does not exist in the fork yet. Drop this block once
|
||||
# the fork rebases past the bump that added checkpoint_min_step.
|
||||
if grep -q '^#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP, skipping"
|
||||
if grep -q 'get_media_marker()' "$SRC"; then
|
||||
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
|
||||
# Only one call site today (ModelMetadata), but replace all occurrences to
|
||||
# stay robust if upstream adds more. Use a temp file to avoid relying on
|
||||
# sed -i portability (the builder image uses GNU sed, but keeping this
|
||||
# consistent with the awk block above).
|
||||
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> get_media_marker() substitution OK"
|
||||
else
|
||||
echo "==> patching $SRC to define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP at the top"
|
||||
# Insert the define before the very first `#include` so it precedes the
|
||||
# checkpoint_min_step references.
|
||||
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
|
||||
fi
|
||||
|
||||
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
|
||||
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
|
||||
# Each substitution is the exact post-refactor path → legacy flat field.
|
||||
# Order doesn't matter because the source paths are disjoint, but we keep
|
||||
# the most-specific (mparams.path) first for readability.
|
||||
sed -E \
|
||||
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
|
||||
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
|
||||
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
|
||||
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
|
||||
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
|
||||
"$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> speculative field rename OK"
|
||||
else
|
||||
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
|
||||
fi
|
||||
|
||||
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
|
||||
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
|
||||
# exposes the field as `model` on `server_context_impl`. The two call sites
|
||||
# are in the Rerank and ModelMetadata RPC handlers.
|
||||
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
|
||||
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
|
||||
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> model_tgt rename OK"
|
||||
else
|
||||
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
|
||||
fi
|
||||
|
||||
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
|
||||
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
|
||||
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
|
||||
# draft.tensor_buft_overrides) introduced for the post-#22838 layout, the
|
||||
# draft.tensor_buft_overrides sentinel termination, and the
|
||||
# common_params::checkpoint_min_step default/option (added with the
|
||||
# 35c9b1f3 bump). Those blocks reference struct fields that simply do not
|
||||
# exist in the fork.
|
||||
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
|
||||
else
|
||||
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
|
||||
# Insert the define before the very first `#include` so it precedes all the
|
||||
# speculative-decoding code paths.
|
||||
awk '
|
||||
!done && /^#include/ {
|
||||
print "#define LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP 1"
|
||||
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
|
||||
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
|
||||
print ""
|
||||
done = 1
|
||||
@@ -91,13 +145,13 @@ else
|
||||
{ print }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP" > "/dev/stderr"
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> LOCALAI_TURBOQUANT_NO_CHECKPOINT_MIN_STEP define OK"
|
||||
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
|
||||
fi
|
||||
|
||||
echo "==> all patches applied"
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
hip: port the turboquant CUDA additions that ggml's HIP shim doesn't cover
|
||||
|
||||
The turboquant fork adds/modifies a few ggml-cuda.cu spots with CUDA APIs
|
||||
that ggml's HIP (and MUSA) compatibility layer does not provide, breaking
|
||||
the -gpu-rocm-hipblas-turboquant build:
|
||||
|
||||
1. ggml_cuda_copy2d_across_devices() (host-staged cross-device copy for
|
||||
split mul_mat output) uses the CUDA 3D-peer copy APIs
|
||||
cudaMemcpy3DPeerParms / make_cudaPitchedPtr / make_cudaExtent /
|
||||
cudaMemcpy3DPeerAsync. HIP genuinely does not support these (see the
|
||||
fork's own comment "HIP does not support cudaMemcpy3DPeerAsync"), so
|
||||
guard the peer fast path with #if !defined(GGML_USE_HIP) &&
|
||||
!defined(GGML_USE_MUSA) -- matching how the fork already guards the
|
||||
same API for the sibling 2D copy -- and fall through to the existing
|
||||
cudaMemcpyAsync staging fallback below (functionally identical,
|
||||
slightly slower on multi-GPU ROCm).
|
||||
|
||||
2. ggml_backend_cuda_device_event_new() creates its event with plain
|
||||
cudaEventCreate, which ggml's HIP shim does not alias (it only aliases
|
||||
cudaEventCreateWithFlags). Use cudaEventCreateWithFlags(...,
|
||||
cudaEventDisableTiming) -- exactly what the rest of this file already
|
||||
does (cf. lines ~1034, ~3461) and HIP-safe.
|
||||
|
||||
CUDA builds are unaffected. Drop the relevant hunk once the fork HIP-ports
|
||||
these; apply-patches.sh fails fast if an anchor goes stale.
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 0427e6b..6352e6a 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -1933,6 +1933,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
size_t width, size_t height, cudaStream_t dst_stream, cudaStream_t src_stream) {
|
||||
|
||||
const auto & info = ggml_cuda_info();
|
||||
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // 3D-peer copy types unmapped by ggml's HIP/MUSA shim; use staging fallback below
|
||||
if (info.peer_access[src_device][dst_device]) {
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = dst_device;
|
||||
@@ -1942,6 +1943,7 @@ static cudaError_t ggml_cuda_copy2d_across_devices(
|
||||
p.extent = make_cudaExtent(width, height, 1);
|
||||
return cudaMemcpy3DPeerAsync(&p, dst_stream);
|
||||
}
|
||||
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
// Fallback: stage all rows through a single contiguous pinned buffer
|
||||
int prev_device = ggml_cuda_get_device();
|
||||
@@ -5714,7 +5716,7 @@ static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_
|
||||
ggml_cuda_set_device(dev_ctx->device);
|
||||
|
||||
cudaEvent_t event;
|
||||
- CUDA_CHECK(cudaEventCreate(&event));
|
||||
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
|
||||
|
||||
return new ggml_backend_event {
|
||||
/* .device = */ dev,
|
||||
@@ -14,7 +14,7 @@ target_include_directories(gocrispasr PRIVATE
|
||||
# whisper. crispasr is the referencer; the backend static libs supply the
|
||||
# per-architecture symbols; ggml is the math/runtime base.
|
||||
target_link_libraries(gocrispasr PRIVATE
|
||||
crispasr-lib
|
||||
crispasr
|
||||
parakeet canary canary_ctc cohere granite_speech granite_nle
|
||||
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
|
||||
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=c29f6653a516a3001d923944dad8892072cc7334
|
||||
CRISPASR_VERSION?=13d54e110e1538e0f0bc3af0680b9ab246cfb48d
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
10
backend/go/dllm/.gitignore
vendored
10
backend/go/dllm/.gitignore
vendored
@@ -1,10 +0,0 @@
|
||||
.cache/
|
||||
sources/
|
||||
build/
|
||||
package/
|
||||
dllm-grpc
|
||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||
# symlinked for local dev; the real sources live in dllm.cpp upstream.
|
||||
*.so
|
||||
*.so.*
|
||||
compile_commands.json
|
||||
@@ -1,93 +0,0 @@
|
||||
# dllm backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DLLM_VERSION?=<sha> so .github/bump_deps.sh
|
||||
# can find and update it - matches the whisper.cpp / parakeet-cpp / ds4
|
||||
# convention.
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree dllm.cpp build,
|
||||
# you can symlink the .so into this directory and skip the clone/cmake
|
||||
# steps entirely, e.g.:
|
||||
#
|
||||
# ln -sf /path/to/dllm.cpp/build/libdllm.so .
|
||||
# go build -o dllm-grpc .
|
||||
#
|
||||
# That's what the gated C-ABI binding smoke uses (DLLM_TEST_LIBRARY). The
|
||||
# default target below does the proper clone-at-pin + cmake build so CI
|
||||
# doesn't need a side-checkout.
|
||||
#
|
||||
# NOTE: github.com/mudler/dllm.cpp is still private (publishing is planned);
|
||||
# until then the anonymous clone below fails. Use the symlink shortcut above
|
||||
# with a local checkout, or a git credential helper with access to the repo.
|
||||
|
||||
DLLM_VERSION?=b22fcebebfb225131113188599a9ae542b2935d7
|
||||
DLLM_REPO?=https://github.com/mudler/dllm.cpp
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
# libdllm.so is self-contained: dllm.cpp's CMakeLists statically absorbs ggml
|
||||
# (BUILD_SHARED_LIBS=OFF + PIC) into the shared lib, so dlopen needs no
|
||||
# libggml*.so alongside it, only system libs (libstdc++/libgomp/libc) the
|
||||
# runtime image already provides. Tests/CLI are upstream-only concerns.
|
||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DDLLM_BUILD_TESTS=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# Same arch set the sibling ggml backends (acestep/vibevoice/qwen3-tts) bake
|
||||
# for their cublas images; override for a native build.
|
||||
CUDA_ARCHITECTURES?=75-virtual;80-virtual;86-real;89-real
|
||||
|
||||
# dllm.cpp gates CUDA behind DLLM_CUDA (set(GGML_CUDA ... CACHE FORCE)), so
|
||||
# forward that instead of a bare -DGGML_CUDA=ON.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DDLLM_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="$(CUDA_ARCHITECTURES)"
|
||||
endif
|
||||
|
||||
.PHONY: dllm-grpc package build clean purge test all
|
||||
|
||||
all: dllm-grpc
|
||||
|
||||
# Clone the upstream dllm.cpp source at the pinned commit (ggml comes in as
|
||||
# a submodule). Directory acts as the target so make only re-clones when
|
||||
# missing. After a DLLM_VERSION bump, run 'make purge && make' to refetch.
|
||||
sources/dllm.cpp:
|
||||
mkdir -p sources/dllm.cpp
|
||||
cd sources/dllm.cpp && \
|
||||
git init -q && \
|
||||
git remote add origin $(DLLM_REPO) && \
|
||||
git fetch --depth 1 origin $(DLLM_VERSION) && \
|
||||
git checkout FETCH_HEAD && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Build the shared lib out-of-tree, then stage it next to the Go sources so
|
||||
# purego.Dlopen("libdllm.so") and the packaging step both pick it up.
|
||||
libdllm.so: sources/dllm.cpp
|
||||
cmake -B sources/dllm.cpp/build -S sources/dllm.cpp $(CMAKE_ARGS)
|
||||
cmake --build sources/dllm.cpp/build --config Release -j$(JOBS)
|
||||
cp -fv sources/dllm.cpp/build/libdllm.so ./
|
||||
|
||||
dllm-grpc: libdllm.so main.go capi.go
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o dllm-grpc .
|
||||
|
||||
package: dllm-grpc
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
# Test target. The C-ABI binding smoke is gated on DLLM_TEST_LIBRARY +
|
||||
# DLLM_TEST_TINY_MODEL; without them the gated specs auto-skip and only the
|
||||
# pure-Go helper specs run.
|
||||
test:
|
||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||
|
||||
clean: purge
|
||||
rm -rf libdllm.so* package dllm-grpc
|
||||
|
||||
purge:
|
||||
rm -rf sources/dllm.cpp
|
||||
@@ -1,256 +0,0 @@
|
||||
package main
|
||||
|
||||
// Typed Go wrappers over dllm.cpp's flat C-ABI (include/dllm_capi.h, ABI v1).
|
||||
//
|
||||
// Contract highlights the wrappers encode (see the header + src/capi.cpp):
|
||||
// - tokenize_json/generate return malloc'd char* the CALLER owns: bound as
|
||||
// uintptr, copied with goStringFromCPtr, released via dllm_capi_free_string.
|
||||
// - last_error returns a BORROWED pointer (valid until the next call on the
|
||||
// same ctx): bound as a plain string (purego copies), never freed, and only
|
||||
// read AFTER the failing call has returned - reading it while a generate is
|
||||
// in flight on the same ctx violates the per-ctx serialization contract.
|
||||
// - All entry points except dllm_capi_cancel must be externally serialized
|
||||
// per ctx (one ctx = one concurrent generate/tokenize). Cancel only flips
|
||||
// an atomic and may be called from any goroutine mid-generate.
|
||||
// - No C++ exception crosses the boundary; failures land in last_error.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
)
|
||||
|
||||
// dllmABIVersion is the DLLM_CAPI_ABI_VERSION this binding was written
|
||||
// against; main.go refuses to start against a libdllm.so reporting another.
|
||||
const dllmABIVersion = 1
|
||||
|
||||
// purego-bound entry points from libdllm.so. Names match dllm_capi.h
|
||||
// exactly; loadCAPI (main.go) fills these in at boot.
|
||||
var (
|
||||
cppAbiVersion func() int32
|
||||
cppLoad func(ggufPath, paramsJSON string) uintptr
|
||||
cppFree func(ctx uintptr)
|
||||
cppLastError func(ctx uintptr) string // borrowed pointer: purego copies, do NOT free
|
||||
cppFreeString func(s uintptr)
|
||||
// malloc'd char* returns, hence uintptr (see loadCAPI's doc comment).
|
||||
cppTokenizeJSON func(ctx uintptr, text string) uintptr
|
||||
cppGenerate func(ctx uintptr, prompt, optsJSON string) uintptr
|
||||
// on_block/on_step are C function pointers produced by purego.NewCallback;
|
||||
// userData carries the streamCallStates registry key.
|
||||
cppGenerateStream func(ctx uintptr, prompt, optsJSON string, onBlock, onStep, userData uintptr) int32
|
||||
cppCancel func(ctx uintptr)
|
||||
)
|
||||
|
||||
// cAbiVersion returns the library's DLLM_CAPI_ABI_VERSION.
|
||||
func cAbiVersion() int32 {
|
||||
return cppAbiVersion()
|
||||
}
|
||||
|
||||
// cLoad opens the GGUF at path with the flat params JSON (e.g.
|
||||
// {"n_gpu_layers":99}). Returns 0 on failure; per the header contract there
|
||||
// is no ctx to carry the reason, the C side logs it to stderr (and
|
||||
// cLastError(0) only yields the static NULL-ctx message).
|
||||
func cLoad(path, paramsJSON string) uintptr {
|
||||
return cppLoad(path, paramsJSON)
|
||||
}
|
||||
|
||||
// cFree releases a ctx; safe on 0 (delete nullptr).
|
||||
func cFree(h uintptr) {
|
||||
cppFree(h)
|
||||
}
|
||||
|
||||
// cLastError returns the ctx's last error message (or the static NULL-ctx
|
||||
// message for h==0). The C pointer is borrowed and only valid until the next
|
||||
// call on the same ctx; purego's string return copies it immediately, so the
|
||||
// returned Go string is safe to keep. Must not be called while another call
|
||||
// on the same ctx is in flight.
|
||||
func cLastError(h uintptr) string {
|
||||
return cppLastError(h)
|
||||
}
|
||||
|
||||
// lastErrorOr is cLastError with a fallback for the empty-message case, so
|
||||
// wrapped errors never end in ": ".
|
||||
func lastErrorOr(h uintptr, fallback string) string {
|
||||
if msg := cLastError(h); msg != "" {
|
||||
return msg
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// cTokenizeJSON tokenizes text (the C side prepends bos per vocab.add_bos)
|
||||
// and returns the token ids as a JSON array string, e.g. "[2,18]".
|
||||
func cTokenizeJSON(h uintptr, text string) (string, error) {
|
||||
ret := cppTokenizeJSON(h, text)
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("dllm: tokenize failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
out := goStringFromCPtr(ret)
|
||||
cppFreeString(ret)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// cGenerate runs a blocking generation and returns the detokenized text.
|
||||
// optsJSON must be a FLAT JSON object of scalars (use buildOptsJSON); the C
|
||||
// parser rejects nested objects/arrays. NULL return -> last_error (read only
|
||||
// after the call returned, per the serialization contract); a cancelled call
|
||||
// surfaces as the "cancelled" message.
|
||||
func cGenerate(h uintptr, prompt, optsJSON string) (string, error) {
|
||||
ret := cppGenerate(h, prompt, optsJSON)
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("dllm: generate failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
out := goStringFromCPtr(ret)
|
||||
cppFreeString(ret)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// streamCallState carries the Go callbacks for one in-flight
|
||||
// cGenerateStream call; the registry key travels through C as user_data.
|
||||
// The map shape mirrors the whisper backend's streamCallStates: only one
|
||||
// entry per ctx is ever live (the C-ABI is serialized per ctx), but keying
|
||||
// by call survives multiple models/processes sharing the package.
|
||||
type streamCallState struct {
|
||||
onBlock func(text string)
|
||||
onStep func(step, total int, preview string)
|
||||
}
|
||||
|
||||
var (
|
||||
streamCallStates sync.Map // uint64 -> *streamCallState
|
||||
streamCallSeq atomic.Uint64
|
||||
|
||||
// purego.NewCallback allocates a finite, never-released callback slot, so
|
||||
// the two trampolines are created exactly once and reused across calls.
|
||||
streamCbOnce sync.Once
|
||||
blockCbPtr uintptr
|
||||
stepCbPtr uintptr
|
||||
)
|
||||
|
||||
// onBlockTrampoline is the Go side of dllm_block_cb. It runs on the C
|
||||
// calling thread, mid-generate: keep it tiny and non-blocking (callers that
|
||||
// bridge to goroutines must hand off via buffered channels). The text
|
||||
// pointer is only valid for the duration of the invocation, so it is copied
|
||||
// to a Go string immediately.
|
||||
func onBlockTrampoline(text uintptr, userData uintptr) {
|
||||
v, ok := streamCallStates.Load(uint64(userData))
|
||||
if !ok {
|
||||
return // call already torn down
|
||||
}
|
||||
state := v.(*streamCallState)
|
||||
if state.onBlock != nil {
|
||||
state.onBlock(goStringFromCPtr(text))
|
||||
}
|
||||
}
|
||||
|
||||
// onStepTrampoline is the Go side of dllm_step_cb; same threading and
|
||||
// lifetime caveats as onBlockTrampoline.
|
||||
func onStepTrampoline(step int32, totalSteps int32, canvasPreview uintptr, userData uintptr) {
|
||||
v, ok := streamCallStates.Load(uint64(userData))
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
state := v.(*streamCallState)
|
||||
if state.onStep != nil {
|
||||
state.onStep(int(step), int(totalSteps), goStringFromCPtr(canvasPreview))
|
||||
}
|
||||
}
|
||||
|
||||
// cGenerateStream runs a generation with per-committed-block (onBlock) and
|
||||
// per-denoising-step (onStep) callbacks; either may be nil. The callbacks
|
||||
// run on the C thread (see the trampoline docs). Returns an error carrying
|
||||
// last_error on failure; cancellation surfaces as the "cancelled" message.
|
||||
func cGenerateStream(h uintptr, prompt, optsJSON string, onBlock func(text string), onStep func(step, total int, preview string)) error {
|
||||
streamCbOnce.Do(func() {
|
||||
blockCbPtr = purego.NewCallback(onBlockTrampoline)
|
||||
stepCbPtr = purego.NewCallback(onStepTrampoline)
|
||||
})
|
||||
|
||||
id := streamCallSeq.Add(1)
|
||||
streamCallStates.Store(id, &streamCallState{onBlock: onBlock, onStep: onStep})
|
||||
defer streamCallStates.Delete(id)
|
||||
|
||||
// Pass NULL for absent callbacks so the C side skips the per-block /
|
||||
// per-step detokenize work entirely.
|
||||
var blockPtr, stepPtr uintptr
|
||||
if onBlock != nil {
|
||||
blockPtr = blockCbPtr
|
||||
}
|
||||
if onStep != nil {
|
||||
stepPtr = stepCbPtr
|
||||
}
|
||||
|
||||
if rc := cppGenerateStream(h, prompt, optsJSON, blockPtr, stepPtr, uintptr(id)); rc != 0 {
|
||||
return fmt.Errorf("dllm: generate_stream failed: %s", lastErrorOr(h, "unknown error"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cCancel requests cancellation of the in-flight generate on h. This is the
|
||||
// ONE entry point safe to call from any goroutine while a generate runs (it
|
||||
// only flips an atomic). Note the cancel-reset race from the header: each
|
||||
// generate resets the flag on entry, so a watchdog should re-issue cancel if
|
||||
// the call has not returned.
|
||||
func cCancel(h uintptr) {
|
||||
cppCancel(h)
|
||||
}
|
||||
|
||||
// buildOptsJSON renders generation options as the flat JSON object the
|
||||
// C-ABI expects (known keys: n_predict, blocks, seed, eb_*, kv_cache). The
|
||||
// C-side scanner only understands scalar number/string values and rejects
|
||||
// nested objects/arrays loudly; bools are rejected here too because the
|
||||
// scanner has no concept of them. Fail loud rather than let an option be
|
||||
// silently misread.
|
||||
//
|
||||
// CAVEAT: json.Marshal HTML-escapes <, > and & inside string values (e.g.
|
||||
// "<" becomes the six-byte \u003c sequence). None of the known string-valued keys
|
||||
// (kv_cache: auto|on|off) can contain those bytes today; if one ever does,
|
||||
// switch to an Encoder with SetEscapeHTML(false) like gemma4JSONString.
|
||||
func buildOptsJSON(opts map[string]any) (string, error) {
|
||||
if len(opts) == 0 {
|
||||
return "{}", nil
|
||||
}
|
||||
for k, v := range opts {
|
||||
switch v.(type) {
|
||||
case string,
|
||||
int, int8, int16, int32, int64,
|
||||
uint, uint8, uint16, uint32, uint64,
|
||||
float32, float64,
|
||||
json.Number:
|
||||
// scalar: fine
|
||||
default:
|
||||
return "", fmt.Errorf("dllm: opts key %q has non-scalar value %T (the C-ABI only accepts flat number/string scalars)", k, v)
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(opts)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("dllm: marshal opts: %w", err)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is
|
||||
// the raw pointer returned by purego from the C-ABI (a malloc'd buffer the
|
||||
// caller owns, or a callback argument only valid during the invocation);
|
||||
// owning callers must free it via cppFreeString after the copy lands.
|
||||
//
|
||||
// A direct unsafe.Pointer(cptr) conversion trips go vet's unsafeptr check,
|
||||
// which can't distinguish a C-owned heap pointer from Go-managed memory (the
|
||||
// parakeet-cpp and whisper backends tolerate that warning). Reinterpreting
|
||||
// through &cptr below is equivalent at runtime and keeps plain `go vet`
|
||||
// clean. It is safe either way: the pointer addresses C memory the Go GC
|
||||
// neither tracks nor moves, and we dereference it immediately to copy the
|
||||
// bytes out.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := *(*unsafe.Pointer)(unsafe.Pointer(&cptr)) // C-owned buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
}
|
||||
return string(unsafe.Slice((*byte)(p), n))
|
||||
}
|
||||
@@ -1,553 +0,0 @@
|
||||
package main
|
||||
|
||||
// LocalAI gRPC backend for dllm.cpp (DiffusionGemma block-diffusion models).
|
||||
//
|
||||
// Wiring overview:
|
||||
// - Load opens the GGUF via dllm_capi_load and starts the per-model worker
|
||||
// goroutine that serializes every C call (see submit).
|
||||
// - PredictRich / PredictStreamRich implement grpc.AIModelRich: when the
|
||||
// request carries raw messages (use_tokenizer_template), the backend owns
|
||||
// templating (RenderGemma4) and output parsing (Gemma4Parser) and replies
|
||||
// with ChatDeltas, like the llama.cpp autoparser and the ds4 backend.
|
||||
// - The legacy Predict / PredictStream methods delegate to the rich pair
|
||||
// (cloud-proxy precedent); the gRPC server prefers the rich path anyway.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// The gRPC server cancels in-flight generations on client disconnect only
|
||||
// for backends advertising the Cancellable capability; keep Dllm pinned to
|
||||
// it so a signature drift fails the build, not the disconnect path.
|
||||
var _ grpc.Cancellable = (*Dllm)(nil)
|
||||
|
||||
// generator is the seam between the backend wiring and the dllm.cpp C-ABI:
|
||||
// the real implementation (capiGenerator) wraps the cGenerate/cTokenizeJSON
|
||||
// family, while tests substitute a fake to exercise prompt construction,
|
||||
// parsing and serialization without libdllm.so.
|
||||
type generator interface {
|
||||
generate(prompt, optsJSON string) (string, error)
|
||||
// generateStream invokes onBlock once per committed diffusion block, on
|
||||
// the thread running the C call, before returning.
|
||||
generateStream(prompt, optsJSON string, onBlock func(text string)) error
|
||||
tokenizeJSON(text string) (string, error)
|
||||
// cancel is the ONE entry point safe to call concurrently with an
|
||||
// in-flight generate on the same ctx (dllm_capi.h: it only flips an
|
||||
// atomic; everything else must be externally serialized per ctx).
|
||||
cancel()
|
||||
free()
|
||||
}
|
||||
|
||||
// capiGenerator is the production generator over one dllm_ctx handle.
|
||||
type capiGenerator struct {
|
||||
h uintptr
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generate(prompt, optsJSON string) (string, error) {
|
||||
return cGenerate(g.h, prompt, optsJSON)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
|
||||
// on_step (per-denoise-step canvas preview, dllm.cpp's --visual) is
|
||||
// passed as nil for now: a future progress hook for the React UI can
|
||||
// plumb it through without touching the C binding.
|
||||
return cGenerateStream(g.h, prompt, optsJSON, onBlock, nil)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) tokenizeJSON(text string) (string, error) {
|
||||
return cTokenizeJSON(g.h, text)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) cancel() {
|
||||
cCancel(g.h)
|
||||
}
|
||||
|
||||
func (g *capiGenerator) free() {
|
||||
cFree(g.h)
|
||||
}
|
||||
|
||||
// Dllm is the gRPC backend instance: one per loaded model (LocalAI starts
|
||||
// one backend process per model).
|
||||
type Dllm struct {
|
||||
base.Base
|
||||
|
||||
gen generator
|
||||
// genOpts holds the model-level generation overrides parsed from
|
||||
// ModelOptions.Options at Load (eb_*, blocks, kv_cache). The C-ABI takes
|
||||
// them per-generate, not per-load, so they are merged into every
|
||||
// request's opts JSON (requestOptsJSON).
|
||||
genOpts map[string]any
|
||||
|
||||
// jobs is the per-model worker queue. dllm_capi.h requires every entry
|
||||
// point EXCEPT dllm_capi_cancel to be externally serialized per ctx (one
|
||||
// ctx = one concurrent generate/tokenize; last_error is unsafe to read
|
||||
// while a call is in flight). A single goroutine owning all C calls makes
|
||||
// that contract structural instead of relying on lock discipline.
|
||||
jobs chan func()
|
||||
workerWG sync.WaitGroup
|
||||
|
||||
// genMu guards gen against Free racing in-flight requests: requests hold
|
||||
// the read lock for their full duration (they stay concurrent with each
|
||||
// other - the worker still serializes the C calls), Free takes the write
|
||||
// lock so it can only run when no request is in flight.
|
||||
genMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (d *Dllm) startWorker() {
|
||||
d.jobs = make(chan func())
|
||||
d.workerWG.Add(1)
|
||||
go func() {
|
||||
defer d.workerWG.Done()
|
||||
for job := range d.jobs {
|
||||
job()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// submit runs job on the worker goroutine and waits for it to finish.
|
||||
// Concurrent gRPC requests therefore queue up and execute one at a time
|
||||
// against the single dllm_ctx.
|
||||
func (d *Dllm) submit(job func()) {
|
||||
done := make(chan struct{})
|
||||
d.jobs <- func() {
|
||||
defer close(done)
|
||||
job()
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
// Load opens the GGUF and prepares the worker. Load-time engine parameters
|
||||
// travel as the flat params JSON of dllm_capi_load; generation overrides
|
||||
// from Options are stored for per-request opts JSON instead (the C-ABI has
|
||||
// no per-load sampler state).
|
||||
func (d *Dllm) Load(opts *pb.ModelOptions) error {
|
||||
if d.gen != nil {
|
||||
return errors.New("dllm: model already loaded")
|
||||
}
|
||||
|
||||
params := map[string]any{
|
||||
"n_gpu_layers": opts.GetNGPULayers(),
|
||||
}
|
||||
if opts.GetThreads() > 0 {
|
||||
params["n_threads"] = opts.GetThreads()
|
||||
}
|
||||
if opts.GetContextSize() > 0 {
|
||||
params["ctx_len"] = opts.GetContextSize()
|
||||
}
|
||||
paramsJSON, err := buildOptsJSON(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.genOpts = parseModelGenOpts(opts.GetOptions())
|
||||
|
||||
h := cLoad(opts.GetModelFile(), paramsJSON)
|
||||
if h == 0 {
|
||||
// No ctx exists on load failure, so last_error(NULL) only carries the
|
||||
// static NULL-ctx message; the real reason is on the backend's stderr.
|
||||
return fmt.Errorf("dllm: load %q failed: %s (see backend log for details)",
|
||||
opts.GetModelFile(), lastErrorOr(0, "unknown error"))
|
||||
}
|
||||
d.gen = &capiGenerator{h: h}
|
||||
d.startWorker()
|
||||
xlog.Info("dllm: model loaded", "model", opts.GetModelFile(), "params", paramsJSON, "gen_opts", d.genOpts)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Free releases the dllm ctx and stops the worker. Safe when never loaded.
|
||||
//
|
||||
// The write lock is essential: the gRPC server (pkg/grpc/server.go, see the
|
||||
// model-unload path around line 764) calls Free with no locking of its own,
|
||||
// and base.Base provides none either. Without it a request racing Free would
|
||||
// panic sending on the closed jobs channel - or worse, generate on a freed C
|
||||
// ctx. Holding genMu until gen is nil also turns post-Free requests into a
|
||||
// clean "model not loaded" error instead of a crash.
|
||||
func (d *Dllm) Free() error {
|
||||
d.genMu.Lock()
|
||||
defer d.genMu.Unlock()
|
||||
if d.gen == nil {
|
||||
return nil
|
||||
}
|
||||
d.submit(d.gen.free)
|
||||
close(d.jobs)
|
||||
d.workerWG.Wait()
|
||||
d.gen = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel requests cancellation of the in-flight generate (the
|
||||
// grpc.Cancellable capability). The gRPC server arms it via
|
||||
// context.AfterFunc on the request/stream context, so a client
|
||||
// disconnect or timeout aborts the generation server-side - the same
|
||||
// semantics the llama.cpp C++ backend gets from polling IsCancelled().
|
||||
// It deliberately bypasses the worker queue: dllm_capi_cancel is the one
|
||||
// call the C-ABI allows from any goroutine mid-generate (it only flips
|
||||
// an atomic).
|
||||
//
|
||||
// Note dllm_capi.h's cancel-reset race: each generate resets the flag on
|
||||
// entry, so a Cancel racing a NEW generate on the same ctx can be lost
|
||||
// (and, with requests queued on the worker, it aborts whichever generate
|
||||
// is currently running). The single-flag granularity is acceptable here
|
||||
// because the server de-registers the hook on normal completion and one
|
||||
// backend process serves one model.
|
||||
func (d *Dllm) Cancel() {
|
||||
// RLock so a server-side AfterFunc firing in the window between a
|
||||
// request finishing and a model unload cannot touch a freed C ctx
|
||||
// (Free holds the write lock while tearing gen down). cancel() is the
|
||||
// one C call that is safe concurrently with an in-flight generate, so
|
||||
// taking a read lock here cannot deadlock against request holders.
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen != nil {
|
||||
d.gen.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// dllmGenOptKeys are the ModelOptions.Options keys this backend forwards to
|
||||
// the engine. Options is a shared free-form bag (other layers put their own
|
||||
// entries there), so unknown keys are skipped with a warning, not an error.
|
||||
var dllmGenOptKeys = map[string]bool{
|
||||
"blocks": true,
|
||||
"kv_cache": true, // "auto"|"on"|"off"; honored by the engine from P3
|
||||
}
|
||||
|
||||
// parseModelGenOpts parses "key:value" Options entries into the flat scalar
|
||||
// map merged into every generate's opts JSON. eb_* (Entropy-Bound sampler
|
||||
// knobs) and the keys in dllmGenOptKeys are recognized; values are typed by
|
||||
// first successful parse (int, then float, else string) to match the C
|
||||
// scanner's number/string scalars.
|
||||
func parseModelGenOpts(options []string) map[string]any {
|
||||
out := map[string]any{}
|
||||
for _, o := range options {
|
||||
key, val, found := strings.Cut(o, ":")
|
||||
if !found {
|
||||
xlog.Warn("dllm: ignoring malformed option (want key:value)", "option", o)
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(key, "eb_") && !dllmGenOptKeys[key] {
|
||||
xlog.Debug("dllm: ignoring unrecognized option", "key", key)
|
||||
continue
|
||||
}
|
||||
out[key] = parseScalarOpt(val)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseScalarOpt(v string) any {
|
||||
if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return iv
|
||||
}
|
||||
if fv, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return fv
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// metadataEnableThinking reads the enable_thinking gate. Unlike ds4 (default
|
||||
// ON, matching ds4-server), dllm defaults OFF: DiffusionGemma's chat
|
||||
// template guards every thinking branch with `enable_thinking is defined and
|
||||
// enable_thinking`, i.e. thinking is opt-in for this model family, and the
|
||||
// no-thinking render pre-closes an empty thought channel that the OFF
|
||||
// default must produce.
|
||||
func metadataEnableThinking(opts *pb.PredictOptions) bool {
|
||||
v := opts.GetMetadata()["enable_thinking"]
|
||||
return v == "true" || v == "1"
|
||||
}
|
||||
|
||||
// buildPrompt resolves the prompt for a request. With use_tokenizer_template
|
||||
// and raw messages the backend owns templating (RenderGemma4) and the output
|
||||
// is in the known gemma4 format, so parse=true. Without it the caller
|
||||
// templated the prompt themselves (LocalAI's Go templates + PEG fallback, or
|
||||
// a bare completion): the prompt passes through verbatim and the output is
|
||||
// NOT gemma4-parsed - it is emitted as plain content and the Go side's
|
||||
// extraction applies, as for any non-autoparsing backend.
|
||||
func buildPrompt(opts *pb.PredictOptions) (prompt string, parse bool, err error) {
|
||||
if opts.GetUseTokenizerTemplate() && len(opts.GetMessages()) > 0 {
|
||||
prompt, err = RenderGemma4(opts.GetMessages(), opts.GetTools(), metadataEnableThinking(opts), true)
|
||||
return prompt, true, err
|
||||
}
|
||||
return opts.GetPrompt(), false, nil
|
||||
}
|
||||
|
||||
// requestOptsJSON merges the model-level overrides with the request's
|
||||
// sampling fields into the flat opts JSON for one generate call.
|
||||
func (d *Dllm) requestOptsJSON(opts *pb.PredictOptions) (string, error) {
|
||||
m := make(map[string]any, len(d.genOpts)+2)
|
||||
for k, v := range d.genOpts {
|
||||
m[k] = v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
// The engine rounds n_predict UP to a whole number of diffusion
|
||||
// blocks (the canvas is denoised block-wise), so the completion may
|
||||
// run slightly past the requested budget. Tokens==0 omits the key so
|
||||
// the C-ABI default of 256 applies (hardcoded in capi.cpp's
|
||||
// parse_gen_opts, independent of canvas_length).
|
||||
m["n_predict"] = n
|
||||
}
|
||||
if s := opts.GetSeed(); s > 0 {
|
||||
// The engine seeds mt19937 with explicit non-negative seeds. Seed<=0
|
||||
// is omitted: proto3 cannot distinguish 0 from unset, and negative
|
||||
// values conventionally mean "random" across LocalAI backends.
|
||||
m["seed"] = s
|
||||
}
|
||||
return buildOptsJSON(m)
|
||||
}
|
||||
|
||||
// prepareRequest is the shared prologue of the rich methods: resolve the
|
||||
// prompt (and whether the output gets gemma4-parsed) and build the per-call
|
||||
// opts JSON.
|
||||
func (d *Dllm) prepareRequest(opts *pb.PredictOptions) (prompt string, parse bool, optsJSON string, err error) {
|
||||
prompt, parse, err = buildPrompt(opts)
|
||||
if err != nil {
|
||||
return "", false, "", err
|
||||
}
|
||||
optsJSON, err = d.requestOptsJSON(opts)
|
||||
if err != nil {
|
||||
return "", false, "", err
|
||||
}
|
||||
return prompt, parse, optsJSON, nil
|
||||
}
|
||||
|
||||
// sanitizeUTF8 makes s safe for a proto3 string field. Block-boundary
|
||||
// detokenization and byte-fallback tokens can produce invalid UTF-8, and
|
||||
// grpc-go refuses to marshal it ("string field contains invalid UTF-8"), so
|
||||
// every string destined for a Reply/ChatDelta must pass through here (or
|
||||
// through splitValidUTF8, which calls it). Lone malformed bytes are genuinely
|
||||
// undecodable: replace with U+FFFD rather than crash the stream.
|
||||
func sanitizeUTF8(s string) string {
|
||||
if utf8.ValidString(s) {
|
||||
return s
|
||||
}
|
||||
return strings.ToValidUTF8(s, "<22>")
|
||||
}
|
||||
|
||||
// utf8SeqLen returns the declared sequence length of a UTF-8 leading byte
|
||||
// (1 for bytes that can never lead a multi-byte sequence, so they are never
|
||||
// held back and fall through to sanitizeUTF8's replacement).
|
||||
func utf8SeqLen(b byte) int {
|
||||
switch {
|
||||
case b&0xE0 == 0xC0:
|
||||
return 2
|
||||
case b&0xF0 == 0xE0:
|
||||
return 3
|
||||
case b&0xF8 == 0xF0:
|
||||
return 4
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// splitValidUTF8 prepends the previous block's carry to the new block and
|
||||
// splits the result into text safe to emit now and a trailing INCOMPLETE
|
||||
// UTF-8 sequence (at most utf8.UTFMax-1 bytes) to carry into the next block:
|
||||
// the per-block detokenize can split a multi-byte character across block
|
||||
// boundaries (llama.cpp's grpc-server holds back the same way). Only a
|
||||
// suffix that can still become a valid rune is withheld; bytes that are
|
||||
// already undecodable are replaced immediately so the carry stays bounded.
|
||||
func splitValidUTF8(carry, block string) (emit, newCarry string) {
|
||||
s := carry + block
|
||||
cut := len(s)
|
||||
for i := len(s) - 1; i >= 0 && len(s)-i < utf8.UTFMax; i-- {
|
||||
b := s[i]
|
||||
if b < utf8.RuneSelf {
|
||||
break // ASCII: everything before the tail scan is complete
|
||||
}
|
||||
if !utf8.RuneStart(b) {
|
||||
continue // continuation byte: keep looking for its leading byte
|
||||
}
|
||||
// Leading byte: hold the sequence back iff it declares more bytes
|
||||
// than the stream has produced so far (it may complete next block).
|
||||
if utf8SeqLen(b) > len(s)-i {
|
||||
cut = i
|
||||
}
|
||||
break
|
||||
}
|
||||
return sanitizeUTF8(s[:cut]), s[cut:]
|
||||
}
|
||||
|
||||
// PredictRich is the non-streaming inference path (grpc.AIModelRich).
|
||||
// Returns one Reply whose Message is the aggregated assistant content and
|
||||
// whose ChatDeltas carry the parsed content/reasoning/tool-call events.
|
||||
func (d *Dllm) PredictRich(opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return nil, grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
prompt, parse, optsJSON, err := d.prepareRequest(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out string
|
||||
var genErr error
|
||||
d.submit(func() {
|
||||
out, genErr = d.gen.generate(prompt, optsJSON)
|
||||
})
|
||||
if genErr != nil {
|
||||
return nil, genErr
|
||||
}
|
||||
// Byte-fallback tokens can detokenize to invalid UTF-8; proto3 strings
|
||||
// must be valid or grpc-go fails the whole reply at marshal time.
|
||||
out = sanitizeUTF8(out)
|
||||
|
||||
if !parse {
|
||||
// Raw-prompt mode: plain content, no gemma4 parsing (see buildPrompt).
|
||||
return &pb.Reply{Message: []byte(out), ChatDeltas: []*pb.ChatDelta{{Content: out}}}, nil
|
||||
}
|
||||
|
||||
// The prompt renders with add_generation_prompt; both thinking modes
|
||||
// leave the model starting in content state (see the Gemma4Parser header
|
||||
// comment), hence NewGemma4Parser(false).
|
||||
parser := NewGemma4Parser(false)
|
||||
if reply := replyFromDeltas(append(parser.Feed(out), parser.Close()...)); reply != nil {
|
||||
return reply, nil
|
||||
}
|
||||
// Everything was markers (or out was empty): an empty but non-nil Reply.
|
||||
return &pb.Reply{}, nil
|
||||
}
|
||||
|
||||
// PredictStreamRich is the streaming counterpart (grpc.AIModelRich): one
|
||||
// Reply per committed diffusion block that produced deltas. Per the
|
||||
// interface contract the channel is only sent into here - the gRPC server
|
||||
// closes it after this returns (opposite to legacy PredictStream).
|
||||
func (d *Dllm) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
prompt, parse, optsJSON, err := d.prepareRequest(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var parser *Gemma4Parser
|
||||
if parse {
|
||||
parser = NewGemma4Parser(false)
|
||||
}
|
||||
// emit runs inside onBlock, i.e. on the thread driving the C generate.
|
||||
// Sending on results can block on a slow consumer, but the server-side
|
||||
// pump (pkg/grpc/server.go PredictStream) drains continuously and drops
|
||||
// undeliverable sends, so this backpressure is brief and bounded - and
|
||||
// pausing the diffusion loop under it is the desired behavior anyway.
|
||||
emit := func(text string) {
|
||||
if !parse {
|
||||
if text != "" {
|
||||
results <- &pb.Reply{Message: []byte(text), ChatDeltas: []*pb.ChatDelta{{Content: text}}}
|
||||
}
|
||||
return
|
||||
}
|
||||
deltas := parser.Feed(text)
|
||||
if reply := replyFromDeltas(deltas); reply != nil {
|
||||
results <- reply
|
||||
}
|
||||
}
|
||||
// onBlock guards emit (and through it the parser) against invalid UTF-8:
|
||||
// a multi-byte character split across block boundaries is held back until
|
||||
// it completes (see splitValidUTF8), so proto3 marshaling never fails.
|
||||
var carry string
|
||||
onBlock := func(block string) {
|
||||
var text string
|
||||
text, carry = splitValidUTF8(carry, block)
|
||||
emit(text)
|
||||
}
|
||||
|
||||
var genErr error
|
||||
d.submit(func() {
|
||||
genErr = d.gen.generateStream(prompt, optsJSON, onBlock)
|
||||
})
|
||||
if genErr != nil {
|
||||
return genErr
|
||||
}
|
||||
if carry != "" {
|
||||
// The stream ended mid-sequence: the held-back bytes can no longer
|
||||
// complete, so flush them through the U+FFFD last resort.
|
||||
emit(sanitizeUTF8(carry))
|
||||
}
|
||||
if parse {
|
||||
if reply := replyFromDeltas(parser.Close()); reply != nil {
|
||||
results <- reply
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// replyFromDeltas wraps one batch of parsed deltas into a streaming Reply,
|
||||
// or nil when the batch is empty (markers consumed, nothing emitted yet).
|
||||
// Message mirrors the batch's content text so legacy chan-string consumers
|
||||
// see exactly the displayed tokens.
|
||||
func replyFromDeltas(deltas []*pb.ChatDelta) *pb.Reply {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
var content strings.Builder
|
||||
for _, delta := range deltas {
|
||||
content.WriteString(delta.GetContent())
|
||||
}
|
||||
return &pb.Reply{Message: []byte(content.String()), ChatDeltas: deltas}
|
||||
}
|
||||
|
||||
// Predict is the legacy (string, error) signature; the gRPC server prefers
|
||||
// PredictRich, this exists for non-rich callers (cloud-proxy precedent).
|
||||
func (d *Dllm) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
reply, err := d.PredictRich(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.GetMessage()), nil
|
||||
}
|
||||
|
||||
// PredictStream is the legacy chan-string path: rich replies reduced to
|
||||
// their content text. Note the inverted channel ownership - the LEGACY
|
||||
// contract requires the impl to close the channel.
|
||||
func (d *Dllm) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
defer close(results)
|
||||
richCh := make(chan *pb.Reply)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.PredictStreamRich(opts, richCh)
|
||||
close(richCh)
|
||||
}()
|
||||
for reply := range richCh {
|
||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
||||
results <- string(msg)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// TokenizeString tokenizes opts.Prompt via dllm_capi_tokenize_json (the C
|
||||
// side prepends bos per the vocab) and decodes the returned id array.
|
||||
func (d *Dllm) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
d.genMu.RLock()
|
||||
defer d.genMu.RUnlock()
|
||||
if d.gen == nil {
|
||||
return pb.TokenizationResponse{}, grpcerrors.ModelNotLoaded("dllm")
|
||||
}
|
||||
var out string
|
||||
var tokErr error
|
||||
d.submit(func() {
|
||||
out, tokErr = d.gen.tokenizeJSON(opts.GetPrompt())
|
||||
})
|
||||
if tokErr != nil {
|
||||
return pb.TokenizationResponse{}, tokErr
|
||||
}
|
||||
var tokens []int32
|
||||
if err := json.Unmarshal([]byte(out), &tokens); err != nil {
|
||||
return pb.TokenizationResponse{}, fmt.Errorf("dllm: decode tokenize result %q: %w", out, err)
|
||||
}
|
||||
return pb.TokenizationResponse{Length: int32(len(tokens)), Tokens: tokens}, nil
|
||||
}
|
||||
@@ -1,807 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
func TestDllm(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "dllm Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// C-ABI bridge without spinning up the gRPC server. The library path comes
|
||||
// from DLLM_TEST_LIBRARY (gated specs Skip when it is unset).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libLoadErr = loadCAPI(os.Getenv("DLLM_TEST_LIBRARY"))
|
||||
})
|
||||
}
|
||||
|
||||
// C-ABI binding smoke: drives the real libdllm.so against the tiny GGUF
|
||||
// fixture from dllm.cpp (tests/fixtures/tiny_with_vocab.gguf). Gated on:
|
||||
//
|
||||
// DLLM_TEST_LIBRARY absolute path to libdllm.so
|
||||
// DLLM_TEST_TINY_MODEL absolute path to tiny_with_vocab.gguf
|
||||
var _ = Describe("C-ABI binding", func() {
|
||||
BeforeEach(func() {
|
||||
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
|
||||
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the C-ABI binding smoke")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
Expect(libLoadErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("binds the 9 symbols and round-trips the tiny model", func() {
|
||||
Expect(cAbiVersion()).To(Equal(int32(1)))
|
||||
|
||||
h := cLoad(os.Getenv("DLLM_TEST_TINY_MODEL"), "{}")
|
||||
Expect(h).ToNot(BeZero(), "dllm_capi_load of the tiny fixture")
|
||||
|
||||
// Tiny fixture vocab: "hello" tokenizes to ids [2,18] (bos prepended
|
||||
// by the C side: vocab.add_bos).
|
||||
toks, err := cTokenizeJSON(h, "hello")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(toks).To(Equal("[2,18]"))
|
||||
|
||||
// Deterministic generation: an explicit non-negative seed seeds
|
||||
// mt19937, so two identical calls must produce identical text.
|
||||
out1, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out1).ToNot(BeEmpty())
|
||||
// Cancel with no call in flight is dropped: each generate resets the
|
||||
// cancel flag on entry (header contract), so this must not affect
|
||||
// the next call. Also binds the 9th symbol; safe on NULL too.
|
||||
cCancel(h)
|
||||
cCancel(0)
|
||||
|
||||
out2, err := cGenerate(h, "hello", `{"n_predict":16,"seed":7}`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out2).To(Equal(out1))
|
||||
|
||||
// Streaming variant: same opts, blocks arrive via the purego
|
||||
// callback trampoline. The per-block detokenize can differ from the
|
||||
// seamless full-text decode at block boundaries, so only assert that
|
||||
// blocks arrived and were non-trivial, not byte equality with out1.
|
||||
var blocks []string
|
||||
var steps int
|
||||
err = cGenerateStream(h, "hello", `{"n_predict":16,"seed":7}`,
|
||||
func(text string) { blocks = append(blocks, text) },
|
||||
func(step, total int, preview string) { steps++ },
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(blocks).ToNot(BeEmpty())
|
||||
Expect(steps).To(BeNumerically(">", 0))
|
||||
|
||||
// Load failure path: NULL ctx back, and last_error(NULL) returns the
|
||||
// static NULL-ctx message (there is no ctx to carry the real reason).
|
||||
bad := cLoad("/nonexistent/dllm-model.gguf", "{}")
|
||||
Expect(bad).To(BeZero())
|
||||
Expect(cLastError(0)).ToNot(BeEmpty())
|
||||
|
||||
// Free is safe on a live handle and a NULL one (delete nullptr).
|
||||
cFree(h)
|
||||
cFree(0)
|
||||
})
|
||||
})
|
||||
|
||||
// Ungated specs for the pure-Go helpers (no libdllm.so required).
|
||||
var _ = Describe("buildOptsJSON", func() {
|
||||
It("renders flat scalars as a JSON object", func() {
|
||||
out, err := buildOptsJSON(map[string]any{
|
||||
"n_predict": 16,
|
||||
"seed": int64(7),
|
||||
"eb_t_min": 0.5,
|
||||
"kv_cache": "auto",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(MatchJSON(`{"n_predict":16,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
|
||||
})
|
||||
|
||||
It("renders an empty object for no options", func() {
|
||||
out, err := buildOptsJSON(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal("{}"))
|
||||
})
|
||||
|
||||
It("rejects nested objects (the C-side scanner only reads flat scalars)", func() {
|
||||
_, err := buildOptsJSON(map[string]any{"sampler": map[string]any{"seed": 1}})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects arrays", func() {
|
||||
_, err := buildOptsJSON(map[string]any{"stop": []string{"a"}})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects booleans (the C-side scanner only understands numbers and strings)", func() {
|
||||
_, err := buildOptsJSON(map[string]any{"flag": true})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("splitValidUTF8", func() {
|
||||
It("holds back a trailing incomplete sequence and completes it next block", func() {
|
||||
emit, carry := splitValidUTF8("", "caf\xe2")
|
||||
Expect(emit).To(Equal("caf"))
|
||||
Expect(carry).To(Equal("\xe2"))
|
||||
|
||||
emit, carry = splitValidUTF8(carry, "\x82")
|
||||
Expect(emit).To(BeEmpty())
|
||||
Expect(carry).To(Equal("\xe2\x82"))
|
||||
|
||||
emit, carry = splitValidUTF8(carry, "\xac!")
|
||||
Expect(emit).To(Equal("€!"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("holds back up to 3 bytes of a 4-byte sequence", func() {
|
||||
emit, carry := splitValidUTF8("", "x\xf0\x9f\x98") // 😀 missing its last byte
|
||||
Expect(emit).To(Equal("x"))
|
||||
Expect(carry).To(Equal("\xf0\x9f\x98"))
|
||||
|
||||
emit, carry = splitValidUTF8(carry, "\x80")
|
||||
Expect(emit).To(Equal("😀"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("replaces undecodable bytes immediately instead of carrying them", func() {
|
||||
// A mid-string invalid byte can never complete: carrying it would let
|
||||
// the carry grow unboundedly, so it is substituted on the spot.
|
||||
emit, carry := splitValidUTF8("", "a\xe2bc")
|
||||
Expect(emit).To(Equal("a<>bc"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
|
||||
// Orphan continuation bytes at the end have no leading byte to wait
|
||||
// for either.
|
||||
emit, carry = splitValidUTF8("", "a\x82")
|
||||
Expect(emit).To(Equal("a<>"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("passes pure ASCII and complete UTF-8 through untouched", func() {
|
||||
emit, carry := splitValidUTF8("", "héllo €")
|
||||
Expect(emit).To(Equal("héllo €"))
|
||||
Expect(carry).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("goStringFromCPtr", func() {
|
||||
It("copies a NUL-terminated buffer", func() {
|
||||
buf := []byte("dllm\x00")
|
||||
s := goStringFromCPtr(uintptr(unsafe.Pointer(&buf[0])))
|
||||
// The uintptr round-trip hides buf from the GC's liveness analysis;
|
||||
// keep it reachable until after the copy.
|
||||
runtime.KeepAlive(buf)
|
||||
Expect(s).To(Equal("dllm"))
|
||||
})
|
||||
|
||||
It("returns the empty string for NULL", func() {
|
||||
Expect(goStringFromCPtr(0)).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Backend wiring (T4): fake-generator specs, no libdllm.so required.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type fakeGenCall struct {
|
||||
prompt string
|
||||
optsJSON string
|
||||
}
|
||||
|
||||
// fakeGen implements generator in-process. It records every call (prompt +
|
||||
// opts JSON), tracks concurrent in-flight calls to prove worker
|
||||
// serialization, and replays canned output (out for generate/tokenize,
|
||||
// blocks for generateStream).
|
||||
type fakeGen struct {
|
||||
mu sync.Mutex
|
||||
calls []fakeGenCall
|
||||
inFlight int
|
||||
maxInFlight int
|
||||
|
||||
out string
|
||||
blocks []string
|
||||
err error
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (f *fakeGen) begin(prompt, optsJSON string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.calls = append(f.calls, fakeGenCall{prompt: prompt, optsJSON: optsJSON})
|
||||
f.inFlight++
|
||||
if f.inFlight > f.maxInFlight {
|
||||
f.maxInFlight = f.inFlight
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeGen) end() {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.inFlight--
|
||||
}
|
||||
|
||||
func (f *fakeGen) snapshot() (calls []fakeGenCall, maxInFlight int) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return append([]fakeGenCall(nil), f.calls...), f.maxInFlight
|
||||
}
|
||||
|
||||
func (f *fakeGen) generate(prompt, optsJSON string) (string, error) {
|
||||
f.begin(prompt, optsJSON)
|
||||
defer f.end()
|
||||
if f.delay > 0 {
|
||||
time.Sleep(f.delay)
|
||||
}
|
||||
return f.out, f.err
|
||||
}
|
||||
|
||||
func (f *fakeGen) generateStream(prompt, optsJSON string, onBlock func(text string)) error {
|
||||
f.begin(prompt, optsJSON)
|
||||
defer f.end()
|
||||
if f.err != nil {
|
||||
return f.err
|
||||
}
|
||||
for _, b := range f.blocks {
|
||||
onBlock(b)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeGen) tokenizeJSON(text string) (string, error) {
|
||||
f.begin(text, "")
|
||||
defer f.end()
|
||||
return f.out, f.err
|
||||
}
|
||||
|
||||
func (f *fakeGen) cancel() {}
|
||||
func (f *fakeGen) free() {}
|
||||
|
||||
// newTestDllm assembles a backend around a fake generator (bypassing Load,
|
||||
// which needs libdllm.so) and registers cleanup of the worker goroutine.
|
||||
func newTestDllm(g generator, genOpts map[string]any) *Dllm {
|
||||
d := &Dllm{gen: g, genOpts: genOpts}
|
||||
d.startWorker()
|
||||
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
|
||||
return d
|
||||
}
|
||||
|
||||
// drainReplies empties ch without blocking, failing the spec if the channel
|
||||
// was closed (PredictStreamRich must NOT close it - interface.go contract).
|
||||
// Size ch above the expected reply count: an overflow deadlocks the spec on
|
||||
// the producer's send instead of failing it.
|
||||
func drainReplies(ch chan *pb.Reply) []*pb.Reply {
|
||||
var out []*pb.Reply
|
||||
for {
|
||||
select {
|
||||
case r, ok := <-ch:
|
||||
if !ok {
|
||||
Fail("PredictStreamRich closed the results channel (the gRPC server owns the close)")
|
||||
}
|
||||
expectValidUTF8Reply(r)
|
||||
out = append(out, r)
|
||||
default:
|
||||
return out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expectValidUTF8Reply is the blanket guard for the proto3 marshaling
|
||||
// contract: grpc-go rejects any string field carrying invalid UTF-8, so every
|
||||
// reply field that ends up in a proto string must validate.
|
||||
func expectValidUTF8Reply(r *pb.Reply) {
|
||||
GinkgoHelper()
|
||||
Expect(utf8.ValidString(string(r.GetMessage()))).To(BeTrue(), "Reply.Message carries invalid UTF-8")
|
||||
for _, delta := range r.GetChatDeltas() {
|
||||
Expect(utf8.ValidString(delta.GetContent())).To(BeTrue(), "ChatDelta.Content carries invalid UTF-8")
|
||||
Expect(utf8.ValidString(delta.GetReasoningContent())).To(BeTrue(), "ChatDelta.ReasoningContent carries invalid UTF-8")
|
||||
for _, tc := range delta.GetToolCalls() {
|
||||
Expect(utf8.ValidString(tc.GetName())).To(BeTrue(), "ToolCallDelta.Name carries invalid UTF-8")
|
||||
Expect(utf8.ValidString(tc.GetArguments())).To(BeTrue(), "ToolCallDelta.Arguments carries invalid UTF-8")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("Dllm backend wiring", func() {
|
||||
Describe("PredictRich", func() {
|
||||
It("renders gemma4 from raw messages and parses the output when use_tokenizer_template is set", func() {
|
||||
fake := &fakeGen{out: "<|channel>thought\npondering<channel|>The answer.<turn|>"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
reply, err := d.PredictRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
|
||||
Metadata: map[string]string{"enable_thinking": "true"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls).To(HaveLen(1))
|
||||
// The enable_thinking=true render from the transformers fixture.
|
||||
Expect(calls[0].prompt).To(Equal(
|
||||
"<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n"))
|
||||
|
||||
Expect(string(reply.GetMessage())).To(Equal("The answer."))
|
||||
Expect(reply.GetChatDeltas()).To(HaveLen(2))
|
||||
Expect(reply.GetChatDeltas()[0].GetReasoningContent()).To(Equal("pondering"))
|
||||
Expect(reply.GetChatDeltas()[1].GetContent()).To(Equal("The answer."))
|
||||
})
|
||||
|
||||
It("defaults enable_thinking OFF (the gemma4 template treats thinking as opt-in)", func() {
|
||||
fake := &fakeGen{out: "hi"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "Write a long essay about Portugal."}},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
// No-thinking render: the template pre-opens AND pre-closes an
|
||||
// empty thought channel in the generation prompt.
|
||||
Expect(calls[0].prompt).To(Equal(
|
||||
"<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>"))
|
||||
})
|
||||
|
||||
It("passes the raw prompt verbatim and skips gemma4 parsing without use_tokenizer_template", func() {
|
||||
// Marker-looking text must survive untouched: in raw-prompt mode
|
||||
// the caller templates themselves and the Go-side extraction
|
||||
// applies, so the backend must not interpret the output.
|
||||
fake := &fakeGen{out: "<|channel>thought\nnot parsed<channel|>tail"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "my raw prompt"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].prompt).To(Equal("my raw prompt"))
|
||||
Expect(string(reply.GetMessage())).To(Equal(fake.out))
|
||||
Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal(fake.out))
|
||||
})
|
||||
|
||||
It("sanitizes invalid UTF-8 in the non-streaming output", func() {
|
||||
// Byte-fallback tokens can decode to lone malformed bytes; the
|
||||
// whole-output sanitize must replace them so proto3 marshaling of
|
||||
// Message/ChatDeltas cannot fail.
|
||||
fake := &fakeGen{out: "a\xe2b"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
reply, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectValidUTF8Reply(reply)
|
||||
Expect(string(reply.GetMessage())).To(Equal("a<>b"))
|
||||
Expect(reply.GetChatDeltas()[0].GetContent()).To(Equal("a<>b"))
|
||||
})
|
||||
|
||||
It("maps Tokens and Seed into the opts JSON on top of the model-level overrides", func() {
|
||||
fake := &fakeGen{out: "x"}
|
||||
d := newTestDllm(fake, map[string]any{"eb_t_min": 0.5, "kv_cache": "auto"})
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p", Tokens: 32, Seed: 7})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].optsJSON).To(MatchJSON(`{"n_predict":32,"seed":7,"eb_t_min":0.5,"kv_cache":"auto"}`))
|
||||
})
|
||||
|
||||
It("omits n_predict and seed when unset so the engine defaults apply", func() {
|
||||
fake := &fakeGen{out: "x"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].optsJSON).To(MatchJSON(`{}`))
|
||||
})
|
||||
|
||||
It("surfaces generator errors", func() {
|
||||
fake := &fakeGen{err: errors.New("boom")}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).To(MatchError("boom"))
|
||||
})
|
||||
|
||||
It("errors before generating when no model is loaded", func() {
|
||||
d := &Dllm{} // no Load, no worker: must fail fast, not hang
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("makes a concurrent Free wait for the in-flight request (both finish cleanly)", func() {
|
||||
// server.go's Free has no locking of its own: the backend's genMu
|
||||
// must hold Free back until the racing generate drains, instead of
|
||||
// closing the jobs channel (panic) or freeing the C ctx under it.
|
||||
fake := &fakeGen{out: "x", delay: 50 * time.Millisecond}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
predictDone := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
predictDone <- err
|
||||
}()
|
||||
// Wait until the fake generate is actually in flight (the read
|
||||
// lock is held from before submit until PredictRich returns).
|
||||
Eventually(func() int {
|
||||
_, maxInFlight := fake.snapshot()
|
||||
return maxInFlight
|
||||
}).Should(Equal(1))
|
||||
|
||||
Expect(d.Free()).To(Succeed())
|
||||
// Free's write lock means the request finished before Free did.
|
||||
var predictErr error
|
||||
Eventually(predictDone).Should(Receive(&predictErr))
|
||||
Expect(predictErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns model-not-loaded for requests after Free", func() {
|
||||
fake := &fakeGen{out: "x"}
|
||||
d := newTestDllm(fake, nil)
|
||||
Expect(d.Free()).To(Succeed())
|
||||
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
|
||||
})
|
||||
|
||||
It("serializes concurrent requests through the worker goroutine", func() {
|
||||
// dllm_capi.h: one ctx = one concurrent generate. Two overlapping
|
||||
// PredictRich calls must execute the C calls one at a time.
|
||||
fake := &fakeGen{out: "x", delay: 30 * time.Millisecond}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 2 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer GinkgoRecover()
|
||||
_, err := d.PredictRich(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
calls, maxInFlight := fake.snapshot()
|
||||
Expect(calls).To(HaveLen(2))
|
||||
Expect(maxInFlight).To(Equal(1), "generate calls overlapped despite the worker queue")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("PredictStreamRich", func() {
|
||||
It("emits one reply per delta-producing block and leaves the channel open", func() {
|
||||
// Blocks split mid-marker and mid-payload: the parser's holdback
|
||||
// must keep marker fragments out of the emitted deltas.
|
||||
fake := &fakeGen{blocks: []string{
|
||||
"<|channel>thou", // partial channel open: no deltas yet
|
||||
"ght\nponder", // header completes, reasoning starts
|
||||
"ing<channel|>Hi ", // reasoning ends, content starts
|
||||
"there<turn|>discarded", // turn ends: trailing text dropped
|
||||
}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
replies := drainReplies(ch)
|
||||
Expect(replies).To(HaveLen(3), "block 1 completes no delta and must not produce a reply")
|
||||
|
||||
var content, reasoning string
|
||||
for _, r := range replies {
|
||||
for _, delta := range r.GetChatDeltas() {
|
||||
content += delta.GetContent()
|
||||
reasoning += delta.GetReasoningContent()
|
||||
}
|
||||
}
|
||||
Expect(reasoning).To(Equal("pondering"))
|
||||
Expect(content).To(Equal("Hi there"))
|
||||
// Message mirrors each reply's content so legacy consumers see
|
||||
// exactly the displayed tokens.
|
||||
Expect(string(replies[1].GetMessage())).To(Equal("Hi "))
|
||||
Expect(string(replies[2].GetMessage())).To(Equal("there"))
|
||||
})
|
||||
|
||||
It("streams raw blocks verbatim without use_tokenizer_template", func() {
|
||||
fake := &fakeGen{blocks: []string{"abc", "", "<|channel>def"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
replies := drainReplies(ch)
|
||||
Expect(replies).To(HaveLen(2), "empty blocks produce no reply")
|
||||
Expect(string(replies[0].GetMessage())).To(Equal("abc"))
|
||||
Expect(string(replies[1].GetMessage())).To(Equal("<|channel>def"))
|
||||
Expect(replies[1].GetChatDeltas()).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("flushes parser holdback after the stream ends", func() {
|
||||
// The unterminated partial marker "<chan" is held back during the
|
||||
// stream and must come out as content on the final flush.
|
||||
fake := &fakeGen{blocks: []string{"tail<chan"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) {
|
||||
content += string(r.GetMessage())
|
||||
}
|
||||
Expect(content).To(Equal("tail<chan"))
|
||||
})
|
||||
|
||||
It("reassembles a multi-byte character split across block boundaries", func() {
|
||||
// Per-block detokenize can split "€" (E2 82 AC) as E2 | 82 AC.
|
||||
// Emitting the lone E2 would make grpc-go fail the marshal of the
|
||||
// whole reply; the trailing incomplete sequence must be held back
|
||||
// and completed by the next block.
|
||||
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac ok"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) { // drain asserts ValidString per reply
|
||||
content += string(r.GetMessage())
|
||||
}
|
||||
Expect(content).To(Equal("caf€ ok"))
|
||||
})
|
||||
|
||||
It("reassembles a split multi-byte character in parsed (gemma4) mode too", func() {
|
||||
fake := &fakeGen{blocks: []string{"caf\xe2", "\x82\xac<turn|>"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) {
|
||||
for _, delta := range r.GetChatDeltas() {
|
||||
content += delta.GetContent()
|
||||
}
|
||||
}
|
||||
Expect(content).To(Equal("caf€"))
|
||||
})
|
||||
|
||||
It("replaces an incomplete sequence left at stream end with U+FFFD", func() {
|
||||
// A byte-fallback token can leave a lone leading byte (0xE2) that
|
||||
// no later block completes: the final flush must substitute it,
|
||||
// never emit it raw and never drop into a marshal error.
|
||||
fake := &fakeGen{blocks: []string{"ok\xe2"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "raw"}, ch)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var content string
|
||||
for _, r := range drainReplies(ch) {
|
||||
content += string(r.GetMessage())
|
||||
}
|
||||
Expect(content).To(Equal("ok<6F>"))
|
||||
})
|
||||
|
||||
It("surfaces generator errors without sending replies", func() {
|
||||
fake := &fakeGen{err: errors.New("stream boom")}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan *pb.Reply, 16)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
|
||||
Expect(err).To(MatchError("stream boom"))
|
||||
Expect(drainReplies(ch)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("errors before generating when no model is loaded", func() {
|
||||
d := &Dllm{} // no Load, no worker: must fail fast, not hang
|
||||
ch := make(chan *pb.Reply, 1)
|
||||
err := d.PredictStreamRich(&pb.PredictOptions{Prompt: "p"}, ch)
|
||||
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
|
||||
Expect(drainReplies(ch)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("legacy Predict/PredictStream adapters", func() {
|
||||
It("Predict returns the aggregated content string", func() {
|
||||
fake := &fakeGen{out: "plain text"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
out, err := d.Predict(&pb.PredictOptions{Prompt: "p"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal("plain text"))
|
||||
})
|
||||
|
||||
It("PredictStream forwards content strings and closes the channel (legacy ownership)", func() {
|
||||
fake := &fakeGen{blocks: []string{"a", "b"}}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
ch := make(chan string, 16)
|
||||
Expect(d.PredictStream(&pb.PredictOptions{Prompt: "p"}, ch)).To(Succeed())
|
||||
|
||||
var got []string
|
||||
for s := range ch { // terminates only if the impl closed ch
|
||||
got = append(got, s)
|
||||
}
|
||||
Expect(got).To(Equal([]string{"a", "b"}))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("TokenizeString", func() {
|
||||
It("decodes the C-side JSON id array", func() {
|
||||
fake := &fakeGen{out: "[2,18]"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Length).To(Equal(int32(2)))
|
||||
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
|
||||
|
||||
calls, _ := fake.snapshot()
|
||||
Expect(calls[0].prompt).To(Equal("hello"))
|
||||
})
|
||||
|
||||
It("fails loud on a malformed id array", func() {
|
||||
fake := &fakeGen{out: "not json"}
|
||||
d := newTestDllm(fake, nil)
|
||||
|
||||
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors before tokenizing when no model is loaded", func() {
|
||||
d := &Dllm{} // no Load, no worker: must fail fast, not hang
|
||||
_, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).To(MatchError(ContainSubstring("model not loaded")))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("parseModelGenOpts", func() {
|
||||
It("parses eb_*/blocks/kv_cache entries and types values by first successful parse", func() {
|
||||
got := parseModelGenOpts([]string{
|
||||
"eb_max_steps:16",
|
||||
"eb_t_min:0.25",
|
||||
"kv_cache:auto",
|
||||
"blocks:4",
|
||||
"unrelated_key:1", // other layers' options: skipped
|
||||
"malformed", // no colon: skipped
|
||||
})
|
||||
Expect(got).To(Equal(map[string]any{
|
||||
"eb_max_steps": int64(16),
|
||||
"eb_t_min": 0.25,
|
||||
"kv_cache": "auto",
|
||||
"blocks": int64(4),
|
||||
}))
|
||||
})
|
||||
|
||||
It("round-trips through buildOptsJSON (only flat scalars are produced)", func() {
|
||||
got := parseModelGenOpts([]string{"eb_entropy_bound:0.8", "kv_cache:off"})
|
||||
out, err := buildOptsJSON(got)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(MatchJSON(`{"eb_entropy_bound":0.8,"kv_cache":"off"}`))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Gated backend round-trip against the real libdllm.so + tiny GGUF fixture.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var _ = Describe("Dllm backend (real tiny model)", func() {
|
||||
BeforeEach(func() {
|
||||
if os.Getenv("DLLM_TEST_LIBRARY") == "" || os.Getenv("DLLM_TEST_TINY_MODEL") == "" {
|
||||
Skip("set DLLM_TEST_LIBRARY and DLLM_TEST_TINY_MODEL to run the backend round-trip")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
Expect(libLoadErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("round-trips Load, PredictRich, PredictStreamRich and TokenizeString", func() {
|
||||
d := &Dllm{}
|
||||
Expect(d.Load(&pb.ModelOptions{ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL")})).To(Succeed())
|
||||
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
|
||||
|
||||
// TokenizeString: tiny fixture vocab tokenizes "hello" to [2,18].
|
||||
resp, err := d.TokenizeString(&pb.PredictOptions{Prompt: "hello"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Tokens).To(Equal([]int32{2, 18}))
|
||||
Expect(resp.Length).To(Equal(int32(2)))
|
||||
|
||||
req := &pb.PredictOptions{
|
||||
UseTokenizerTemplate: true,
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
|
||||
Tokens: 16,
|
||||
Seed: 7,
|
||||
}
|
||||
|
||||
// Non-streaming: the tiny random-weight model emits arbitrary vocab
|
||||
// words; with no gemma4 markers in them everything is content.
|
||||
reply, err := d.PredictRich(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(reply.GetMessage())).ToNot(BeEmpty())
|
||||
Expect(reply.GetChatDeltas()).ToNot(BeEmpty())
|
||||
|
||||
// Streaming: at least one reply, and the channel-ownership rule is
|
||||
// honored (drainReplies fails the spec on a closed channel).
|
||||
ch := make(chan *pb.Reply, 64)
|
||||
Expect(d.PredictStreamRich(req, ch)).To(Succeed())
|
||||
replies := drainReplies(ch)
|
||||
Expect(replies).ToNot(BeEmpty())
|
||||
var streamed string
|
||||
for _, r := range replies {
|
||||
streamed += string(r.GetMessage())
|
||||
}
|
||||
Expect(streamed).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("aborts an in-flight generation promptly on Cancel", func() {
|
||||
d := &Dllm{}
|
||||
// eb_max_steps inflates the per-block denoise loop so the full run
|
||||
// takes ~10s on the tiny fixture (vs ~40ms at engine defaults; 16
|
||||
// blocks, first block after ~0.7s) - long enough that a prompt
|
||||
// post-cancel return is distinguishable from the generation simply
|
||||
// finishing.
|
||||
Expect(d.Load(&pb.ModelOptions{
|
||||
ModelFile: os.Getenv("DLLM_TEST_TINY_MODEL"),
|
||||
Options: []string{"eb_max_steps:256"},
|
||||
})).To(Succeed())
|
||||
DeferCleanup(func() { Expect(d.Free()).To(Succeed()) })
|
||||
|
||||
ch := make(chan *pb.Reply, 64)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
errCh <- d.PredictStreamRich(&pb.PredictOptions{Prompt: "hello", Tokens: 256, Seed: 7}, ch)
|
||||
}()
|
||||
|
||||
// Cancel only once the first block proves the generate is in
|
||||
// flight: the C side resets the cancel flag on generate entry, so
|
||||
// an earlier Cancel would be swallowed (dllm_capi.h race note).
|
||||
Eventually(ch, "60s").Should(Receive())
|
||||
cancelAt := time.Now()
|
||||
d.Cancel()
|
||||
|
||||
// Uncancelled, ~10s of generation remain; the cancelled call must
|
||||
// come back in milliseconds (the flag is checked per denoise step).
|
||||
var genErr error
|
||||
Eventually(errCh, "5s").Should(Receive(&genErr))
|
||||
latency := time.Since(cancelAt)
|
||||
Expect(genErr).To(MatchError(ContainSubstring("cancelled")))
|
||||
GinkgoWriter.Printf("dllm cancel: PredictStreamRich returned %v after Cancel\n", latency)
|
||||
})
|
||||
})
|
||||
@@ -1,562 +0,0 @@
|
||||
// Gemma4 (DiffusionGemma) streaming output parser: raw model text, fed in
|
||||
// arbitrary fragments (per committed diffusion block; a fragment can split
|
||||
// anywhere, including mid-marker and mid-payload), is turned into
|
||||
// pb.ChatDelta events (content / reasoning_content / tool_calls).
|
||||
//
|
||||
// Normative sources:
|
||||
// - The chat template embedded at the top of gemma4_renderer.go ("tpl L<n>"
|
||||
// citations below refer to its numbered lines). The OUTPUT format mirrors
|
||||
// what the template renders for assistant history: thought channels
|
||||
// (<|channel>thought\n ... <channel|>, tpl L240), tool calls
|
||||
// (<|tool_call>call:name{...}<tool_call|>, tpl L246-L257) and turn ends
|
||||
// (<turn|>, tpl L351).
|
||||
// - vLLM PR #45163: vllm/tool_parsers/gemma4_tool_parser.py (marker
|
||||
// handling, the call:name{...} argument grammar and its decoder, ported
|
||||
// below) and vllm/reasoning/gemma4_reasoning_parser.py (channel markers,
|
||||
// the "thought\n" role label, is_reasoning_end semantics).
|
||||
//
|
||||
// Initial state (derived from the generation prompt, tpl L356-L362, see
|
||||
// RenderGemma4):
|
||||
// - enable_thinking=false: the prompt ends with "<|turn>model\n" +
|
||||
// "<|channel>thought\n<channel|>" - an EMPTY thought channel, pre-opened
|
||||
// AND pre-closed by the template. The model's output therefore starts in
|
||||
// plain content. Use NewGemma4Parser(false).
|
||||
// - enable_thinking=true: the prompt ends at "<|turn>model\n" and the model
|
||||
// opens and closes its own thought channel in the OUTPUT
|
||||
// ("<|channel>thought\n...reasoning...<channel|>final answer", per the
|
||||
// vLLM Gemma4ReasoningParser docstring). The parser still starts in
|
||||
// content state - the channel markers in the output drive the switch.
|
||||
// Use NewGemma4Parser(false) here too.
|
||||
// - NewGemma4Parser(true) is for callers that pre-open the thought channel
|
||||
// in the prompt themselves (appending "<|channel>thought\n" after the
|
||||
// generation prompt to force thinking): the output then begins mid-thought
|
||||
// and everything is reasoning until the first <channel|>.
|
||||
//
|
||||
// State diagram (markers are consumed, never emitted):
|
||||
//
|
||||
// <|channel> \n (channel name dropped: the
|
||||
// [content] --------------> [chan-header] ----> [thought] "thought\n" role
|
||||
// ^ | <channel|> (stray close: swallowed, label, stripped
|
||||
// +-+ strip_thinking semantics, tpl L148-L158) like vLLM does)
|
||||
// ^ <channel|>
|
||||
// +----------------------------------------- [thought]
|
||||
// ^ <tool_call|> | <|tool_call> (implicit
|
||||
// +-------------- [tool-call] <-------------------+ reasoning end, vLLM
|
||||
// | <|tool_call> ^ is_reasoning_end)
|
||||
// +-------------------+
|
||||
// [content]/[thought] --- <turn|> ---> [done] (everything after is dropped)
|
||||
//
|
||||
// Buffering rules:
|
||||
// - content/thought states hold back at most len(longest marker)-1 bytes:
|
||||
// the longest tail that is still a proper prefix of a watched marker.
|
||||
// Content is otherwise emitted immediately (no unbounded buffering).
|
||||
// - the tool-call state buffers the whole payload until <tool_call|>. This
|
||||
// is unbounded in principle but bounded in practice by the model's
|
||||
// diffusion canvas, and is required because the call:name{...} payload
|
||||
// only becomes decodable (and trustworthy) once complete - the same
|
||||
// reason vLLM's parser accumulates before parsing.
|
||||
// - Close() flushes whatever is still held: partial markers come out as
|
||||
// content/reasoning (per the state that held them); an unterminated
|
||||
// channel header or tool-call payload is re-emitted RAW (including its
|
||||
// opening marker) as content - malformed output is never silently
|
||||
// dropped (mirrors vLLM extract_tool_calls returning the raw text as
|
||||
// content when its regex does not match).
|
||||
//
|
||||
// Streaming granularity DIVERGENCE from vLLM: vLLM re-parses the partial
|
||||
// payload on every token and streams argument-JSON diffs (its `partial=True`
|
||||
// decoder mode plus withholding logic exist only for that). Our fragments are
|
||||
// whole committed diffusion blocks, so each completed tool call is emitted
|
||||
// once, as a single ToolCallDelta carrying index + id + name + the full
|
||||
// arguments JSON - exactly the shape backend/python/vllm/backend.py emits
|
||||
// per call and pkg/functions.ToolCallsFromChatDeltas re-accumulates.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// gemma4CallRE is vLLM's tool_call_regex
|
||||
// (`<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>`, DOTALL) anchored to
|
||||
// a single already-extracted payload: name charset [\w\-.], braces mandatory.
|
||||
var gemma4CallRE = regexp.MustCompile(`(?s)^call:([\w\-.]+)\{(.*)\}$`)
|
||||
|
||||
type g4State int
|
||||
|
||||
const (
|
||||
g4Content g4State = iota
|
||||
g4ChanHeader
|
||||
g4Thought
|
||||
g4ToolCall
|
||||
g4Done
|
||||
)
|
||||
|
||||
// Markers watched per emitting state. A stray <tool_call|> outside a tool
|
||||
// call is deliberately NOT watched: it passes through verbatim, consistent
|
||||
// with the malformed-payload fallback re-emitting it as content.
|
||||
var (
|
||||
gemma4ContentMarkers = []string{gemma4ChannelOpen, gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
|
||||
gemma4ThoughtMarkers = []string{gemma4ChannelClose, gemma4ToolCallOpen, gemma4TurnEnd}
|
||||
)
|
||||
|
||||
type Gemma4Parser struct {
|
||||
state g4State
|
||||
// held is the per-state carry-over between Feed calls: a partial marker
|
||||
// (content/thought), a partial channel header (chan-header) or the
|
||||
// payload accumulated so far (tool-call).
|
||||
held string
|
||||
toolIdx int
|
||||
}
|
||||
|
||||
// NewGemma4Parser returns a parser positioned per the initial-state rules in
|
||||
// the header comment: startInThought=true only when the caller pre-opened a
|
||||
// thought channel in the prompt.
|
||||
func NewGemma4Parser(startInThought bool) *Gemma4Parser {
|
||||
state := g4Content
|
||||
if startInThought {
|
||||
state = g4Thought
|
||||
}
|
||||
return &Gemma4Parser{state: state}
|
||||
}
|
||||
|
||||
// Feed consumes the next output fragment and returns the deltas it completes.
|
||||
func (p *Gemma4Parser) Feed(text string) []*pb.ChatDelta {
|
||||
if text == "" || p.state == g4Done {
|
||||
return nil
|
||||
}
|
||||
pending := p.held + text
|
||||
p.held = ""
|
||||
var em g4Emitter
|
||||
for pending != "" {
|
||||
switch p.state {
|
||||
case g4Content, g4Thought:
|
||||
markers := gemma4ContentMarkers
|
||||
if p.state == g4Thought {
|
||||
markers = gemma4ThoughtMarkers
|
||||
}
|
||||
idx, marker := findEarliestGemma4Marker(pending, markers)
|
||||
if idx == -1 {
|
||||
hold := gemma4MarkerHoldback(pending, markers)
|
||||
p.emitText(&em, pending[:len(pending)-hold])
|
||||
p.held = pending[len(pending)-hold:]
|
||||
pending = ""
|
||||
continue
|
||||
}
|
||||
p.emitText(&em, pending[:idx])
|
||||
pending = pending[idx+len(marker):]
|
||||
switch marker {
|
||||
case gemma4ChannelOpen:
|
||||
p.state = g4ChanHeader
|
||||
case gemma4ChannelClose:
|
||||
// In thought: channel ends. In content: stray close,
|
||||
// swallowed (strip_thinking keeps both sides, tpl L148-L158).
|
||||
p.state = g4Content
|
||||
case gemma4ToolCallOpen:
|
||||
p.state = g4ToolCall
|
||||
case gemma4TurnEnd:
|
||||
p.state = g4Done
|
||||
}
|
||||
case g4ChanHeader:
|
||||
// The channel header is "<name>\n"; the template only ever writes
|
||||
// "thought" (tpl L240/L360) and the label is structural, so it is
|
||||
// dropped, not emitted (vLLM strips the same "thought\n" prefix).
|
||||
nl := strings.IndexByte(pending, '\n')
|
||||
if nl == -1 {
|
||||
p.held = pending
|
||||
pending = ""
|
||||
continue
|
||||
}
|
||||
pending = pending[nl+1:]
|
||||
p.state = g4Thought
|
||||
case g4ToolCall:
|
||||
end := strings.Index(pending, gemma4ToolCallClose)
|
||||
if end == -1 {
|
||||
p.held = pending
|
||||
pending = ""
|
||||
continue
|
||||
}
|
||||
p.emitToolCall(&em, pending[:end])
|
||||
pending = pending[end+len(gemma4ToolCallClose):]
|
||||
p.state = g4Content
|
||||
case g4Done:
|
||||
pending = ""
|
||||
}
|
||||
}
|
||||
return em.deltas
|
||||
}
|
||||
|
||||
// Close flushes held-back partials. Incomplete structures (open channel
|
||||
// header, unterminated tool payload) are re-emitted raw as content rather
|
||||
// than dropped. The parser is finished afterwards.
|
||||
func (p *Gemma4Parser) Close() []*pb.ChatDelta {
|
||||
var em g4Emitter
|
||||
switch p.state {
|
||||
case g4Content:
|
||||
em.content(p.held)
|
||||
case g4Thought:
|
||||
em.reasoning(p.held)
|
||||
case g4ChanHeader:
|
||||
em.content(gemma4ChannelOpen + p.held)
|
||||
case g4ToolCall:
|
||||
em.content(gemma4ToolCallOpen + p.held)
|
||||
case g4Done:
|
||||
}
|
||||
p.held = ""
|
||||
p.state = g4Done
|
||||
return em.deltas
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) emitText(em *g4Emitter, s string) {
|
||||
if p.state == g4Thought {
|
||||
em.reasoning(s)
|
||||
return
|
||||
}
|
||||
em.content(s)
|
||||
}
|
||||
|
||||
// emitToolCall decodes one complete <|tool_call>...<tool_call|> payload. On a
|
||||
// payload that does not match call:name{...} the raw text (markers included)
|
||||
// is emitted as content, mirroring vLLM's extract_tool_calls fallback.
|
||||
func (p *Gemma4Parser) emitToolCall(em *g4Emitter, payload string) {
|
||||
m := gemma4CallRE.FindStringSubmatch(payload)
|
||||
if m == nil {
|
||||
em.content(gemma4ToolCallOpen + payload + gemma4ToolCallClose)
|
||||
return
|
||||
}
|
||||
// Index-based ids: deterministic (the split-invariance property relies
|
||||
// on it) and matching the call_<n> convention of pkg/grpc/rich_test.go;
|
||||
// core only needs ids to be non-empty and unique within the response.
|
||||
em.tool(p.toolIdx, "call_"+strconv.Itoa(p.toolIdx), m[1], decodeGemma4Args(m[2], 0))
|
||||
p.toolIdx++
|
||||
}
|
||||
|
||||
// g4Emitter collects ChatDeltas; empty text events are dropped.
|
||||
type g4Emitter struct {
|
||||
deltas []*pb.ChatDelta
|
||||
}
|
||||
|
||||
func (e *g4Emitter) content(s string) {
|
||||
if s != "" {
|
||||
e.deltas = append(e.deltas, &pb.ChatDelta{Content: s})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *g4Emitter) reasoning(s string) {
|
||||
if s != "" {
|
||||
e.deltas = append(e.deltas, &pb.ChatDelta{ReasoningContent: s})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *g4Emitter) tool(index int, id, name, argsJSON string) {
|
||||
e.deltas = append(e.deltas, &pb.ChatDelta{ToolCalls: []*pb.ToolCallDelta{{
|
||||
Index: int32(index),
|
||||
Id: id,
|
||||
Name: name,
|
||||
Arguments: argsJSON,
|
||||
}}})
|
||||
}
|
||||
|
||||
// findEarliestGemma4Marker returns the position and value of the first
|
||||
// complete marker occurrence, or (-1, "").
|
||||
func findEarliestGemma4Marker(s string, markers []string) (int, string) {
|
||||
best, bestMarker := -1, ""
|
||||
for _, m := range markers {
|
||||
if idx := strings.Index(s, m); idx >= 0 && (best == -1 || idx < best) {
|
||||
best, bestMarker = idx, m
|
||||
}
|
||||
}
|
||||
return best, bestMarker
|
||||
}
|
||||
|
||||
// gemma4MarkerHoldback returns the length of the longest suffix of s that is
|
||||
// a proper prefix of a watched marker - the only bytes that may still grow
|
||||
// into a marker and therefore must not be emitted yet (bounded by the
|
||||
// longest marker, so content is never buffered unboundedly).
|
||||
func gemma4MarkerHoldback(s string, markers []string) int {
|
||||
maxHold := 0
|
||||
for _, m := range markers {
|
||||
if len(m)-1 > maxHold {
|
||||
maxHold = len(m) - 1
|
||||
}
|
||||
}
|
||||
if len(s) < maxHold {
|
||||
maxHold = len(s)
|
||||
}
|
||||
for k := maxHold; k >= 1; k-- {
|
||||
tail := s[len(s)-k:]
|
||||
for _, m := range markers {
|
||||
if strings.HasPrefix(m, tail) {
|
||||
return k
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// call:name{...} argument decoder
|
||||
//
|
||||
// Port of vLLM's _parse_gemma4_args / _parse_gemma4_array /
|
||||
// _parse_gemma4_value (gemma4_tool_parser.py) in non-partial mode only: this
|
||||
// parser decodes exclusively COMPLETE payloads (incomplete ones fall back to
|
||||
// raw content at Close), so vLLM's partial-withholding machinery
|
||||
// (trailing-dot floats, withheld bare tails) is intentionally not ported.
|
||||
//
|
||||
// Grammar (inverse of the renderer's formatGemma4Argument, tpl L118-L147):
|
||||
//
|
||||
// args := pair (',' pair)*
|
||||
// pair := key ':' value (keys unquoted, up to the first ':')
|
||||
// value := string | object | array | bare
|
||||
// string := '<|"|>' ... '<|"|>' (no escapes; unterminated -> rest)
|
||||
// object := '{' args '}' (delimited strings skipped when
|
||||
// array := '[' value,* ']' counting braces/brackets)
|
||||
// bare := true | false | null/none/nil | number | bare-string
|
||||
//
|
||||
// Output is a JSON object/array string with keys in payload order (Python
|
||||
// dict insertion order), built with HTML escaping off so payload text
|
||||
// survives byte-for-byte.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func isGemma4Space(c byte) bool { return c == ' ' || c == '\n' || c == '\t' }
|
||||
|
||||
// gemma4MaxArgsDepth caps the mutual recursion between decodeGemma4Args and
|
||||
// decodeGemma4Array. Defense against model-generated deep nesting: a Go stack
|
||||
// overflow is a fatal process kill, not a recoverable error, so past the cap
|
||||
// a nested body gracefully degrades to a JSON string of its raw text.
|
||||
const gemma4MaxArgsDepth = 100
|
||||
|
||||
// decodeGemma4Args decodes one args body (the text between the outer braces
|
||||
// of call:name{...}) into a JSON object string. depth is the current nesting
|
||||
// level (0 at the payload root); see gemma4MaxArgsDepth.
|
||||
func decodeGemma4Args(s string, depth int) string {
|
||||
if depth > gemma4MaxArgsDepth {
|
||||
return gemma4JSONString(s)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("{")
|
||||
first := true
|
||||
pair := func(key, val string) {
|
||||
if !first {
|
||||
b.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
b.WriteString(gemma4JSONString(key))
|
||||
b.WriteString(":")
|
||||
b.WriteString(val)
|
||||
}
|
||||
i, n := 0, len(s)
|
||||
for i < n {
|
||||
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
break
|
||||
}
|
||||
keyStart := i
|
||||
for i < n && s[i] != ':' {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
break // no ':' -> trailing junk, dropped (vLLM does the same)
|
||||
}
|
||||
key := strings.TrimSpace(s[keyStart:i])
|
||||
i++ // skip ':'
|
||||
for i < n && isGemma4Space(s[i]) {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
pair(key, `""`) // "key:" with nothing after -> empty string
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(s[i:], gemma4StringDelim):
|
||||
i += len(gemma4StringDelim)
|
||||
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
|
||||
pair(key, gemma4JSONString(s[i:])) // unterminated -> take rest
|
||||
i = n
|
||||
} else {
|
||||
pair(key, gemma4JSONString(s[i:i+end]))
|
||||
i += end + len(gemma4StringDelim)
|
||||
}
|
||||
case s[i] == '{':
|
||||
inner, next := scanGemma4Balanced(s, i, '{', '}')
|
||||
pair(key, decodeGemma4Args(inner, depth+1))
|
||||
i = next
|
||||
case s[i] == '[':
|
||||
inner, next := scanGemma4Balanced(s, i, '[', ']')
|
||||
pair(key, decodeGemma4Array(inner, depth+1))
|
||||
i = next
|
||||
default:
|
||||
valStart := i
|
||||
for i < n && s[i] != ',' && s[i] != '}' && s[i] != ']' {
|
||||
i++
|
||||
}
|
||||
if i == valStart {
|
||||
// No progress (value starts on a stray '}'/']'): abort on
|
||||
// malformed input rather than loop, like vLLM.
|
||||
i = n
|
||||
continue
|
||||
}
|
||||
pair(key, decodeGemma4Bare(s[valStart:i]))
|
||||
}
|
||||
}
|
||||
b.WriteString("}")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// decodeGemma4Array decodes one array body (the text between '[' and ']')
|
||||
// into a JSON array string. depth is the current nesting level; see
|
||||
// gemma4MaxArgsDepth.
|
||||
func decodeGemma4Array(s string, depth int) string {
|
||||
if depth > gemma4MaxArgsDepth {
|
||||
return gemma4JSONString(s)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("[")
|
||||
first := true
|
||||
item := func(val string) {
|
||||
if !first {
|
||||
b.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
b.WriteString(val)
|
||||
}
|
||||
i, n := 0, len(s)
|
||||
for i < n {
|
||||
for i < n && (isGemma4Space(s[i]) || s[i] == ',') {
|
||||
i++
|
||||
}
|
||||
if i >= n {
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(s[i:], gemma4StringDelim):
|
||||
i += len(gemma4StringDelim)
|
||||
if end := strings.Index(s[i:], gemma4StringDelim); end == -1 {
|
||||
item(gemma4JSONString(s[i:]))
|
||||
i = n
|
||||
} else {
|
||||
item(gemma4JSONString(s[i : i+end]))
|
||||
i += end + len(gemma4StringDelim)
|
||||
}
|
||||
case s[i] == '{':
|
||||
inner, next := scanGemma4Balanced(s, i, '{', '}')
|
||||
item(decodeGemma4Args(inner, depth+1))
|
||||
i = next
|
||||
case s[i] == '[':
|
||||
inner, next := scanGemma4Balanced(s, i, '[', ']')
|
||||
item(decodeGemma4Array(inner, depth+1))
|
||||
i = next
|
||||
default:
|
||||
valStart := i
|
||||
for i < n && s[i] != ',' && s[i] != ']' {
|
||||
i++
|
||||
}
|
||||
if i == valStart {
|
||||
i = n // no progress: abort on malformed input, like vLLM
|
||||
continue
|
||||
}
|
||||
item(decodeGemma4Bare(s[valStart:i]))
|
||||
}
|
||||
}
|
||||
b.WriteString("]")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// scanGemma4Balanced scans a brace/bracket-balanced span starting at the
|
||||
// opener s[start], skipping over <|"|>-delimited strings so structural
|
||||
// characters inside them do not count (vLLM's depth scan). Returns the inner
|
||||
// text and the index just past the closer; an unterminated span yields the
|
||||
// rest of the string (the inner decoder still extracts what is there - this
|
||||
// path is only reachable from genuinely malformed complete payloads).
|
||||
func scanGemma4Balanced(s string, start int, open, close byte) (string, int) {
|
||||
depth := 1
|
||||
i := start + 1
|
||||
innerStart := i
|
||||
n := len(s)
|
||||
for i < n && depth > 0 {
|
||||
if strings.HasPrefix(s[i:], gemma4StringDelim) {
|
||||
i += len(gemma4StringDelim)
|
||||
if nd := strings.Index(s[i:], gemma4StringDelim); nd == -1 {
|
||||
i = n
|
||||
} else {
|
||||
i += nd + len(gemma4StringDelim)
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch s[i] {
|
||||
case open:
|
||||
depth++
|
||||
case close:
|
||||
depth--
|
||||
}
|
||||
i++
|
||||
}
|
||||
if depth > 0 {
|
||||
return s[innerStart:], n
|
||||
}
|
||||
return s[innerStart : i-1], i
|
||||
}
|
||||
|
||||
// decodeGemma4Bare maps an undelimited value to its JSON form: booleans,
|
||||
// null aliases (null/none/nil, case-insensitive - the renderer writes
|
||||
// Python None as "None", tpl L144-L145 via format_argument's else branch),
|
||||
// numbers (vLLM's rule: a '.' tries float, otherwise int; anything that
|
||||
// fails parses as a bare string).
|
||||
func decodeGemma4Bare(raw string) string {
|
||||
v := strings.TrimSpace(raw)
|
||||
if v == "" {
|
||||
return `""`
|
||||
}
|
||||
if v == "true" || v == "false" {
|
||||
return v
|
||||
}
|
||||
switch strings.ToLower(v) {
|
||||
case "null", "none", "nil":
|
||||
return "null"
|
||||
}
|
||||
if strings.Contains(v, ".") {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return formatGemma4Float(f)
|
||||
}
|
||||
} else if iv, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return strconv.FormatInt(iv, 10)
|
||||
}
|
||||
return gemma4JSONString(v)
|
||||
}
|
||||
|
||||
// formatGemma4Float renders like Python's json.dumps(float): integral floats
|
||||
// keep a ".0" suffix ("108." decodes to 108.0, not 108), so the arguments
|
||||
// JSON matches what vLLM would have produced for the same payload.
|
||||
func formatGemma4Float(f float64) string {
|
||||
s := strconv.FormatFloat(f, 'g', -1, 64)
|
||||
if !strings.ContainsAny(s, ".eE") {
|
||||
s += ".0"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// gemma4JSONString encodes a JSON string WITHOUT HTML escaping (json.Marshal
|
||||
// would escape the angle brackets in "<div>" to \u003c / \u003e sequences;
|
||||
// payload text should survive
|
||||
// byte-for-byte, like Python's json.dumps(ensure_ascii=False)).
|
||||
func gemma4JSONString(s string) string {
|
||||
var sb strings.Builder
|
||||
enc := json.NewEncoder(&sb)
|
||||
enc.SetEscapeHTML(false)
|
||||
if err := enc.Encode(s); err != nil {
|
||||
// Unreachable for plain strings; fall back to default escaping
|
||||
// rather than emitting invalid JSON.
|
||||
b, mErr := json.Marshal(s)
|
||||
if mErr != nil {
|
||||
return `""`
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
// Encode appends a trailing newline.
|
||||
return strings.TrimSuffix(sb.String(), "\n")
|
||||
}
|
||||
@@ -1,592 +0,0 @@
|
||||
package main
|
||||
|
||||
// Parser specs for Gemma4Parser (model output text -> pb.ChatDelta events).
|
||||
//
|
||||
// Fixture provenance:
|
||||
// - Entries marked "vLLM: <name>" are direct ports of the named test from
|
||||
// vLLM PR #45163, tests/tool_parsers/test_gemma4_tool_parser.py (the
|
||||
// authoritative test-suite for the gemma4 tool-call wire format). The
|
||||
// streaming tests' chunk lists are reused verbatim as Feed fragments.
|
||||
// - Decoder entries port the TestParseGemma4Args / TestParseGemma4Array
|
||||
// classes from the same file (non-partial mode only; this parser never
|
||||
// decodes partial payloads, see the divergence note in gemma4_parser.go).
|
||||
// - Channel/turn-marker expectations come from the chat template embedded
|
||||
// in gemma4_renderer.go (tpl L356-L362 generation prompt, L148-L158
|
||||
// strip_thinking) and vLLM's Gemma4ReasoningParser
|
||||
// (vllm/reasoning/gemma4_reasoning_parser.py).
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// flatGemma4Tool is one accumulated tool call, mirroring how LocalAI core
|
||||
// folds ToolCallDelta streams (pkg/functions/chat_deltas.go
|
||||
// ToolCallsFromChatDeltas: name/id latch on first non-empty, arguments
|
||||
// concatenate per index). Tests flatten through the same rules so they
|
||||
// assert exactly what core will reconstruct.
|
||||
type flatGemma4Tool struct {
|
||||
id string
|
||||
name string
|
||||
args string
|
||||
}
|
||||
|
||||
func flattenGemma4Deltas(deltas []*pb.ChatDelta) (string, string, []flatGemma4Tool) {
|
||||
var content, reasoning strings.Builder
|
||||
byIndex := map[int32]*flatGemma4Tool{}
|
||||
maxIdx := int32(-1)
|
||||
for _, d := range deltas {
|
||||
content.WriteString(d.GetContent())
|
||||
reasoning.WriteString(d.GetReasoningContent())
|
||||
for _, tc := range d.GetToolCalls() {
|
||||
acc, ok := byIndex[tc.GetIndex()]
|
||||
if !ok {
|
||||
acc = &flatGemma4Tool{}
|
||||
byIndex[tc.GetIndex()] = acc
|
||||
}
|
||||
if tc.GetName() != "" {
|
||||
acc.name = tc.GetName()
|
||||
}
|
||||
if tc.GetId() != "" {
|
||||
acc.id = tc.GetId()
|
||||
}
|
||||
acc.args += tc.GetArguments()
|
||||
if tc.GetIndex() > maxIdx {
|
||||
maxIdx = tc.GetIndex()
|
||||
}
|
||||
}
|
||||
}
|
||||
var tools []flatGemma4Tool
|
||||
for i := int32(0); i <= maxIdx; i++ {
|
||||
if acc, ok := byIndex[i]; ok {
|
||||
tools = append(tools, *acc)
|
||||
}
|
||||
}
|
||||
return content.String(), reasoning.String(), tools
|
||||
}
|
||||
|
||||
type wantGemma4Tool struct {
|
||||
name string
|
||||
argsJSON string // compared with MatchJSON (key order irrelevant)
|
||||
}
|
||||
|
||||
type parseGemma4Case struct {
|
||||
startInThought bool
|
||||
fragments []string
|
||||
wantContent string
|
||||
wantReasoning string
|
||||
wantTools []wantGemma4Tool
|
||||
}
|
||||
|
||||
func parseGemma4Fragments(startInThought bool, fragments []string) []*pb.ChatDelta {
|
||||
p := NewGemma4Parser(startInThought)
|
||||
var all []*pb.ChatDelta
|
||||
for _, f := range fragments {
|
||||
all = append(all, p.Feed(f)...)
|
||||
}
|
||||
return append(all, p.Close()...)
|
||||
}
|
||||
|
||||
var _ = Describe("Gemma4Parser", func() {
|
||||
DescribeTable("parses streamed gemma4 output into ChatDeltas",
|
||||
func(c parseGemma4Case) {
|
||||
content, reasoning, tools := flattenGemma4Deltas(parseGemma4Fragments(c.startInThought, c.fragments))
|
||||
Expect(content).To(Equal(c.wantContent))
|
||||
Expect(reasoning).To(Equal(c.wantReasoning))
|
||||
Expect(tools).To(HaveLen(len(c.wantTools)))
|
||||
seenIDs := map[string]bool{}
|
||||
for i, want := range c.wantTools {
|
||||
Expect(tools[i].name).To(Equal(want.name), "tool %d name", i)
|
||||
Expect(tools[i].args).To(MatchJSON(want.argsJSON), "tool %d arguments", i)
|
||||
Expect(tools[i].id).ToNot(BeEmpty(), "tool %d id", i)
|
||||
Expect(seenIDs).ToNot(HaveKey(tools[i].id), "tool %d id must be unique", i)
|
||||
seenIDs[tools[i].id] = true
|
||||
}
|
||||
},
|
||||
|
||||
// --- (1) pure content -------------------------------------------------
|
||||
// vLLM: test_no_tool_calls
|
||||
Entry("pure content, single fragment", parseGemma4Case{
|
||||
fragments: []string{"Hello, how can I help you today?"},
|
||||
wantContent: "Hello, how can I help you today?",
|
||||
}),
|
||||
|
||||
// --- (2) thought -> final transition ----------------------------------
|
||||
// enable_thinking render: prompt ends at <|turn>model\n and the model
|
||||
// opens/closes its own thought channel in the OUTPUT (vLLM
|
||||
// Gemma4ReasoningParser docstring; tpl L356-L362). The "thought\n"
|
||||
// role label after <|channel> is structural and must be stripped
|
||||
// (vLLM _THOUGHT_PREFIX handling).
|
||||
Entry("thought channel then final content", parseGemma4Case{
|
||||
fragments: []string{"<|channel>thought\nLet me think about this.\n<channel|>The answer is 42."},
|
||||
wantReasoning: "Let me think about this.\n",
|
||||
wantContent: "The answer is 42.",
|
||||
}),
|
||||
|
||||
// --- (3) startInThought both ways -------------------------------------
|
||||
Entry("startInThought=true routes initial text to reasoning until <channel|>", parseGemma4Case{
|
||||
startInThought: true,
|
||||
fragments: []string{"I am thinking hard.<channel|>Done."},
|
||||
wantReasoning: "I am thinking hard.",
|
||||
wantContent: "Done.",
|
||||
}),
|
||||
// A stray <channel|> with no open channel is swallowed, matching the
|
||||
// template's strip_thinking (tpl L148-L158: the marker is dropped,
|
||||
// text on both sides is kept).
|
||||
Entry("startInThought=false keeps the same text as content, stray <channel|> swallowed", parseGemma4Case{
|
||||
startInThought: false,
|
||||
fragments: []string{"I am thinking hard.<channel|>Done."},
|
||||
wantContent: "I am thinking hard.Done.",
|
||||
}),
|
||||
|
||||
// --- (4) one tool call, full payload type zoo --------------------------
|
||||
Entry("single tool call: strings, numbers, bools, null, nested object and array", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:complex_function{text:<|"|>with, comma and {braces}<|"|>,count:42,score:3.14,yes:true,no:false,nothing:null,obj:{inner:<|"|>v<|"|>,k:1},arr:[<|"|>a<|"|>,2,true]}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{
|
||||
name: "complex_function",
|
||||
argsJSON: `{"text":"with, comma and {braces}","count":42,"score":3.14,"yes":true,"no":false,"nothing":null,"obj":{"inner":"v","k":1},"arr":["a",2,true]}`,
|
||||
}},
|
||||
}),
|
||||
|
||||
// --- (5) payload split across 3 fragments ------------------------------
|
||||
Entry("tool-call payload split across three fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>call:get_weather{loc",
|
||||
`ation:<|"|>Paris, Fra`,
|
||||
`nce<|"|>}<tool_call|>`,
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
|
||||
}),
|
||||
|
||||
// --- (6) marker split across fragments ----------------------------------
|
||||
Entry("tool-call open marker split across fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_ca",
|
||||
`ll>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
Entry("channel open marker split across fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|chan",
|
||||
"nel>thought\ndeep thought<channel|>final",
|
||||
},
|
||||
wantReasoning: "deep thought",
|
||||
wantContent: "final",
|
||||
}),
|
||||
|
||||
// --- (7) trailing partial marker held, flushed by Close -----------------
|
||||
Entry("trailing partial marker is held back and flushed by Close", parseGemma4Case{
|
||||
fragments: []string{"Hello <|tool"},
|
||||
wantContent: "Hello <|tool",
|
||||
}),
|
||||
|
||||
// --- (8) malformed/incomplete payload -> content fallback ---------------
|
||||
// vLLM: test_incomplete_tool_call (no end marker: the whole text stays
|
||||
// content, never silently dropped).
|
||||
Entry("incomplete tool payload at Close is emitted as raw content", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London`},
|
||||
wantContent: `<|tool_call>call:get_weather{location:<|"|>London`,
|
||||
}),
|
||||
Entry("malformed complete payload is emitted as raw content, parsing continues", parseGemma4Case{
|
||||
fragments: []string{"<|tool_call>oops no call syntax<tool_call|> done"},
|
||||
wantContent: "<|tool_call>oops no call syntax<tool_call|> done",
|
||||
}),
|
||||
|
||||
// --- (9) <turn|> ends the turn -------------------------------------------
|
||||
Entry("text after <turn|> is ignored, including later fragments", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"before<turn|>after",
|
||||
`more <|tool_call>call:f{}<tool_call|>`,
|
||||
},
|
||||
wantContent: "before",
|
||||
}),
|
||||
Entry("<turn|> inside a thought channel ends the turn", parseGemma4Case{
|
||||
startInThought: true,
|
||||
fragments: []string{"thinking<turn|>ignored"},
|
||||
wantReasoning: "thinking",
|
||||
}),
|
||||
|
||||
// --- (10) ported vLLM non-streaming cases ---------------------------------
|
||||
// vLLM: test_single_tool_call
|
||||
Entry("vLLM: test_single_tool_call", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_multiple_arguments
|
||||
Entry("vLLM: test_multiple_arguments", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"San Francisco","unit":"celsius"}`}},
|
||||
}),
|
||||
// vLLM: test_text_before_tool_call. DIVERGENCE: vLLM's non-streaming
|
||||
// extractor trims the content ("...you."); a streaming parser cannot
|
||||
// retroactively trim already-emitted text, so the trailing space is
|
||||
// kept (vLLM's own streaming path keeps it too, see
|
||||
// test_streaming_text_before_tool_call which only checks a prefix).
|
||||
Entry("vLLM: test_text_before_tool_call (streaming semantics: no trim)", parseGemma4Case{
|
||||
fragments: []string{`Let me check the weather for you. <|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`},
|
||||
wantContent: "Let me check the weather for you. ",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
|
||||
}),
|
||||
// vLLM: test_multiple_tool_calls (also covers case 11: multi-tool sequence)
|
||||
Entry("vLLM: test_multiple_tool_calls", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|><|tool_call>call:get_time{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{
|
||||
{name: "get_weather", argsJSON: `{"location":"London"}`},
|
||||
{name: "get_time", argsJSON: `{"location":"London"}`},
|
||||
},
|
||||
}),
|
||||
// vLLM: test_nested_arguments
|
||||
Entry("vLLM: test_nested_arguments", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:complex_function{nested:{inner:<|"|>value<|"|>},list:[<|"|>a<|"|>,<|"|>b<|"|>]}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "complex_function", argsJSON: `{"nested":{"inner":"value"},"list":["a","b"]}`}},
|
||||
}),
|
||||
// vLLM: test_tool_call_with_number_and_boolean
|
||||
Entry("vLLM: test_tool_call_with_number_and_boolean", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:set_status{is_active:true,count:42,score:3.14}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "set_status", argsJSON: `{"is_active":true,"count":42,"score":3.14}`}},
|
||||
}),
|
||||
// vLLM: test_hyphenated_function_name
|
||||
Entry("vLLM: test_hyphenated_function_name", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "get-weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_dotted_function_name
|
||||
Entry("vLLM: test_dotted_function_name", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "weather.get", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_no_arguments
|
||||
Entry("vLLM: test_no_arguments", parseGemma4Case{
|
||||
fragments: []string{"<|tool_call>call:get_status{}<tool_call|>"},
|
||||
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
|
||||
}),
|
||||
|
||||
// --- ported vLLM streaming cases (chunk lists reused as fragments) --------
|
||||
// vLLM: test_basic_streaming_single_tool
|
||||
Entry("vLLM: test_basic_streaming_single_tool", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>Paris`,
|
||||
", France",
|
||||
`<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris, France"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_multi_arg
|
||||
Entry("vLLM: test_streaming_multi_arg", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>Tokyo<|"|>,`,
|
||||
`unit:<|"|>celsius<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Tokyo","unit":"celsius"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_text_before_tool_call
|
||||
Entry("vLLM: test_streaming_text_before_tool_call", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"Let me check ",
|
||||
"the weather. ",
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>London<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantContent: "Let me check the weather. ",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"London"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_numeric_args
|
||||
Entry("vLLM: test_streaming_numeric_args", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:set_config{",
|
||||
"count:42,",
|
||||
"active:true}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "set_config", argsJSON: `{"count":42,"active":true}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_boolean_split_across_chunks
|
||||
Entry("vLLM: test_streaming_boolean_split_across_chunks", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:search{input:{all:tru",
|
||||
"e}}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "search", argsJSON: `{"input":{"all":true}}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_false_split_across_chunks
|
||||
Entry("vLLM: test_streaming_false_split_across_chunks", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:set{flag:fals",
|
||||
"e}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"flag":false}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_number_split_across_chunks
|
||||
Entry("vLLM: test_streaming_number_split_across_chunks", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:set{count:4",
|
||||
"2}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "set", argsJSON: `{"count":42}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_empty_args
|
||||
Entry("vLLM: test_streaming_empty_args", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_status{}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "get_status", argsJSON: `{}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_split_delimiter_no_invalid_json (string
|
||||
// delimiter <|"|> split across fragments must not leak fragments).
|
||||
Entry("vLLM: test_streaming_split_delimiter_no_invalid_json", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:todowrite{",
|
||||
`content:<|"|>Buy milk<|`,
|
||||
`"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{name: "todowrite", argsJSON: `{"content":"Buy milk"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call
|
||||
Entry("vLLM: test_streaming_does_not_duplicate_plain_text_after_tool_call", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
`location:<|"|>Paris<|"|>}`,
|
||||
"<tool_call|><",
|
||||
"div>",
|
||||
},
|
||||
wantContent: "<div>",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Paris"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes
|
||||
Entry("vLLM: test_streaming_html_argument_does_not_duplicate_tag_prefixes", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:write_file{",
|
||||
`path:<|"|>index.html<|"|>,`,
|
||||
`content:<|"|><!DOCTYPE html>` + "\n<",
|
||||
`html lang="zh-CN">` + "\n<",
|
||||
"head>\n <",
|
||||
`meta charset="UTF-8">` + "\n <",
|
||||
`meta name="viewport" content="width=device-width">` + "\n",
|
||||
`<|"|>}`,
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{
|
||||
name: "write_file",
|
||||
argsJSON: `{"path":"index.html","content":"<!DOCTYPE html>\n<html lang=\"zh-CN\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width\">\n"}`,
|
||||
}},
|
||||
}),
|
||||
// vLLM: test_streaming_single_chunk_complete_tool_call
|
||||
Entry("vLLM: test_streaming_single_chunk_complete_tool_call", parseGemma4Case{
|
||||
fragments: []string{`<|tool_call>call:name_a_color{color_hex:<|"|>00ff11<|"|>}<tool_call|>`},
|
||||
wantTools: []wantGemma4Tool{{name: "name_a_color", argsJSON: `{"color_hex":"00ff11"}`}},
|
||||
}),
|
||||
// vLLM: test_streaming_multi_chunk_batched_tool_calls (two complete
|
||||
// calls in ONE fragment; both must come out with distinct indices)
|
||||
Entry("vLLM: test_streaming_multi_chunk_batched_tool_calls", parseGemma4Case{
|
||||
fragments: []string{
|
||||
`<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>` +
|
||||
`<|tool_call>call:get_time{timezone:<|"|>GMT<|"|>}<tool_call|>`,
|
||||
},
|
||||
wantTools: []wantGemma4Tool{
|
||||
{name: "get_weather", argsJSON: `{"location":"London"}`},
|
||||
{name: "get_time", argsJSON: `{"timezone":"GMT"}`},
|
||||
},
|
||||
}),
|
||||
// vLLM: test_streaming_trailing_bare_bool_not_duplicated
|
||||
Entry("vLLM: test_streaming_trailing_bare_bool_not_duplicated", parseGemma4Case{
|
||||
fragments: []string{
|
||||
"<|tool_call>",
|
||||
"call:Edit{",
|
||||
`file_path:<|"|>src/env.py<|"|>,`,
|
||||
`old_string:<|"|>old_val<|"|>,`,
|
||||
`new_string:<|"|>new_val<|"|>,`,
|
||||
"replace_all:",
|
||||
"false}",
|
||||
"<tool_call|>",
|
||||
},
|
||||
wantTools: []wantGemma4Tool{{
|
||||
name: "Edit",
|
||||
argsJSON: `{"file_path":"src/env.py","old_string":"old_val","new_string":"new_val","replace_all":false}`,
|
||||
}},
|
||||
}),
|
||||
|
||||
// --- implicit reasoning end on <|tool_call> (vLLM is_reasoning_end:
|
||||
// a tool_call token means reasoning is over) -----------------------------
|
||||
Entry("tool call inside an open thought channel ends the reasoning", parseGemma4Case{
|
||||
startInThought: true,
|
||||
fragments: []string{`need the weather<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}<tool_call|>`},
|
||||
wantReasoning: "need the weather",
|
||||
wantTools: []wantGemma4Tool{{name: "get_weather", argsJSON: `{"location":"Rome"}`}},
|
||||
}),
|
||||
|
||||
// --- (12) empty fragments are no-ops --------------------------------------
|
||||
Entry("empty fragments are no-ops", parseGemma4Case{
|
||||
fragments: []string{"", "Hello", "", "", " world", ""},
|
||||
wantContent: "Hello world",
|
||||
}),
|
||||
)
|
||||
|
||||
It("returns no deltas for an empty fragment and after Close", func() {
|
||||
p := NewGemma4Parser(false)
|
||||
Expect(p.Feed("")).To(BeEmpty())
|
||||
Expect(p.Feed("hi")).ToNot(BeEmpty())
|
||||
Expect(p.Close()).To(BeEmpty()) // nothing held back
|
||||
// The parser is finished after Close: further input is dropped.
|
||||
Expect(p.Feed("more")).To(BeEmpty())
|
||||
Expect(p.Close()).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("generates index-based tool call ids (call_<index>)", func() {
|
||||
// Mirrors the index-based id convention of pkg/grpc/rich_test.go and
|
||||
// keeps ids deterministic for the split-invariance property below.
|
||||
deltas := parseGemma4Fragments(false, []string{
|
||||
`<|tool_call>call:a{}<tool_call|><|tool_call>call:b{}<tool_call|>`,
|
||||
})
|
||||
_, _, tools := flattenGemma4Deltas(deltas)
|
||||
Expect(tools).To(HaveLen(2))
|
||||
Expect(tools[0].id).To(Equal("call_0"))
|
||||
Expect(tools[1].id).To(Equal("call_1"))
|
||||
})
|
||||
|
||||
// Property: for a fixed full output, EVERY 2-split position must yield
|
||||
// exactly the same flattened result as the unsplit parse. This kills
|
||||
// fragment-boundary bugs (mid-marker, mid-delimiter, mid-payload splits).
|
||||
DescribeTable("2-split fragment invariance",
|
||||
func(startInThought bool, full string) {
|
||||
refContent, refReasoning, refTools := flattenGemma4Deltas(
|
||||
parseGemma4Fragments(startInThought, []string{full}))
|
||||
for i := 0; i <= len(full); i++ {
|
||||
content, reasoning, tools := flattenGemma4Deltas(
|
||||
parseGemma4Fragments(startInThought, []string{full[:i], full[i:]}))
|
||||
Expect(content).To(Equal(refContent), fmt.Sprintf("content diverged at split %d", i))
|
||||
Expect(reasoning).To(Equal(refReasoning), fmt.Sprintf("reasoning diverged at split %d", i))
|
||||
Expect(tools).To(Equal(refTools), fmt.Sprintf("tool calls diverged at split %d", i))
|
||||
}
|
||||
},
|
||||
Entry("thought + content + two tool calls + turn end", false,
|
||||
"<|channel>thought\nPondering the request...\n<channel|>Sure - calling tools now. "+
|
||||
`<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>,days:3,detailed:true}<tool_call|>`+
|
||||
`<|tool_call>call:get_time{timezone:<|"|>Europe/Lisbon<|"|>,nested:{flag:false,vals:[1,2.5,<|"|>x<|"|>]}}<tool_call|>`+
|
||||
"Done.<turn|>ignored tail"),
|
||||
Entry("startInThought + tool call + trailing partial marker", true,
|
||||
`Deep thought<channel|>final answer <|tool_call>call:noop{}<tool_call|> trailing <|tool`),
|
||||
Entry("malformed payload fallback", false,
|
||||
`pre <|tool_call>not a call<tool_call|> post`),
|
||||
)
|
||||
})
|
||||
|
||||
// Decoder-level ports of vLLM's TestParseGemma4Args / TestParseGemma4Array
|
||||
// (non-partial mode; the partial-withholding tests do not apply because this
|
||||
// parser only ever decodes COMPLETE payloads, see gemma4_parser.go).
|
||||
var _ = Describe("decodeGemma4Args", func() {
|
||||
DescribeTable("decodes the gemma4 call syntax into JSON arguments",
|
||||
func(in, wantJSON string) {
|
||||
Expect(decodeGemma4Args(in, 0)).To(MatchJSON(wantJSON))
|
||||
},
|
||||
// vLLM: test_empty_string / test_whitespace_only
|
||||
Entry("empty string", "", `{}`),
|
||||
Entry("whitespace only", " ", `{}`),
|
||||
// vLLM: test_single_string_value
|
||||
Entry("single string value", `location:<|"|>Paris<|"|>`, `{"location":"Paris"}`),
|
||||
// vLLM: test_string_value_with_comma
|
||||
Entry("string value with comma", `location:<|"|>Paris, France<|"|>`, `{"location":"Paris, France"}`),
|
||||
// vLLM: test_multiple_string_values
|
||||
Entry("multiple string values", `location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>`, `{"location":"San Francisco","unit":"celsius"}`),
|
||||
// vLLM: test_integer_value / test_float_value
|
||||
Entry("integer value", "count:42", `{"count":42}`),
|
||||
Entry("float value", "score:3.14", `{"score":3.14}`),
|
||||
// vLLM: test_boolean_true / test_boolean_false
|
||||
Entry("boolean true", "flag:true", `{"flag":true}`),
|
||||
Entry("boolean false", "flag:false", `{"flag":false}`),
|
||||
// vLLM: test_null_value (bare null must become JSON null, not "null")
|
||||
Entry("null value", "param:null", `{"param":null}`),
|
||||
// vLLM: test_mixed_types
|
||||
Entry("mixed types", `name:<|"|>test<|"|>,count:42,active:true,score:3.14`,
|
||||
`{"name":"test","count":42,"active":true,"score":3.14}`),
|
||||
// vLLM: test_nested_object
|
||||
Entry("nested object", `nested:{inner:<|"|>value<|"|>}`, `{"nested":{"inner":"value"}}`),
|
||||
// vLLM: test_array_of_strings
|
||||
Entry("array of strings", `items:[<|"|>a<|"|>,<|"|>b<|"|>]`, `{"items":["a","b"]}`),
|
||||
// vLLM: test_unterminated_string (take everything after the delimiter)
|
||||
Entry("unterminated string", `key:<|"|>unterminated`, `{"key":"unterminated"}`),
|
||||
// vLLM: test_empty_value (key with no value after colon)
|
||||
Entry("empty value", "key:", `{"key":""}`),
|
||||
// vLLM: test_trailing_dot_float_partial_withheld, non-partial branch
|
||||
// (trailing-dot floats parse normally outside streaming).
|
||||
Entry("trailing dot float, complete payload", "left:108.,right:22.8", `{"left":108.0,"right":22.8}`),
|
||||
)
|
||||
|
||||
It("terminates and yields valid JSON on malformed input", func() {
|
||||
// vLLM: test_malformed_partial_array (the assertion there is only
|
||||
// "returns a dict without hanging"; ours is "valid JSON object").
|
||||
out := decodeGemma4Args(":[t:[]", 0)
|
||||
var v map[string]any
|
||||
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
|
||||
})
|
||||
|
||||
It("degrades nesting beyond the recursion cap to a string value", func() {
|
||||
// 200 levels of a:{a:{...a:1...}}. Without the depth cap the mutual
|
||||
// recursion would grow the stack with the model's output; a Go stack
|
||||
// overflow is a fatal process kill, so levels past gemma4MaxArgsDepth
|
||||
// must gracefully fall back to the raw inner text as a JSON string.
|
||||
const depth = 200
|
||||
body := strings.Repeat("a:{", depth-1) + "a:1" + strings.Repeat("}", depth-1)
|
||||
out := decodeGemma4Args(body, 0)
|
||||
var v map[string]any
|
||||
Expect(json.Unmarshal([]byte(out), &v)).To(Succeed())
|
||||
levels := 0
|
||||
var cur any = v
|
||||
for {
|
||||
m, ok := cur.(map[string]any)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
Expect(m).To(HaveKey("a"))
|
||||
cur = m["a"]
|
||||
levels++
|
||||
}
|
||||
Expect(levels).To(Equal(gemma4MaxArgsDepth + 1))
|
||||
Expect(cur).To(BeAssignableToTypeOf(""))
|
||||
Expect(cur).To(ContainSubstring("a:{"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("decodeGemma4Array", func() {
|
||||
DescribeTable("decodes gemma4 array bodies into JSON arrays",
|
||||
func(in, wantJSON string) {
|
||||
Expect(decodeGemma4Array(in, 0)).To(MatchJSON(wantJSON))
|
||||
},
|
||||
// vLLM: test_string_array / test_empty_array / test_bare_values
|
||||
Entry("string array", `<|"|>a<|"|>,<|"|>b<|"|>`, `["a","b"]`),
|
||||
Entry("empty array", "", `[]`),
|
||||
Entry("bare values", "42,true,3.14", `[42,true,3.14]`),
|
||||
// vLLM: test_string_element_with_closing_bracket (a ']' inside a
|
||||
// delimited string must not close the array)
|
||||
Entry("string element with closing bracket", `[<|"|>a]b<|"|>,<|"|>c<|"|>],<|"|>tail<|"|>`, `[["a]b","c"],"tail"]`),
|
||||
// vLLM: test_stray_closing_bracket (no-progress abort, keep prefix)
|
||||
Entry("stray closing bracket", "42,]trailing", `[42]`),
|
||||
)
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,347 +0,0 @@
|
||||
package main
|
||||
|
||||
// Renderer specs for RenderGemma4 against the canonical gemma4 chat template
|
||||
// (see the normative template comment in gemma4_renderer.go).
|
||||
//
|
||||
// Fixture provenance:
|
||||
// - "single user message" and "enable_thinking" are the EXACT expected
|
||||
// decodes from transformers tests/models/diffusion_gemma/
|
||||
// test_modeling_diffusion_gemma.py (test_diffusion_gemma_chat_template
|
||||
// and ..._with_thinking) with ONE difference: the transformers fixtures
|
||||
// start with "<bos>" because apply_chat_template tokenizes the rendered
|
||||
// text with add_bos. Our prompt goes through dllm_capi_generate, whose
|
||||
// run_generate already tokenizes with prepend_bos = vocab.add_bos
|
||||
// (dllm.cpp src/capi.cpp:230-231, true for gemma4), so the renderer must
|
||||
// NOT emit a literal <bos> (it would double) and every expected string
|
||||
// here drops that leading token.
|
||||
// - All other expected strings were produced by rendering the verbatim
|
||||
// GGUF template with jinja2 3.1.2 (bos_token="<bos>") and dropping the
|
||||
// leading "<bos>" for the same reason.
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// Two-function tools array used by the tool fixtures (OpenAI wire shape, as
|
||||
// LocalAI passes it through PredictOptions.Tools).
|
||||
const testToolsJSON = `[{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a location.","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city name."},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}},{"type":"function","function":{"name":"get_time","description":"Get the current time in a timezone.","parameters":{"type":"object","properties":{"timezone":{"type":"string","description":"IANA timezone name."}},"required":["timezone"]}}}]`
|
||||
|
||||
// The <|tool>...<tool|> block the template renders for testToolsJSON inside
|
||||
// the system turn (jinja2-verified).
|
||||
const testToolsBlock = `<|tool>declaration:get_weather{description:<|"|>Get the current weather in a location.<|"|>,parameters:{properties:{location:{description:<|"|>The city name.<|"|>,type:<|"|>STRING<|"|>},unit:{enum:[<|"|>celsius<|"|>,<|"|>fahrenheit<|"|>],type:<|"|>STRING<|"|>}},required:[<|"|>location<|"|>],type:<|"|>OBJECT<|"|>}}<tool|><|tool>declaration:get_time{description:<|"|>Get the current time in a timezone.<|"|>,parameters:{properties:{timezone:{description:<|"|>IANA timezone name.<|"|>,type:<|"|>STRING<|"|>}},required:[<|"|>timezone<|"|>],type:<|"|>OBJECT<|"|>}}<tool|>`
|
||||
|
||||
// A single tool exercising the deep format_parameters branches: array items
|
||||
// (string-typed and nested-array), nullable, enum+nullable, nested object
|
||||
// properties/required, and a response declaration.
|
||||
const complexToolsJSON = `[{"type":"function","function":{"name":"complex_tool","description":"A complex tool.","parameters":{"type":"object","properties":{"tags":{"type":"array","description":"Tags.","items":{"type":"string"}},"matrix":{"type":"array","items":{"type":"array","items":{"type":"number"}}},"opts":{"type":"object","description":"Options.","properties":{"depth":{"type":"integer","nullable":true}},"required":["depth"]},"mode":{"type":"string","enum":["a","b"],"nullable":true}},"required":["tags","opts"]},"response":{"description":"The result.","type":"object"}}}]`
|
||||
|
||||
// jinja2-verified render of complexToolsJSON. Notable template quirks pinned
|
||||
// here: nested array items go through format_argument with ESCAPED keys and
|
||||
// an un-uppercased type (<|"|>type<|"|>:<|"|>number<|"|>), while direct item
|
||||
// types are uppercased; properties dictsort case-insensitively.
|
||||
const complexToolsBlock = `<|tool>declaration:complex_tool{description:<|"|>A complex tool.<|"|>,parameters:{properties:{matrix:{items:{items:{<|"|>type<|"|>:<|"|>number<|"|>},type:<|"|>ARRAY<|"|>},type:<|"|>ARRAY<|"|>},mode:{enum:[<|"|>a<|"|>,<|"|>b<|"|>],nullable:true,type:<|"|>STRING<|"|>},opts:{description:<|"|>Options.<|"|>,properties:{depth:{nullable:true,type:<|"|>INTEGER<|"|>}},required:[<|"|>depth<|"|>],type:<|"|>OBJECT<|"|>},tags:{description:<|"|>Tags.<|"|>,items:{type:<|"|>STRING<|"|>},type:<|"|>ARRAY<|"|>}},required:[<|"|>tags<|"|>,<|"|>opts<|"|>],type:<|"|>OBJECT<|"|>},response:{description:<|"|>The result.<|"|>,type:<|"|>OBJECT<|"|>}}<tool|>`
|
||||
|
||||
type renderGemma4Case struct {
|
||||
msgs []*pb.Message
|
||||
toolsJSON string
|
||||
enableThinking bool
|
||||
noGenerationPrompt bool // inverted so the zero value is the common case
|
||||
expected string
|
||||
}
|
||||
|
||||
var _ = Describe("RenderGemma4", func() {
|
||||
DescribeTable("renders the canonical gemma4 prompt",
|
||||
func(c renderGemma4Case) {
|
||||
out, err := RenderGemma4(c.msgs, c.toolsJSON, c.enableThinking, !c.noGenerationPrompt)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal(c.expected))
|
||||
// The C-ABI generate prepends BOS itself: a literal <bos>
|
||||
// anywhere in the rendered prompt would double-encode it.
|
||||
Expect(out).ToNot(ContainSubstring("<bos>"))
|
||||
},
|
||||
|
||||
// transformers fixture (test_diffusion_gemma_chat_template), sans <bos>:
|
||||
// default thinking pre-opens an EMPTY thought channel in the
|
||||
// generation prompt.
|
||||
Entry("single user message, default (no thinking)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Write a long essay about Portugal."},
|
||||
},
|
||||
expected: "<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// transformers fixture (test_diffusion_gemma_chat_template_with_thinking),
|
||||
// sans <bos>: a system turn carrying <|think|> and NO auto-opened
|
||||
// thought channel.
|
||||
Entry("enable_thinking=true", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Write a long essay about Portugal."},
|
||||
},
|
||||
enableThinking: true,
|
||||
expected: "<|turn>system\n<|think|>\n<turn|>\n<|turn>user\nWrite a long essay about Portugal.<turn|>\n<|turn>model\n",
|
||||
}),
|
||||
|
||||
Entry("multi-turn user/assistant/user", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "Hello, who are you?"},
|
||||
{Role: "assistant", Content: "I am Gemma, a helpful assistant."},
|
||||
{Role: "user", Content: "Tell me a joke."},
|
||||
},
|
||||
expected: "<|turn>user\nHello, who are you?<turn|>\n<|turn>model\nI am Gemma, a helpful assistant.<turn|>\n<|turn>user\nTell me a joke.<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L178-L195: a leading system message is folded into the system
|
||||
// turn (trimmed) and consumed from the loop.
|
||||
Entry("system message folds into the system turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "system", Content: "You are a pirate."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|turn>system\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L182-L185: <|think|> goes at the very top of the SAME system
|
||||
// turn, before the system prompt text.
|
||||
Entry("system message with enable_thinking shares the turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "system", Content: "You are a pirate."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
enableThinking: true,
|
||||
expected: "<|turn>system\n<|think|>\nYou are a pirate.<turn|>\n<|turn>user\nHello!<turn|>\n<|turn>model\n",
|
||||
}),
|
||||
|
||||
// tpl L196-L203: tool declarations render in the system turn, one
|
||||
// <|tool>declaration:...<tool|> block per tool, no separators.
|
||||
Entry("tools array (two functions)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
||||
},
|
||||
toolsJSON: testToolsJSON,
|
||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// format_parameters deep branches (tpl L1-L85) + response declaration
|
||||
// (tpl L106-L116).
|
||||
Entry("complex tool schema (array items, nullable, nested object, response)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
},
|
||||
toolsJSON: complexToolsJSON,
|
||||
expected: "<|turn>system\n" + complexToolsBlock + "<turn|>\n<|turn>user\ngo<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L243-L313: assistant tool_calls render as
|
||||
// <|tool_call>call:name{args}<tool_call|>; the following role=tool
|
||||
// message renders inline as <|tool_response>response:name{value:..}
|
||||
// <tool_response|>; the model turn stays OPEN (no <turn|>, no new
|
||||
// generation prompt) so the model continues after the response.
|
||||
Entry("assistant tool_calls + role=tool result", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
|
||||
{Role: "tool", ToolCallId: "call_1", Content: "Sunny, 22 degrees celsius."},
|
||||
},
|
||||
toolsJSON: testToolsJSON,
|
||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny, 22 degrees celsius.<|"|>}<tool_response|>`,
|
||||
}),
|
||||
|
||||
// tpl L348-L349: a tool_calls turn with no rendered responses ends
|
||||
// on an OPEN <|tool_response> marker for the runtime to fill, and
|
||||
// add_generation_prompt adds nothing (tpl L357).
|
||||
Entry("assistant tool_calls without a result leaves <|tool_response> open", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "What is the weather in Tokyo?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\",\"unit\":\"celsius\"}"}}]`},
|
||||
},
|
||||
toolsJSON: testToolsJSON,
|
||||
expected: "<|turn>system\n" + testToolsBlock + "<turn|>\n<|turn>user\nWhat is the weather in Tokyo?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>`,
|
||||
}),
|
||||
|
||||
// tpl L237-L241: reasoning_content renders as a thought channel only
|
||||
// on a tool-calling turn after the last user message.
|
||||
Entry("reasoning_content with tool_calls renders the thought channel", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "weather?"},
|
||||
{Role: "assistant", Content: "", ReasoningContent: "I should call the tool", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
|
||||
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
|
||||
},
|
||||
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n<|channel>thought\nI should call the tool\n<channel|>" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>`,
|
||||
}),
|
||||
|
||||
// tpl L220-L235: the assistant answer following its own tool round
|
||||
// continues the SAME model turn (no second <|turn>model).
|
||||
Entry("tool round then final assistant answer then user", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "weather?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Tokyo\"}"}}]`},
|
||||
{Role: "tool", ToolCallId: "c1", Content: "Sunny"},
|
||||
{Role: "assistant", Content: "It is sunny."},
|
||||
{Role: "user", Content: "thanks"},
|
||||
},
|
||||
expected: "<|turn>user\nweather?<turn|>\n<|turn>model\n" + `<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}<tool_call|><|tool_response>response:get_weather{value:<|"|>Sunny<|"|>}<tool_response|>` + "It is sunny.<turn|>\n<|turn>user\nthanks<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// format_argument (tpl L118-L147): numbers keep their JSON literal,
|
||||
// booleans lower-case, nested maps have unquoted dictsorted keys,
|
||||
// arrays bracketed; top-level args are dictsorted case-insensitively.
|
||||
Entry("tool_call argument types (number/bool/nested/array)", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"count\":42,\"ratio\":3.5,\"flag\":true,\"off\":false,\"nested\":{\"x\":\"y\",\"n\":7},\"list\":[\"a\",1,true]}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{count:42,flag:true,list:[<|"|>a<|"|>,1,true],nested:{n:7,x:<|"|>y<|"|>},off:false,ratio:3.5}<tool_call|><|tool_response>`,
|
||||
}),
|
||||
|
||||
// jinja dictsort is case-insensitive: alpha sorts before Beta.
|
||||
Entry("tool_call argument dictsort is case-insensitive", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"Beta\":1,\"alpha\":2}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{alpha:2,Beta:1}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
// jinja renders Python None as "None" (round-trips through vLLM's
|
||||
// parser, which lowers "none" back to null).
|
||||
Entry("tool_call null argument renders as None", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{\"maybe\":null}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{maybe:None}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
Entry("tool_call empty arguments render empty braces", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
// tpl L253-L254: a non-object arguments string renders verbatim.
|
||||
Entry("tool_call non-object string arguments render verbatim", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"just text"}}]`},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n<|tool_call>call:f{just text}<tool_call|><|tool_response>",
|
||||
}),
|
||||
|
||||
// tpl L278-L285: unmatched tool_call_id falls back to the tool
|
||||
// message's own name.
|
||||
Entry("tool result name falls back when tool_call_id does not match", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "go"},
|
||||
{Role: "assistant", Content: "", ToolCalls: `[{"index":0,"id":"c1","type":"function","function":{"name":"f","arguments":"{}"}}]`},
|
||||
{Role: "tool", ToolCallId: "OTHER", Name: "named_tool", Content: "out"},
|
||||
},
|
||||
expected: "<|turn>user\ngo<turn|>\n<|turn>model\n" + `<|tool_call>call:f{}<tool_call|><|tool_response>response:named_tool{value:<|"|>out<|"|>}<tool_response|>`,
|
||||
}),
|
||||
|
||||
// strip_thinking (tpl L148-L158): historical assistant content loses
|
||||
// its <|channel>...<channel|> spans.
|
||||
Entry("assistant content thinking channels are stripped", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "<|channel>thought\nsecret\n<channel|>visible answer"},
|
||||
{Role: "user", Content: "more"},
|
||||
},
|
||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\nvisible answer<turn|>\n<|turn>user\nmore<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
// tpl L220-L235: consecutive assistant messages suppress the second
|
||||
// <|turn>model (continuation), but each still closes with <turn|>.
|
||||
Entry("consecutive assistant messages continue the model turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "part one"},
|
||||
{Role: "assistant", Content: "part two"},
|
||||
{Role: "user", Content: "ok"},
|
||||
},
|
||||
expected: "<|turn>user\nhi<turn|>\n<|turn>model\npart one<turn|>\npart two<turn|>\n<|turn>user\nok<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
}),
|
||||
|
||||
Entry("add_generation_prompt=false renders no model turn", renderGemma4Case{
|
||||
msgs: []*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
},
|
||||
noGenerationPrompt: true,
|
||||
expected: "<|turn>user\nhi<turn|>\n",
|
||||
}),
|
||||
)
|
||||
|
||||
Describe("error handling", func() {
|
||||
It("fails loud on an unknown role", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "narrator", Content: "Meanwhile..."},
|
||||
}, "", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(`unknown role "narrator"`))
|
||||
})
|
||||
|
||||
It("fails on invalid tools JSON", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, "{not json", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools JSON"))
|
||||
})
|
||||
|
||||
It("fails on invalid tool_calls JSON", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "", ToolCalls: "{not json"},
|
||||
}, "", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tool_calls JSON"))
|
||||
})
|
||||
|
||||
It("fails on an orphan tool message, naming its index", func() {
|
||||
// A role:tool message with no preceding assistant tool_calls turn
|
||||
// would be silently dropped by the jinja; we fail loud instead.
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "tool", Content: `{"temp": 20}`, ToolCallId: "call_1"},
|
||||
}, "", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("orphan tool message 1"))
|
||||
})
|
||||
|
||||
It("fails on trailing garbage after the tools JSON array", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, "[] junk", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools JSON"))
|
||||
})
|
||||
|
||||
It("fails when the tools JSON is not an array", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, `{"type":"function"}`, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools JSON is not an array"))
|
||||
})
|
||||
|
||||
It("fails when a tools array element is not an object", func() {
|
||||
_, err := RenderGemma4([]*pb.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
}, `[42]`, false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tools[0] is not an object"))
|
||||
})
|
||||
|
||||
It("rejects a nil message via the unknown-role check", func() {
|
||||
// Pins current behavior: pb getters are nil-safe, so a nil message
|
||||
// reads as role "" and trips the fail-loud unknown-role guard.
|
||||
_, err := RenderGemma4([]*pb.Message{nil}, "", false, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(`unknown role "" in message 0`))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,85 +0,0 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
//
|
||||
// Loads libdllm.so via purego and registers the 9-symbol flat C-ABI
|
||||
// declared in dllm.cpp's include/dllm_capi.h (ABI v1). The library name can
|
||||
// be overridden with DLLM_LIBRARY (mirrors the PARAKEET_LIBRARY /
|
||||
// WHISPER_LIBRARY convention in the sibling backends); the default looks
|
||||
// for the .so next to this binary (run.sh puts the package dir on
|
||||
// LD_LIBRARY_PATH).
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
// loadCAPI dlopens libName and binds the 9 dllm_capi_* entry points 1:1 to
|
||||
// dllm_capi.h, so an `nm libdllm.so | grep dllm_capi` is enough to spot
|
||||
// drift. Shared with the test suite (ensureLibLoaded), which drives the
|
||||
// bridge without the gRPC server.
|
||||
//
|
||||
// The C-ABI returns malloc'd char* buffers from tokenize_json/generate; we
|
||||
// register those as uintptr so we get the raw pointer back and can call
|
||||
// dllm_capi_free_string on it (purego's string return would copy and forget
|
||||
// the original pointer, leaking it on every call). last_error returns a
|
||||
// BORROWED pointer instead, so it is registered as a plain string: purego
|
||||
// copies it and nothing must be freed.
|
||||
func loadCAPI(libName string) error {
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dllm: dlopen %q: %w", libName, err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&cppAbiVersion, "dllm_capi_abi_version"},
|
||||
{&cppLoad, "dllm_capi_load"},
|
||||
{&cppFree, "dllm_capi_free"},
|
||||
{&cppLastError, "dllm_capi_last_error"},
|
||||
{&cppFreeString, "dllm_capi_free_string"},
|
||||
{&cppTokenizeJSON, "dllm_capi_tokenize_json"},
|
||||
{&cppGenerate, "dllm_capi_generate"},
|
||||
{&cppGenerateStream, "dllm_capi_generate_stream"},
|
||||
{&cppCancel, "dllm_capi_cancel"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("DLLM_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libdllm.so"
|
||||
}
|
||||
|
||||
if err := loadCAPI(libName); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Hard-fail on an ABI mismatch: the flat-pointer bindings above would
|
||||
// otherwise misbehave silently against a future libdllm.so.
|
||||
if v := cAbiVersion(); v != dllmABIVersion {
|
||||
panic(fmt.Errorf("dllm: libdllm.so ABI=%d, this backend speaks ABI=%d", v, dllmABIVersion))
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "[dllm] ABI=%d\n", cAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &Dllm{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# T1 packaging stub: copy the binary, run.sh and libdllm.so into package/.
|
||||
# The full ldd walk (libc, libstdc++, libgomp, GPU runtimes, arch
|
||||
# detection) lands with the registration task, mirroring
|
||||
# backend/go/whisper/package.sh.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/dllm-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libdllm.so + any soname symlinks, should upstream ever add them.
|
||||
cp -avf "$CURDIR"/libdllm.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libdllm.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo "T1 package layout (full ldd walk lands with registration):"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
@@ -1,16 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# If a self-contained ld.so was packaged, route through it so the
|
||||
# packaged libc / libstdc++ are used instead of the host's (matches the
|
||||
# whisper / parakeet-cpp backends' runtime layout).
|
||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec "$CURDIR/lib/ld.so" "$CURDIR/dllm-grpc" "$@"
|
||||
fi
|
||||
|
||||
exec "$CURDIR/dllm-grpc" "$@"
|
||||
@@ -1,6 +1,6 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
@@ -15,7 +15,7 @@
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=e270af73b94c9a5c37ec516230219ed4580e1db6
|
||||
PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
|
||||
@@ -7,12 +7,8 @@ import "time"
|
||||
type batchRequest struct {
|
||||
pcm []float32
|
||||
decoder int32
|
||||
// language is the per-request target locale ("" means the model default).
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang for the whole batch,
|
||||
// so the dispatcher only coalesces requests that share a language.
|
||||
language string
|
||||
tag string
|
||||
reply chan batchReply
|
||||
tag string
|
||||
reply chan batchReply
|
||||
}
|
||||
|
||||
// batchReply carries one per-item JSON object string (an element of the C-API's
|
||||
@@ -47,25 +43,13 @@ func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchReques
|
||||
// run is the dispatcher loop: accumulate submitted requests until either maxSize
|
||||
// is reached or maxWait elapses since the first queued request, then dispatch.
|
||||
// Exits when stop is closed (draining any partially-filled batch first).
|
||||
//
|
||||
// A batch carries ONE language (parakeet.cpp's batched C-API takes a single
|
||||
// target_lang), so a request whose language differs from the batch leader is
|
||||
// not coalesced: it is held in carry and becomes the leader of the next batch.
|
||||
// carry is therefore never dropped and its caller never deadlocks: every batch
|
||||
// (including a lone carry on stop) is dispatched, and runBatch replies to all.
|
||||
func (b *batcher) run(stop <-chan struct{}) {
|
||||
var carry *batchRequest
|
||||
for {
|
||||
var first *batchRequest
|
||||
if carry != nil {
|
||||
// A mismatched request from the previous fill leads this batch.
|
||||
first, carry = carry, nil
|
||||
} else {
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
batch := []*batchRequest{first}
|
||||
|
||||
@@ -80,22 +64,12 @@ func (b *batcher) run(stop <-chan struct{}) {
|
||||
for len(batch) < b.maxSize {
|
||||
select {
|
||||
case r := <-b.submit:
|
||||
if r.language != first.language {
|
||||
// Different language: carry it to the next batch so this
|
||||
// batch stays single-language, then dispatch what we have.
|
||||
carry = r
|
||||
break fill
|
||||
}
|
||||
batch = append(batch, r)
|
||||
case <-timer.C:
|
||||
break fill
|
||||
case <-stop:
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
// Don't strand a carried request's caller on shutdown.
|
||||
if carry != nil {
|
||||
b.runBatch([]*batchRequest{carry})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,60 +105,4 @@ var _ = Describe("batcher", func() {
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("never coalesces requests with different languages into one batch", func() {
|
||||
// parakeet.cpp's batched C-API takes ONE target_lang per batch, so the
|
||||
// dispatcher must keep every dispatched batch single-language. Submit a
|
||||
// mix of languages and assert (a) no batch ever carries more than one
|
||||
// distinct language and (b) every submitted request still gets a reply
|
||||
// (the mismatched carry-over is never dropped).
|
||||
var mu sync.Mutex
|
||||
var langsPerBatch [][]string
|
||||
run := func(reqs []*batchRequest) {
|
||||
seen := map[string]struct{}{}
|
||||
var distinct []string
|
||||
for _, r := range reqs {
|
||||
if _, ok := seen[r.language]; !ok {
|
||||
seen[r.language] = struct{}{}
|
||||
distinct = append(distinct, r.language)
|
||||
}
|
||||
}
|
||||
mu.Lock()
|
||||
langsPerBatch = append(langsPerBatch, distinct)
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
// Large window + size so the fill loop stays open across submits and the
|
||||
// language constraint (not the timer) is what splits the batches.
|
||||
b := newBatcher(16, 200*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
langs := []string{"en", "en", "de", "de", "en", "fr", "fr"}
|
||||
const N = 7
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), language: langs[i], reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// Invariant: every dispatched batch is single-language.
|
||||
for _, distinct := range langsPerBatch {
|
||||
Expect(len(distinct)).To(Equal(1), "a batch coalesced more than one language: %v", distinct)
|
||||
}
|
||||
// Liveness: every request got a reply (carry-over never stranded).
|
||||
for i := 0; i < N; i++ {
|
||||
Expect(got[i]).To(Equal(string(rune('a' + i))))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -48,13 +48,6 @@ var (
|
||||
// side reads them as const float*/const int*.
|
||||
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
|
||||
|
||||
// CppTranscribePcmBatchJSONLang is the multilingual variant of the batched
|
||||
// JSON entry point: identical, plus a trailing target_lang. "" (the model
|
||||
// default, "auto") is passed for non-prompt models, which ignore it; an
|
||||
// unknown locale on a prompt model returns 0 and sets last_error. Present
|
||||
// only in newer libparakeet.so; nil falls back to CppTranscribePcmBatchJSON.
|
||||
CppTranscribePcmBatchJSONLang func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32, targetLang string) uintptr
|
||||
|
||||
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
|
||||
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
|
||||
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
|
||||
@@ -62,18 +55,6 @@ var (
|
||||
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
|
||||
CppStreamFinalize func(s uintptr) uintptr
|
||||
CppStreamFree func(s uintptr)
|
||||
|
||||
// CppStreamBeginLang is the multilingual variant of stream_begin: identical,
|
||||
// plus a trailing target_lang ("" means the model default). Present only in
|
||||
// newer libparakeet.so; nil falls back to CppStreamBegin.
|
||||
CppStreamBeginLang func(ctx uintptr, targetLang string) uintptr
|
||||
|
||||
// Streaming JSON variants (ABI v4): feed/finalize returning a malloc'd char*
|
||||
// JSON document {text,eou,frame_sec,words} (uintptr, freed via CppFreeString)
|
||||
// so streaming segments can carry per-word timestamps. Present only in newer
|
||||
// libparakeet.so; nil falls back to the text-only CppStreamFeed/Finalize path.
|
||||
CppStreamFeedJSON func(s uintptr, pcm []float32, nSamples int32) uintptr
|
||||
CppStreamFinalizeJSON func(s uintptr) uintptr
|
||||
)
|
||||
|
||||
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
|
||||
@@ -91,26 +72,9 @@ const streamChunkSamples = 16000
|
||||
//
|
||||
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
|
||||
type transcriptJSON struct {
|
||||
Text string `json:"text"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
}
|
||||
|
||||
// streamFeedJSON mirrors the document returned by
|
||||
// parakeet_capi_stream_feed_json / parakeet_capi_stream_finalize_json (ABI v4):
|
||||
//
|
||||
// {"text":"...","eou":0,"frame_sec":0.080000,
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
|
||||
// <EOU>/<EOB> fired this feed; "words" are the words finalized this call with
|
||||
// absolute (stream-relative) start/end seconds.
|
||||
type streamFeedJSON struct {
|
||||
Text string `json:"text"`
|
||||
Eou int `json:"eou"`
|
||||
FrameSec float64 `json:"frame_sec"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Text string `json:"text"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
}
|
||||
|
||||
type transcriptWord struct {
|
||||
@@ -139,10 +103,6 @@ type ParakeetCpp struct {
|
||||
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
|
||||
bat *batcher
|
||||
batStop chan struct{}
|
||||
// segmentGapFrames is NeMo's segment_gap_threshold in ENCODER FRAMES (model
|
||||
// YAML option, default 0=off). When >0 it adds NeMo's silence-gap split on
|
||||
// top of the punctuation split; converted to seconds via the JSON frame_sec.
|
||||
segmentGapFrames int
|
||||
}
|
||||
|
||||
// Load is the LocalAI gRPC entry point for LoadModel: it calls
|
||||
@@ -172,11 +132,6 @@ func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
if maxWaitMs < 0 {
|
||||
maxWaitMs = 0
|
||||
}
|
||||
|
||||
// NeMo's segment_gap_threshold (encoder frames, default 0=off). Off by
|
||||
// default matches NeMo's default (punctuation-only segments); when set it
|
||||
// additionally splits segments on inter-word silence (see transcriptResultFromDoc).
|
||||
p.segmentGapFrames = optInt(opts, "segment_gap_threshold", 0)
|
||||
if CppTranscribePcmBatchJSON != nil {
|
||||
p.batStop = make(chan struct{})
|
||||
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
|
||||
@@ -232,19 +187,8 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
if len(reqs) > 0 {
|
||||
dec = reqs[0].decoder
|
||||
}
|
||||
// All requests in a batch share one language (the batcher coalesces only
|
||||
// same-language requests), so any element's language describes the batch.
|
||||
lang := ""
|
||||
if len(reqs) > 0 {
|
||||
lang = reqs[0].language
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
var cstr uintptr
|
||||
if CppTranscribePcmBatchJSONLang != nil {
|
||||
cstr = CppTranscribePcmBatchJSONLang(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec, lang)
|
||||
} else {
|
||||
cstr = CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
}
|
||||
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
p.engineMu.Unlock()
|
||||
if cstr == 0 {
|
||||
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
|
||||
@@ -282,9 +226,8 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// OpenAI API, whose default is segment-level); token ids always populate
|
||||
// Segment.Tokens.
|
||||
//
|
||||
// translate/diarize/prompt/temperature/threads are not applicable to parakeet
|
||||
// and are ignored; language is honored on the batched + streaming paths (see
|
||||
// opts.GetLanguage() below); streaming is handled by AudioTranscriptionStream
|
||||
// translate/diarize/prompt/temperature/language/threads are not applicable to
|
||||
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
@@ -316,7 +259,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
// Batched path: decode to PCM, submit to the batcher, wait for this request's
|
||||
@@ -328,7 +271,7 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
|
||||
}
|
||||
rep := make(chan batchReply, 1)
|
||||
select {
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, language: opts.GetLanguage(), reply: rep}:
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
@@ -345,169 +288,34 @@ func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.Transcrip
|
||||
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts, p.segmentGapFrames), nil
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
// segmentSeparators is NeMo's default segment_seperators (sentence-ending
|
||||
// punctuation). Splitting on these matches NeMo's default segment timestamps.
|
||||
var segmentSeparators = []rune{'.', '?', '!'}
|
||||
|
||||
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
|
||||
// grouping words into NeMo-faithful segments (see splitWordsIntoSegments). The
|
||||
// optional gapFrames (NeMo's segment_gap_threshold, in encoder FRAMES; 0=off)
|
||||
// additionally splits on inter-word silence; it is converted to a seconds gap
|
||||
// with the document's frame_sec. Per-segment word timings are attached only when
|
||||
// the caller requested word granularity; token ids populate each segment's
|
||||
// Tokens by time-window membership. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
|
||||
// synthesising a single whole-clip segment and attaching word timings only when
|
||||
// the caller requested word granularity. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
|
||||
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
|
||||
gapSeconds := 0.0
|
||||
if gapFrames > 0 {
|
||||
if doc.FrameSec > 0 {
|
||||
gapSeconds = float64(gapFrames) * doc.FrameSec
|
||||
} else {
|
||||
xlog.Warn("parakeet-cpp: segment_gap_threshold set but libparakeet.so " +
|
||||
"did not report frame_sec; falling back to punctuation-only segments")
|
||||
}
|
||||
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
|
||||
for _, w := range doc.Words {
|
||||
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
|
||||
}
|
||||
|
||||
groups := splitWordsIntoSegments(doc.Words, segmentSeparators, gapSeconds)
|
||||
if len(groups) == 0 {
|
||||
// No words (edge case): single whole-clip text segment.
|
||||
return pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
|
||||
}
|
||||
tokens := make([]int32, 0, len(doc.Tokens))
|
||||
for _, t := range doc.Tokens {
|
||||
tokens = append(tokens, t.ID)
|
||||
}
|
||||
|
||||
wantWords := wordsRequested(opts.TimestampGranularities)
|
||||
segments := make([]*pb.TranscriptSegment, 0, len(groups))
|
||||
for id, group := range groups {
|
||||
parts := make([]string, len(group))
|
||||
for i, gw := range group {
|
||||
parts[i] = gw.W
|
||||
}
|
||||
seg := &pb.TranscriptSegment{
|
||||
Id: int32(id),
|
||||
Start: secondsToNanos(group[0].Start),
|
||||
End: secondsToNanos(group[len(group)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
Tokens: tokensInWindow(doc.Tokens, group[0].Start, group[len(group)-1].End),
|
||||
}
|
||||
if wantWords {
|
||||
ws := make([]*pb.TranscriptWord, len(group))
|
||||
for i, gw := range group {
|
||||
ws[i] = &pb.TranscriptWord{Start: secondsToNanos(gw.Start), End: secondsToNanos(gw.End), Text: gw.W}
|
||||
}
|
||||
seg.Words = ws
|
||||
}
|
||||
segments = append(segments, seg)
|
||||
var segStart, segEnd int64
|
||||
if len(words) > 0 {
|
||||
segStart = words[0].Start
|
||||
segEnd = words[len(words)-1].End
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: segments}
|
||||
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
|
||||
if wordsRequested(opts.TimestampGranularities) {
|
||||
seg.Words = words
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
|
||||
}
|
||||
|
||||
// splitWordsIntoSegments groups words into segments exactly as NeMo's
|
||||
// get_segment_offsets does (nemo/collections/asr/parts/utils/timestamp_utils.py).
|
||||
// Walking the words, it closes a segment when (1) the gap rule is enabled
|
||||
// (gapSeconds > 0) and the segment already has words and the gap from the
|
||||
// previous word's end to this word's start is >= gapSeconds - the current word
|
||||
// then STARTS a new segment - or, checked only when the gap rule did not apply
|
||||
// (NeMo's elif), (2) the word ends with (or is) a separator, which closes the
|
||||
// segment INCLUDING that word. Trailing words flush into a final segment.
|
||||
// gapSeconds <= 0 disables the gap rule, matching NeMo's default
|
||||
// segment_gap_threshold=None (punctuation-only segments).
|
||||
func splitWordsIntoSegments(words []transcriptWord, separators []rune, gapSeconds float64) [][]transcriptWord {
|
||||
var segments [][]transcriptWord
|
||||
var cur []transcriptWord
|
||||
for i, word := range words {
|
||||
gapActive := gapSeconds > 0 && len(cur) > 0
|
||||
if gapActive && (word.Start-words[i-1].End) >= gapSeconds {
|
||||
segments = append(segments, cur)
|
||||
cur = []transcriptWord{word}
|
||||
continue
|
||||
}
|
||||
if !gapActive && endsWithSeparator(word.W, separators) {
|
||||
cur = append(cur, word)
|
||||
segments = append(segments, cur)
|
||||
cur = nil
|
||||
continue
|
||||
}
|
||||
cur = append(cur, word)
|
||||
}
|
||||
if len(cur) > 0 {
|
||||
segments = append(segments, cur)
|
||||
}
|
||||
return segments
|
||||
}
|
||||
|
||||
// endsWithSeparator reports whether w's last rune is in separators (matching
|
||||
// NeMo's `word[-1] in delims or word in delims`).
|
||||
func endsWithSeparator(w string, separators []rune) bool {
|
||||
r := []rune(strings.TrimSpace(w))
|
||||
if len(r) == 0 {
|
||||
return false
|
||||
}
|
||||
last := r[len(r)-1]
|
||||
for _, s := range separators {
|
||||
if last == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// tokensInWindow returns the ids of tokens whose timestamp t falls in
|
||||
// [start, end] (inclusive), assigning each token to the segment that spans its
|
||||
// time. The last segment's end is the last word end, so the final token is
|
||||
// included.
|
||||
func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
|
||||
var ids []int32
|
||||
for _, t := range tokens {
|
||||
if t.T >= start && t.T <= end {
|
||||
ids = append(ids, t.ID)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
|
||||
// is the model's own utterance boundary; each closed segment takes its start/end
|
||||
// from its first/last accumulated word.
|
||||
type streamSegmenter struct {
|
||||
segs []*pb.TranscriptSegment
|
||||
cur []transcriptWord
|
||||
nextID int32
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) add(doc streamFeedJSON) {
|
||||
s.cur = append(s.cur, doc.Words...)
|
||||
if doc.Eou != 0 {
|
||||
s.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) flush() {
|
||||
if len(s.cur) == 0 {
|
||||
return
|
||||
}
|
||||
parts := make([]string, len(s.cur))
|
||||
for i, w := range s.cur {
|
||||
parts[i] = w.W
|
||||
}
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{
|
||||
Id: s.nextID,
|
||||
Start: secondsToNanos(s.cur[0].Start),
|
||||
End: secondsToNanos(s.cur[len(s.cur)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
})
|
||||
s.nextID++
|
||||
s.cur = nil
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
|
||||
|
||||
// wordsRequested reports whether the caller asked for word-level timestamps.
|
||||
// The OpenAI transcription API gates word timings behind
|
||||
// timestamp_granularities[] containing "word" and defaults to segment-level
|
||||
@@ -553,12 +361,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
var stream uintptr
|
||||
if CppStreamBeginLang != nil {
|
||||
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
|
||||
} else {
|
||||
stream = CppStreamBegin(p.ctxPtr)
|
||||
}
|
||||
stream := CppStreamBegin(p.ctxPtr)
|
||||
if stream == 0 {
|
||||
// Not a cache-aware streaming model: run a normal offline
|
||||
// transcription and emit it as one delta + a closing final result.
|
||||
@@ -587,14 +390,6 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return err
|
||||
}
|
||||
|
||||
// ABI v4: when the streaming JSON entry points are present, drive them so the
|
||||
// per-utterance segments carry per-word start/end timestamps. Falls through to
|
||||
// the text-only loop below against an older libparakeet.so. Runs under the
|
||||
// engineMu already held above.
|
||||
if CppStreamFeedJSON != nil {
|
||||
return p.streamJSON(ctx, stream, data, duration, results)
|
||||
}
|
||||
|
||||
var (
|
||||
full strings.Builder
|
||||
segText strings.Builder
|
||||
@@ -671,71 +466,6 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamJSON drives the ABI v4 streaming JSON entry points: each feed/finalize
|
||||
// returns a {text,eou,frame_sec,words} document. The newly-finalized text is
|
||||
// emitted as a delta (unchanged streaming contract) while words are accumulated
|
||||
// into per-utterance segments (closed on EOU) so the closing FinalResult carries
|
||||
// timestamped segments. Runs under engineMu (already held by the caller).
|
||||
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
|
||||
duration float32, results chan *pb.TranscriptStreamResponse) error {
|
||||
var (
|
||||
full strings.Builder
|
||||
seg streamSegmenter
|
||||
)
|
||||
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
|
||||
// emits the delta, and routes words through the segmenter.
|
||||
consume := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
raw := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
var doc streamFeedJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
|
||||
}
|
||||
if doc.Text != "" {
|
||||
full.WriteString(doc.Text)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
|
||||
}
|
||||
seg.add(doc)
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
|
||||
return err
|
||||
}
|
||||
seg.flush() // close any trailing utterance that never saw an EOU
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
segments := seg.segments()
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
|
||||
// float samples plus the clip duration in seconds. Mirrors the whisper
|
||||
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
|
||||
|
||||
@@ -53,10 +53,6 @@ func ensureLibLoaded() {
|
||||
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
|
||||
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
|
||||
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
|
||||
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
|
||||
})
|
||||
@@ -111,22 +107,13 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
|
||||
"expected non-empty transcript for %s", audioPath)
|
||||
// NeMo-faithful segmentation: one or more punctuation-delimited
|
||||
// segments, each with text and a monotonically-advancing time span.
|
||||
Expect(res.Segments).ToNot(BeEmpty(), "expected at least one segment")
|
||||
var prevEnd int64
|
||||
for i, seg := range res.Segments {
|
||||
Expect(strings.TrimSpace(seg.Text)).ToNot(BeEmpty(),
|
||||
"segment %d must have text", i)
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start),
|
||||
"segment %d end must not precede its start", i)
|
||||
Expect(seg.Start).To(BeNumerically(">=", prevEnd),
|
||||
"segments must be in time order")
|
||||
prevEnd = seg.End
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(seg.Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
}
|
||||
Expect(res.Segments).To(HaveLen(1),
|
||||
"synthesises a single whole-clip segment")
|
||||
Expect(res.Segments[0].Text).To(Equal(res.Text),
|
||||
"single segment text must equal the top-level text")
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(res.Segments[0].Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
})
|
||||
|
||||
It("emits word-level timestamps when granularity=word", func() {
|
||||
@@ -142,28 +129,15 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
TimestampGranularities: []string{"word"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Segments).ToNot(BeEmpty())
|
||||
// With word granularity every segment carries its own words, and each
|
||||
// segment's span tracks its first/last word; word starts advance
|
||||
// monotonically across the whole transcript.
|
||||
totalWords := 0
|
||||
var prevStart int64 = -1
|
||||
for i, seg := range res.Segments {
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"segment %d must carry per-word timestamps with granularity=word", i)
|
||||
Expect(seg.Start).To(Equal(seg.Words[0].Start),
|
||||
"segment %d start tracks its first word", i)
|
||||
Expect(seg.End).To(Equal(seg.Words[len(seg.Words)-1].End),
|
||||
"segment %d end tracks its last word", i)
|
||||
for _, w := range seg.Words {
|
||||
Expect(w.End).To(BeNumerically(">=", w.Start))
|
||||
Expect(w.Start).To(BeNumerically(">=", prevStart))
|
||||
prevStart = w.Start
|
||||
totalWords++
|
||||
}
|
||||
}
|
||||
Expect(totalWords).To(BeNumerically(">", 0))
|
||||
Expect(res.Segments[0].Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
seg := res.Segments[0]
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"expected per-word timestamps with granularity=word")
|
||||
// Monotonic, non-negative timings spanning the segment.
|
||||
Expect(seg.Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start))
|
||||
Expect(seg.Words[len(seg.Words)-1].End).To(Equal(seg.End),
|
||||
"segment end tracks the last word")
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -65,25 +65,6 @@ func main() {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
|
||||
// Per-request language variants (multilingual nemotron). Same probe pattern:
|
||||
// present only in libparakeet.so built with multilingual support, so the
|
||||
// backend still loads against an older library and falls back to the
|
||||
// non-lang batched + streaming entry points (model default / "auto").
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSONLang, lib, "parakeet_capi_transcribe_pcm_batch_json_lang")
|
||||
}
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_begin_lang"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamBeginLang, lib, "parakeet_capi_stream_begin_lang")
|
||||
}
|
||||
|
||||
// Streaming JSON entry points (ABI v4): surface per-word timestamps on the
|
||||
// streaming path. Same probe pattern; absent in older libparakeet.so, where
|
||||
// the backend falls back to the text-only streaming feed.
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_stream_feed_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppStreamFeedJSON, lib, "parakeet_capi_stream_feed_json")
|
||||
purego.RegisterLibFunc(&CppStreamFinalizeJSON, lib, "parakeet_capi_stream_finalize_json")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func tw(text string, start, end float64) transcriptWord {
|
||||
return transcriptWord{W: text, Start: start, End: end}
|
||||
}
|
||||
|
||||
var _ = Describe("splitWordsIntoSegments (NeMo get_segment_offsets parity)", func() {
|
||||
seps := []rune{'.', '?', '!'}
|
||||
|
||||
It("splits on sentence-ending punctuation, including the delimiter word", func() {
|
||||
words := []transcriptWord{tw("hello", 0, 0.4), tw("world.", 0.4, 0.8), tw("bye", 1.0, 1.3)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[0][1].W).To(Equal("world."))
|
||||
Expect(segs[1]).To(HaveLen(1))
|
||||
Expect(segs[1][0].W).To(Equal("bye"))
|
||||
})
|
||||
|
||||
It("keeps a single segment with no terminal punctuation and gap off", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 0)
|
||||
Expect(segs).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("splits on the gap rule when enabled, the gapped word starting the next segment", func() {
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b", 0.2, 0.4), tw("c", 5.0, 5.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0) // c is 4.6s after b
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2)) // a b
|
||||
Expect(segs[1]).To(HaveLen(1)) // c
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("checks the gap rule before punctuation (NeMo elif order)", func() {
|
||||
// "b." would terminate, but c is far after it -> gap closes [a b.] at b.
|
||||
words := []transcriptWord{tw("a", 0, 0.2), tw("b.", 0.2, 0.4), tw("c", 9.0, 9.2)}
|
||||
segs := splitWordsIntoSegments(words, seps, 1.0)
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0]).To(HaveLen(2))
|
||||
Expect(segs[1][0].W).To(Equal("c"))
|
||||
})
|
||||
|
||||
It("still splits on punctuation when the gap rule is enabled but does not fire", func() {
|
||||
words := []transcriptWord{tw("hi.", 0, 0.4), tw("bye", 0.4, 0.8)}
|
||||
segs := splitWordsIntoSegments(words, seps, 5.0) // gap never reached
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0][0].W).To(Equal("hi."))
|
||||
})
|
||||
|
||||
It("returns nothing for empty input", func() {
|
||||
Expect(splitWordsIntoSegments(nil, seps, 0)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
|
||||
doc := transcriptJSON{
|
||||
Text: "hello world. bye now",
|
||||
FrameSec: 0.08,
|
||||
Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4},
|
||||
{W: "world.", Start: 0.4, End: 0.8},
|
||||
{W: "bye", Start: 1.0, End: 1.3},
|
||||
{W: "now", Start: 1.3, End: 1.6},
|
||||
},
|
||||
Tokens: []transcriptToken{{ID: 1, T: 0.1}, {ID: 2, T: 0.5}, {ID: 3, T: 1.1}, {ID: 4, T: 1.4}},
|
||||
}
|
||||
|
||||
It("emits one segment per punctuation-delimited group with start/end", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(2))
|
||||
Expect(res.Segments[0].Text).To(Equal("hello world."))
|
||||
Expect(res.Segments[0].Start).To(Equal(int64(0)))
|
||||
Expect(res.Segments[0].End).To(Equal(secondsToNanos(0.8)))
|
||||
Expect(res.Segments[1].Text).To(Equal("bye now"))
|
||||
Expect(res.Segments[1].Start).To(Equal(secondsToNanos(1.0)))
|
||||
Expect(res.Segments[1].Id).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("assigns tokens to the segment whose time window contains them", func() {
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments[0].Tokens).To(Equal([]int32{1, 2}))
|
||||
Expect(res.Segments[1].Tokens).To(Equal([]int32{3, 4}))
|
||||
})
|
||||
|
||||
It("attaches per-segment words only when word granularity requested", func() {
|
||||
plain := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(plain.Segments[0].Words).To(BeEmpty())
|
||||
withWords := transcriptResultFromDoc(doc, &pb.TranscriptRequest{TimestampGranularities: []string{"word"}}, 0)
|
||||
Expect(withWords.Segments[0].Words).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("falls back to a single text segment when there are no words", func() {
|
||||
res := transcriptResultFromDoc(transcriptJSON{Text: "hi"}, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
Expect(res.Segments[0].Text).To(Equal("hi"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("streaming segment assembly", func() {
|
||||
It("closes a segment with start/end from its words on EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
Expect(segs).To(HaveLen(1))
|
||||
Expect(segs[0].Text).To(Equal("hello world"))
|
||||
Expect(segs[0].Start).To(Equal(int64(0)))
|
||||
Expect(segs[0].End).To(Equal(secondsToNanos(0.9)))
|
||||
})
|
||||
|
||||
It("buffers words across feeds until EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
|
||||
Expect(acc.segments()).To(BeEmpty())
|
||||
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
|
||||
Expect(acc.segments()).To(HaveLen(1))
|
||||
Expect(acc.segments()[0].Text).To(Equal("hi there"))
|
||||
})
|
||||
})
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=19bdfe22d255d5b4dff39d449318b9bc5ea2317f
|
||||
STABLEDIFFUSION_GGML_VERSION?=1f9ee88e09c258053fa59d5e05e23dfb10fa0b13
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -386,7 +386,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *llm_vision_path = "";
|
||||
const char *diffusion_model_path = stableDiffusionModel;
|
||||
const char *high_noise_diffusion_model_path = "";
|
||||
const char *uncond_diffusion_model_path = "";
|
||||
const char *taesd_path = "";
|
||||
const char *control_net_path = "";
|
||||
const char *embedding_dir = "";
|
||||
@@ -473,7 +472,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval);
|
||||
if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "uncond_diffusion_model_path")) uncond_diffusion_model_path = strdup(optval);
|
||||
if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
|
||||
if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
|
||||
if (!strcmp(optname, "embedding_dir")) {
|
||||
@@ -573,7 +571,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.llm_vision_path = llm_vision_path;
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.uncond_diffusion_model_path = uncond_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=df7638d8229a243af8a4b5a8ae557e0d74e0a0ae
|
||||
WHISPER_CPP_VERSION?=99613cb720b65036237d44b52f753b51f75c2797
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -95,29 +95,6 @@
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4"
|
||||
metal: "metal-ds4"
|
||||
metal-darwin-arm64: "metal-ds4"
|
||||
- &dllm
|
||||
name: "dllm"
|
||||
alias: "dllm"
|
||||
license: mit
|
||||
description: |
|
||||
mudler/dllm.cpp - DiffusionGemma block-diffusion LLM inference engine
|
||||
(C++/ggml, GGUF weights). Decodes whole token canvases per diffusion
|
||||
round instead of autoregressive sampling. Runs on CPU and NVIDIA CUDA 13
|
||||
(including Jetson/GB10 L4T targets).
|
||||
urls:
|
||||
- https://github.com/mudler/dllm.cpp
|
||||
tags:
|
||||
- text-to-text
|
||||
- LLM
|
||||
- gguf
|
||||
- diffusion
|
||||
- CPU
|
||||
- CUDA
|
||||
capabilities:
|
||||
default: "cpu-dllm"
|
||||
nvidia: "cuda13-dllm"
|
||||
nvidia-cuda-13: "cuda13-dllm"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm"
|
||||
- &whispercpp
|
||||
name: "whisper"
|
||||
alias: "whisper"
|
||||
@@ -1295,13 +1272,6 @@
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ds4-development"
|
||||
metal: "metal-ds4-development"
|
||||
metal-darwin-arm64: "metal-ds4-development"
|
||||
- !!merge <<: *dllm
|
||||
name: "dllm-development"
|
||||
capabilities:
|
||||
default: "cpu-dllm-development"
|
||||
nvidia: "cuda13-dllm-development"
|
||||
nvidia-cuda-13: "cuda13-dllm-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-dllm-development"
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "stablediffusion-ggml-development"
|
||||
capabilities:
|
||||
@@ -1889,37 +1859,6 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-ds4"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-ds4
|
||||
## dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cpu-dllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cpu-dllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-dllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-dllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-nvidia-l4t-arm64-dllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-dllm
|
||||
- !!merge <<: *dllm
|
||||
name: "cuda13-nvidia-l4t-arm64-dllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-dllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-dllm
|
||||
## whisper
|
||||
- !!merge <<: *whispercpp
|
||||
name: "whisper-development"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
transformers==4.48.3
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
torch==2.4.1
|
||||
torch==2.7.1+cpu
|
||||
torchaudio==2.4.1
|
||||
coqui-tts
|
||||
@@ -1,5 +1,5 @@
|
||||
torch==2.4.1
|
||||
torch==2.7.1+cpu
|
||||
torchaudio==2.4.1
|
||||
transformers==4.48.3
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
coqui-tts
|
||||
@@ -1,6 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm7.0
|
||||
torch==2.10.0+rocm7.0
|
||||
torch==2.7.1+cpu
|
||||
torchaudio==2.10.0+rocm7.0
|
||||
transformers==4.48.3
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
coqui-tts
|
||||
@@ -1,8 +1,8 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch==2.8.0+xpu
|
||||
torch==2.7.1+cpu
|
||||
torchaudio==2.8.0+xpu
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
transformers==4.48.3
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
coqui-tts
|
||||
@@ -1,4 +1,4 @@
|
||||
torch==2.7.1
|
||||
transformers==4.48.3
|
||||
torch==2.7.1+cpu
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
coqui-tts
|
||||
|
||||
@@ -26,10 +26,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
try:
|
||||
from vllm.tokenizers import get_tokenizer # vLLM >= 0.22
|
||||
except ImportError:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer # vLLM < 0.22
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.assets.video import VideoAsset
|
||||
import base64
|
||||
|
||||
@@ -23,9 +23,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -308,31 +308,10 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
application.galleryService.SetNATSClient(distSvc.Nats)
|
||||
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
|
||||
// Clean up stale in-progress operations from previous crashed instances
|
||||
if _, err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to clean stale gallery operations", "error", err)
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
|
||||
// Reap stale ops periodically, not just at boot: an op orphaned by
|
||||
// a replica that died mid-install (its foreground handler goroutine
|
||||
// gone) would otherwise linger "processing" in the UI until the next
|
||||
// restart. 30m matches the install/upgrade ceiling so a genuinely
|
||||
// slow op is never reaped out from under itself.
|
||||
gsvc := application.galleryService
|
||||
go func() {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-options.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := gsvc.ReapStaleOperations(30 * time.Minute); err != nil {
|
||||
xlog.Warn("Failed to reap stale gallery operations", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
// Hydrate from the store first so the wildcard subscriber finds an
|
||||
// already-populated statuses map for any operations still in flight
|
||||
|
||||
@@ -214,9 +214,7 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
"from", info.InstalledVersion, "to", info.AvailableVersion)
|
||||
var err error
|
||||
if bm != nil {
|
||||
// Background auto-upgrade: no live admin watching a progress bar,
|
||||
// so opID is empty and the distributed path skips progress streaming.
|
||||
err = bm.UpgradeBackend(ctx, "", name, nil)
|
||||
err = bm.UpgradeBackend(ctx, name, nil)
|
||||
} else {
|
||||
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Model string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
In io.Reader
|
||||
Out io.Writer
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, opts Options) error {
|
||||
if opts.In == nil {
|
||||
opts.In = strings.NewReader("")
|
||||
}
|
||||
if opts.Out == nil {
|
||||
opts.Out = io.Discard
|
||||
}
|
||||
|
||||
session, err := newChatSession(ctx, newLocalAIChatClient(opts.BaseURL, opts.APIKey), opts.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return runTerminalChat(ctx, session, opts.In, opts.Out)
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestChat(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Chat Suite")
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Run chat", func() {
|
||||
It("streams a single chat response", func() {
|
||||
var capturedModel string
|
||||
var capturedAuth string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v1/models" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
writeResponse(w, `{"object":"list","data":[{"id":"test-model","object":"model"}]}`)
|
||||
return
|
||||
}
|
||||
|
||||
Expect(r.URL.Path).To(Equal("/v1/chat/completions"))
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
|
||||
var body struct {
|
||||
Model string `json:"model"`
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
|
||||
capturedModel = body.Model
|
||||
Expect(body.Messages).To(HaveLen(1))
|
||||
Expect(body.Messages[0].Role).To(Equal("user"))
|
||||
Expect(body.Messages[0].Content).To(Equal("hello"))
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n")
|
||||
writeResponse(w, "data: [DONE]\n\n")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var out bytes.Buffer
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
Model: "test-model",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "secret",
|
||||
In: strings.NewReader("hello\n/exit\n"),
|
||||
Out: &out,
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(capturedModel).To(Equal("test-model"))
|
||||
Expect(capturedAuth).To(Equal("Bearer secret"))
|
||||
Expect(out.String()).To(ContainSubstring("assistant: hi!"))
|
||||
Expect(out.String()).To(ContainSubstring("bye"))
|
||||
})
|
||||
|
||||
It("auto-selects the only available model", func() {
|
||||
server := chatTestServer([]string{"solo"}, nil)
|
||||
defer server.Close()
|
||||
|
||||
var out bytes.Buffer
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader("/exit\n"),
|
||||
Out: &out,
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out.String()).To(ContainSubstring("LocalAI chat (solo)"))
|
||||
})
|
||||
|
||||
It("returns an actionable error when no models are installed", func() {
|
||||
server := chatTestServer(nil, nil)
|
||||
defer server.Close()
|
||||
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader(""),
|
||||
})
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no chat models are installed"))
|
||||
Expect(err.Error()).To(ContainSubstring("local-ai models install <model>"))
|
||||
})
|
||||
|
||||
It("returns an actionable error when multiple models are available without a selection", func() {
|
||||
server := chatTestServer([]string{"alpha", "beta"}, nil)
|
||||
defer server.Close()
|
||||
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader(""),
|
||||
})
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("multiple models are available"))
|
||||
Expect(err.Error()).To(ContainSubstring("--model"))
|
||||
Expect(err.Error()).To(ContainSubstring("alpha"))
|
||||
Expect(err.Error()).To(ContainSubstring("beta"))
|
||||
})
|
||||
|
||||
It("lists and switches models inside the chat", func() {
|
||||
requestedModels := []string{}
|
||||
server := chatTestServer([]string{"alpha", "beta"}, func(model string) {
|
||||
requestedModels = append(requestedModels, model)
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
var out bytes.Buffer
|
||||
err := Run(GinkgoT().Context(), Options{
|
||||
Model: "alpha",
|
||||
BaseURL: server.URL + "/v1",
|
||||
In: strings.NewReader("/models\n/model beta\nhello\n/exit\n"),
|
||||
Out: &out,
|
||||
})
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out.String()).To(ContainSubstring("* alpha"))
|
||||
Expect(out.String()).To(ContainSubstring(" beta"))
|
||||
Expect(out.String()).To(ContainSubstring("switched to beta; conversation cleared"))
|
||||
Expect(requestedModels).To(Equal([]string{"beta"}))
|
||||
})
|
||||
})
|
||||
|
||||
func chatTestServer(models []string, onChat func(model string)) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v1/models":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
writeResponse(w, `{"object":"list","data":[`)
|
||||
for i, model := range models {
|
||||
if i > 0 {
|
||||
writeResponse(w, ",")
|
||||
}
|
||||
writeResponsef(w, `{"id":%q,"object":"model"}`, model)
|
||||
}
|
||||
writeResponse(w, `]}`)
|
||||
case "/v1/chat/completions":
|
||||
var body struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
Expect(json.NewDecoder(r.Body).Decode(&body)).To(Succeed())
|
||||
if onChat != nil {
|
||||
onChat(body.Model)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n\n")
|
||||
writeResponse(w, "data: [DONE]\n\n")
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func writeResponse(w io.Writer, text string) {
|
||||
_, err := fmt.Fprint(w, text)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
func writeResponsef(w io.Writer, format string, args ...any) {
|
||||
_, err := fmt.Fprintf(w, format, args...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type chatClient interface {
|
||||
ListModels(ctx context.Context) ([]string, error)
|
||||
StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error)
|
||||
}
|
||||
|
||||
type localAIChatClient struct {
|
||||
client *openai.Client
|
||||
}
|
||||
|
||||
func newLocalAIChatClient(baseURL string, apiKey string) *localAIChatClient {
|
||||
cfg := openai.DefaultConfig(apiKey)
|
||||
cfg.BaseURL = baseURL
|
||||
return &localAIChatClient{client: openai.NewClientWithConfig(cfg)}
|
||||
}
|
||||
|
||||
func (c *localAIChatClient) ListModels(ctx context.Context) ([]string, error) {
|
||||
resp, err := c.client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(resp.Models))
|
||||
for _, model := range resp.Models {
|
||||
if model.ID != "" {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
}
|
||||
sort.Strings(models)
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (c *localAIChatClient) StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
|
||||
stream, err := c.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: openAIChatMessages(messages),
|
||||
})
|
||||
if err != nil {
|
||||
return "", friendlyChatError(err, model)
|
||||
}
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
var answer strings.Builder
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return answer.String(), friendlyChatError(err, model)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
token := resp.Choices[0].Delta.Content
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
answer.WriteString(token)
|
||||
if _, err := fmt.Fprint(out, token); err != nil {
|
||||
return answer.String(), err
|
||||
}
|
||||
}
|
||||
|
||||
return answer.String(), nil
|
||||
}
|
||||
|
||||
func openAIChatMessages(messages []chatMessage) []openai.ChatCompletionMessage {
|
||||
converted := make([]openai.ChatCompletionMessage, len(messages))
|
||||
for i, message := range messages {
|
||||
converted[i] = openai.ChatCompletionMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func friendlyChatError(err error, model string) error {
|
||||
var apiErr *openai.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
switch apiErr.HTTPStatusCode {
|
||||
case 404:
|
||||
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
|
||||
case 403:
|
||||
return fmt.Errorf("model %q is disabled. Enable it from LocalAI settings or choose another model with `/model <name>`", model)
|
||||
}
|
||||
if apiErr.Message != "" {
|
||||
return errors.New(apiErr.Message)
|
||||
}
|
||||
}
|
||||
|
||||
msg := err.Error()
|
||||
if strings.Contains(msg, "model") && strings.Contains(msg, "not found") {
|
||||
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package chat
|
||||
|
||||
import "strings"
|
||||
|
||||
func formatChatModelList(models []string, current string) string {
|
||||
var b strings.Builder
|
||||
for _, model := range models {
|
||||
prefix := " "
|
||||
if model == current {
|
||||
prefix = "* "
|
||||
}
|
||||
b.WriteString(prefix)
|
||||
b.WriteString(model)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
chatRoleUser = "user"
|
||||
chatRoleAssistant = "assistant"
|
||||
)
|
||||
|
||||
type chatMessage struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
type chatSession struct {
|
||||
client chatClient
|
||||
model string
|
||||
models []string
|
||||
messages []chatMessage
|
||||
}
|
||||
|
||||
func newChatSession(ctx context.Context, client chatClient, requestedModel string) (*chatSession, error) {
|
||||
models, err := client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list models: %w", err)
|
||||
}
|
||||
|
||||
model, err := resolveChatModel(requestedModel, models)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &chatSession{
|
||||
client: client,
|
||||
model: model,
|
||||
models: models,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *chatSession) CurrentModel() string {
|
||||
return s.model
|
||||
}
|
||||
|
||||
func (s *chatSession) Models() []string {
|
||||
models := make([]string, len(s.models))
|
||||
copy(models, s.models)
|
||||
return models
|
||||
}
|
||||
|
||||
func (s *chatSession) Clear() {
|
||||
s.messages = nil
|
||||
}
|
||||
|
||||
func (s *chatSession) SwitchModel(model string) error {
|
||||
if !modelExists(s.models, model) {
|
||||
return fmt.Errorf("model %q is not available. Use /models to see installed models", model)
|
||||
}
|
||||
s.model = model
|
||||
s.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *chatSession) Send(ctx context.Context, prompt string, out io.Writer) error {
|
||||
s.messages = append(s.messages, chatMessage{
|
||||
Role: chatRoleUser,
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
answer, err := s.client.StreamChat(ctx, s.model, s.messages, out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.messages = append(s.messages, chatMessage{
|
||||
Role: chatRoleAssistant,
|
||||
Content: answer,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveChatModel(requested string, models []string) (string, error) {
|
||||
switch {
|
||||
case requested == "" && len(models) == 0:
|
||||
return "", errors.New(`no chat models are installed.
|
||||
|
||||
Install a model first, for example:
|
||||
local-ai models list
|
||||
local-ai models install <model>
|
||||
local-ai run
|
||||
|
||||
Then start a chat session:
|
||||
local-ai chat --model <model>`)
|
||||
case requested == "" && len(models) == 1:
|
||||
return models[0], nil
|
||||
case requested == "" && len(models) > 1:
|
||||
var b strings.Builder
|
||||
b.WriteString("multiple models are available; choose one with --model:\n")
|
||||
b.WriteString(formatChatModelList(models, ""))
|
||||
return "", errors.New(b.String())
|
||||
case !modelExists(models, requested):
|
||||
return "", fmt.Errorf("model %q is not available. Use `local-ai models list` and `local-ai models install <model>`, or pass an installed model with --model", requested)
|
||||
default:
|
||||
return requested, nil
|
||||
}
|
||||
}
|
||||
|
||||
func modelExists(models []string, name string) bool {
|
||||
for _, model := range models {
|
||||
if model == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Chat session", func() {
|
||||
It("keeps model switching and message history out of the terminal adapter", func() {
|
||||
client := &fakeChatClient{
|
||||
models: []string{"alpha", "beta"},
|
||||
answer: "pong",
|
||||
}
|
||||
|
||||
session, err := newChatSession(context.Background(), client, "alpha")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(session.CurrentModel()).To(Equal("alpha"))
|
||||
|
||||
Expect(session.SwitchModel("beta")).To(Succeed())
|
||||
Expect(session.CurrentModel()).To(Equal("beta"))
|
||||
Expect(session.Send(context.Background(), "ping", io.Discard)).To(Succeed())
|
||||
|
||||
Expect(client.requests).To(HaveLen(1))
|
||||
Expect(client.requests[0].model).To(Equal("beta"))
|
||||
Expect(client.requests[0].messages).To(HaveLen(1))
|
||||
Expect(client.requests[0].messages[0].Content).To(Equal("ping"))
|
||||
})
|
||||
})
|
||||
|
||||
type fakeChatClient struct {
|
||||
models []string
|
||||
answer string
|
||||
requests []fakeChatRequest
|
||||
}
|
||||
|
||||
type fakeChatRequest struct {
|
||||
model string
|
||||
messages []chatMessage
|
||||
}
|
||||
|
||||
func (c *fakeChatClient) ListModels(context.Context) ([]string, error) {
|
||||
return c.models, nil
|
||||
}
|
||||
|
||||
func (c *fakeChatClient) StreamChat(_ context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
|
||||
copied := make([]chatMessage, len(messages))
|
||||
copy(copied, messages)
|
||||
c.requests = append(c.requests, fakeChatRequest{model: model, messages: copied})
|
||||
if _, err := io.WriteString(out, c.answer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return c.answer, nil
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func runTerminalChat(ctx context.Context, session *chatSession, in io.Reader, out io.Writer) error {
|
||||
scanner := bufio.NewScanner(in)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
|
||||
|
||||
if err := writeChat(out, "LocalAI chat (%s)\n", session.CurrentModel()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeChat(out, "Type /exit to quit, /clear to reset the conversation, /models to list models.\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
if err := writeChat(out, "\n> "); err != nil {
|
||||
return err
|
||||
}
|
||||
if !scanner.Scan() {
|
||||
break
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(scanner.Text())
|
||||
switch prompt {
|
||||
case "":
|
||||
continue
|
||||
case "/bye", "/exit", "/quit":
|
||||
return writeChat(out, "bye\n")
|
||||
case "/clear":
|
||||
session.Clear()
|
||||
if err := writeChat(out, "conversation cleared\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
case "/models":
|
||||
if err := printChatModels(out, session.Models(), session.CurrentModel()); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if nextModel, ok := strings.CutPrefix(prompt, "/model "); ok {
|
||||
nextModel = strings.TrimSpace(nextModel)
|
||||
if nextModel == "" {
|
||||
if err := writeChat(out, "usage: /model <name>\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := session.SwitchModel(nextModel); err != nil {
|
||||
if writeErr := writeChat(out, "%s\n", err); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := writeChat(out, "switched to %s; conversation cleared\n", session.CurrentModel()); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := writeChat(out, "assistant: "); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := session.Send(ctx, prompt, out); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeChat(out, "\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func printChatModels(out io.Writer, models []string, current string) error {
|
||||
if len(models) == 0 {
|
||||
return writeChat(out, "no models installed\n")
|
||||
}
|
||||
return writeChat(out, "%s", formatChatModelList(models, current))
|
||||
}
|
||||
|
||||
func writeChat(out io.Writer, format string, args ...any) error {
|
||||
_, err := fmt.Fprintf(out, format, args...)
|
||||
return err
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
chatcli "github.com/mudler/LocalAI/core/cli/chat"
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
)
|
||||
|
||||
type ChatCMD struct {
|
||||
Model string `short:"m" help:"Model name to use. Defaults to the only model returned by the server when exactly one is available"`
|
||||
Endpoint string `env:"LOCALAI_CHAT_ENDPOINT" default:"http://127.0.0.1:8080" help:"LocalAI server endpoint. The /v1 path is added automatically when omitted"`
|
||||
APIKey string `env:"LOCALAI_API_KEY,API_KEY" help:"API key to use when the LocalAI server requires authentication"`
|
||||
}
|
||||
|
||||
func (c *ChatCMD) Run(ctx *cliContext.Context) error {
|
||||
return chatcli.Run(context.Background(), chatcli.Options{
|
||||
Model: c.Model,
|
||||
BaseURL: chatAPIBaseURL(c.Endpoint),
|
||||
APIKey: c.APIKey,
|
||||
In: os.Stdin,
|
||||
Out: os.Stdout,
|
||||
})
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Chat command wiring", func() {
|
||||
Describe("chatAPIBaseURL", func() {
|
||||
It("adds /v1 to a root endpoint", func() {
|
||||
Expect(chatAPIBaseURL("http://127.0.0.1:8080")).To(Equal("http://127.0.0.1:8080/v1"))
|
||||
})
|
||||
|
||||
It("keeps endpoints that already include /v1", func() {
|
||||
Expect(chatAPIBaseURL("http://127.0.0.1:8080/v1")).To(Equal("http://127.0.0.1:8080/v1"))
|
||||
Expect(chatAPIBaseURL("http://127.0.0.1:8080/v1/")).To(Equal("http://127.0.0.1:8080/v1"))
|
||||
})
|
||||
|
||||
It("adds a default http scheme", func() {
|
||||
Expect(chatAPIBaseURL("127.0.0.1:8080")).To(Equal("http://127.0.0.1:8080/v1"))
|
||||
})
|
||||
|
||||
It("preserves non-root paths before /v1", func() {
|
||||
Expect(chatAPIBaseURL("http://127.0.0.1:8080/localai")).To(Equal("http://127.0.0.1:8080/localai/v1"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,29 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func chatAPIBaseURL(endpoint string) string {
|
||||
if !strings.Contains(endpoint, "://") {
|
||||
endpoint = "http://" + endpoint
|
||||
}
|
||||
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return strings.TrimRight(endpoint, "/") + "/v1"
|
||||
}
|
||||
|
||||
path := strings.TrimRight(u.Path, "/")
|
||||
if path == "" {
|
||||
u.Path = "/v1"
|
||||
} else if path != "/v1" && !strings.HasSuffix(path, "/v1") {
|
||||
u.Path = path + "/v1"
|
||||
} else {
|
||||
u.Path = path
|
||||
}
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
return u.String()
|
||||
}
|
||||
@@ -9,7 +9,6 @@ var CLI struct {
|
||||
cliContext.Context `embed:""`
|
||||
|
||||
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
|
||||
Chat ChatCMD `cmd:"" help:"Open an interactive chat session against a running LocalAI server"`
|
||||
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
|
||||
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
|
||||
Backends BackendsCMD `cmd:"" help:"Manage LocalAI backends and definitions"`
|
||||
|
||||
@@ -30,8 +30,6 @@ type RunCMD struct {
|
||||
ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
|
||||
|
||||
ExternalBackends []string `env:"LOCALAI_EXTERNAL_BACKENDS,EXTERNAL_BACKENDS" help:"A list of external backends to load from gallery on boot" group:"backends"`
|
||||
WebRTCNAT1To1IPs []string `env:"LOCALAI_WEBRTC_NAT_1TO1_IPS,WEBRTC_NAT_1TO1_IPS" help:"IPs advertised as the host ICE candidates for /v1/realtime WebRTC instead of every local interface. Set to the reachable host/LAN IP when running under Docker host networking or NAT, where pion otherwise offers unreachable bridge addresses and the connection drops after ICE consent checks fail." group:"api"`
|
||||
WebRTCICEInterfaces []string `env:"LOCALAI_WEBRTC_ICE_INTERFACES,WEBRTC_ICE_INTERFACES" help:"Restrict /v1/realtime WebRTC ICE candidate gathering to these network interfaces (e.g. eth0), filtering out docker0/veth noise." group:"api"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||
@@ -227,8 +225,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithApiKeys(r.APIKeys),
|
||||
config.WithModelsURL(append(r.Models, r.ModelArgs...)...),
|
||||
config.WithExternalBackends(r.ExternalBackends...),
|
||||
config.WithWebRTCNAT1To1IPs(r.WebRTCNAT1To1IPs...),
|
||||
config.WithWebRTCICEInterfaces(r.WebRTCICEInterfaces...),
|
||||
config.WithOpaqueErrors(r.OpaqueErrors),
|
||||
config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan),
|
||||
config.WithSubtleKeyComparison(r.UseSubtleKeyComparison),
|
||||
@@ -656,12 +652,12 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
// waitForServerReady polls the given address until the HTTP server is
|
||||
// accepting connections or the context is cancelled.
|
||||
func waitForServerReady(address string, ctx context.Context) {
|
||||
// Ensure the address has a host component for dialing.
|
||||
// Echo accepts ":8080" but net.Dial needs a resolvable host.
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err == nil && host == "" {
|
||||
address = "127.0.0.1:" + port
|
||||
}
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -669,17 +665,11 @@ func waitForServerReady(address string, ctx context.Context) {
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", address, 500*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,19 +12,10 @@ import (
|
||||
)
|
||||
|
||||
type ApplicationConfig struct {
|
||||
Context context.Context
|
||||
ConfigFile string
|
||||
SystemState *system.SystemState
|
||||
ExternalBackends []string
|
||||
|
||||
// WebRTCNAT1To1IPs, when set, are advertised as the host ICE candidates for
|
||||
// /v1/realtime WebRTC instead of every local interface address. Needed when
|
||||
// the routable address differs from what pion gathers — e.g. Docker host
|
||||
// networking (where pion also offers unreachable bridge IPs) or NAT.
|
||||
WebRTCNAT1To1IPs []string
|
||||
// WebRTCICEInterfaces, when set, restricts ICE candidate gathering to these
|
||||
// network interfaces (e.g. eth0), filtering out docker0/veth noise.
|
||||
WebRTCICEInterfaces []string
|
||||
Context context.Context
|
||||
ConfigFile string
|
||||
SystemState *system.SystemState
|
||||
ExternalBackends []string
|
||||
UploadLimitMB, Threads, ContextSize int
|
||||
F16 bool
|
||||
Debug bool
|
||||
@@ -90,6 +81,7 @@ type ApplicationConfig struct {
|
||||
// file is mode 0600.
|
||||
MITMCADir string
|
||||
|
||||
|
||||
// PIIPatternOverrides applies persisted per-id deltas (action,
|
||||
// disabled) to the live redactor at startup. Loaded from
|
||||
// runtime_settings.json and applied right after pii.NewRedactor.
|
||||
@@ -124,11 +116,11 @@ type ApplicationConfig struct {
|
||||
// --require-backend-integrity / LOCALAI_REQUIRE_BACKEND_INTEGRITY.
|
||||
RequireBackendIntegrity bool
|
||||
|
||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
WatchDogIdle bool
|
||||
WatchDogBusy bool
|
||||
WatchDog bool
|
||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
WatchDogIdle bool
|
||||
WatchDogBusy bool
|
||||
WatchDog bool
|
||||
|
||||
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
|
||||
MemoryReclaimerEnabled bool // Enable memory threshold monitoring
|
||||
@@ -319,18 +311,6 @@ func WithExternalBackends(backends ...string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithWebRTCNAT1To1IPs(ips ...string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.WebRTCNAT1To1IPs = ips
|
||||
}
|
||||
}
|
||||
|
||||
func WithWebRTCICEInterfaces(interfaces ...string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.WebRTCICEInterfaces = interfaces
|
||||
}
|
||||
}
|
||||
|
||||
func WithMachineTag(tag string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MachineTag = tag
|
||||
@@ -722,6 +702,7 @@ func WithMITMCADir(dir string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DynamicConfigsDir = dynamicConfigsDir
|
||||
|
||||
@@ -39,21 +39,7 @@ func llamaCppDefaults(cfg *ModelConfig, modelPath string) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Startup parses every model's GGUF header to guess defaults. We only need
|
||||
// scalar metadata (architecture, head/ff counts, chat_template, token IDs,
|
||||
// MTP head) plus array *lengths* — never the array *contents*. Two options
|
||||
// keep this cheap, which matters when many models live on slow storage such
|
||||
// as a Docker volume (see https://github.com/mudler/LocalAI/issues/9790):
|
||||
//
|
||||
// - SkipLargeMetadata: seek past large array-valued metadata (the tokenizer
|
||||
// vocab: tokenizer.ggml.tokens/scores/merges, often >100k entries) instead
|
||||
// of reading and allocating every element. Lengths stay populated.
|
||||
// - UseMMap: read the header via a memory map so faulting in a few pages
|
||||
// replaces hundreds of thousands of tiny read() syscalls (measured ~524k
|
||||
// -> 8 for a 256k-token vocab), the dominant cost on slow filesystems.
|
||||
//
|
||||
// The mapping is released when ParseGGUFFile returns.
|
||||
f, err := gguf.ParseGGUFFile(guessPath, gguf.UseMMap(), gguf.SkipLargeMetadata())
|
||||
f, err := gguf.ParseGGUFFile(guessPath)
|
||||
if err == nil {
|
||||
guessGGUFFromFile(cfg, f, 0)
|
||||
}
|
||||
|
||||
@@ -1,76 +1,13 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// GGUF metadata value type tags (see github.com/gpustack/gguf-parser-go).
|
||||
const (
|
||||
ggufTypeUint32 uint32 = 4
|
||||
ggufTypeString uint32 = 8
|
||||
ggufTypeArray uint32 = 9
|
||||
)
|
||||
|
||||
// writeTestGGUF emits a minimal but valid little-endian GGUF v3 header carrying
|
||||
// the scalar metadata the llama-cpp hook guesses from plus a large string vocab
|
||||
// array (tokenizer.ggml.tokens). The big array is exactly what SkipLargeMetadata
|
||||
// + UseMMap are expected to avoid reading element-by-element, so it must survive a
|
||||
// round-trip through the real hook without corrupting the guessed defaults.
|
||||
func writeTestGGUF(path, chatTemplate string, vocab int) error {
|
||||
wStr := func(b *bytes.Buffer, s string) {
|
||||
binary.Write(b, binary.LittleEndian, uint64(len(s)))
|
||||
b.WriteString(s)
|
||||
}
|
||||
kvStr := func(b *bytes.Buffer, k, v string) {
|
||||
wStr(b, k)
|
||||
binary.Write(b, binary.LittleEndian, ggufTypeString)
|
||||
wStr(b, v)
|
||||
}
|
||||
kvU32 := func(b *bytes.Buffer, k string, v uint32) {
|
||||
wStr(b, k)
|
||||
binary.Write(b, binary.LittleEndian, ggufTypeUint32)
|
||||
binary.Write(b, binary.LittleEndian, v)
|
||||
}
|
||||
|
||||
var meta bytes.Buffer
|
||||
kvStr(&meta, "general.architecture", "llama")
|
||||
kvStr(&meta, "general.name", "ReproModel")
|
||||
kvU32(&meta, "llama.context_length", 4096)
|
||||
kvU32(&meta, "llama.attention.head_count", 32)
|
||||
kvU32(&meta, "llama.feed_forward_length", 11008)
|
||||
kvU32(&meta, "llama.block_count", 32)
|
||||
kvU32(&meta, "tokenizer.ggml.bos_token_id", 1)
|
||||
kvStr(&meta, "tokenizer.chat_template", chatTemplate)
|
||||
|
||||
// large array value — the one the optimization skips reading
|
||||
wStr(&meta, "tokenizer.ggml.tokens")
|
||||
binary.Write(&meta, binary.LittleEndian, ggufTypeArray)
|
||||
binary.Write(&meta, binary.LittleEndian, ggufTypeString)
|
||||
binary.Write(&meta, binary.LittleEndian, uint64(vocab))
|
||||
for i := 0; i < vocab; i++ {
|
||||
wStr(&meta, "token")
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
binary.Write(&out, binary.LittleEndian, gguf.GGUFMagicGGUFLe)
|
||||
binary.Write(&out, binary.LittleEndian, uint32(3)) // version
|
||||
binary.Write(&out, binary.LittleEndian, uint64(0)) // tensor count
|
||||
binary.Write(&out, binary.LittleEndian, uint64(9)) // metadata kv count
|
||||
out.Write(meta.Bytes())
|
||||
|
||||
return os.WriteFile(path, out.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
var _ = Describe("Backend hooks and parser defaults", func() {
|
||||
Context("MatchParserDefaults", func() {
|
||||
It("matches Qwen3 family", func() {
|
||||
@@ -200,58 +137,6 @@ var _ = Describe("Backend hooks and parser defaults", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("llamaCppDefaults GGUF guessing", func() {
|
||||
// Regression coverage for https://github.com/mudler/LocalAI/issues/9790:
|
||||
// the hook reads GGUF headers with SkipLargeMetadata + UseMMap to avoid
|
||||
// pulling the whole tokenizer vocab off (slow) disk on every startup. This
|
||||
// verifies that skipping the vocab array still yields the correct guessed
|
||||
// defaults from the remaining scalar metadata.
|
||||
const chatTemplate = "{{ bos_token }}{% for m in messages %}{{ m.content }}{% endfor %}"
|
||||
|
||||
It("guesses defaults from a GGUF whose large vocab is skipped", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
modelFile := "repro.gguf"
|
||||
Expect(writeTestGGUF(filepath.Join(dir, modelFile), chatTemplate, 50000)).To(Succeed())
|
||||
|
||||
// A pre-set context size short-circuits the GGUF run-estimate, which
|
||||
// needs full tensor info this header-only fixture deliberately omits;
|
||||
// the metadata-reading path the optimization touches is unaffected.
|
||||
ctxSize := 4096
|
||||
cfg := &ModelConfig{
|
||||
Backend: "llama-cpp",
|
||||
LLMConfig: LLMConfig{ContextSize: &ctxSize},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: modelFile},
|
||||
},
|
||||
}
|
||||
cfg.SetDefaults(ModelPath(dir))
|
||||
|
||||
// chat_template is a scalar string, not part of the skipped array,
|
||||
// so it must be captured verbatim.
|
||||
Expect(cfg.GetModelTemplate()).To(Equal(chatTemplate))
|
||||
// scalar-derived defaults are still applied
|
||||
Expect(cfg.ContextSize).NotTo(BeNil())
|
||||
Expect(cfg.NGPULayers).NotTo(BeNil())
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
Expect(cfg.KnownUsecaseStrings).To(ContainElement("FLAG_CHAT"))
|
||||
})
|
||||
|
||||
It("falls back to the default context size when the GGUF is unreadable", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
Expect(os.WriteFile(filepath.Join(dir, "bad.gguf"), []byte("not a gguf"), 0o644)).To(Succeed())
|
||||
|
||||
cfg := &ModelConfig{
|
||||
Backend: "llama-cpp",
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: "bad.gguf"},
|
||||
},
|
||||
}
|
||||
cfg.SetDefaults(ModelPath(dir))
|
||||
|
||||
Expect(cfg.ContextSize).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("PromptCacheAll default", func() {
|
||||
It("defaults to true when omitted from YAML", func() {
|
||||
cfg := &ModelConfig{}
|
||||
|
||||
@@ -308,41 +308,6 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
},
|
||||
Order: 64,
|
||||
},
|
||||
"pipeline.disable_thinking": {
|
||||
Section: "pipeline",
|
||||
Label: "Disable Thinking",
|
||||
Description: "Suppress reasoning/thinking output from the pipeline LLM (sets enable_thinking=false on the underlying model). Use for models that emit <think> blocks you don't want spoken or streamed back to the realtime client.",
|
||||
Component: "toggle",
|
||||
Order: 65,
|
||||
},
|
||||
"pipeline.streaming.llm": {
|
||||
Section: "pipeline",
|
||||
Label: "Stream LLM",
|
||||
Description: "Stream LLM tokens to the realtime client as they are generated instead of waiting for the full response. Emits incremental response.output_audio_transcript.delta / text deltas.",
|
||||
Component: "toggle",
|
||||
Order: 66,
|
||||
},
|
||||
"pipeline.streaming.tts": {
|
||||
Section: "pipeline",
|
||||
Label: "Stream TTS",
|
||||
Description: "Stream synthesized audio chunks to the realtime client as they are produced (requires a TTS backend that implements TTSStream). Falls back to unary synthesis otherwise.",
|
||||
Component: "toggle",
|
||||
Order: 67,
|
||||
},
|
||||
"pipeline.streaming.transcription": {
|
||||
Section: "pipeline",
|
||||
Label: "Stream Transcription",
|
||||
Description: "Stream partial transcription text to the realtime client as the STT backend produces it (requires a transcription backend that implements AudioTranscriptionStream). Falls back to unary transcription otherwise.",
|
||||
Component: "toggle",
|
||||
Order: 68,
|
||||
},
|
||||
"pipeline.streaming.clause_chunking": {
|
||||
Section: "pipeline",
|
||||
Label: "Clause Chunking",
|
||||
Description: "Split the streamed reply into speakable clauses and synthesize each as soon as it completes, instead of buffering the whole message before TTS — lower time-to-first-audio. Script-aware (handles CJK 。!? and Thai/Lao spaces), so it does not whitespace-split. Requires Stream LLM; off buffers the whole message.",
|
||||
Component: "toggle",
|
||||
Order: 69,
|
||||
},
|
||||
|
||||
// --- Functions ---
|
||||
"function.grammar.parallel_calls": {
|
||||
|
||||
@@ -499,16 +499,6 @@ type Pipeline struct {
|
||||
// the pipeline's LLM without editing the LLM model config. Overrides the LLM's
|
||||
// own reasoning_effort. Unset leaves the LLM model config in charge.
|
||||
ReasoningEffort string `yaml:"reasoning_effort,omitempty" json:"reasoning_effort,omitempty"`
|
||||
|
||||
// Streaming opts each pipeline stage into incremental delivery (LLM tokens,
|
||||
// TTS audio chunks, transcription text). Unset stages keep the blocking
|
||||
// unary path, so existing configs are unaffected.
|
||||
Streaming PipelineStreaming `yaml:"streaming,omitempty" json:"streaming,omitempty"`
|
||||
|
||||
// DisableThinking suppresses reasoning/thinking for the pipeline LLM (maps
|
||||
// to enable_thinking=false backend metadata) without editing the underlying
|
||||
// LLM model config. Unset leaves the LLM model config in charge.
|
||||
DisableThinking *bool `yaml:"disable_thinking,omitempty" json:"disable_thinking,omitempty"`
|
||||
}
|
||||
|
||||
// ApplyReasoningEffort resolves the effective reasoning effort — a per-request
|
||||
@@ -540,41 +530,6 @@ func (c *ModelConfig) ApplyReasoningEffort(requestEffort string) {
|
||||
}
|
||||
}
|
||||
|
||||
// @Description PipelineStreaming toggles incremental delivery per realtime stage.
|
||||
type PipelineStreaming struct {
|
||||
LLM *bool `yaml:"llm,omitempty" json:"llm,omitempty"`
|
||||
TTS *bool `yaml:"tts,omitempty" json:"tts,omitempty"`
|
||||
Transcription *bool `yaml:"transcription,omitempty" json:"transcription,omitempty"`
|
||||
// ClauseChunking splits the streamed LLM reply into speakable clauses and
|
||||
// synthesizes each as soon as it completes, instead of buffering the whole
|
||||
// message before TTS. Script-aware (CJK/Thai), so it does not rely on
|
||||
// whitespace sentence boundaries. Requires LLM streaming; unset buffers the
|
||||
// whole message (today's default).
|
||||
ClauseChunking *bool `yaml:"clause_chunking,omitempty" json:"clause_chunking,omitempty"`
|
||||
}
|
||||
|
||||
// StreamLLM reports whether LLM tokens should be streamed for this pipeline.
|
||||
func (p Pipeline) StreamLLM() bool { return p.Streaming.LLM != nil && *p.Streaming.LLM }
|
||||
|
||||
// StreamTTS reports whether TTS audio should be streamed for this pipeline.
|
||||
func (p Pipeline) StreamTTS() bool { return p.Streaming.TTS != nil && *p.Streaming.TTS }
|
||||
|
||||
// StreamTranscription reports whether transcription text should be streamed.
|
||||
func (p Pipeline) StreamTranscription() bool {
|
||||
return p.Streaming.Transcription != nil && *p.Streaming.Transcription
|
||||
}
|
||||
|
||||
// ChunkClauses reports whether the streamed reply should be split into
|
||||
// script-aware clauses and synthesized incrementally rather than buffered whole.
|
||||
func (p Pipeline) ChunkClauses() bool {
|
||||
return p.Streaming.ClauseChunking != nil && *p.Streaming.ClauseChunking
|
||||
}
|
||||
|
||||
// ThinkingDisabled reports whether the pipeline forces the LLM's thinking off.
|
||||
func (p Pipeline) ThinkingDisabled() bool {
|
||||
return p.DisableThinking != nil && *p.DisableThinking
|
||||
}
|
||||
|
||||
// @Description File configuration for model downloads
|
||||
type File struct {
|
||||
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`
|
||||
|
||||
@@ -30,26 +30,11 @@ func MTPSpecOptions() []string {
|
||||
return out
|
||||
}
|
||||
|
||||
// isDraftOnlyAssistantArch reports whether an architecture names a standalone
|
||||
// MTP *draft* model rather than a self-speculating trunk. Upstream's Gemma4 MTP
|
||||
// (ggml-org/llama.cpp#23398) registers the head as a separate `gemma4-assistant`
|
||||
// architecture whose GGUF still carries `nextn_predict_layers`, but which cannot
|
||||
// run alone: it requires a paired target context (`ctx_other`). Such archs must
|
||||
// not trigger the embedded-head self-speculation defaults. The `-assistant`
|
||||
// suffix is upstream's naming convention for these draft-only checkpoints.
|
||||
func isDraftOnlyAssistantArch(arch string) bool {
|
||||
return strings.HasSuffix(arch, "-assistant")
|
||||
}
|
||||
|
||||
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a self-speculating
|
||||
// Multi-Token Prediction head. Detection reads `<arch>.nextn_predict_layers`,
|
||||
// which is what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
|
||||
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a Multi-Token
|
||||
// Prediction head. Detection reads `<arch>.nextn_predict_layers`, which is
|
||||
// what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
|
||||
// `conversion/qwen.py` MTP mixin. A positive layer count means the head is
|
||||
// present in the same GGUF as the trunk.
|
||||
//
|
||||
// Draft-only assistant architectures (e.g. Gemma4's `gemma4-assistant`) carry
|
||||
// the same key but are separate draft checkpoints meant to be paired with a
|
||||
// target model, so they are deliberately excluded here.
|
||||
func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
|
||||
if f == nil {
|
||||
return 0, false
|
||||
@@ -58,9 +43,6 @@ func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
|
||||
if arch == "" {
|
||||
return 0, false
|
||||
}
|
||||
if isDraftOnlyAssistantArch(arch) {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := f.Header.MetadataKV.Get(arch + ".nextn_predict_layers")
|
||||
if !ok {
|
||||
return 0, false
|
||||
|
||||
@@ -3,33 +3,10 @@ package config_test
|
||||
import (
|
||||
. "github.com/mudler/LocalAI/core/config"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// ggufWithArch fabricates a minimal in-memory GGUF carrying the given
|
||||
// `general.architecture` and a positive `<arch>.nextn_predict_layers` count,
|
||||
// so HasEmbeddedMTPHead can be exercised without a real model file.
|
||||
func ggufWithArch(arch string, nextn uint32) *gguf.GGUFFile {
|
||||
return &gguf.GGUFFile{
|
||||
Header: gguf.GGUFHeader{
|
||||
MetadataKV: gguf.GGUFMetadataKVs{
|
||||
{
|
||||
Key: "general.architecture",
|
||||
ValueType: gguf.GGUFMetadataValueTypeString,
|
||||
Value: arch,
|
||||
},
|
||||
{
|
||||
Key: arch + ".nextn_predict_layers",
|
||||
ValueType: gguf.GGUFMetadataValueTypeUint32,
|
||||
Value: nextn,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("MTP auto-defaults", func() {
|
||||
Context("MTPSpecOptions", func() {
|
||||
It("returns the upstream-recommended speculative tuple", func() {
|
||||
@@ -105,20 +82,5 @@ var _ = Describe("MTP auto-defaults", func() {
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
|
||||
It("detects a same-GGUF embedded head (DeepSeek/Qwen style)", func() {
|
||||
n, ok := HasEmbeddedMTPHead(ggufWithArch("qwen3moe", 1))
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(n).To(Equal(uint32(1)))
|
||||
})
|
||||
|
||||
It("ignores a gemma4-assistant draft-only model", func() {
|
||||
// The assistant GGUF carries nextn_predict_layers but is a separate
|
||||
// draft model that requires a paired target (ctx_other); it cannot
|
||||
// self-speculate, so it must not trigger the embedded-head defaults.
|
||||
n, ok := HasEmbeddedMTPHead(ggufWithArch("gemma4-assistant", 48))
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// The realtime pipeline can stream each stage (LLM tokens, TTS audio,
|
||||
// transcription text) and can disable model "thinking" for the LLM. These are
|
||||
// opt-in per pipeline; everything defaults to off so existing configs keep the
|
||||
// unary behaviour.
|
||||
var _ = Describe("Pipeline streaming config", func() {
|
||||
It("defaults every streaming + thinking helper to false when unset", func() {
|
||||
var p Pipeline
|
||||
Expect(p.StreamLLM()).To(BeFalse())
|
||||
Expect(p.StreamTTS()).To(BeFalse())
|
||||
Expect(p.StreamTranscription()).To(BeFalse())
|
||||
Expect(p.ChunkClauses()).To(BeFalse())
|
||||
Expect(p.ThinkingDisabled()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("parses the nested streaming block and disable_thinking from YAML", func() {
|
||||
var c ModelConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
name: gpt-realtime
|
||||
pipeline:
|
||||
llm: my-llm
|
||||
tts: my-tts
|
||||
transcription: my-stt
|
||||
streaming:
|
||||
llm: true
|
||||
tts: true
|
||||
transcription: true
|
||||
clause_chunking: true
|
||||
disable_thinking: true
|
||||
`), &c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Pipeline.StreamLLM()).To(BeTrue())
|
||||
Expect(c.Pipeline.StreamTTS()).To(BeTrue())
|
||||
Expect(c.Pipeline.StreamTranscription()).To(BeTrue())
|
||||
Expect(c.Pipeline.ChunkClauses()).To(BeTrue())
|
||||
Expect(c.Pipeline.ThinkingDisabled()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("treats an explicit false in the streaming block as disabled", func() {
|
||||
var c ModelConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
name: gpt-realtime
|
||||
pipeline:
|
||||
streaming:
|
||||
tts: false
|
||||
`), &c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Pipeline.StreamTTS()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
@@ -383,13 +383,13 @@ var _ = Describe("API test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
xlog.Error("server error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig(apiKey)
|
||||
defaultConfig.BaseURL = testHTTPBase + "/v1"
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
|
||||
client2 = openaigo.NewClient("")
|
||||
client2.BaseURL = defaultConfig.BaseURL
|
||||
@@ -418,7 +418,7 @@ var _ = Describe("API test", func() {
|
||||
|
||||
Context("Auth Tests", func() {
|
||||
It("Should fail if the api key is missing", func() {
|
||||
err, sc := postInvalidRequest(testHTTPBase + "/models/available")
|
||||
err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available")
|
||||
Expect(err).ToNot(BeNil())
|
||||
Expect(sc).To(Equal(401))
|
||||
})
|
||||
@@ -427,7 +427,7 @@ var _ = Describe("API test", func() {
|
||||
Context("URL routing Tests", func() {
|
||||
It("Should support reverse-proxy when unauthenticated", func() {
|
||||
|
||||
err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{
|
||||
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
|
||||
"X-Forwarded-Proto": {"https"},
|
||||
"X-Forwarded-Host": {"example.org"},
|
||||
"X-Forwarded-Prefix": {"/myprefix/"},
|
||||
@@ -441,7 +441,7 @@ var _ = Describe("API test", func() {
|
||||
|
||||
It("Should support reverse-proxy when authenticated", func() {
|
||||
|
||||
err, sc, body := getRequest(testHTTPBase+"/myprefix/", http.Header{
|
||||
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
|
||||
"Authorization": {bearerKey},
|
||||
"X-Forwarded-Proto": {"https"},
|
||||
"X-Forwarded-Host": {"example.org"},
|
||||
@@ -459,7 +459,7 @@ var _ = Describe("API test", func() {
|
||||
// requests them through the proxy.
|
||||
It("Should support reverse-proxy when prefix is stripped by the proxy", func() {
|
||||
|
||||
err, sc, body := getRequest(testHTTPBase+"/app", http.Header{
|
||||
err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{
|
||||
"X-Forwarded-Proto": {"https"},
|
||||
"X-Forwarded-Host": {"example.org"},
|
||||
"X-Forwarded-Prefix": {"/myprefix"},
|
||||
@@ -477,7 +477,7 @@ var _ = Describe("API test", func() {
|
||||
// from a foreign origin. BasePathPrefix must reject these via
|
||||
// SafeForwardedPrefix and fall back to "/".
|
||||
It("Should ignore an unsafe X-Forwarded-Prefix and not poison asset URLs", func() {
|
||||
err, sc, body := getRequest(testHTTPBase+"/app", http.Header{
|
||||
err, sc, body := getRequest("http://127.0.0.1:9090/app", http.Header{
|
||||
"X-Forwarded-Proto": {"https"},
|
||||
"X-Forwarded-Host": {"example.org"},
|
||||
"X-Forwarded-Prefix": {"//evil.com"},
|
||||
@@ -492,13 +492,13 @@ var _ = Describe("API test", func() {
|
||||
Context("Applying models", func() {
|
||||
|
||||
It("applies models from a gallery", func() {
|
||||
models, err := getModels(testHTTPBase + "/models/available")
|
||||
models, err := getModels("http://127.0.0.1:9090/models/available")
|
||||
Expect(err).To(BeNil())
|
||||
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
|
||||
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
|
||||
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
|
||||
|
||||
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
ID: "test@bert2",
|
||||
})
|
||||
|
||||
@@ -507,7 +507,7 @@ var _ = Describe("API test", func() {
|
||||
uuid := response["uuid"].(string)
|
||||
resp := map[string]any{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
@@ -526,7 +526,7 @@ var _ = Describe("API test", func() {
|
||||
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
||||
Expect(content["foo"]).To(Equal("bar"))
|
||||
|
||||
models, err = getModels(testHTTPBase + "/models/available")
|
||||
models, err = getModels("http://127.0.0.1:9090/models/available")
|
||||
Expect(err).To(BeNil())
|
||||
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
|
||||
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
|
||||
@@ -541,7 +541,7 @@ var _ = Describe("API test", func() {
|
||||
})
|
||||
It("overrides models", func() {
|
||||
|
||||
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: bertEmbeddingsURL,
|
||||
Name: "bert",
|
||||
Overrides: map[string]any{
|
||||
@@ -554,7 +554,7 @@ var _ = Describe("API test", func() {
|
||||
uuid := response["uuid"].(string)
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
@@ -567,7 +567,7 @@ var _ = Describe("API test", func() {
|
||||
Expect(content["backend"]).To(Equal("llama"))
|
||||
})
|
||||
It("apply models without overrides", func() {
|
||||
response := postModelApplyRequest(testHTTPBase+"/models/apply", modelApplyRequest{
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: bertEmbeddingsURL,
|
||||
Name: "bert",
|
||||
Overrides: map[string]any{},
|
||||
@@ -578,7 +578,7 @@ var _ = Describe("API test", func() {
|
||||
uuid := response["uuid"].(string)
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
@@ -622,14 +622,14 @@ parameters:
|
||||
}
|
||||
|
||||
var response schema.GalleryResponse
|
||||
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response.ID).ToNot(BeEmpty())
|
||||
|
||||
uuid := response.ID
|
||||
resp := map[string]any{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
@@ -657,7 +657,7 @@ parameters:
|
||||
}
|
||||
|
||||
var response schema.GalleryResponse
|
||||
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
|
||||
// The endpoint should return an error immediately
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("failed to discover model config"))
|
||||
@@ -693,14 +693,14 @@ parameters:
|
||||
}
|
||||
|
||||
var response schema.GalleryResponse
|
||||
err := postRequestResponseJSON(testHTTPBase+"/models/import-uri", &importReq, &response)
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response.ID).ToNot(BeEmpty())
|
||||
|
||||
uuid := response.ID
|
||||
resp := map[string]any{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus(testHTTPBase + "/models/jobs/" + uuid)
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
@@ -751,13 +751,13 @@ parameters:
|
||||
app, err = API(localAIApp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
xlog.Error("server error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = testHTTPBase + "/v1"
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
|
||||
client2 = openaigo.NewClient("")
|
||||
client2.BaseURL = defaultConfig.BaseURL
|
||||
@@ -801,7 +801,7 @@ parameters:
|
||||
// Mock-backend is registered via SetExternalBackend so it appears
|
||||
// alongside any built-in entries; verifying that string proves the
|
||||
// endpoint is wired up regardless of which real backends exist.
|
||||
resp, err := http.Get(testHTTPBase + "/system")
|
||||
resp, err := http.Get("http://127.0.0.1:9090/system")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
dat, err := io.ReadAll(resp.Body)
|
||||
@@ -824,14 +824,14 @@ parameters:
|
||||
}
|
||||
|
||||
var createResp map[string]any
|
||||
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(createResp["id"]).ToNot(BeEmpty())
|
||||
taskID := createResp["id"].(string)
|
||||
|
||||
// Get the task
|
||||
var task schema.Task
|
||||
resp, err := http.Get(testHTTPBase + "/api/agent/tasks/" + taskID)
|
||||
resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -839,7 +839,7 @@ parameters:
|
||||
Expect(task.Name).To(Equal("Test Task"))
|
||||
|
||||
// List tasks
|
||||
resp, err = http.Get(testHTTPBase + "/api/agent/tasks")
|
||||
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
var tasks []schema.Task
|
||||
@@ -849,18 +849,18 @@ parameters:
|
||||
|
||||
// Update task
|
||||
taskBody["name"] = "Updated Task"
|
||||
err = putRequestJSON(testHTTPBase+"/api/agent/tasks/"+taskID, &taskBody)
|
||||
err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify update
|
||||
resp, err = http.Get(testHTTPBase + "/api/agent/tasks/" + taskID)
|
||||
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
body, _ = io.ReadAll(resp.Body)
|
||||
json.Unmarshal(body, &task)
|
||||
Expect(task.Name).To(Equal("Updated Task"))
|
||||
|
||||
// Delete task
|
||||
req, _ := http.NewRequest("DELETE", testHTTPBase+"/api/agent/tasks/"+taskID, nil)
|
||||
req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil)
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -877,7 +877,7 @@ parameters:
|
||||
}
|
||||
|
||||
var createResp map[string]any
|
||||
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
taskID := createResp["id"].(string)
|
||||
|
||||
@@ -888,14 +888,14 @@ parameters:
|
||||
}
|
||||
|
||||
var jobResp schema.JobExecutionResponse
|
||||
err = postRequestResponseJSON(testHTTPBase+"/api/agent/jobs/execute", &jobBody, &jobResp)
|
||||
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(jobResp.JobID).ToNot(BeEmpty())
|
||||
jobID := jobResp.JobID
|
||||
|
||||
// Get job status
|
||||
var job schema.Job
|
||||
resp, err := http.Get(testHTTPBase + "/api/agent/jobs/" + jobID)
|
||||
resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -904,7 +904,7 @@ parameters:
|
||||
Expect(job.TaskID).To(Equal(taskID))
|
||||
|
||||
// List jobs
|
||||
resp, err = http.Get(testHTTPBase + "/api/agent/jobs")
|
||||
resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
var jobs []schema.Job
|
||||
@@ -914,7 +914,7 @@ parameters:
|
||||
|
||||
// Cancel job (if still pending/running)
|
||||
if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning {
|
||||
req, _ := http.NewRequest("POST", testHTTPBase+"/api/agent/jobs/"+jobID+"/cancel", nil)
|
||||
req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil)
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -932,13 +932,13 @@ parameters:
|
||||
}
|
||||
|
||||
var createResp map[string]any
|
||||
err := postRequestResponseJSON(testHTTPBase+"/api/agent/tasks", &taskBody, &createResp)
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Execute by name
|
||||
paramsBody := map[string]string{"param1": "value1"}
|
||||
var jobResp schema.JobExecutionResponse
|
||||
err = postRequestResponseJSON(testHTTPBase+"/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp)
|
||||
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(jobResp.JobID).ToNot(BeEmpty())
|
||||
})
|
||||
@@ -998,13 +998,13 @@ parameters:
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
xlog.Error("server error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = testHTTPBase + "/v1"
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
client2 = openaigo.NewClient("")
|
||||
client2.BaseURL = defaultConfig.BaseURL
|
||||
// Wait for API to be ready
|
||||
|
||||
@@ -25,10 +25,6 @@ var knownPrefOnlyBackends = []schema.KnownBackend{
|
||||
// Text LLM
|
||||
// ds4: antirez/ds4 - single-model DeepSeek V4 Flash engine; auto-detected via DS4Importer
|
||||
{Name: "ds4", Modality: "text", AutoDetect: false, Description: "antirez/ds4 DeepSeek V4 Flash engine (auto-detected; pref-only fallback)"},
|
||||
// dllm consumes GGUF weights like llama-cpp does, but only for the
|
||||
// DiffusionGemma architecture - auto-detecting on .gguf would shadow
|
||||
// llama-cpp, so it stays preference-only.
|
||||
{Name: "dllm", Modality: "text", AutoDetect: false, Description: "dllm.cpp DiffusionGemma block-diffusion engine (preference-only)"},
|
||||
{Name: "sglang", Modality: "text", AutoDetect: false, Description: "SGLang runtime (preference-only)"},
|
||||
{Name: "tinygrad", Modality: "text", AutoDetect: false, Description: "tinygrad runtime (preference-only)"},
|
||||
{Name: "trl", Modality: "text", AutoDetect: false, Description: "Transformers Reinforcement Learning (preference-only)"},
|
||||
|
||||
@@ -135,7 +135,6 @@ var _ = Describe("Backend Endpoints", func() {
|
||||
Expect(entry.Modality).To(Equal(modality))
|
||||
}
|
||||
|
||||
expectPrefOnly("dllm", "text")
|
||||
expectPrefOnly("sglang", "text")
|
||||
expectPrefOnly("tinygrad", "text")
|
||||
expectPrefOnly("trl", "text")
|
||||
|
||||
@@ -103,12 +103,7 @@ func applyAutoparserOverride(
|
||||
// blocks like "<think></think>" that some models emit when reasoning
|
||||
// is disabled.
|
||||
if deltaReasoning == "" && deltaContent != "" {
|
||||
// Complete-response extraction: only honor a prefilled <think> start
|
||||
// token when deltaContent actually closes the reasoning block. Without
|
||||
// it the model answered directly and the whole answer must stay in
|
||||
// content rather than be swallowed as unclosed reasoning. See
|
||||
// reason.ExtractReasoningComplete.
|
||||
deltaReasoning, deltaContent = reason.ExtractReasoningComplete(deltaContent, thinkingStartToken, reasoningConfig)
|
||||
deltaReasoning, deltaContent = reason.ExtractReasoningWithConfig(deltaContent, thinkingStartToken, reasoningConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] non-SSE no-tools: overriding result with C++ autoparser deltas",
|
||||
"content_len", len(deltaContent), "reasoning_len", len(deltaReasoning))
|
||||
|
||||
@@ -186,114 +186,6 @@ var _ = Describe("applyAutoparserOverride", func() {
|
||||
Expect(result).To(Equal(existing))
|
||||
})
|
||||
})
|
||||
|
||||
// Regression tests for the prefilled-thinking-token path (thinkingStartToken
|
||||
// != ""). This is the configuration the gallery qwen3 family runs in: the
|
||||
// chat template injects <think> into the prompt, so DetectThinkingStartToken
|
||||
// returns "<think>" and the model's output begins *inside* a reasoning block
|
||||
// — it emits a closing </think> but no opening tag.
|
||||
//
|
||||
// The defensive Go-side fallback prepends the start token so the standard
|
||||
// extractor can pair it with the model's </think>. But on a *complete*
|
||||
// response that contains NO closing tag (the model answered directly with no
|
||||
// reasoning at all), prepending <think> manufactures an unclosed block that
|
||||
// swallows the entire answer into reasoning, leaving content empty. That is
|
||||
// the bug: short/direct answers (session names, JSON summaries) come back
|
||||
// with an empty content field.
|
||||
Context("autoparser delivered content with empty reasoning and a prefilled thinking token", func() {
|
||||
const startToken = "<think>"
|
||||
|
||||
It("keeps a tag-less direct answer as content instead of swallowing it as reasoning", func() {
|
||||
// Model answered directly: no <think>, no </think> anywhere.
|
||||
chatDeltas := []*pb.ChatDelta{
|
||||
{Content: "hello", ReasoningContent: ""},
|
||||
}
|
||||
|
||||
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
|
||||
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(result[0].Message.Content).ToNot(BeNil())
|
||||
Expect(*(result[0].Message.Content.(*string))).To(Equal("hello"),
|
||||
"a complete answer with no closing reasoning tag must stay in content")
|
||||
Expect(result[0].Message.Reasoning).To(BeNil(),
|
||||
"no reasoning block was emitted, so Reasoning must not be set")
|
||||
})
|
||||
|
||||
It("keeps a tag-less JSON answer as content (the summary case)", func() {
|
||||
raw := `{"short":"Tests pass","long":"go test ./... succeeded."}`
|
||||
chatDeltas := []*pb.ChatDelta{
|
||||
{Content: raw, ReasoningContent: ""},
|
||||
}
|
||||
|
||||
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
|
||||
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(*(result[0].Message.Content.(*string))).To(Equal(raw))
|
||||
Expect(result[0].Message.Reasoning).To(BeNil())
|
||||
})
|
||||
|
||||
It("still splits reasoning when the model emits the closing tag (prefill paired with </think>)", func() {
|
||||
// The legitimate prefill case: <think> was in the prompt, so the
|
||||
// output carries only the closing tag. The closing tag is the proof
|
||||
// that a reasoning block exists, so extraction must run.
|
||||
raw := "The user wants a greeting.\n</think>\n\nHello there!"
|
||||
chatDeltas := []*pb.ChatDelta{
|
||||
{Content: raw, ReasoningContent: ""},
|
||||
}
|
||||
|
||||
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
|
||||
|
||||
Expect(result).To(HaveLen(1))
|
||||
content := *(result[0].Message.Content.(*string))
|
||||
Expect(content).To(ContainSubstring("Hello there!"))
|
||||
Expect(content).ToNot(ContainSubstring("</think>"))
|
||||
Expect(content).ToNot(ContainSubstring("The user wants a greeting"))
|
||||
Expect(result[0].Message.Reasoning).ToNot(BeNil())
|
||||
Expect(*result[0].Message.Reasoning).To(ContainSubstring("The user wants a greeting"))
|
||||
})
|
||||
|
||||
It("still splits a fully-tagged <think>…</think> block with a prefill token set", func() {
|
||||
raw := "<think>Reasoning here.</think>Final answer."
|
||||
chatDeltas := []*pb.ChatDelta{
|
||||
{Content: raw, ReasoningContent: ""},
|
||||
}
|
||||
|
||||
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
|
||||
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(*(result[0].Message.Content.(*string))).To(Equal("Final answer."))
|
||||
Expect(result[0].Message.Reasoning).ToNot(BeNil())
|
||||
Expect(*result[0].Message.Reasoning).To(ContainSubstring("Reasoning here"))
|
||||
})
|
||||
|
||||
// End-to-end regression for the real production failure: a request with
|
||||
// enable_thinking=false against a <think>-capable model (qwen3 family).
|
||||
//
|
||||
// In non-thinking mode the model emits no reasoning block, so llama.cpp's
|
||||
// autoparser correctly returns ChatDeltas with Content set and
|
||||
// ReasoningContent EMPTY (verified against stock llama-server: the same
|
||||
// model with chat_template_kwargs.enable_thinking=false returns
|
||||
// reasoning_content=null and content="hello"). But thinkingStartToken is
|
||||
// detected per-model from the enable_thinking=TRUE render
|
||||
// (grpc-server renders with enable_thinking=true; DetectThinkingStartToken
|
||||
// does not evaluate the jinja {% if enable_thinking %} conditional), so it
|
||||
// is "<think>" even for this non-thinking request. The old code prepended
|
||||
// it and swallowed the answer. This is the case that broke session
|
||||
// summaries and auto-titles and was NOT covered before.
|
||||
It("preserves content for a non-thinking-mode request (enable_thinking=false, empty reasoning_content)", func() {
|
||||
// What llama.cpp's autoparser actually returns in non-thinking mode.
|
||||
chatDeltas := []*pb.ChatDelta{
|
||||
{Content: `{"short":"Go tests passed for internal/session"}`, ReasoningContent: ""},
|
||||
}
|
||||
|
||||
result := applyAutoparserOverride(chatDeltas, startToken, reason.Config{}, nil)
|
||||
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(*(result[0].Message.Content.(*string))).To(Equal(`{"short":"Go tests passed for internal/session"}`),
|
||||
"non-thinking-mode answers must reach the client intact, not be swallowed as reasoning")
|
||||
Expect(result[0].Message.Reasoning).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("mergeToolCallDeltas", func() {
|
||||
|
||||
@@ -2,10 +2,8 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
@@ -237,12 +235,6 @@ type Model interface {
|
||||
Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error)
|
||||
Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error)
|
||||
TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error)
|
||||
// TTSStream synthesizes speech incrementally, invoking onAudio with raw PCM
|
||||
// chunks (and the backend sample rate) as they are produced.
|
||||
TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error
|
||||
// TranscribeStream transcribes audio incrementally, invoking onDelta for each
|
||||
// transcript text fragment and returning the final aggregated result.
|
||||
TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error)
|
||||
PredictConfig() *config.ModelConfig
|
||||
}
|
||||
|
||||
@@ -1262,15 +1254,27 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co
|
||||
// TODO: If we have a real any-to-any model then transcription is optional
|
||||
var transcript string
|
||||
if session.InputAudioTranscription != nil {
|
||||
// emitTranscription streams transcript deltas when
|
||||
// pipeline.streaming.transcription is set, otherwise emits a single
|
||||
// completed event; either way it returns the final transcript text.
|
||||
var err error
|
||||
transcript, err = emitTranscription(ctx, t, session, generateItemID(), f.Name())
|
||||
tr, err := session.ModelInterface.Transcribe(ctx, f.Name(), session.InputAudioTranscription.Language, false, false, session.InputAudioTranscription.Prompt)
|
||||
if err != nil {
|
||||
sendError(t, "transcription_failed", err.Error(), "", "event_TODO")
|
||||
return
|
||||
} else if tr == nil {
|
||||
sendError(t, "transcription_failed", "trancribe result is nil", "", "event_TODO")
|
||||
return
|
||||
}
|
||||
|
||||
transcript = tr.Text
|
||||
sendEvent(t, types.ConversationItemInputAudioTranscriptionCompletedEvent{
|
||||
ServerEventBase: types.ServerEventBase{
|
||||
EventID: "event_TODO",
|
||||
},
|
||||
|
||||
ItemID: generateItemID(),
|
||||
// ResponseID: "resp_TODO", // Not needed for transcription completed event
|
||||
// OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Transcript: transcript,
|
||||
})
|
||||
} else {
|
||||
sendNotImplemented(t, "any-to-any models")
|
||||
return
|
||||
@@ -1498,26 +1502,6 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
},
|
||||
})
|
||||
|
||||
// Streamed LLM path: when the pipeline opts into LLM streaming, stream the
|
||||
// transcript to the client as it is generated and synthesize the buffered
|
||||
// message once. Tool turns are supported only when the model uses its
|
||||
// tokenizer template: the C++ autoparser then delivers content and tool
|
||||
// calls via ChatDeltas (clearing the text stream), so the spoken transcript
|
||||
// never leaks tool-call tokens. Grammar-based function calling emits the
|
||||
// call as JSON in the token stream, so those turns keep the buffered path.
|
||||
if config != nil && session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamLLM() {
|
||||
canStream := len(tools) == 0 || config.TemplateConfig.UseTokenizerTemplate
|
||||
var respMods []types.Modality
|
||||
if overrides != nil {
|
||||
respMods = overrides.OutputModalities
|
||||
}
|
||||
if canStream && modalitiesContainAudio(resolveOutputModalities(session.OutputModalities, respMods)) {
|
||||
if streamLLMResponse(ctx, session, conv, t, responseID, conversationHistory, images, config, tools, toolChoice, toolTurn) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
predFunc, err := session.ModelInterface.Predict(ctx, conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil)
|
||||
if err != nil {
|
||||
sendError(t, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here
|
||||
@@ -1595,7 +1579,7 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
// ExtractReasoningWithConfig is a no-op when no tag pair matches,
|
||||
// so it's safe to apply unconditionally in the no-reasoning branch.
|
||||
if deltaReasoning == "" && deltaContent != "" {
|
||||
deltaReasoning, deltaContent = reasoning.ExtractReasoningComplete(deltaContent, thinkingStartToken, spokenReasoningConfig(config.ReasoningConfig))
|
||||
deltaReasoning, deltaContent = reasoning.ExtractReasoningWithConfig(deltaContent, thinkingStartToken, config.ReasoningConfig)
|
||||
}
|
||||
reasoningText = deltaReasoning
|
||||
responseWithoutReasoning = deltaContent
|
||||
@@ -1603,7 +1587,7 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
cleanedResponse = deltaContent
|
||||
toolCalls = deltaToolCalls
|
||||
} else {
|
||||
reasoningText, responseWithoutReasoning = reasoning.ExtractReasoningComplete(rawResponse, thinkingStartToken, spokenReasoningConfig(config.ReasoningConfig))
|
||||
reasoningText, responseWithoutReasoning = reasoning.ExtractReasoningWithConfig(rawResponse, thinkingStartToken, config.ReasoningConfig)
|
||||
textContent = functions.ParseTextContent(responseWithoutReasoning, config.FunctionsConfig)
|
||||
cleanedResponse = functions.CleanupLLMResult(responseWithoutReasoning, config.FunctionsConfig)
|
||||
toolCalls = functions.ParseFunctionCall(cleanedResponse, config.FunctionsConfig)
|
||||
@@ -1729,7 +1713,64 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
return
|
||||
}
|
||||
|
||||
// Transcript of the spoken reply (the audio's text).
|
||||
audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
xlog.Debug("TTS cancelled (barge-in)")
|
||||
sendCancelledResponse()
|
||||
return
|
||||
}
|
||||
xlog.Error("TTS failed", "error", err)
|
||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
if !res.Success {
|
||||
xlog.Error("TTS failed", "message", res.Message)
|
||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
defer func() { _ = os.Remove(audioFilePath) }()
|
||||
|
||||
audioBytes, err := os.ReadFile(audioFilePath)
|
||||
if err != nil {
|
||||
xlog.Error("failed to read TTS file", "error", err)
|
||||
sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse WAV header to get raw PCM and the actual sample rate from the TTS backend.
|
||||
pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes)
|
||||
if ttsSampleRate == 0 {
|
||||
ttsSampleRate = localSampleRate
|
||||
}
|
||||
xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate)
|
||||
|
||||
// SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the
|
||||
// Opus encoder, which resamples to 48kHz internally. This avoids a
|
||||
// lossy intermediate resample through 16kHz.
|
||||
// XXX: This is a noop in websocket mode; it's included in the JSON instead
|
||||
if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
xlog.Debug("Audio playback cancelled (barge-in)")
|
||||
sendCancelledResponse()
|
||||
return
|
||||
}
|
||||
xlog.Error("failed to send audio via transport", "error", err)
|
||||
}
|
||||
|
||||
// For WebSocket clients, resample to the session's output rate and
|
||||
// deliver audio as base64 in JSON events. WebRTC clients already
|
||||
// received audio over the RTP track, so skip the base64 payload.
|
||||
if !isWebRTC {
|
||||
wsPCM := pcmData
|
||||
if ttsSampleRate != session.OutputSampleRate {
|
||||
samples := sound.BytesToInt16sLE(pcmData)
|
||||
resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate)
|
||||
wsPCM = sound.Int16toBytesLE(resampled)
|
||||
}
|
||||
audioString = base64.StdEncoding.EncodeToString(wsPCM)
|
||||
}
|
||||
|
||||
sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
@@ -1747,26 +1788,15 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
Transcript: finalSpeech,
|
||||
})
|
||||
|
||||
// Synthesize and send the audio. With pipeline.streaming.tts enabled
|
||||
// emitSpeech forwards a response.output_audio.delta per backend PCM
|
||||
// chunk as it's produced; otherwise it sends the whole utterance as a
|
||||
// single delta. The returned PCM is stored (base64) on the item below.
|
||||
pcmAudio, err := emitSpeech(ctx, t, session, responseID, item.Assistant.ID, finalSpeech)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
xlog.Debug("TTS cancelled (barge-in)")
|
||||
sendCancelledResponse()
|
||||
return
|
||||
}
|
||||
xlog.Error("TTS failed", "error", err)
|
||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
if !isWebRTC {
|
||||
audioString = base64.StdEncoding.EncodeToString(pcmAudio)
|
||||
}
|
||||
|
||||
if !isWebRTC {
|
||||
sendEvent(t, types.ResponseOutputAudioDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
ItemID: item.Assistant.ID,
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Delta: audioString,
|
||||
})
|
||||
sendEvent(t, types.ResponseOutputAudioDoneEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
@@ -1819,27 +1849,17 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
})
|
||||
}
|
||||
|
||||
// Emit the parsed tool calls, the terminal response.done, and (for
|
||||
// server-side assistant tools) the follow-up response. Shared with the
|
||||
// streamed path so both finalize tool calls identically.
|
||||
emitToolCallItems(ctx, session, conv, t, responseID, finalToolCalls, finalSpeech != "", toolTurn)
|
||||
}
|
||||
|
||||
// emitToolCallItems emits the realtime function_call items for the parsed tool
|
||||
// calls, the terminal response.done, and — for server-side LocalAI Assistant
|
||||
// tools — re-triggers a follow-up response so the model can speak the result.
|
||||
// hasContent shifts the tool-call output index past the assistant content item
|
||||
// when the same turn also produced spoken/text content. Two tool paths:
|
||||
// - LocalAI Assistant tools (session.AssistantExecutor.IsTool) run server-side;
|
||||
// we append both the call and its output to conv.Items and re-trigger. The
|
||||
// client only sees observability events.
|
||||
// - All other tools follow the standard OpenAI flow: emit
|
||||
// function_call_arguments.done and wait for the client to send
|
||||
// conversation.item.create back.
|
||||
func emitToolCallItems(ctx context.Context, session *Session, conv *Conversation, t Transport, responseID string, toolCalls []functions.FuncCallResults, hasContent bool, toolTurn int) {
|
||||
xlog.Debug("About to handle tool calls", "finalToolCallsCount", len(toolCalls))
|
||||
// Handle Tool Calls. Two paths:
|
||||
// - LocalAI Assistant tools (session.AssistantExecutor.IsTool) run
|
||||
// server-side; we append both the call and its output to conv.Items
|
||||
// and re-trigger a follow-up response so the model can speak the
|
||||
// result. The client only sees observability events.
|
||||
// - All other tools follow the standard OpenAI flow: emit
|
||||
// function_call_arguments.done and wait for the client to send
|
||||
// conversation.item.create back.
|
||||
xlog.Debug("About to handle tool calls", "finalToolCallsCount", len(finalToolCalls))
|
||||
executedAssistantTool := false
|
||||
for i, tc := range toolCalls {
|
||||
for i, tc := range finalToolCalls {
|
||||
toolCallID := generateItemID()
|
||||
callID := "call_" + generateUniqueID() // OpenAI uses call_xyz
|
||||
|
||||
@@ -1859,7 +1879,7 @@ func emitToolCallItems(ctx context.Context, session *Session, conv *Conversation
|
||||
conv.Lock.Unlock()
|
||||
|
||||
outputIndex := i
|
||||
if hasContent {
|
||||
if finalSpeech != "" {
|
||||
outputIndex++
|
||||
}
|
||||
|
||||
@@ -1985,11 +2005,8 @@ func generateItemID() string {
|
||||
}
|
||||
|
||||
func generateUniqueID() string {
|
||||
// 16 random bytes, hex-encoded. Must be collision-free: session, item,
|
||||
// response and call IDs build on this, and the conversation tracks/removes
|
||||
// items by ID (e.g. cancel() in realtime_stream.go, conversation.item.retrieve).
|
||||
// A constant would make every ID alias and corrupt that bookkeeping.
|
||||
var b [16]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return hex.EncodeToString(b[:])
|
||||
// Generate a unique ID string
|
||||
// For simplicity, use a counter or UUID
|
||||
// Implement as needed
|
||||
return "unique_id"
|
||||
}
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/rivo/uniseg"
|
||||
)
|
||||
|
||||
// Default clause-chunker bounds (in runes). minRunes gates only sub-sentence
|
||||
// (clause-mark / Thai-space) cuts so we don't synthesize tiny choppy fragments;
|
||||
// full sentences always flush regardless of length. maxRunes caps an
|
||||
// unterminated run so a long punctuation-less span doesn't buffer unbounded.
|
||||
const (
|
||||
defaultClauseMinRunes = 12
|
||||
defaultClauseMaxRunes = 200
|
||||
)
|
||||
|
||||
// clauseChunker splits streamed LLM content into speakable clauses for
|
||||
// incremental TTS, in a SCRIPT-AWARE way so it works for languages without
|
||||
// whitespace word boundaries. It leans on UAX #29 sentence segmentation (which
|
||||
// natively terminates on CJK 。!? as well as Latin .!?), adds CJK clause
|
||||
// punctuation (,、;:) and Thai/Lao spaces as finer boundaries, and caps an
|
||||
// over-long unterminated run via UAX #14 line-break opportunities.
|
||||
//
|
||||
// Unlike the old ASCII .!?/newline segmenter (dropped in 076dcdbe), it does not
|
||||
// degrade to whole-message buffering for CJK (handled natively) or Thai/Lao
|
||||
// (handled via spaces, which Thai uses at clause/sentence boundaries). Scripts
|
||||
// that genuinely need a dictionary (Khmer/Myanmar) simply stay buffered until a
|
||||
// space or end-of-message — no worse than the buffered default.
|
||||
//
|
||||
// It is not safe for concurrent use; callers feed it from a single goroutine
|
||||
// (the LLM token callback).
|
||||
type clauseChunker struct {
|
||||
buf strings.Builder
|
||||
minRunes int
|
||||
maxRunes int
|
||||
}
|
||||
|
||||
func newClauseChunker(minRunes, maxRunes int) *clauseChunker {
|
||||
return &clauseChunker{minRunes: minRunes, maxRunes: maxRunes}
|
||||
}
|
||||
|
||||
// push appends streamed content and returns any clauses that are now complete —
|
||||
// "complete" meaning confirmed by following content, so we never speak a clause
|
||||
// that the next token might extend. Incomplete trailing text stays buffered.
|
||||
func (c *clauseChunker) push(text string) []string {
|
||||
c.buf.WriteString(text)
|
||||
return c.drain(false)
|
||||
}
|
||||
|
||||
// flush returns the remaining buffered clauses, treating end-of-input as a hard
|
||||
// boundary, and clears the buffer.
|
||||
func (c *clauseChunker) flush() []string {
|
||||
return c.drain(true)
|
||||
}
|
||||
|
||||
func (c *clauseChunker) drain(final bool) []string {
|
||||
s := c.buf.String()
|
||||
rest := s
|
||||
var out []string
|
||||
for rest != "" {
|
||||
end, ok := c.nextBoundary(rest, final)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if seg := strings.TrimSpace(rest[:end]); seg != "" {
|
||||
out = append(out, seg)
|
||||
}
|
||||
rest = rest[end:]
|
||||
}
|
||||
// Rewriting the builder reallocates and copies the whole buffer; skip it on
|
||||
// the common per-token call where no boundary was confirmed.
|
||||
if len(rest) != len(s) {
|
||||
c.buf.Reset()
|
||||
c.buf.WriteString(rest)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// nextBoundary returns the byte offset just past the first emittable clause in
|
||||
// s, or ok=false when more input is needed (final=false) and no boundary is
|
||||
// confirmed yet.
|
||||
func (c *clauseChunker) nextBoundary(s string, final bool) (int, bool) {
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// 1) UAX #29 sentence boundary. When the first sentence is followed by more
|
||||
// text it is a confirmed complete sentence (handles Latin .!? with
|
||||
// abbreviation/decimal guards, and CJK 。!? with no whitespace).
|
||||
sentence, rest, _ := uniseg.FirstSentenceInString(s, -1)
|
||||
if rest != "" {
|
||||
// Optionally cut finer inside the sentence at a clause boundary.
|
||||
if cut, ok := c.firstClauseCut(sentence); ok {
|
||||
return cut, true
|
||||
}
|
||||
return len(sentence), true
|
||||
}
|
||||
|
||||
// 2) Unterminated tail: look for a sub-sentence clause boundary (CJK
|
||||
// punctuation or a Thai/Lao space) confirmed by following content.
|
||||
if cut, ok := c.firstClauseCut(s); ok {
|
||||
return cut, true
|
||||
}
|
||||
|
||||
// 3) Over-long punctuation-less run: force a typographically legal break so
|
||||
// we don't buffer unbounded (e.g. a long CJK run with no punctuation).
|
||||
if !final && c.maxRunes > 0 && utf8.RuneCountInString(s) > c.maxRunes {
|
||||
if cut, ok := lineBreakCut(s, c.maxRunes); ok {
|
||||
return cut, true
|
||||
}
|
||||
}
|
||||
|
||||
// 4) End of input: emit whatever remains as the final clause.
|
||||
if final {
|
||||
return len(s), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// firstClauseCut returns the byte offset just past the first sub-sentence clause
|
||||
// boundary in s — a CJK clause punctuation mark, or a space following a Thai/Lao
|
||||
// letter — provided the prefix is at least minRunes long and non-space content
|
||||
// follows. The boundary mark (and any trailing spaces) stay with the left clause.
|
||||
func (c *clauseChunker) firstClauseCut(s string) (int, bool) {
|
||||
var prev rune
|
||||
runes := 0
|
||||
for i, r := range s {
|
||||
boundary := isCJKClausePunct(r) || (unicode.IsSpace(r) && isThaiLao(prev))
|
||||
if boundary && runes+1 >= c.minRunes {
|
||||
end := i + utf8.RuneLen(r)
|
||||
for end < len(s) {
|
||||
nr, sz := utf8.DecodeRuneInString(s[end:])
|
||||
if !unicode.IsSpace(nr) {
|
||||
break
|
||||
}
|
||||
end += sz
|
||||
}
|
||||
if end < len(s) { // confirmed: real content follows the boundary
|
||||
return end, true
|
||||
}
|
||||
// Boundary sits at the end of the buffer with nothing after it yet —
|
||||
// wait for the next token to confirm it rather than emit early.
|
||||
return 0, false
|
||||
}
|
||||
prev = r
|
||||
runes++
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// lineBreakCut walks UAX #14 line segments and returns the byte offset of the
|
||||
// last legal break opportunity at or before maxRunes. Returns ok=false when the
|
||||
// run has no internal break opportunity (e.g. a space-less Thai run), leaving it
|
||||
// buffered.
|
||||
func lineBreakCut(s string, maxRunes int) (int, bool) {
|
||||
state := -1
|
||||
rest := s
|
||||
consumed := 0
|
||||
runes := 0
|
||||
for rest != "" {
|
||||
seg, rem, _, st := uniseg.FirstLineSegmentInString(rest, state)
|
||||
state = st
|
||||
runes += utf8.RuneCountInString(seg)
|
||||
consumed += len(seg)
|
||||
rest = rem
|
||||
if runes >= maxRunes {
|
||||
if consumed < len(s) {
|
||||
return consumed, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// isCJKClausePunct reports whether r is a CJK clause-level separator worth a
|
||||
// soft TTS break. Sentence terminators (。!?) are intentionally excluded — UAX
|
||||
// #29 sentence segmentation already handles those.
|
||||
func isCJKClausePunct(r rune) bool {
|
||||
switch r {
|
||||
case ',', // , fullwidth comma
|
||||
'、', // 、 ideographic comma
|
||||
';', // ; fullwidth semicolon
|
||||
':', // : fullwidth colon
|
||||
'・', // ・ katakana middle dot
|
||||
'・': // ・ halfwidth katakana middle dot
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isThaiLao reports whether r is a Thai or Lao letter. Those scripts have no
|
||||
// inter-word spaces; an ASCII space inside such a run marks a clause/sentence
|
||||
// boundary, which is the only no-dictionary segmentation signal available.
|
||||
func isThaiLao(r rune) bool {
|
||||
return unicode.Is(unicode.Thai, r) || unicode.Is(unicode.Lao, r)
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// clauseChunker splits streamed LLM content into speakable clauses in a
|
||||
// script-aware way: UAX#29 sentences (Latin .!? and CJK 。!?), CJK clause
|
||||
// punctuation, and Thai/Lao spaces — never whitespace-splitting CJK.
|
||||
var _ = Describe("clauseChunker", func() {
|
||||
Context("Latin sentences", func() {
|
||||
It("emits a sentence only once following content confirms it is complete", func() {
|
||||
c := newClauseChunker(12, 200)
|
||||
Expect(c.push("Hello world. How are you?")).To(Equal([]string{"Hello world."}))
|
||||
// The trailing sentence is held until flush (the next token might extend it).
|
||||
Expect(c.flush()).To(Equal([]string{"How are you?"}))
|
||||
})
|
||||
|
||||
It("assembles a sentence across many small tokens", func() {
|
||||
c := newClauseChunker(12, 200)
|
||||
var got []string
|
||||
for _, tok := range []string{"Hello", " world.", " How", " are", " you?"} {
|
||||
got = append(got, c.push(tok)...)
|
||||
}
|
||||
got = append(got, c.flush()...)
|
||||
Expect(got).To(Equal([]string{"Hello world.", "How are you?"}))
|
||||
})
|
||||
|
||||
It("does not split decimals or abbreviations (UAX#29 SB6)", func() {
|
||||
c := newClauseChunker(12, 200)
|
||||
got := c.push("Pi is 3.14 and e is 2.72. Done")
|
||||
Expect(got).To(Equal([]string{"Pi is 3.14 and e is 2.72."}))
|
||||
Expect(c.flush()).To(Equal([]string{"Done"}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("CJK (no whitespace)", func() {
|
||||
It("splits Chinese on the ideographic full stop", func() {
|
||||
c := newClauseChunker(12, 200)
|
||||
Expect(c.push("你好世界。今天天气很好。")).To(Equal([]string{"你好世界。"}))
|
||||
Expect(c.flush()).To(Equal([]string{"今天天气很好。"}))
|
||||
})
|
||||
|
||||
It("splits Japanese on the ideographic full stop", func() {
|
||||
c := newClauseChunker(12, 200)
|
||||
Expect(c.push("こんにちは。元気ですか。")).To(Equal([]string{"こんにちは。"}))
|
||||
Expect(c.flush()).To(Equal([]string{"元気ですか。"}))
|
||||
})
|
||||
|
||||
It("splits on CJK clause punctuation for lower latency", func() {
|
||||
c := newClauseChunker(2, 200) // small min so short test clauses cut
|
||||
Expect(c.push("你好,世界。再见")).To(Equal([]string{"你好,", "世界。"}))
|
||||
Expect(c.flush()).To(Equal([]string{"再见"}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Thai (spaces mark clauses, not words)", func() {
|
||||
It("splits a Thai run on the inter-clause space", func() {
|
||||
c := newClauseChunker(2, 200)
|
||||
Expect(c.push("สวัสดีครับ กินข้าวไหม")).To(Equal([]string{"สวัสดีครับ"}))
|
||||
Expect(c.flush()).To(Equal([]string{"กินข้าวไหม"}))
|
||||
})
|
||||
|
||||
It("never shatters a space-less Thai run into characters", func() {
|
||||
c := newClauseChunker(2, 200)
|
||||
Expect(c.push("สวัสดีครับ")).To(BeEmpty()) // held, no boundary
|
||||
Expect(c.flush()).To(Equal([]string{"สวัสดีครับ"}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("length cap (UAX#14 fallback)", func() {
|
||||
It("force-breaks an over-long punctuation-less CJK run at legal points", func() {
|
||||
c := newClauseChunker(4, 10) // maxRunes = 10
|
||||
run := strings.Repeat("字", 25)
|
||||
got := c.push(run)
|
||||
got = append(got, c.flush()...)
|
||||
total := 0
|
||||
for _, seg := range got {
|
||||
n := utf8.RuneCountInString(seg)
|
||||
Expect(n).To(BeNumerically("<=", 10)) // never exceeds the cap
|
||||
total += n
|
||||
}
|
||||
Expect(total).To(Equal(25)) // nothing dropped
|
||||
Expect(len(got)).To(BeNumerically(">=", 3)) // 10 + 10 + 5
|
||||
})
|
||||
})
|
||||
|
||||
Context("buffer lifecycle", func() {
|
||||
It("flush clears the buffer so the chunker is reusable", func() {
|
||||
c := newClauseChunker(12, 200)
|
||||
// "First one." is confirmed by the following "Second", so push drains it;
|
||||
// only the unterminated tail remains for flush.
|
||||
Expect(c.push("First one. Second")).To(Equal([]string{"First one."}))
|
||||
Expect(c.flush()).To(Equal([]string{"Second"}))
|
||||
Expect(c.flush()).To(BeEmpty())
|
||||
Expect(c.push("Again. More")).To(Equal([]string{"Again."}))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,138 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// fakeTransport records the server events and audio sent to a realtime client
|
||||
// so streaming behaviour can be asserted without a real WebSocket/WebRTC peer.
|
||||
// It is not a *WebRTCTransport, so handler code takes the WebSocket path.
|
||||
type fakeTransport struct {
|
||||
events []types.ServerEvent
|
||||
audio []fakeAudioChunk
|
||||
}
|
||||
|
||||
type fakeAudioChunk struct {
|
||||
pcm []byte
|
||||
sampleRate int
|
||||
}
|
||||
|
||||
func (f *fakeTransport) SendEvent(e types.ServerEvent) error {
|
||||
f.events = append(f.events, e)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeTransport) ReadEvent() ([]byte, error) { return nil, nil }
|
||||
|
||||
func (f *fakeTransport) SendAudio(_ context.Context, pcm []byte, sampleRate int) error {
|
||||
f.audio = append(f.audio, fakeAudioChunk{pcm: pcm, sampleRate: sampleRate})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeTransport) Close() error { return nil }
|
||||
|
||||
// countEvents returns how many recorded events have the given type.
|
||||
func (f *fakeTransport) countEvents(et types.ServerEventType) int {
|
||||
n := 0
|
||||
for _, e := range f.events {
|
||||
if e.ServerEventType() == et {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// transcriptDeltaText concatenates the Delta of every recorded transcript
|
||||
// delta event — i.e. the text streamed to the client as it is generated.
|
||||
func (f *fakeTransport) transcriptDeltaText() string {
|
||||
var b strings.Builder
|
||||
for _, e := range f.events {
|
||||
if d, ok := e.(types.ResponseOutputAudioTranscriptDeltaEvent); ok {
|
||||
b.WriteString(d.Delta)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// fakeModel is a configurable Model double. TTSStream replays ttsStreamChunks
|
||||
// and TranscribeStream replays transcribeDeltas, so the handler's streaming
|
||||
// paths can be driven deterministically.
|
||||
type fakeModel struct {
|
||||
cfg *config.ModelConfig
|
||||
|
||||
ttsFile string
|
||||
ttsStreamChunks [][]byte
|
||||
ttsStreamRate int
|
||||
ttsStreamErr error
|
||||
|
||||
transcribeDeltas []string
|
||||
transcribeFinal *schema.TranscriptionResult
|
||||
|
||||
// Predict streaming: predictTokens are replayed through the token callback
|
||||
// (simulating streamed LLM output); predictResp/predictErr are returned by
|
||||
// the deferred predict function. predictChunkDeltas, when set, are delivered
|
||||
// per-token via TokenUsage.ChatDeltas to exercise the autoparser path.
|
||||
predictTokens []string
|
||||
predictChunkDeltas [][]*proto.ChatDelta
|
||||
predictResp backend.LLMResponse
|
||||
predictErr error
|
||||
}
|
||||
|
||||
func (m *fakeModel) VAD(context.Context, *schema.VADRequest) (*schema.VADResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *fakeModel) Transcribe(context.Context, string, string, bool, bool, string) (*schema.TranscriptionResult, error) {
|
||||
return m.transcribeFinal, nil
|
||||
}
|
||||
|
||||
func (m *fakeModel) Predict(_ context.Context, _ schema.Messages, _, _, _ []string, cb func(string, backend.TokenUsage) bool, _ []types.ToolUnion, _ *types.ToolChoiceUnion, _, _ *int, _ map[string]float64) (func() (backend.LLMResponse, error), error) {
|
||||
if m.predictErr != nil {
|
||||
return nil, m.predictErr
|
||||
}
|
||||
return func() (backend.LLMResponse, error) {
|
||||
for i, tok := range m.predictTokens {
|
||||
if cb == nil {
|
||||
continue
|
||||
}
|
||||
usage := backend.TokenUsage{}
|
||||
if i < len(m.predictChunkDeltas) {
|
||||
usage.ChatDeltas = m.predictChunkDeltas[i]
|
||||
}
|
||||
cb(tok, usage)
|
||||
}
|
||||
return m.predictResp, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *fakeModel) TTS(context.Context, string, string, string) (string, *proto.Result, error) {
|
||||
return m.ttsFile, &proto.Result{Success: true}, nil
|
||||
}
|
||||
|
||||
func (m *fakeModel) TTSStream(_ context.Context, _, _, _ string, onAudio func(pcm []byte, sampleRate int) error) error {
|
||||
if m.ttsStreamErr != nil {
|
||||
return m.ttsStreamErr
|
||||
}
|
||||
for _, c := range m.ttsStreamChunks {
|
||||
if err := onAudio(c, m.ttsStreamRate); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *fakeModel) TranscribeStream(_ context.Context, _, _ string, _, _ bool, _ string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
|
||||
for _, d := range m.transcribeDeltas {
|
||||
onDelta(d)
|
||||
}
|
||||
return m.transcribeFinal, nil
|
||||
}
|
||||
|
||||
func (m *fakeModel) PredictConfig() *config.ModelConfig { return m.cfg }
|
||||
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -88,14 +87,6 @@ func (m *transcriptOnlyModel) TTS(ctx context.Context, text, voice, language str
|
||||
return "", nil, fmt.Errorf("TTS not supported in transcript-only mode")
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error {
|
||||
return fmt.Errorf("TTS not supported in transcript-only mode")
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
|
||||
return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta)
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) PredictConfig() *config.ModelConfig {
|
||||
return nil
|
||||
}
|
||||
@@ -330,75 +321,10 @@ func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (s
|
||||
return backend.ModelTTS(ctx, text, voice, language, "", nil, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error {
|
||||
return ttsStream(ctx, m.modelLoader, m.appConfig, *m.TTSConfig, text, voice, language, onAudio)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
|
||||
return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
|
||||
return m.LLMConfig
|
||||
}
|
||||
|
||||
// wavStreamHeaderBytes is the size of the WAV header that backend.ModelTTSStream
|
||||
// emits as its first audio callback; the sample rate lives at byte offset 24.
|
||||
const wavStreamHeaderBytes = 44
|
||||
|
||||
// ttsStream adapts backend.ModelTTSStream (which emits a WAV stream: a 44-byte
|
||||
// header carrying the sample rate, then raw PCM) to the realtime onAudio
|
||||
// callback, which wants raw PCM plus the sample rate. The header is buffered
|
||||
// until complete, the sample rate is read from it, and subsequent bytes are
|
||||
// forwarded as PCM.
|
||||
func ttsStream(ctx context.Context, ml *model.ModelLoader, appConfig *config.ApplicationConfig, ttsConfig config.ModelConfig, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error {
|
||||
var header []byte
|
||||
headerDone := false
|
||||
sampleRate := 0
|
||||
return backend.ModelTTSStream(ctx, text, voice, language, "", nil, ml, appConfig, ttsConfig, func(b []byte) error {
|
||||
if headerDone {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
return onAudio(b, sampleRate)
|
||||
}
|
||||
header = append(header, b...)
|
||||
if len(header) < wavStreamHeaderBytes {
|
||||
return nil
|
||||
}
|
||||
sampleRate = int(binary.LittleEndian.Uint32(header[24:28]))
|
||||
headerDone = true
|
||||
if len(header) > wavStreamHeaderBytes {
|
||||
return onAudio(header[wavStreamHeaderBytes:], sampleRate)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// transcribeStream adapts backend.ModelTranscriptionStream to the realtime
|
||||
// onDelta callback, returning the final aggregated transcription result.
|
||||
func transcribeStream(ctx context.Context, ml *model.ModelLoader, transcriptionConfig config.ModelConfig, appConfig *config.ApplicationConfig, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
|
||||
var final *schema.TranscriptionResult
|
||||
err := backend.ModelTranscriptionStream(ctx, backend.TranscriptionRequest{
|
||||
Audio: audio,
|
||||
Language: language,
|
||||
Translate: translate,
|
||||
Diarize: diarize,
|
||||
Prompt: prompt,
|
||||
}, ml, transcriptionConfig, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
||||
if chunk.Delta != "" {
|
||||
onDelta(chunk.Delta)
|
||||
}
|
||||
if chunk.Final != nil {
|
||||
final = chunk.Final
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return final, nil
|
||||
}
|
||||
|
||||
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
|
||||
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
@@ -528,10 +454,8 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
// Let the pipeline set the LLM's reasoning effort and force thinking off
|
||||
// (cfgLLM is a per-session copy). disable_thinking applies after the effort.
|
||||
// Let the pipeline set the LLM's reasoning effort (cfgLLM is a per-session copy).
|
||||
applyPipelineReasoning(cfgLLM, *pipeline)
|
||||
applyPipelineThinking(cfgLLM, *pipeline)
|
||||
|
||||
cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
laudio "github.com/mudler/LocalAI/pkg/audio"
|
||||
"github.com/mudler/LocalAI/pkg/sound"
|
||||
)
|
||||
|
||||
// emitSpeech synthesizes text and sends the audio to the client. When the
|
||||
// pipeline opts into TTS streaming it forwards each PCM chunk as its own
|
||||
// response.output_audio.delta as soon as the backend produces it; otherwise it
|
||||
// synthesizes the whole utterance and sends it as a single delta.
|
||||
//
|
||||
// It deliberately does NOT emit transcript or audio-done events: the caller owns
|
||||
// those so a streamed reply can be split into several spoken segments that share
|
||||
// one response/item.
|
||||
//
|
||||
// It returns the PCM audio (at the session output rate) accumulated across all
|
||||
// chunks, which the caller base64-encodes onto the conversation item. For WebRTC
|
||||
// the audio goes over the RTP track instead, so the returned slice is empty.
|
||||
func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) ([]byte, error) {
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, isWebRTC := t.(*WebRTCTransport)
|
||||
|
||||
var wsAudio []byte // PCM at the session output rate, accumulated for the item record
|
||||
|
||||
// sendChunk hands one PCM buffer to the transport: WebRTC consumes the raw
|
||||
// PCM directly (it resamples internally); WebSocket gets base64 PCM at the
|
||||
// session output rate via a JSON delta event.
|
||||
sendChunk := func(pcm []byte, sampleRate int) error {
|
||||
if len(pcm) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := t.SendAudio(ctx, pcm, sampleRate); err != nil {
|
||||
return err
|
||||
}
|
||||
if isWebRTC {
|
||||
return nil
|
||||
}
|
||||
wsPCM := pcm
|
||||
if sampleRate != 0 && sampleRate != session.OutputSampleRate {
|
||||
samples := sound.BytesToInt16sLE(pcm)
|
||||
resampled := sound.ResampleInt16(samples, sampleRate, session.OutputSampleRate)
|
||||
wsPCM = sound.Int16toBytesLE(resampled)
|
||||
}
|
||||
wsAudio = append(wsAudio, wsPCM...)
|
||||
return t.SendEvent(types.ResponseOutputAudioDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
ItemID: itemID,
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Delta: base64.StdEncoding.EncodeToString(wsPCM),
|
||||
})
|
||||
}
|
||||
|
||||
language := ""
|
||||
if session.InputAudioTranscription != nil {
|
||||
language = session.InputAudioTranscription.Language
|
||||
}
|
||||
|
||||
if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTTS() {
|
||||
if err := session.ModelInterface.TTSStream(ctx, text, session.Voice, language, sendChunk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wsAudio, nil
|
||||
}
|
||||
|
||||
// Unary fallback: synthesize the whole utterance to a file, then emit once.
|
||||
audioFilePath, res, err := session.ModelInterface.TTS(ctx, text, session.Voice, language)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res != nil && !res.Success {
|
||||
return nil, fmt.Errorf("tts generation failed: %s", res.Message)
|
||||
}
|
||||
defer func() { _ = os.Remove(audioFilePath) }()
|
||||
|
||||
// filepath.Clean normalizes the backend-produced temp path before reading
|
||||
// (also keeps gosec G304 quiet — the path is backend-controlled, not user input).
|
||||
audioBytes, err := os.ReadFile(filepath.Clean(audioFilePath))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read tts audio: %w", err)
|
||||
}
|
||||
pcm, sampleRate := laudio.ParseWAV(audioBytes)
|
||||
if sampleRate == 0 {
|
||||
sampleRate = session.OutputSampleRate
|
||||
}
|
||||
if err := sendChunk(pcm, sampleRate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wsAudio, nil
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
laudio "github.com/mudler/LocalAI/pkg/audio"
|
||||
)
|
||||
|
||||
// emitSpeech synthesizes a piece of text and forwards the audio to the client,
|
||||
// streaming a delta per TTS chunk when the pipeline opts in, or sending the
|
||||
// whole utterance as one delta otherwise.
|
||||
var _ = Describe("emitSpeech", func() {
|
||||
ttsOn := true
|
||||
|
||||
streamingSession := func(m Model) *Session {
|
||||
return &Session{
|
||||
OutputSampleRate: 24000,
|
||||
ModelInterface: m,
|
||||
ModelConfig: &config.ModelConfig{
|
||||
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{TTS: &ttsOn}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
It("streams one output_audio.delta per TTS chunk when streaming is enabled", func() {
|
||||
m := &fakeModel{
|
||||
ttsStreamChunks: [][]byte{{1, 2}, {3, 4}, {5, 6}},
|
||||
ttsStreamRate: 24000,
|
||||
}
|
||||
t := &fakeTransport{}
|
||||
|
||||
audio, err := emitSpeech(context.Background(), t, streamingSession(m), "resp1", "item1", "Hello there.")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(3))
|
||||
// The returned audio is all chunks concatenated (session output rate).
|
||||
Expect(audio).To(Equal([]byte{1, 2, 3, 4, 5, 6}))
|
||||
})
|
||||
|
||||
It("sends a single output_audio.delta in unary mode", func() {
|
||||
// A minimal real WAV file for the unary TTS path to read + parse.
|
||||
f, err := os.CreateTemp("", "emit-*.wav")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer func() { _ = os.Remove(f.Name()) }()
|
||||
pcm := make([]byte, 320) // 160 samples of silence
|
||||
hdr := laudio.NewWAVHeader(uint32(len(pcm)))
|
||||
Expect(hdr.Write(f)).To(Succeed())
|
||||
_, err = f.Write(pcm)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(f.Close()).To(Succeed())
|
||||
|
||||
session := &Session{
|
||||
OutputSampleRate: 24000,
|
||||
ModelInterface: &fakeModel{ttsFile: f.Name()},
|
||||
ModelConfig: &config.ModelConfig{}, // streaming off
|
||||
}
|
||||
t := &fakeTransport{}
|
||||
|
||||
_, err = emitSpeech(context.Background(), t, session, "resp1", "item1", "Hello there.")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(1))
|
||||
})
|
||||
})
|
||||
@@ -1,315 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
)
|
||||
|
||||
// transcriptStreamer turns streamed LLM tokens into the assistant's spoken
|
||||
// transcript: it strips reasoning incrementally and sends one
|
||||
// response.output_audio_transcript.delta per content fragment. It does NOT
|
||||
// synthesize audio — the caller buffers the full message and synthesizes it
|
||||
// once (streaming the audio chunks when the TTS backend supports TTSStream),
|
||||
// which works uniformly for streaming and non-streaming TTS and for languages
|
||||
// without sentence or word boundaries.
|
||||
type transcriptStreamer struct {
|
||||
ctx context.Context
|
||||
t Transport
|
||||
responseID string
|
||||
itemID string
|
||||
extractor *reasoning.ReasoningExtractor
|
||||
|
||||
// announce, if set, is invoked once just before the first transcript delta.
|
||||
// It lets the caller create the assistant item lazily, so a content-less
|
||||
// tool-call turn never emits a spurious empty assistant item.
|
||||
announce func()
|
||||
announced bool
|
||||
}
|
||||
|
||||
func newTranscriptStreamer(ctx context.Context, t Transport, responseID, itemID, thinkingStartToken string, reasoningCfg reasoning.Config) *transcriptStreamer {
|
||||
return &transcriptStreamer{
|
||||
ctx: ctx,
|
||||
t: t,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
extractor: reasoning.NewReasoningExtractor(thinkingStartToken, spokenReasoningConfig(reasoningCfg)),
|
||||
}
|
||||
}
|
||||
|
||||
// onToken handles one streamed unit of model output, sending a transcript delta
|
||||
// for the new content (reasoning stripped) and returning that content delta so
|
||||
// the caller can also feed it to the clause chunker. For plain-content models
|
||||
// the unit is the raw text token; for autoparser tool turns the backend clears
|
||||
// the text and delivers content via ChatDeltas, so the caller passes that
|
||||
// content here. Returns "" when the token produced no new spoken content.
|
||||
func (s *transcriptStreamer) onToken(token string) string {
|
||||
_, content := s.extractor.ProcessToken(token)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
if !s.announced {
|
||||
s.announced = true
|
||||
if s.announce != nil {
|
||||
s.announce()
|
||||
}
|
||||
}
|
||||
_ = s.t.SendEvent(types.ResponseOutputAudioTranscriptDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: s.responseID,
|
||||
ItemID: s.itemID,
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Delta: content,
|
||||
})
|
||||
return content
|
||||
}
|
||||
|
||||
// content returns the full transcript so far with reasoning stripped.
|
||||
func (s *transcriptStreamer) content() string {
|
||||
return s.extractor.CleanedContent()
|
||||
}
|
||||
|
||||
// streamLLMResponse drives a streamed realtime reply. It streams the assistant
|
||||
// transcript as the LLM generates, then synthesizes the whole buffered message
|
||||
// once (streaming the audio chunks when the TTS backend supports it, otherwise a
|
||||
// single unary delta). Tool calls parsed from the autoparser ChatDeltas are
|
||||
// emitted after the spoken content. The assistant content item is created lazily
|
||||
// on the first content delta, so a content-less tool-call turn emits only the
|
||||
// tool calls. It returns true when it has fully handled the response so the
|
||||
// caller can return; callers must only invoke it for an audio modality, and with
|
||||
// tools only when the model uses its tokenizer template (see triggerResponseAtTurn).
|
||||
func streamLLMResponse(ctx context.Context, session *Session, conv *Conversation, t Transport, responseID string, history schema.Messages, images []string, llmCfg *config.ModelConfig, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, toolTurn int) bool {
|
||||
itemID := generateItemID()
|
||||
item := types.MessageItemUnion{
|
||||
Assistant: &types.MessageItemAssistant{
|
||||
ID: itemID,
|
||||
Status: types.ItemStatusInProgress,
|
||||
Content: []types.MessageContentOutput{{Type: types.MessageContentTypeOutputAudio}},
|
||||
},
|
||||
}
|
||||
|
||||
// announce creates the assistant content item lazily, just before the first
|
||||
// transcript delta — a tool-only turn never produces content, so it stays out
|
||||
// of the conversation and the client sees only the tool calls.
|
||||
announced := false
|
||||
announce := func() {
|
||||
announced = true
|
||||
conv.Lock.Lock()
|
||||
conv.Items = append(conv.Items, &item)
|
||||
conv.Lock.Unlock()
|
||||
sendEvent(t, types.ResponseOutputItemAddedEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
OutputIndex: 0,
|
||||
Item: item,
|
||||
})
|
||||
sendEvent(t, types.ResponseContentPartAddedEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
ItemID: itemID,
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Part: item.Assistant.Content[0],
|
||||
})
|
||||
}
|
||||
|
||||
cancel := func() {
|
||||
if announced {
|
||||
conv.Lock.Lock()
|
||||
for i := len(conv.Items) - 1; i >= 0; i-- {
|
||||
if conv.Items[i].Assistant != nil && conv.Items[i].Assistant.ID == itemID {
|
||||
conv.Items = append(conv.Items[:i], conv.Items[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
conv.Lock.Unlock()
|
||||
}
|
||||
sendEvent(t, types.ResponseDoneEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
Response: types.Response{ID: responseID, Object: "realtime.response", Status: types.ResponseStatusCancelled},
|
||||
})
|
||||
}
|
||||
|
||||
var template string
|
||||
if llmCfg.TemplateConfig.UseTokenizerTemplate {
|
||||
template = llmCfg.GetModelTemplate()
|
||||
} else {
|
||||
template = llmCfg.TemplateConfig.Chat
|
||||
}
|
||||
thinkingStartToken := reasoning.DetectThinkingStartToken(template, &llmCfg.ReasoningConfig)
|
||||
|
||||
// The autoparser (tokenizer-template path) already delivers reasoning-free
|
||||
// content. Prefilling the thinking start token here would re-tag that clean
|
||||
// content as an unclosed reasoning block, leaving CleanedContent() empty —
|
||||
// no spoken reply, no TTS. Disable the prefill; closed tag pairs are still
|
||||
// stripped (PEG-fallback case, #9985).
|
||||
reasoningCfg := llmCfg.ReasoningConfig
|
||||
if llmCfg.TemplateConfig.UseTokenizerTemplate {
|
||||
disablePrefill := true
|
||||
reasoningCfg.DisableReasoningTagPrefill = &disablePrefill
|
||||
}
|
||||
|
||||
streamer := newTranscriptStreamer(ctx, t, responseID, itemID, thinkingStartToken, reasoningCfg)
|
||||
streamer.announce = announce
|
||||
|
||||
// Clause chunking (opt-in): synthesize each clause as soon as it completes
|
||||
// instead of buffering the whole reply. streamedAudio accumulates the PCM
|
||||
// across clauses for the conversation item record; ttsErr captures the first
|
||||
// synthesis failure so the token callback can stop the prediction. emitSpeech
|
||||
// runs synchronously here — the LLM keeps generating into the gRPC stream
|
||||
// while a clause is synthesized, so audio still starts mid-generation.
|
||||
var chunker *clauseChunker
|
||||
if session.ModelConfig != nil && session.ModelConfig.Pipeline.ChunkClauses() {
|
||||
chunker = newClauseChunker(defaultClauseMinRunes, defaultClauseMaxRunes)
|
||||
}
|
||||
var streamedAudio []byte
|
||||
var ttsErr error
|
||||
speakClause := func(clause string) error {
|
||||
a, err := emitSpeech(ctx, t, session, responseID, itemID, clause)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
streamedAudio = append(streamedAudio, a...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// fail reports a mid-stream failure. A cancelled context means the client
|
||||
// interrupted (barge-in), so roll the turn back instead of erroring.
|
||||
fail := func(code, msg string, err error) bool {
|
||||
if ctx.Err() != nil {
|
||||
cancel()
|
||||
} else {
|
||||
sendError(t, code, fmt.Sprintf("%s: %v", msg, err), "", itemID)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
cb := func(token string, usage backend.TokenUsage) bool {
|
||||
if ctx.Err() != nil {
|
||||
return false
|
||||
}
|
||||
// Plain-content models stream text via the token; autoparser tool turns
|
||||
// clear the text and deliver content via ChatDeltas, so prefer the latter
|
||||
// when present. Either way only content reaches the transcript — tool-call
|
||||
// deltas are parsed from the final response below.
|
||||
text := token
|
||||
if len(usage.ChatDeltas) > 0 {
|
||||
text = functions.ContentFromChatDeltas(usage.ChatDeltas)
|
||||
}
|
||||
delta := streamer.onToken(text)
|
||||
if chunker != nil && delta != "" {
|
||||
for _, clause := range chunker.push(delta) {
|
||||
if ttsErr = speakClause(clause); ttsErr != nil {
|
||||
return false // stop the prediction; reported after predFunc returns
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
predFunc, err := session.ModelInterface.Predict(ctx, history, images, nil, nil, cb, tools, toolChoice, nil, nil, nil)
|
||||
if err != nil {
|
||||
sendError(t, "inference_failed", fmt.Sprintf("backend error: %v", err), "", itemID)
|
||||
return true
|
||||
}
|
||||
pred, err := predFunc()
|
||||
// A clause synthesis failed mid-stream (the callback stopped the prediction);
|
||||
// report it as a TTS error rather than a prediction error.
|
||||
if ttsErr != nil {
|
||||
return fail("tts_error", "TTS generation failed", ttsErr)
|
||||
}
|
||||
if err != nil {
|
||||
return fail("prediction_failed", "backend error", err)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
cancel()
|
||||
return true
|
||||
}
|
||||
|
||||
content := streamer.content()
|
||||
toolCalls := functions.ToolCallsFromChatDeltas(pred.ChatDeltas)
|
||||
|
||||
// Finalize the spoken content item only when the turn produced content. A
|
||||
// tool-only turn skips this entirely (no empty assistant item).
|
||||
if content != "" {
|
||||
if !announced {
|
||||
announce()
|
||||
}
|
||||
|
||||
// Synthesize the audio. With clause chunking the completed clauses were
|
||||
// already spoken inside the token callback; flush the trailing clause(s)
|
||||
// the segmenter was still holding. Otherwise buffer the whole message and
|
||||
// synthesize it once. emitSpeech streams the audio chunks when the TTS
|
||||
// backend supports TTSStream, otherwise it sends a single unary delta.
|
||||
var audio []byte
|
||||
if chunker != nil {
|
||||
for _, clause := range chunker.flush() {
|
||||
if ttsErr = speakClause(clause); ttsErr != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
audio = streamedAudio
|
||||
} else {
|
||||
audio, ttsErr = emitSpeech(ctx, t, session, responseID, itemID, content)
|
||||
}
|
||||
if ttsErr != nil {
|
||||
return fail("tts_error", "TTS generation failed", ttsErr)
|
||||
}
|
||||
|
||||
_, isWebRTC := t.(*WebRTCTransport)
|
||||
|
||||
sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
ItemID: itemID,
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Transcript: content,
|
||||
})
|
||||
if !isWebRTC {
|
||||
sendEvent(t, types.ResponseOutputAudioDoneEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
ItemID: itemID,
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
})
|
||||
}
|
||||
|
||||
conv.Lock.Lock()
|
||||
item.Assistant.Status = types.ItemStatusCompleted
|
||||
item.Assistant.Content[0].Transcript = content
|
||||
if !isWebRTC {
|
||||
item.Assistant.Content[0].Audio = base64.StdEncoding.EncodeToString(audio)
|
||||
}
|
||||
conv.Lock.Unlock()
|
||||
|
||||
sendEvent(t, types.ResponseContentPartDoneEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
ItemID: itemID,
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Part: item.Assistant.Content[0],
|
||||
})
|
||||
sendEvent(t, types.ResponseOutputItemDoneEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
OutputIndex: 0,
|
||||
Item: item,
|
||||
})
|
||||
}
|
||||
|
||||
// Emit any tool calls, the terminal response.done, and (for server-side
|
||||
// assistant tools) the follow-up turn — shared with the buffered path.
|
||||
emitToolCallItems(ctx, session, conv, t, responseID, toolCalls, content != "", toolTurn)
|
||||
return true
|
||||
}
|
||||
@@ -1,213 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
)
|
||||
|
||||
// transcriptStreamer turns streamed LLM tokens into incremental transcript
|
||||
// deltas, stripping reasoning. Audio is synthesized once from the full message
|
||||
// by the caller, so there is no per-sentence segmentation.
|
||||
var _ = Describe("transcriptStreamer", func() {
|
||||
It("emits one transcript delta per content token", func() {
|
||||
t := &fakeTransport{}
|
||||
s := newTranscriptStreamer(context.Background(), t, "resp1", "item1", "", reasoning.Config{})
|
||||
|
||||
for _, tok := range []string{"Hello", " world.", " Bye"} {
|
||||
s.onToken(tok)
|
||||
}
|
||||
|
||||
Expect(s.content()).To(Equal("Hello world. Bye"))
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(3))
|
||||
Expect(t.transcriptDeltaText()).To(Equal("Hello world. Bye"))
|
||||
})
|
||||
|
||||
It("strips leaked reasoning even when reasoning is disabled (disable_thinking safety net)", func() {
|
||||
// disable_thinking maps to DisableReasoning=true (enable_thinking=false to
|
||||
// the backend). If the model emits thinking anyway, the transcript must
|
||||
// still not leak it: stripping always runs for spoken output.
|
||||
disable := true
|
||||
t := &fakeTransport{}
|
||||
s := newTranscriptStreamer(context.Background(), t, "resp1", "item1", "",
|
||||
reasoning.Config{DisableReasoning: &disable})
|
||||
|
||||
s.onToken("<think>secret plan</think>")
|
||||
s.onToken("The answer is 42.")
|
||||
|
||||
Expect(s.content()).To(Equal("The answer is 42."))
|
||||
Expect(s.content()).ToNot(ContainSubstring("secret plan"))
|
||||
Expect(t.transcriptDeltaText()).ToNot(ContainSubstring("secret plan"))
|
||||
})
|
||||
|
||||
It("does not swallow autoparser content when the template has a thinking start token (tokenizer-template path)", func() {
|
||||
// Regression: with tag prefill on, the detected <think> token is
|
||||
// prepended to the autoparser's already-clean content, swallowing the
|
||||
// whole reply (empty transcript → no TTS). streamLLMResponse disables
|
||||
// the prefill for the tokenizer-template path.
|
||||
disablePrefill := true
|
||||
t := &fakeTransport{}
|
||||
s := newTranscriptStreamer(context.Background(), t, "resp1", "item1", "<think>",
|
||||
reasoning.Config{DisableReasoningTagPrefill: &disablePrefill})
|
||||
|
||||
s.onToken("Hello")
|
||||
s.onToken(" there.")
|
||||
|
||||
Expect(s.content()).To(Equal("Hello there."))
|
||||
Expect(t.transcriptDeltaText()).To(Equal("Hello there."))
|
||||
})
|
||||
|
||||
It("still strips embedded closed reasoning tags with prefill disabled (PEG-fallback safety, #9985)", func() {
|
||||
// Disabling prefill must not stop stripping closed <think>…</think>
|
||||
// pairs the PEG fallback can leave in autoparser content.
|
||||
disablePrefill := true
|
||||
t := &fakeTransport{}
|
||||
s := newTranscriptStreamer(context.Background(), t, "resp1", "item1", "<think>",
|
||||
reasoning.Config{DisableReasoningTagPrefill: &disablePrefill})
|
||||
|
||||
s.onToken("<think>secret</think>")
|
||||
s.onToken("The answer is 42.")
|
||||
|
||||
Expect(s.content()).To(Equal("The answer is 42."))
|
||||
Expect(t.transcriptDeltaText()).ToNot(ContainSubstring("secret"))
|
||||
})
|
||||
})
|
||||
|
||||
// streamLLMResponse drives a full streamed realtime turn: live transcript
|
||||
// deltas while the LLM generates, then the whole message is synthesized once.
|
||||
var _ = Describe("streamLLMResponse", func() {
|
||||
It("streams transcript deltas then synthesizes the whole message once", func() {
|
||||
on := true
|
||||
m := &fakeModel{
|
||||
predictTokens: []string{"Hello", " world.", " How are you?"},
|
||||
predictResp: backend.LLMResponse{Response: "Hello world. How are you?"},
|
||||
ttsStreamChunks: [][]byte{{9}},
|
||||
ttsStreamRate: 24000,
|
||||
}
|
||||
session := &Session{
|
||||
OutputSampleRate: 24000,
|
||||
ModelInterface: m,
|
||||
ModelConfig: &config.ModelConfig{
|
||||
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{LLM: &on, TTS: &on}},
|
||||
},
|
||||
}
|
||||
conv := &Conversation{}
|
||||
t := &fakeTransport{}
|
||||
llmCfg := &config.ModelConfig{}
|
||||
|
||||
handled := streamLLMResponse(context.Background(), session, conv, t, "resp1", nil, nil, llmCfg, nil, nil, 0)
|
||||
|
||||
Expect(handled).To(BeTrue())
|
||||
// One live transcript delta per streamed token.
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(3))
|
||||
// The whole message is synthesized ONCE (not per sentence): a single
|
||||
// emitSpeech replays the one TTS stream chunk.
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(1))
|
||||
Expect(t.transcriptDeltaText()).To(Equal("Hello world. How are you?"))
|
||||
})
|
||||
|
||||
It("synthesizes each clause as it completes when clause chunking is enabled", func() {
|
||||
on := true
|
||||
m := &fakeModel{
|
||||
predictTokens: []string{"Hello world.", " How are you?"},
|
||||
predictResp: backend.LLMResponse{Response: "Hello world. How are you?"},
|
||||
ttsStreamChunks: [][]byte{{9}},
|
||||
ttsStreamRate: 24000,
|
||||
}
|
||||
session := &Session{
|
||||
OutputSampleRate: 24000,
|
||||
ModelInterface: m,
|
||||
ModelConfig: &config.ModelConfig{
|
||||
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{LLM: &on, TTS: &on, ClauseChunking: &on}},
|
||||
},
|
||||
}
|
||||
conv := &Conversation{}
|
||||
t := &fakeTransport{}
|
||||
llmCfg := &config.ModelConfig{}
|
||||
|
||||
handled := streamLLMResponse(context.Background(), session, conv, t, "resp1", nil, nil, llmCfg, nil, nil, 0)
|
||||
|
||||
Expect(handled).To(BeTrue())
|
||||
// Two clauses ("Hello world." mid-stream, "How are you?" on flush) → two
|
||||
// emitSpeech calls → two audio deltas, vs one for whole-message buffering.
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(2))
|
||||
// The full transcript still streams verbatim.
|
||||
Expect(t.transcriptDeltaText()).To(Equal("Hello world. How are you?"))
|
||||
// Exactly one terminal response.done.
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseDone)).To(Equal(1))
|
||||
})
|
||||
|
||||
It("streams content deltas and emits tool-call items (autoparser tool turn)", func() {
|
||||
on := true
|
||||
// Autoparser path: reply.Message is empty; content + tool calls arrive via
|
||||
// ChatDeltas. Chunk 1 carries content, chunk 2 carries the tool call.
|
||||
contentDelta := []*proto.ChatDelta{{Content: "Let me check."}}
|
||||
toolDelta := []*proto.ChatDelta{{ToolCalls: []*proto.ToolCallDelta{{Index: 0, Name: "get_weather", Arguments: `{"city":"Paris"}`}}}}
|
||||
m := &fakeModel{
|
||||
predictTokens: []string{"", ""},
|
||||
predictChunkDeltas: [][]*proto.ChatDelta{contentDelta, toolDelta},
|
||||
predictResp: backend.LLMResponse{ChatDeltas: append(append([]*proto.ChatDelta{}, contentDelta...), toolDelta...)},
|
||||
ttsStreamChunks: [][]byte{{9}},
|
||||
ttsStreamRate: 24000,
|
||||
}
|
||||
session := &Session{
|
||||
OutputSampleRate: 24000,
|
||||
ModelInterface: m,
|
||||
ModelConfig: &config.ModelConfig{
|
||||
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{LLM: &on, TTS: &on}},
|
||||
},
|
||||
}
|
||||
conv := &Conversation{}
|
||||
t := &fakeTransport{}
|
||||
llmCfg := &config.ModelConfig{}
|
||||
llmCfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
|
||||
handled := streamLLMResponse(context.Background(), session, conv, t, "resp1", nil, nil, llmCfg, nil, nil, 0)
|
||||
|
||||
Expect(handled).To(BeTrue())
|
||||
// The spoken content was streamed live.
|
||||
Expect(t.transcriptDeltaText()).To(Equal("Let me check."))
|
||||
// The tool call is emitted as a function_call item.
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseFunctionCallArgumentsDone)).To(Equal(1))
|
||||
// Exactly one terminal response.done.
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseDone)).To(Equal(1))
|
||||
})
|
||||
|
||||
It("emits only tool-call items for a content-less tool turn (no empty assistant item)", func() {
|
||||
on := true
|
||||
toolDelta := []*proto.ChatDelta{{ToolCalls: []*proto.ToolCallDelta{{Index: 0, Name: "get_weather", Arguments: `{"city":"Rome"}`}}}}
|
||||
m := &fakeModel{
|
||||
predictTokens: []string{""},
|
||||
predictChunkDeltas: [][]*proto.ChatDelta{toolDelta},
|
||||
predictResp: backend.LLMResponse{ChatDeltas: toolDelta},
|
||||
}
|
||||
session := &Session{
|
||||
OutputSampleRate: 24000,
|
||||
ModelInterface: m,
|
||||
ModelConfig: &config.ModelConfig{
|
||||
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{LLM: &on, TTS: &on}},
|
||||
},
|
||||
}
|
||||
conv := &Conversation{}
|
||||
t := &fakeTransport{}
|
||||
llmCfg := &config.ModelConfig{}
|
||||
llmCfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
|
||||
handled := streamLLMResponse(context.Background(), session, conv, t, "resp1", nil, nil, llmCfg, nil, nil, 0)
|
||||
|
||||
Expect(handled).To(BeTrue())
|
||||
// No content → no transcript deltas and no spurious assistant content item.
|
||||
Expect(t.transcriptDeltaText()).To(Equal(""))
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(0))
|
||||
// The tool call is still emitted.
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseFunctionCallArgumentsDone)).To(Equal(1))
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseDone)).To(Equal(1))
|
||||
})
|
||||
})
|
||||
@@ -1,33 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
)
|
||||
|
||||
// applyPipelineThinking forces the LLM's reasoning/thinking off when the realtime
|
||||
// pipeline sets disable_thinking, mapping to the enable_thinking=false backend
|
||||
// metadata via ReasoningConfig.DisableReasoning. The LLM config passed in is the
|
||||
// per-session copy returned by the config loader, so this does not affect other
|
||||
// users of the same model. When the pipeline does not set disable_thinking the
|
||||
// LLM config is left untouched.
|
||||
func applyPipelineThinking(llm *config.ModelConfig, pipeline config.Pipeline) {
|
||||
if llm == nil || !pipeline.ThinkingDisabled() {
|
||||
return
|
||||
}
|
||||
disable := true
|
||||
llm.ReasoningConfig.DisableReasoning = &disable
|
||||
}
|
||||
|
||||
// spokenReasoningConfig adapts a model's reasoning config for stripping reasoning
|
||||
// OUT of realtime spoken output. ReasoningConfig.DisableReasoning is overloaded:
|
||||
// the backend reads it as the "enable_thinking=false" hint (which pipeline
|
||||
// disable_thinking sets via applyPipelineThinking), but the reasoning extractor
|
||||
// reads it as "skip stripping, assume there is no reasoning". Honouring the latter
|
||||
// when extracting for speech would leak raw <think>…</think> whenever the model
|
||||
// ignores the suppression hint. Spoken output must never contain reasoning, so we
|
||||
// always strip: clear DisableReasoning while keeping custom tokens/tag pairs.
|
||||
func spokenReasoningConfig(cfg reasoning.Config) reasoning.Config {
|
||||
cfg.DisableReasoning = nil
|
||||
return cfg
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
)
|
||||
|
||||
// applyPipelineThinking lets a realtime pipeline force the LLM's thinking off
|
||||
// (enable_thinking=false metadata) without editing the LLM model config.
|
||||
var _ = Describe("applyPipelineThinking", func() {
|
||||
It("disables reasoning on the LLM config when the pipeline disables thinking", func() {
|
||||
disable := true
|
||||
llm := &config.ModelConfig{}
|
||||
applyPipelineThinking(llm, config.Pipeline{DisableThinking: &disable})
|
||||
Expect(llm.ReasoningConfig.DisableReasoning).ToNot(BeNil())
|
||||
Expect(*llm.ReasoningConfig.DisableReasoning).To(BeTrue())
|
||||
})
|
||||
|
||||
It("leaves the LLM config untouched when the pipeline does not set disable_thinking", func() {
|
||||
llm := &config.ModelConfig{}
|
||||
applyPipelineThinking(llm, config.Pipeline{})
|
||||
Expect(llm.ReasoningConfig.DisableReasoning).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
// spokenReasoningConfig clears DisableReasoning so realtime spoken output always
|
||||
// strips reasoning, even though disable_thinking sets DisableReasoning=true on the
|
||||
// LLM config (which the backend reads as enable_thinking=false).
|
||||
var _ = Describe("spokenReasoningConfig", func() {
|
||||
It("clears DisableReasoning so the extractor still strips leaked reasoning", func() {
|
||||
disable := true
|
||||
out := spokenReasoningConfig(reasoning.Config{DisableReasoning: &disable})
|
||||
Expect(out.DisableReasoning).To(BeNil())
|
||||
})
|
||||
|
||||
It("preserves the other reasoning settings", func() {
|
||||
disable := true
|
||||
out := spokenReasoningConfig(reasoning.Config{
|
||||
DisableReasoning: &disable,
|
||||
ThinkingStartTokens: []string{"<reason>"},
|
||||
TagPairs: []reasoning.TagPair{{Start: "<reason>", End: "</reason>"}},
|
||||
})
|
||||
Expect(out.ThinkingStartTokens).To(Equal([]string{"<reason>"}))
|
||||
Expect(out.TagPairs).To(HaveLen(1))
|
||||
Expect(out.TagPairs[0].Start).To(Equal("<reason>"))
|
||||
})
|
||||
})
|
||||
@@ -1,63 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
)
|
||||
|
||||
// emitTranscription transcribes a committed utterance and emits the transcription
|
||||
// events for it, returning the final transcript text. With
|
||||
// pipeline.streaming.transcription enabled it streams each transcript fragment as
|
||||
// a conversation.item.input_audio_transcription.delta as the backend produces it,
|
||||
// then a completed event; otherwise it transcribes the whole utterance and emits
|
||||
// a single completed event. delta and completed events share itemID.
|
||||
func emitTranscription(ctx context.Context, t Transport, session *Session, itemID, audioPath string) (string, error) {
|
||||
cfg := session.InputAudioTranscription
|
||||
|
||||
if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTranscription() {
|
||||
final, err := session.ModelInterface.TranscribeStream(ctx, audioPath, cfg.Language, false, false, cfg.Prompt, func(delta string) {
|
||||
_ = t.SendEvent(types.ConversationItemInputAudioTranscriptionDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{EventID: "event_TODO"},
|
||||
ItemID: itemID,
|
||||
ContentIndex: 0,
|
||||
Delta: delta,
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
transcript := ""
|
||||
if final != nil {
|
||||
transcript = final.Text
|
||||
}
|
||||
if err := t.SendEvent(types.ConversationItemInputAudioTranscriptionCompletedEvent{
|
||||
ServerEventBase: types.ServerEventBase{EventID: "event_TODO"},
|
||||
ItemID: itemID,
|
||||
ContentIndex: 0,
|
||||
Transcript: transcript,
|
||||
}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return transcript, nil
|
||||
}
|
||||
|
||||
// Unary fallback: transcribe the whole utterance, emit one completed event.
|
||||
tr, err := session.ModelInterface.Transcribe(ctx, audioPath, cfg.Language, false, false, cfg.Prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if tr == nil {
|
||||
return "", fmt.Errorf("transcribe result is nil")
|
||||
}
|
||||
if err := t.SendEvent(types.ConversationItemInputAudioTranscriptionCompletedEvent{
|
||||
ServerEventBase: types.ServerEventBase{EventID: "event_TODO"},
|
||||
ItemID: itemID,
|
||||
ContentIndex: 0,
|
||||
Transcript: tr.Text,
|
||||
}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tr.Text, nil
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// emitTranscription transcribes a committed utterance, streaming transcript text
|
||||
// deltas when the pipeline opts in, and returns the final transcript text.
|
||||
var _ = Describe("emitTranscription", func() {
|
||||
It("streams transcription deltas then a completed event when streaming is enabled", func() {
|
||||
on := true
|
||||
session := &Session{
|
||||
InputAudioTranscription: &types.AudioTranscription{},
|
||||
ModelConfig: &config.ModelConfig{
|
||||
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{Transcription: &on}},
|
||||
},
|
||||
ModelInterface: &fakeModel{
|
||||
transcribeDeltas: []string{"Hel", "lo", " world"},
|
||||
transcribeFinal: &schema.TranscriptionResult{Text: "Hello world"},
|
||||
},
|
||||
}
|
||||
t := &fakeTransport{}
|
||||
|
||||
transcript, err := emitTranscription(context.Background(), t, session, "item1", "/tmp/x.wav")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(transcript).To(Equal("Hello world"))
|
||||
Expect(t.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionDelta)).To(Equal(3))
|
||||
Expect(t.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionCompleted)).To(Equal(1))
|
||||
})
|
||||
|
||||
It("emits a single completed event with no deltas in unary mode", func() {
|
||||
session := &Session{
|
||||
InputAudioTranscription: &types.AudioTranscription{},
|
||||
ModelConfig: &config.ModelConfig{}, // streaming off
|
||||
ModelInterface: &fakeModel{transcribeFinal: &schema.TranscriptionResult{Text: "Hi"}},
|
||||
}
|
||||
t := &fakeTransport{}
|
||||
|
||||
transcript, err := emitTranscription(context.Background(), t, session, "item1", "/tmp/x.wav")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(transcript).To(Equal("Hi"))
|
||||
Expect(t.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionDelta)).To(Equal(0))
|
||||
Expect(t.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionCompleted)).To(Equal(1))
|
||||
})
|
||||
})
|
||||
@@ -48,8 +48,7 @@ func RealtimeCalls(application *application.Application) echo.HandlerFunc {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "codec registration failed"})
|
||||
}
|
||||
|
||||
se := webRTCSettingEngine(application.ApplicationConfig())
|
||||
api := webrtc.NewAPI(webrtc.WithMediaEngine(m), webrtc.WithSettingEngine(se))
|
||||
api := webrtc.NewAPI(webrtc.WithMediaEngine(m))
|
||||
|
||||
pc, err := api.NewPeerConnection(webrtc.Configuration{})
|
||||
if err != nil {
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
// webRTCSettingEngine builds the pion SettingEngine for /v1/realtime WebRTC.
|
||||
//
|
||||
// With a default (empty) SettingEngine, pion gathers a host ICE candidate for
|
||||
// every local interface. Under Docker host networking that includes bridge
|
||||
// addresses (docker0/veth, 172.x) that a remote browser cannot route to; the
|
||||
// connection often establishes on a good pair and then drops once ICE consent
|
||||
// checks fail on the unreachable ones. The two opt-in knobs below let an
|
||||
// operator advertise only the reachable address.
|
||||
func webRTCSettingEngine(cfg *config.ApplicationConfig) webrtc.SettingEngine {
|
||||
s := webrtc.SettingEngine{}
|
||||
if cfg == nil {
|
||||
return s
|
||||
}
|
||||
if len(cfg.WebRTCNAT1To1IPs) > 0 {
|
||||
s.SetNAT1To1IPs(cfg.WebRTCNAT1To1IPs, webrtc.ICECandidateTypeHost)
|
||||
xlog.Debug("realtime webrtc: advertising NAT 1:1 host IPs", "ips", cfg.WebRTCNAT1To1IPs)
|
||||
}
|
||||
if filter := iceInterfaceFilter(cfg.WebRTCICEInterfaces); filter != nil {
|
||||
s.SetInterfaceFilter(filter)
|
||||
xlog.Debug("realtime webrtc: restricting ICE interfaces", "interfaces", cfg.WebRTCICEInterfaces)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// iceInterfaceFilter returns an interface allow-list predicate for pion, or nil
|
||||
// when no interfaces are configured (pion's default: gather from all).
|
||||
func iceInterfaceFilter(allowed []string) func(string) bool {
|
||||
if len(allowed) == 0 {
|
||||
return nil
|
||||
}
|
||||
set := make(map[string]struct{}, len(allowed))
|
||||
for _, name := range allowed {
|
||||
set[name] = struct{}{}
|
||||
}
|
||||
return func(iface string) bool {
|
||||
_, ok := set[iface]
|
||||
return ok
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("webRTC ICE settings", func() {
|
||||
Describe("iceInterfaceFilter", func() {
|
||||
It("returns nil when no interfaces are configured", func() {
|
||||
Expect(iceInterfaceFilter(nil)).To(BeNil())
|
||||
Expect(iceInterfaceFilter([]string{})).To(BeNil())
|
||||
})
|
||||
|
||||
It("admits only the configured interfaces", func() {
|
||||
f := iceInterfaceFilter([]string{"eth0", "wlan0"})
|
||||
Expect(f).NotTo(BeNil())
|
||||
Expect(f("eth0")).To(BeTrue())
|
||||
Expect(f("wlan0")).To(BeTrue())
|
||||
Expect(f("docker0")).To(BeFalse())
|
||||
Expect(f("veth123")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("webRTCSettingEngine", func() {
|
||||
It("does not panic on a nil config", func() {
|
||||
Expect(func() { webRTCSettingEngine(nil) }).NotTo(Panic())
|
||||
})
|
||||
|
||||
It("builds an engine with NAT 1:1 IPs and an interface filter configured", func() {
|
||||
cfg := &config.ApplicationConfig{
|
||||
WebRTCNAT1To1IPs: []string{"192.168.1.10"},
|
||||
WebRTCICEInterfaces: []string{"eth0"},
|
||||
}
|
||||
Expect(func() { webRTCSettingEngine(cfg) }).NotTo(Panic())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1356,7 +1356,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig)
|
||||
|
||||
// Extract reasoning from result before cleaning
|
||||
reasoningContent, cleanedResult := reason.ExtractReasoningComplete(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
reasoningContent, cleanedResult := reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
|
||||
// Parse tool calls if using functions
|
||||
var outputItems []schema.ORItemField
|
||||
@@ -1996,7 +1996,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
if finalReasoning == "" && finalCleanedResult == "" {
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningComplete(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
}
|
||||
|
||||
// Close reasoning item if it exists and wasn't closed yet
|
||||
@@ -2493,7 +2493,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
if finalReasoning == "" && finalCleanedResult == "" {
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningComplete(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
}
|
||||
|
||||
// Close reasoning item if it exists and wasn't closed yet
|
||||
|
||||
@@ -21,20 +21,6 @@ var (
|
||||
mockBackendPath string
|
||||
)
|
||||
|
||||
// testHTTPAddr is the listen address used by specs that start a full HTTP
|
||||
// server. Configurable so the suite can run on machines where the default
|
||||
// port is taken by an unrelated service (override: LOCALAI_TEST_HTTP_PORT).
|
||||
var testHTTPAddr = func() string {
|
||||
port := os.Getenv("LOCALAI_TEST_HTTP_PORT")
|
||||
if port == "" {
|
||||
port = "9090"
|
||||
}
|
||||
return "127.0.0.1:" + port
|
||||
}()
|
||||
|
||||
// testHTTPBase is the matching http://host:port prefix for client requests.
|
||||
var testHTTPBase = "http://" + testHTTPAddr
|
||||
|
||||
// findMockBackendBinary locates the mock-backend binary built by
|
||||
// `make build-mock-backend`. Mirrors the lookup used by
|
||||
// tests/e2e/e2e_suite_test.go so both suites consume the same artifact.
|
||||
|
||||
@@ -59,14 +59,14 @@ var _ = Describe("Open Responses API", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
if err := app.Start(testHTTPAddr); err != nil && err != http.ErrServerClosed {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
xlog.Error("server error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for API to be ready
|
||||
Eventually(func() error {
|
||||
resp, err := http.Get(testHTTPBase + "/healthz")
|
||||
resp, err := http.Get("http://127.0.0.1:9090/healthz")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -95,7 +95,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -118,7 +118,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -143,7 +143,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -168,7 +168,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -196,7 +196,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -241,7 +241,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -269,7 +269,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -297,7 +297,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -328,7 +328,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -358,7 +358,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -386,7 +386,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -418,7 +418,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -454,7 +454,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -490,7 +490,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -539,7 +539,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -590,7 +590,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -624,7 +624,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -658,7 +658,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -680,7 +680,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -727,7 +727,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -756,7 +756,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -799,7 +799,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -835,7 +835,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload1, err := json.Marshal(reqBody1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req1, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload1))
|
||||
req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req1.Header.Set("Content-Type", "application/json")
|
||||
req1.Header.Set("Authorization", bearerKey)
|
||||
@@ -869,7 +869,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload2, err := json.Marshal(reqBody2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req2, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload2))
|
||||
req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req2.Header.Set("Content-Type", "application/json")
|
||||
req2.Header.Set("Authorization", bearerKey)
|
||||
@@ -897,7 +897,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
@@ -933,7 +933,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload1, err := json.Marshal(reqBody1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req1, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload1))
|
||||
req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req1.Header.Set("Content-Type", "application/json")
|
||||
req1.Header.Set("Authorization", bearerKey)
|
||||
@@ -983,7 +983,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload2, err := json.Marshal(reqBody2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req2, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload2))
|
||||
req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req2.Header.Set("Content-Type", "application/json")
|
||||
req2.Header.Set("Authorization", bearerKey)
|
||||
@@ -1009,7 +1009,7 @@ var _ = Describe("Open Responses API", func() {
|
||||
payload, err := json.Marshal(reqBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", testHTTPBase+"/v1/responses", bytes.NewBuffer(payload))
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerKey)
|
||||
|
||||
6
core/http/react-ui/src/hooks/useChat.js
vendored
6
core/http/react-ui/src/hooks/useChat.js
vendored
@@ -216,12 +216,6 @@ export function useChat(initialModel = '') {
|
||||
audio_url: { url: `data:${file.type};base64,${file.base64}` },
|
||||
})
|
||||
userFiles.push({ name: file.name, type: 'audio' })
|
||||
} else if (file.type?.startsWith('video/')) {
|
||||
messageContent.push({
|
||||
type: 'video_url',
|
||||
video_url: { url: `data:${file.type};base64,${file.base64}` },
|
||||
})
|
||||
userFiles.push({ name: file.name, type: 'video' })
|
||||
} else {
|
||||
// Text/PDF files - append to content
|
||||
if (file.textContent) {
|
||||
|
||||
@@ -506,10 +506,7 @@ export default function Backends() {
|
||||
<tbody>
|
||||
{backends.map((b, idx) => {
|
||||
const op = getBackendOp(b)
|
||||
// A failed op is intentionally kept in the operations list so the
|
||||
// OperationsBar can surface the error + Dismiss; it must NOT render
|
||||
// as a perpetual "Installing..." spinner here (mirrors Models.jsx).
|
||||
const isProcessing = !!op && !op.error
|
||||
const isProcessing = !!op
|
||||
const isExpanded = expandedRow === idx
|
||||
|
||||
return (
|
||||
|
||||
@@ -265,7 +265,7 @@ function UserMessageContent({ content, files }) {
|
||||
<div className="chat-message-files">
|
||||
{files.map((f, i) => (
|
||||
<span key={i} className="chat-file-inline">
|
||||
<i className={`fas ${f.type === 'image' ? 'fa-image' : f.type === 'audio' ? 'fa-headphones' : f.type === 'video' ? 'fa-film' : 'fa-file'}`} />
|
||||
<i className={`fas ${f.type === 'image' ? 'fa-image' : f.type === 'audio' ? 'fa-headphones' : 'fa-file'}`} />
|
||||
{f.name}
|
||||
</span>
|
||||
))}
|
||||
@@ -274,9 +274,6 @@ function UserMessageContent({ content, files }) {
|
||||
{Array.isArray(content) && content.filter(c => c.type === 'image_url').map((img, i) => (
|
||||
<img key={i} src={img.image_url.url} alt="attached" className="chat-inline-image" />
|
||||
))}
|
||||
{Array.isArray(content) && content.filter(c => c.type === 'video_url').map((vid, i) => (
|
||||
<video key={i} src={vid.video_url.url} controls className="chat-inline-video" />
|
||||
))}
|
||||
</>
|
||||
)
|
||||
}
|
||||
@@ -714,7 +711,7 @@ export default function Chat() {
|
||||
for (const file of e.target.files) {
|
||||
const base64 = await fileToBase64(file)
|
||||
const entry = { name: file.name, type: file.type, base64 }
|
||||
if (!file.type.startsWith('image/') && !file.type.startsWith('audio/') && !file.type.startsWith('video/')) {
|
||||
if (!file.type.startsWith('image/') && !file.type.startsWith('audio/')) {
|
||||
entry.textContent = await file.text().catch(() => '')
|
||||
}
|
||||
newFiles.push(entry)
|
||||
@@ -1247,7 +1244,7 @@ export default function Chat() {
|
||||
<div className="chat-files">
|
||||
{files.map((f, i) => (
|
||||
<span key={i} className="chat-file-badge">
|
||||
<i className={`fas ${f.type?.startsWith('image/') ? 'fa-image' : f.type?.startsWith('audio/') ? 'fa-headphones' : f.type?.startsWith('video/') ? 'fa-film' : 'fa-file'}`} />
|
||||
<i className={`fas ${f.type?.startsWith('image/') ? 'fa-image' : f.type?.startsWith('audio/') ? 'fa-headphones' : 'fa-file'}`} />
|
||||
{f.name}
|
||||
<button onClick={() => setFiles(prev => prev.filter((_, idx) => idx !== i))}>
|
||||
<i className="fas fa-xmark" />
|
||||
@@ -1346,7 +1343,7 @@ export default function Chat() {
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
multiple
|
||||
accept="image/*,audio/*,video/*,application/pdf,.txt,.md,.csv,.json"
|
||||
accept="image/*,audio/*,application/pdf,.txt,.md,.csv,.json"
|
||||
style={{ display: 'none' }}
|
||||
onChange={handleFileChange}
|
||||
/>
|
||||
|
||||
@@ -12,7 +12,6 @@ import ActionMenu from '../components/ActionMenu'
|
||||
import ResourceRow, { ChevronCell, IconCell, StopPropagationCell } from '../components/ResourceRow'
|
||||
import { useModels } from '../hooks/useModels'
|
||||
import { useGalleryEnrichment } from '../hooks/useGalleryEnrichment'
|
||||
import { useOperations } from '../hooks/useOperations'
|
||||
import { backendControlApi, modelsApi, backendsApi, systemApi, nodesApi } from '../utils/api'
|
||||
import { renderMarkdown } from '../utils/markdown'
|
||||
import { safeHref } from '../utils/url'
|
||||
@@ -127,7 +126,6 @@ export default function Manage() {
|
||||
const [activeTab, setActiveTab] = useState(TABS.some(tab => tab.key === initialTab) ? initialTab : 'models')
|
||||
const { models, loading: modelsLoading, refetch: refetchModels } = useModels()
|
||||
const { enrichModel, enrichBackend } = useGalleryEnrichment()
|
||||
const { operations } = useOperations()
|
||||
const [loadedModelIds, setLoadedModelIds] = useState(new Set())
|
||||
const [backends, setBackends] = useState([])
|
||||
const [backendsLoading, setBackendsLoading] = useState(true)
|
||||
@@ -260,19 +258,14 @@ export default function Manage() {
|
||||
return `${m}m ago`
|
||||
})()
|
||||
|
||||
// Refresh installed backends + available upgrades when the Backends tab opens
|
||||
// AND whenever a backend operation settles (operations.length changes as a
|
||||
// reinstall/upgrade completes and drops off the list). Without the op-settle
|
||||
// refresh the installed-version cell and the "update available" badge stay
|
||||
// stale after an upgrade until the user switches tabs - the op looks like it
|
||||
// "did nothing". Mirrors the operations.length watch Backends.jsx uses.
|
||||
// Fetch available backend upgrades
|
||||
useEffect(() => {
|
||||
if (activeTab !== 'backends') return
|
||||
fetchBackends()
|
||||
backendsApi.checkUpgrades()
|
||||
.then(data => setUpgrades(data || {}))
|
||||
.catch(() => {})
|
||||
}, [operations.length, activeTab, fetchBackends])
|
||||
if (activeTab === 'backends') {
|
||||
backendsApi.checkUpgrades()
|
||||
.then(data => setUpgrades(data || {}))
|
||||
.catch(() => {})
|
||||
}
|
||||
}, [activeTab])
|
||||
|
||||
const handleStopModel = (modelName) => {
|
||||
setConfirmDialog({
|
||||
|
||||
@@ -17,24 +17,6 @@ const STATUS_STYLES = {
|
||||
error: { icon: 'fa-solid fa-circle', color: 'var(--color-error)', bg: 'var(--color-error-light)' },
|
||||
}
|
||||
|
||||
// upsertAssistant merges a streamed transcript fragment into the assistant entry
|
||||
// identified by the server's item_id, or appends a new entry if none exists yet.
|
||||
// Keying by item_id (not a mutable index tracked across handler/updater
|
||||
// boundaries) makes streamed deltas idempotent and order-independent, so React's
|
||||
// batching of non-React data-channel events cannot produce a duplicate bubble.
|
||||
// mode 'append' adds to the running text; 'replace' sets the final transcript.
|
||||
function upsertAssistant(prev, itemId, text, mode) {
|
||||
// Only assistant entries carry an id, and the streaming entry is almost
|
||||
// always the newest — search from the tail so per-delta cost stays constant.
|
||||
const i = prev.findLastIndex(e => e.id === itemId)
|
||||
if (i === -1) {
|
||||
return [...prev, { role: 'assistant', id: itemId, text }]
|
||||
}
|
||||
const next = [...prev]
|
||||
next[i] = { ...next[i], text: mode === 'append' ? next[i].text + text : text }
|
||||
return next
|
||||
}
|
||||
|
||||
export default function Talk() {
|
||||
const { addToast } = useOutletContext()
|
||||
const navigate = useNavigate()
|
||||
@@ -52,10 +34,7 @@ export default function Talk() {
|
||||
|
||||
// Transcript
|
||||
const [transcript, setTranscript] = useState([])
|
||||
// item_id of the assistant message currently streaming — used only to remove
|
||||
// its partial bubble when a response is cancelled (barge-in). The transcript
|
||||
// itself is keyed by item_id via upsertAssistant, not by this ref.
|
||||
const inProgressIdRef = useRef(null)
|
||||
const streamingRef = useRef(null) // tracks the index of the in-progress assistant message
|
||||
|
||||
// Session settings
|
||||
const [instructions, setInstructions] = useState(
|
||||
@@ -248,21 +227,39 @@ export default function Talk() {
|
||||
break
|
||||
case 'conversation.item.input_audio_transcription.completed':
|
||||
if (event.transcript) {
|
||||
streamingRef.current = null
|
||||
setTranscript(prev => [...prev, { role: 'user', text: event.transcript }])
|
||||
}
|
||||
updateStatus('thinking', 'Generating response...')
|
||||
break
|
||||
case 'response.output_audio_transcript.delta':
|
||||
if (event.delta) {
|
||||
inProgressIdRef.current = event.item_id
|
||||
setTranscript(prev => upsertAssistant(prev, event.item_id, event.delta, 'append'))
|
||||
setTranscript(prev => {
|
||||
if (streamingRef.current !== null) {
|
||||
const updated = [...prev]
|
||||
updated[streamingRef.current] = {
|
||||
...updated[streamingRef.current],
|
||||
text: updated[streamingRef.current].text + event.delta,
|
||||
}
|
||||
return updated
|
||||
}
|
||||
streamingRef.current = prev.length
|
||||
return [...prev, { role: 'assistant', text: event.delta }]
|
||||
})
|
||||
}
|
||||
break
|
||||
case 'response.output_audio_transcript.done':
|
||||
if (event.transcript) {
|
||||
setTranscript(prev => upsertAssistant(prev, event.item_id, event.transcript, 'replace'))
|
||||
setTranscript(prev => {
|
||||
if (streamingRef.current !== null) {
|
||||
const updated = [...prev]
|
||||
updated[streamingRef.current] = { ...updated[streamingRef.current], text: event.transcript }
|
||||
return updated
|
||||
}
|
||||
return [...prev, { role: 'assistant', text: event.transcript }]
|
||||
})
|
||||
}
|
||||
inProgressIdRef.current = null
|
||||
streamingRef.current = null
|
||||
break
|
||||
case 'response.output_audio.delta':
|
||||
updateStatus('speaking', 'Speaking...')
|
||||
@@ -284,7 +281,7 @@ export default function Talk() {
|
||||
// Pretty-print JSON for readability; fall back to raw string.
|
||||
try { preview = JSON.stringify(JSON.parse(preview), null, 2) } catch (_) { /* keep raw */ }
|
||||
setTranscript(prev => [...prev, { role: 'tool_result', text: preview }])
|
||||
inProgressIdRef.current = null // tool result ends the current assistant text run
|
||||
streamingRef.current = null // tool result ends the current assistant text run
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -293,20 +290,9 @@ export default function Talk() {
|
||||
// conversation.item.create + response.create when it's done.
|
||||
handleFunctionCall(event)
|
||||
break
|
||||
case 'response.done': {
|
||||
// A cancelled response (barge-in / interruption) leaves a partial,
|
||||
// incrementally-streamed assistant bubble behind. The server discards
|
||||
// the interrupted item from history; mirror that here (remove the
|
||||
// in-progress assistant entry by item_id) so the regenerated reply
|
||||
// doesn't show up as a second assistant message.
|
||||
if (event.response?.status === 'cancelled' && inProgressIdRef.current) {
|
||||
const id = inProgressIdRef.current
|
||||
inProgressIdRef.current = null
|
||||
setTranscript(prev => prev.filter(e => e.id !== id))
|
||||
}
|
||||
case 'response.done':
|
||||
updateStatus('listening', 'Listening...')
|
||||
break
|
||||
}
|
||||
case 'error':
|
||||
hasErrorRef.current = true
|
||||
updateStatus('error', 'Error: ' + (event.error?.message || 'Unknown error'))
|
||||
@@ -803,7 +789,7 @@ export default function Talk() {
|
||||
const iconColor = isToolCall || isToolResult ? 'var(--color-text-secondary)'
|
||||
: isUser ? 'var(--color-primary)' : 'var(--color-accent)'
|
||||
return (
|
||||
<div key={entry.id || i} style={{ display: 'flex', alignItems: 'flex-start', gap: 'var(--spacing-xs)' }}>
|
||||
<div key={i} style={{ display: 'flex', alignItems: 'flex-start', gap: 'var(--spacing-xs)' }}>
|
||||
<i className={iconClass} style={{ color: iconColor, marginTop: 3, flexShrink: 0, fontSize: '0.75rem' }} />
|
||||
<p style={{
|
||||
margin: 0,
|
||||
|
||||
@@ -466,11 +466,10 @@ func (s *AgentPoolService) Chat(name, message string) (string, error) {
|
||||
s.collectAndCopyMetadata(metadata, chatUserID)
|
||||
}
|
||||
|
||||
content := s.appendLocalAGIKBCitations(response.Response, name, message, response.State)
|
||||
msg := map[string]any{
|
||||
"id": messageID + "-agent",
|
||||
"sender": "agent",
|
||||
"content": content,
|
||||
"content": response.Response,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
@@ -490,79 +489,6 @@ func (s *AgentPoolService) Chat(name, message string) (string, error) {
|
||||
return messageID, nil
|
||||
}
|
||||
|
||||
func (s *AgentPoolService) appendLocalAGIKBCitations(response, agentKey, message string, states []coreTypes.ActionState) string {
|
||||
if strings.TrimSpace(response) == "" {
|
||||
return response
|
||||
}
|
||||
|
||||
userID, collection := splitAgentKey(agentKey)
|
||||
cfg := s.localAGI.pool.GetConfig(agentKey)
|
||||
if cfg == nil || !cfg.EnableKnowledgeBase {
|
||||
return response
|
||||
}
|
||||
|
||||
citations := kbCitationsFromActionStates(states)
|
||||
if len(citations) == 0 && cfg.KBAutoSearch {
|
||||
maxResults := cfg.KnowledgeBaseResults
|
||||
if maxResults <= 0 {
|
||||
maxResults = 5
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
kbResult := agents.KBAutoSearchPrompt(ctx, s.apiURL, s.apiKey, collection, message, maxResults, userID)
|
||||
citations = kbResult.Citations
|
||||
}
|
||||
|
||||
return agents.AppendKBCitations(response, collection, userID, citations)
|
||||
}
|
||||
|
||||
func splitAgentKey(agentKey string) (userID, name string) {
|
||||
if uid, n, ok := strings.Cut(agentKey, ":"); ok {
|
||||
return uid, n
|
||||
}
|
||||
return "", agentKey
|
||||
}
|
||||
|
||||
func kbCitationsFromActionStates(states []coreTypes.ActionState) []agents.KBCitation {
|
||||
var citations []agents.KBCitation
|
||||
for _, state := range states {
|
||||
citations = append(citations, kbCitationsFromMetadata(state.Metadata)...)
|
||||
}
|
||||
return citations
|
||||
}
|
||||
|
||||
func kbCitationsFromMetadata(metadata map[string]any) []agents.KBCitation {
|
||||
if len(metadata) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fileName := metadata["file_name"]
|
||||
source := metadata["source"]
|
||||
if fileName == nil && source == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
citation := agents.KBCitation{
|
||||
FileName: metadataString(fileName),
|
||||
EntryKey: metadataString(source),
|
||||
}
|
||||
if citation.FileName == "" && citation.EntryKey == "" {
|
||||
return nil
|
||||
}
|
||||
return []agents.KBCitation{citation}
|
||||
}
|
||||
|
||||
func metadataString(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case fmt.Stringer:
|
||||
return v.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// userOutputsDir returns the per-user outputs directory, creating it if needed.
|
||||
// If userID is empty, falls back to the shared outputs directory.
|
||||
func (s *AgentPoolService) userOutputsDir(userID string) string {
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
package agents
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type kbCitationList struct {
|
||||
mu sync.Mutex
|
||||
citations []KBCitation
|
||||
}
|
||||
|
||||
func (l *kbCitationList) AddKBCitations(citations []KBCitation) {
|
||||
if len(citations) == 0 {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.citations = append(l.citations, citations...)
|
||||
}
|
||||
|
||||
func (l *kbCitationList) Citations() []KBCitation {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
out := make([]KBCitation, len(l.citations))
|
||||
copy(out, l.citations)
|
||||
return out
|
||||
}
|
||||
|
||||
// AppendKBCitations appends a markdown Sources block for KB citations.
|
||||
func AppendKBCitations(response, collection, userID string, citations []KBCitation) string {
|
||||
if strings.TrimSpace(response) == "" || len(citations) == 0 {
|
||||
return response
|
||||
}
|
||||
|
||||
var lines []string
|
||||
seen := make(map[string]struct{})
|
||||
for _, citation := range citations {
|
||||
key := strings.TrimSpace(citation.EntryKey)
|
||||
if key == "" {
|
||||
key = strings.TrimSpace(citation.FileName)
|
||||
}
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
|
||||
displayName := kbCitationDisplayName(citation)
|
||||
if displayName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
sourceURL := kbCitationRawFileURL(collection, citation.EntryKey, userID)
|
||||
number := len(lines) + 1
|
||||
if sourceURL == "" {
|
||||
lines = append(lines, fmt.Sprintf("[%d] %s", number, displayName))
|
||||
continue
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("[%d] [%s](%s)", number, escapeMarkdownLinkText(displayName), sourceURL))
|
||||
}
|
||||
|
||||
if len(lines) == 0 {
|
||||
return response
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(strings.TrimRight(response, "\n"))
|
||||
sb.WriteString("\n\nSources:\n")
|
||||
for _, line := range lines {
|
||||
sb.WriteString(line)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
return strings.TrimRight(sb.String(), "\n")
|
||||
}
|
||||
|
||||
func kbCitationDisplayName(citation KBCitation) string {
|
||||
if fileName := strings.TrimSpace(citation.FileName); fileName != "" {
|
||||
return fileName
|
||||
}
|
||||
|
||||
segments := strings.Split(strings.Trim(strings.TrimSpace(citation.EntryKey), "/"), "/")
|
||||
for i := len(segments) - 1; i >= 0; i-- {
|
||||
if segment := strings.TrimSpace(segments[i]); segment != "" {
|
||||
return segment
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func kbCitationRawFileURL(collection, entryKey, userID string) string {
|
||||
collection = strings.TrimSpace(collection)
|
||||
entryKey = strings.Trim(strings.TrimSpace(entryKey), "/")
|
||||
if collection == "" || entryKey == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var escapedEntrySegments []string
|
||||
for _, segment := range strings.Split(entryKey, "/") {
|
||||
if segment == "" {
|
||||
continue
|
||||
}
|
||||
escapedEntrySegments = append(escapedEntrySegments, url.PathEscape(segment))
|
||||
}
|
||||
if len(escapedEntrySegments) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
sourceURL := "/api/agents/collections/" + url.PathEscape(collection) + "/entries-raw/" + strings.Join(escapedEntrySegments, "/")
|
||||
if userID != "" {
|
||||
query := url.Values{}
|
||||
query.Set("user_id", userID)
|
||||
sourceURL += "?" + query.Encode()
|
||||
}
|
||||
return sourceURL
|
||||
}
|
||||
|
||||
func escapeMarkdownLinkText(text string) string {
|
||||
text = strings.ReplaceAll(text, `\`, `\\`)
|
||||
text = strings.ReplaceAll(text, "[", `\[`)
|
||||
text = strings.ReplaceAll(text, "]", `\]`)
|
||||
return text
|
||||
}
|
||||
@@ -167,12 +167,10 @@ func ExecuteChatWithLLM(ctx context.Context, llm cogito.LLM, cfg *AgentConfig, m
|
||||
}
|
||||
}
|
||||
|
||||
kbCitations := &kbCitationList{}
|
||||
if cfg.EnableKnowledgeBase && (kbMode == KBModeAutoSearch || kbMode == KBModeBoth) {
|
||||
kbResult := KBAutoSearchPrompt(ctx, effectiveURL, effectiveKey, cfg.Name, message, cfg.KnowledgeBaseResults, userID)
|
||||
if kbResult.Prompt != "" {
|
||||
fragment = fragment.AddMessage(cogito.SystemMessageRole, kbResult.Prompt)
|
||||
kbCitations.AddKBCitations(kbResult.Citations)
|
||||
kbResults := KBAutoSearchPrompt(ctx, effectiveURL, effectiveKey, cfg.Name, message, cfg.KnowledgeBaseResults, userID)
|
||||
if kbResults != "" {
|
||||
fragment = fragment.AddMessage(cogito.SystemMessageRole, kbResults)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,7 +197,7 @@ func ExecuteChatWithLLM(ctx context.Context, llm cogito.LLM, cfg *AgentConfig, m
|
||||
}
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithTools(
|
||||
cogito.NewToolDefinition(
|
||||
KBSearchMemoryTool{APIURL: effectiveURL, APIKey: effectiveKey, Collection: cfg.Name, MaxResults: kbResults, UserID: userID, CitationCollector: kbCitations},
|
||||
KBSearchMemoryTool{APIURL: effectiveURL, APIKey: effectiveKey, Collection: cfg.Name, MaxResults: kbResults, UserID: userID},
|
||||
KBSearchMemoryArgs{},
|
||||
"search_memory",
|
||||
"Search the knowledge base for relevant information",
|
||||
@@ -338,8 +336,6 @@ func ExecuteChatWithLLM(ctx context.Context, llm cogito.LLM, cfg *AgentConfig, m
|
||||
if cfg.StripThinkingTags && response != "" {
|
||||
response = stripThinkingTags(response)
|
||||
}
|
||||
responseForMemory := response
|
||||
response = AppendKBCitations(response, cfg.Name, userID, kbCitations.Citations())
|
||||
|
||||
// Save conversation to KB when long-term memory is enabled.
|
||||
// Use a detached context: the parent ctx may be cancelled (e.g. in distributed
|
||||
@@ -348,7 +344,7 @@ func ExecuteChatWithLLM(ctx context.Context, llm cogito.LLM, cfg *AgentConfig, m
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
saveConversationToKB(ctx, llm, effectiveURL, effectiveKey, cfg, message, responseForMemory, userID)
|
||||
saveConversationToKB(ctx, llm, effectiveURL, effectiveKey, cfg, message, response, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user