Compare commits

..

2 Commits

Author SHA1 Message Date
Ettore Di Giacinto
b6fed26271 chore(turboquant): retreat pin to 4c1c3ac0 to skip fork GPU regression
CI on the prior 2cbfdc62 pin confirmed our grpc-server.cpp/patch fix
works (tests-turboquant-grpc + all multiarch turboquant builds passed),
but every GPU singlearch turboquant build now hits a static-assertion
error in the fork's own ggml/src/ggml-cuda/fattn-mma-f16.cuh — a
regression introduced by the May 14 #22880 `HIP: RDNA3 mma FA` refactor
(file went from 1855 to 2049 lines).

4c1c3ac0 (2026-05-13 22:12 UTC) is the last commit before that refactor
and still has every API piece grpc-server.cpp depends on (DRAFT_SIMPLE
enum, nested common_params_speculative, model_tgt, get_media_marker(),
common_speculative_types_from_names). MTP support landed later (May 16)
and is not exercised by grpc-server.cpp.

Assisted-by: Claude:claude-opus-4-7
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-05-21 15:54:38 +00:00
Ettore Di Giacinto
66af748332 chore(turboquant): bump to 2cbfdc62 and retire obsolete grpc-server patches
The turboquant fork rebased past ggml-org/llama.cpp#21962, #22397 and
#22838, so common_params_speculative now uses the nested draft/
ngram_simple/ngram_mod layout, server_context_impl exposes model_tgt
(not model), and get_media_marker() is provided. The compatibility
shims in patch-grpc-server.sh were rewriting the shared grpc-server.cpp
to the pre-refactor flat layout, which no longer matches the fork and
broke the build (see PR #9912 CI failure).

Keep only the fork-specific kv_cache_types[] insertion for the
TURBO2_0 / TURBO3_0 / TURBO4_0 enum entries. The dormant
LOCALAI_LEGACY_LLAMA_CPP_SPEC #ifdef blocks in
backend/cpp/llama-cpp/grpc-server.cpp stay as an escape hatch if a
future fork bump regresses.

Assisted-by: Claude:claude-opus-4-7
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-05-21 11:03:46 +00:00
512 changed files with 2627 additions and 45778 deletions

View File

@@ -16,8 +16,7 @@ side (`pkg/oci/cosignverify` plus the gallery YAML).
per-arch manifest before checking signatures.
- **Storage:** Signatures are written as OCI 1.1 referrers
(`--registry-referrers-mode=oci-1-1`) in the new Sigstore bundle format
(current cosign releases do this by default; no `--new-bundle-format`
flag). No `:sha256-<hex>.sig` tag clutter.
(`--new-bundle-format`). No `:sha256-<hex>.sig` tag clutter.
- **Consumer:** `pkg/oci/cosignverify` discovers the bundle via the
referrers API, hands it to `sigstore-go`, and verifies it against the
policy declared in the gallery YAML (`Gallery.Verification`).
@@ -34,14 +33,15 @@ to sign. The job needs:
- `permissions: { id-token: write, contents: read }` at the job level so
the runner can exchange its GitHub OIDC token for a Fulcio cert.
- `sigstore/cosign-installer@v3` step (current cosign releases already
default to the new bundle format).
- `sigstore/cosign-installer@v3` step (cosign ≥ 2.2 for
`--new-bundle-format`).
- After each `docker buildx imagetools create`, resolve the resulting
list digest with `docker buildx imagetools inspect <tag> --format
'{{.Manifest.Digest}}'` and sign:
```sh
cosign sign --yes --recursive \
--new-bundle-format \
--registry-referrers-mode=oci-1-1 \
"${REGISTRY_REPO}@${DIGEST}"
```
@@ -49,12 +49,6 @@ cosign sign --yes --recursive \
Sign by digest, never by tag — signing by tag binds the signature to
whatever the tag points at *now*, and a subsequent tag push orphans it.
`--registry-referrers-mode=oci-1-1` is still gated behind
`COSIGN_EXPERIMENTAL=1` in cosign v2.4.x (set at the job env level in
`backend_merge.yml`). Re-evaluate when bumping the pinned cosign release
— newer versions are expected to graduate this flag and the env var can
then be dropped.
`backend_build_darwin.yml` builds and pushes single-arch darwin images
that bypass the manifest-list merge. If/when those entries get a gallery
`verification:` policy, the equivalent cosign step has to land there

View File

@@ -15,32 +15,3 @@ Let's say the user wants to build a particular backend for a given platform. For
- Unless the user specifies that they want you to run the command, then just print it because not all agent frontends handle long running jobs well and the output may overflow your context
- The user may say they want to build AMD or ROCM instead of hipblas, or Intel instead of SYCL or NVIDIA insted of l4t or cublas. Ask for confirmation if there is ambiguity.
- Sometimes the user may need extra parameters to be added to `docker build` (e.g. `--platform` for cross-platform builds or `--progress` to view the full logs), in which case you can generate the `docker build` command directly.
## Test coverage gate
The core Go suites (`./pkg`, `./core`, plus the in-process integration suite `./tests/e2e`) are covered by a **strict, monotonic coverage ratchet**:
- `make test-coverage` — runs the suites with `covermode=atomic` instrumentation and writes a merged profile to `coverage/coverage.out`. Uses the same prerequisites as `make test`.
- **`--coverpkg` (`COVERAGE_COVERPKG = core/...,pkg/...`):** coverage is attributed to the core+pkg packages, not just the package under test. This is what lets the in-process `tests/e2e` suite (which drives the real HTTP server over loopback via `application.New`) credit the `core/http/endpoints/...` handlers it exercises — folding it in roughly doubled endpoint coverage (e.g. `endpoints/openai` 13.6% → 52%). The denominator is therefore *all* of `core`+`pkg` (minus generated proto, dropped via `COVERAGE_EXCLUDE_RE`), so the number isn't comparable to a plain per-package figure.
- **Integration suites (`COVERAGE_E2E_ROOTS = ./tests/e2e`)** run non-recursively (excludes `tests/e2e/distributed`, which needs containers) with `--label-filter=!real-models` (those need a downloaded model) against the mock backend built by `prepare-test`. `tests/integration` is deliberately excluded — it needs `make backends/local-store`, which the coverage CI job doesn't build.
- **Flake note:** folding integration tests into a *strict* gate means a hard e2e failure (or a spec that silently stops running) can fail the coverage gate, not just the test. `--flake-attempts` absorbs transient retryable failures; covermode=atomic keeps line coverage deterministic otherwise.
- **Why one ginkgo run per root (`scripts/run-coverage.sh`):** passing several recursive roots to a *single* ginkgo invocation (e.g. `ginkgo -r ./pkg ./core`) only merges **one** root's coverprofile into `--output-dir`/`--coverprofile` — the others are silently dropped. Verified with ginkgo 2.29.0: `-r ./pkg ./core` yields only `./pkg` coverage, while `-r ./core` alone yields all 34 core packages. So the script runs each root separately and concatenates the (disjoint) profiles. Don't "simplify" it back to a single multi-root invocation — that's how `core/` (including all of `core/http`, ~7.4k statements) silently vanished from the number before.
- **Build tags (`COVERAGE_TAGS`, passed via `GINKGO_TAGS`):** defaults to `debug auth`. The `auth` tag is required to compile the real (sqlite-backed) auth implementation and its ~150 `//go:build auth` tests — without it those files aren't built, the tests don't run, and the gate scores auth against a stub (~3.7% instead of ~38%). If you add new tag-gated tests, extend `COVERAGE_TAGS` or they won't count (and likely won't run in CI at all).
- `make test-coverage-check` — runs `test-coverage`, then `scripts/coverage-check.sh` fails the build if total coverage is **below** the committed baseline in `coverage-baseline.txt`. The Linux job in `.github/workflows/test.yml` runs this instead of `make test`.
- `make test-coverage-baseline` — regenerates and overwrites `coverage-baseline.txt` from the current run.
- `make install-hooks` — sets `core.hooksPath` to the versioned `.githooks/`, whose `pre-commit` runs checks scoped to what's staged: Go changes → `make lint` + `make test-coverage-check`; `core/http/react-ui/` changes → `make test-ui-coverage-check` (Playwright e2e + UI coverage gate). A commit touching neither is skipped; bypass with `git commit --no-verify`. The hook resolves golangci-lint's new-from base to `upstream/master``origin/master``master`, so it works from a fork clone where `origin/master` is stale (passed to `make lint` via `LINT_NEW_FROM`).
### React UI coverage
The React UI (`core/http/react-ui/`) has **no component/unit tests** — its only tests are the Playwright e2e specs in `e2e/`, which run against the real app served by `tests/e2e-ui/ui-test-server` (the dist is `//go:embed`ed, so the server is rebuilt per coverage run). Those specs do genuinely exercise the UI (clicks, `fill`, `setInputFiles`, `getByRole`/`getByText`, visibility/value assertions).
- `make test-ui-coverage` — builds an istanbul-instrumented bundle (`COVERAGE=true`, via `vite-plugin-istanbul` with `forceBuildInstrument: true` — the plugin skips production builds otherwise), re-embeds it into `ui-test-server` (the dist is `//go:embed`ed), runs the Playwright specs, and writes an `nyc` report to `core/http/react-ui/coverage/`. The specs import `{ test, expect }` from `e2e/coverage-fixtures.js` (re-exports Playwright's, plus harvests `window.__coverage__` into `.nyc_output/` after each test). Instrumentation is off unless `COVERAGE=true`, so dev/prod builds and plain `make test-ui-e2e` are unaffected (the fixture no-ops when `window.__coverage__` is absent).
- **Browser:** the flake dev shell ships `chromium` and exports `PLAYWRIGHT_CHROMIUM_PATH`; `playwright.config.js` uses it via `launchOptions.executablePath`, and the Makefile skips `playwright install` when it's set. This avoids Playwright's downloaded browser, which can't resolve system libs (`libglib-2.0`, …) on NixOS. In CI (no `PLAYWRIGHT_CHROMIUM_PATH`) the Makefile falls back to `playwright install --with-deps chromium`.
- The app is a React SPA, so coverage accumulates across in-app navigation within a test; a full `page.goto`/reload resets it.
- `.nycrc.json` uses `all: true`, so **every `src/**` file is in the report**, including 0%-coverage ones — that's how you spot features with no test at all (sort the HTML report or `coverage-summary.json` by line% ascending).
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` (default **1.0pp**) below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. **Why a tolerance (unlike the strict Go gate):** UI e2e line coverage is *non-deterministic* — async/debounced paths (e.g. the VRAM estimate's 500ms debounce) make identical specs vary ~0.5pp run-to-run, so a zero-tolerance gate would flake. Keep the tolerance just above the observed jitter. Run in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
Rules:
- The gate is **strict — there is no tolerance**. Any decrease fails, regardless of how many lines a PR adds or deletes. `covermode=atomic` makes line coverage deterministic, so there's no run-to-run jitter to excuse.
- When a change legitimately **raises** coverage, run `make test-coverage-baseline` and **commit** the updated `coverage-baseline.txt` so the ratchet moves up. Never lower the baseline by hand.
- If you can't get coverage back to baseline, the fix is to **add tests**, not to edit the baseline.

View File

@@ -50,17 +50,6 @@ Do not mix styles within a package. If you are extending tests in a package that
This is enforced by `golangci-lint` via the `forbidigo` linter (see `.golangci.yml`); calls like `t.Errorf` / `t.Fatalf` / `t.Run` / `t.Skip` / `t.Logf` are flagged. Run `make lint` locally before submitting; the same check runs in CI (`.github/workflows/lint.yml`).
## Outbound HTTP
All outbound HTTP must go through `github.com/mudler/LocalAI/pkg/httpclient` rather than the standard library's default client. Use `httpclient.New(...)` (no body deadline — safe for streaming/SSE) or `httpclient.NewWithTimeout(d, ...)` (simple request/response). Both **refuse redirects by default** and set a TLS 1.2 floor.
The reason is GHSA-3mj3-57v2-4636: the std default client follows redirects, and on a *cross-host* redirect Go forwards custom credential headers (e.g. Anthropic's `x-api-key`) to the redirect target, leaking the secret. `httpclient` fails closed instead.
- Need to follow redirects (download CDNs, registry blobs, GitHub asset URLs)? Pass `httpclient.WithFollowRedirects()` — it still strips credential headers on any cross-host hop.
- Have a custom transport (IP-pinned dialer, HTTP/2 tuning, a credential-injecting `RoundTripper`)? Pass `httpclient.WithTransport(rt)`, basing the transport on `httpclient.HardenedTransport()` to keep the TLS floor. Handed a `*http.Client` by a library? `httpclient.Harden(c)` applies the policy in place.
This is enforced by `forbidigo` (see `.golangci.yml`): `http.DefaultClient` and `http.Get`/`Post`/`PostForm`/`Head` are flagged. The `&http.Client{}` composite literal can't be matched precisely by forbidigo without also flagging legitimate `*http.Client` type references, so that form is caught by review — don't construct raw clients.
## Documentation
The project documentation is located in `docs/content`. When adding new features or changing existing functionality, it is crucial to update the documentation to reflect these changes. This helps users understand how to use the new capabilities and ensures the documentation stays relevant.

View File

@@ -4,7 +4,6 @@
.devcontainer
models
backends
volumes
examples/chatbot-ui/models
backend/go/image/stablediffusion-ggml/build/
backend/go/*/build
@@ -22,27 +21,3 @@ __pycache__
# backend virtual environments
**/venv
backend/python/**/source
# In-place llama.cpp clone + per-variant build copies. The Makefile
# clones llama.cpp itself at the pinned LLAMA_VERSION; if a stale
# local checkout is COPY'd into the image, the `llama.cpp:` target
# sees the directory and skips re-cloning, so grpc-server.cpp ends
# up compiled against whatever (likely older) commit the host had.
backend/cpp/llama-cpp/llama.cpp
backend/cpp/llama-cpp-*-build
# Rust backend build output (sources are tracked; target/ is generated)
backend/rust/*/target
# Local-only artifacts that bloat the build context but the image never needs.
# Saved image tarballs, locally-installed backends, the host-built binary, and
# assorted tool/scratch dirs. None of these are git-tracked.
backend-images
local-backends
local-ai
.crush
protoc
tests
# Installed via npm inside the build stage; no need to ship the host copy.
**/node_modules

View File

@@ -1,60 +0,0 @@
#!/usr/bin/env sh
#
# LocalAI pre-commit hook. Install it (once per clone) with:
#
# make install-hooks
#
# Runs only the checks relevant to what's staged:
# - Go files -> make lint + make test-coverage-check
# - core/http/react-ui -> make test-ui-coverage-check (Playwright e2e + gate)
# A commit touching neither is skipped entirely (docs/YAML/etc. can't change
# lint findings, Go coverage, or the UI).
#
# To bypass for a single commit (e.g. a WIP checkpoint): git commit --no-verify
set -eu
repo_root="$(git rev-parse --show-toplevel)"
cd "$repo_root"
staged="$(git diff --cached --name-only --diff-filter=ACMRD)"
go_changed=0
ui_changed=0
if echo "$staged" | grep -qE '\.go$'; then go_changed=1; fi
if echo "$staged" | grep -qE '^core/http/react-ui/'; then ui_changed=1; fi
if [ "$go_changed" -eq 0 ] && [ "$ui_changed" -eq 0 ]; then
echo "pre-commit: no Go or React UI changes staged — skipping."
exit 0
fi
if [ "$go_changed" -eq 1 ]; then
# Resolve the ref golangci-lint's new-from-merge-base should compare
# against. .golangci.yml pins origin/master, which is correct in CI
# (origin == the canonical repo) but wrong from a fork clone, where
# origin/master lags behind and lint would report the whole upstream
# backlog. Prefer upstream/master, then origin/master, then master.
lint_base=""
for ref in upstream/master origin/master master; do
if git rev-parse --verify --quiet "${ref}^{commit}" >/dev/null 2>&1; then
lint_base="$ref"
break
fi
done
echo "pre-commit ▶ golangci-lint (make lint${lint_base:+, new-from $lint_base})"
make lint LINT_NEW_FROM="$lint_base"
echo "pre-commit ▶ coverage gate (make test-coverage-check) — builds and runs the"
echo " pkg/core suites plus tests/e2e; can take a few minutes."
make test-coverage-check
fi
if [ "$ui_changed" -eq 1 ]; then
echo "pre-commit ▶ React UI e2e + coverage gate (make test-ui-coverage-check) —"
echo " rebuilds the UI + ui-test-server, runs the Playwright specs, and"
echo " fails if line coverage regressed; can take a couple of minutes."
make test-ui-coverage-check
fi
echo "pre-commit ✓ all relevant checks passed"

View File

@@ -690,19 +690,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-rfdetr-cpp'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "rfdetr-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
@@ -716,19 +703,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-parakeet-cpp'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
@@ -1517,19 +1491,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-13-rfdetr-cpp'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "rfdetr-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -1543,19 +1504,6 @@ include:
backend: "sam3-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'false'
tag-latest: 'auto'
tag-suffix: '-nvidia-l4t-cuda-13-arm64-rfdetr-cpp'
base-image: "ubuntu:24.04"
ubuntu-version: '2404'
runs-on: 'ubuntu-24.04-arm'
backend: "rfdetr-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -1569,19 +1517,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-13-parakeet-cpp'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -1595,19 +1530,6 @@ include:
backend: "whisper"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'false'
tag-latest: 'auto'
tag-suffix: '-nvidia-l4t-cuda-13-arm64-parakeet-cpp'
base-image: "ubuntu:24.04"
ubuntu-version: '2404'
runs-on: 'ubuntu-24.04-arm'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -2713,74 +2635,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
# rfdetr-cpp
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-rfdetr-cpp'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "rfdetr-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'sycl_f32'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-intel-sycl-f32-rfdetr-cpp'
runs-on: 'ubuntu-latest'
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
skip-drivers: 'false'
backend: "rfdetr-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'sycl_f16'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-intel-sycl-f16-rfdetr-cpp'
runs-on: 'ubuntu-latest'
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
skip-drivers: 'false'
backend: "rfdetr-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'vulkan'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
platform-tag: 'amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-vulkan-rfdetr-cpp'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "rfdetr-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'vulkan'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/arm64'
platform-tag: 'arm64'
tag-latest: 'auto'
tag-suffix: '-gpu-vulkan-rfdetr-cpp'
runs-on: 'ubuntu-24.04-arm'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "rfdetr-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'sycl_f32'
cuda-major-version: ""
cuda-minor-version: ""
@@ -2861,19 +2715,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2204'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'false'
tag-latest: 'auto'
tag-suffix: '-nvidia-l4t-arm64-rfdetr-cpp'
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
runs-on: 'ubuntu-24.04-arm'
backend: "rfdetr-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2204'
# whisper
- build-type: ''
cuda-major-version: ""
@@ -2983,115 +2824,6 @@ include:
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
# parakeet-cpp
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
platform-tag: 'amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-parakeet-cpp'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/arm64'
platform-tag: 'arm64'
tag-latest: 'auto'
tag-suffix: '-cpu-parakeet-cpp'
runs-on: 'ubuntu-24.04-arm'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'sycl_f32'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-intel-sycl-f32-parakeet-cpp'
runs-on: 'ubuntu-latest'
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
skip-drivers: 'false'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'sycl_f16'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-intel-sycl-f16-parakeet-cpp'
runs-on: 'ubuntu-latest'
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
skip-drivers: 'false'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'vulkan'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
platform-tag: 'amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-vulkan-parakeet-cpp'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'vulkan'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/arm64'
platform-tag: 'arm64'
tag-latest: 'auto'
tag-suffix: '-gpu-vulkan-parakeet-cpp'
runs-on: 'ubuntu-24.04-arm'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "0"
platforms: 'linux/arm64'
skip-drivers: 'false'
tag-latest: 'auto'
tag-suffix: '-nvidia-l4t-arm64-parakeet-cpp'
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
runs-on: 'ubuntu-24.04-arm'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2204'
- build-type: 'hipblas'
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-rocm-hipblas-parakeet-cpp'
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
runs-on: 'ubuntu-latest'
skip-drivers: 'false'
backend: "parakeet-cpp"
dockerfile: "./backend/Dockerfile.golang"
context: "./"
ubuntu-version: '2404'
# acestep-cpp
- build-type: ''
cuda-major-version: ""
@@ -4124,10 +3856,6 @@ includeDarwin:
tag-suffix: "-metal-darwin-arm64-whisper"
build-type: "metal"
lang: "go"
- backend: "parakeet-cpp"
tag-suffix: "-metal-darwin-arm64-parakeet-cpp"
build-type: "metal"
lang: "go"
- backend: "acestep-cpp"
tag-suffix: "-metal-darwin-arm64-acestep-cpp"
build-type: "metal"

View File

@@ -40,11 +40,6 @@ jobs:
id-token: write
env:
quay_username: ${{ secrets.quayUsername }}
# cosign v2.4.x still gates --registry-referrers-mode=oci-1-1 behind
# this flag. Without it, signing fails with:
# invalid argument "oci-1-1" for "--registry-referrers-mode" flag:
# in order to use mode "oci-1-1", you must set COSIGN_EXPERIMENTAL=1
COSIGN_EXPERIMENTAL: '1'
steps:
# Sparse checkout: the merge job needs `.github/scripts/` (for the
# keepalive cleanup script) but none of the source tree.
@@ -71,8 +66,7 @@ jobs:
# cosign signs each pushed manifest list with --recursive so the
# index and every per-arch entry get an attached Sigstore bundle.
# Recent cosign releases always emit the new bundle format, so
# there's no extra CLI flag to opt into it.
# 2.2+ is required for --new-bundle-format.
- name: Install cosign
if: github.event_name != 'pull_request'
uses: sigstore/cosign-installer@v3
@@ -159,6 +153,7 @@ jobs:
# manifest before checking signatures need the per-arch
# signatures, not just the list-level one.
cosign sign --yes --recursive \
--new-bundle-format \
--registry-referrers-mode=oci-1-1 \
"quay.io/go-skynet/local-ai-backends@${digest}"
@@ -185,6 +180,7 @@ jobs:
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
cosign sign --yes --recursive \
--new-bundle-format \
--registry-referrers-mode=oci-1-1 \
"localai/localai-backends@${digest}"

View File

@@ -30,10 +30,6 @@ jobs:
variable: "WHISPER_CPP_VERSION"
branch: "master"
file: "backend/go/whisper/Makefile"
- repository: "mudler/parakeet.cpp"
variable: "PARAKEET_VERSION"
branch: "master"
file: "backend/go/parakeet-cpp/Makefile"
- repository: "leejet/stable-diffusion.cpp"
variable: "STABLEDIFFUSION_GGML_VERSION"
branch: "master"
@@ -54,10 +50,6 @@ jobs:
variable: "SAM3_VERSION"
branch: "main"
file: "backend/go/sam3-cpp/Makefile"
- repository: "mudler/rf-detr.cpp"
variable: "RFDETR_VERSION"
branch: "main"
file: "backend/go/rfdetr-cpp/Makefile"
- repository: "predict-woo/qwen3-tts.cpp"
variable: "QWEN3TTS_CPP_VERSION"
branch: "main"

View File

@@ -106,7 +106,6 @@ jobs:
type=ref,event=branch
type=semver,pattern={{raw}}
type=sha
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
flavor: |
latest=${{ inputs.tag-latest }}
suffix=${{ inputs.tag-suffix }},onlatest=true

View File

@@ -80,7 +80,6 @@ jobs:
type=ref,event=branch
type=semver,pattern={{raw}}
type=sha
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
flavor: |
latest=${{ inputs.tag-latest }}
suffix=${{ inputs.tag-suffix }},onlatest=true

View File

@@ -11,7 +11,7 @@ jobs:
if: github.repository == 'mudler/LocalAI'
runs-on: ubuntu-latest
steps:
- uses: actions/stale@eb5cf3af3ac0a1aa4c9c45633dd1ae542a27a899 # v9
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v9
with:
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'

View File

@@ -37,7 +37,6 @@ jobs:
sglang: ${{ steps.detect.outputs.sglang }}
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
qwen3-tts-cpp: ${{ steps.detect.outputs.qwen3-tts-cpp }}
rfdetr-cpp: ${{ steps.detect.outputs.rfdetr-cpp }}
vibevoice-cpp: ${{ steps.detect.outputs.vibevoice-cpp }}
localvqe: ${{ steps.detect.outputs.localvqe }}
voxtral: ${{ steps.detect.outputs.voxtral }}
@@ -46,7 +45,6 @@ jobs:
speaker-recognition: ${{ steps.detect.outputs.speaker-recognition }}
sherpa-onnx: ${{ steps.detect.outputs.sherpa-onnx }}
whisper: ${{ steps.detect.outputs.whisper }}
parakeet-cpp: ${{ steps.detect.outputs.parakeet-cpp }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
@@ -634,26 +632,6 @@ jobs:
- name: Build whisper backend image and run transcription gRPC e2e tests
run: |
make test-extra-backend-whisper-transcription
# Parakeet ASR via the parakeet-cpp backend (C++/ggml port of NeMo
# Parakeet). Drives AudioTranscription (offline, with word timestamps) on
# tdt_ctc-110m + the JFK 11s clip.
tests-parakeet-cpp-grpc-transcription:
needs: detect-changes
if: needs.detect-changes.outputs.parakeet-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest
timeout-minutes: 90
steps:
- name: Clone
uses: actions/checkout@v6
with:
submodules: true
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: '1.25.4'
- name: Build parakeet-cpp backend image and run transcription gRPC e2e tests
run: |
make test-extra-backend-parakeet-cpp-transcription
# VITS TTS via the sherpa-onnx backend. Drives both TTS (file write) and
# TTSStream (PCM chunks) on the e2e-backends harness.
tests-sherpa-onnx-grpc-tts:
@@ -865,42 +843,6 @@ jobs:
- name: Test qwen3-tts-cpp
run: |
make --jobs=5 --output-sync=target -C backend/go/qwen3-tts-cpp test
# Per-backend smoke for rfdetr-cpp: builds the .so + Go binary and runs
# `make -C backend/go/rfdetr-cpp test`. test.sh fetches the small (~20 MB)
# rfdetr-nano-q8_0 GGUF from the published mudler/rfdetr-cpp-nano HF repo
# via curl and synthesises a tiny PNG to exercise the wire protocol.
tests-rfdetr-cpp:
needs: detect-changes
if: needs.detect-changes.outputs.rfdetr-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v6
with:
submodules: true
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install -y build-essential cmake curl libopenblas-dev
- name: Setup Go
uses: actions/setup-go@v5
- name: Display Go version
run: go version
- name: Proto Dependencies
run: |
# Install protoc
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
rm protoc.zip
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
PATH="$PATH:$HOME/go/bin" make protogen-go
- name: Build rfdetr-cpp
run: |
make --jobs=5 --output-sync=target -C backend/go/rfdetr-cpp
- name: Test rfdetr-cpp
run: |
make --jobs=5 --output-sync=target -C backend/go/rfdetr-cpp test
# Per-backend smoke for vibevoice-cpp: builds the .so + Go binary and
# runs `make -C backend/go/vibevoice-cpp test`. test.sh auto-downloads
# the published mudler/vibevoice.cpp-models bundle (TTS Q8_0 + ASR Q4_K

View File

@@ -53,22 +53,9 @@ jobs:
node-version: '22'
- name: Build React UI
run: make react-ui
# Runs the core suite with coverage and fails if total coverage dropped
# below the committed baseline (coverage-baseline.txt). The gate is
# strict — any decrease fails. Raise the baseline with
# `make test-coverage-baseline` and commit it when coverage rises.
- name: Test (with coverage gate)
- name: Test
run: |
PATH="$PATH:/root/go/bin" make --jobs 5 --output-sync=target test-coverage-check
- name: Upload coverage report
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: coverage-linux
path: |
coverage/coverage.out
coverage/coverage.html
if-no-files-found: ignore
PATH="$PATH:/root/go/bin" make --jobs 5 --output-sync=target test
- name: Setup tmate session if tests fail
if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.23

View File

@@ -37,10 +37,6 @@ jobs:
uses: actions/setup-node@v6
with:
node-version: '22'
- name: Setup Bun
uses: oven-sh/setup-bun@v2
with:
bun-version: '1.3.11'
- name: Proto Dependencies
run: |
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
@@ -52,12 +48,16 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y build-essential libopus-dev
# Builds an instrumented UI bundle, runs the Playwright specs, and fails
# if line coverage regressed beyond the jitter tolerance (the gate is
# in `make test-ui-coverage-check`). PLAYWRIGHT_CHROMIUM_PATH is unset
# here, so scripts/ensure-playwright-browser.sh installs Chromium via apt.
- name: Run UI e2e + coverage gate
run: PATH="$PATH:$HOME/go/bin" make test-ui-coverage-check
- name: Build UI test server
run: PATH="$PATH:$HOME/go/bin" make build-ui-test-server
- name: Install Playwright
working-directory: core/http/react-ui
run: |
npm install
npx playwright install --with-deps chromium
- name: Run Playwright tests
working-directory: core/http/react-ui
run: npx playwright test
- name: Upload Playwright report
if: ${{ failure() }}
uses: actions/upload-artifact@v7
@@ -65,14 +65,6 @@ jobs:
name: playwright-report
path: core/http/react-ui/playwright-report/
retention-days: 7
- name: Upload UI coverage report
if: ${{ always() }}
uses: actions/upload-artifact@v7
with:
name: ui-coverage
path: core/http/react-ui/coverage/
if-no-files-found: ignore
retention-days: 7
- name: Setup tmate session if tests fail
if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.23

14
.gitignore vendored
View File

@@ -26,10 +26,6 @@ go-bert
LocalAI
/local-ai
/local-ai-launcher
# Root-level build artifacts when running `go build ./...` against
# Go backend packages whose main lives under backend/go/.
/cloud-proxy
/local-store
# prevent above rules from omitting the helm chart
!charts/*
# prevent above rules from omitting the api/localai folder
@@ -70,17 +66,10 @@ docs/static/gallery.html
# per-developer customization files for the development container
.devcontainer/customization/*
# Coverage profiles (the committed baseline is coverage-baseline.txt)
/coverage/
# React UI build artifacts (keep placeholder dist/index.html)
core/http/react-ui/node_modules/
core/http/react-ui/dist
# React UI coverage (vite-plugin-istanbul + nyc, via `make test-ui-coverage`)
core/http/react-ui/.nyc_output/
core/http/react-ui/coverage/
# Extracted backend binaries for container-based testing
local-backends/
@@ -88,6 +77,3 @@ local-backends/
tests/e2e-ui/ui-test-server
core/http/react-ui/playwright-report/
core/http/react-ui/test-results/
# Local worktrees
.worktrees/

View File

@@ -56,20 +56,6 @@ linters:
# are exempt — see linters.exclusions.rules below.
- pattern: '^os\.(Getenv|LookupEnv|Environ)$'
msg: 'Plumb config through ApplicationConfig (or the relevant CLI struct) instead of reading env directly. CLI entry points (core/cli/) bind env vars via kong''s `env:` tag — that is the only sanctioned env→struct boundary. See .agents/coding-style.md.'
# Outbound HTTP must go through pkg/httpclient, which refuses redirects
# by default and sets a TLS floor. The std-library default client and
# the http.Get/Post/... convenience helpers follow redirects (up to 10)
# and, on a cross-host redirect, forward custom credential headers such
# as Anthropic's x-api-key to the redirect target — leaking the secret
# (GHSA-3mj3-57v2-4636). forbidigo can't precisely match the
# `&http.Client{}` composite literal without also flagging legitimate
# `*http.Client` type references, so that form is enforced by
# convention + review; these two patterns catch the implicit-default
# client, which is the common footgun.
- pattern: '^http\.DefaultClient$'
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.DefaultClient — the std client follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
- pattern: '^http\.(Get|Post|PostForm|Head)$'
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.Get/Post/PostForm/Head — these use http.DefaultClient, which follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
exclusions:
paths:
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
@@ -109,18 +95,3 @@ linters:
- path: _test\.go$
text: 'os\.(Getenv|LookupEnv|Environ)'
linters: [forbidigo]
# pkg/httpclient is the sanctioned home for outbound HTTP clients; it
# necessarily references net/http directly.
- path: ^pkg/httpclient/
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
linters: [forbidigo]
# Tests drive local httptest servers where redirect/TLS hardening is
# irrelevant; the std client is fine there.
- path: _test\.go$
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
linters: [forbidigo]
# Vendored upstream whisper.cpp Go bindings are a separate module and
# cannot import pkg/httpclient.
- path: ^backend/go/whisper/sources/
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
linters: [forbidigo]

View File

@@ -198,7 +198,6 @@ For AI-assisted development, see [`AGENTS.md`](AGENTS.md) (or the equivalent [`C
- Prefer modern Go idioms — for example, use `any` instead of `interface{}`.
- Use [`golangci-lint`](https://golangci-lint.run) to catch common issues before submitting a PR.
- Run `make install-hooks` once per clone to enable the pre-commit hook: Go changes run `make lint` + the coverage gate (`make test-coverage-check`); `core/http/react-ui/` changes run the Playwright e2e suite (`make test-ui`). Bypass a single commit with `git commit --no-verify`.
- Use [`github.com/mudler/xlog`](https://github.com/mudler/xlog) for logging (same API as `slog`). Do not use `fmt.Println` or the standard `log` package for operational logging.
- Use tab indentation for Go files (as defined in `.editorconfig`).

161
Makefile
View File

@@ -1,5 +1,5 @@
# Disable parallel execution for backend builds
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
GOCMD=go
GOTEST=$(GOCMD) test
@@ -69,41 +69,10 @@ else
GORELEASER=$(shell which goreleaser)
endif
TEST_PATHS?=./api/... ./pkg/... ./core/... ./backend/go/cloud-proxy/... ./backend/go/local-store/...
## Coverage output and the committed baseline that CI compares against.
## The gate is strict: total coverage must never decrease (no tolerance).
## covermode=atomic makes line coverage deterministic regardless of test
## ordering or flake retries, so there is no run-to-run jitter to absorb.
COVERAGE_DIR?=$(abspath ./coverage)
COVERAGE_PROFILE?=$(COVERAGE_DIR)/coverage.out
COVERAGE_BASELINE?=coverage-baseline.txt
## Coverage is collected one recursive root at a time and merged (see
## scripts/run-coverage.sh): passing several recursive roots to a single
## ginkgo invocation only keeps one root's coverprofile. Mirrors TEST_PATHS
## minus ./api (which doesn't exist).
COVERAGE_ROOTS?=./pkg ./core
## Build tags for the coverage build. `auth` is required to compile the real
## auth implementation and its ~150 `//go:build auth` tests (otherwise they're
## invisible and the gate scores auth against a stub). `debug` matches `test`.
COVERAGE_TAGS?=debug auth
## Coverage is attributed to these packages via --coverpkg, so the in-process
## integration suites (COVERAGE_E2E_ROOTS) credit the core/http handlers they
## drive over HTTP — not just their own test package.
COVERAGE_COVERPKG?=github.com/mudler/LocalAI/core/...,github.com/mudler/LocalAI/pkg/...
## In-process integration suites folded into coverage. Run non-recursively
## (excludes tests/e2e/distributed, which needs containers) with the mock
## backend built by prepare-test. real-models specs need a downloaded model,
## so they're filtered out. NOTE: tests/integration is intentionally NOT here —
## it needs the local-store backend built (`make backends/local-store`), which
## the coverage CI job doesn't do.
COVERAGE_E2E_ROOTS?=./tests/e2e
COVERAGE_E2E_LABELS?=!real-models
## Drop generated protobuf from the denominator (it has no tests by design).
COVERAGE_EXCLUDE_RE?=grpc/proto/.*[.]pb[.]go
TEST_PATHS?=./api/... ./pkg/... ./core/...
.PHONY: all test test-coverage test-coverage-baseline test-coverage-check test-ui test-ui-coverage-baseline test-ui-coverage-check install-hooks build vendor lint lint-all
.PHONY: all test build vendor lint lint-all
all: help
@@ -201,36 +170,6 @@ test: prepare-test
OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
## Runs the core suite ($(TEST_PATHS)) with statement-coverage instrumentation
## and writes a merged profile to $(COVERAGE_PROFILE). Deliberately omits
## --fail-fast so a single failure doesn't truncate the coverage number, and
## uses covermode=atomic so the result is deterministic. Prints the total.
test-coverage: prepare-test
@echo 'Running tests with coverage'
GINKGO_TAGS="$(COVERAGE_TAGS)" \
COVERAGE_COVERPKG="$(COVERAGE_COVERPKG)" \
COVERAGE_E2E_ROOTS="$(COVERAGE_E2E_ROOTS)" \
COVERAGE_E2E_LABELS="$(COVERAGE_E2E_LABELS)" \
COVERAGE_EXCLUDE_RE='$(COVERAGE_EXCLUDE_RE)' \
OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \
scripts/run-coverage.sh $(COVERAGE_DIR) $(COVERAGE_PROFILE) $(TEST_FLAKES) $(COVERAGE_ROOTS)
@$(GOCMD) tool cover -html=$(COVERAGE_PROFILE) -o $(COVERAGE_DIR)/coverage.html
@$(GOCMD) tool cover -func=$(COVERAGE_PROFILE) | tail -n1
## Writes the current total coverage to $(COVERAGE_BASELINE). Run this (and
## commit the result) whenever a change legitimately raises coverage so the
## ratchet moves up. Never lower it by hand.
test-coverage-baseline: test-coverage
@$(GOCMD) tool cover -func=$(COVERAGE_PROFILE) | awk '/^total:/{gsub(/%/,"",$$NF); print $$NF}' > $(COVERAGE_BASELINE)
@echo "Saved coverage baseline: $$(cat $(COVERAGE_BASELINE))%"
## CI gate: fails if total coverage dropped more than COVERAGE_TOLERANCE
## (default 0.5pp) below the committed baseline. A small tolerance absorbs the
## run-to-run jitter from the in-process tests/e2e suite folded in via
## --coverpkg (timing-dependent which handler lines execute).
test-coverage-check: test-coverage
@scripts/coverage-check.sh $(COVERAGE_PROFILE) $(COVERAGE_BASELINE)
########################################################
## Lint
########################################################
@@ -246,17 +185,12 @@ test-coverage-check: test-coverage
## everything else automatically, so new packages are scanned by default.
LINT_EXCLUDE_DIRS_RE=/(backend/go/(piper|silero-vad|llm)|cmd/launcher)(/|$$)
## Set LINT_NEW_FROM to a git ref to override .golangci.yml's
## new-from-merge-base (origin/master). Useful from a fork clone where
## origin/master is stale relative to the canonical repo — the pre-commit
## hook passes the resolved upstream ref here so local lint matches CI.
LINT_NEW_FROM?=
lint:
@command -v golangci-lint >/dev/null 2>&1 || { \
echo 'golangci-lint not installed. Install: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest'; \
exit 1; \
}
golangci-lint run $(if $(LINT_NEW_FROM),--new-from-merge-base=$(LINT_NEW_FROM),) $$(go list -e -f '{{.Dir}}' ./... | grep -vE '$(LINT_EXCLUDE_DIRS_RE)')
golangci-lint run $$(go list -e -f '{{.Dir}}' ./... | grep -vE '$(LINT_EXCLUDE_DIRS_RE)')
## Like `lint` but reports every issue, including the pre-existing baseline
## that `lint` ignores via .golangci.yml's new-from-merge-base. Use this to
@@ -268,17 +202,6 @@ lint-all:
}
golangci-lint run --new=false --new-from-merge-base= --new-from-rev= $$(go list -e -f '{{.Dir}}' ./... | grep -vE '$(LINT_EXCLUDE_DIRS_RE)')
########################################################
## Git hooks
########################################################
## Points git at the versioned .githooks/ directory so the pre-commit hook
## (lint + coverage gate) runs locally. Run once per clone. Undo with:
## `git config --unset core.hooksPath`. Skip a single commit with
## `git commit --no-verify`.
install-hooks:
git config core.hooksPath .githooks
@echo 'Installed git hooks: core.hooksPath -> .githooks (pre-commit runs lint + test-coverage-check on Go changes)'
########################################################
## E2E AIO tests (uses standard image with pre-configured models)
########################################################
@@ -345,13 +268,12 @@ prepare-e2e:
run-e2e-image:
docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests
test-e2e: build-mock-backend build-cloud-proxy-backend prepare-e2e run-e2e-image
test-e2e: build-mock-backend prepare-e2e run-e2e-image
@echo 'Running e2e tests'
BUILD_TYPE=$(BUILD_TYPE) \
LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
$(MAKE) clean-mock-backend
$(MAKE) clean-cloud-proxy-backend
$(MAKE) teardown-e2e
docker rmi localai-tests
@@ -558,7 +480,6 @@ prepare-test-extra: protogen-python
$(MAKE) -C backend/python/insightface
$(MAKE) -C backend/python/speaker-recognition
$(MAKE) -C backend/rust/kokoros kokoros-grpc
$(MAKE) -C backend/go/rfdetr-cpp
test-extra: prepare-test-extra
$(MAKE) -C backend/python/transformers test
@@ -585,7 +506,6 @@ test-extra: prepare-test-extra
$(MAKE) -C backend/python/insightface test
$(MAKE) -C backend/python/speaker-recognition test
$(MAKE) -C backend/rust/kokoros test
$(MAKE) -C backend/go/rfdetr-cpp test
##
## End-to-end gRPC tests that exercise a built backend container image.
@@ -991,19 +911,6 @@ test-extra-backend-whisper-transcription: docker-build-whisper
BACKEND_TEST_CAPS=health,load,transcription \
$(MAKE) test-extra-backend
## Audio transcription wrapper for the parakeet-cpp (parakeet.cpp ggml port)
## backend. Mirrors test-extra-backend-whisper-transcription: drives the
## AudioTranscription / AudioTranscriptionStream RPCs against a published
## Parakeet GGUF using the JFK 11s clip from whisper.cpp's CI samples. Not
## part of the default test suite - run explicitly once the pinned model URL
## is reachable.
test-extra-backend-parakeet-cpp-transcription: docker-build-parakeet-cpp
BACKEND_IMAGE=local-ai-backend:parakeet-cpp \
BACKEND_TEST_MODEL_URL=https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/tdt_ctc-110m-f16.gguf \
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
BACKEND_TEST_CAPS=health,load,transcription \
$(MAKE) test-extra-backend
## LocalVQE audio transform (joint AEC + noise suppression + dereverb).
## Exercises the audio_transform capability end-to-end: batch transform
## of a real WAV fixture and bidi streaming of synthetic silent frames.
@@ -1157,12 +1064,10 @@ BACKEND_DS4 = ds4|ds4|.|false|false
# Golang backends
BACKEND_PIPER = piper|golang|.|false|true
BACKEND_LOCAL_STORE = local-store|golang|.|false|true
BACKEND_CLOUD_PROXY = cloud-proxy|golang|.|false|true
BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
BACKEND_WHISPER = whisper|golang|.|false|true
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
BACKEND_VOXTRAL = voxtral|golang|.|false|true
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
@@ -1212,7 +1117,6 @@ BACKEND_KOKOROS = kokoros|rust|.|false|true
# C++ backends (Go wrapper with purego)
BACKEND_SAM3_CPP = sam3-cpp|golang|.|false|true
BACKEND_RFDETR_CPP = rfdetr-cpp|golang|.|false|true
# Helper function to build docker image for a backend
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
@@ -1245,12 +1149,10 @@ $(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
$(eval $(call generate-docker-build-target,$(BACKEND_DS4)))
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
$(eval $(call generate-docker-build-target,$(BACKEND_CLOUD_PROXY)))
$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
@@ -1293,14 +1195,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
$(eval $(call generate-docker-build-target,$(BACKEND_TINYGRAD)))
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_RFDETR_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
# Pattern rule for docker-save targets
docker-save-%: backend-images
docker save local-ai-backend:$* -o backend-images/$*.tar
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx
########################################################
### Mock Backend for E2E Tests
@@ -1312,12 +1213,6 @@ build-mock-backend: protogen-go
clean-mock-backend:
rm -f tests/e2e/mock-backend/mock-backend
build-cloud-proxy-backend: protogen-go
$(GOCMD) build -o tests/e2e/mock-backend/cloud-proxy ./backend/go/cloud-proxy
clean-cloud-proxy-backend:
rm -f tests/e2e/mock-backend/cloud-proxy
########################################################
### UI E2E Test Server
########################################################
@@ -1328,50 +1223,6 @@ build-ui-test-server: build-mock-backend react-ui protogen-go
test-ui-e2e: build-ui-test-server
cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test
## Optional Playwright worker count for the UI e2e targets below. Pass
## UI_TEST_WORKERS=N (e.g. `make test-ui-coverage UI_TEST_WORKERS=20`) to
## override Playwright's default (cores/2). Empty by default so Playwright
## picks its own worker count.
UI_TEST_WORKERS ?=
PLAYWRIGHT_WORKERS_FLAG = $(if $(UI_TEST_WORKERS),--workers=$(UI_TEST_WORKERS),)
## Fast Playwright e2e run used by the pre-commit hook on React UI changes.
## Force-rebuilds the (non-instrumented) dist so the suite tests the working
## tree — not a stale dist the `react-ui` skip-guard would leave — re-embeds
## it into ui-test-server, and runs the specs. Uses the nix-provided browser
## when PLAYWRIGHT_CHROMIUM_PATH is set (flake dev shell), else falls back to
## downloading it as `test-ui-e2e` does.
test-ui: build-mock-backend protogen-go
cd core/http/react-ui && bun install && bun run build
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui
cd core/http/react-ui && sh $(CURDIR)/scripts/ensure-playwright-browser.sh && bunx playwright test $(PLAYWRIGHT_WORKERS_FLAG)
## React UI code coverage from the Playwright e2e suite. Builds a
## NON-instrumented bundle with source maps (COVERAGE_V8=true), re-embeds it
## into the ui-test-server (the dist is //go:embed'ed at compile time), runs the
## Playwright specs which collect native Chromium V8 coverage (PW_V8_COVERAGE=1)
## — far cheaper than istanbul's build-time counters (~40% faster end-to-end) —
## convert it to istanbul via v8-to-istanbul in the coverage fixture, and write
## an nyc report to core/http/react-ui/coverage/. Removes the dist afterwards so
## normal builds aren't served source-mapped assets. (The legacy istanbul path
## still exists: `bun run build:coverage` + unset PW_V8_COVERAGE.)
test-ui-coverage: build-mock-backend protogen-go
trap 'rm -rf "$(CURDIR)/core/http/react-ui/dist"' EXIT; \
( cd core/http/react-ui && bun install && bun run build:coverage-v8 ) && \
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui && \
( cd core/http/react-ui && rm -rf .nyc_output coverage && \
sh $(CURDIR)/scripts/ensure-playwright-browser.sh && \
PW_V8_COVERAGE=1 bunx playwright test $(PLAYWRIGHT_WORKERS_FLAG) && bun run coverage:report )
## UI coverage baseline (committed) and the strict gate that compares against
## it — the React mirror of test-coverage-baseline / test-coverage-check.
test-ui-coverage-baseline: test-ui-coverage
@node -e 'const fs=require("fs");process.stdout.write(String(JSON.parse(fs.readFileSync("core/http/react-ui/coverage/coverage-summary.json")).total.lines.pct))' > core/http/react-ui/coverage-baseline.txt
@echo "Saved UI coverage baseline: $$(cat core/http/react-ui/coverage-baseline.txt)% lines"
test-ui-coverage-check: test-ui-coverage
sh $(CURDIR)/scripts/ui-coverage-check.sh core/http/react-ui/coverage/coverage-summary.json core/http/react-ui/coverage-baseline.txt
test-ui-e2e-docker:
docker build -t localai-ui-e2e -f tests/e2e-ui/Dockerfile .
docker run --rm localai-ui-e2e

View File

@@ -149,10 +149,8 @@ For more details, see the [Getting Started guide](https://localai.io/basics/gett
## Latest News
- **May 2026**: **LocalAI 4.3.0** - `llama.cpp` [prompt cache on by default](https://github.com/mudler/LocalAI/pull/9925) (repeated system prompts collapse from minutes to seconds), [keyless cosign signing of backend OCI images](https://github.com/mudler/LocalAI/pull/9823), [per-API-key + per-user usage attribution](https://github.com/mudler/LocalAI/pull/9920), Distributed v3 with [per-request replica routing](https://github.com/mudler/LocalAI/pull/9968). [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.3.0)
- **May 2026**: **LocalAI 4.2.0** - LocalAI sees and hears: [voice recognition](https://github.com/mudler/LocalAI/pull/9500), [face recognition + antispoofing liveness](https://github.com/mudler/LocalAI/pull/9480), speaker diarization. Plus [drop-in Ollama API](https://github.com/mudler/LocalAI/pull/9284), [video generation](https://github.com/mudler/LocalAI/pull/9420), redesigned UI with i18n + admin-configurable branding, vLLM at feature parity with llama.cpp, and 11 new backends. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.2.0)
- **April 2026**: **LocalAI 4.1.0** - LocalAI becomes a control tower: distributed cluster mode with VRAM-aware smart routing + autoscaling, multi-user platform with OIDC and API keys, per-user quotas with predictive analytics, in-UI fine-tuning with TRL (auto-export to GGUF), on-the-fly quantization backend, visual pipeline editor. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.1.0)
- **March 2026**: **LocalAI 4.0.0** - native agentic orchestration with the new [Agenthub](https://agenthub.localai.io) community hub, full React UI rewrite with Canvas mode, [MCP Apps + client-side](https://github.com/mudler/LocalAI/pull/8947) with tool streaming, [WebRTC realtime audio](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed](https://github.com/mudler/LocalAI/pull/8801). [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.0.0)
- **April 2026**: [Voice recognition](https://github.com/mudler/LocalAI/pull/9500), [Face recognition, identification & liveness detection](https://github.com/mudler/LocalAI/pull/9480), [Ollama API compatibility](https://github.com/mudler/LocalAI/pull/9284), [Video generation in stable-diffusion.ggml](https://github.com/mudler/LocalAI/pull/9420), [Backend versioning with auto-upgrade](https://github.com/mudler/LocalAI/pull/9315), [Pin models & load-on-demand toggle](https://github.com/mudler/LocalAI/pull/9309), [Universal model importer](https://github.com/mudler/LocalAI/pull/9466), new backends: [sglang](https://github.com/mudler/LocalAI/pull/9359), [ik-llama-cpp](https://github.com/mudler/LocalAI/pull/9326), [TurboQuant](https://github.com/mudler/LocalAI/pull/9355), [sam.cpp](https://github.com/mudler/LocalAI/pull/9288), [Kokoros](https://github.com/mudler/LocalAI/pull/9212), [qwen3tts.cpp](https://github.com/mudler/LocalAI/pull/9316), [tinygrad multimodal](https://github.com/mudler/LocalAI/pull/9364)
- **March 2026**: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
- **February 2026**: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
- **January 2026**: **LocalAI 3.10.0** — Anthropic API support, Open Responses API, video & image generation (LTX-2), unified GPU backends, tool streaming, Moonshine, Pocket-TTS. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0)
- **December 2025**: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic multi-GPU model fitting (llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
@@ -238,22 +236,11 @@ A huge thank you to our generous sponsors who support this project covering CI e
<a href="https://www.spectrocloud.com/" target="blank">
<img height="200" src="https://github.com/user-attachments/assets/72eab1dd-8b93-4fc0-9ade-84db49f24962">
</a>
</p>
<details>
<summary>
Past sponsors
</summary>
<p align="center">
<a href="https://www.premai.io/" target="blank">
<img height="200" src="https://github.com/mudler/LocalAI/assets/2420543/42e4ca83-661e-4f79-8e46-ae43689683d6"> <br>
</a>
</p>
</details>
### Individual sponsors
A special thanks to individual sponsors, a full list is on [GitHub](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler). Special shout out to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!

View File

@@ -37,22 +37,6 @@ service Backend {
rpc Rerank(RerankRequest) returns (RerankResult) {}
// TokenClassify runs a token-classification (NER) model on the
// supplied text and returns each detected entity span. Used by the
// PII redactor's optional NER tier — the regex tier still handles
// formatted hits cheaply, while this catches names, locations, and
// other unformatted PII that regex misses.
rpc TokenClassify(TokenClassifyRequest) returns (TokenClassifyResponse) {}
// Score evaluates the model's joint log-probability of each
// supplied candidate continuation given a shared prompt. The
// prompt's KV cache is computed once and reused across candidates.
// Used for routing-policy multi-label classification, reranking,
// calibrated confidence, and reward-model scoring — any task where
// the consumer wants the model's confidence in a pre-specified
// continuation rather than a generated one.
rpc Score(ScoreRequest) returns (ScoreResponse) {}
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
rpc VAD(VADRequest) returns (VADResponse) {}
@@ -84,23 +68,6 @@ service Backend {
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
// Forward proxies a raw HTTP request to an upstream provider. The
// cloud-proxy backend implements this for passthrough-mode model
// configs: the client wire format is preserved end-to-end (no
// translation through internal proto), which means new provider
// fields work the day they ship. Translation-mode proxies use the
// standard Predict/PredictStream RPCs instead. Backends that don't
// support this return UNIMPLEMENTED.
//
// The request is bidirectionally streamed so large bodies can flow
// without buffering. In practice the first ForwardRequest carries
// path, method, headers, and the initial body chunk; subsequent
// messages append body chunks. The first ForwardReply carries the
// upstream status and response headers; subsequent messages stream
// body chunks (SSE frames or chunked transfer). Cancellation of the
// gRPC context closes the upstream connection.
rpc Forward(stream ForwardRequest) returns (stream ForwardReply) {}
}
// Define the empty request
@@ -114,76 +81,6 @@ message MetricsResponse {
int32 prompt_tokens_processed = 5;
}
// TokenClassifyRequest carries the text to classify plus an optional
// score threshold. The transformers backend interprets threshold as
// the minimum confidence to include in the response; 0 = include all.
message TokenClassifyRequest {
string text = 1;
float threshold = 2;
}
// TokenClassifyEntity is one detected entity span. Byte offsets are
// into the original UTF-8 text — start..end is a half-open range that
// addresses the substring corresponding to entity_group.
//
// entity_group follows HuggingFace's aggregated-tag convention (e.g.
// "PER", "LOC", "ORG", or a PII-specific label like "EMAIL" /
// "SSN" depending on the model). The redactor's per-pattern action
// map keys off this string.
message TokenClassifyEntity {
string entity_group = 1;
int32 start = 2;
int32 end = 3;
float score = 4;
string text = 5;
}
message TokenClassifyResponse {
repeated TokenClassifyEntity entities = 1;
}
// ScoreRequest carries one shared prompt and one or more continuations
// to score against it. The backend tokenises the prompt once and reuses
// the resulting KV cache across all candidates in this request.
message ScoreRequest {
string prompt = 1;
repeated string candidates = 2;
// Return per-token logprobs for each candidate when true. Default
// false to keep the wire response small; the joint log_prob field
// covers the common ranking case.
bool include_token_logprobs = 3;
// When true, the response also populates length_normalized_log_prob
// (joint log-prob divided by candidate token count). Useful when
// candidates differ in length and the consumer wants a per-token
// measure comparable across them (PMI-style scoring).
bool length_normalize = 4;
}
// CandidateScore is one row in the ScoreResponse, matching by index
// the candidate in ScoreRequest.candidates.
message CandidateScore {
// Sum of log P(token_i | prompt, candidate_token_<i) across the
// candidate's tokens. The primary ranking signal.
double log_prob = 1;
// log_prob / num_tokens — populated when length_normalize=true on
// the request.
double length_normalized_log_prob = 2;
// Per-token detail — populated when include_token_logprobs=true.
repeated TokenLogProb tokens = 3;
// Number of tokens the backend tokenised this candidate into, after
// any backend-specific normalisation (e.g. leading-space handling).
int32 num_tokens = 4;
}
message TokenLogProb {
string token = 1;
double log_prob = 2;
}
message ScoreResponse {
repeated CandidateScore candidates = 1;
}
message RerankRequest {
string query = 1;
repeated string documents = 2;
@@ -428,25 +325,6 @@ message ModelOptions {
// applied verbatim to the backend's engine constructor (e.g. vLLM AsyncEngineArgs).
// Unknown keys produce an error at LoadModel time.
string EngineArgs = 73;
// Proxy carries the cloud-proxy backend's per-model configuration.
// Empty for non-proxy backends.
ProxyOptions Proxy = 74;
}
// ProxyOptions configures the cloud-proxy backend. UpstreamURL and
// Mode are always meaningful; Provider only matters in translate mode.
// The two api_key_* fields are mutually exclusive and resolved by the
// backend at LoadModel — core forwards the references rather than the
// plaintext key.
message ProxyOptions {
string upstream_url = 1;
string mode = 2;
string provider = 3;
string api_key_env = 4;
string api_key_file = 5;
string upstream_model = 6;
int32 request_timeout_seconds = 7;
}
message Result {
@@ -1124,32 +1002,3 @@ message QuantizationStopRequest {
string job_id = 1;
}
// ForwardHeader is one HTTP header on the request or response. Headers
// like Authorization are typically injected by the backend (from the
// resolved API key) rather than passed through from the client.
message ForwardHeader {
string name = 1;
string value = 2;
}
// ForwardRequest is a streamed HTTP request to the upstream. First
// message carries path/method/headers; subsequent messages carry
// body_chunk only. All fields except body_chunk are honoured on the
// first message and ignored thereafter.
message ForwardRequest {
string path = 1; // e.g. "/v1/chat/completions" — appended to the model's upstream_url
string method = 2; // usually "POST"
repeated ForwardHeader headers = 3;
bytes body_chunk = 4;
}
// ForwardReply is a streamed HTTP response from the upstream. First
// message carries status/headers; subsequent messages carry body_chunk
// only. SSE responses arrive as a sequence of body_chunk frames; the
// caller is responsible for any parsing.
message ForwardReply {
int32 status = 1;
repeated ForwardHeader headers = 2;
bytes body_chunk = 3;
}

View File

@@ -60,11 +60,6 @@ elseif(DS4_GPU STREQUAL "cpu")
set(DS4_OBJS "${DS4_DIR}/ds4_cpu.o")
endif()
# ds4.c now references ds4_distributed.c (distributed inference was split into
# its own translation unit upstream). It is a single GPU-agnostic object shared
# by every GPU mode, so link it in regardless of DS4_GPU.
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_distributed.o")
add_executable(${TARGET}
grpc-server.cpp
dsml_parser.cpp

View File

@@ -1,10 +1,10 @@
# ds4 backend Makefile.
#
# Upstream pin lives below as DS4_VERSION?=e16ead1e29c81a67bbb64e5b001117679cf9ce6e
# Upstream pin lives below as DS4_VERSION?=2606543be7a8c125a32cee37f5d1d85dc78f2fcf
# (.github/bump_deps.sh) can find and update it - matches the
# llama-cpp / ik-llama-cpp / turboquant convention.
DS4_VERSION?=e16ead1e29c81a67bbb64e5b001117679cf9ce6e
DS4_VERSION?=2606543be7a8c125a32cee37f5d1d85dc78f2fcf
DS4_REPO?=https://github.com/antirez/ds4
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
@@ -18,19 +18,16 @@ UNAME_S := $(shell uname -s)
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
# ds4_distributed.o is a GPU-agnostic translation unit that ds4.c/ds4_cpu.o now
# reference (upstream split distributed inference into its own .c). The same
# object is shared by every GPU mode, so it is appended unconditionally below.
ifeq ($(BUILD_TYPE),cublas)
CMAKE_ARGS += -DDS4_GPU=cuda
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o
DS4_OBJ_TARGET := ds4.o ds4_cuda.o
else ifeq ($(UNAME_S),Darwin)
CMAKE_ARGS += -DDS4_GPU=metal
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o
DS4_OBJ_TARGET := ds4.o ds4_metal.o
else
# CPU reference path (Linux only - macOS CPU path is broken by VM bug per ds4 README).
CMAKE_ARGS += -DDS4_GPU=cpu
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o
DS4_OBJ_TARGET := ds4_cpu.o
endif
ifneq ($(NATIVE),true)
@@ -55,11 +52,11 @@ ds4:
# the right per-platform compile flags (Objective-C/Metal on Darwin, nvcc on Linux+CUDA).
ds4/ds4.o: ds4
ifeq ($(BUILD_TYPE),cublas)
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o
+$(MAKE) -C ds4 ds4.o ds4_cuda.o
else ifeq ($(UNAME_S),Darwin)
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o
+$(MAKE) -C ds4 ds4.o ds4_metal.o
else
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o
+$(MAKE) -C ds4 ds4_cpu.o
endif
grpc-server: ds4/ds4.o

View File

@@ -1,5 +1,5 @@
IK_LLAMA_VERSION?=8960c5ba5ee9db30ba838304373aa4dbec9f7cbd
IK_LLAMA_VERSION?=11a1fea9e291f12ce2c803a9d7812c30ca806bcf
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
CMAKE_ARGS?=

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=22d66b567eef11cf2e9832f04db64ee0323a0fd0
LLAMA_VERSION?=ad277572619fcfb6ddd38f4c6437283a4b2b8636
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?=

View File

@@ -34,7 +34,6 @@
#include <regex>
#include <algorithm>
#include <atomic>
#include <cmath>
#include <cstdlib>
#include <fstream>
#include <iterator>
@@ -122,40 +121,6 @@ static std::string base64_encode_bytes(const unsigned char* data, size_t len) {
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
// Score bypasses the slot loop (see the comment on Score below) so it
// must not run concurrently with any slot-loop RPC. These counters
// are a defence-in-depth tripwire — ModelConfig.Validate already
// rejects llama-cpp configs that mix score with chat/completion/
// embeddings, so a healthy deployment never trips them. seq_cst is
// load-bearing for the increment-then-check pattern below.
static std::atomic<int> slot_loop_inflight{0};
static std::atomic<int> score_inflight{0};
// Increment-then-check, not check-then-increment: two simultaneous
// racers both observe the other's increment and both abort cleanly.
// Reversed, both could see zero and proceed.
struct conflict_guard {
std::atomic<int>& self;
conflict_guard(const char* rpc, std::atomic<int>& self_, std::atomic<int>& other, const char* other_name)
: self(self_) {
self.fetch_add(1, std::memory_order_seq_cst);
int o = other.load(std::memory_order_seq_cst);
if (o > 0) {
fprintf(stderr,
"FATAL: %s called with %s=%d. The llama-cpp backend cannot "
"service Score and slot-loop RPCs concurrently — Score "
"bypasses the slot loop and races the llama_context. Bind "
"Score-using features to a model dedicated to scoring "
"(known_usecases: [score] with no chat/completion/embeddings).\n",
rpc, other_name, o);
std::abort();
}
}
~conflict_guard() {
self.fetch_sub(1, std::memory_order_seq_cst);
}
};
static std::function<void(int)> shutdown_handler;
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
@@ -552,33 +517,10 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
params.warmup = true;
// no_op_offload: disable host tensor op offload (default: false)
params.no_op_offload = false;
// kv_unified: enable unified KV cache. Upstream's server auto-enables this
// when the slot count is auto (-np <0), bumping n_parallel to 4 alongside.
// LocalAI keeps n_parallel=1 by default, which would skip that auto path
// and leave kv_unified=false. We flip the default to true here so the
// server-side prompt cache (cache_idle_slots) is actually usable on the
// single-slot path that LocalAI ships with: without it, idle slots are
// never persisted across requests and the prompt cache is dead weight.
// Users can opt out with `options: [ "kv_unified:false" ]`.
params.kv_unified = true;
// n_ctx_checkpoints: max context checkpoints per slot. Match upstream's
// default (32); the previous LocalAI-specific 8 was unnecessarily tight
// and limits partial-prefix recovery without a clear memory rationale.
params.n_ctx_checkpoints = 32;
// cache_idle_slots: save and clear idle slot KV to the prompt cache on
// task switch. Upstream default is true; the server auto-disables it if
// kv_unified=false or cache_ram_mib=0, so flipping kv_unified above is
// what actually unlocks it.
params.cache_idle_slots = true;
// checkpoint_min_step: minimum spacing between context checkpoints in
// tokens (0 disables the minimum). Match upstream's default (256). This
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
// also shifted from a fixed cadence to a minimum spacing. The turboquant
// fork branched before the field existed, so skip it on the legacy path
// (LOCALAI_LEGACY_LLAMA_CPP_SPEC is injected by patch-grpc-server.sh).
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
params.checkpoint_min_step = 256;
#endif
// kv_unified: enable unified KV cache (default: false)
params.kv_unified = false;
// n_ctx_checkpoints: max context checkpoints per slot (default: 8)
params.n_ctx_checkpoints = 8;
// decode options. Options are in form optname:optvale, or if booleans only optname.
for (int i = 0; i < request->options_size(); i++) {
@@ -737,44 +679,10 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
try {
params.n_ctx_checkpoints = std::stoi(optval_str);
} catch (const std::exception& e) {
// If conversion fails, keep default value (32)
// If conversion fails, keep default value (8)
}
}
// --- server-side idle-slot prompt cache toggle (upstream --cache-idle-slots) ---
// Saves the slot's KV state into the host-side prompt cache on task
// switch so a later request with the same prefix can warm-load it.
// Auto-disabled by the server if kv_unified=false or cache_ram=0.
} else if (!strcmp(optname, "cache_idle_slots") || !strcmp(optname, "idle_slots_cache")) {
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
params.cache_idle_slots = true;
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
params.cache_idle_slots = false;
}
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
// with existing user configs: upstream renamed the field and shifted its
// semantics from a fixed cadence to a minimum spacing.
//
// Gated out for the turboquant fork, which lacks common_params::
// checkpoint_min_step. The leading `}` closing the cache_idle_slots
// branch is removed with this block; the next `} else if` (n_ubatch)
// then closes cache_idle_slots, so braces stay balanced under both
// preprocessor branches.
} else if (!strcmp(optname, "checkpoint_min_step") || !strcmp(optname, "checkpoint_min_spacing") ||
!strcmp(optname, "checkpoint_every_nt") || !strcmp(optname, "checkpoint_every_n_tokens")) {
if (optval != NULL) {
try {
params.checkpoint_min_step = std::stoi(optval_str);
} catch (const std::exception& e) {
// If conversion fails, keep default value (256)
}
}
#endif
// --- physical batch size (upstream -ub / --ubatch-size) ---
// Note: line ~482 already aliases n_ubatch to n_batch as a default; this
// option lets users decouple the two (useful for embeddings/rerank).
@@ -1177,15 +1085,9 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
}
// The draft tensor_buft_overrides are only populated under the modern
// (post-#22838) layout, whose population code is itself gated by
// LOCALAI_LEGACY_LLAMA_CPP_SPEC above. The turboquant fork lacks
// common_params_speculative::draft entirely, so skip the sentinel there too.
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
}
#endif
// TODO: Add yarn
@@ -1505,7 +1407,6 @@ public:
if (params_base.model.path.empty()) {
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
}
conflict_guard guard("PredictStream", slot_loop_inflight, score_inflight, "score_inflight");
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
@@ -2265,7 +2166,6 @@ public:
if (params_base.model.path.empty()) {
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
}
conflict_guard guard("Predict", slot_loop_inflight, score_inflight, "score_inflight");
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
data["stream"] = false;
@@ -3024,7 +2924,6 @@ public:
if (params_base.model.path.empty()) {
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
}
conflict_guard guard("Embedding", slot_loop_inflight, score_inflight, "score_inflight");
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
body["stream"] = false;
@@ -3132,8 +3031,6 @@ public:
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
}
conflict_guard guard("Rerank", slot_loop_inflight, score_inflight, "score_inflight");
// Create and queue the task
auto rd = ctx_server.get_response_reader();
{
@@ -3206,218 +3103,12 @@ public:
return grpc::Status::OK;
}
// Score returns the model's joint log-probability of each candidate
// continuation given a shared prompt.
//
// WHY bypass the slot/task queue: upstream server_context exposes
// get_llama_context as "main thread only" and the slot loop's
// update_slots() owns the context whenever a task is in flight.
// No public synchronization primitive is available — so Score is
// unsafe to call concurrently with active generation through this
// backend. In practice routing-classifier calls happen before the
// request is routed to a generation backend, so the model used
// for Score is typically idle. Concurrent Score calls are
// serialised by a local mutex; KV-cache state is isolated behind
// a dedicated sequence ID cleared between candidates.
//
// A patch to server-context.cpp that adds SERVER_TASK_TYPE_SCORE
// and routes scoring through the slot loop would be the correct
// long-term fix; tracked as a follow-up.
//
// Perf TODO (measured: ~450 ms warm for 3 candidates on Arch-
// Router-1.5B Q4_K_M + Intel SYCL): the current loop re-decodes
// `prompt + candidate` from scratch for every candidate, throwing
// away the prompt's KV cache between iterations. A smarter
// version would:
// 1. Decode just the prompt once into score_seq_id.
// 2. Snapshot/cp that sequence (llama_memory_seq_cp) into a
// per-candidate sequence id.
// 3. For each candidate, decode only its tokens onto the copy
// (continuing from the saved prompt state), read logits.
// 4. llama_memory_seq_rm the copy.
// Estimated speedup: 3-candidate calls 450 ms -> ~150-200 ms,
// 6-candidate calls 630 ms -> ~220 ms. Single source-file change,
// no proto / Go-side changes needed. Worth doing once routing is
// wired into the middleware and Score is on the hot path of every
// chat request.
grpc::Status Score(ServerContext* context, const backend::ScoreRequest* request, backend::ScoreResponse* response) override {
auto auth = checkAuth(context);
if (!auth.ok()) return auth;
if (params_base.model.path.empty()) {
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
}
if (request->candidates_size() == 0) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "candidates must be non-empty");
}
// Tripwire against the slot loop. Acquired before score_mutex
// so it fires even when this Score is queued behind another.
conflict_guard guard("Score", score_inflight, slot_loop_inflight, "slot_loop_inflight");
// Serialise concurrent Score calls. The slot loop is still
// free to race with us — see the class comment above.
static std::mutex score_mutex;
std::lock_guard<std::mutex> score_lock(score_mutex);
llama_context * lctx = ctx_server.get_llama_context();
if (lctx == nullptr) {
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "llama context unavailable (sleeping?)");
}
const llama_vocab * vocab = ctx_server.impl->vocab;
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
const int32_t n_ctx = llama_n_ctx(lctx);
llama_memory_t mem = llama_get_memory(lctx);
// The KV-cache is sized to seq_to_stream.size() at load
// (typically equal to n_slots, often 1). Sequence IDs must
// be in [0, n_seq_max), so we can't pick a high-value
// "private" ID — we have to share with the slot. We clear
// the cache before AND after each candidate to keep
// scoring isolated from whatever state the slot held, and
// the static mutex above guarantees no other Score call is
// racing in the meantime. The slot loop is still free to
// race (see comment on this method) — Score must not run
// concurrently with generation through this backend.
const llama_seq_id score_seq_id = 0;
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
// Tokenize the shared prompt once with add_special=true so
// BOS is prepended when the model requires it. parse_special
// keeps chat-template markers in the prompt intact.
const std::string prompt = request->prompt();
std::vector<llama_token> prompt_tokens = common_tokenize(vocab, prompt, /*add_special=*/true, /*parse_special=*/true);
const int32_t prompt_len = (int32_t) prompt_tokens.size();
for (int ci = 0; ci < request->candidates_size(); ci++) {
const std::string & candidate_text = request->candidates(ci);
// Re-tokenize prompt + candidate as a single string. BPE
// merges across the boundary can shift the tokenization
// versus tokenize(prompt) ++ tokenize(candidate), so we
// find the divergence point against prompt_tokens.
std::vector<llama_token> full_tokens = common_tokenize(vocab, prompt + candidate_text, /*add_special=*/true, /*parse_special=*/true);
int32_t divergence = prompt_len;
const int32_t min_len = std::min<int32_t>(prompt_len, (int32_t) full_tokens.size());
for (int32_t i = 0; i < min_len; i++) {
if (prompt_tokens[i] != full_tokens[i]) {
divergence = i;
break;
}
}
const int32_t cand_len = (int32_t) full_tokens.size() - divergence;
backend::CandidateScore * cs = response->add_candidates();
cs->set_num_tokens(cand_len);
if (cand_len <= 0) {
cs->set_log_prob(0.0);
if (request->length_normalize()) {
cs->set_length_normalized_log_prob(0.0);
}
continue;
}
if (divergence < 1) {
// Need at least one prior token (typically BOS) to
// predict the first candidate token's logit. Tokeniser
// models without BOS + an empty prompt fall in here.
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"Score: prompt produced no leading tokens; need at least one (e.g. BOS) to predict candidate");
}
if ((int32_t) full_tokens.size() > n_ctx) {
return grpc::Status(grpc::StatusCode::OUT_OF_RANGE,
"Score: prompt+candidate exceeds context size (got " +
std::to_string(full_tokens.size()) + ", n_ctx=" + std::to_string(n_ctx) + ")");
}
// Build a batch covering the entire prompt+candidate. We
// need logits at (divergence-1) onward — those are the
// predictions for each candidate token.
llama_batch batch = llama_batch_init((int32_t) full_tokens.size(), 0, 1);
for (int32_t i = 0; i < (int32_t) full_tokens.size(); i++) {
batch.token[i] = full_tokens[i];
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = score_seq_id;
// logits[i] is "do we want the prediction *for the
// next token*, computed from this position?"
// We want predictions for candidate tokens at
// positions divergence .. full_tokens.size()-1, which
// come from logits at positions (divergence-1) ..
// (full_tokens.size()-2).
bool need_logit = (i >= divergence - 1) && (i < (int32_t) full_tokens.size() - 1);
batch.logits[i] = need_logit ? 1 : 0;
}
batch.n_tokens = (int32_t) full_tokens.size();
// Decode the batch. If decode fails (e.g. KV slot
// exhaustion), surface as INTERNAL — the caller will
// typically fall back to a sampling-based classifier.
int decode_err = llama_decode(lctx, batch);
if (decode_err != 0) {
llama_batch_free(batch);
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
return grpc::Status(grpc::StatusCode::INTERNAL,
"llama_decode failed during Score: " + std::to_string(decode_err));
}
// Sum log-probabilities of the actual candidate tokens.
double total_log_prob = 0.0;
for (int32_t k = 0; k < cand_len; k++) {
// The k-th candidate token sits at full_tokens index
// (divergence + k). Its predicting logit is at batch
// position (divergence + k - 1).
int32_t logit_pos = divergence + k - 1;
const float * logits = llama_get_logits_ith(lctx, logit_pos);
if (logits == nullptr) {
llama_batch_free(batch);
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
return grpc::Status(grpc::StatusCode::INTERNAL,
"llama_get_logits_ith returned null at position " + std::to_string(logit_pos));
}
llama_token target_token = full_tokens[divergence + k];
// Compute log_softmax(logits)[target_token] with the
// max-subtraction stability trick.
float max_logit = logits[0];
for (int32_t v = 1; v < n_vocab; v++) {
if (logits[v] > max_logit) max_logit = logits[v];
}
double sum_exp = 0.0;
for (int32_t v = 0; v < n_vocab; v++) {
sum_exp += std::exp((double)(logits[v] - max_logit));
}
double token_log_prob = (double)(logits[target_token] - max_logit) - std::log(sum_exp);
total_log_prob += token_log_prob;
if (request->include_token_logprobs()) {
backend::TokenLogProb * tlp = cs->add_tokens();
std::string piece = common_token_to_piece(lctx, target_token);
tlp->set_token(piece);
tlp->set_log_prob(token_log_prob);
}
}
cs->set_log_prob(total_log_prob);
if (request->length_normalize() && cand_len > 0) {
cs->set_length_normalized_log_prob(total_log_prob / (double) cand_len);
}
llama_batch_free(batch);
// Drop this candidate's KV-cache contribution so the next
// candidate starts from a clean state. Without this, the
// next decode would conflict at positions 0..N-1 for our
// sequence ID.
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
}
return grpc::Status::OK;
}
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
auto auth = checkAuth(context);
if (!auth.ok()) return auth;
if (params_base.model.path.empty()) {
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
}
conflict_guard guard("TokenizeString", slot_loop_inflight, score_inflight, "score_inflight");
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
body["stream"] = false;
@@ -3439,8 +3130,6 @@ public:
grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override {
conflict_guard guard("GetMetrics", slot_loop_inflight, score_inflight, "score_inflight");
// request slots data using task queue
auto rd = ctx_server.get_response_reader();
int task_id = rd.queue_tasks.get_new_id();

View File

@@ -1,7 +1,7 @@
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
TURBOQUANT_VERSION?=4c1c3ac09d2dba0aa9a55b94f6c50c41a92f9c8c
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
CMAKE_ARGS?=

View File

@@ -1,30 +1,23 @@
#!/bin/bash
# Patch the shared backend/cpp/llama-cpp/grpc-server.cpp *copy* used by the
# turboquant build to account for the gaps between upstream and the fork:
# turboquant build:
#
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
# server-side random per-instance marker) with the legacy "<__media__>"
# literal. The fork branched before that PR, so server-common.cpp has no
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
# "<__media__>", and Go-side tooling falls back to that sentinel when the
# backend does not expose media_marker, so substituting the literal keeps
# behavior identical on the turboquant path.
# 3. Revert the `common_params_speculative` field references to the
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
# the turboquant fork branched before that PR and still exposes the flat
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
# below map the new nested paths back to the legacy flat names so the
# shared grpc-server.cpp keeps compiling against the fork's common.h.
# Drop this block once the fork rebases past #22397.
#
# Historical context: this script used to also paper over API gaps between the
# fork and upstream (flat vs nested `common_params_speculative`, missing
# `get_media_marker()`, `ctx_server.impl->model` vs `model_tgt`, and a
# LOCALAI_LEGACY_LLAMA_CPP_SPEC compile gate). As of TURBOQUANT_VERSION
# 4c1c3ac0 the fork has rebased past ggml-org/llama.cpp#21962, #22397 and
# #22838, so the shared grpc-server.cpp compiles unmodified against the fork.
# Only the fork-specific KV-cache enum entries remain.
#
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
# under backend/cpp/llama-cpp/, so the stock llama-cpp build stays compiling
# against vanilla upstream.
#
# Idempotent: skips each insertion if its marker is already present (so re-runs
# Idempotent: skips the insertion if its marker is already present (so re-runs
# of the same build dir don't double-insert).
set -euo pipefail
@@ -52,7 +45,7 @@ else
awk '
/^ GGML_TYPE_Q5_1,$/ && !done {
print
print " // turboquant fork extras added by patch-grpc-server.sh"
print " // turboquant fork extras - added by patch-grpc-server.sh"
print " GGML_TYPE_TURBO2_0,"
print " GGML_TYPE_TURBO3_0,"
print " GGML_TYPE_TURBO4_0,"
@@ -72,86 +65,4 @@ else
echo "==> KV allow-list patch OK"
fi
if grep -q 'get_media_marker()' "$SRC"; then
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
# Only one call site today (ModelMetadata), but replace all occurrences to
# stay robust if upstream adds more. Use a temp file to avoid relying on
# sed -i portability (the builder image uses GNU sed, but keeping this
# consistent with the awk block above).
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> get_media_marker() substitution OK"
else
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
fi
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
# Each substitution is the exact post-refactor path → legacy flat field.
# Order doesn't matter because the source paths are disjoint, but we keep
# the most-specific (mparams.path) first for readability.
sed -E \
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
"$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> speculative field rename OK"
else
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
fi
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
# exposes the field as `model` on `server_context_impl`. The two call sites
# are in the Rerank and ModelMetadata RPC handlers.
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> model_tgt rename OK"
else
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
fi
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
# draft.tensor_buft_overrides) introduced for the post-#22838 layout, the
# draft.tensor_buft_overrides sentinel termination, and the
# common_params::checkpoint_min_step default/option (added with the
# 35c9b1f3 bump). Those blocks reference struct fields that simply do not
# exist in the fork.
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
else
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
# Insert the define before the very first `#include` so it precedes all the
# speculative-decoding code paths.
awk '
!done && /^#include/ {
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
print ""
done = 1
}
{ print }
END {
if (!done) {
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
exit 1
}
}
' "$SRC" > "$SRC.tmp"
mv "$SRC.tmp" "$SRC"
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
fi
echo "==> all patches applied"

View File

@@ -1,12 +0,0 @@
GOCMD=go
cloud-proxy:
CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o cloud-proxy ./
package:
bash package.sh
build: cloud-proxy package
clean:
rm -f cloud-proxy

View File

@@ -1,16 +0,0 @@
package main
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// Ginkgo bootstrap. The other Test* functions in this package use
// raw testing.T and run independently; they coexist with Ginkgo
// specs registered via Describe / Context.
func TestCloudProxySpecs(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "cloud-proxy specs")
}

View File

@@ -1,39 +0,0 @@
package main
// cloud-proxy is a LocalAI backend that forwards request traffic to an
// external HTTP provider (OpenAI, Anthropic, etc.). Two modes:
//
// - passthrough: serves the Forward RPC; the client wire format is
// preserved end-to-end, no translation.
// - translate: serves Predict/PredictStream; the backend converts
// internal proto to the provider's wire format. (Phases 56.)
//
// LoadModel reads UpstreamURL/Mode/Provider/key references from
// ProxyOptions and resolves the API key once at load time.
import (
"flag"
"os"
grpc "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/xlog"
"golang.org/x/term"
)
var addr = flag.String("addr", "localhost:50051", "the address to listen on")
func main() {
// xlog's default handler emits ANSI color codes; that's fine for an
// interactive shell but unreadable when the backend's stdout is
// captured by LocalAI and tee'd to a log file. Force plain text when
// LOCALAI_LOG_FORMAT is unset and stdout isn't a terminal.
format := os.Getenv("LOCALAI_LOG_FORMAT")
if format == "" && !term.IsTerminal(int(os.Stdout.Fd())) {
format = xlog.TextFormat
}
xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), format))
flag.Parse()
if err := grpc.StartServer(*addr, NewCloudProxy()); err != nil {
panic(err)
}
}

View File

@@ -1,13 +0,0 @@
#!/bin/bash
# Script to copy the cloud-proxy binary into the package dir for the
# final Dockerfile stage. Mirrors backend/go/local-store/package.sh —
# no extra runtime libs needed since the backend is pure Go.
set -e
CURDIR=$(dirname "$(realpath $0)")
mkdir -p $CURDIR/package
cp -avf $CURDIR/cloud-proxy $CURDIR/package/
cp -rfv $CURDIR/run.sh $CURDIR/package/

View File

@@ -1,325 +0,0 @@
package main
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strconv"
"sync"
grpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("composeURL", func() {
// Upstream URL convention: gallery configs put the canonical path
// in upstream_url, so per-request Path is ignored. A bare-host
// upstream_url accepts the per-request path.
DescribeTable("path resolution",
func(upstream, reqPath, want string) {
got, err := composeURL(upstream, reqPath)
Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(want))
},
Entry("full path wins", "https://api.openai.com/v1/chat/completions", "/v1/something-else", "https://api.openai.com/v1/chat/completions"),
Entry("bare host accepts path", "https://api.openai.com", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
Entry("root slash treated as bare", "https://api.openai.com/", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
Entry("bare host + empty path", "https://api.openai.com", "", "https://api.openai.com"),
)
It("returns an error on invalid upstream URL", func() {
_, err := composeURL("://garbage", "")
Expect(err).To(HaveOccurred())
})
})
var _ = Describe("applyAuthHeader", func() {
It("sets x-api-key and anthropic-version for Anthropic, no Authorization", func() {
req, _ := http.NewRequest("POST", "https://example.com", nil)
applyAuthHeader(req, providerAnthropic, "ant-key")
Expect(req.Header.Get("x-api-key")).To(Equal("ant-key"))
Expect(req.Header.Get("anthropic-version")).NotTo(BeEmpty())
Expect(req.Header.Get("Authorization")).To(BeEmpty(), "Authorization must not leak on Anthropic backend")
})
It("sets Bearer Authorization for OpenAI, no x-api-key", func() {
req, _ := http.NewRequest("POST", "https://example.com", nil)
applyAuthHeader(req, providerOpenAI, "sk-key")
Expect(req.Header.Get("Authorization")).To(Equal("Bearer sk-key"))
Expect(req.Header.Get("x-api-key")).To(BeEmpty(), "x-api-key must not leak on OpenAI backend")
})
It("defaults to Bearer when provider is empty", func() {
// Passthrough mode often has provider == "" because the operator
// doesn't claim a specific upstream wire format. Most providers
// (including OpenAI-compatible ones) accept Bearer, so default to it.
req, _ := http.NewRequest("POST", "https://example.com", nil)
applyAuthHeader(req, "", "some-key")
Expect(req.Header.Get("Authorization")).To(Equal("Bearer some-key"))
})
It("preserves an existing anthropic-version header", func() {
// If the client supplied anthropic-version (rare but legitimate
// for an upstream pinned to a specific date), the proxy must not
// clobber it.
req, _ := http.NewRequest("POST", "https://example.com", nil)
req.Header.Set("anthropic-version", "2024-10-01")
applyAuthHeader(req, providerAnthropic, "k")
Expect(req.Header.Get("anthropic-version")).To(Equal("2024-10-01"))
})
})
var _ = Describe("isHopByHopHeader", func() {
DescribeTable("hop-by-hop classification",
func(header string, want bool) {
Expect(isHopByHopHeader(header)).To(Equal(want))
},
Entry("Connection is hop-by-hop", "Connection", true),
Entry("Keep-Alive is hop-by-hop", "Keep-Alive", true),
Entry("Proxy-Connection is hop-by-hop", "Proxy-Connection", true),
Entry("Transfer-Encoding is hop-by-hop", "Transfer-Encoding", true),
Entry("TE is hop-by-hop", "TE", true),
Entry("Trailer is hop-by-hop", "Trailer", true),
Entry("Upgrade is hop-by-hop", "Upgrade", true),
Entry("Host is hop-by-hop", "Host", true),
Entry("Content-Length is hop-by-hop", "Content-Length", true),
// Case-insensitive — RFC 7230 doesn't constrain header case.
Entry("lowercase connection is hop-by-hop", "connection", true),
Entry("uppercase HOST is hop-by-hop", "HOST", true),
// Non hop-by-hop — must NOT be stripped.
Entry("Authorization is end-to-end", "Authorization", false),
Entry("Content-Type is end-to-end", "Content-Type", false),
Entry("Accept is end-to-end", "Accept", false),
Entry("X-Custom is end-to-end", "X-Custom", false),
)
})
var _ = Describe("Forward", func() {
It("strips hop-by-hop and Connection headers before upstream, preserves custom headers", func() {
gotConnection := make(chan string, 1)
gotXCustom := make(chan string, 1)
gotHost := make(chan string, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotConnection <- r.Header.Get("Connection")
gotXCustom <- r.Header.Get("X-Custom")
gotHost <- r.Header.Get("Host")
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()
cp := NewCloudProxy()
Expect(cp.Load(&pb.ModelOptions{
Proxy: &pb.ProxyOptions{
UpstreamUrl: upstream.URL,
Mode: modePassthrough,
},
})).To(Succeed())
addr := "test://forward-hopbyhop"
grpc.Provide(addr, cp)
c := grpc.NewClient(addr, true, nil, false)
stream, err := c.Forward(context.Background())
Expect(err).NotTo(HaveOccurred())
Expect(stream.Send(&pb.ForwardRequest{
Path: "/v1/chat/completions",
Method: "POST",
Headers: []*pb.ForwardHeader{
{Name: "Connection", Value: "keep-alive"},
{Name: "Host", Value: "spoofed.example.com"},
{Name: "X-Custom", Value: "preserved"},
},
})).To(Succeed())
Expect(stream.CloseSend()).To(Succeed())
_, _ = stream.Recv()
for {
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
break
}
}
Expect(<-gotConnection).To(BeEmpty(), "Connection must not leak to upstream")
Expect(<-gotHost).NotTo(Equal("spoofed.example.com"), "Host header must not be spoofed through")
Expect(<-gotXCustom).To(Equal("preserved"), "X-Custom header must survive")
})
It("replaces caller-supplied Authorization with the configured key", func() {
// The proxy must overwrite a client-supplied Authorization header
// so a downstream caller can't smuggle stale or wrong credentials.
gotAuth := make(chan string, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth <- r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()
GinkgoT().Setenv("CLOUD_PROXY_AUTH_REPLACE_KEY", "sk-real")
cp := NewCloudProxy()
Expect(cp.Load(&pb.ModelOptions{
Proxy: &pb.ProxyOptions{
UpstreamUrl: upstream.URL,
Mode: modePassthrough,
ApiKeyEnv: "CLOUD_PROXY_AUTH_REPLACE_KEY",
},
})).To(Succeed())
addr := "test://forward-replaces-auth"
grpc.Provide(addr, cp)
c := grpc.NewClient(addr, true, nil, false)
stream, err := c.Forward(context.Background())
Expect(err).NotTo(HaveOccurred())
Expect(stream.Send(&pb.ForwardRequest{
Path: "/v1/chat/completions",
Method: "POST",
Headers: []*pb.ForwardHeader{
// Client-supplied Authorization with the wrong scheme / key.
{Name: "Authorization", Value: "Basic Zm9vOmJhcg=="},
},
})).To(Succeed())
Expect(stream.CloseSend()).To(Succeed())
_, _ = stream.Recv()
for {
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
break
}
}
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
})
It("refuses to follow upstream redirects and never leaks the key to the redirect target", func() {
// A 3xx from the configured upstream means misconfiguration or a
// hijacked/spoofed host. Following it would replay the request —
// and the injected API key — to the Location host. Anthropic's
// x-api-key is NOT stripped by Go on cross-host redirects, so this
// would be a credential leak. The proxy must refuse the redirect.
sinkHit := make(chan string, 1)
sink := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sinkHit <- r.Header.Get("x-api-key")
w.WriteHeader(http.StatusOK)
}))
defer sink.Close()
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, sink.URL, http.StatusFound)
}))
defer redirector.Close()
GinkgoT().Setenv("CLOUD_PROXY_REDIRECT_KEY", "ant-secret")
cp := NewCloudProxy()
Expect(cp.Load(&pb.ModelOptions{
Proxy: &pb.ProxyOptions{
UpstreamUrl: redirector.URL,
Mode: modePassthrough,
Provider: providerAnthropic,
ApiKeyEnv: "CLOUD_PROXY_REDIRECT_KEY",
},
})).To(Succeed())
addr := "test://forward-no-redirect"
grpc.Provide(addr, cp)
c := grpc.NewClient(addr, true, nil, false)
stream, err := c.Forward(context.Background())
Expect(err).NotTo(HaveOccurred())
Expect(stream.Send(&pb.ForwardRequest{
Path: "/v1/messages",
Method: "POST",
})).To(Succeed())
Expect(stream.CloseSend()).To(Succeed())
// Drain the stream; a refused redirect surfaces as a non-EOF error.
var streamErr error
for {
if _, err := stream.Recv(); err != nil {
if !errors.Is(err, io.EOF) {
streamErr = err
}
break
}
}
Expect(streamErr).To(HaveOccurred(), "refused redirect must surface as an error")
Expect(sinkHit).NotTo(Receive(), "the redirect target must never be contacted")
})
It("handles concurrent calls without interference", func() {
// CloudProxy explicitly omits base.SingleThread — independent
// Forward streams must not block each other or leak state.
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(body)
}))
defer upstream.Close()
cp := NewCloudProxy()
Expect(cp.Load(&pb.ModelOptions{
Proxy: &pb.ProxyOptions{
UpstreamUrl: upstream.URL,
Mode: modePassthrough,
},
})).To(Succeed())
addr := "test://forward-concurrent"
grpc.Provide(addr, cp)
c := grpc.NewClient(addr, true, nil, false)
const N = 8
var wg sync.WaitGroup
errs := make(chan error, N)
for i := 0; i < N; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
stream, err := c.Forward(context.Background())
if err != nil {
errs <- err
return
}
payload := "request-" + string(rune('A'+idx))
if err := stream.Send(&pb.ForwardRequest{
Path: "/v1/chat/completions",
Method: "POST",
BodyChunk: []byte(payload),
}); err != nil {
errs <- err
return
}
_ = stream.CloseSend()
_, _ = stream.Recv()
var body []byte
for {
r, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
errs <- err
return
}
body = append(body, r.GetBodyChunk()...)
}
if string(body) != payload {
errs <- &echoMismatch{want: payload, got: string(body)}
}
}(i)
}
wg.Wait()
close(errs)
var collected []error
for err := range errs {
collected = append(collected, err)
}
Expect(collected).To(BeEmpty(), "no concurrent Forward call should fail")
})
})
type echoMismatch struct{ want, got string }
func (e *echoMismatch) Error() string {
return "echo mismatch: want " + strconv.Quote(e.want) + " got " + strconv.Quote(e.got)
}

View File

@@ -1,508 +0,0 @@
package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
)
// Anthropic Messages API wire-format types. Narrowed to what translate
// mode preserves through the Reply proto: text + tool_use blocks +
// usage tokens. Image blocks, prompt caching, metadata, and stop
// sequence metadata are not modelled — passthrough mode covers those.
//
// Notable differences from OpenAI:
// - max_tokens is REQUIRED. Anthropic 400s without it.
// - Roles are user/assistant only — system messages move to a
// top-level `system` string field.
// - Streaming SSE uses event: lines alongside data: lines. The
// events we care about: content_block_start (carries tool_use
// init: id + name), content_block_delta (text_delta with text;
// input_json_delta with partial_json for tool arguments), and
// message_stop (terminates the stream). Others are ignored.
type anthropicRequest struct {
Model string `json:"model"`
MaxTokens int32 `json:"max_tokens"`
System string `json:"system,omitempty"`
Messages []anthropicMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropicTool `json:"tools,omitempty"`
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
}
// Content is `any` because Anthropic accepts a bare string OR a
// list of content blocks. Use the string form for plain user/
// assistant turns; switch to []anthropicContentBlock when the
// turn needs tool_use (assistant) or tool_result (user) blocks.
type anthropicMessage struct {
Role string `json:"role"`
Content any `json:"content"`
}
type anthropicTool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema"`
}
// anthropicToolChoice mirrors the four shapes Anthropic accepts:
// {"type":"auto"} | {"type":"any"} | {"type":"tool","name":"X"} |
// {"type":"none"} (newer models). OpenAI's "auto"/"none"/
// "required"/{"function":{"name":"X"}} all map here.
type anthropicToolChoice struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
}
// anthropicContentBlock is the union shape used both for response
// blocks (text/tool_use we read off the wire) and outbound request
// blocks (tool_use/tool_result we emit in the conversation history).
// Anthropic encodes tool calls inline rather than as a separate field,
// so we walk Content[] looking for type=="tool_use" on responses and
// produce equivalent blocks when serialising prior-turn tool calls.
type anthropicContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
// Tool-result block fields. tool_result uses `content` (not
// `text`) and pairs with `tool_use_id`; modelling them as
// distinct fields avoids ambiguity at marshal time.
ToolUseID string `json:"tool_use_id,omitempty"`
ResultContent string `json:"content,omitempty"`
}
type anthropicResponse struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []anthropicContentBlock `json:"content"`
Model string `json:"model"`
Usage *anthropicUsage `json:"usage,omitempty"`
}
type anthropicUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// anthropicStreamEvent is the union shape used for every event type we
// process. Type discriminates; only the matching fields are populated.
// content_block_start carries ContentBlock (with id/name for tool_use);
// content_block_delta carries Delta (text or partial_json).
type anthropicStreamEvent struct {
Type string `json:"type"`
Index int `json:"index,omitempty"`
ContentBlock *anthropicContentBlock `json:"content_block,omitempty"`
Delta *anthropicStreamDelta `json:"delta,omitempty"`
Message *anthropicResponse `json:"message,omitempty"`
Usage *anthropicUsage `json:"usage,omitempty"`
}
type anthropicStreamDelta struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
}
// Anthropic requires max_tokens. If the caller didn't set it, use a
// generous-but-bounded default so the request doesn't 400.
const anthropicDefaultMaxTokens int32 = 4096
const anthropicToolChoiceNone = "none"
// Reused JSON-Schema defaults for malformed inputs. Anthropic requires
// input_schema to be a JSON object and tool_use.input to be a JSON
// object; clients that omit them must not 400 the entire request.
var (
emptyJSONObject = json.RawMessage(`{}`)
emptyObjectSchema = json.RawMessage(`{"type":"object","properties":{}}`)
)
func buildAnthropicRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
req := anthropicRequest{
Model: modelName(cfg, opts),
MaxTokens: opts.GetTokens(),
Stream: stream,
StopSequences: opts.GetStopPrompts(),
}
if req.MaxTokens <= 0 {
req.MaxTokens = anthropicDefaultMaxTokens
}
// Newer Anthropic models 400 when both temperature and top_p are
// set ("`temperature` and `top_p` cannot both be specified for
// this model. Please use only one.") even though their docs only
// "recommend" picking one. The OpenAI-compatible chat UI almost
// always sends both with default values, so prefer temperature
// and drop top_p when both are present.
if t := opts.GetTemperature(); t != 0 {
v := float64(t)
req.Temperature = &v
} else if t := opts.GetTopP(); t != 0 {
v := float64(t)
req.TopP = &v
}
req.Tools = convertOpenAITools(opts.GetTools())
req.ToolChoice = convertOpenAIToolChoice(opts.GetToolChoice())
// Anthropic rejects tool_choice without tools and older models
// don't accept {"type":"none"} — collapse to a no-tools request.
if req.ToolChoice != nil && req.ToolChoice.Type == anthropicToolChoiceNone {
req.Tools, req.ToolChoice = nil, nil
}
var systemParts []string
for _, m := range opts.GetMessages() {
role := m.GetRole()
if role == "system" {
if c := m.GetContent(); c != "" {
systemParts = append(systemParts, c)
}
continue
}
switch role {
case "user":
req.Messages = append(req.Messages, anthropicMessage{
Role: "user",
Content: m.GetContent(),
})
case "assistant":
if blocks := assistantBlocks(m); blocks != nil {
req.Messages = append(req.Messages, anthropicMessage{Role: "assistant", Content: blocks})
continue
}
req.Messages = append(req.Messages, anthropicMessage{
Role: "assistant",
Content: m.GetContent(),
})
case "tool", "function":
req.Messages = appendToolResult(req.Messages, anthropicContentBlock{
Type: "tool_result",
ToolUseID: m.GetToolCallId(),
ResultContent: m.GetContent(),
})
}
}
req.System = strings.Join(systemParts, "\n\n")
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
req.Messages = []anthropicMessage{{Role: "user", Content: opts.GetPrompt()}}
}
return json.Marshal(req)
}
// appendToolResult appends a tool_result block as a user message,
// merging into a preceding user message that already carries blocks.
// Anthropic concatenates consecutive same-role messages on its end,
// but explicit merging keeps the body smaller and the conversation
// strictly alternating — which some upstream filters require.
func appendToolResult(msgs []anthropicMessage, block anthropicContentBlock) []anthropicMessage {
if n := len(msgs); n > 0 && msgs[n-1].Role == "user" {
if existing, ok := msgs[n-1].Content.([]anthropicContentBlock); ok {
msgs[n-1].Content = append(existing, block)
return msgs
}
}
return append(msgs, anthropicMessage{
Role: "user",
Content: []anthropicContentBlock{block},
})
}
func convertOpenAITools(toolsJSON string) []anthropicTool {
if toolsJSON == "" {
return nil
}
var raw []openAITool
if err := json.Unmarshal([]byte(toolsJSON), &raw); err != nil {
xlog.Warn("cloud-proxy: anthropic translate: unparseable tools JSON, dropping", "error", err)
return nil
}
tools := make([]anthropicTool, 0, len(raw))
for _, t := range raw {
if t.Function.Name == "" {
continue
}
schema := t.Function.Parameters
if len(schema) == 0 {
schema = emptyObjectSchema
}
tools = append(tools, anthropicTool{
Name: t.Function.Name,
Description: t.Function.Description,
InputSchema: schema,
})
}
return tools
}
// convertOpenAIToolChoice accepts the spec form
// ({type:function, function:{name:X}}) and the flat legacy form
// ({type:function, name:X}) some clients send. Unknown object shapes
// are warned and dropped rather than silently treated as auto.
func convertOpenAIToolChoice(toolChoiceJSON string) *anthropicToolChoice {
if toolChoiceJSON == "" {
return nil
}
var asString string
if err := json.Unmarshal([]byte(toolChoiceJSON), &asString); err == nil {
switch asString {
case "auto":
return &anthropicToolChoice{Type: "auto"}
case "none":
return &anthropicToolChoice{Type: anthropicToolChoiceNone}
case "required":
return &anthropicToolChoice{Type: "any"}
}
return nil
}
var asObj struct {
Type string `json:"type"`
Name string `json:"name"`
Function struct {
Name string `json:"name"`
} `json:"function"`
}
if err := json.Unmarshal([]byte(toolChoiceJSON), &asObj); err != nil {
xlog.Warn("cloud-proxy: anthropic translate: unparseable tool_choice, dropping", "error", err)
return nil
}
if name := asObj.Function.Name; name != "" {
return &anthropicToolChoice{Type: "tool", Name: name}
}
if asObj.Name != "" {
return &anthropicToolChoice{Type: "tool", Name: asObj.Name}
}
xlog.Warn("cloud-proxy: anthropic translate: unrecognised tool_choice shape, dropping", "shape", toolChoiceJSON)
return nil
}
// openAITool mirrors pkg/functions.Tool but keeps Parameters as
// json.RawMessage so the input_schema passes through verbatim — no
// re-marshal cost, no fidelity loss on exotic schemas.
type openAITool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters json.RawMessage `json:"parameters"`
} `json:"function"`
}
func assistantBlocks(m *pb.Message) []anthropicContentBlock {
toolCallsJSON := m.GetToolCalls()
if toolCallsJSON == "" {
return nil
}
var toolCalls []openAIToolCall
if err := json.Unmarshal([]byte(toolCallsJSON), &toolCalls); err != nil || len(toolCalls) == 0 {
return nil
}
blocks := make([]anthropicContentBlock, 0, len(toolCalls)+1)
if text := m.GetContent(); text != "" {
blocks = append(blocks, anthropicContentBlock{Type: "text", Text: text})
}
for _, tc := range toolCalls {
// OpenAI's arguments are a JSON-encoded string; pass through
// as RawMessage so a non-JSON string from a poorly-formed
// local model doesn't crash the marshaller downstream.
args := json.RawMessage(tc.Function.Arguments)
if len(args) == 0 {
args = emptyJSONObject
}
blocks = append(blocks, anthropicContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: args,
})
}
return blocks
}
// doAnthropicRequest is the Anthropic counterpart of doOpenAIRequest.
// applyAuthHeader sets x-api-key and anthropic-version when provider
// is anthropic, so this method doesn't need to duplicate that.
func (c *CloudProxy) doAnthropicRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
if cfg.apiKey != "" {
applyAuthHeader(req, cfg.provider, cfg.apiKey)
}
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
}
return resp, nil
}
// predictAnthropicRich returns the full Reply: joined text from all
// text blocks, tool_use blocks mapped to ToolCallDelta, and usage
// tokens.
func (c *CloudProxy) predictAnthropicRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
body, err := buildAnthropicRequest(opts, cfg, false)
if err != nil {
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
}
resp, err := c.doAnthropicRequest(ctx, cfg, body)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
}
var parsed anthropicResponse
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
}
reply := &pb.Reply{}
if parsed.Usage != nil {
reply.PromptTokens = int32(parsed.Usage.InputTokens)
reply.Tokens = int32(parsed.Usage.OutputTokens)
}
var content strings.Builder
var toolCalls []*pb.ToolCallDelta
toolIdx := 0
for _, b := range parsed.Content {
switch b.Type {
case "text":
content.WriteString(b.Text)
case "tool_use":
// Input is a structured JSON object; we serialise to a
// string so it fits the OpenAI-shaped arguments field
// downstream consumers expect.
args := ""
if len(b.Input) > 0 {
args = string(b.Input)
}
toolCalls = append(toolCalls, newToolCallDelta(toolIdx, b.ID, b.Name, args))
toolIdx++
}
}
reply.Message = []byte(content.String())
if len(toolCalls) > 0 {
reply.ChatDeltas = []*pb.ChatDelta{{ToolCalls: toolCalls}}
}
return reply, nil
}
// predictAnthropicStreamRich streams Reply chunks from Anthropic's SSE.
// Three event types matter: content_block_start (initialises tool_use
// id+name), content_block_delta (carries text or input_json_delta),
// message_stop (terminates). The block index from the wire feeds
// straight into ToolCallDelta.Index so downstream consumers can
// reassemble multiple parallel tool calls.
func (c *CloudProxy) predictAnthropicStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
body, err := buildAnthropicRequest(opts, cfg, true)
if err != nil {
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
}
resp, err := c.doAnthropicRequest(ctx, cfg, body)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" {
continue
}
var ev anthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
continue
}
switch ev.Type {
case "content_block_start":
// tool_use blocks announce id + name here; arguments arrive
// in subsequent input_json_delta events. Emit a Reply with
// just the tool_call init fields so consumers can allocate
// a slot at this index.
if ev.ContentBlock != nil && ev.ContentBlock.Type == "tool_use" {
if !sendReply(ctx, results, &pb.Reply{
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
newToolCallDelta(ev.Index, ev.ContentBlock.ID, ev.ContentBlock.Name, ""),
}}},
}) {
return ctx.Err()
}
}
case "content_block_delta":
if ev.Delta == nil {
continue
}
switch ev.Delta.Type {
case "text_delta":
if ev.Delta.Text == "" {
continue
}
if !sendReply(ctx, results, &pb.Reply{
Message: []byte(ev.Delta.Text),
ChatDeltas: []*pb.ChatDelta{{Content: ev.Delta.Text}},
}) {
return ctx.Err()
}
case "input_json_delta":
if ev.Delta.PartialJSON == "" {
continue
}
if !sendReply(ctx, results, &pb.Reply{
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
newToolCallDelta(ev.Index, "", "", ev.Delta.PartialJSON),
}}},
}) {
return ctx.Err()
}
}
case "message_delta":
// Anthropic sends final usage in message_delta.usage. Emit
// a usage-only Reply so the consumer can record totals.
if ev.Usage != nil {
if !sendReply(ctx, results, &pb.Reply{
Tokens: int32(ev.Usage.OutputTokens),
}) {
return ctx.Err()
}
}
case "message_stop":
return nil
}
}
return scanner.Err()
}

View File

@@ -1,334 +0,0 @@
package main
import (
"encoding/json"
"io"
"math"
"net/http"
"net/http/httptest"
"strings"
"testing"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/gomega"
)
// fakeAnthropicUpstream mirrors fakeOpenAIUpstream but decodes the
// request body as an anthropicRequest so tests can assert on the
// translated wire shape (system field, max_tokens, etc.).
func fakeAnthropicUpstream(t *testing.T, handler func(req anthropicRequest) (status int, body string, contentType string)) (*httptest.Server, *anthropicRequest) {
t.Helper()
var captured anthropicRequest
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(raw, &captured)
status, body, ct := handler(captured)
w.Header().Set("Content-Type", ct)
w.WriteHeader(status)
_, _ = io.WriteString(w, body)
}))
return srv, &captured
}
func newAnthropicTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
t.Helper()
g := NewWithT(t)
t.Setenv("CLOUD_PROXY_ANTHROPIC_FAKE", "sk-ant-fake")
cp := NewCloudProxy()
err := cp.Load(&pb.ModelOptions{
Model: "claude-local",
Proxy: &pb.ProxyOptions{
UpstreamUrl: upstreamURL,
Mode: modeTranslate,
Provider: providerAnthropic,
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_FAKE",
UpstreamModel: "claude-3-5-sonnet-20241022",
},
})
g.Expect(err).NotTo(HaveOccurred())
return cp
}
func TestPredict_Anthropic_BasicMessages(t *testing.T) {
g := NewWithT(t)
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hi there"}],"model":"claude-3-5-sonnet-20241022","usage":{"input_tokens":5,"output_tokens":2}}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
got, err := cp.Predict(&pb.PredictOptions{
Messages: []*pb.Message{
{Role: "system", Content: "be brief"},
{Role: "user", Content: "hello"},
},
Temperature: 0.5,
TopP: 0.9,
Tokens: 32,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(got).To(Equal("hi there"))
g.Expect(captured.Model).To(Equal("claude-3-5-sonnet-20241022"))
// System message must be hoisted out of Messages into top-level field.
g.Expect(captured.System).To(Equal("be brief"))
g.Expect(captured.Messages).To(HaveLen(1))
g.Expect(captured.Messages[0].Role).To(Equal("user"))
g.Expect(captured.MaxTokens).To(Equal(int32(32)))
g.Expect(captured.Temperature).NotTo(BeNil())
g.Expect(*captured.Temperature).To(Equal(0.5))
// Anthropic 400s when both temperature and top_p are set; the
// translator must prefer temperature and drop top_p.
g.Expect(captured.TopP).To(BeNil())
g.Expect(captured.Stream).To(BeFalse())
}
// When only top_p is set, it should be forwarded.
func TestPredict_Anthropic_TopPOnly(t *testing.T) {
g := NewWithT(t)
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
TopP: 0.9,
Tokens: 16,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.Temperature).To(BeNil())
// PredictOptions.TopP is float32 on the wire; the translator widens
// to float64 so 0.9 round-trips as 0.8999999761581421… — compare
// with a small tolerance rather than exact equality.
g.Expect(captured.TopP).NotTo(BeNil())
g.Expect(math.Abs(*captured.TopP - 0.9)).To(BeNumerically("<=", 1e-6))
}
func TestPredict_Anthropic_DefaultsMaxTokens(t *testing.T) {
g := NewWithT(t)
// Anthropic 400s without max_tokens. The translator must default
// it when the caller doesn't supply Tokens.
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.MaxTokens).To(Equal(anthropicDefaultMaxTokens))
}
func TestPredict_Anthropic_PromptFallback(t *testing.T) {
g := NewWithT(t)
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?", Tokens: 16})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.Messages).To(HaveLen(1))
g.Expect(captured.Messages[0].Role).To(Equal("user"))
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
}
func TestPredict_Anthropic_ConcatenatesContentBlocks(t *testing.T) {
g := NewWithT(t)
// Anthropic may return multiple text blocks; the translator joins
// them so the Predict() string return is the full assistant message.
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"content":[{"type":"text","text":"hello "},{"type":"text","text":"world"}]}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(got).To(Equal("hello world"))
}
func TestPredict_Anthropic_UpstreamError(t *testing.T) {
g := NewWithT(t)
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 401, `{"error":{"type":"authentication_error","message":"bad key"}}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("401"))
}
func TestPredictStream_Anthropic_StreamsTextDeltas(t *testing.T) {
g := NewWithT(t)
// Real Anthropic SSE has event: lines + data: lines. The translator
// only needs the data: payload; only content_block_delta with
// delta.type=text_delta carries content. message_stop ends.
frames := []string{
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0}\n\n",
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n",
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" \"}}\n\n",
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"world\"}}\n\n",
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
}
body := strings.Join(frames, "")
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, body, "text/event-stream"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
results := make(chan string, 8)
done := make(chan error, 1)
go func() {
done <- cp.PredictStream(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
Tokens: 16,
}, results)
}()
var got []string
for s := range results {
got = append(got, s)
}
err := <-done
g.Expect(err).NotTo(HaveOccurred())
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
g.Expect(captured.Stream).To(BeTrue())
}
func TestBuildAnthropic_TranslatesOpenAITools(t *testing.T) {
g := NewWithT(t)
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
tools := `[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}]`
_, err := cp.Predict(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "weather in Paris?"}},
Tools: tools,
ToolChoice: `"auto"`,
Tokens: 32,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.Tools).To(HaveLen(1))
g.Expect(captured.Tools[0].Name).To(Equal("get_weather"))
g.Expect(captured.Tools[0].Description).To(Equal("Get weather"))
// input_schema must be the parameters object verbatim.
g.Expect(string(captured.Tools[0].InputSchema)).To(ContainSubstring(`"city"`))
g.Expect(captured.ToolChoice).NotTo(BeNil())
g.Expect(captured.ToolChoice.Type).To(Equal("auto"))
}
func TestBuildAnthropic_ToolChoice_RequiredMapsToAny(t *testing.T) {
g := NewWithT(t)
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"content":[]}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "x"}},
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
ToolChoice: `"required"`,
Tokens: 16,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.ToolChoice).NotTo(BeNil())
g.Expect(captured.ToolChoice.Type).To(Equal("any"))
}
func TestBuildAnthropic_ToolChoice_NoneDropsTools(t *testing.T) {
g := NewWithT(t)
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"content":[]}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "x"}},
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
ToolChoice: `"none"`,
Tokens: 16,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.Tools).To(BeNil())
g.Expect(captured.ToolChoice).To(BeNil())
}
func TestBuildAnthropic_ToolChoice_NamedFunction(t *testing.T) {
g := NewWithT(t)
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"content":[]}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "x"}},
Tools: `[{"type":"function","function":{"name":"weather","parameters":{"type":"object"}}}]`,
ToolChoice: `{"type":"function","function":{"name":"weather"}}`,
Tokens: 16,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.ToolChoice).NotTo(BeNil())
g.Expect(captured.ToolChoice.Type).To(Equal("tool"))
g.Expect(captured.ToolChoice.Name).To(Equal("weather"))
}
func TestBuildAnthropic_RoundTripsAssistantToolCalls(t *testing.T) {
g := NewWithT(t)
// LocalAI Assistant's second turn: the LLM previously emitted a
// tool_use, the server executed it, and the conversation now
// includes the assistant turn (with tool_calls) plus a tool-role
// result message. Both must convert to Anthropic block form.
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
tools := `[{"type":"function","function":{"name":"list_models","parameters":{"type":"object"}}}]`
toolCallsJSON := `[{"id":"call_abc","type":"function","function":{"name":"list_models","arguments":"{}"}}]`
_, err := cp.Predict(&pb.PredictOptions{
Tools: tools,
Messages: []*pb.Message{
{Role: "user", Content: "what models are installed?"},
{Role: "assistant", Content: "", ToolCalls: toolCallsJSON},
{Role: "tool", Content: `{"models":["a","b"]}`, ToolCallId: "call_abc"},
},
Tokens: 64,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.Messages).To(HaveLen(3))
// 1. user text — bare string
s, ok := captured.Messages[0].Content.(string)
g.Expect(ok).To(BeTrue())
g.Expect(s).To(Equal("what models are installed?"))
// 2. assistant — must be a content-block list with one tool_use
// json.Unmarshal of `any` produces []any not []anthropicContentBlock.
blocks, ok := captured.Messages[1].Content.([]any)
g.Expect(ok).To(BeTrue())
g.Expect(blocks).To(HaveLen(1))
b0, _ := blocks[0].(map[string]any)
g.Expect(b0["type"]).To(Equal("tool_use"))
g.Expect(b0["id"]).To(Equal("call_abc"))
g.Expect(b0["name"]).To(Equal("list_models"))
// 3. tool → user with tool_result block
g.Expect(captured.Messages[2].Role).To(Equal("user"))
resBlocks, _ := captured.Messages[2].Content.([]any)
r0, _ := resBlocks[0].(map[string]any)
g.Expect(r0["type"]).To(Equal("tool_result"))
g.Expect(r0["tool_use_id"]).To(Equal("call_abc"))
g.Expect(r0["content"]).To(Equal(`{"models":["a","b"]}`))
}

View File

@@ -1,119 +0,0 @@
package main
import (
"encoding/json"
"strings"
"testing"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/gomega"
)
// Verify buildOpenAIRequest preserves caller-supplied tools and
// tool_choice as opaque JSON. PredictOptions carries them as strings;
// they must land in the outbound request body unchanged so the
// upstream sees the caller's intent verbatim. A regression here would
// silently disable function calling for translate-mode clients.
func TestBuildOpenAIRequest_ToolsAndToolChoicePassthrough(t *testing.T) {
g := NewWithT(t)
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
toolsJSON := `[{"type":"function","function":{"name":"search","parameters":{"type":"object"}}}]`
choiceJSON := `{"type":"function","function":{"name":"search"}}`
body, err := buildOpenAIRequest(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "find x"}},
Tools: toolsJSON,
ToolChoice: choiceJSON,
}, cfg, false)
g.Expect(err).NotTo(HaveOccurred())
var decoded openAIRequest
err = json.Unmarshal(body, &decoded)
g.Expect(err).NotTo(HaveOccurred())
// Compare the JSON-canonical form so whitespace differences are ignored.
gotTools, _ := json.Marshal(json.RawMessage(decoded.Tools))
wantTools, _ := json.Marshal(json.RawMessage(toolsJSON))
g.Expect(string(gotTools)).To(Equal(string(wantTools)))
gotChoice, _ := json.Marshal(json.RawMessage(decoded.ToolChoice))
wantChoice, _ := json.Marshal(json.RawMessage(choiceJSON))
g.Expect(string(gotChoice)).To(Equal(string(wantChoice)))
}
// Garbage JSON in tools / tool_choice is silently dropped (omitted)
// rather than blowing up the request. Documents the parseRawJSON
// behaviour — operators shouldn't see hard failures from an upstream
// caller's mis-formatted tools field.
func TestBuildOpenAIRequest_InvalidToolsJSONDropped(t *testing.T) {
g := NewWithT(t)
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
body, err := buildOpenAIRequest(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "x"}},
Tools: "this is not json",
ToolChoice: "{also bad",
}, cfg, false)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(string(body)).NotTo(ContainSubstring("this is not json"))
g.Expect(string(body)).NotTo(ContainSubstring("{also bad"))
}
// Anthropic empty content array yields an empty Reply (not an error).
// Mirrors how an upstream tool_use-only response might arrive — the
// content array can legitimately be empty in some edge cases.
func TestPredictRich_Anthropic_EmptyContent(t *testing.T) {
g := NewWithT(t)
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{"id":"m1","type":"message","role":"assistant","content":[],"usage":{"input_tokens":3,"output_tokens":0}}`, "application/json"
})
defer srv.Close()
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
reply, err := cp.PredictRich(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "x"}},
Tokens: 16,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(string(reply.GetMessage())).To(Equal(""))
g.Expect(reply.GetChatDeltas()).To(HaveLen(0))
g.Expect(reply.GetPromptTokens()).To(Equal(int32(3)))
}
// A truncated / malformed SSE payload mid-stream should be tolerated:
// the malformed chunk gets skipped (xlog.Debug logged), valid chunks
// before AND after it still reach the channel.
func TestPredictStreamRich_OpenAI_TolerantOfBadChunks(t *testing.T) {
g := NewWithT(t)
body := strings.Join([]string{
`data: {"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
``,
`data: this-is-not-json{{`,
``,
`data: {"choices":[{"index":0,"delta":{"content":" world"}}]}`,
``,
`data: [DONE]`,
``,
}, "\n")
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, body, "text/event-stream"
})
defer srv.Close()
cp := newTranslateCloudProxy(t, srv.URL)
results := make(chan *pb.Reply, 8)
done := make(chan error, 1)
go func() {
done <- cp.PredictStreamRich(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, results)
close(results)
}()
var assembled strings.Builder
for reply := range results {
assembled.Write(reply.GetMessage())
}
err := <-done
g.Expect(err).NotTo(HaveOccurred())
// The good chunks before and after the malformed one both made it through.
g.Expect(assembled.String()).To(Equal("hello world"))
}

View File

@@ -1,320 +0,0 @@
package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
)
// OpenAI Chat Completions wire-format types. Narrowed to the fields
// translate mode needs to preserve through the Reply proto: content,
// role, tool_calls (typed so we can map them to pb.ToolCallDelta),
// and sampling params copied verbatim from PredictOptions.
//
// Provider-specific extensions (logit_bias, function calling beyond
// tool_calls, etc.) are not modelled — passthrough mode covers callers
// that need full upstream fidelity.
type openAIRequest struct {
Model string `json:"model"`
Messages []openAIMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
MaxTokens *int32 `json:"max_tokens,omitempty"`
Stop []string `json:"stop,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
}
type openAIMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
}
// openAIToolCall covers both the non-streaming response shape (full
// id+function+arguments) and the streaming-delta shape (sparse fields,
// index assignment). The proto's ToolCallDelta absorbs both — name is
// set on first appearance, arguments arrive incrementally in streaming.
type openAIToolCall struct {
Index int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function openAIFunctionCall `json:"function,omitempty"`
}
type openAIFunctionCall struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
type openAIChoice struct {
Index int `json:"index"`
Message openAIMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type openAIResponse struct {
ID string `json:"id"`
Choices []openAIChoice `json:"choices"`
Usage *openAIUsage `json:"usage,omitempty"`
}
type openAIStreamChoice struct {
Index int `json:"index"`
Delta struct {
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
} `json:"delta"`
FinishReason string `json:"finish_reason,omitempty"`
}
type openAIStreamChunk struct {
Choices []openAIStreamChoice `json:"choices"`
Usage *openAIUsage `json:"usage,omitempty"`
}
type openAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// buildOpenAIRequest converts pb.PredictOptions into the OpenAI Chat
// Completions request body. Prefers Messages when non-empty; falls
// back to wrapping Prompt as a single user message so plain
// /completions-style calls still work in translate mode.
func buildOpenAIRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
req := openAIRequest{
Model: modelName(cfg, opts),
Stream: stream,
Stop: opts.GetStopPrompts(),
Tools: parseRawJSON(opts.GetTools()),
ToolChoice: parseRawJSON(opts.GetToolChoice()),
}
if t := opts.GetTemperature(); t != 0 {
v := float64(t)
req.Temperature = &v
}
if t := opts.GetTopP(); t != 0 {
v := float64(t)
req.TopP = &v
}
if n := opts.GetTokens(); n > 0 {
req.MaxTokens = &n
}
if p := opts.GetFrequencyPenalty(); p != 0 {
v := float64(p)
req.FrequencyPenalty = &v
}
if p := opts.GetPresencePenalty(); p != 0 {
v := float64(p)
req.PresencePenalty = &v
}
for _, m := range opts.GetMessages() {
msg := openAIMessage{
Role: m.GetRole(),
Content: m.GetContent(),
Name: m.GetName(),
ToolCallID: m.GetToolCallId(),
}
// Pre-existing tool_calls arrive as a JSON string from the
// upstream caller's previous assistant turn; pass-through as-is.
if tc := m.GetToolCalls(); tc != "" {
_ = json.Unmarshal([]byte(tc), &msg.ToolCalls)
}
req.Messages = append(req.Messages, msg)
}
// Fallback for plain Prompt requests (no Messages array). LocalAI
// templating may have produced a flat prompt; rewrap as a single
// user message so the upstream chat endpoint accepts it.
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
req.Messages = []openAIMessage{{Role: "user", Content: opts.GetPrompt()}}
}
return json.Marshal(req)
}
// modelName picks the upstream model: upstream_model from the proxy
// config wins (operator override), else the local model name captured
// at LoadModel time. Operator sets upstream_model to map LocalAI's
// alias (e.g. "claude-strict") to the upstream's canonical name
// (e.g. "claude-3-5-sonnet-20241022").
func modelName(cfg *proxyConfig, _ *pb.PredictOptions) string {
if cfg.upstreamModel != "" {
return cfg.upstreamModel
}
return cfg.localModel
}
// parseRawJSON parses a JSON string into a RawMessage so it round-trips
// into the upstream body. Returns nil for empty/invalid input so the
// field is omitted (omitempty).
func parseRawJSON(s string) json.RawMessage {
if s == "" {
return nil
}
var probe json.RawMessage
if err := json.Unmarshal([]byte(s), &probe); err != nil {
return nil
}
return probe
}
// doOpenAIRequest builds + sends the upstream request. Returns the
// raw response on success; caller handles status / body.
func (c *CloudProxy) doOpenAIRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
if cfg.apiKey != "" {
applyAuthHeader(req, cfg.provider, cfg.apiKey)
}
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
}
return resp, nil
}
// predictOpenAIRich is the non-streaming translate path. Returns a
// fully-populated *pb.Reply with assistant content, tool calls, and
// token usage. The gRPC server forwards the Reply verbatim.
func (c *CloudProxy) predictOpenAIRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
body, err := buildOpenAIRequest(opts, cfg, false)
if err != nil {
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
}
resp, err := c.doOpenAIRequest(ctx, cfg, body)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
}
var parsed openAIResponse
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
}
if len(parsed.Choices) == 0 {
return nil, errors.New("cloud-proxy: upstream returned no choices")
}
choice := parsed.Choices[0]
reply := &pb.Reply{
Message: []byte(choice.Message.Content),
}
if parsed.Usage != nil {
reply.PromptTokens = int32(parsed.Usage.PromptTokens)
reply.Tokens = int32(parsed.Usage.CompletionTokens)
}
if len(choice.Message.ToolCalls) > 0 {
// Non-streaming: a single ChatDelta carries the full tool-call
// set. Index/Name/Arguments are populated together; downstream
// consumers don't need to assemble streaming deltas.
delta := &pb.ChatDelta{}
for _, tc := range choice.Message.ToolCalls {
delta.ToolCalls = append(delta.ToolCalls,
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
}
reply.ChatDeltas = []*pb.ChatDelta{delta}
}
return reply, nil
}
// predictOpenAIStreamRich streams *pb.Reply chunks. Each chunk carries
// either a content delta (Message + ChatDeltas[].Content) or tool-call
// deltas (ChatDeltas[].ToolCalls). The final Reply carries usage tokens
// when the upstream sends them (stream_options.include_usage).
func (c *CloudProxy) predictOpenAIStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
body, err := buildOpenAIRequest(opts, cfg, true)
if err != nil {
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
}
resp, err := c.doOpenAIRequest(ctx, cfg, body)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" || payload == "[DONE]" {
return nil
}
var chunk openAIStreamChunk
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
continue
}
// Usage frames may arrive separately from content frames when
// stream_options.include_usage is set; emit a usage-only Reply
// in that case so the consumer sees the totals.
if chunk.Usage != nil && len(chunk.Choices) == 0 {
if !sendReply(ctx, results, &pb.Reply{
PromptTokens: int32(chunk.Usage.PromptTokens),
Tokens: int32(chunk.Usage.CompletionTokens),
}) {
return ctx.Err()
}
continue
}
for _, ch := range chunk.Choices {
reply := &pb.Reply{}
if ch.Delta.Content != "" {
reply.Message = []byte(ch.Delta.Content)
reply.ChatDeltas = []*pb.ChatDelta{{Content: ch.Delta.Content}}
}
if len(ch.Delta.ToolCalls) > 0 {
if len(reply.ChatDeltas) == 0 {
reply.ChatDeltas = []*pb.ChatDelta{{}}
}
for _, tc := range ch.Delta.ToolCalls {
reply.ChatDeltas[0].ToolCalls = append(reply.ChatDeltas[0].ToolCalls,
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
}
}
if reply.Message == nil && len(reply.ChatDeltas) == 0 {
continue
}
if !sendReply(ctx, results, reply) {
return ctx.Err()
}
}
}
return scanner.Err()
}

View File

@@ -1,170 +0,0 @@
package main
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// fakeOpenAIUpstream returns an httptest.Server that decodes the
// inbound request as an openAIRequest, calls handler with it, and
// writes the handler's reply as the response.
func fakeOpenAIUpstream(t *testing.T, handler func(req openAIRequest) (status int, body string, contentType string)) (*httptest.Server, *openAIRequest) {
t.Helper()
var captured openAIRequest
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(raw, &captured)
status, body, ct := handler(captured)
w.Header().Set("Content-Type", ct)
w.WriteHeader(status)
_, _ = io.WriteString(w, body)
}))
return srv, &captured
}
func newTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
t.Helper()
g := NewWithT(t)
t.Setenv("CLOUD_PROXY_OPENAI_FAKE", "sk-fake-openai")
cp := NewCloudProxy()
err := cp.Load(&pb.ModelOptions{
Model: "gpt-4o-local",
Proxy: &pb.ProxyOptions{
UpstreamUrl: upstreamURL,
Mode: modeTranslate,
Provider: providerOpenAI,
ApiKeyEnv: "CLOUD_PROXY_OPENAI_FAKE",
UpstreamModel: "gpt-4o",
},
})
g.Expect(err).NotTo(HaveOccurred())
return cp
}
func TestPredict_OpenAI_BasicChat(t *testing.T) {
g := NewWithT(t)
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, `{"id":"resp-1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, "application/json"
})
defer srv.Close()
cp := newTranslateCloudProxy(t, srv.URL)
got, err := cp.Predict(&pb.PredictOptions{
Messages: []*pb.Message{
{Role: "system", Content: "be brief"},
{Role: "user", Content: "hello"},
},
Temperature: 0.5,
TopP: 0.9,
Tokens: 32,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(got).To(Equal("hi there"))
// Verify the upstream saw a properly-translated request.
g.Expect(captured.Model).To(Equal("gpt-4o"))
g.Expect(captured.Messages).To(HaveLen(2))
g.Expect(captured.Messages[0].Role).To(Equal("system"))
g.Expect(captured.Messages[1].Role).To(Equal("user"))
g.Expect(captured.Temperature).NotTo(BeNil())
g.Expect(*captured.Temperature).To(Equal(0.5))
g.Expect(captured.MaxTokens).NotTo(BeNil())
g.Expect(*captured.MaxTokens).To(Equal(int32(32)))
g.Expect(captured.Stream).To(BeFalse())
}
func TestPredict_OpenAI_PromptFallback(t *testing.T) {
g := NewWithT(t)
// No Messages array — backend should synth a single user message
// from Prompt so non-chat clients still route through translate.
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`, "application/json"
})
defer srv.Close()
cp := newTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?"})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.Messages).To(HaveLen(1))
g.Expect(captured.Messages[0].Role).To(Equal("user"))
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
}
func TestPredict_OpenAI_UpstreamError(t *testing.T) {
g := NewWithT(t)
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 401, `{"error":{"message":"bad key"}}`, "application/json"
})
defer srv.Close()
cp := newTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("401"))
}
func TestPredictStream_OpenAI_StreamsContent(t *testing.T) {
g := NewWithT(t)
// Stream three content deltas then [DONE]. Verify the channel
// receives them in order with no missing pieces.
chunks := []string{
`{"choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
`{"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
`{"choices":[{"index":0,"delta":{"content":" "}}]}`,
`{"choices":[{"index":0,"delta":{"content":"world"}}]}`,
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
}
body := ""
for _, c := range chunks {
body += "data: " + c + "\n\n"
}
body += "data: [DONE]\n\n"
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, body, "text/event-stream"
})
defer srv.Close()
cp := newTranslateCloudProxy(t, srv.URL)
results := make(chan string, 8)
done := make(chan error, 1)
go func() {
done <- cp.PredictStream(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, results)
}()
var got []string
for s := range results {
got = append(got, s)
}
err := <-done
g.Expect(err).NotTo(HaveOccurred())
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
g.Expect(captured.Stream).To(BeTrue())
}
func TestPredict_RejectedInPassthroughMode(t *testing.T) {
g := NewWithT(t)
t.Setenv("CLOUD_PROXY_FAKE", "k")
cp := NewCloudProxy()
err := cp.Load(&pb.ModelOptions{
Proxy: &pb.ProxyOptions{
UpstreamUrl: "https://example.com",
Mode: modePassthrough,
ApiKeyEnv: "CLOUD_PROXY_FAKE",
},
})
g.Expect(err).NotTo(HaveOccurred())
_, err = cp.Predict(&pb.PredictOptions{})
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("only valid in translate"))
}

View File

@@ -1,435 +0,0 @@
package main
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync/atomic"
"github.com/mudler/xlog"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/httpclient"
)
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
// import core to keep the boundary clean.
const (
modePassthrough = "passthrough"
modeTranslate = "translate"
providerOpenAI = "openai"
providerAnthropic = "anthropic"
)
// CloudProxy is the LocalAI backend that proxies model traffic to a
// configured upstream HTTP provider. Concurrency: base.SingleThread is
// NOT embedded — forward calls are independent and HTTP transport is
// goroutine-safe, so multiple Forward streams can run in parallel.
// Locking would serialise requests to a chat provider for no benefit.
type CloudProxy struct {
base.Base
cfg atomic.Pointer[proxyConfig]
client *http.Client
}
type proxyConfig struct {
upstreamURL string
mode string
provider string
upstreamModel string
localModel string // ModelOptions.Model — fallback when upstream_model is unset
apiKey string // resolved at Load time
}
func NewCloudProxy() *CloudProxy {
// httpclient.New refuses redirects outright: the proxy talks to a
// single configured upstream API (OpenAI/Anthropic/...) that answers
// directly, so a 3xx means misconfiguration, a hijacked upstream, or
// DNS trickery — never normal operation. Following it would replay the
// request, including the operator's x-api-key (which Go does NOT strip
// on cross-host redirects), to an unvetted host and leak the key
// (GHSA-3mj3-57v2-4636). It also imposes no body deadline, so streaming
// SSE responses that legitimately last minutes are not truncated.
return &CloudProxy{client: httpclient.New()}
}
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
po := opts.GetProxy()
if po == nil {
return errors.New("cloud-proxy: Load requires ProxyOptions to be set")
}
if po.GetUpstreamUrl() == "" {
return errors.New("cloud-proxy: upstream_url is required")
}
if _, err := url.ParseRequestURI(po.GetUpstreamUrl()); err != nil {
return fmt.Errorf("cloud-proxy: upstream_url %q invalid: %w", po.GetUpstreamUrl(), err)
}
mode := po.GetMode()
if mode == "" {
mode = modePassthrough
}
switch mode {
case modePassthrough:
case modeTranslate:
switch po.GetProvider() {
case providerOpenAI:
// implemented in provider_openai.go
case providerAnthropic:
// implemented in provider_anthropic.go
default:
return fmt.Errorf("cloud-proxy: translate mode requires provider in {%s, %s}, got %q",
providerOpenAI, providerAnthropic, po.GetProvider())
}
default:
return fmt.Errorf("cloud-proxy: unknown mode %q", mode)
}
key, err := resolveAPIKey(po.GetApiKeyEnv(), po.GetApiKeyFile())
if err != nil {
return err
}
c.cfg.Store(&proxyConfig{
upstreamURL: po.GetUpstreamUrl(),
mode: mode,
provider: po.GetProvider(),
upstreamModel: po.GetUpstreamModel(),
localModel: opts.GetModel(),
apiKey: key,
})
xlog.Info("cloud-proxy: ready",
"upstream", po.GetUpstreamUrl(),
"mode", mode,
"provider", po.GetProvider(),
"has_key", key != "")
return nil
}
// resolveAPIKey mirrors config.ProxyConfig.ResolveAPIKey. Duplicated
// (a few lines) rather than importing core/config from a backend
// binary — keeps backends independent of core's package layout.
// Mutual-exclusion is enforced upstream in core/config.Validate.
func resolveAPIKey(envName, filePath string) (string, error) {
if envName != "" {
v := os.Getenv(envName)
if v == "" {
return "", fmt.Errorf("cloud-proxy: api_key_env %q is unset", envName)
}
return v, nil
}
if filePath != "" {
b, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("cloud-proxy: read api_key_file %q: %w", filePath, err)
}
return strings.TrimSpace(string(b)), nil
}
return "", nil
}
// PredictRich is the non-streaming translate path. Returns a fully-
// populated *pb.Reply: content, tool-call deltas (ChatDeltas), and
// usage tokens. Implements the optional grpc.AIModelRich interface;
// the gRPC server prefers this path over Predict when present so
// tool calls survive the round-trip. Passthrough mode rejects
// PredictRich — callers must use Forward.
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
cfg := c.cfg.Load()
if cfg == nil {
return nil, errors.New("cloud-proxy: model not loaded")
}
if cfg.mode != modeTranslate {
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
}
xlog.Info("cloud-proxy: predict", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
defer func() {
if err != nil {
xlog.Warn("cloud-proxy: predict failed", "provider", cfg.provider, "error", err)
}
}()
ctx := context.Background()
switch cfg.provider {
case providerOpenAI:
return c.predictOpenAIRich(ctx, cfg, opts)
case providerAnthropic:
return c.predictAnthropicRich(ctx, cfg, opts)
default:
return nil, fmt.Errorf("cloud-proxy: predict not implemented for provider %q", cfg.provider)
}
}
// PredictStreamRich is the rich streaming counterpart of PredictRich.
// Each emitted Reply carries either a content delta, tool-call deltas,
// or usage tokens (the final upstream frame). base.Base.PredictStream
// is bypassed when AIModelRich is implemented, so the channel is
// closed by the gRPC server pump.
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
cfg := c.cfg.Load()
if cfg == nil {
return errors.New("cloud-proxy: model not loaded")
}
if cfg.mode != modeTranslate {
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
}
xlog.Info("cloud-proxy: predict-stream", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
defer func() {
if err != nil {
xlog.Warn("cloud-proxy: predict-stream failed", "provider", cfg.provider, "error", err)
}
}()
ctx := context.Background()
switch cfg.provider {
case providerOpenAI:
return c.predictOpenAIStreamRich(ctx, cfg, opts, results)
case providerAnthropic:
return c.predictAnthropicStreamRich(ctx, cfg, opts, results)
default:
return fmt.Errorf("cloud-proxy: predictStream not implemented for provider %q", cfg.provider)
}
}
// Predict is the legacy (string, error) AIModel signature. Used only
// if a caller goes through the non-rich path (it shouldn't, since
// server.go prefers PredictRich). Provided so the AIModel interface
// is satisfied for backends that haven't opted into the rich variant.
func (c *CloudProxy) Predict(opts *pb.PredictOptions) (string, error) {
reply, err := c.PredictRich(opts)
if err != nil {
return "", err
}
return string(reply.GetMessage()), nil
}
// PredictStream is the legacy chan-string streaming path. Adapts the
// rich stream by extracting only content text — tool-call-only chunks
// (no Message bytes) and usage-only chunks are silently dropped, since
// the legacy chan-string contract cannot represent them. Consumers
// that need tool calls must call PredictStreamRich directly.
func (c *CloudProxy) PredictStream(opts *pb.PredictOptions, results chan string) error {
defer close(results)
richCh := make(chan *pb.Reply)
errCh := make(chan error, 1)
go func() {
errCh <- c.PredictStreamRich(opts, richCh)
close(richCh)
}()
for reply := range richCh {
if msg := reply.GetMessage(); len(msg) > 0 {
results <- string(msg)
}
}
return <-errCh
}
// sendReply pushes one Reply onto a stream channel honouring ctx
// cancellation. Returns false on cancel so the caller can exit with
// ctx.Err(). Used by both translate-mode providers.
func sendReply(ctx context.Context, results chan<- *pb.Reply, reply *pb.Reply) bool {
select {
case results <- reply:
return true
case <-ctx.Done():
return false
}
}
// newToolCallDelta is a small constructor for the cross-provider
// tool-call delta shape. Centralised so the int32 cast and the four
// fields stay consistent across the OpenAI / Anthropic translators.
// Empty name/args are valid — Anthropic streaming announces the call
// with id+name then sends arguments incrementally; OpenAI's reverse
// pattern (args without name) also lands here.
func newToolCallDelta(index int, id, name, args string) *pb.ToolCallDelta {
return &pb.ToolCallDelta{
Index: int32(index),
Id: id,
Name: name,
Arguments: args,
}
}
// Forward shovels bytes between a Forward gRPC stream and an upstream
// HTTP request. First request message carries path/method/headers and
// the initial body chunk; subsequent messages append body chunks. The
// first reply carries upstream status + response headers; subsequent
// replies stream body chunks until the upstream connection closes.
// Cancellation of ctx (the gRPC stream context) closes the upstream
// connection.
func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest, out chan<- *pb.ForwardReply) error {
defer close(out)
cfg := c.cfg.Load()
if cfg == nil {
return errors.New("cloud-proxy: model not loaded")
}
if cfg.mode != modePassthrough {
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
}
first, ok := <-in
if !ok {
return errors.New("cloud-proxy: Forward stream closed before first request")
}
// Honour the per-request path only when the configured upstream_url
// has no path of its own — gallery convention is to put the
// canonical path in upstream_url.
fullURL, err := composeURL(cfg.upstreamURL, first.GetPath())
if err != nil {
return err
}
method := first.GetMethod()
if method == "" {
method = http.MethodPost
}
// Pipe the body in from the gRPC stream so the HTTP request can
// start before the client finishes sending. The pipe-reader is
// closed via CloseWithError on the error paths so the writer
// goroutine doesn't block forever.
pr, pw := io.Pipe()
go func() {
var writeErr error
defer func() { _ = pw.CloseWithError(writeErr) }()
if len(first.GetBodyChunk()) > 0 {
if _, writeErr = pw.Write(first.GetBodyChunk()); writeErr != nil {
return
}
}
for req := range in {
if len(req.GetBodyChunk()) == 0 {
continue
}
if _, writeErr = pw.Write(req.GetBodyChunk()); writeErr != nil {
return
}
}
}()
req, err := http.NewRequestWithContext(ctx, method, fullURL, pr)
if err != nil {
_ = pr.CloseWithError(err) // unblocks the body-pump's pw.Write
return fmt.Errorf("cloud-proxy: build request: %w", err)
}
// Apply caller-supplied headers, then override with the
// authorization header derived from the resolved key. Caller-
// supplied Authorization is always replaced — operators may not
// know the backend's auth scheme, and silently leaking through a
// client Authorization header to a different upstream would
// confuse the upstream and could leak credentials.
for _, h := range first.GetHeaders() {
if h == nil || h.GetName() == "" {
continue
}
// Strip hop-by-hop headers that aren't meaningful to the
// upstream (Host is set by the http client from the URL;
// Content-Length is computed from the body).
if isHopByHopHeader(h.GetName()) {
continue
}
req.Header.Add(h.GetName(), h.GetValue())
}
if cfg.apiKey != "" {
applyAuthHeader(req, cfg.provider, cfg.apiKey)
}
xlog.Info("cloud-proxy: forward", "method", method, "url", fullURL, "provider", cfg.provider)
resp, err := c.client.Do(req)
if err != nil {
xlog.Warn("cloud-proxy: forward upstream failed", "url", fullURL, "error", err)
return fmt.Errorf("cloud-proxy: upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
logFn := xlog.Info
if resp.StatusCode >= 400 {
logFn = xlog.Warn
}
logFn("cloud-proxy: forward response", "url", fullURL, "status", resp.StatusCode)
// First reply: status + response headers, no body.
headers := make([]*pb.ForwardHeader, 0, len(resp.Header))
for k, vs := range resp.Header {
for _, v := range vs {
headers = append(headers, &pb.ForwardHeader{Name: k, Value: v})
}
}
out <- &pb.ForwardReply{Status: int32(resp.StatusCode), Headers: headers}
// Subsequent replies: body chunks. Use a fixed 8KB buffer — small
// enough that SSE token frames flush promptly, large enough that
// long chunked-transfer bodies aren't death by a thousand reads.
buf := make([]byte, 8*1024)
for {
n, rerr := resp.Body.Read(buf)
if n > 0 {
chunk := make([]byte, n)
copy(chunk, buf[:n])
out <- &pb.ForwardReply{BodyChunk: chunk}
}
if rerr != nil {
if errors.Is(rerr, io.EOF) {
return nil
}
return fmt.Errorf("cloud-proxy: upstream body read: %w", rerr)
}
}
}
// composeURL combines the configured upstream URL with the per-request
// path. The upstream URL typically already includes the canonical path
// (e.g. https://api.openai.com/v1/chat/completions) so the per-request
// path is ignored in that case. When upstream_url is a bare host
// (https://api.openai.com), the request path is appended.
func composeURL(upstream, reqPath string) (string, error) {
u, err := url.Parse(upstream)
if err != nil {
return "", fmt.Errorf("cloud-proxy: parse upstream_url %q: %w", upstream, err)
}
if u.Path == "" || u.Path == "/" {
u.Path = reqPath
}
return u.String(), nil
}
// applyAuthHeader writes the appropriate authorization header for the
// provider. OpenAI/Anthropic/most providers use Bearer; Anthropic
// historically uses x-api-key + anthropic-version, but accepts Bearer
// too via the OpenAI-compatible path. Default to Bearer when provider
// is empty (passthrough mode where the operator doesn't claim a
// provider).
func applyAuthHeader(req *http.Request, provider, key string) {
switch provider {
case providerAnthropic:
req.Header.Set("x-api-key", key)
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
default:
req.Header.Set("Authorization", "Bearer "+key)
}
}
// isHopByHopHeader returns true for headers that should not be
// forwarded from the client request to the upstream (RFC 7230 §6.1
// hop-by-hop list, plus a few that the http.Client sets itself).
func isHopByHopHeader(name string) bool {
switch strings.ToLower(name) {
case "connection", "proxy-connection", "keep-alive", "transfer-encoding",
"te", "trailer", "upgrade", "host", "content-length":
return true
}
return false
}

View File

@@ -1,206 +0,0 @@
package main
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
grpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/gomega"
)
// helper: run a CloudProxy in-process via grpc.Provide so tests can
// call Forward through the public Backend interface without listening
// on a real socket.
func newInProcClient(t *testing.T, proxy *CloudProxy) grpc.Backend {
t.Helper()
addr := "test://" + t.Name()
grpc.Provide(addr, proxy)
return grpc.NewClient(addr, true, nil, false)
}
func TestForward_PassthroughEcho(t *testing.T) {
g := NewWithT(t)
// Fake upstream: echoes the request body back, prefixed with a
// canary so the test can assert both that the body reached the
// upstream and the response made it back to the client.
gotBody := make(chan string, 1)
gotAuth := make(chan string, 1)
gotPath := make(chan string, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
gotBody <- string(body)
gotAuth <- r.Header.Get("Authorization")
gotPath <- r.URL.Path
w.Header().Set("X-Echo", "true")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("echo: " + string(body)))
}))
defer upstream.Close()
t.Setenv("CLOUD_PROXY_FAKE_KEY", "sk-fake")
cp := NewCloudProxy()
err := cp.Load(&pb.ModelOptions{
Proxy: &pb.ProxyOptions{
UpstreamUrl: upstream.URL,
Mode: modePassthrough,
ApiKeyEnv: "CLOUD_PROXY_FAKE_KEY",
},
})
g.Expect(err).NotTo(HaveOccurred())
c := newInProcClient(t, cp)
stream, err := c.Forward(context.Background())
g.Expect(err).NotTo(HaveOccurred())
err = stream.Send(&pb.ForwardRequest{
Path: "/v1/chat/completions",
Method: "POST",
Headers: []*pb.ForwardHeader{{Name: "Content-Type", Value: "application/json"}},
BodyChunk: []byte(`{"prompt":`),
})
g.Expect(err).NotTo(HaveOccurred())
err = stream.Send(&pb.ForwardRequest{BodyChunk: []byte(`"hi"}`)})
g.Expect(err).NotTo(HaveOccurred())
err = stream.CloseSend()
g.Expect(err).NotTo(HaveOccurred())
// First reply: status + headers.
first, err := stream.Recv()
g.Expect(err).NotTo(HaveOccurred())
g.Expect(first.Status).To(Equal(int32(http.StatusOK)))
g.Expect(hasHeader(first.Headers, "X-Echo", "true")).To(BeTrue())
// Subsequent replies: body.
var body []byte
for {
r, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
g.Expect(err).NotTo(HaveOccurred())
body = append(body, r.BodyChunk...)
}
g.Expect(string(body)).To(Equal(`echo: {"prompt":"hi"}`))
// Upstream observations.
var gotBodyVal, gotAuthVal, gotPathVal string
g.Eventually(gotBody).Should(Receive(&gotBodyVal), "upstream never saw body")
g.Expect(gotBodyVal).To(Equal(`{"prompt":"hi"}`))
g.Eventually(gotAuth).Should(Receive(&gotAuthVal), "upstream never saw auth header")
g.Expect(gotAuthVal).To(Equal("Bearer sk-fake"))
g.Eventually(gotPath).Should(Receive(&gotPathVal), "upstream never saw path")
g.Expect(gotPathVal).To(Equal("/v1/chat/completions"))
}
func TestForward_AnthropicAuthHeader(t *testing.T) {
g := NewWithT(t)
gotXAPIKey := make(chan string, 1)
gotVersion := make(chan string, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotXAPIKey <- r.Header.Get("x-api-key")
gotVersion <- r.Header.Get("anthropic-version")
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()
t.Setenv("CLOUD_PROXY_ANTHROPIC_KEY", "sk-ant-fake")
cp := NewCloudProxy()
err := cp.Load(&pb.ModelOptions{
Proxy: &pb.ProxyOptions{
UpstreamUrl: upstream.URL,
Mode: modePassthrough,
Provider: providerAnthropic,
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_KEY",
},
})
g.Expect(err).NotTo(HaveOccurred())
c := newInProcClient(t, cp)
stream, err := c.Forward(context.Background())
g.Expect(err).NotTo(HaveOccurred())
err = stream.Send(&pb.ForwardRequest{Path: "/v1/messages", Method: "POST"})
g.Expect(err).NotTo(HaveOccurred())
_ = stream.CloseSend()
_, _ = stream.Recv() // drain status
for {
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
break
}
}
g.Expect(<-gotXAPIKey).To(Equal("sk-ant-fake"))
g.Expect(<-gotVersion).NotTo(BeEmpty())
}
func TestLoad_ValidatesConfig(t *testing.T) {
g := NewWithT(t)
cp := NewCloudProxy()
err := cp.Load(&pb.ModelOptions{})
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("ProxyOptions"))
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{}})
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("upstream_url"))
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
UpstreamUrl: "https://example.com",
Mode: "rewrite",
}})
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("unknown mode"))
// translate + openai should load successfully (Phase 5).
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
UpstreamUrl: "https://example.com/v1/chat/completions",
Mode: modeTranslate,
Provider: providerOpenAI,
}})
g.Expect(err).NotTo(HaveOccurred())
// translate + anthropic should load successfully (Phase 6).
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
UpstreamUrl: "https://example.com/v1/messages",
Mode: modeTranslate,
Provider: providerAnthropic,
}})
g.Expect(err).NotTo(HaveOccurred())
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
UpstreamUrl: "https://example.com",
ApiKeyEnv: "DEFINITELY_UNSET_ENV_VAR_XYZ",
}})
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("unset"))
}
func TestForward_RejectsWithoutLoad(t *testing.T) {
g := NewWithT(t)
cp := NewCloudProxy()
c := newInProcClient(t, cp)
stream, err := c.Forward(context.Background())
g.Expect(err).NotTo(HaveOccurred())
_ = stream.CloseSend()
_, err = stream.Recv()
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("not loaded"))
}
func hasHeader(hs []*pb.ForwardHeader, name, value string) bool {
for _, h := range hs {
if strings.EqualFold(h.GetName(), name) && h.GetValue() == value {
return true
}
}
return false
}

View File

@@ -1,6 +0,0 @@
#!/bin/bash
set -ex
CURDIR=$(dirname "$(realpath $0)")
exec $CURDIR/cloud-proxy "$@"

View File

@@ -1,232 +0,0 @@
package main
import (
"strings"
"testing"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/gomega"
)
// OpenAI: non-streaming tool call response. Verify the response is
// mapped to Reply.ChatDeltas[].ToolCalls with id/name/arguments intact,
// and usage tokens land on Reply.PromptTokens / Reply.Tokens.
func TestPredictRich_OpenAI_ToolCalls(t *testing.T) {
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, `{
"id":"resp-1",
"choices":[{
"index":0,
"message":{
"role":"assistant",
"content":"",
"tool_calls":[
{"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"SF\"}"}},
{"id":"call_def","type":"function","function":{"name":"get_time","arguments":"{\"tz\":\"PT\"}"}}
]
},
"finish_reason":"tool_calls"
}],
"usage":{"prompt_tokens":42,"completion_tokens":18,"total_tokens":60}
}`, "application/json"
})
defer srv.Close()
g := NewWithT(t)
cp := newTranslateCloudProxy(t, srv.URL)
reply, err := cp.PredictRich(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(string(reply.GetMessage())).To(Equal(""))
g.Expect(reply.GetPromptTokens()).To(Equal(int32(42)))
g.Expect(reply.GetTokens()).To(Equal(int32(18)))
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
tcs := reply.GetChatDeltas()[0].GetToolCalls()
g.Expect(tcs).To(HaveLen(2))
g.Expect(tcs[0].GetId()).To(Equal("call_abc"))
g.Expect(tcs[0].GetName()).To(Equal("get_weather"))
g.Expect(tcs[0].GetArguments()).To(ContainSubstring(`"location":"SF"`))
g.Expect(tcs[1].GetId()).To(Equal("call_def"))
g.Expect(tcs[1].GetName()).To(Equal("get_time"))
}
// OpenAI: streaming tool call. Arguments arrive as a sequence of
// delta chunks; the consumer is expected to concatenate by tool index.
// Verify each chunk reaches the channel and the assembled arguments
// match the input.
func TestPredictStreamRich_OpenAI_ToolCallDeltas(t *testing.T) {
chunks := []string{
// Frame 0: announce the tool call (id + name, no args yet).
`{"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_xyz","type":"function","function":{"name":"search"}}]}}]}`,
// Frames 1-3: arguments arrive in fragments.
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"q\":"}}]}}]}`,
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"clo"}}]}}]}`,
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"uds\"}"}}]}}]}`,
// Stop frame.
`{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
}
body := ""
for _, c := range chunks {
body += "data: " + c + "\n\n"
}
body += "data: [DONE]\n\n"
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, body, "text/event-stream"
})
defer srv.Close()
g := NewWithT(t)
cp := newTranslateCloudProxy(t, srv.URL)
results := make(chan *pb.Reply, 16)
done := make(chan error, 1)
go func() {
done <- cp.PredictStreamRich(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "find something"}},
}, results)
close(results)
}()
var (
toolName string
toolID string
toolIndex int32 = -1
argsBuf strings.Builder
)
for reply := range results {
for _, cd := range reply.GetChatDeltas() {
for _, tc := range cd.GetToolCalls() {
if tc.GetName() != "" {
toolName = tc.GetName()
}
if tc.GetId() != "" {
toolID = tc.GetId()
}
if toolIndex == -1 {
toolIndex = tc.GetIndex()
}
argsBuf.WriteString(tc.GetArguments())
}
}
}
err := <-done
g.Expect(err).NotTo(HaveOccurred())
g.Expect(toolID).To(Equal("call_xyz"))
g.Expect(toolName).To(Equal("search"))
g.Expect(toolIndex).To(Equal(int32(0)))
g.Expect(argsBuf.String()).To(Equal(`{"q":"clouds"}`))
}
// Anthropic: non-streaming tool_use block. The block appears in
// Content[] alongside text blocks; the input field is a structured
// JSON object. Map to ToolCallDelta with arguments as serialised JSON
// so downstream OpenAI-shaped consumers see a familiar format.
func TestPredictRich_Anthropic_ToolUse(t *testing.T) {
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, `{
"id":"msg_1","type":"message","role":"assistant",
"content":[
{"type":"text","text":"Let me check that."},
{"type":"tool_use","id":"toolu_01","name":"weather","input":{"location":"SF"}}
],
"model":"claude","usage":{"input_tokens":12,"output_tokens":34}
}`, "application/json"
})
defer srv.Close()
g := NewWithT(t)
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
reply, err := cp.PredictRich(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
Tokens: 64,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(string(reply.GetMessage())).To(Equal("Let me check that."))
g.Expect(reply.GetPromptTokens()).To(Equal(int32(12)))
g.Expect(reply.GetTokens()).To(Equal(int32(34)))
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
g.Expect(reply.GetChatDeltas()[0].GetToolCalls()).To(HaveLen(1))
tc := reply.GetChatDeltas()[0].GetToolCalls()[0]
g.Expect(tc.GetId()).To(Equal("toolu_01"))
g.Expect(tc.GetName()).To(Equal("weather"))
g.Expect(tc.GetArguments()).To(ContainSubstring(`"location":"SF"`))
}
// Anthropic: streaming tool_use. content_block_start announces the
// tool's id + name; input_json_delta events carry argument fragments
// which the consumer accumulates. message_delta carries final usage.
func TestPredictStreamRich_Anthropic_InputJSONDelta(t *testing.T) {
frames := []string{
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
// Block 0 is a tool_use; consumer should allocate a slot.
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_42\",\"name\":\"lookup\"}}\n\n",
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"q\\\":\"}}\n\n",
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"\\\"rain\\\"}\"}}\n\n",
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
"event: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n",
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
}
body := strings.Join(frames, "")
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
return 200, body, "text/event-stream"
})
defer srv.Close()
g := NewWithT(t)
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
results := make(chan *pb.Reply, 16)
done := make(chan error, 1)
go func() {
done <- cp.PredictStreamRich(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "rain?"}},
Tokens: 64,
}, results)
close(results)
}()
var (
toolID, toolName string
argsBuf strings.Builder
finalTokens int32
)
for reply := range results {
if reply.GetTokens() > 0 && len(reply.GetChatDeltas()) == 0 {
finalTokens = reply.GetTokens()
continue
}
for _, cd := range reply.GetChatDeltas() {
for _, tc := range cd.GetToolCalls() {
if tc.GetId() != "" {
toolID = tc.GetId()
}
if tc.GetName() != "" {
toolName = tc.GetName()
}
argsBuf.WriteString(tc.GetArguments())
}
}
}
err := <-done
g.Expect(err).NotTo(HaveOccurred())
g.Expect(toolID).To(Equal("toolu_42"))
g.Expect(toolName).To(Equal("lookup"))
g.Expect(argsBuf.String()).To(Equal(`{"q":"rain"}`))
g.Expect(finalTokens).To(Equal(int32(7)))
}
// Sanity: the legacy Predict() (string, error) signature still works
// — it delegates to PredictRich and extracts Message.
func TestPredict_LegacyWrapper_OpenAI(t *testing.T) {
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, `{"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, "application/json"
})
defer srv.Close()
g := NewWithT(t)
cp := newTranslateCloudProxy(t, srv.URL)
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "hi"}}})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(got).To(Equal("hello"))
}

View File

@@ -8,6 +8,6 @@ import (
func assert(cond bool, msg string) {
if !cond {
xlog.Fatal(msg)
xlog.Fatal().Stack().Msg(msg)
}
}

View File

@@ -1,22 +1,7 @@
package main
// LocalAI's in-process vector store, exposed as a gRPC backend. Keep
// the implementation here — NOT in a pkg/ library imported by the main
// LocalAI process. The whole point of the gRPC surface is that vector
// storage is a backend like any other (local-store, qdrant, pinecone,
// ...) and can be swapped without changing the routing/recognition
// code that consumes it.
//
// Storage is a sorted parallel-slice (keys [][]float32, values
// [][]byte). Set/Delete preserve the sort so Get can binary-search.
// Find scans linearly and uses a heap to keep the top-K — fine for
// the tens-to-thousands range. The "normalized fast path" (Find when
// every stored key has unit magnitude AND the query is normalized)
// skips the per-item magnitude calculation.
//
// Concurrency: base.SingleThread serialises gRPC calls so the
// non-thread-safe slice/heap manipulation here is sound.
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"container/heap"
"fmt"
@@ -25,29 +10,32 @@ import (
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/store"
"github.com/mudler/xlog"
)
type Store struct {
base.SingleThread
keys [][]float32
// The sorted keys
keys [][]float32
// The sorted values
values [][]byte
// keysAreNormalized stays true until any non-unit-magnitude key
// is added; once false, the magnitude-aware fallback path is
// used by Find. Re-evaluated only at Set time, never again on
// its own — a deletion of the offending key does NOT flip it
// back to true (the bookkeeping cost would dominate the gain).
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
// TODO: Should we normalize incoming keys if they are not instead?
keysAreNormalized bool
// keyLen is the dimension of every stored key. -1 means "no
// keys yet, dimension is open". Dimension mismatch on Set is
// rejected so cosine similarity (which requires equal-length
// vectors) doesn't silently mis-match.
// The first key decides the length of the keys
keyLen int
}
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
type Pair struct {
Key []float32
Value []byte
}
func NewStore() *Store {
return &Store{
keys: make([][]float32, 0),
@@ -57,278 +45,334 @@ func NewStore() *Store {
}
}
// Load is a no-op — local-store has no on-disk artefact. opts.Model is
// just a namespace identifier; isolation is already handled upstream
// (ModelLoader spawns a fresh local-store process per (backend,
// model) tuple, so each namespace is its own Store{} instance).
func compareSlices(k1, k2 []float32) int {
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
return slices.Compare(k1, k2)
}
func hasKey(unsortedSlice [][]float32, target []float32) bool {
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
return compareSlices(k, target) == 0
})
}
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
return compareSlices(k, t)
})
}
func isSortedPairs(kvs []Pair) bool {
for i := 1; i < len(kvs); i++ {
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
return false
}
}
return true
}
func isSortedKeys(keys [][]float32) bool {
for i := 1; i < len(keys); i++ {
if compareSlices(keys[i-1], keys[i]) > 0 {
return false
}
}
return true
}
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
ks := make([][]float32, len(keys))
for i, k := range keys {
ks[i] = k.Floats
}
slices.SortFunc(ks, compareSlices)
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
assert(isSortedKeys(ks), "keys are not sorted")
return ks
}
func (s *Store) Load(opts *pb.ModelOptions) error {
// local-store is an in-memory vector store with no on-disk artefact to
// load — opts.Model is just a namespace identifier. The old `!= ""` guard
// rejected any non-empty model name with "not implemented", which broke
// callers that pass a namespace to isolate embedding spaces (face vs.
// voice biometrics both go through local-store but need distinct stores
// so ArcFace 512-D and ECAPA-TDNN 192-D don't collide). Namespace
// isolation is already handled upstream: ModelLoader spawns a fresh
// local-store process per (backend, model) tuple, so each namespace is
// its own Store{} instance. Nothing to do here beyond accepting the load.
_ = opts
return nil
}
// Sort the incoming kvs and merge them with the existing sorted kvs
func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
keys := store.UnwrapKeys(opts.Keys)
values := store.UnwrapValues(opts.Values)
if len(keys) == 0 {
return fmt.Errorf("local-store: Set: no keys to add")
if len(opts.Keys) == 0 {
return fmt.Errorf("no keys to add")
}
if len(keys) != len(values) {
return fmt.Errorf("local-store: Set: len(keys) = %d, len(values) = %d", len(keys), len(values))
if len(opts.Keys) != len(opts.Values) {
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
}
if s.keyLen == -1 {
s.keyLen = len(keys[0])
} else if len(keys[0]) != s.keyLen {
return fmt.Errorf("local-store: Set: key length %d does not match existing %d", len(keys[0]), s.keyLen)
s.keyLen = len(opts.Keys[0].Floats)
} else {
if len(opts.Keys[0].Floats) != s.keyLen {
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
}
}
kvs := make([]incomingPair, len(keys))
for i, k := range keys {
if len(k) != s.keyLen {
return fmt.Errorf("local-store: Set: key %d length %d does not match existing %d", i, len(k), s.keyLen)
}
if s.keysAreNormalized && !isNormalized(k) {
kvs := make([]Pair, len(opts.Keys))
for i, k := range opts.Keys {
if s.keysAreNormalized && !isNormalized(k.Floats) {
s.keysAreNormalized = false
var sample []float32
if len(s.keys) > 5 {
sample = k.Floats[:5]
} else {
sample = k.Floats
}
xlog.Debug("Key is not normalized", "sample", sample)
}
kvs[i] = Pair{
Key: k.Floats,
Value: opts.Values[i].Bytes,
}
kvs[i] = incomingPair{key: k, value: values[i]}
}
slices.SortFunc(kvs, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) })
slices.SortFunc(kvs, func(a, b Pair) int {
return compareSlices(a.Key, b.Key)
})
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
assert(isSortedPairs(kvs), "keys are not sorted")
l := len(kvs) + len(s.keys)
merge_ks := make([][]float32, 0, l)
merge_vs := make([][]byte, 0, l)
i, j := 0, 0
for {
if i+j >= l {
break
}
if i >= len(kvs) {
merge_ks = append(merge_ks, s.keys[j])
merge_vs = append(merge_vs, s.values[j])
j++
continue
}
if j >= len(s.keys) {
merge_ks = append(merge_ks, kvs[i].Key)
merge_vs = append(merge_vs, kvs[i].Value)
i++
continue
}
c := compareSlices(kvs[i].Key, s.keys[j])
if c < 0 {
merge_ks = append(merge_ks, kvs[i].Key)
merge_vs = append(merge_vs, kvs[i].Value)
i++
} else if c > 0 {
merge_ks = append(merge_ks, s.keys[j])
merge_vs = append(merge_vs, s.values[j])
j++
} else {
merge_ks = append(merge_ks, kvs[i].Key)
merge_vs = append(merge_vs, kvs[i].Value)
i++
j++
}
}
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
s.keys = merge_ks
s.values = merge_vs
merged := mergeSortedPairs(s.keys, s.values, kvs)
s.keys = merged.keys
s.values = merged.values
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Set: s.keys not sorted post-merge")
assert(len(s.keys) == len(s.values), "Set: keys/values length skew")
return nil
}
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
keys := store.UnwrapKeys(opts.Keys)
if len(keys) == 0 {
return fmt.Errorf("local-store: Delete: no keys to delete")
if len(opts.Keys) == 0 {
return fmt.Errorf("no keys to delete")
}
if s.keyLen != -1 {
for i, k := range keys {
if len(k) != s.keyLen {
return fmt.Errorf("local-store: Delete: key %d length %d does not match existing %d", i, len(k), s.keyLen)
if len(opts.Keys) == 0 {
return fmt.Errorf("no keys to add")
}
if s.keyLen == -1 {
s.keyLen = len(opts.Keys[0].Floats)
} else {
if len(opts.Keys[0].Floats) != s.keyLen {
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
}
}
ks := sortIntoKeySlicese(opts.Keys)
l := len(s.keys) - len(ks)
merge_ks := make([][]float32, 0, l)
merge_vs := make([][]byte, 0, l)
tail_ks := s.keys
tail_vs := s.values
for _, k := range ks {
j, found := findInSortedSlice(tail_ks, k)
if found {
merge_ks = append(merge_ks, tail_ks[:j]...)
merge_vs = append(merge_vs, tail_vs[:j]...)
tail_ks = tail_ks[j+1:]
tail_vs = tail_vs[j+1:]
} else {
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
}
xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs))
}
merge_ks = append(merge_ks, tail_ks...)
merge_vs = append(merge_vs, tail_vs...)
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
s.keys = merge_ks
s.values = merge_vs
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
assert(isSortedKeys(s.keys), "keys are not sorted")
assert(func() bool {
for _, k := range ks {
if _, found := findInSortedSlice(s.keys, k); found {
return false
}
}
}
sortedKeys := append([][]float32(nil), keys...)
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
return true
}(), "Keys to delete still present")
mergedK := make([][]float32, 0, len(s.keys))
mergedV := make([][]byte, 0, len(s.keys))
tailK := s.keys
tailV := s.values
for _, k := range sortedKeys {
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
if ok {
mergedK = append(mergedK, tailK[:j]...)
mergedV = append(mergedV, tailV[:j]...)
tailK = tailK[j+1:]
tailV = tailV[j+1:]
}
if len(s.keys) != l {
xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l)
}
mergedK = append(mergedK, tailK...)
mergedV = append(mergedV, tailV...)
s.keys = mergedK
s.values = mergedV
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Delete: s.keys not sorted post-merge")
assert(len(s.keys) == len(s.values), "Delete: keys/values length skew")
return nil
}
// StoresGet fetches values for the given keys. Missing keys are
// omitted from the result rather than reported as an error — callers
// compare returned-key length against requested-key length to detect
// them. Returned slices are aligned.
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
keys := store.UnwrapKeys(opts.Keys)
pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys))
pbValues := make([]*pb.StoresValue, 0, len(opts.Keys))
ks := sortIntoKeySlicese(opts.Keys)
if len(s.keys) == 0 {
return pb.StoresGetResult{}, nil
}
if s.keyLen != -1 {
for i, k := range keys {
if len(k) != s.keyLen {
return pb.StoresGetResult{}, fmt.Errorf("local-store: Get: key %d length %d does not match existing %d", i, len(k), s.keyLen)
}
}
}
sortedKeys := append([][]float32(nil), keys...)
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
var foundKeys [][]float32
var foundValues [][]byte
tailK := s.keys
tailV := s.values
for _, k := range sortedKeys {
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
if !ok {
continue
}
foundKeys = append(foundKeys, tailK[j])
foundValues = append(foundValues, tailV[j])
tailK = tailK[j+1:]
tailV = tailV[j+1:]
}
return pb.StoresGetResult{
Keys: store.WrapKeys(foundKeys),
Values: store.WrapValues(foundValues),
}, nil
}
// StoresFind returns the topK nearest stored entries by cosine
// similarity, ordered most-similar first. An empty store returns
// empty slices and no error.
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
query := opts.Key.Floats
topK := int(opts.TopK)
if topK < 1 {
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: topK = %d, must be >= 1", topK)
}
if len(s.keys) == 0 {
return pb.StoresFindResult{}, nil
}
if len(query) != s.keyLen {
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: query length %d does not match existing %d", len(query), s.keyLen)
xlog.Debug("Get: No keys in store")
}
var keys [][]float32
var values [][]byte
var sims []float32
if s.keysAreNormalized && isNormalized(query) {
keys, values, sims = s.findNormalized(query, topK)
if s.keyLen == -1 {
s.keyLen = len(opts.Keys[0].Floats)
} else {
keys, values, sims = s.findFallback(query, topK)
if len(opts.Keys[0].Floats) != s.keyLen {
return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
}
}
return pb.StoresFindResult{
Keys: store.WrapKeys(keys),
Values: store.WrapValues(values),
Similarities: sims,
tail_k := s.keys
tail_v := s.values
for i, k := range ks {
j, found := findInSortedSlice(tail_k, k)
if found {
pbKeys = append(pbKeys, &pb.StoresKey{
Floats: k,
})
pbValues = append(pbValues, &pb.StoresValue{
Bytes: tail_v[j],
})
tail_k = tail_k[j+1:]
tail_v = tail_v[j+1:]
} else {
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k))
}
}
if len(pbKeys) != len(opts.Keys) {
xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys))
}
return pb.StoresGetResult{
Keys: pbKeys,
Values: pbValues,
}, nil
}
func (s *Store) findNormalized(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
assert(s.keysAreNormalized, "findNormalized: s.keysAreNormalized is false")
assert(isNormalized(query), "findNormalized: query is not unit-length")
pq := make(priorityQueue, 0, topK)
heap.Init(&pq)
for i, k := range s.keys {
var dot float32
for j := range k {
dot += query[j] * k[j]
}
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("findNormalized: dot %f out of [-1, 1] — keysAreNormalized invariant violated", dot))
heap.Push(&pq, &priorityItem{similarity: dot, key: k, value: s.values[i]})
if pq.Len() > topK {
heap.Pop(&pq)
}
}
return drainPQ(&pq)
}
func (s *Store) findFallback(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
var qmag float64
for _, v := range query {
qmag += float64(v) * float64(v)
}
qmag = math.Sqrt(qmag)
pq := make(priorityQueue, 0, topK)
heap.Init(&pq)
for i, k := range s.keys {
var dot, kmag float64
for j := range k {
dot += float64(query[j]) * float64(k[j])
kmag += float64(k[j]) * float64(k[j])
}
denom := qmag * math.Sqrt(kmag)
var sim float32
if denom > 0 {
sim = float32(dot / denom)
}
heap.Push(&pq, &priorityItem{similarity: sim, key: k, value: s.values[i]})
if pq.Len() > topK {
heap.Pop(&pq)
}
}
return drainPQ(&pq)
}
func isNormalized(k []float32) bool {
var sum float64
for _, v := range k {
sum += float64(v) * float64(v)
v64 := float64(v)
sum += v64 * v64
}
mag := math.Sqrt(sum)
return mag >= 0.99 && mag <= 1.01
s := math.Sqrt(sum)
return s >= 0.99 && s <= 1.01
}
type incomingPair struct {
key []float32
value []byte
}
// TODO: This we could replace with handwritten SIMD code
func normalizedCosineSimilarity(k1, k2 []float32) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
type pairs struct {
keys [][]float32
values [][]byte
}
// mergeSortedPairs merges (existing, incoming) into a fresh sorted
// slice. Equal keys take the incoming value — Set is upsert.
func mergeSortedPairs(existingK [][]float32, existingV [][]byte, incoming []incomingPair) pairs {
assert(slices.IsSortedFunc(existingK, slices.Compare[[]float32]), "mergeSortedPairs: existing not sorted")
assert(slices.IsSortedFunc(incoming, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) }), "mergeSortedPairs: incoming not sorted")
l := len(existingK) + len(incoming)
mk := make([][]float32, 0, l)
mv := make([][]byte, 0, l)
i, j := 0, 0
for i < len(incoming) || j < len(existingK) {
switch {
case j >= len(existingK):
mk = append(mk, incoming[i].key)
mv = append(mv, incoming[i].value)
i++
case i >= len(incoming):
mk = append(mk, existingK[j])
mv = append(mv, existingV[j])
j++
default:
c := slices.Compare(incoming[i].key, existingK[j])
switch {
case c < 0:
mk = append(mk, incoming[i].key)
mv = append(mv, incoming[i].value)
i++
case c > 0:
mk = append(mk, existingK[j])
mv = append(mv, existingV[j])
j++
default:
mk = append(mk, incoming[i].key)
mv = append(mv, incoming[i].value)
i++
j++
}
}
var dot float32
for i := range len(k1) {
dot += k1[i] * k2[i]
}
return pairs{keys: mk, values: mv}
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot))
// 2.0 * (1.0 - dot) would be the Euclidean distance
return dot
}
type priorityItem struct {
similarity float32
key []float32
value []byte
type PriorityItem struct {
Similarity float32
Key []float32
Value []byte
}
type priorityQueue []*priorityItem
type PriorityQueue []*PriorityItem
func (pq priorityQueue) Len() int { return len(pq) }
func (pq priorityQueue) Less(i, j int) bool { return pq[i].similarity < pq[j].similarity }
func (pq priorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
func (pq *priorityQueue) Push(x any) { *pq = append(*pq, x.(*priorityItem)) }
func (pq *priorityQueue) Pop() any {
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
// Inverted because the most similar should be at the top
return pq[i].Similarity < pq[j].Similarity
}
func (pq PriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
}
func (pq *PriorityQueue) Push(x any) {
item := x.(*PriorityItem)
*pq = append(*pq, item)
}
func (pq *PriorityQueue) Pop() any {
old := *pq
n := len(old)
item := old[n-1]
@@ -336,16 +380,142 @@ func (pq *priorityQueue) Pop() any {
return item
}
func drainPQ(pq *priorityQueue) (keys [][]float32, values [][]byte, similarities []float32) {
n := pq.Len()
keys = make([][]float32, n)
values = make([][]byte, n)
similarities = make([]float32, n)
for i := n - 1; i >= 0; i-- {
item := heap.Pop(pq).(*priorityItem)
keys[i] = item.key
values[i] = item.value
similarities[i] = item.similarity
func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
tk := opts.Key.Floats
top_ks := make(PriorityQueue, 0, int(opts.TopK))
heap.Init(&top_ks)
for i, k := range s.keys {
sim := normalizedCosineSimilarity(tk, k)
heap.Push(&top_ks, &PriorityItem{
Similarity: sim,
Key: k,
Value: s.values[i],
})
if top_ks.Len() > int(opts.TopK) {
heap.Pop(&top_ks)
}
}
similarities := make([]float32, top_ks.Len())
pbKeys := make([]*pb.StoresKey, top_ks.Len())
pbValues := make([]*pb.StoresValue, top_ks.Len())
for i := top_ks.Len() - 1; i >= 0; i-- {
item := heap.Pop(&top_ks).(*PriorityItem)
similarities[i] = item.Similarity
pbKeys[i] = &pb.StoresKey{
Floats: item.Key,
}
pbValues[i] = &pb.StoresValue{
Bytes: item.Value,
}
}
return pb.StoresFindResult{
Keys: pbKeys,
Values: pbValues,
Similarities: similarities,
}, nil
}
func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
var dot, mag2 float64
for i := range len(k1) {
dot += float64(k1[i] * k2[i])
mag2 += float64(k2[i] * k2[i])
}
sim := float32(dot / (mag1 * math.Sqrt(mag2)))
assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim))
return sim
}
func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
tk := opts.Key.Floats
top_ks := make(PriorityQueue, 0, int(opts.TopK))
heap.Init(&top_ks)
var mag1 float64
for _, v := range tk {
mag1 += float64(v * v)
}
mag1 = math.Sqrt(mag1)
for i, k := range s.keys {
dist := cosineSimilarity(tk, k, mag1)
heap.Push(&top_ks, &PriorityItem{
Similarity: dist,
Key: k,
Value: s.values[i],
})
if top_ks.Len() > int(opts.TopK) {
heap.Pop(&top_ks)
}
}
similarities := make([]float32, top_ks.Len())
pbKeys := make([]*pb.StoresKey, top_ks.Len())
pbValues := make([]*pb.StoresValue, top_ks.Len())
for i := top_ks.Len() - 1; i >= 0; i-- {
item := heap.Pop(&top_ks).(*PriorityItem)
similarities[i] = item.Similarity
pbKeys[i] = &pb.StoresKey{
Floats: item.Key,
}
pbValues[i] = &pb.StoresValue{
Bytes: item.Value,
}
}
return pb.StoresFindResult{
Keys: pbKeys,
Values: pbValues,
Similarities: similarities,
}, nil
}
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
tk := opts.Key.Floats
if len(tk) != s.keyLen {
return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen)
}
if opts.TopK < 1 {
return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK)
}
if s.keyLen == -1 {
s.keyLen = len(opts.Key.Floats)
} else {
if len(opts.Key.Floats) != s.keyLen {
return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen)
}
}
if s.keysAreNormalized && isNormalized(tk) {
return s.StoresFindNormalized(opts)
} else {
if s.keysAreNormalized {
var sample []float32
if len(s.keys) > 5 {
sample = tk[:5]
} else {
sample = tk
}
xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample)
}
return s.StoresFindFallback(opts)
}
return keys, values, similarities
}

View File

@@ -1,13 +0,0 @@
package main
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLocalStore(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "local-store test suite")
}

View File

@@ -1,284 +0,0 @@
package main
// Regression suite for the local-store gRPC backend. Exercises the
// Stores{Set,Get,Find,Delete} surface — the only public contract.
// Callers (face/voice recognition, the routing KNN classifier) reach
// this code via grpc.Backend, so testing at the wire-shaped boundary
// matches the production import shape.
import (
"math"
"math/rand/v2"
"testing"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("StoresSet", func() {
It("rejects empty input", func() {
Expect(NewStore().StoresSet(&pb.StoresSetOptions{})).NotTo(Succeed(), "Set with no keys should fail")
})
It("rejects key/value length mismatch", func() {
err := NewStore().StoresSet(&pb.StoresSetOptions{
Keys: wrapKeys([][]float32{{1, 0, 0}}),
Values: wrapValues([][]byte{[]byte("a"), []byte("b")}),
})
Expect(err).To(HaveOccurred(), "len(keys) != len(values) should fail")
})
It("rejects dimension mismatch on later add", func() {
s := NewStore()
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("3d")})
err := s.StoresSet(&pb.StoresSetOptions{
Keys: wrapKeys([][]float32{{1, 0}}),
Values: wrapValues([][]byte{[]byte("2d")}),
})
Expect(err).To(HaveOccurred(), "dimension mismatch on later Set should fail")
})
It("rejects dimension mismatch within batch", func() {
err := NewStore().StoresSet(&pb.StoresSetOptions{
Keys: wrapKeys([][]float32{{1, 0, 0}, {1, 0}}),
Values: wrapValues([][]byte{[]byte("3d"), []byte("2d")}),
})
Expect(err).To(HaveOccurred(), "mixed-dimension within one batch should fail")
})
It("merges sorted and updates existing key", func() {
s := NewStore()
mustSet(s, [][]float32{{0.3, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("c"), []byte("a")})
mustSet(s, [][]float32{{0.2, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("b"), []byte("a-updated")})
Expect(s.keys).To(HaveLen(3))
got := singleGet(s, []float32{0.1, 0, 0})
Expect(string(got)).To(Equal("a-updated"))
})
})
var _ = Describe("StoresGet", func() {
It("round-trips multi-key", func() {
s := NewStore()
mustSet(s,
[][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}},
[][]byte{[]byte("a"), []byte("b"), []byte("c")},
)
res, err := s.StoresGet(&pb.StoresGetOptions{
Keys: wrapKeys([][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}}),
})
Expect(err).NotTo(HaveOccurred())
Expect(res.Keys).To(HaveLen(2))
})
It("omits missing keys rather than erroring", func() {
s := NewStore()
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
res, err := s.StoresGet(&pb.StoresGetOptions{
Keys: wrapKeys([][]float32{{0.1, 0, 0}, {0.9, 0, 0}}),
})
Expect(err).NotTo(HaveOccurred())
Expect(res.Keys).To(HaveLen(1))
})
})
var _ = Describe("StoresDelete", func() {
It("removes and preserves sort", func() {
s := NewStore()
mustSet(s,
[][]float32{{0.1, 0, 0}, {0.2, 0, 0}, {0.3, 0, 0}, {0.4, 0, 0}},
[][]byte{[]byte("a"), []byte("b"), []byte("c"), []byte("d")},
)
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
Keys: wrapKeys([][]float32{{0.2, 0, 0}, {0.4, 0, 0}}),
})).To(Succeed())
Expect(s.keys).To(HaveLen(2))
})
It("tolerates missing keys", func() {
s := NewStore()
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
Keys: wrapKeys([][]float32{{0.9, 0, 0}}),
})).To(Succeed(), "delete of missing key should succeed")
Expect(s.keys).To(HaveLen(1))
})
})
var _ = Describe("StoresFind", func() {
It("returns normalized top-K", func() {
s := NewStore()
mustSet(s,
[][]float32{
normalizeVec([]float32{1, 0, 0}),
normalizeVec([]float32{0, 1, 0}),
normalizeVec([]float32{0, 0, 1}),
},
[][]byte{[]byte("x"), []byte("y"), []byte("z")},
)
res, err := s.StoresFind(&pb.StoresFindOptions{
Key: &pb.StoresKey{Floats: normalizeVec([]float32{0.9, 0.1, 0})},
TopK: 2,
})
Expect(err).NotTo(HaveOccurred())
Expect(res.Keys).To(HaveLen(2))
Expect(res.Similarities[0]).To(BeNumerically(">=", res.Similarities[1]), "results not sorted desc by similarity")
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
})
It("falls back for non-normalized keys", func() {
s := NewStore()
mustSet(s, [][]float32{{2, 0, 0}, {0, 3, 0}}, [][]byte{[]byte("x"), []byte("y")})
Expect(s.keysAreNormalized).To(BeFalse(), "store should report non-normalized after Set with magnitude > 1")
res, err := s.StoresFind(&pb.StoresFindOptions{
Key: &pb.StoresKey{Floats: []float32{4, 0, 0}},
TopK: 1,
})
Expect(err).NotTo(HaveOccurred())
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
Expect(res.Similarities[0]).To(BeNumerically(">=", float32(0.99)))
Expect(res.Similarities[0]).To(BeNumerically("<=", float32(1.01)))
})
It("rejects zero topK", func() {
s := NewStore()
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
_, err := s.StoresFind(&pb.StoresFindOptions{
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
TopK: 0,
})
Expect(err).To(HaveOccurred(), "Find with topK=0 should fail")
})
It("rejects dimension mismatch", func() {
s := NewStore()
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
_, err := s.StoresFind(&pb.StoresFindOptions{
Key: &pb.StoresKey{Floats: []float32{1, 0}},
TopK: 1,
})
Expect(err).To(HaveOccurred(), "Find with mismatched dimension should fail")
})
It("returns empty result on empty store", func() {
res, err := NewStore().StoresFind(&pb.StoresFindOptions{
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
TopK: 5,
})
Expect(err).NotTo(HaveOccurred(), "Find on empty store should succeed")
Expect(res.Keys).To(BeEmpty())
})
It("handles topK larger than store", func() {
s := NewStore()
mustSet(s,
[][]float32{normalizeVec([]float32{1, 0, 0}), normalizeVec([]float32{0, 1, 0})},
[][]byte{[]byte("x"), []byte("y")},
)
res, err := s.StoresFind(&pb.StoresFindOptions{
Key: &pb.StoresKey{Floats: normalizeVec([]float32{1, 0, 0})},
TopK: 10,
})
Expect(err).NotTo(HaveOccurred())
Expect(res.Keys).To(HaveLen(2))
})
})
var _ = Describe("StoresLoad", func() {
It("is a no-op", func() {
Expect(NewStore().Load(&pb.ModelOptions{Model: "any-namespace"})).To(Succeed())
})
})
func BenchmarkStoresFindNormalized(b *testing.B) {
const dim = 768
for _, n := range []int{8, 32, 128, 512} {
b.Run(fmtN(n), func(b *testing.B) {
s := buildStore(b, n, dim)
query := normalizeVec(randVec(dim, 42))
req := &pb.StoresFindOptions{Key: &pb.StoresKey{Floats: query}, TopK: 1}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := s.StoresFind(req); err != nil {
b.Fatal(err)
}
}
})
}
}
// --- test helpers ---
func mustSet(s *Store, keys [][]float32, values [][]byte) {
ExpectWithOffset(1, s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)})).To(Succeed())
}
func singleGet(s *Store, key []float32) []byte {
res, err := s.StoresGet(&pb.StoresGetOptions{Keys: wrapKeys([][]float32{key})})
ExpectWithOffset(1, err).NotTo(HaveOccurred())
if len(res.Values) == 0 {
return nil
}
return res.Values[0].Bytes
}
func wrapKeys(in [][]float32) []*pb.StoresKey {
out := make([]*pb.StoresKey, len(in))
for i, k := range in {
out[i] = &pb.StoresKey{Floats: k}
}
return out
}
func wrapValues(in [][]byte) []*pb.StoresValue {
out := make([]*pb.StoresValue, len(in))
for i, v := range in {
out[i] = &pb.StoresValue{Bytes: v}
}
return out
}
func buildStore(tb testing.TB, n, dim int) *Store {
tb.Helper()
s := NewStore()
keys := make([][]float32, n)
values := make([][]byte, n)
for i := 0; i < n; i++ {
keys[i] = normalizeVec(randVec(dim, int64(i)+1))
values[i] = []byte{byte(i)}
}
if err := s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)}); err != nil {
tb.Fatal(err)
}
return s
}
func randVec(dim int, seed int64) []float32 {
r := rand.New(rand.NewPCG(uint64(seed), 0xabcdef))
v := make([]float32, dim)
for i := range v {
v[i] = float32(r.NormFloat64())
}
return v
}
func normalizeVec(v []float32) []float32 {
var sum float64
for _, x := range v {
sum += float64(x) * float64(x)
}
mag := math.Sqrt(sum)
if mag == 0 {
return v
}
out := make([]float32, len(v))
for i, x := range v {
out[i] = float32(float64(x) / mag)
}
return out
}
func fmtN(n int) string {
return map[int]string{8: "n=8", 32: "n=32", 128: "n=128", 512: "n=512"}[n]
}

View File

@@ -1,11 +0,0 @@
.cache/
sources/
build/
package/
parakeet-cpp-grpc
# build artifacts staged in-tree by the Makefile (cp from sources/) or
# symlinked for local dev; the real sources live in parakeet.cpp upstream.
*.so
*.so.*
parakeet_capi.h
compile_commands.json

View File

@@ -1,89 +0,0 @@
# parakeet-cpp backend Makefile.
#
# Upstream pin lives below as PARAKEET_VERSION?=30a307553f1965ceb38a1a922069a71e7dd67bf3
# (.github/bump_deps.sh) can find and update it - matches the
# whisper.cpp / ds4 / vibevoice-cpp convention.
#
# Local dev shortcut: if you already have an out-of-tree parakeet.cpp
# build, you can symlink the .so + header into this directory and skip
# the clone/cmake steps entirely, e.g.:
#
# ln -sf /path/to/parakeet.cpp/build-shared/libparakeet.so .
# ln -sf /path/to/parakeet.cpp/include/parakeet_capi.h .
# go build -o parakeet-cpp-grpc .
#
# That's what the L0 smoke test uses. The default target below does the
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
PARAKEET_VERSION?=30a307553f1965ceb38a1a922069a71e7dd67bf3
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
GOCMD?=go
GO_TAGS?=
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
BUILD_TYPE?=
NATIVE?=false
# Build ggml statically into libparakeet.so (PIC) so the shared lib is
# self-contained: dlopen needs no libggml*.so alongside it, only system libs
# (libstdc++/libgomp/libc) that the runtime image already provides.
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DPARAKEET_SHARED=ON -DPARAKEET_BUILD_CLI=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif
ifeq ($(BUILD_TYPE),cublas)
CMAKE_ARGS+=-DGGML_CUDA=ON
else ifeq ($(BUILD_TYPE),openblas)
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
else ifeq ($(BUILD_TYPE),hipblas)
CMAKE_ARGS+=-DGGML_HIP=ON
else ifeq ($(BUILD_TYPE),vulkan)
CMAKE_ARGS+=-DGGML_VULKAN=ON
endif
.PHONY: parakeet-cpp-grpc package build clean purge test all
all: parakeet-cpp-grpc
# Clone the upstream parakeet.cpp source at the pinned commit. Directory
# acts as the target so make only re-clones when missing. After a
# PARAKEET_VERSION bump, run 'make purge && make' to refetch.
sources/parakeet.cpp:
mkdir -p sources/parakeet.cpp
cd sources/parakeet.cpp && \
git init -q && \
git remote add origin $(PARAKEET_REPO) && \
git fetch --depth 1 origin $(PARAKEET_VERSION) && \
git checkout FETCH_HEAD && \
git submodule update --init --recursive --depth 1 --single-branch
# Build the shared lib + header out-of-tree, then stage them next to the
# Go sources so purego.Dlopen("libparakeet.so") and the cgo-less build
# both pick them up.
libparakeet.so: sources/parakeet.cpp
cmake -B sources/parakeet.cpp/build-shared -S sources/parakeet.cpp $(CMAKE_ARGS)
cmake --build sources/parakeet.cpp/build-shared --config Release -j$(JOBS)
cp -fv sources/parakeet.cpp/build-shared/libparakeet.so* ./ 2>/dev/null || true
cp -fv sources/parakeet.cpp/include/parakeet_capi.h ./
parakeet-cpp-grpc: libparakeet.so main.go goparakeetcpp.go
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o parakeet-cpp-grpc .
package: parakeet-cpp-grpc
bash package.sh
build: package
# Test target. Smoke test is gated on PARAKEET_BACKEND_TEST_MODEL +
# PARAKEET_BACKEND_TEST_WAV; without them the spec auto-skips.
test:
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
clean: purge
rm -rf libparakeet.so* parakeet_capi.h package parakeet-cpp-grpc
purge:
rm -rf sources/parakeet.cpp

View File

@@ -1,393 +0,0 @@
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"unsafe"
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// purego-bound entry points from libparakeet.so. Names match
// parakeet_capi.h exactly so a `nm libparakeet.so | grep parakeet_capi`
// is enough to spot drift.
//
// Functions that return char* are declared as uintptr so we can call
// parakeet_capi_free_string on the same pointer after copying, the
// C-API contract is "caller owns and must free the returned buffer".
var (
CppAbiVersion func() int32
CppLoad func(ggufPath string) uintptr
CppFree func(ctx uintptr)
CppTranscribePath func(ctx uintptr, wavPath string, decoder int32) uintptr
CppTranscribePathJSON func(ctx uintptr, wavPath string, decoder int32) uintptr
CppFreeString func(s uintptr)
CppLastError func(ctx uintptr) string
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
CppStreamBegin func(ctx uintptr) uintptr
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
CppStreamFinalize func(s uintptr) uintptr
CppStreamFree func(s uintptr)
)
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
// call (1 s). The session buffers internally and decodes once a full
// cache-aware encoder chunk is available, so this only bounds how often we
// poll for newly-finalized text, not the model's actual chunk size.
const streamChunkSamples = 16000
// transcriptJSON mirrors the document returned by
// parakeet_capi_transcribe_path_json (see parakeet_capi.h):
//
// {"text":"...",
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...],
// "tokens":[{"id":123,"t":0.480,"conf":0.9100}, ...]}
//
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
type transcriptJSON struct {
Text string `json:"text"`
Words []transcriptWord `json:"words"`
Tokens []transcriptToken `json:"tokens"`
}
type transcriptWord struct {
W string `json:"w"`
Start float64 `json:"start"`
End float64 `json:"end"`
Conf float64 `json:"conf"`
}
type transcriptToken struct {
ID int32 `json:"id"`
T float64 `json:"t"`
Conf float64 `json:"conf"`
}
// ParakeetCpp owns a single loaded parakeet_ctx. The C engine is a
// thread-unsafe singleton (mirrors whisper.cpp / vibevoice.cpp), so we
// serialize calls through base.SingleThread.
type ParakeetCpp struct {
base.SingleThread
ctxPtr uintptr
}
// Load is the LocalAI gRPC entry point for LoadModel: it calls
// parakeet_capi_load with the GGUF path and stashes the resulting
// opaque context pointer for AudioTranscription.
func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
if opts.ModelFile == "" {
return errors.New("parakeet-cpp: ModelFile is required")
}
ctx := CppLoad(opts.ModelFile)
if ctx == 0 {
// No ctx to ask for last_error (the C-API's last-error buffer
// lives on the ctx that was never returned). Surface the path
// so the operator at least knows which load failed.
return fmt.Errorf("parakeet-cpp: parakeet_capi_load failed for %q", opts.ModelFile)
}
p.ctxPtr = ctx
return nil
}
// AudioTranscription runs parakeet_capi_transcribe_path_json on the wav at
// opts.Dst with the default decoder (decoder=0, which selects the right head
// per architecture: transducer for tdt/rnnt/hybrid, CTC for ctc) and shapes
// the per-word timestamps into a LocalAI TranscriptResult.
//
// Parakeet emits word- and token-level timestamps but no native segment
// boundaries, so we synthesise a single whole-clip segment spanning the first
// word start to the last word end. Word-level timings are attached only when
// the caller opts in via timestamp_granularities=["word"] (matching the
// OpenAI API, whose default is segment-level); token ids always populate
// Segment.Tokens.
//
// translate/diarize/prompt/temperature/language/threads are not applicable to
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
// (L2).
func (p *ParakeetCpp) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if p.ctxPtr == 0 {
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
}
if opts.Dst == "" {
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
}
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
if cstr == 0 {
msg := CppLastError(p.ctxPtr)
if msg == "" {
msg = "unknown error"
}
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", msg)
}
raw := goStringFromCPtr(cstr)
CppFreeString(cstr)
var doc transcriptJSON
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
}
text := strings.TrimSpace(doc.Text)
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
for _, w := range doc.Words {
words = append(words, &pb.TranscriptWord{
Start: secondsToNanos(w.Start),
End: secondsToNanos(w.End),
Text: w.W,
})
}
tokens := make([]int32, 0, len(doc.Tokens))
for _, t := range doc.Tokens {
tokens = append(tokens, t.ID)
}
// Single whole-clip segment, spanning the first word start to the last
// word end (0/0 when the clip produced no words).
var segStart, segEnd int64
if len(words) > 0 {
segStart = words[0].Start
segEnd = words[len(words)-1].End
}
seg := &pb.TranscriptSegment{
Id: 0,
Start: segStart,
End: segEnd,
Text: text,
Tokens: tokens,
}
if wordsRequested(opts.TimestampGranularities) {
seg.Words = words
}
return pb.TranscriptResult{
Text: text,
Segments: []*pb.TranscriptSegment{seg},
}, nil
}
// wordsRequested reports whether the caller asked for word-level timestamps.
// The OpenAI transcription API gates word timings behind
// timestamp_granularities[] containing "word" and defaults to segment-level
// otherwise; we follow that contract.
func wordsRequested(granularities []string) bool {
for _, g := range granularities {
if strings.EqualFold(strings.TrimSpace(g), "word") {
return true
}
}
return false
}
// secondsToNanos converts the C-API's fractional-second timestamps into the
// int64 nanoseconds LocalAI carries on TranscriptSegment/TranscriptWord, the
// same nanosecond convention the whisper backend uses.
func secondsToNanos(sec float64) int64 {
return int64(sec * 1e9)
}
// AudioTranscriptionStream drives the cache-aware streaming RNN-T over the
// audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it in
// chunks to parakeet_capi_stream_feed, and emits each newly-finalized text
// run as a TranscriptStreamResponse delta. <EOU>/<EOB> events close the
// current segment; a closing FinalResult carries the full transcript and the
// per-utterance segments.
//
// stream_begin returns 0 for models that are not cache-aware streaming models
// (only e.g. nvidia/parakeet_realtime_eou_120m-v1 qualifies). For those we fall
// back to a single offline transcription emitted as one delta plus a closing
// FinalResult, matching LocalAI's non-streaming streaming contract (and the
// whisper backend), so the streaming endpoint works for every model.
func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
defer close(results)
if p.ctxPtr == 0 {
return errors.New("parakeet-cpp: model not loaded")
}
if opts.Dst == "" {
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
}
if err := ctx.Err(); err != nil {
return status.Error(codes.Canceled, "transcription cancelled")
}
stream := CppStreamBegin(p.ctxPtr)
if stream == 0 {
// Not a cache-aware streaming model: run a normal offline
// transcription and emit it as one delta + a closing final result.
res, err := p.AudioTranscription(ctx, opts)
if err != nil {
return err
}
if t := strings.TrimSpace(res.Text); t != "" {
results <- &pb.TranscriptStreamResponse{Delta: t}
}
results <- &pb.TranscriptStreamResponse{FinalResult: &res}
return nil
}
defer CppStreamFree(stream)
data, duration, err := decodeWavMono16k(opts.Dst)
if err != nil {
return err
}
var (
full strings.Builder
segText strings.Builder
segments []*pb.TranscriptSegment
segID int32
)
flushSegment := func() {
t := strings.TrimSpace(segText.String())
segText.Reset()
if t == "" {
return
}
segments = append(segments, &pb.TranscriptSegment{Id: segID, Text: t})
segID++
}
// emitDelta consumes the malloc'd char* returned by feed/finalize: frees
// it, accumulates the text, and sends a delta when non-empty. A 0 return
// is an error (vs the "" empty-but-non-NULL no-new-text case).
emitDelta := func(ret uintptr) error {
if ret == 0 {
msg := CppLastError(p.ctxPtr)
if msg == "" {
msg = "unknown error"
}
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
}
delta := goStringFromCPtr(ret)
CppFreeString(ret)
if delta == "" {
return nil
}
full.WriteString(delta)
segText.WriteString(delta)
results <- &pb.TranscriptStreamResponse{Delta: delta}
return nil
}
for off := 0; off < len(data); off += streamChunkSamples {
if err := ctx.Err(); err != nil {
return status.Error(codes.Canceled, "transcription cancelled")
}
end := min(off+streamChunkSamples, len(data))
chunk := data[off:end]
var eou int32
ret := CppStreamFeed(stream, chunk, int32(len(chunk)), unsafe.Pointer(&eou))
if err := emitDelta(ret); err != nil {
return err
}
if eou != 0 {
flushSegment()
}
}
// Flush the streaming tail (final encoder chunk).
if err := emitDelta(CppStreamFinalize(stream)); err != nil {
return err
}
flushSegment()
text := strings.TrimSpace(full.String())
if len(segments) == 0 && text != "" {
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
}
results <- &pb.TranscriptStreamResponse{
FinalResult: &pb.TranscriptResult{
Text: text,
Segments: segments,
Duration: duration,
},
}
return nil
}
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
// float samples plus the clip duration in seconds. Mirrors the whisper
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
// decodes the PCM.
func decodeWavMono16k(path string) ([]float32, float32, error) {
dir, err := os.MkdirTemp("", "parakeet")
if err != nil {
return nil, 0, err
}
defer func() { _ = os.RemoveAll(dir) }()
converted := filepath.Join(dir, "converted.wav")
if err := utils.AudioToWav(path, converted); err != nil {
return nil, 0, err
}
fh, err := os.Open(converted)
if err != nil {
return nil, 0, err
}
defer func() { _ = fh.Close() }()
buf, err := wav.NewDecoder(fh).FullPCMBuffer()
if err != nil {
return nil, 0, err
}
data := buf.AsFloat32Buffer().Data
var duration float32
if buf.Format != nil && buf.Format.SampleRate > 0 {
duration = float32(len(data)) / float32(buf.Format.SampleRate)
}
return data, duration, nil
}
// Free releases the underlying parakeet_ctx. Called by LocalAI when the
// model is unloaded.
func (p *ParakeetCpp) Free() error {
if p.ctxPtr != 0 {
CppFree(p.ctxPtr)
p.ctxPtr = 0
}
return nil
}
// goStringFromCPtr copies a NUL-terminated C string into Go memory.
// cptr is the raw pointer returned by purego from the C-API (a malloc'd
// buffer the caller owns); callers must free it via CppFreeString after
// the copy lands.
//
// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr
// check, which can't distinguish a C-owned heap pointer from Go-managed
// memory. It is safe here: the pointer addresses a malloc'd C buffer the
// Go GC neither tracks nor moves, and we dereference it immediately to
// copy the bytes out, the same pattern (and the same tolerated warning)
// as the whisper backend's unsafe.Slice over segsPtr.
func goStringFromCPtr(cptr uintptr) string {
if cptr == 0 {
return ""
}
p := unsafe.Pointer(cptr) //nolint:govet // C-owned malloc'd buffer, not Go-GC memory (see doc above)
n := 0
for *(*byte)(unsafe.Add(p, n)) != 0 {
n++
}
return string(unsafe.Slice((*byte)(p), n))
}

View File

@@ -1,164 +0,0 @@
package main
import (
"context"
"os"
"strings"
"sync"
"testing"
"github.com/ebitengine/purego"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestParakeetCpp(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "parakeet-cpp Backend Suite")
}
var (
libLoadOnce sync.Once
libLoadErr error
)
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive
// the C-API bridge without spinning up the gRPC server. Skips the
// current spec when libparakeet.so isn't loadable from cwd
// ($LD_LIBRARY_PATH or a symlink in ./).
func ensureLibLoaded() {
libLoadOnce.Do(func() {
libName := os.Getenv("PARAKEET_LIBRARY")
if libName == "" {
libName = "libparakeet.so"
}
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libLoadErr = err
return
}
purego.RegisterLibFunc(&CppAbiVersion, lib, "parakeet_capi_abi_version")
purego.RegisterLibFunc(&CppLoad, lib, "parakeet_capi_load")
purego.RegisterLibFunc(&CppFree, lib, "parakeet_capi_free")
purego.RegisterLibFunc(&CppTranscribePath, lib, "parakeet_capi_transcribe_path")
purego.RegisterLibFunc(&CppTranscribePathJSON, lib, "parakeet_capi_transcribe_path_json")
purego.RegisterLibFunc(&CppStreamBegin, lib, "parakeet_capi_stream_begin")
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
})
if libLoadErr != nil {
Skip("libparakeet.so not loadable: " + libLoadErr.Error())
}
}
// fixturesOrSkip returns the model + audio paths or skips the spec if
// either env var is unset. The smoke test never runs in default CI; it
// needs a real parakeet GGUF and a 16 kHz mono WAV on disk.
func fixturesOrSkip() (string, string) {
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_MODEL")
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
if modelPath == "" || audioPath == "" {
Skip("set PARAKEET_BACKEND_TEST_MODEL and PARAKEET_BACKEND_TEST_WAV to run this spec")
}
return modelPath, audioPath
}
var _ = Describe("ParakeetCpp", func() {
Context("AudioTranscription", func() {
It("transcribes a WAV via the parakeet C-API", func() {
modelPath, audioPath := fixturesOrSkip()
ensureLibLoaded()
p := &ParakeetCpp{}
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
defer func() { _ = p.Free() }()
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: audioPath,
})
Expect(err).ToNot(HaveOccurred())
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
"expected non-empty transcript for %s", audioPath)
Expect(res.Segments).To(HaveLen(1),
"synthesises a single whole-clip segment")
Expect(res.Segments[0].Text).To(Equal(res.Text),
"single segment text must equal the top-level text")
// Default (no granularities) is segment-level: no per-word timings.
Expect(res.Segments[0].Words).To(BeEmpty(),
"word timings are opt-in via timestamp_granularities")
})
It("emits word-level timestamps when granularity=word", func() {
modelPath, audioPath := fixturesOrSkip()
ensureLibLoaded()
p := &ParakeetCpp{}
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
defer func() { _ = p.Free() }()
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: audioPath,
TimestampGranularities: []string{"word"},
})
Expect(err).ToNot(HaveOccurred())
Expect(res.Segments).To(HaveLen(1))
seg := res.Segments[0]
Expect(seg.Words).ToNot(BeEmpty(),
"expected per-word timestamps with granularity=word")
// Monotonic, non-negative timings spanning the segment.
Expect(seg.Words[0].Start).To(BeNumerically(">=", int64(0)))
Expect(seg.End).To(BeNumerically(">=", seg.Start))
Expect(seg.Words[len(seg.Words)-1].End).To(Equal(seg.End),
"segment end tracks the last word")
})
})
Context("AudioTranscriptionStream", func() {
It("streams deltas and a closing FinalResult from a cache-aware model", func() {
// Streaming needs a cache-aware streaming model (e.g.
// realtime_eou); the offline test model would fail stream_begin.
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_STREAM_MODEL")
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
if modelPath == "" || audioPath == "" {
Skip("set PARAKEET_BACKEND_TEST_STREAM_MODEL (cache-aware streaming model) and PARAKEET_BACKEND_TEST_WAV")
}
ensureLibLoaded()
p := &ParakeetCpp{}
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
defer func() { _ = p.Free() }()
results := make(chan *pb.TranscriptStreamResponse, 64)
errCh := make(chan error, 1)
go func() {
errCh <- p.AudioTranscriptionStream(context.Background(),
&pb.TranscriptRequest{Dst: audioPath}, results)
}()
var deltas []string
var final *pb.TranscriptResult
for r := range results {
if r.Delta != "" {
deltas = append(deltas, r.Delta)
}
if r.FinalResult != nil {
final = r.FinalResult
}
}
Expect(<-errCh).ToNot(HaveOccurred())
Expect(final).ToNot(BeNil(), "expected a closing FinalResult")
Expect(strings.TrimSpace(final.Text)).ToNot(BeEmpty(),
"expected a non-empty streamed transcript")
Expect(final.Segments).ToNot(BeEmpty(),
"FinalResult always carries at least one segment")
// The concatenated deltas reconstruct the final transcript.
Expect(strings.TrimSpace(strings.Join(deltas, ""))).To(Equal(strings.TrimSpace(final.Text)),
"deltas should reconstruct the final text")
})
})
})

View File

@@ -1,68 +0,0 @@
package main
// Started internally by LocalAI - one gRPC server per loaded model.
//
// Loads libparakeet.so via purego and registers the flat C-API entry
// points declared in parakeet_capi.h. The library name can be overridden
// with PARAKEET_LIBRARY (mirrors the WHISPER_LIBRARY / VIBEVOICECPP_LIBRARY
// convention in the sibling backends); the default looks for the .so next
// to this binary.
import (
"flag"
"fmt"
"os"
"github.com/ebitengine/purego"
grpc "github.com/mudler/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
type LibFuncs struct {
FuncPtr any
Name string
}
func main() {
libName := os.Getenv("PARAKEET_LIBRARY")
if libName == "" {
libName = "libparakeet.so"
}
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
panic(fmt.Errorf("parakeet-cpp: dlopen %q: %w", libName, err))
}
// Bound 1:1 to parakeet_capi.h. The C-API returns malloc'd char*
// buffers from transcribe_*; we register those as uintptr so we get
// the raw pointer back and can call parakeet_capi_free_string on it
// (purego's string return would copy and forget the original pointer,
// leaking it on every call).
libFuncs := []LibFuncs{
{&CppAbiVersion, "parakeet_capi_abi_version"},
{&CppLoad, "parakeet_capi_load"},
{&CppFree, "parakeet_capi_free"},
{&CppTranscribePath, "parakeet_capi_transcribe_path"},
{&CppTranscribePathJSON, "parakeet_capi_transcribe_path_json"},
{&CppStreamBegin, "parakeet_capi_stream_begin"},
{&CppStreamFeed, "parakeet_capi_stream_feed"},
{&CppStreamFinalize, "parakeet_capi_stream_finalize"},
{&CppStreamFree, "parakeet_capi_stream_free"},
{&CppFreeString, "parakeet_capi_free_string"},
{&CppLastError, "parakeet_capi_last_error"},
}
for _, lf := range libFuncs {
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
}
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
flag.Parse()
if err := grpc.StartServer(*addr, &ParakeetCpp{}); err != nil {
panic(err)
}
}

View File

@@ -1,23 +0,0 @@
#!/bin/bash
#
# L0 packaging stub: copy the binary, run.sh and libparakeet.so* into
# package/. The full ldd walk (libc, libstdc++, libgomp, GPU runtimes,
# arch detection) lands in L3, mirroring backend/go/whisper/package.sh.
set -e
CURDIR=$(dirname "$(realpath "$0")")
mkdir -p "$CURDIR/package/lib"
cp -avf "$CURDIR/parakeet-cpp-grpc" "$CURDIR/package/"
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
# libparakeet.so + any soname symlinks (libparakeet.so.X, libparakeet.so.X.Y).
cp -avf "$CURDIR"/libparakeet.so* "$CURDIR/package/lib/" 2>/dev/null || {
echo "ERROR: libparakeet.so not found in $CURDIR, run 'make' first" >&2
exit 1
}
echo "L0 package layout (full ldd walk lands in L3):"
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"

View File

@@ -1,16 +0,0 @@
#!/bin/bash
set -e
CURDIR=$(dirname "$(realpath "$0")")
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
# If a self-contained ld.so was packaged, route through it so the
# packaged libc / libstdc++ are used instead of the host's (matches the
# whisper backend's runtime layout).
if [ -f "$CURDIR/lib/ld.so" ]; then
echo "Using lib/ld.so"
exec "$CURDIR/lib/ld.so" "$CURDIR/parakeet-cpp-grpc" "$@"
fi
exec "$CURDIR/parakeet-cpp-grpc" "$@"

View File

@@ -1,7 +0,0 @@
sources/
build*/
package/
librfdetrcpp*.so
rfdetr-cpp
test-models/
test-data/

View File

@@ -1,79 +0,0 @@
cmake_minimum_required(VERSION 3.18)
project(librfdetrcpp LANGUAGES C CXX)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Static-link ggml + rfdetr so the resulting .so has no runtime dependency on
# extra ggml/rfdetr shared libraries — only on libc/libstdc++/libgomp, which
# the LocalAI package step bundles into the docker image.
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build static libraries" FORCE)
# rfdetr.cpp build switches: skip CLI/tests, keep static lib.
set(RFDETR_BUILD_CLI OFF CACHE BOOL "Disable rfdetr CLI" FORCE)
set(RFDETR_BUILD_TESTS OFF CACHE BOOL "Disable rfdetr tests" FORCE)
set(RFDETR_SHARED OFF CACHE BOOL "Build rfdetr as static lib" FORCE)
# rt-detr.cpp's top-level CMakeLists invokes
# `bash ${CMAKE_SOURCE_DIR}/scripts/apply_ggml_patches.sh` to apply its
# in-tree ggml patches before descending into the submodule. When we
# `add_subdirectory` it from a parent project, `CMAKE_SOURCE_DIR` points
# at *our* directory, not theirs, so the script path resolves wrong.
#
# Run the patches script ourselves up front (it's idempotent — re-running
# is a no-op once patches are applied) so the rt-detr.cpp configure step
# is essentially a no-op for the patch hook.
set(RFDETR_CPP_SRC ${CMAKE_CURRENT_SOURCE_DIR}/sources/rt-detr.cpp)
if(EXISTS ${RFDETR_CPP_SRC}/scripts/apply_ggml_patches.sh)
execute_process(
COMMAND bash ${RFDETR_CPP_SRC}/scripts/apply_ggml_patches.sh
RESULT_VARIABLE _rfdetr_patch_result
OUTPUT_VARIABLE _rfdetr_patch_output
ERROR_VARIABLE _rfdetr_patch_error
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_STRIP_TRAILING_WHITESPACE)
if(NOT _rfdetr_patch_result EQUAL 0)
message(FATAL_ERROR
"Failed to apply ggml patches (exit ${_rfdetr_patch_result}):\n"
"stdout:\n${_rfdetr_patch_output}\n"
"stderr:\n${_rfdetr_patch_error}")
endif()
message(STATUS "${_rfdetr_patch_output}")
endif()
# Stage a shim 'scripts/apply_ggml_patches.sh' under our source dir so that
# rt-detr.cpp's CMakeLists — which calls
# bash ${CMAKE_SOURCE_DIR}/scripts/apply_ggml_patches.sh
# — finds an idempotent no-op there. The real patches have already been
# applied above; this just satisfies the path lookup.
file(MAKE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts)
file(WRITE ${CMAKE_CURRENT_SOURCE_DIR}/scripts/apply_ggml_patches.sh
"#!/usr/bin/env bash
# Shim - patches were already applied by the parent CMakeLists.
exit 0
")
execute_process(COMMAND chmod +x ${CMAKE_CURRENT_SOURCE_DIR}/scripts/apply_ggml_patches.sh)
add_subdirectory(./sources/rt-detr.cpp)
# rfdetr.cpp's C-API symbols already live inside librfdetr (src/rfdetr_capi.cpp
# is compiled into the lib). We re-export them via a MODULE library that
# whole-archive-links rfdetr so the symbols are visible at dlopen time.
add_library(rfdetrcpp MODULE
sources/rt-detr.cpp/src/rfdetr_capi.cpp)
target_include_directories(rfdetrcpp PRIVATE
sources/rt-detr.cpp/include
sources/rt-detr.cpp/src
sources/rt-detr.cpp/third_party/stb
)
target_link_libraries(rfdetrcpp PRIVATE rfdetr ggml)
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
target_link_libraries(rfdetrcpp PRIVATE stdc++fs)
endif()
set_property(TARGET rfdetrcpp PROPERTY CXX_STANDARD 17)
set_target_properties(rfdetrcpp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})

View File

@@ -1,135 +0,0 @@
CMAKE_ARGS?=
BUILD_TYPE?=
NATIVE?=false
GOCMD?=go
GO_TAGS?=
JOBS?=$(shell nproc --ignore=1)
# rt-detr.cpp (GitHub redirects the historical mudler/rt-detr.cpp to the new
# mudler/rf-detr.cpp slug). Pin to a specific commit if you need a stable
# build; leaving this on `master` always picks up the latest C-API surface
# (incl. the per-detection accessor functions used by gorfdetrcpp.go).
RFDETR_REPO?=https://github.com/mudler/rf-detr.cpp.git
RFDETR_VERSION?=65c0ffcc9a9bc9dae38252f63d0417c9845a6cf7
ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif
# Forward LocalAI's BUILD_TYPE to the matching ggml backend switch.
ifeq ($(BUILD_TYPE),cublas)
CMAKE_ARGS+=-DGGML_CUDA=ON -DRFDETR_GGML_CUDA=ON
else ifeq ($(BUILD_TYPE),openblas)
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
else ifeq ($(BUILD_TYPE),clblas)
CMAKE_ARGS+=-DGGML_CLBLAST=ON
else ifeq ($(BUILD_TYPE),hipblas)
ROCM_HOME ?= /opt/rocm
ROCM_PATH ?= /opt/rocm
export CXX=$(ROCM_HOME)/llvm/bin/clang++
export CC=$(ROCM_HOME)/llvm/bin/clang
AMDGPU_TARGETS?=gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DRFDETR_GGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
else ifeq ($(BUILD_TYPE),vulkan)
CMAKE_ARGS+=-DGGML_VULKAN=ON -DRFDETR_GGML_VULKAN=ON
else ifeq ($(OS),Darwin)
ifneq ($(BUILD_TYPE),metal)
CMAKE_ARGS+=-DGGML_METAL=OFF
else
CMAKE_ARGS+=-DGGML_METAL=ON
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
CMAKE_ARGS+=-DRFDETR_GGML_METAL=ON
endif
endif
ifeq ($(BUILD_TYPE),sycl_f16)
CMAKE_ARGS+=-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx \
-DGGML_SYCL_F16=ON
endif
ifeq ($(BUILD_TYPE),sycl_f32)
CMAKE_ARGS+=-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx
endif
sources/rt-detr.cpp:
mkdir -p sources && \
git clone --recursive $(RFDETR_REPO) sources/rt-detr.cpp && \
cd sources/rt-detr.cpp && \
git checkout $(RFDETR_VERSION) && \
git submodule update --init --recursive --depth 1 --single-branch
# Detect OS
UNAME_S := $(shell uname -s)
# Only build CPU variants on Linux
ifeq ($(UNAME_S),Linux)
VARIANT_TARGETS = librfdetrcpp-avx.so librfdetrcpp-avx2.so librfdetrcpp-avx512.so librfdetrcpp-fallback.so
else
# On non-Linux (e.g., Darwin), build only fallback variant
VARIANT_TARGETS = librfdetrcpp-fallback.so
endif
rfdetr-cpp: main.go gorfdetrcpp.go $(VARIANT_TARGETS)
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o rfdetr-cpp ./
package: rfdetr-cpp
bash package.sh
build: package
clean: purge
rm -rf librfdetrcpp*.so rfdetr-cpp package sources
purge:
rm -rf build*
# Build all variants (Linux only)
ifeq ($(UNAME_S),Linux)
librfdetrcpp-avx.so: sources/rt-detr.cpp
rm -rfv build-$@
$(info ${GREEN}I rfdetr-cpp build info:avx${RESET})
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) librfdetrcpp-custom
rm -rfv build-$@
librfdetrcpp-avx2.so: sources/rt-detr.cpp
rm -rfv build-$@
$(info ${GREEN}I rfdetr-cpp build info:avx2${RESET})
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) librfdetrcpp-custom
rm -rfv build-$@
librfdetrcpp-avx512.so: sources/rt-detr.cpp
rm -rfv build-$@
$(info ${GREEN}I rfdetr-cpp build info:avx512${RESET})
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) librfdetrcpp-custom
rm -rfv build-$@
endif
# Build fallback variant (all platforms)
librfdetrcpp-fallback.so: sources/rt-detr.cpp
rm -rfv build-$@
$(info ${GREEN}I rfdetr-cpp build info:fallback${RESET})
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) librfdetrcpp-custom
rm -rfv build-$@
librfdetrcpp-custom: CMakeLists.txt
mkdir -p build-$(SO_TARGET) && \
cd build-$(SO_TARGET) && \
cmake .. $(CMAKE_ARGS) && \
cmake --build . --config Release -j$(JOBS) && \
cd .. && \
mv build-$(SO_TARGET)/librfdetrcpp.so ./$(SO_TARGET)
all: rfdetr-cpp package
# `test` is invoked by the top-level Makefile's `test-extra` target. It builds
# the backend binary + the fallback shared library (needed for dlopen at
# runtime), then runs test.sh which downloads the test models + COCO image
# and exercises the gRPC Load/Detect wire path via the Go smoke test in
# main_test.go for both the detection and segmentation models.
test: rfdetr-cpp librfdetrcpp-fallback.so
bash test.sh

View File

@@ -1,195 +0,0 @@
package main
// gorfdetrcpp.go - gRPC handlers (Load, Detect) for the rfdetr-cpp backend.
//
// Embeds base.SingleThread to default unimplemented RPCs to "not supported"
// while we only implement object detection.
import (
"encoding/base64"
"fmt"
"os"
"path/filepath"
"strconv"
"unsafe"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// Default upper bound on detections returned per image. RF-DETR's decoder
// queries are limited to a few hundred; 300 is a safe ceiling.
const defaultTopK = 300
// rfdetr_handle_t is a uintptr-typed opaque handle (see include/rfdetr_capi.h).
var (
// rfdetr_capi_load(const char* model_path, int n_threads, rfdetr_handle_t* out_handle) -> int
CapiLoad func(modelPath string, nThreads int32, outHandle *uintptr) int32
// rfdetr_capi_unload(rfdetr_handle_t handle) -> int
CapiUnload func(handle uintptr) int32
// rfdetr_capi_detect_path(handle, image_path, threshold, top_k, out_json) -> int
CapiDetectPath func(handle uintptr, imagePath string, threshold float32, topK uint32, outJSON *uintptr) int32
// rfdetr_capi_detect_buffer(handle, bytes, len, threshold, top_k, out_json) -> int
CapiDetectBuffer func(handle uintptr, bytes uintptr, length uintptr, threshold float32, topK uint32, outJSON *uintptr) int32
// rfdetr_capi_free_string(char* s)
CapiFreeString func(s uintptr)
// rfdetr_capi_get_n_detections(handle) -> int
CapiGetNDetections func(handle uintptr) int32
// rfdetr_capi_get_detection_class_id(handle, i) -> int
CapiGetDetectionClassID func(handle uintptr, i int32) int32
// rfdetr_capi_get_detection_box(handle, i, out_xyxy[4]) -> int (0 on success)
CapiGetDetectionBox func(handle uintptr, i int32, outXYXY uintptr) int32
// rfdetr_capi_get_detection_score(handle, i) -> float
CapiGetDetectionScore func(handle uintptr, i int32) float32
// rfdetr_capi_get_detection_class_name(handle, i, buf, buf_size) -> int (needed/written; two-call sizing)
CapiGetDetectionClassName func(handle uintptr, i int32, buf uintptr, bufSize int32) int32
// rfdetr_capi_get_detection_mask_png(handle, i, buf, buf_size) -> int (needed/written; 0 means no mask)
CapiGetDetectionMaskPNG func(handle uintptr, i int32, buf uintptr, bufSize int32) int32
)
type RFDetrCpp struct {
base.SingleThread
handle uintptr
}
// Load loads the GGUF model at opts.ModelFile (joined with opts.ModelPath if relative)
// and stores the handle for later Detect calls.
func (r *RFDetrCpp) Load(opts *pb.ModelOptions) error {
modelFile := opts.ModelFile
if modelFile == "" {
modelFile = opts.Model
}
if modelFile == "" {
return fmt.Errorf("rfdetr-cpp: ModelFile is empty")
}
var modelPath string
if filepath.IsAbs(modelFile) {
modelPath = modelFile
} else {
modelPath = filepath.Join(opts.ModelPath, modelFile)
}
if _, err := os.Stat(modelPath); err != nil {
return fmt.Errorf("rfdetr-cpp: model file not found: %s: %w", modelPath, err)
}
threads := opts.Threads
if threads <= 0 {
threads = 4
}
// Release previous model if any (re-Load).
if r.handle != 0 {
CapiUnload(r.handle)
r.handle = 0
}
var h uintptr
rc := CapiLoad(modelPath, threads, &h)
if rc != 0 || h == 0 {
return fmt.Errorf("rfdetr-cpp: rfdetr_capi_load failed with rc=%d for %s", rc, modelPath)
}
r.handle = h
return nil
}
// Detect runs object detection on the base64-encoded image in opts.Src at
// opts.Threshold, returning one pb.Detection per result. Seg models also
// populate Detection.Mask with PNG-encoded mask bytes.
func (r *RFDetrCpp) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
if r.handle == 0 {
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: model not loaded")
}
// Decode base64 image and write to temp file.
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
if err != nil {
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: failed to decode base64 image: %w", err)
}
tmpFile, err := os.CreateTemp("", "rfdetr-*.img")
if err != nil {
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: failed to create temp file: %w", err)
}
defer func() { _ = os.Remove(tmpFile.Name()) }()
if _, err := tmpFile.Write(imgData); err != nil {
_ = tmpFile.Close()
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: failed to write temp file: %w", err)
}
if err := tmpFile.Close(); err != nil {
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: failed to close temp file: %w", err)
}
threshold := opts.Threshold
if threshold <= 0 {
threshold = 0.5
}
// JSON output from detect_path is unused: we read structured detections via
// the accessor functions. Still must free the returned string.
var jsonPtr uintptr
rc := CapiDetectPath(r.handle, tmpFile.Name(), threshold, uint32(defaultTopK), &jsonPtr)
if jsonPtr != 0 {
CapiFreeString(jsonPtr)
}
if rc != 0 {
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: detect failed with rc=%d", rc)
}
n := CapiGetNDetections(r.handle)
if n < 0 {
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: invalid n_detections=%d", n)
}
detections := make([]*pb.Detection, 0, n)
for i := int32(0); i < n; i++ {
var bbox [4]float32 // x1, y1, x2, y2
if rc := CapiGetDetectionBox(r.handle, i, uintptr(unsafe.Pointer(&bbox[0]))); rc != 0 {
continue
}
cid := CapiGetDetectionClassID(r.handle, i)
score := CapiGetDetectionScore(r.handle, i)
// Two-call sizing for class_name.
var className string
nameSize := CapiGetDetectionClassName(r.handle, i, 0, 0)
if nameSize > 1 {
buf := make([]byte, nameSize)
written := CapiGetDetectionClassName(r.handle, i, uintptr(unsafe.Pointer(&buf[0])), nameSize)
// `written` is the same number (needed bytes including NUL); strip NUL.
if written > 0 && int(written) <= len(buf) {
className = string(buf[:written-1])
} else {
className = string(buf[:len(buf)-1])
}
}
if className == "" {
className = strconv.Itoa(int(cid))
}
// Two-call sizing for mask PNG (returns 0 when no mask).
var mask []byte
maskSize := CapiGetDetectionMaskPNG(r.handle, i, 0, 0)
if maskSize > 0 {
maskBuf := make([]byte, maskSize)
CapiGetDetectionMaskPNG(r.handle, i, uintptr(unsafe.Pointer(&maskBuf[0])), maskSize)
mask = maskBuf
}
detections = append(detections, &pb.Detection{
X: bbox[0],
Y: bbox[1],
Width: bbox[2] - bbox[0],
Height: bbox[3] - bbox[1],
Confidence: score,
ClassName: className,
Mask: mask,
})
}
return pb.DetectResponse{
Detections: detections,
}, nil
}

View File

@@ -1,61 +0,0 @@
package main
// main.go - entry point for the rfdetr-cpp gRPC backend.
//
// Dlopens librfdetrcpp-<variant>.so via purego at the path in
// RFDETR_LIBRARY (set by run.sh based on /proc/cpuinfo), registers the
// rfdetr_capi_* C ABI symbols, then starts the gRPC server.
import (
"flag"
"os"
"github.com/ebitengine/purego"
grpc "github.com/mudler/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
type LibFuncs struct {
FuncPtr any
Name string
}
func main() {
// Get library name from environment variable, default to fallback
libName := os.Getenv("RFDETR_LIBRARY")
if libName == "" {
libName = "./librfdetrcpp-fallback.so"
}
rfdetrLib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
panic(err)
}
libFuncs := []LibFuncs{
{&CapiLoad, "rfdetr_capi_load"},
{&CapiUnload, "rfdetr_capi_unload"},
{&CapiDetectPath, "rfdetr_capi_detect_path"},
{&CapiDetectBuffer, "rfdetr_capi_detect_buffer"},
{&CapiFreeString, "rfdetr_capi_free_string"},
{&CapiGetNDetections, "rfdetr_capi_get_n_detections"},
{&CapiGetDetectionClassID, "rfdetr_capi_get_detection_class_id"},
{&CapiGetDetectionBox, "rfdetr_capi_get_detection_box"},
{&CapiGetDetectionScore, "rfdetr_capi_get_detection_score"},
{&CapiGetDetectionClassName, "rfdetr_capi_get_detection_class_name"},
{&CapiGetDetectionMaskPNG, "rfdetr_capi_get_detection_mask_png"},
}
for _, lf := range libFuncs {
purego.RegisterLibFunc(lf.FuncPtr, rfdetrLib, lf.Name)
}
flag.Parse()
if err := grpc.StartServer(*addr, &RFDetrCpp{}); err != nil {
panic(err)
}
}

View File

@@ -1,220 +0,0 @@
package main
// main_test.go - end-to-end smoke test for the rfdetr-cpp gRPC backend.
//
// Spawns the compiled rfdetr-cpp binary on a free local port, dials it via
// gRPC, and exercises LoadModel + Detect against the test fixtures
// downloaded by test.sh. Two scenarios:
//
// 1. detection — loads rfdetr-nano-q8_0.gguf and asserts at least one
// detection comes back with a non-empty class name and a bounding box
// of non-zero size.
// 2. segmentation — loads rfdetr-seg-nano-q8_0.gguf and additionally
// asserts that at least one detection carries a PNG-encoded mask blob
// (verified by PNG magic bytes).
//
// Both specs Skip cleanly if their fixtures are missing so the test target
// stays usable on a fresh checkout where models haven't been downloaded.
import (
"context"
"encoding/base64"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"testing"
"time"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func TestRFDetrCpp(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "rfdetr-cpp backend smoke suite")
}
// freePort grabs an ephemeral TCP port and immediately releases it so the
// spawned backend can bind to it. There is a tiny TOCTOU window here but in
// practice it's adequate for a smoke test on a quiet runner.
func freePort() int {
l, err := net.Listen("tcp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred(), "freePort listen")
port := l.Addr().(*net.TCPAddr).Port
Expect(l.Close()).To(Succeed())
return port
}
// startBackend spawns the rfdetr-cpp binary on the given port and waits
// until it accepts TCP connections (up to 10s). The returned cleanup func
// kills the process and reaps it.
func startBackend(port int) func() {
binary, err := filepath.Abs("./rfdetr-cpp")
Expect(err).ToNot(HaveOccurred())
if _, err := os.Stat(binary); err != nil {
Skip(fmt.Sprintf("backend binary not built: %s (run `make rfdetr-cpp` first)", binary))
}
libPath, err := filepath.Abs("./librfdetrcpp-fallback.so")
Expect(err).ToNot(HaveOccurred())
if _, err := os.Stat(libPath); err != nil {
Skip(fmt.Sprintf("fallback library not built: %s (run `make librfdetrcpp-fallback.so` first)", libPath))
}
addr := fmt.Sprintf("127.0.0.1:%d", port)
cmd := exec.Command(binary, "--addr", addr)
cmd.Env = append(os.Environ(), "RFDETR_LIBRARY="+libPath)
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
Expect(cmd.Start()).To(Succeed())
cleanup := func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
}
}
deadline := time.Now().Add(10 * time.Second)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
if err == nil {
_ = c.Close()
return cleanup
}
time.Sleep(200 * time.Millisecond)
}
cleanup()
Fail(fmt.Sprintf("backend did not become ready on %s within 10s", addr))
return func() {}
}
// loadTestImage reads the COCO test image downloaded by test.sh and returns
// its base64-encoded content (the wire format accepted by the Detect RPC).
func loadTestImage() string {
imgPath, err := filepath.Abs("test-data/test.jpg")
Expect(err).ToNot(HaveOccurred())
imgBytes, err := os.ReadFile(imgPath)
if err != nil {
Skip(fmt.Sprintf("test image not present: %s (run test.sh first)", imgPath))
}
return base64.StdEncoding.EncodeToString(imgBytes)
}
// dialBackend opens a gRPC client connection to the spawned backend.
func dialBackend(port int) (pb.BackendClient, func()) {
addr := fmt.Sprintf("127.0.0.1:%d", port)
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expect(err).ToNot(HaveOccurred())
return pb.NewBackendClient(conn), func() { _ = conn.Close() }
}
// modelPathOrSkip resolves a model file under ./test-models/ and Skip()s
// the current spec if it's missing.
func modelPathOrSkip(name string) string {
modelDir, err := filepath.Abs("test-models")
Expect(err).ToNot(HaveOccurred())
modelPath := filepath.Join(modelDir, name)
if _, err := os.Stat(modelPath); err != nil {
Skip(fmt.Sprintf("model not present: %s (run test.sh first)", modelPath))
}
return modelPath
}
var _ = Describe("rfdetr-cpp backend", func() {
It("runs object detection against a known-good COCO image", func() {
modelPath := modelPathOrSkip("rfdetr-nano-q8_0.gguf")
imgB64 := loadTestImage()
port := freePort()
cleanup := startBackend(port)
defer cleanup()
client, closeConn := dialBackend(port)
defer closeConn()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
loadResp, err := client.LoadModel(ctx, &pb.ModelOptions{
Model: "rfdetr-nano-q8_0.gguf",
ModelFile: modelPath,
Threads: 2,
})
Expect(err).ToNot(HaveOccurred(), "LoadModel")
Expect(loadResp.GetSuccess()).To(BeTrue(), "LoadModel reported failure: %s", loadResp.GetMessage())
detResp, err := client.Detect(ctx, &pb.DetectOptions{
Src: imgB64,
Threshold: 0.5,
})
Expect(err).ToNot(HaveOccurred(), "Detect")
Expect(detResp.GetDetections()).ToNot(BeEmpty(), "no detections returned on a known-good COCO image")
_, _ = fmt.Fprintf(GinkgoWriter, "detection OK: %d detections\n", len(detResp.GetDetections()))
for i, d := range detResp.GetDetections() {
Expect(d.GetClassName()).ToNot(BeEmpty(), "detection %d has empty class_name", i)
Expect(d.GetConfidence()).To(BeNumerically(">=", float32(0.5)),
"detection %d below threshold", i)
Expect(d.GetWidth()).To(BeNumerically(">", float32(0)),
"detection %d has non-positive width", i)
Expect(d.GetHeight()).To(BeNumerically(">", float32(0)),
"detection %d has non-positive height", i)
}
})
It("runs segmentation and returns PNG-encoded masks", func() {
modelPath := modelPathOrSkip("rfdetr-seg-nano-q8_0.gguf")
imgB64 := loadTestImage()
port := freePort()
cleanup := startBackend(port)
defer cleanup()
client, closeConn := dialBackend(port)
defer closeConn()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
loadResp, err := client.LoadModel(ctx, &pb.ModelOptions{
Model: "rfdetr-seg-nano-q8_0.gguf",
ModelFile: modelPath,
Threads: 2,
})
Expect(err).ToNot(HaveOccurred(), "LoadModel")
Expect(loadResp.GetSuccess()).To(BeTrue(), "LoadModel reported failure: %s", loadResp.GetMessage())
detResp, err := client.Detect(ctx, &pb.DetectOptions{
Src: imgB64,
Threshold: 0.5,
})
Expect(err).ToNot(HaveOccurred(), "Detect")
Expect(detResp.GetDetections()).ToNot(BeEmpty(), "no detections returned from segmentation model")
haveMask := false
for i, d := range detResp.GetDetections() {
m := d.GetMask()
if len(m) == 0 {
continue
}
haveMask = true
// Verify PNG magic: 89 50 4E 47 ("\x89PNG").
Expect(len(m)).To(BeNumerically(">=", 4), "detection %d mask too short", i)
Expect([]byte{m[0], m[1], m[2], m[3]}).To(Equal([]byte{0x89, 'P', 'N', 'G'}),
"detection %d mask is not a PNG", i)
}
Expect(haveMask).To(BeTrue(),
"segmentation model returned %d detections but none carried a mask",
len(detResp.GetDetections()))
_, _ = fmt.Fprintf(GinkgoWriter, "segmentation OK: %d detections, at least one with PNG mask\n",
len(detResp.GetDetections()))
})
})

View File

@@ -1,59 +0,0 @@
#!/bin/bash
# Script to copy the appropriate libraries based on architecture
set -e
CURDIR=$(dirname "$(realpath $0)")
REPO_ROOT="${CURDIR}/../../.."
# Create lib directory
mkdir -p $CURDIR/package/lib
cp -avf $CURDIR/librfdetrcpp-*.so $CURDIR/package/
cp -avf $CURDIR/rfdetr-cpp $CURDIR/package/
cp -fv $CURDIR/run.sh $CURDIR/package/
# Detect architecture and copy appropriate libraries
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
# x86_64 architecture
echo "Detected x86_64 architecture, copying x86_64 libraries..."
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
# ARM64 architecture
echo "Detected ARM64 architecture, copying ARM64 libraries..."
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
elif [ $(uname -s) = "Darwin" ]; then
echo "Detected Darwin"
else
echo "Error: Could not detect architecture"
exit 1
fi
# Package GPU libraries based on BUILD_TYPE
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
if [ -f "$GPU_LIB_SCRIPT" ]; then
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
package_gpu_libs
fi
echo "Packaging completed successfully"
ls -liah $CURDIR/package/
ls -liah $CURDIR/package/lib/

View File

@@ -1,52 +0,0 @@
#!/bin/bash
set -ex
# Get the absolute current dir where the script is located
CURDIR=$(dirname "$(realpath $0)")
cd /
echo "CPU info:"
if [ "$(uname)" != "Darwin" ]; then
grep -e "model\sname" /proc/cpuinfo | head -1
grep -e "flags" /proc/cpuinfo | head -1
fi
LIBRARY="$CURDIR/librfdetrcpp-fallback.so"
if [ "$(uname)" != "Darwin" ]; then
if grep -q -e "\savx\s" /proc/cpuinfo ; then
echo "CPU: AVX found OK"
if [ -e $CURDIR/librfdetrcpp-avx.so ]; then
LIBRARY="$CURDIR/librfdetrcpp-avx.so"
fi
fi
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
echo "CPU: AVX2 found OK"
if [ -e $CURDIR/librfdetrcpp-avx2.so ]; then
LIBRARY="$CURDIR/librfdetrcpp-avx2.so"
fi
fi
# Check avx 512
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
echo "CPU: AVX512F found OK"
if [ -e $CURDIR/librfdetrcpp-avx512.so ]; then
LIBRARY="$CURDIR/librfdetrcpp-avx512.so"
fi
fi
fi
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
export RFDETR_LIBRARY=$LIBRARY
# If there is a lib/ld.so, use it
if [ -f $CURDIR/lib/ld.so ]; then
echo "Using lib/ld.so"
echo "Using library: $LIBRARY"
exec $CURDIR/lib/ld.so $CURDIR/rfdetr-cpp "$@"
fi
echo "Using library: $LIBRARY"
exec $CURDIR/rfdetr-cpp "$@"

View File

@@ -1,55 +0,0 @@
#!/bin/bash
set -e
CURDIR=$(dirname "$(realpath $0)")
echo "Running rfdetr-cpp backend tests..."
# Test models from the mudler/rfdetr-cpp-* HuggingFace repos. Both the
# detection (nano-q8_0, ~36 MB) and segmentation (seg-nano-q8_0, ~40 MB)
# variants are downloaded so the Go smoke test exercises both code paths.
RFDETR_MODEL_DIR="${RFDETR_MODEL_DIR:-$CURDIR/test-models}"
RFDETR_DET_FILE="${RFDETR_DET_FILE:-rfdetr-nano-q8_0.gguf}"
RFDETR_DET_URL="${RFDETR_DET_URL:-https://huggingface.co/mudler/rfdetr-cpp-nano/resolve/main/rfdetr-nano-q8_0.gguf}"
RFDETR_SEG_FILE="${RFDETR_SEG_FILE:-rfdetr-seg-nano-q8_0.gguf}"
RFDETR_SEG_URL="${RFDETR_SEG_URL:-https://huggingface.co/mudler/rfdetr-cpp-seg-nano/resolve/main/rfdetr-seg-nano-q8_0.gguf}"
mkdir -p "$RFDETR_MODEL_DIR"
if [ ! -f "$RFDETR_MODEL_DIR/$RFDETR_DET_FILE" ]; then
echo "Downloading rfdetr nano-q8_0 detection model..."
curl -L -o "$RFDETR_MODEL_DIR/$RFDETR_DET_FILE" "$RFDETR_DET_URL" --progress-bar
fi
if [ ! -f "$RFDETR_MODEL_DIR/$RFDETR_SEG_FILE" ]; then
echo "Downloading rfdetr seg-nano-q8_0 segmentation model..."
curl -L -o "$RFDETR_MODEL_DIR/$RFDETR_SEG_FILE" "$RFDETR_SEG_URL" --progress-bar
fi
# Use a real COCO test image from the upstream rf-detr.cpp repo (~46 KB).
# A synthetic 64x64 red PNG was too synthetic to elicit detections from a
# real model — the smoke test would always trivially pass with zero
# detections.
TEST_IMAGE_DIR="$CURDIR/test-data"
TEST_IMAGE_FILE="$TEST_IMAGE_DIR/test.jpg"
TEST_IMAGE_URL="${TEST_IMAGE_URL:-https://raw.githubusercontent.com/mudler/rf-detr.cpp/main/tests/fixtures/ci/test_image.jpg}"
mkdir -p "$TEST_IMAGE_DIR"
if [ ! -f "$TEST_IMAGE_FILE" ]; then
echo "Downloading COCO test image..."
curl -L -o "$TEST_IMAGE_FILE" "$TEST_IMAGE_URL" --progress-bar
fi
echo "rfdetr-cpp test setup complete."
echo " detection model: $RFDETR_MODEL_DIR/$RFDETR_DET_FILE"
echo " segmentation model: $RFDETR_MODEL_DIR/$RFDETR_SEG_FILE"
echo " test image: $TEST_IMAGE_FILE"
# Run the Go smoke test: spawns the backend binary on a free port, calls
# LoadModel + Detect via gRPC for both detection and segmentation models.
echo ""
echo "Running Go smoke test..."
cd "$CURDIR"
go test -v -timeout 5m ./...

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=0e4ee04488159b81d95a9ffcd983a077fd5dcb77
STABLEDIFFUSION_GGML_VERSION?=5b0267e941cade15bd80089d89838795d9f4baa6
CMAKE_ARGS+=-DGGML_MAX_NAME=128

View File

@@ -27,7 +27,6 @@
#include <stdlib.h>
#include <regex>
#include <errno.h>
#include <inttypes.h>
#include <signal.h>
#include <unistd.h>
#include <sys/wait.h>
@@ -377,8 +376,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
const char *clip_g_path = "";
const char *t5xxl_path = "";
const char *vae_path = "";
const char *audio_vae_path = "";
const char *embeddings_connectors_path = "";
const char *scheduler_str = "";
const char *sampler = "";
const char *clip_vision_path = "";
@@ -434,12 +431,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
if (!strcmp(optname, "vae_path")) {
vae_path = strdup(optval);
}
if (!strcmp(optname, "audio_vae_path")) {
audio_vae_path = strdup(optval);
}
if (!strcmp(optname, "embeddings_connectors_path")) {
embeddings_connectors_path = strdup(optval);
}
if (!strcmp(optname, "scheduler")) {
scheduler_str = optval;
}
@@ -572,8 +563,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
ctx_params.diffusion_model_path = diffusion_model_path;
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
ctx_params.vae_path = vae_path;
ctx_params.audio_vae_path = audio_vae_path;
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
ctx_params.taesd_path = taesd_path;
ctx_params.control_net_path = control_net_path;
if (lora_dir && strlen(lora_dir) > 0) {
@@ -1076,71 +1065,9 @@ static uint8_t* load_and_resize_image(const char* path, int target_width, int ta
return buf;
}
// Write sd.cpp's audio buffer to a temp WAV file (IEEE float, interleaved).
// sd_audio_t.data is planar (all channel 0 samples, then channel 1, etc.) — we
// interleave on the fly so ffmpeg's standard wav demuxer can read it directly.
// Returns 0 on success and fills wav_path (must be at least 64 bytes).
static int write_planar_float_wav(const sd_audio_t* a, char* wav_path, size_t wav_path_sz) {
if (!a || !a->data || a->sample_count == 0 || a->channels == 0 || a->sample_rate == 0) {
return -1;
}
snprintf(wav_path, wav_path_sz, "/tmp/gosd-audio-XXXXXX.wav");
int fd = mkstemps(wav_path, 4);
if (fd < 0) { perror("mkstemps wav"); return -1; }
FILE* f = fdopen(fd, "wb");
if (!f) { perror("fdopen wav"); close(fd); return -1; }
uint64_t frames = a->sample_count;
uint32_t channels = a->channels;
uint32_t sample_rate = a->sample_rate;
uint64_t total_samples64 = frames * (uint64_t)channels;
uint64_t data_bytes64 = total_samples64 * sizeof(float);
if (data_bytes64 > 0xFFFFFFFFull - 44) {
fprintf(stderr, "audio too large for 32-bit WAV (%" PRIu64 " bytes)\n", data_bytes64);
fclose(f);
unlink(wav_path);
return -1;
}
uint32_t data_bytes = (uint32_t)data_bytes64;
uint32_t riff_size = 36 + data_bytes;
uint16_t fmt_code = 3; // WAVE_FORMAT_IEEE_FLOAT
uint16_t bits_per_sample = 32;
uint16_t block_align = (uint16_t)(channels * sizeof(float));
uint32_t byte_rate = sample_rate * block_align;
uint16_t ch16 = (uint16_t)channels;
uint32_t fmt_size = 16;
fwrite("RIFF", 1, 4, f);
fwrite(&riff_size, 4, 1, f);
fwrite("WAVEfmt ", 1, 8, f);
fwrite(&fmt_size, 4, 1, f);
fwrite(&fmt_code, 2, 1, f);
fwrite(&ch16, 2, 1, f);
fwrite(&sample_rate, 4, 1, f);
fwrite(&byte_rate, 4, 1, f);
fwrite(&block_align, 2, 1, f);
fwrite(&bits_per_sample, 2, 1, f);
fwrite("data", 1, 4, f);
fwrite(&data_bytes, 4, 1, f);
// Interleave planar [ch0_samples..., ch1_samples...] → [ch0_s0, ch1_s0, ...]
for (uint64_t s = 0; s < frames; s++) {
for (uint32_t c = 0; c < channels; c++) {
float v = a->data[(size_t)c * frames + s];
fwrite(&v, sizeof(float), 1, f);
}
}
fclose(f);
return 0;
}
// Pipe raw RGB/RGBA frames to ffmpeg stdin and let it produce an MP4 at dst.
// Uses fork+execvp to avoid shell interpretation of dst. When `audio` is
// non-null, the audio waveform is staged to a temp WAV and added as a second
// ffmpeg input so the final MP4 contains both video and AAC audio.
static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps,
const sd_audio_t* audio, const char* dst) {
// Uses fork+execvp to avoid shell interpretation of dst.
static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps, const char* dst) {
if (num_frames <= 0 || !frames || !frames[0].data) {
fprintf(stderr, "ffmpeg_mux: empty frames\n");
return 1;
@@ -1155,87 +1082,38 @@ static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps,
snprintf(size_str, sizeof(size_str), "%dx%d", width, height);
snprintf(fps_str, sizeof(fps_str), "%d", fps);
// Optional audio: write a temp WAV file if the model produced audio.
char wav_path[64] = {0};
bool have_audio = false;
if (audio && audio->data && audio->sample_count > 0 && audio->channels > 0 && audio->sample_rate > 0) {
if (write_planar_float_wav(audio, wav_path, sizeof(wav_path)) == 0) {
have_audio = true;
fprintf(stderr, "ffmpeg_mux: audio %u Hz × %u ch × %" PRIu64 " frames → %s\n",
audio->sample_rate, audio->channels, audio->sample_count, wav_path);
} else {
fprintf(stderr, "ffmpeg_mux: failed to stage audio; producing silent video\n");
}
}
int pipefd[2];
if (pipe(pipefd) != 0) {
perror("pipe");
if (have_audio) unlink(wav_path);
return 1;
}
if (pipe(pipefd) != 0) { perror("pipe"); return 1; }
pid_t pid = fork();
if (pid < 0) {
perror("fork");
close(pipefd[0]); close(pipefd[1]);
if (have_audio) unlink(wav_path);
return 1;
}
if (pid < 0) { perror("fork"); close(pipefd[0]); close(pipefd[1]); return 1; }
if (pid == 0) {
// child
close(pipefd[1]);
if (dup2(pipefd[0], STDIN_FILENO) < 0) { perror("dup2"); _exit(127); }
close(pipefd[0]);
std::vector<char*> argv;
argv.push_back(const_cast<char*>("ffmpeg"));
argv.push_back(const_cast<char*>("-y"));
argv.push_back(const_cast<char*>("-hide_banner"));
argv.push_back(const_cast<char*>("-loglevel"));
argv.push_back(const_cast<char*>("warning"));
// Input 0: raw video from stdin
argv.push_back(const_cast<char*>("-f"));
argv.push_back(const_cast<char*>("rawvideo"));
argv.push_back(const_cast<char*>("-pix_fmt"));
argv.push_back(const_cast<char*>(pix_fmt_in));
argv.push_back(const_cast<char*>("-s"));
argv.push_back(size_str);
argv.push_back(const_cast<char*>("-framerate"));
argv.push_back(fps_str);
argv.push_back(const_cast<char*>("-i"));
argv.push_back(const_cast<char*>("-"));
// Input 1: optional audio WAV
if (have_audio) {
argv.push_back(const_cast<char*>("-i"));
argv.push_back(wav_path);
argv.push_back(const_cast<char*>("-map"));
argv.push_back(const_cast<char*>("0:v:0"));
argv.push_back(const_cast<char*>("-map"));
argv.push_back(const_cast<char*>("1:a:0"));
argv.push_back(const_cast<char*>("-c:a"));
argv.push_back(const_cast<char*>("aac"));
argv.push_back(const_cast<char*>("-b:a"));
argv.push_back(const_cast<char*>("192k"));
// -shortest so the final clip ends with the shorter of the two
// streams — guards against an audio buffer that overshoots the
// video duration (or vice versa) on certain LTX variants.
argv.push_back(const_cast<char*>("-shortest"));
}
argv.push_back(const_cast<char*>("-c:v"));
argv.push_back(const_cast<char*>("libx264"));
argv.push_back(const_cast<char*>("-pix_fmt"));
argv.push_back(const_cast<char*>("yuv420p"));
argv.push_back(const_cast<char*>("-movflags"));
argv.push_back(const_cast<char*>("+faststart"));
// Force MP4 container. Distributed LocalAI hands us a staging
// path (e.g. /staging/localai-output-NNN.tmp) with a non-standard
// extension; relying on filename suffix makes ffmpeg bail with
// "Unable to choose an output format".
argv.push_back(const_cast<char*>("-f"));
argv.push_back(const_cast<char*>("mp4"));
argv.push_back(const_cast<char*>(dst));
argv.push_back(nullptr);
std::vector<char*> argv = {
const_cast<char*>("ffmpeg"),
const_cast<char*>("-y"),
const_cast<char*>("-hide_banner"),
const_cast<char*>("-loglevel"), const_cast<char*>("warning"),
const_cast<char*>("-f"), const_cast<char*>("rawvideo"),
const_cast<char*>("-pix_fmt"), const_cast<char*>(pix_fmt_in),
const_cast<char*>("-s"), size_str,
const_cast<char*>("-framerate"), fps_str,
const_cast<char*>("-i"), const_cast<char*>("-"),
const_cast<char*>("-c:v"), const_cast<char*>("libx264"),
const_cast<char*>("-pix_fmt"), const_cast<char*>("yuv420p"),
const_cast<char*>("-movflags"), const_cast<char*>("+faststart"),
// Force MP4 container. Distributed LocalAI hands us a staging
// path (e.g. /staging/localai-output-NNN.tmp) with a non-standard
// extension; relying on filename suffix makes ffmpeg bail with
// "Unable to choose an output format".
const_cast<char*>("-f"), const_cast<char*>("mp4"),
const_cast<char*>(dst),
nullptr
};
execvp(argv[0], argv.data());
perror("execvp ffmpeg");
_exit(127);
@@ -1260,7 +1138,6 @@ static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps,
close(pipefd[1]);
int status;
waitpid(pid, &status, 0);
if (have_audio) unlink(wav_path);
return 1;
}
p += n;
@@ -1271,13 +1148,8 @@ static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps,
int status = 0;
while (waitpid(pid, &status, 0) < 0) {
if (errno != EINTR) {
perror("waitpid");
if (have_audio) unlink(wav_path);
return 1;
}
if (errno != EINTR) { perror("waitpid"); return 1; }
}
if (have_audio) unlink(wav_path);
if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) {
fprintf(stderr, "ffmpeg exited with status %d\n", status);
return 1;
@@ -1352,7 +1224,7 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
fprintf(stderr, "Generated %d frames, muxing to %s via ffmpeg\n", num_frames_out, dst);
int rc = ffmpeg_mux_raw_to_mp4(frames, num_frames_out, fps, audio, dst);
int rc = ffmpeg_mux_raw_to_mp4(frames, num_frames_out, fps, dst);
for (int i = 0; i < num_frames_out; i++) {
if (frames[i].data) free(frames[i].data);

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=f24588a272ae8e23280d9c220536437164e6ed28
WHISPER_CPP_VERSION?=afa2ea544fb4b0448916b4a31ecd33c8685bd482
SO_TARGET?=libgowhisper.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -122,35 +122,6 @@
nvidia-cuda-12: "cuda12-whisper"
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisper"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-whisper"
- &parakeetcpp
name: "parakeet-cpp"
alias: "parakeet-cpp"
license: mit
icon: https://avatars.githubusercontent.com/u/95302084
description: |
parakeet.cpp is a C++/ggml port of NVIDIA NeMo Parakeet automatic speech recognition (ASR) models.
It supports the tdt, ctc, rnnt and hybrid decoder families as well as cache-aware streaming transcription,
and runs on CPU, NVIDIA CUDA, AMD ROCm/HIP, Intel SYCL and NVIDIA Jetson (L4T) targets.
urls:
- https://github.com/mudler/parakeet.cpp
tags:
- audio-transcription
- CPU
- GPU
- CUDA
- HIP
capabilities:
default: "cpu-parakeet-cpp"
nvidia: "cuda12-parakeet-cpp"
intel: "intel-sycl-f16-parakeet-cpp"
metal: "metal-parakeet-cpp"
amd: "rocm-parakeet-cpp"
vulkan: "vulkan-parakeet-cpp"
nvidia-l4t: "nvidia-l4t-arm64-parakeet-cpp"
nvidia-cuda-13: "cuda13-parakeet-cpp"
nvidia-cuda-12: "cuda12-parakeet-cpp"
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-parakeet-cpp"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-parakeet-cpp"
- &voxtral
name: "voxtral"
alias: "voxtral"
@@ -282,34 +253,6 @@
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp"
intel: "intel-sycl-f32-sam3-cpp"
vulkan: "vulkan-sam3-cpp"
- &rfdetrcpp
name: "rfdetr-cpp"
alias: "rfdetr-cpp"
license: apache-2.0
description: |
Native RF-DETR object detection and instance segmentation in C/C++
using GGML. Loads pre-built GGUF weights from the mudler/rfdetr-cpp-*
family (Nano/Small/Base/Medium/Large + SegNano/SegSmall/SegMedium)
and returns bounding boxes, class labels, confidence scores, and
(for segmentation variants) PNG-encoded per-detection masks.
urls:
- https://github.com/mudler/rf-detr.cpp
tags:
- object-detection
- image-segmentation
- rfdetr
- gpu
- cpu
capabilities:
default: "cpu-rfdetr-cpp"
nvidia: "cuda12-rfdetr-cpp"
nvidia-cuda-12: "cuda12-rfdetr-cpp"
nvidia-cuda-13: "cuda13-rfdetr-cpp"
nvidia-l4t: "nvidia-l4t-arm64-rfdetr-cpp"
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr-cpp"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-rfdetr-cpp"
intel: "intel-sycl-f32-rfdetr-cpp"
vulkan: "vulkan-rfdetr-cpp"
- &vllm
name: "vllm"
license: apache-2.0
@@ -1957,121 +1900,6 @@
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-whisper"
mirrors:
- localai/localai-backends:master-gpu-nvidia-cuda-13-whisper
## parakeet-cpp
- !!merge <<: *parakeetcpp
name: "parakeet-cpp-development"
capabilities:
default: "cpu-parakeet-cpp-development"
nvidia: "cuda12-parakeet-cpp-development"
intel: "intel-sycl-f16-parakeet-cpp-development"
metal: "metal-parakeet-cpp-development"
amd: "rocm-parakeet-cpp-development"
vulkan: "vulkan-parakeet-cpp-development"
nvidia-l4t: "nvidia-l4t-arm64-parakeet-cpp-development"
nvidia-cuda-13: "cuda13-parakeet-cpp-development"
nvidia-cuda-12: "cuda12-parakeet-cpp-development"
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-parakeet-cpp-development"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-parakeet-cpp-development"
- !!merge <<: *parakeetcpp
name: "nvidia-l4t-arm64-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-nvidia-l4t-arm64-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "nvidia-l4t-arm64-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-parakeet-cpp"
mirrors:
- localai/localai-backends:master-nvidia-l4t-arm64-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "cuda13-nvidia-l4t-arm64-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "cuda13-nvidia-l4t-arm64-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-parakeet-cpp"
mirrors:
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "cpu-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-cpu-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "cpu-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-parakeet-cpp"
mirrors:
- localai/localai-backends:master-cpu-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "metal-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "metal-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-parakeet-cpp"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "cuda12-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-gpu-nvidia-cuda-12-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "cuda12-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-parakeet-cpp"
mirrors:
- localai/localai-backends:master-gpu-nvidia-cuda-12-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "rocm-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-gpu-rocm-hipblas-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "rocm-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-parakeet-cpp"
mirrors:
- localai/localai-backends:master-gpu-rocm-hipblas-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "intel-sycl-f32-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-gpu-intel-sycl-f32-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "intel-sycl-f32-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-parakeet-cpp"
mirrors:
- localai/localai-backends:master-gpu-intel-sycl-f32-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "intel-sycl-f16-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-gpu-intel-sycl-f16-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "intel-sycl-f16-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-parakeet-cpp"
mirrors:
- localai/localai-backends:master-gpu-intel-sycl-f16-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "vulkan-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-gpu-vulkan-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "vulkan-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-parakeet-cpp"
mirrors:
- localai/localai-backends:master-gpu-vulkan-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "cuda13-parakeet-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-parakeet-cpp"
mirrors:
- localai/localai-backends:latest-gpu-nvidia-cuda-13-parakeet-cpp
- !!merge <<: *parakeetcpp
name: "cuda13-parakeet-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-parakeet-cpp"
mirrors:
- localai/localai-backends:master-gpu-nvidia-cuda-13-parakeet-cpp
## stablediffusion-ggml
- !!merge <<: *stablediffusionggml
name: "cpu-stablediffusion-ggml"
@@ -2521,99 +2349,6 @@
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-sam3-cpp"
mirrors:
- localai/localai-backends:master-gpu-vulkan-sam3-cpp
## rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "rfdetr-cpp-development"
capabilities:
default: "cpu-rfdetr-cpp-development"
nvidia: "cuda12-rfdetr-cpp-development"
nvidia-cuda-12: "cuda12-rfdetr-cpp-development"
nvidia-cuda-13: "cuda13-rfdetr-cpp-development"
nvidia-l4t: "nvidia-l4t-arm64-rfdetr-cpp-development"
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr-cpp-development"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-rfdetr-cpp-development"
intel: "intel-sycl-f32-rfdetr-cpp-development"
vulkan: "vulkan-rfdetr-cpp-development"
- !!merge <<: *rfdetrcpp
name: "cpu-rfdetr-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-rfdetr-cpp"
mirrors:
- localai/localai-backends:latest-cpu-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "cpu-rfdetr-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-rfdetr-cpp"
mirrors:
- localai/localai-backends:master-cpu-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "cuda12-rfdetr-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-rfdetr-cpp"
mirrors:
- localai/localai-backends:latest-gpu-nvidia-cuda-12-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "cuda12-rfdetr-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-rfdetr-cpp"
mirrors:
- localai/localai-backends:master-gpu-nvidia-cuda-12-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "cuda13-rfdetr-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-rfdetr-cpp"
mirrors:
- localai/localai-backends:latest-gpu-nvidia-cuda-13-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "cuda13-rfdetr-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-rfdetr-cpp"
mirrors:
- localai/localai-backends:master-gpu-nvidia-cuda-13-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "nvidia-l4t-arm64-rfdetr-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-rfdetr-cpp"
mirrors:
- localai/localai-backends:latest-nvidia-l4t-arm64-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "nvidia-l4t-arm64-rfdetr-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-rfdetr-cpp"
mirrors:
- localai/localai-backends:master-nvidia-l4t-arm64-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "cuda13-nvidia-l4t-arm64-rfdetr-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-rfdetr-cpp"
mirrors:
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "cuda13-nvidia-l4t-arm64-rfdetr-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-rfdetr-cpp"
mirrors:
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "intel-sycl-f32-rfdetr-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-rfdetr-cpp"
mirrors:
- localai/localai-backends:latest-gpu-intel-sycl-f32-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "intel-sycl-f32-rfdetr-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-rfdetr-cpp"
mirrors:
- localai/localai-backends:master-gpu-intel-sycl-f32-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "intel-sycl-f16-rfdetr-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-rfdetr-cpp"
mirrors:
- localai/localai-backends:latest-gpu-intel-sycl-f16-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "intel-sycl-f16-rfdetr-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-rfdetr-cpp"
mirrors:
- localai/localai-backends:master-gpu-intel-sycl-f16-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "vulkan-rfdetr-cpp"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-rfdetr-cpp"
mirrors:
- localai/localai-backends:latest-gpu-vulkan-rfdetr-cpp
- !!merge <<: *rfdetrcpp
name: "vulkan-rfdetr-cpp-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-rfdetr-cpp"
mirrors:
- localai/localai-backends:master-gpu-vulkan-rfdetr-cpp
## Rerankers
- !!merge <<: *rerankers
name: "rerankers-development"

View File

@@ -99,15 +99,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if not results or len(results) == 0:
return backend_pb2.TranscriptResult(segments=[], text="")
# Get the transcript text from the first result.
# CTC models return List[str], TDT/RNNT models return List[Hypothesis]
# where the actual text lives in Hypothesis.text.
result = results[0]
if isinstance(result, str):
text = result
else:
text = getattr(result, 'text', None) or ""
# Get the transcript text from the first result
text = results[0]
if text:
# Create a single segment with the full transcription
result_segments.append(backend_pb2.TranscriptSegment(

View File

@@ -134,156 +134,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Model loaded successfully", success=True)
@staticmethod
def _is_cjk(ch):
"""Check if a character is CJK (Chinese/Japanese/Korean)."""
cp = ord(ch)
return (
0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs
or 0x3400 <= cp <= 0x4DBF # Extension A
or 0x20000 <= cp <= 0x2A6DF # Extension B
or 0xF900 <= cp <= 0xFAFF # Compatibility Ideographs
or 0x3040 <= cp <= 0x309F # Hiragana
or 0x30A0 <= cp <= 0x30FF # Katakana
or 0xAC00 <= cp <= 0xD7AF # Hangul Syllables
)
@staticmethod
def _is_punct(ch):
"""Check if a character is punctuation (no space before it)."""
import unicodedata
cat = unicodedata.category(ch)
return cat.startswith('P')
@staticmethod
def _smart_join(tokens):
"""Join tokens with spaces for non-CJK text, without spaces for CJK.
Rules:
- Between two CJK chars: no space
- Between two non-CJK tokens: space
- Before punctuation: no space
- CJK adjacent to non-CJK: no space (smooth mixed-text transition)
"""
if not tokens:
return ""
result = [tokens[0]]
for token in tokens[1:]:
if not token:
continue
prev_ch = result[-1][-1] if result[-1] else ''
curr_ch = token[0]
# Punctuation never gets a space before it
if BackendServicer._is_punct(curr_ch):
result.append(token)
# CJK to CJK: no space
elif prev_ch and BackendServicer._is_cjk(prev_ch) and BackendServicer._is_cjk(curr_ch):
result.append(token)
# CJK adjacent to non-CJK or vice versa: no space
elif prev_ch and (BackendServicer._is_cjk(prev_ch) or BackendServicer._is_cjk(curr_ch)):
result.append(token)
# Both non-CJK (Latin, Cyrillic, etc.): add space
else:
result.append(' ' + token)
return "".join(result)
@staticmethod
def _extract_word_info(ts):
"""Return (start_sec, end_sec, text) from a ForcedAlignItem or tuple."""
if hasattr(ts, 'start_time') and hasattr(ts, 'end_time') and hasattr(ts, 'text'):
return (
float(ts.start_time) if ts.start_time is not None else 0.0,
float(ts.end_time) if ts.end_time is not None else 0.0,
str(ts.text) if ts.text else "",
)
elif isinstance(ts, (list, tuple)) and len(ts) >= 3:
return (
float(ts[0]) if ts[0] is not None else 0.0,
float(ts[1]) if ts[1] is not None else 0.0,
ts[2] if len(ts) > 2 and ts[2] is not None else "",
)
return (0.0, 0.0, "")
@staticmethod
def _compute_gap_threshold(time_stamps):
"""Compute a gap threshold for sentence boundary detection.
Uses the median inter-item gap multiplied by a factor, with a
minimum floor of 0.3s. Returns 0 if there are too few items.
"""
if len(time_stamps) < 2:
return 0.0
gaps = []
for i in range(1, len(time_stamps)):
prev_s, prev_e, _ = BackendServicer._extract_word_info(time_stamps[i - 1])
curr_s, _, _ = BackendServicer._extract_word_info(time_stamps[i])
gaps.append(curr_s - prev_e)
if not gaps:
return 0.0
gaps.sort()
median = gaps[len(gaps) // 2]
# threshold = max(median * 4, 0.3s)
return max(median * 4, 0.3)
def _build_segments(self, time_stamps, granularity):
"""Build TranscriptSegment list from forced-aligner output.
granularity:
- "word": one segment per aligned item (character / word)
- "segment" (default): merge consecutive items, splitting at
time gaps that exceed a dynamic threshold (sentence boundaries).
"""
if granularity == "word":
result = []
for idx, ts in enumerate(time_stamps):
s, e, t = self._extract_word_info(ts)
result.append(backend_pb2.TranscriptSegment(
id=idx,
start=int(s * 1_000_000_000),
end=int(e * 1_000_000_000),
text=t,
))
return result
# segment mode — merge at time-gap boundaries
threshold = self._compute_gap_threshold(time_stamps)
result = []
buf_text = []
buf_start = None
buf_end = 0.0
prev_end = None
for ts in time_stamps:
s, e, t = self._extract_word_info(ts)
# Detect sentence boundary via time gap
if prev_end is not None and (s - prev_end) >= threshold and buf_text:
result.append(backend_pb2.TranscriptSegment(
id=len(result),
start=int(buf_start * 1_000_000_000),
end=int(buf_end * 1_000_000_000),
text=self._smart_join(buf_text),
))
buf_text = []
buf_start = None
if buf_start is None:
buf_start = s
buf_text.append(t)
buf_end = e
prev_end = e
# flush remaining
if buf_text and buf_start is not None:
result.append(backend_pb2.TranscriptSegment(
id=len(result),
start=int(buf_start * 1_000_000_000),
end=int(buf_end * 1_000_000_000),
text=self._smart_join(buf_text),
))
return result
def AudioTranscription(self, request, context):
result_segments = []
text = ""
@@ -297,22 +147,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.language and request.language.strip():
language = request.language.strip()
ctx = ""
context = ""
if request.prompt and request.prompt.strip():
ctx = request.prompt.strip()
context = request.prompt.strip()
# Determine requested granularity (default: segment)
granularities = list(request.timestamp_granularities) if request.timestamp_granularities else []
granularity = "word" if "word" in granularities else "segment"
has_aligner = getattr(self.model, 'forced_aligner', None) is not None
try:
results = self.model.transcribe(
audio=audio_path, language=language, context=ctx,
return_time_stamps=has_aligner,
)
except TypeError:
results = self.model.transcribe(audio=audio_path, language=language, context=ctx)
results = self.model.transcribe(audio=audio_path, language=language, context=context)
if not results:
return backend_pb2.TranscriptResult(segments=[], text="")
@@ -321,7 +160,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
text = r.text or ""
if getattr(r, 'time_stamps', None) and len(r.time_stamps) > 0:
result_segments = self._build_segments(r.time_stamps, granularity)
for idx, ts in enumerate(r.time_stamps):
start_ms = 0
end_ms = 0
seg_text = text
if isinstance(ts, (list, tuple)) and len(ts) >= 3:
start_ms = int(float(ts[0]) * 1000) if ts[0] is not None else 0
end_ms = int(float(ts[1]) * 1000) if ts[1] is not None else 0
seg_text = ts[2] if len(ts) > 2 and ts[2] is not None else ""
result_segments.append(backend_pb2.TranscriptSegment(
id=idx, start=start_ms, end=end_ms, text=seg_text
))
else:
if text:
result_segments.append(backend_pb2.TranscriptSegment(

View File

@@ -36,11 +36,15 @@ fi
# flash-attn-4 4.0 stable lands.
EXTRA_PIP_INSTALL_FLAGS+=" --prerelease=allow"
# JetPack 7 / L4T arm64 sglang + torch wheels come straight from PyPI now
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and sglang 0.5.11+
# ships a cp312 aarch64 wheel pinned to that torch). They're cp312-only,
# so bump the venv Python accordingly.
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
# JetPack 7 / L4T arm64 wheels are built for cp312 and shipped via
# pypi.jetson-ai-lab.io. Bump the venv Python so the prebuilt sglang
# wheel resolves cleanly. The actual install on l4t13 goes through
# pyproject.toml (see the elif branch below) so [tool.uv.sources] can
# pin only torch/torchvision/torchaudio/sglang to the jetson-ai-lab
# index — leaving PyPI as the path for transitive deps like
# markdown-it-py / anthropic / propcache that the L4T mirror's proxy
# 503s on. No --index-strategy flag here: the explicit index keeps the
# scoping clean.
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
PYTHON_VERSION="3.12"
PYTHON_PATCH="12"
@@ -106,6 +110,27 @@ if [ "x${BUILD_TYPE}" == "x" ] || [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
fi
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
popd
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
# [tool.uv.sources] can pin torch/torchvision/torchaudio/sglang to the
# jetson-ai-lab index, while everything else (transitive deps and
# PyPI-resolvable packages like transformers / accelerate) comes from
# PyPI. Bypasses installRequirements because uv pip install -r
# requirements.txt does not honor sources — see
# backend/python/sglang/pyproject.toml for the rationale. Mirrors the
# equivalent path in backend/python/vllm/install.sh.
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
ensureVenv
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
fi
pushd "${backend_dir}"
# Build deps first (matches installRequirements' requirements-install.txt
# pass — sglang/sgl-kernel sdists need packaging/setuptools-scm in the
# venv before they can build under --no-build-isolation).
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
popd
runProtogen
else
installRequirements
fi

View File

@@ -0,0 +1,68 @@
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the sglang backend.
#
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
#
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / sglang / sgl-kernel
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently.
# With `--extra-index-url` + `--index-strategy=unsafe-best-match` (the
# historical fix in install.sh) uv would pick those proxy URLs for ordinary
# PyPI packages — markdown-it-py, anthropic, propcache, etc. — and trip on
# the 503s. See e.g. CI run 25439791228 (markdown-it-py-4.0.0).
#
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
# This breaks the historical 503 path without losing access to the L4T
# wheels we actually need from there. Mirrors the equivalent fix already
# in backend/python/vllm/pyproject.toml.
#
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
# (sources are project-mode only, not pip-compat mode), so install.sh's
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
# pipeline through libbackend.sh's installRequirements and never read
# this file.
[project]
name = "localai-sglang-l4t13"
version = "0.0.0"
requires-python = ">=3.12,<3.13"
dependencies = [
# Mirror of requirements.txt — kept in sync manually for now since the
# l4t13 path bypasses installRequirements (see install.sh).
"grpcio==1.80.0",
"protobuf",
"certifi",
"setuptools",
"pillow",
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
"torch",
"torchvision",
"torchaudio",
# sglang on jetson — the [all] extra is deliberately omitted because it
# pulls outlines/decord, and decord has no aarch64 cp312 wheel anywhere
# (PyPI nor the jetson-ai-lab index ships only legacy cp35-cp37). With
# [all] uv backtracks through versions trying to satisfy decord and
# lands on sglang==0.1.16. The 0.5.0 floor matches the only major
# series the jetson-ai-lab sbsa/cu130 mirror currently publishes
# (sglang==0.5.1.post2 as of 2026-05-06). Bumping to >=0.5.11 here
# would make the build unsatisfiable until the mirror catches up.
# Gemma 4 / MTP recipes are therefore not supported on l4t13 — those
# features land on cublas12/cublas13 hosts that pull the newer wheel
# from PyPI. backend.py keeps backward compat with the 0.5.x SamplingParams
# field rename via runtime detection.
"sglang>=0.5.0",
# PyPI-resolvable packages that complete the runtime.
"accelerate",
"transformers",
]
[[tool.uv.index]]
name = "jetson-ai-lab"
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
explicit = true
[tool.uv.sources]
torch = { index = "jetson-ai-lab" }
torchvision = { index = "jetson-ai-lab" }
torchaudio = { index = "jetson-ai-lab" }
sglang = { index = "jetson-ai-lab" }

View File

@@ -1,15 +0,0 @@
# sglang 0.5.11+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist
# pins torch==2.11.0 / torchaudio==2.11.0, locking an ABI-consistent set with
# the cu130 torch wheel installed above. 0.5.11 is the floor for Gemma 4
# support (sgl-project/sglang#21952).
#
# The [all] extra is deliberately NOT used on aarch64: it pulls the
# [diffusion] sub-extra which requires `xatlas`, and xatlas ships no
# aarch64 wheel and its sdist depends on scikit_build_core without
# declaring it in build-system.requires — so under --no-build-isolation
# uv can't build it. Upstream sglang gates st_attn and vsa on
# platform_machine != aarch64 in the diffusion extra but forgot xatlas.
# Plain `sglang` carries everything backend.py uses (Engine, ServerArgs,
# FunctionCallParser, ReasoningParser); the [all] extras are optional
# accelerators not required at import time.
sglang>=0.5.11

View File

@@ -1,9 +0,0 @@
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
# so we no longer need a custom --extra-index-url for the L4T mirror.
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
accelerate
torch
torchvision
torchaudio
transformers

View File

@@ -26,7 +26,7 @@ import torch.cuda
XPU=os.environ.get("XPU", "0") == "1"
import transformers as transformers_module
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, pipeline
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
from scipy.io import wavfile
from sentence_transformers import SentenceTransformer
@@ -200,21 +200,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
autoTokenizer = False
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
self.SentenceTransformer = True
elif request.Type == "TokenClassification":
# NER / PII tagging via HuggingFace's token-classification
# pipeline. aggregation_strategy="simple" merges B-/I- tags
# into single spans and gives byte offsets back. The
# tokenizer is bundled inside the pipeline, so we skip the
# AutoTokenizer load below.
autoTokenizer = False
self.tokenClassifier = pipeline(
"token-classification",
model=model_name,
aggregation_strategy="simple",
device=0 if self.CUDA else -1,
trust_remote_code=request.TrustRemoteCode,
)
self.TokenClassification = True
else:
# Generic: dynamically resolve model class from transformers
model_type = TYPE_ALIASES.get(request.Type, request.Type)
@@ -268,39 +253,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
def TokenClassify(self, request, context):
# Runs HuggingFace's token-classification pipeline and returns
# the aggregated entity spans. The pipeline gives us byte
# offsets via aggregation_strategy="simple" (set at load
# time), so the caller can slice the original text without
# re-tokenising on the Go side.
if not getattr(self, "TokenClassification", False):
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("model was not loaded as Type=TokenClassification")
return backend_pb2.TokenClassifyResponse()
try:
results = self.tokenClassifier(request.text)
except Exception as err:
print("TokenClassify error:", err, file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"token-classification failed: {err}")
return backend_pb2.TokenClassifyResponse()
threshold = request.threshold if request.threshold > 0 else 0.0
entities = []
for r in results:
score = float(r.get("score", 0.0))
if score < threshold:
continue
entities.append(backend_pb2.TokenClassifyEntity(
entity_group=str(r.get("entity_group") or r.get("entity") or ""),
start=int(r.get("start", 0)),
end=int(r.get("end", 0)),
score=score,
text=str(r.get("word", "")),
))
return backend_pb2.TokenClassifyResponse(entities=entities)
def Embedding(self, request, context):
set_seed(request.Seed)
# Tokenize input

View File

@@ -2,9 +2,9 @@ torch==2.7.1
llvmlite==0.43.0
numba==0.60.0
accelerate
transformers>=5.9.0
transformers>=5.8.1
bitsandbytes
sentence-transformers==5.5.1
sentence-transformers==5.5.0
diffusers
soundfile
protobuf==7.35.0
protobuf==6.33.5

View File

@@ -2,9 +2,9 @@ torch==2.7.1
accelerate
llvmlite==0.43.0
numba==0.60.0
transformers>=5.9.0
transformers>=5.8.1
bitsandbytes
sentence-transformers==5.5.1
sentence-transformers==5.5.0
diffusers
soundfile
protobuf==7.35.0
protobuf==6.33.5

View File

@@ -2,9 +2,9 @@
torch==2.9.0
llvmlite==0.43.0
numba==0.60.0
transformers>=5.9.0
transformers>=5.8.1
bitsandbytes
sentence-transformers==5.5.1
sentence-transformers==5.5.0
diffusers
soundfile
protobuf==7.35.0
protobuf==6.33.5

View File

@@ -1,11 +1,11 @@
--extra-index-url https://download.pytorch.org/whl/rocm7.0
torch==2.10.0+rocm7.0
accelerate
transformers>=5.9.0
transformers>=5.8.1
llvmlite==0.43.0
numba==0.60.0
bitsandbytes
sentence-transformers==5.5.1
sentence-transformers==5.5.0
diffusers
soundfile
protobuf==7.35.0
protobuf==6.33.5

View File

@@ -3,9 +3,9 @@ torch
optimum[openvino]
llvmlite==0.43.0
numba==0.60.0
transformers>=5.9.0
transformers>=5.8.1
bitsandbytes
sentence-transformers==5.5.1
sentence-transformers==5.5.0
diffusers
soundfile
protobuf==7.35.0
protobuf==6.33.5

View File

@@ -2,9 +2,9 @@ torch==2.7.1
llvmlite==0.43.0
numba==0.60.0
accelerate
transformers>=5.9.0
transformers>=5.8.1
bitsandbytes
sentence-transformers==5.5.1
sentence-transformers==5.5.0
diffusers
soundfile
protobuf==7.35.0
protobuf==6.33.5

View File

@@ -1,5 +1,5 @@
grpcio==1.80.0
protobuf==7.35.0
protobuf==6.33.5
certifi
setuptools
scipy==1.15.1

View File

@@ -13,14 +13,14 @@ else
fi
# Handle l4t build profiles (Python 3.12, pip fallback) if needed.
# Since PyTorch 2.11 (April 2026) PyPI ships aarch64 + cu130 manylinux wheels
# directly for torch/torchvision/torchaudio and an aarch64 vllm wheel pinned
# to that torch, so the jetson-ai-lab mirror is no longer needed.
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
# unsafe-best-match is required on l4t13 because the jetson-ai-lab index
# lists transitive deps at limited versions — without it uv pins to the
# first matching index and fails to resolve a compatible wheel from PyPI.
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
PYTHON_VERSION="3.12"
PYTHON_PATCH="12"
PY_STANDALONE_TAG="20251120"
EXTRA_PIP_INSTALL_FLAGS="${EXTRA_PIP_INSTALL_FLAGS:-} --index-strategy=unsafe-best-match"
fi
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
@@ -42,11 +42,18 @@ if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
else
uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700
fi
elif [ "x${BUILD_PROFILE}" == "xcublas13" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
# cublas13 (x86_64) and l4t13 (aarch64) both pull vllm from PyPI now:
# vllm 0.19+ defaults to cu130 wheels on x86_64 and vllm 0.20+ ships an
# aarch64 manylinux wheel pinned to torch==2.11.0. No extra index needed
# in either case.
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
# JetPack 7 / L4T arm64 cu130 — vllm comes from the prebuilt SBSA wheel
# at jetson-ai-lab. Version is unpinned: the index ships whatever build
# matches the cu130/cp312 ABI. unsafe-best-match lets uv fall through
# to PyPI for transitive deps not present on the jetson-ai-lab index.
if [ "x${USE_PIP}" == "xtrue" ]; then
pip install vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
else
uv pip install --index-strategy=unsafe-best-match vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
fi
elif [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
# vllm 0.19+ defaults to cu130 wheels on PyPI, no extra index needed.
if [ "x${USE_PIP}" == "xtrue" ]; then
pip install vllm --torch-backend=auto
else

View File

@@ -1,15 +1,11 @@
# JetPack 7 / L4T arm64 + CUDA 13. PyPI ships aarch64 + cu130 manylinux wheels
# for torch/torchvision/torchaudio directly since PyTorch 2.11 (April 2026),
# so no custom index is needed. flash-attn is dropped here: PyPI has no
# aarch64 wheel for it, but vLLM 0.20+ bundles its own vllm_flash_attn
# (fa2 + fa3) inside the main wheel, so it is not required at runtime.
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
--extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
accelerate
torch
torchvision
torchaudio
transformers
bitsandbytes
flash-attn
diffusers
librosa
soundfile

View File

@@ -356,133 +356,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
except Exception as e:
return backend_pb2.Result(success=False, message=str(e))
async def Score(self, request, context):
"""
Joint log-probability of each candidate continuation given the
shared prompt. Used by routing-policy multi-label classification
(read the distribution rather than asking the model to emit a
single argmax label), reranking, and reward-model scoring.
Implementation uses vLLM's `prompt_logprobs` to recover the
per-token log P(token_i | tokens_<i) for the full concatenated
sequence; the candidate's tokens are the suffix whose logprobs
get summed. max_tokens=1 because vLLM requires at least one
generated token; the generated token is discarded.
"""
if not hasattr(self, 'llm') or self.llm is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("Model not loaded")
return backend_pb2.ScoreResponse()
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("Tokenizer not available")
return backend_pb2.ScoreResponse()
if len(request.candidates) == 0:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details("candidates must be non-empty")
return backend_pb2.ScoreResponse()
try:
prompt = request.prompt or ""
prompt_token_ids = self.tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
results = []
for candidate in request.candidates:
# Tokenise the concatenated sequence. We can't naively
# use len(prompt_tokens) + len(tokenizer.encode(candidate))
# because BPE merges at the boundary may produce a
# different tokenisation. Encoding the joined text and
# walking the divergence point is the correct primitive.
full_text = prompt + candidate
full_token_ids = self.tokenizer.encode(full_text)
divergence = prompt_len
min_len = min(prompt_len, len(full_token_ids))
for i in range(min_len):
if prompt_token_ids[i] != full_token_ids[i]:
divergence = i
break
candidate_token_ids = full_token_ids[divergence:]
num_candidate_tokens = len(candidate_token_ids)
if num_candidate_tokens == 0:
results.append(backend_pb2.CandidateScore(
log_prob=0.0,
length_normalized_log_prob=0.0,
num_tokens=0,
))
continue
sampling = SamplingParams(
max_tokens=1,
temperature=0.0,
prompt_logprobs=1,
detokenize=False,
)
request_id = random_uuid()
last_output = None
outputs_iter = self.llm.generate(
{"prompt": full_text},
sampling_params=sampling,
request_id=request_id,
)
try:
async for out in outputs_iter:
last_output = out
finally:
try:
await outputs_iter.aclose()
except Exception:
pass
if last_output is None or not getattr(last_output, "prompt_logprobs", None):
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details("vLLM did not return prompt_logprobs")
return backend_pb2.ScoreResponse()
prompt_logprobs = last_output.prompt_logprobs
total = 0.0
tokens_proto = []
for offset, tok_id in enumerate(candidate_token_ids):
position = divergence + offset
if position >= len(prompt_logprobs) or prompt_logprobs[position] is None:
continue
entry = prompt_logprobs[position]
lp_obj = entry.get(tok_id)
if lp_obj is not None:
lp = lp_obj.logprob
else:
# Token not in top-K; vLLM's top-1 may miss it.
# Fall back to the lowest available logprob in the
# entry — a conservative lower-bound on the true
# log P, biased against this candidate.
lp = min(v.logprob for v in entry.values())
total += lp
if request.include_token_logprobs:
tokens_proto.append(backend_pb2.TokenLogProb(
token=self.tokenizer.decode([tok_id]),
log_prob=lp,
))
cs = backend_pb2.CandidateScore(
log_prob=total,
num_tokens=num_candidate_tokens,
)
if request.length_normalize and num_candidate_tokens > 0:
cs.length_normalized_log_prob = total / num_candidate_tokens
if tokens_proto:
cs.tokens.extend(tokens_proto)
results.append(cs)
return backend_pb2.ScoreResponse(candidates=results)
except Exception as e:
print(f"Score error: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return backend_pb2.ScoreResponse()
async def _predict(self, request, context, streaming=False):
# Build the sampling parameters
# NOTE: this must stay in sync with the vllm backend

View File

@@ -43,11 +43,14 @@ if [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
fi
# JetPack 7 / L4T arm64 vllm + torch wheels come straight from PyPI now
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and vllm 0.20+ ships
# an aarch64 wheel pinned to that torch). They're cp312-only, so bump the
# venv Python accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
# JetPack 7 / L4T arm64 wheels (torch, vllm, flash-attn) live on
# pypi.jetson-ai-lab.io and are built for cp312, so bump the venv Python
# accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
#
# l4t13 uses pyproject.toml (see the elif branch below) to pin only the
# L4T-specific wheels to the jetson-ai-lab index via [tool.uv.sources].
# That keeps PyPI as the resolution path for transitive deps like
# anthropic/openai/propcache, which the L4T mirror's proxy 503s on.
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
USE_PIP=true
fi
@@ -100,6 +103,25 @@ if [ "x${BUILD_TYPE}" == "xintel" ]; then
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
popd
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
# [tool.uv.sources] can pin torch/vllm/flash-attn/torchvision/torchaudio
# to the jetson-ai-lab index, while everything else (transitive deps and
# PyPI-resolvable packages like transformers) comes from PyPI. Bypasses
# installRequirements because uv pip install -r requirements.txt does not
# honor sources — see backend/python/vllm/pyproject.toml for the rationale.
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
ensureVenv
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
fi
pushd "${backend_dir}"
# Build deps first (matches installRequirements' requirements-install.txt
# pass — fastsafetensors and friends need pybind11 in the venv before
# their sdists can build under --no-build-isolation).
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
popd
runProtogen
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
# requirements-cpu-after.txt and compiles vllm locally against the host's
# actual CPU. Not used by default because it takes ~30-40 minutes, but

View File

@@ -0,0 +1,61 @@
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the vllm backend.
#
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
#
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / vllm / flash-attn
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently. With
# `--extra-index-url` + `--index-strategy=unsafe-best-match` (the historical
# fix in install.sh) uv would pick those proxy URLs for ordinary PyPI
# packages — `anthropic`, `openai`, `propcache`, `annotated-types` — and
# trip on the 503s. See e.g. CI run 25212201349 (anthropic-0.97.0).
#
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
# This breaks the historical 503 path without losing access to the L4T
# wheels we actually need from there.
#
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
# (sources are project-mode only, not pip-compat mode), so install.sh's
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
# pipeline through libbackend.sh's installRequirements and never read
# this file.
[project]
name = "localai-vllm-l4t13"
version = "0.0.0"
requires-python = ">=3.12,<3.13"
dependencies = [
# Mirror of requirements.txt — kept in sync manually for now since the
# l4t13 path bypasses installRequirements (see install.sh).
"grpcio==1.80.0",
"protobuf",
"certifi",
"setuptools",
"pillow",
"charset-normalizer>=3.4.7",
"chardet",
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
"torch",
"torchvision",
"torchaudio",
"flash-attn",
"vllm",
# PyPI-resolvable packages that complete the runtime — accelerate,
# transformers, bitsandbytes carry their own wheels for aarch64.
"accelerate",
"transformers",
"bitsandbytes",
]
[[tool.uv.index]]
name = "jetson-ai-lab"
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
explicit = true
[tool.uv.sources]
torch = { index = "jetson-ai-lab" }
torchvision = { index = "jetson-ai-lab" }
torchaudio = { index = "jetson-ai-lab" }
flash-attn = { index = "jetson-ai-lab" }
vllm = { index = "jetson-ai-lab" }

View File

@@ -3,5 +3,5 @@
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
# so uv consults this index alongside PyPI.
--extra-index-url https://wheels.vllm.ai/0.22.0/cu130
vllm==0.22.0
--extra-index-url https://wheels.vllm.ai/0.21.0/cu130
vllm==0.21.0

View File

@@ -1,4 +0,0 @@
# vLLM 0.20+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist pins
# torch==2.11.0 / torchvision==0.26.0 / torchaudio==2.11.0, locking an ABI-
# consistent set with the cu130 torch wheel installed above.
vllm

View File

@@ -1,8 +0,0 @@
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
# so we no longer need a custom --extra-index-url for the L4T mirror.
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
accelerate
torch
transformers
bitsandbytes

View File

@@ -375,15 +375,6 @@ impl Backend for KokorosService {
Err(Status::unimplemented("Not supported"))
}
type AudioToAudioStreamStream = ReceiverStream<Result<backend::AudioToAudioResponse, Status>>;
async fn audio_to_audio_stream(
&self,
_: Request<tonic::Streaming<backend::AudioToAudioRequest>>,
) -> Result<Response<Self::AudioToAudioStreamStream>, Status> {
Err(Status::unimplemented("Not supported"))
}
async fn sound_generation(
&self,
_: Request<backend::SoundGenerationRequest>,

View File

@@ -17,7 +17,6 @@ import (
"time"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/httpclient"
)
// Release represents a LocalAI release
@@ -68,7 +67,9 @@ func NewReleaseManager() *ReleaseManager {
CurrentVersion: internal.PrintableVersion(),
ChecksumsPath: checksumsPath,
MetadataPath: metadataPath,
HTTPClient: httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects()),
HTTPClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}

View File

@@ -9,18 +9,11 @@ import (
corebackend "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/auth"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/services/agentpool"
"github.com/mudler/LocalAI/core/services/facerecognition"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/monitoring"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/routing/admission"
"github.com/mudler/LocalAI/core/services/routing/billing"
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/services/routing/router"
"github.com/mudler/LocalAI/core/services/voicerecognition"
"github.com/mudler/LocalAI/core/templates"
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
@@ -58,22 +51,6 @@ type Application struct {
faceRegistry facerecognition.Registry
voiceRegistry voicerecognition.Registry
authDB *gorm.DB
metricsService *monitoring.LocalAIMetricsService
statsRecorder *billing.Recorder
fallbackUser *auth.User
piiRedactor *pii.Redactor
piiEvents pii.EventStore
mitmCA atomic.Pointer[mitm.CA]
mitmServer atomic.Pointer[mitm.Server]
mitmMutex sync.Mutex // serializes Stop+Start; readers use atomic loads
// mitmHostConflicts records duplicate-host claims across model configs.
// Non-empty disables the MITM listener until resolved — the strict
// 1-to-1 host↔model invariant the dispatcher relies on. Read by
// /api/middleware/status so the admin UI can surface the cause.
mitmHostConflicts atomic.Pointer[map[string][]string]
routerDecisions router.DecisionStore
routerRegistry *router.Registry
admissionLimiter *admission.Limiter
watchdogMutex sync.Mutex
watchdogStop chan bool
p2pMutex sync.Mutex
@@ -90,8 +67,6 @@ type Application struct {
// LocalAI Assistant in-process MCP server. nil when DisableLocalAIAssistant
// is set; otherwise initialised in start() after galleryService.
localAIAssistant *mcpTools.LocalAIAssistantHolder
shutdownOnce sync.Once
}
func newApplication(appConfig *config.ApplicationConfig) *Application {
@@ -210,103 +185,6 @@ func (a *Application) AuthDB() *gorm.DB {
return a.authDB
}
// MetricsService returns the OTel + Prometheus metric service. nil when
// --disable-metrics is set or initialisation failed at startup.
//
// The service is created in startup.go before any counter is registered
// so that otel.SetMeterProvider runs early enough for the billing
// recorder's counters to bind to the Prom-backed provider rather than
// the no-op global. core/http/app.go reuses this instance instead of
// constructing its own — two providers would orphan one set of counters
// behind whichever provider lost the SetMeterProvider race.
func (a *Application) MetricsService() *monitoring.LocalAIMetricsService {
return a.metricsService
}
// StatsRecorder returns the billing recorder used by the usage
// middleware. It is non-nil whenever stats are not explicitly disabled
// — i.e., the no-auth single-user path still gets a working recorder
// (in-memory by default). Routes register UsageMiddleware against this
// recorder regardless of auth state.
func (a *Application) StatsRecorder() *billing.Recorder {
return a.statsRecorder
}
// FallbackUser is the synthetic "local" user that UsageMiddleware uses
// to attribute requests when no authenticated user is on the context
// (i.e., --auth is off). nil when auth is on, since real users are
// always available there.
func (a *Application) FallbackUser() *auth.User {
return a.fallbackUser
}
// PIIRedactor returns the regex-tier PII redactor or nil if PII
// filtering is disabled. The chat-route middleware uses this to apply
// redaction before dispatch.
func (a *Application) PIIRedactor() *pii.Redactor {
return a.piiRedactor
}
// PIIEvents returns the PII event store. Same nil-when-disabled
// semantics as PIIRedactor; admin REST and MCP read tools call List
// against it.
func (a *Application) PIIEvents() pii.EventStore {
return a.piiEvents
}
// MITMCA returns the cloudproxy MITM proxy's CA, or nil when the
// MITM listener is disabled.
func (a *Application) MITMCA() *mitm.CA { return a.mitmCA.Load() }
// MITMServer returns the running MITM proxy or nil.
func (a *Application) MITMServer() *mitm.Server { return a.mitmServer.Load() }
// MITMHostConflicts returns a snapshot of host→[]model-name pairs that
// are claimed by 2+ model configs. Empty when the 1-to-1 invariant
// holds. Non-empty disables the MITM listener — read by the admin
// status endpoint to explain why.
func (a *Application) MITMHostConflicts() map[string][]string {
p := a.mitmHostConflicts.Load()
if p == nil {
return nil
}
return *p
}
// MITMHostOwners returns the host→model-name map, useful for the
// admin status endpoint. The lookup is recomputed on each call to
// stay current with model-config edits without needing a
// MITMRestart.
func (a *Application) MITMHostOwners() map[string]string {
if a.backendLoader == nil {
return nil
}
return a.backendLoader.MITMHostOwners().Owners
}
// RouterDecisions returns the routing decision store. nil when stats
// are disabled (--disable-stats); the RouteModel middleware skips the
// log write in that case but still rewrites requests.
func (a *Application) RouterDecisions() router.DecisionStore {
return a.routerDecisions
}
// RouterClassifierRegistry returns the process-wide classifier cache.
// Shared between the OpenAI and Anthropic route middlewares so the
// admin stats endpoint sees every live classifier — and so a
// classifier built on the OpenAI route is reused on Anthropic.
func (a *Application) RouterClassifierRegistry() *router.Registry {
return a.routerRegistry
}
// AdmissionLimiter returns the per-model admission limiter. The
// admission middleware uses it to gate concurrent requests; the
// admin status surface reads InFlight/Capacity from it for live
// load visibility.
func (a *Application) AdmissionLimiter() *admission.Limiter {
return a.admissionLimiter
}
// StartupConfig returns the original startup configuration (from env vars, before file loading)
func (a *Application) StartupConfig() *config.ApplicationConfig {
return a.startupConfig
@@ -322,24 +200,6 @@ func (a *Application) IsDistributed() bool {
return a.distributed != nil
}
// Shutdown stops backend gRPC processes and distributed services
// synchronously on the caller's stack. The context-cancel goroutine wired
// in New does the same work asynchronously, which races test-binary exit
// and CLI shutdown — orphaning spawned mock-backend / llama.cpp / etc.
// children to init. Callers that need a guarantee that cleanup has
// finished before they proceed (AfterSuite/AfterEach, signal handlers)
// must call this. Safe to call multiple times.
func (a *Application) Shutdown() error {
var err error
a.shutdownOnce.Do(func() {
a.distributed.Shutdown()
if a.modelLoader != nil {
err = a.modelLoader.StopAllGRPC()
}
})
return err
}
// waitForHealthyWorker blocks until at least one healthy backend worker is registered.
// This prevents the agent pool from failing during startup when workers haven't connected yet.
func (a *Application) waitForHealthyWorker() {
@@ -395,15 +255,6 @@ func (a *Application) start() error {
a.modelLoader,
a.galleryService,
)
// Wire usage tracking so the assistant's get_usage_stats tool
// returns real data; nil values keep the tool returning a clear
// "unavailable" error if startup ran with --disable-stats.
assistantClient.StatsRecorder = a.statsRecorder
assistantClient.FallbackUser = a.fallbackUser
// PII filter — same nil-or-real wiring.
assistantClient.PIIRedactor = a.piiRedactor
assistantClient.PIIEvents = a.piiEvents
assistantClient.RouterDecisions = a.routerDecisions
if err := holder.Initialize(a.applicationConfig.Context, assistantClient, localaitools.Options{}); err != nil {
// Why log+continue instead of fail: the assistant is an optional
// feature; a failure here must not take down the whole server.

View File

@@ -233,12 +233,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
xlog.Info("File stager initialized (HTTP direct transfer)")
}
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
remoteUnloader := nodes.NewRemoteUnloaderAdapter(
registry,
natsClient,
cfg.Distributed.BackendInstallTimeoutOrDefault(),
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
)
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
// All dependencies ready — build SmartRouter with all options at once
var conflictResolver nodes.ConcurrencyConflictResolver

View File

@@ -1,146 +0,0 @@
package application
import (
"errors"
"fmt"
"path/filepath"
"sort"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
"github.com/mudler/xlog"
)
func startMITMProxy(app *Application, options *config.ApplicationConfig) error {
app.mitmMutex.Lock()
defer app.mitmMutex.Unlock()
return startMITMLocked(app, options)
}
func startMITMLocked(app *Application, options *config.ApplicationConfig) error {
// Validate the host↔model-config 1-to-1 invariant before binding
// the listener. Two configs claiming the same host means the
// dispatcher would have ambiguous PII settings; refuse to start
// rather than silently picking one. The conflict map is published
// for /api/middleware/status to surface in the UI.
ownership := app.backendLoader.MITMHostOwners()
if len(ownership.Conflicts) > 0 {
conflicts := ownership.Conflicts
app.mitmHostConflicts.Store(&conflicts)
hosts := make([]string, 0, len(conflicts))
for h := range conflicts {
hosts = append(hosts, h)
}
sort.Strings(hosts)
xlog.Error("mitm: refusing to start — duplicate host claims across model configs",
"hosts", hosts,
"conflicts", conflicts,
)
return errors.New("mitm: configuration error: duplicate host claims (see /api/middleware/status)")
}
app.mitmHostConflicts.Store(nil)
caDir := options.MITMCADir
if caDir == "" {
base := options.DataPath
if base == "" {
base = "."
}
caDir = filepath.Join(base, "mitm-ca")
}
if app.mitmCA.Load() == nil {
ca, err := mitm.LoadOrCreateCA(caDir)
if err != nil {
return fmt.Errorf("ca: %w", err)
}
app.mitmCA.Store(ca)
}
// Allowlist is exactly the set of hosts claimed by model configs.
// No global list — admins add hosts by creating an MITM model
// config (template available in the Add Model UI). When no config
// claims any host, the listener still starts but every CONNECT
// tunnels through unmodified.
effectiveHosts := make([]string, 0, len(ownership.Owners))
for h := range ownership.Owners {
effectiveHosts = append(effectiveHosts, h)
}
sort.Strings(effectiveHosts)
// Per-host PII gate inherits from the owning model's pii.enabled.
// A non-cloud-proxy backend with no explicit pii.enabled resolves
// to false → host is intercepted but the regex pass is skipped
// (audit events still record).
var piiDisabled []string
for host, modelName := range ownership.Owners {
cfg, exists := app.backendLoader.GetModelConfig(modelName)
if !exists {
continue
}
if !cfg.PIIIsEnabled() {
piiDisabled = append(piiDisabled, host)
}
}
handler := mitm.NewPIIHandler(mitm.PIIHandlerOptions{
Redactor: app.piiRedactor,
EventStore: app.piiEvents,
HostsWithPIIDisabled: piiDisabled,
})
srv, err := mitm.NewServer(mitm.Config{
Addr: options.MITMListen,
CA: app.mitmCA.Load(),
InterceptHosts: effectiveHosts,
Handler: handler,
EventStore: app.piiEvents,
})
if err != nil {
return fmt.Errorf("server: %w", err)
}
if err := srv.Start(); err != nil {
return fmt.Errorf("listen: %w", err)
}
app.mitmServer.Store(srv)
xlog.Info("mitm: cloudproxy listener started",
"addr", srv.Addr(),
"ca_dir", caDir,
"intercept_hosts", effectiveHosts,
"model_owned_hosts", len(ownership.Owners),
"pii_disabled_hosts", len(piiDisabled),
)
return nil
}
// StopMITM is idempotent.
func (a *Application) StopMITM() error {
a.mitmMutex.Lock()
defer a.mitmMutex.Unlock()
stopMITMLocked(a)
return nil
}
// RestartMITM reuses the existing CA so trusted clients keep
// working across listener flips.
func (a *Application) RestartMITM() error {
a.mitmMutex.Lock()
defer a.mitmMutex.Unlock()
stopMITMLocked(a)
if a.applicationConfig.MITMListen == "" {
xlog.Info("mitm: cloudproxy listener stays disabled (no listen address)")
return nil
}
return startMITMLocked(a, a.applicationConfig)
}
func stopMITMLocked(a *Application) {
srv := a.mitmServer.Load()
if srv == nil {
return
}
srv.Stop()
a.mitmServer.Store(nil)
xlog.Info("mitm: cloudproxy listener stopped")
}

View File

@@ -1,63 +0,0 @@
package application
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
)
// adapterConfig resolves a model name to its runtime ModelConfig, or
// nil when the name is unknown. Shared by the router-facing factories
// below and by ModelConfigLookup.
func (a *Application) adapterConfig(modelName string) *config.ModelConfig {
cfg, err := a.backendLoader.LoadModelConfigFileByNameDefaultOptions(modelName, a.applicationConfig)
if err != nil || cfg == nil {
return nil
}
return cfg
}
// ModelConfigLookup is the lookup function the router middleware's
// classifier validator uses to confirm classifier_model declares
// FLAG_SCORE before binding it.
func (a *Application) ModelConfigLookup() func(modelName string) *config.ModelConfig {
return a.adapterConfig
}
// Scorer returns a backend.Scorer bound to the named model, or nil
// when the model is unknown. Used as a method value (app.Scorer) by
// router.ClassifierDeps — no factory-of-factory wrapper needed.
func (a *Application) Scorer(modelName string) backend.Scorer {
cfg := a.adapterConfig(modelName)
if cfg == nil {
return nil
}
return backend.NewScorer(a.modelLoader, *cfg, a.applicationConfig)
}
// Reranker returns a backend.Reranker bound to the named model, or
// nil when unknown. The reranker model's `type:` (e.g. "colbert")
// selects the scoring head inside the rerankers backend.
func (a *Application) Reranker(modelName string) backend.Reranker {
cfg := a.adapterConfig(modelName)
if cfg == nil {
return nil
}
return backend.NewReranker(a.modelLoader, *cfg, a.applicationConfig)
}
// Embedder returns a backend.Embedder bound to the named model, or
// nil when unknown. Used by the router's L2 embedding cache.
func (a *Application) Embedder(modelName string) backend.Embedder {
cfg := a.adapterConfig(modelName)
if cfg == nil {
return nil
}
return backend.NewEmbedder(a.modelLoader, *cfg, a.applicationConfig)
}
// VectorStore returns a backend.VectorStore for the named collection,
// or nil when the name is empty. Each router model gets its own
// backend process via the model loader's cache keyed by storeName.
func (a *Application) VectorStore(storeName string) backend.VectorStore {
return backend.NewVectorStore(a.modelLoader, a.applicationConfig, storeName)
}

View File

@@ -87,28 +87,6 @@ var _ = Describe("loadRuntimeSettingsFromFile", func() {
})
})
// MITM listener address. The file is the only source — no env var
// exists — so a regression here means an admin who configured the
// listener via /api/settings loses it after a reboot, even though
// the value is still on disk in the volume. (Intercept hosts now
// live in model YAML mitm.hosts: blocks, not runtime_settings.json.)
Describe("MITM fields", func() {
It("loads mitm_listen", func() {
cfg := &config.ApplicationConfig{DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`)}
loadRuntimeSettingsFromFile(cfg)
Expect(cfg.MITMListen).To(Equal(":8443"))
})
It("does not override an explicit CLI flag", func() {
cfg := &config.ApplicationConfig{
DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`),
MITMListen: ":9999", // simulate WithMITMListen(":9999")
}
loadRuntimeSettingsFromFile(cfg)
Expect(cfg.MITMListen).To(Equal(":9999"), "CLI flag must win over the persisted file value")
})
})
// The Agent Pool block has a mix of zero and non-zero defaults
// (Enabled=true, EmbeddingModel="granite-...", MaxChunkingSize=400,
// VectorEngine="chromem", AgentHubURL="https://agenthub.localai.io").

View File

@@ -15,18 +15,11 @@ import (
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/jobs"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/monitoring"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/routing/admission"
"github.com/mudler/LocalAI/core/services/routing/billing"
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/services/routing/router"
"github.com/mudler/LocalAI/core/services/storage"
"github.com/mudler/LocalAI/pkg/signals"
"github.com/mudler/LocalAI/pkg/vram"
coreStartup "github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/vram"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/sanitize"
@@ -135,117 +128,6 @@ func New(opts ...config.AppOption) (*Application, error) {
}()
}
// Initialize the OTel + Prometheus metric pipeline before any
// counter is created. monitoring.NewLocalAIMetricsService calls
// otel.SetMeterProvider, so any subsequent otel.Meter() call —
// including billing.NewRecorder below — sees the real provider
// rather than the no-op global. Initialising metrics later (in
// core/http/app.go) leaves billing's counters bound to a no-op
// meter and never reaches /metrics. We deliberately ignore
// DisableMetrics here for ordering purposes; the HTTP middleware
// that records api_call histograms is still gated.
if !options.DisableMetrics {
ms, err := monitoring.NewLocalAIMetricsService()
if err != nil {
xlog.Error("failed to initialize metrics provider", "error", err)
} else {
application.metricsService = ms
// Bind the billing package's counters to the same meter the
// metrics service exports. Without this, billing's counters
// resolve via the OTel global and never reach /metrics.
billing.SetMeter(ms.Meter)
}
}
// Wire the routing-module billing recorder. The recorder runs in
// every mode (auth on/off, distributed/single-node) so that token
// tracking is not gated on auth — a no-auth single-user box still
// gets dashboards and `/api/usage` populated.
//
// fallbackUser is wired *unconditionally* when stats are enabled.
// UsageMiddleware uses it as the attribution source whenever
// auth.GetUser(c) is nil — that covers (a) no-auth deployments and
// (b) internal callers under auth-on (cron flushers, distributed
// worker callbacks) that hit a recordable endpoint without a user
// in context. The billing.user_id_present invariant still rejects
// empty IDs; LocalUser() returns a stable UUID per data path.
if !options.DisableStats {
var statsBackend billing.StatsBackend
switch {
case application.authDB != nil:
statsBackend = billing.NewGormBackend(application.authDB, 0, 0)
xlog.Info("stats: using auth DB for usage records")
default:
statsBackend = billing.NewMemoryBackend(0)
xlog.Info("stats: using in-memory ring buffer (no-auth single-user mode)")
}
application.fallbackUser = billing.LocalUser(options.DataPath)
application.statsRecorder = billing.NewRecorder(statsBackend)
// Drain pending records on SIGTERM. The GORM backend buffers up
// to maxPending (5k) records across a 5s flush tick, so without
// this the last few seconds of usage disappear on graceful exit.
signals.RegisterGracefulTerminationHandler(func() {
_ = application.statsRecorder.Close()
})
xlog.Info("stats: fallback user wired", "local_user_id", application.fallbackUser.ID)
} else {
xlog.Info("stats: disabled by --disable-stats")
}
// Wire the regex PII filter. Default-on: a single-user box gets
// the built-in pattern set the first time it starts, with email/
// phone/SSN/credit-card on mask and api_key_prefix on block. If
// the operator wants different actions, --pii-config points at a
// YAML file that overrides per-id; --disable-pii turns it off
// entirely.
if !options.DisablePII {
patterns, err := pii.LoadConfig(options.PIIConfigPath)
if err != nil {
return nil, fmt.Errorf("pii config: %w", err)
}
application.piiRedactor = pii.NewRedactor(patterns)
application.piiEvents = pii.NewMemoryEventStore(0)
// Apply persisted per-pattern overrides — admins toggling
// action/disabled via the UI and clicking "Save to disk" land
// here on the next start. Bad ids are warned and ignored so a
// stale entry doesn't block startup.
for id, ov := range options.PIIPatternOverrides {
if ov.Action != nil {
if err := application.piiRedactor.SetAction(id, pii.Action(*ov.Action)); err != nil {
xlog.Warn("pii: persisted override skipped", "pattern", id, "error", err)
continue
}
}
if ov.Disabled != nil {
if err := application.piiRedactor.SetDisabled(id, *ov.Disabled); err != nil {
xlog.Warn("pii: persisted disable skipped", "pattern", id, "error", err)
}
}
}
xlog.Info("pii: filter enabled",
"patterns", len(patterns),
"config_path", options.PIIConfigPath,
"persisted_overrides", len(options.PIIPatternOverrides),
)
} else {
xlog.Info("pii: disabled by --disable-pii")
}
// Wire the routing decision log. Always-on when stats are enabled —
// the per-router admin page reads this as the live activity feed
// and as input to drift checks for subsystem 5.
if !options.DisableStats {
application.routerDecisions = router.NewMemoryDecisionStore(0)
}
// Process-wide classifier cache shared across all route middlewares so
// the embedding-cache stats endpoint sees a single source of truth.
application.routerRegistry = router.NewRegistry()
// Subsystem 5: admission control. Limiter is always wired so a
// model that gains a limits: block via gallery install or YAML
// edit takes effect on the next restart without conditional plumbing.
application.admissionLimiter = admission.New()
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
if application.authDB != nil && application.agentJobService != nil {
@@ -313,36 +195,12 @@ func New(opts ...config.AppOption) (*Application, error) {
}
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
}
// Hydrate from the store first so the wildcard subscriber finds an
// already-populated statuses map for any operations still in flight
// on a peer replica.
if err := application.galleryService.Hydrate(); err != nil {
xlog.Warn("Gallery service hydrate failed", "error", err)
}
// Bind cache-invalidation handler before SubscribeBroadcasts so the
// first inbound event is already routed. Peer replicas install a
// model and broadcast on SubjectCacheInvalidateModels; this
// callback re-runs LoadModelConfigsFromPath so a subsequent chat
// completion that load-balances onto this replica finds the new
// config. The originating replica reloads inline in modelHandler
// and never enters this path.
gs := application.galleryService
sys := options.SystemState
cfgLoaderOpts := options.ToConfigLoaderOptions()
gs.OnModelsChanged = func(_ messaging.CacheInvalidateEvent) {
if err := application.ModelConfigLoader().LoadModelConfigsFromPath(sys.Model.ModelsPath, cfgLoaderOpts...); err != nil {
xlog.Warn("Failed to reload model configs after peer invalidation", "error", err)
}
}
if err := application.galleryService.SubscribeBroadcasts(); err != nil {
xlog.Warn("Gallery service subscribe failed", "error", err)
}
// Wire distributed model/backend managers so delete propagates to workers
application.galleryService.SetModelManager(
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
)
application.galleryService.SetBackendManager(
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry, application.galleryService),
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
)
}
}
@@ -433,31 +291,15 @@ func New(opts ...config.AppOption) (*Application, error) {
loadRuntimeSettingsFromFile(options)
}
// Wire the cloudproxy MITM listener. Opt-in: empty MITMListen
// means "no MITM" — operators must explicitly choose to start
// it because clients have to install the generated CA cert.
// The handler reuses the global redactor + event store so an
// admin who's already configured PII filtering for direct API
// traffic doesn't need a parallel config for MITM traffic.
// Runs after loadRuntimeSettingsFromFile so a listener configured
// via /api/settings is brought back up across restarts.
if options.MITMListen != "" {
if err := startMITMProxy(application, options); err != nil {
return nil, fmt.Errorf("mitm: startup: %w", err)
}
}
application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging)
// Safety-net cleanup if the application context is cancelled without
// the caller invoking Shutdown directly. This is fire-and-forget — it
// races binary exit and is unreliable in tests; the deterministic path
// is application.Shutdown(), which Shutdown's sync.Once dedupes with
// this goroutine.
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
xlog.Debug("Context canceled, shutting down")
if err := application.Shutdown(); err != nil {
application.distributed.Shutdown()
err := application.ModelLoader().StopAllGRPC()
if err != nil {
xlog.Error("error while stopping all grpc backends", "error", err)
}
}()
@@ -710,13 +552,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
options.TracingMaxItems = *settings.TracingMaxItems
}
}
if settings.TracingMaxBodyBytes != nil {
// Allow the on-disk setting to override the CLI/env default. The
// startup default is non-zero (see NewApplicationConfig), so a plain
// `== 0` guard like the others would never trigger; we instead respect
// any value the file specifies. 0 in the file means "uncapped".
options.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
}
// Branding / whitelabeling. There are no env vars for these — the file is
// the only source — so apply unconditionally. Without this block a server
@@ -738,25 +573,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
options.Branding.FaviconFile = *settings.FaviconFile
}
// MITM listener address. The CLI flag WithMITMListen populates
// options at startup; if the user configured MITM via /api/settings
// after the fact, only the file holds the value. Apply when the
// CLI flag did not already set it. (Intercept hosts now live in
// model YAML mitm.hosts: rather than runtime_settings.json.)
if settings.MITMListen != nil && options.MITMListen == "" {
options.MITMListen = *settings.MITMListen
}
// PII pattern overrides — file is the only source; CLI flags don't
// reach into this map. Apply unconditionally when present; the
// redactor wiring below sees the result on first construction.
if settings.PIIPatternOverrides != nil {
options.PIIPatternOverrides = make(map[string]config.PIIPatternRuntimeOverride, len(*settings.PIIPatternOverrides))
for id, ov := range *settings.PIIPatternOverrides {
options.PIIPatternOverrides[id] = ov
}
}
// Backend upgrade flags
if settings.AutoUpgradeBackends != nil {
if !options.AutoUpgradeBackends {

View File

@@ -78,7 +78,7 @@ func ModelAudioTransform(
var startTime time.Time
if appConfig.EnableTracing {
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
startTime = time.Now()
}
@@ -104,7 +104,7 @@ func ModelAudioTransform(
data["sample_rate"] = res.SampleRate
data["samples"] = res.Samples
data["reference_provided"] = res.ReferenceProvided
if snippet := trace.AudioSnippet(dst, appConfig.TracingMaxBodyBytes); snippet != nil {
if snippet := trace.AudioSnippet(dst); snippet != nil {
maps.Copy(data, snippet)
}
}

View File

@@ -1,169 +0,0 @@
package backend_test
// Regression spec for X-LocalAI-Node coverage on audio/image/TTS/rerank/VAD.
//
// The X-LocalAI-Node middleware (core/http/middleware.ExposeNodeHeader)
// works end-to-end only if the per-request holder attached to the HTTP
// request context reaches the SmartRouter via ml.Load(opts...). The chain
// is:
//
// handler -> backend.Foo(ctx, ...) -> ModelOptions(cfg, app, WithContext(ctx))
// -> ml.Load(opts...) -> grpcModel(..., o.context) -> modelRouter(ctx, ...)
// -> SmartRouter -> distributedhdr.Stamp(ctx, nodeID)
//
// If any backend helper drops `ctx` and lets ModelOptions fall back to the
// app context, the router never sees the per-request holder and the
// header silently stays empty for that endpoint. These specs pin the
// request-context-reaches-router contract for the five backend helpers
// that were previously dropping ctx between the handler and Load.
import (
"context"
"sync/atomic"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
pbproto "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/distributedhdr"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// newCapturingLoader returns a ModelLoader wired with a stub model router
// that captures the context it receives and then short-circuits with a
// sentinel error. The router callback is the exact seam where the
// SmartRouter would call distributedhdr.Stamp in production, so observing
// the holder here is equivalent to observing it at the real router.
func newCapturingLoader() (*model.ModelLoader, *atomic.Value, func() context.Context) {
loader := model.NewModelLoader(&system.SystemState{})
var captured atomic.Value
loader.SetModelRouter(func(ctx context.Context, _ string, _, _, _ string, _ *pbproto.ModelOptions, _ bool) (*model.Model, error) {
captured.Store(ctx)
// Return an error so the backend short-circuits before trying to
// dial gRPC. We only care about the context-arrival contract.
return nil, errRouterShortCircuit
})
get := func() context.Context {
v, _ := captured.Load().(context.Context)
return v
}
return loader, &captured, get
}
var errRouterShortCircuit = sentinelErr("router short-circuit (test)")
type sentinelErr string
func (s sentinelErr) Error() string { return string(s) }
func newAppCfg() *config.ApplicationConfig {
return config.NewApplicationConfig(config.WithSystemState(&system.SystemState{}))
}
func newModelCfg() config.ModelConfig {
threads := 1
cfg := config.ModelConfig{
Name: "test-model",
Backend: "stub-backend",
Threads: &threads,
}
cfg.Model = "test.bin"
return cfg
}
var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
const fakeNodeID = "node-ctx-propagation-7"
var (
appCfg *config.ApplicationConfig
modelCfg config.ModelConfig
loader *model.ModelLoader
routerCtxOf func() context.Context
holder *atomic.Value
reqCtx context.Context
)
BeforeEach(func() {
appCfg = newAppCfg()
modelCfg = newModelCfg()
loader, _, routerCtxOf = newCapturingLoader()
holder = distributedhdr.NewHolder()
reqCtx = distributedhdr.WithHolder(context.Background(), holder)
})
// stampViaRouterCtx asserts the captured router context carries the
// SAME holder that was attached to the request. We verify by stamping
// through the router-side ctx and observing the value via the
// request-side holder; if the holders were different objects the load
// would return "".
stampViaRouterCtx := func() {
routerCtx := routerCtxOf()
Expect(routerCtx).ToNot(BeNil(), "router callback must have been invoked")
distributedhdr.Stamp(routerCtx, fakeNodeID)
Expect(distributedhdr.Load(holder)).To(Equal(fakeNodeID),
"stamp via router-side ctx must be observable via the request-side holder")
}
It("Rerank forwards the request context to the SmartRouter", func() {
_, err := backend.Rerank(reqCtx, &pbproto.RerankRequest{Query: "q"}, loader, appCfg, modelCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("VAD forwards the request context to the SmartRouter", func() {
_, err := backend.VAD(&schema.VADRequest{}, reqCtx, loader, appCfg, modelCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTTS forwards the request context to the SmartRouter", func() {
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTranscriptionWithOptions forwards the request context to the SmartRouter", func() {
_, err := backend.ModelTranscriptionWithOptions(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ModelTranscriptionStream forwards the request context to the SmartRouter", func() {
err := backend.ModelTranscriptionStream(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg, func(backend.TranscriptionStreamChunk) {})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("ImageGeneration forwards the request context to the SmartRouter", func() {
_, err := backend.ImageGeneration(reqCtx, 64, 64, 1, 0, "p", "", "", "/tmp/out.png", loader, modelCfg, appCfg, nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
stampViaRouterCtx()
})
It("does NOT leak the holder when the app context is used instead", func() {
// Sanity: the bug being fixed manifests as the router getting
// appCfg.Context (no holder) instead of reqCtx (holder). A direct
// call with context.Background() must not see the holder via the
// app context surface.
appCtxOnly := appCfg.Context
Expect(distributedhdr.Holder(appCtxOnly)).To(BeNil(),
"the app context must not be the carrier of per-request holders")
})
})

View File

@@ -35,7 +35,7 @@ func Detection(
var startTime time.Time
if appConfig.EnableTracing {
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
startTime = time.Now()
}

View File

@@ -1,7 +1,6 @@
package backend
import (
"context"
"fmt"
"time"
@@ -12,38 +11,9 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
// Embedder produces a fixed-dimension vector from a prompt. The
// router's L2 embedding cache uses it to look up semantically-similar
// past decisions.
type Embedder interface {
Embed(ctx context.Context, text string) ([]float32, error)
}
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
// NewEmbedder binds (loader, modelConfig, appConfig) into an Embedder.
func NewEmbedder(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Embedder {
return &modelEmbedder{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
}
type modelEmbedder struct {
loader *model.ModelLoader
modelConfig config.ModelConfig
appConfig *config.ApplicationConfig
}
func (e *modelEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
fn, err := ModelEmbedding(ctx, text, nil, e.loader, e.modelConfig, e.appConfig)
if err != nil {
return nil, err
}
return fn()
}
func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
// model.WithContext(ctx) overrides the app-context default set in
// ModelOptions so distributed routing decisions reach the request's
// X-LocalAI-Node holder via distributedhdr.Stamp.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
opts := ModelOptions(modelConfig, appConfig)
inferenceModel, err := loader.Load(opts...)
if err != nil {
@@ -97,7 +67,7 @@ func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.M
}
if appConfig.EnableTracing {
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
traceData := map[string]any{
"input_text": trace.TruncateString(s, 1000),

View File

@@ -32,7 +32,7 @@ func FaceAnalyze(
var startTime time.Time
if appConfig.EnableTracing {
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
startTime = time.Now()
}

Some files were not shown because too many files have changed in this diff Show More