mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-25 17:18:18 -04:00
Compare commits
117 Commits
v4.2.3
...
fix/9988-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
391c1a3dcc | ||
|
|
2c7f83d6a2 | ||
|
|
9ff270eb65 | ||
|
|
8d6548c0b9 | ||
|
|
b02e3ffe61 | ||
|
|
a891eedd08 | ||
|
|
06e777b75e | ||
|
|
90ea327178 | ||
|
|
6a80e23733 | ||
|
|
1dcd1ae915 | ||
|
|
acad78a95a | ||
|
|
c94d1e1f5b | ||
|
|
270c256409 | ||
|
|
1a30020a82 | ||
|
|
8bbe89a537 | ||
|
|
dcc5599f89 | ||
|
|
a95f4e63e0 | ||
|
|
dfd19a3f88 | ||
|
|
d7387c725c | ||
|
|
63d84a5705 | ||
|
|
1198d10b58 | ||
|
|
a0f3e26245 | ||
|
|
e4cc1f11f3 | ||
|
|
6ed269d0b9 | ||
|
|
5756fb046d | ||
|
|
7980629bc5 | ||
|
|
d0a59be9de | ||
|
|
5cda4f1ccf | ||
|
|
c500461c69 | ||
|
|
834ecc36bf | ||
|
|
61bf34ea2f | ||
|
|
0b2ae3c6ca | ||
|
|
4735345105 | ||
|
|
7384fd800b | ||
|
|
6942713d85 | ||
|
|
0cf52c44d4 | ||
|
|
0d34cf7cbd | ||
|
|
f0cb02afb8 | ||
|
|
a39e025d64 | ||
|
|
05e8e1e9f4 | ||
|
|
a7f6cc8956 | ||
|
|
f15b9178ec | ||
|
|
959de86761 | ||
|
|
4c234abc2c | ||
|
|
c68818a62e | ||
|
|
11d5bd0cc3 | ||
|
|
12e056e96d | ||
|
|
308aa8908a | ||
|
|
b2d68a53a2 | ||
|
|
e3706c0512 | ||
|
|
1ffd82a050 | ||
|
|
f515168dbe | ||
|
|
ef6ca34513 | ||
|
|
9413c3767f | ||
|
|
3bf3cce232 | ||
|
|
06f8159035 | ||
|
|
f6a73f54fa | ||
|
|
24e04d8e81 | ||
|
|
b9a49449ae | ||
|
|
1879e11042 | ||
|
|
403d391316 | ||
|
|
fc3980dadd | ||
|
|
2009544b44 | ||
|
|
e859345b12 | ||
|
|
f30712f8e8 | ||
|
|
a19c77c5f8 | ||
|
|
4b02d23c0c | ||
|
|
21140e96b2 | ||
|
|
fc803e8d48 | ||
|
|
ca51606bfe | ||
|
|
cb502de309 | ||
|
|
5d0b549049 | ||
|
|
11cff1b309 | ||
|
|
4ca3d2cdc0 | ||
|
|
3cba35ed32 | ||
|
|
265ae35231 | ||
|
|
6a48157a80 | ||
|
|
41c838b2df | ||
|
|
21e793ad2a | ||
|
|
7c190bb4b9 | ||
|
|
d77a9137d8 | ||
|
|
661a0c3b9d | ||
|
|
00b8989886 | ||
|
|
43e0d397ca | ||
|
|
a1a7a219ed | ||
|
|
3937ec6527 | ||
|
|
1355b55794 | ||
|
|
5a2626d465 | ||
|
|
a39591f144 | ||
|
|
8c785dbe4a | ||
|
|
4abf5befbb | ||
|
|
195b910260 | ||
|
|
ba21bf667c | ||
|
|
7bd1693ad0 | ||
|
|
b5ac3a7373 | ||
|
|
53de474ef5 | ||
|
|
c33d36b870 | ||
|
|
57fa178a64 | ||
|
|
745473cbe6 | ||
|
|
594c9fd92e | ||
|
|
8af963bdd9 | ||
|
|
6e1dbae256 | ||
|
|
53bdb18d10 | ||
|
|
42a8db3573 | ||
|
|
0353d3bd77 | ||
|
|
ec49995190 | ||
|
|
67c34bbb96 | ||
|
|
4430fae779 | ||
|
|
ab01ed1a3e | ||
|
|
6bfe7f8c05 | ||
|
|
5a42dbf3ec | ||
|
|
c2fe0a6475 | ||
|
|
ddbbdf45b9 | ||
|
|
b4fdb41dcc | ||
|
|
0245b33eab | ||
|
|
a2940e5d47 | ||
|
|
a645c1f4aa |
@@ -112,6 +112,8 @@ Add a YAML anchor definition in the `## metas` section (around line 2-300). Look
|
||||
|
||||
Add image entries at the end of the file, following the pattern of similar backends such as `diffusers` or `chatterbox`. Include both `latest` (production) and `master` (development) tags.
|
||||
|
||||
**Note on integrity:** OCI backends installed from a gallery whose `verification:` block is set are verified against a keyless-cosign policy before extraction; tarball/HTTP backends use the optional `sha256:` field. New backends do not need any extra YAML — the gallery-level `verification:` block covers every entry. See [.agents/backend-signing.md](backend-signing.md) for the producer-side CI step.
|
||||
|
||||
## 4. Update the Makefile
|
||||
|
||||
The Makefile needs to be updated in several places to support building and testing the new backend:
|
||||
|
||||
@@ -284,7 +284,17 @@ Also bump the expected-length count in `api_instructions_test.go` and add the na
|
||||
|
||||
### 3. `capabilities.js` symbol (for new model-config FLAG_* flags)
|
||||
|
||||
If your feature needs a new `FLAG_*` usecase flag in `core/config/model_config.go` (so users can filter gallery models by it, and so `/v1/models` surfaces it), also declare the matching symbol in `core/http/react-ui/src/utils/capabilities.js`:
|
||||
If your feature needs a new `FLAG_*` usecase flag in `core/config/model_config.go` (so users can filter gallery models by it, and so `/v1/models` surfaces it), you need to update **all** of:
|
||||
|
||||
- `Usecase<Name>` string constant in `core/config/backend_capabilities.go`
|
||||
- `UsecaseInfoMap` entry mapping the string to its flag + gRPC method
|
||||
- `FLAG_<NAME>` bitmask in `core/config/model_config.go`
|
||||
- `GetAllModelConfigUsecases()` map entry (otherwise the YAML loader silently ignores the string)
|
||||
- `ModalityGroups` membership if the flag should affect `IsMultimodal()` (e.g. realtime_audio is in both speech-input and audio-output groups so a lone flag still reads as multimodal)
|
||||
- `GuessUsecases()` branch listing the backends that own this capability
|
||||
- `usecaseFilters` in `core/http/routes/ui_api.go` (drives the gallery filter dropdown)
|
||||
- `Models.jsx` `FILTERS` array + matching `filters.<camelCase>` i18n key in `core/http/react-ui/public/locales/en/models.json`
|
||||
- `core/http/react-ui/src/utils/capabilities.js`:
|
||||
|
||||
```js
|
||||
export const CAP_MY_CAPABILITY = 'FLAG_MY_CAPABILITY'
|
||||
|
||||
126
.agents/backend-signing.md
Normal file
126
.agents/backend-signing.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Backend image signing & verification
|
||||
|
||||
LocalAI verifies backend OCI images against a per-gallery keyless-cosign
|
||||
policy. This page documents the trust model, the producer side
|
||||
(`.github/workflows/backend_merge.yml` in this repo), and the consumer
|
||||
side (`pkg/oci/cosignverify` plus the gallery YAML).
|
||||
|
||||
## Trust model
|
||||
|
||||
- **Producer:** `.github/workflows/backend_merge.yml` signs each pushed
|
||||
manifest list with `cosign sign --recursive` in keyless mode after
|
||||
`docker buildx imagetools create`. The signing cert is issued by
|
||||
Fulcio bound to the workflow's OIDC identity. There is no long-lived
|
||||
signing key. `--recursive` signs both the manifest list and every
|
||||
per-arch entry — needed because our consumer resolves a tag to a
|
||||
per-arch manifest before checking signatures.
|
||||
- **Storage:** Signatures are written as OCI 1.1 referrers
|
||||
(`--registry-referrers-mode=oci-1-1`) in the new Sigstore bundle format
|
||||
(current cosign releases do this by default; no `--new-bundle-format`
|
||||
flag). No `:sha256-<hex>.sig` tag clutter.
|
||||
- **Consumer:** `pkg/oci/cosignverify` discovers the bundle via the
|
||||
referrers API, hands it to `sigstore-go`, and verifies it against the
|
||||
policy declared in the gallery YAML (`Gallery.Verification`).
|
||||
- **Revocation:** Keyless cosign certs are ephemeral (10-minute Fulcio
|
||||
validity), so revocation is policy-side, not CA-side. The gallery's
|
||||
`verification.not_before` (RFC3339) is the kill-switch — advance it to
|
||||
invalidate every signature produced before a known compromise window.
|
||||
|
||||
## Producer setup
|
||||
|
||||
`backend_merge.yml` is the workflow that joins per-arch digests into the
|
||||
multi-arch manifest list users actually pull, so it's also the right place
|
||||
to sign. The job needs:
|
||||
|
||||
- `permissions: { id-token: write, contents: read }` at the job level so
|
||||
the runner can exchange its GitHub OIDC token for a Fulcio cert.
|
||||
- `sigstore/cosign-installer@v3` step (current cosign releases already
|
||||
default to the new bundle format).
|
||||
- After each `docker buildx imagetools create`, resolve the resulting
|
||||
list digest with `docker buildx imagetools inspect <tag> --format
|
||||
'{{.Manifest.Digest}}'` and sign:
|
||||
|
||||
```sh
|
||||
cosign sign --yes --recursive \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"${REGISTRY_REPO}@${DIGEST}"
|
||||
```
|
||||
|
||||
Sign by digest, never by tag — signing by tag binds the signature to
|
||||
whatever the tag points at *now*, and a subsequent tag push orphans it.
|
||||
|
||||
`--registry-referrers-mode=oci-1-1` is still gated behind
|
||||
`COSIGN_EXPERIMENTAL=1` in cosign v2.4.x (set at the job env level in
|
||||
`backend_merge.yml`). Re-evaluate when bumping the pinned cosign release
|
||||
— newer versions are expected to graduate this flag and the env var can
|
||||
then be dropped.
|
||||
|
||||
`backend_build_darwin.yml` builds and pushes single-arch darwin images
|
||||
that bypass the manifest-list merge. If/when those entries get a gallery
|
||||
`verification:` policy, the equivalent cosign step has to land there
|
||||
too.
|
||||
|
||||
## Consumer setup (in `mudler/LocalAI` gallery YAML)
|
||||
|
||||
Once CI is signing, add a `verification:` block to the backend gallery
|
||||
entry (`backend/index.yaml`):
|
||||
|
||||
```yaml
|
||||
- name: localai
|
||||
url: github:mudler/LocalAI/backend/index.yaml@master
|
||||
verification:
|
||||
issuer: "https://token.actions.githubusercontent.com"
|
||||
identity_regex: "^https://github\\.com/mudler/LocalAI/\\.github/workflows/backend_merge\\.yml@refs/heads/master$"
|
||||
# Optional revocation cutoff; advance during incident response.
|
||||
# not_before: "2026-06-01T00:00:00Z"
|
||||
```
|
||||
|
||||
Identity matching pins the OIDC subject Fulcio issued the signing cert
|
||||
to. Without this, any image signed by *anyone* with a Fulcio cert would
|
||||
pass — the regex is what makes a signature mean "produced by our CI".
|
||||
|
||||
## Strict mode
|
||||
|
||||
Default behaviour: OCI backends without a `verification:` block install
|
||||
with a warning (logs include `installing OCI backend without signature
|
||||
verification`). Tarball/HTTP backends without a `sha256` field log a
|
||||
similar warning.
|
||||
|
||||
For production, set `LOCALAI_REQUIRE_BACKEND_INTEGRITY=1` (or pass
|
||||
`--require-backend-integrity` to `local-ai run` / `local-ai backends
|
||||
install` / `local-ai models install`). The warning becomes a hard error
|
||||
and unverifiable backends refuse to install.
|
||||
|
||||
## Revocation playbook
|
||||
|
||||
If `backend_merge.yml` (or any workflow with `id-token: write`) is
|
||||
compromised and we've shipped malicious signed images:
|
||||
|
||||
1. **Identify the compromise window.** Find the earliest IntegratedTime
|
||||
from the bad signatures (Rekor search by `subject` filter).
|
||||
2. **Set `verification.not_before`** in `backend/index.yaml` to a
|
||||
timestamp just *after* that window's start.
|
||||
3. **Push the YAML.** Deployed LocalAI instances pick it up on next
|
||||
gallery refresh (1-hour cache in `core/gallery/gallery.go`).
|
||||
4. **Fix the underlying compromise** in the workflow and re-sign images
|
||||
with the new build, which will have IntegratedTime > `not_before`.
|
||||
5. **Optional:** for absolute decisiveness, also rotate to a new
|
||||
workflow path (`backend_merge_v2.yml`) and update `identity_regex`.
|
||||
|
||||
## Where the code lives
|
||||
|
||||
- `pkg/oci/cosignverify/` — verifier, policy, OCI referrer fetch, NotBefore enforcement.
|
||||
- `pkg/downloader/uri.go` — `WithImageVerifier` option threaded through `DownloadFileWithContext`.
|
||||
- `core/gallery/backends.go` — `backendDownloadOptions` builds the verifier from the gallery's policy.
|
||||
- `core/config/gallery.go` — `Gallery.Verification` YAML schema.
|
||||
- `core/cli/run.go`, `core/cli/backends.go`, `core/cli/models.go` — `--require-backend-integrity` flag propagation.
|
||||
- `.github/workflows/backend_merge.yml` — producer-side `cosign sign --recursive` after each multi-arch manifest list push.
|
||||
|
||||
## Out of scope (follow-ups)
|
||||
|
||||
- **Signing the gallery YAML itself.** The index is fetched over HTTPS
|
||||
from GitHub; we trust the host. A cosign blob signature on the YAML
|
||||
would close that gap but adds key-management overhead. Revisit this
|
||||
page if/when added.
|
||||
- **Tarball/HTTP backend signing.** Cosign can sign arbitrary blobs, but
|
||||
for now non-OCI backends keep using the `sha256:` field in YAML.
|
||||
@@ -61,6 +61,12 @@ Always check `llama.cpp` for new model configuration options that should be supp
|
||||
- `reasoning_format` - Reasoning format options
|
||||
- Any new flags or parameters
|
||||
|
||||
### Speculative Decoding Types
|
||||
|
||||
The `spec_type` option in `grpc-server.cpp` delegates to upstream's `common_speculative_types_from_names()`, so new speculative types added to the `common_speculative_type_from_name` map in `common/speculative.cpp` are picked up automatically with no code changes - only docs need an entry in `docs/content/advanced/model-configuration.md`. Current values: `none`, `draft-simple`, `draft-eagle3`, `draft-mtp`, `ngram-simple`, `ngram-map-k`, `ngram-map-k4v`, `ngram-mod`, `ngram-cache`.
|
||||
|
||||
`draft-mtp` (Multi-Token Prediction, [ggml-org/llama.cpp#22673](https://github.com/ggml-org/llama.cpp/pull/22673)) does not need a separate draft GGUF: when `spec_type` includes `draft-mtp` and `draftmodel` is empty, the upstream server creates an MTP context off the target model itself. LocalAI's gRPC layer needs no changes for this — it works through the existing `params.speculative.types` plumbing and the derived `cparams.n_rs_seq = params.speculative.need_n_rs_seq()` in `common_context_params_to_llama`.
|
||||
|
||||
### Implementation Guidelines
|
||||
|
||||
1. **Feature Parity**: Always aim for feature parity with llama.cpp's implementation
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
.devcontainer
|
||||
models
|
||||
backends
|
||||
volumes
|
||||
examples/chatbot-ui/models
|
||||
backend/go/image/stablediffusion-ggml/build/
|
||||
backend/go/*/build
|
||||
@@ -21,3 +22,11 @@ __pycache__
|
||||
# backend virtual environments
|
||||
**/venv
|
||||
backend/python/**/source
|
||||
|
||||
# In-place llama.cpp clone + per-variant build copies. The Makefile
|
||||
# clones llama.cpp itself at the pinned LLAMA_VERSION; if a stale
|
||||
# local checkout is COPY'd into the image, the `llama.cpp:` target
|
||||
# sees the directory and skips re-cloning, so grpc-server.cpp ends
|
||||
# up compiled against whatever (likely older) commit the host had.
|
||||
backend/cpp/llama-cpp/llama.cpp
|
||||
backend/cpp/llama-cpp-*-build
|
||||
|
||||
79
.github/backend-matrix.yml
vendored
79
.github/backend-matrix.yml
vendored
@@ -278,6 +278,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-liquid-audio'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "liquid-audio"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -808,6 +821,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-liquid-audio'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "liquid-audio"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1088,6 +1114,19 @@ include:
|
||||
backend: "vibevoice"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-liquid-audio'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "liquid-audio"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1729,6 +1768,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-liquid-audio'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "liquid-audio"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2177,6 +2229,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-liquid-audio'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "liquid-audio"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -3503,6 +3568,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-liquid-audio'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "liquid-audio"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
|
||||
59
.github/workflows/backend_merge.yml
vendored
59
.github/workflows/backend_merge.yml
vendored
@@ -31,8 +31,20 @@ on:
|
||||
jobs:
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
# id-token: write is required for keyless cosign — the workflow
|
||||
# exchanges the GitHub OIDC token for a short-lived Fulcio cert that
|
||||
# signs each pushed manifest. Without this permission the runner
|
||||
# cannot mint the token, and `cosign sign` fails with "no token".
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
env:
|
||||
quay_username: ${{ secrets.quayUsername }}
|
||||
# cosign v2.4.x still gates --registry-referrers-mode=oci-1-1 behind
|
||||
# this flag. Without it, signing fails with:
|
||||
# invalid argument "oci-1-1" for "--registry-referrers-mode" flag:
|
||||
# in order to use mode "oci-1-1", you must set COSIGN_EXPERIMENTAL=1
|
||||
COSIGN_EXPERIMENTAL: '1'
|
||||
steps:
|
||||
# Sparse checkout: the merge job needs `.github/scripts/` (for the
|
||||
# keepalive cleanup script) but none of the source tree.
|
||||
@@ -57,6 +69,16 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@master
|
||||
|
||||
# cosign signs each pushed manifest list with --recursive so the
|
||||
# index and every per-arch entry get an attached Sigstore bundle.
|
||||
# Recent cosign releases always emit the new bundle format, so
|
||||
# there's no extra CLI flag to opt into it.
|
||||
- name: Install cosign
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: sigstore/cosign-installer@v3
|
||||
with:
|
||||
cosign-release: 'v2.4.1'
|
||||
|
||||
- name: Login to DockerHub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v4
|
||||
@@ -120,11 +142,25 @@ jobs:
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
if [ -z "$tags" ]; then
|
||||
echo "No quay.io tags from docker/metadata-action; skipping quay merge"
|
||||
else
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'quay.io/go-skynet/ci-cache@sha256:%s ' *)
|
||||
exit 0
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'quay.io/go-skynet/ci-cache@sha256:%s ' *)
|
||||
# Resolve the manifest-list digest (any tag points at it) so
|
||||
# cosign can sign by digest. Signing by tag would leave the
|
||||
# signature orphaned the next time the tag moves.
|
||||
first_tag=$(jq -cr '
|
||||
.tags | map(select(startswith("quay.io/"))) | .[0]
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||
# --recursive walks the list and signs every per-arch entry
|
||||
# too — clients that resolve a tag to a platform-specific
|
||||
# manifest before checking signatures need the per-arch
|
||||
# signatures, not just the list-level one.
|
||||
cosign sign --yes --recursive \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"quay.io/go-skynet/local-ai-backends@${digest}"
|
||||
|
||||
- name: Create manifest list and push (dockerhub)
|
||||
if: github.event_name != 'pull_request'
|
||||
@@ -139,11 +175,18 @@ jobs:
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
if [ -z "$tags" ]; then
|
||||
echo "No dockerhub tags from docker/metadata-action; skipping dockerhub merge"
|
||||
else
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'localai/localai-backends@sha256:%s ' *)
|
||||
exit 0
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'localai/localai-backends@sha256:%s ' *)
|
||||
first_tag=$(jq -cr '
|
||||
.tags | map(select(startswith("localai/"))) | .[0]
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||
cosign sign --yes --recursive \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"localai/localai-backends@${digest}"
|
||||
|
||||
- name: Inspect manifest
|
||||
if: github.event_name != 'pull_request'
|
||||
|
||||
94
.github/workflows/image.yml
vendored
94
.github/workflows/image.yml
vendored
@@ -151,7 +151,11 @@
|
||||
ubuntu-codename: 'noble'
|
||||
|
||||
core-image-merge:
|
||||
if: github.repository == 'mudler/LocalAI'
|
||||
# !cancelled(): without it, GHA's default `needs:` cascade skips the
|
||||
# merge whenever any matrix cell of the parent build fails or is
|
||||
# cancelled. Same fix as backend.yml's merge jobs — we still want to
|
||||
# publish the manifest list for tag-suffixes whose legs all succeeded.
|
||||
if: ${{ !cancelled() && github.repository == 'mudler/LocalAI' }}
|
||||
needs: core-image-build
|
||||
uses: ./.github/workflows/image_merge.yml
|
||||
with:
|
||||
@@ -164,7 +168,7 @@
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
|
||||
gpu-vulkan-image-merge:
|
||||
if: github.repository == 'mudler/LocalAI'
|
||||
if: ${{ !cancelled() && github.repository == 'mudler/LocalAI' }}
|
||||
needs: core-image-build
|
||||
uses: ./.github/workflows/image_merge.yml
|
||||
with:
|
||||
@@ -175,7 +179,91 @@
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
|
||||
|
||||
# Single-arch server-image merges. Same conceptual fix as the backend
|
||||
# singletons in PR #9781: image_build.yml pushes by canonical digest
|
||||
# only, so without a downstream merge step there's no tag for consumers
|
||||
# (no :latest-gpu-nvidia-cuda-12, no :v<X>-gpu-nvidia-cuda-12, etc.).
|
||||
# Each merge job needs only its parent build matrix and is filtered by
|
||||
# tag-suffix in image_merge.yml's artifact-download pattern.
|
||||
gpu-nvidia-cuda-12-image-merge:
|
||||
if: ${{ !cancelled() && github.repository == 'mudler/LocalAI' }}
|
||||
needs: core-image-build
|
||||
uses: ./.github/workflows/image_merge.yml
|
||||
with:
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12'
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
|
||||
gpu-nvidia-cuda-13-image-merge:
|
||||
if: ${{ !cancelled() && github.repository == 'mudler/LocalAI' }}
|
||||
needs: core-image-build
|
||||
uses: ./.github/workflows/image_merge.yml
|
||||
with:
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13'
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
|
||||
gpu-intel-image-merge:
|
||||
if: ${{ !cancelled() && github.repository == 'mudler/LocalAI' }}
|
||||
needs: core-image-build
|
||||
uses: ./.github/workflows/image_merge.yml
|
||||
with:
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel'
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
|
||||
gpu-hipblas-image-merge:
|
||||
if: ${{ !cancelled() && github.repository == 'mudler/LocalAI' }}
|
||||
needs: hipblas-jobs
|
||||
uses: ./.github/workflows/image_merge.yml
|
||||
with:
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-hipblas'
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
|
||||
nvidia-l4t-arm64-image-merge:
|
||||
if: ${{ !cancelled() && github.repository == 'mudler/LocalAI' }}
|
||||
needs: gh-runner
|
||||
uses: ./.github/workflows/image_merge.yml
|
||||
with:
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64'
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
|
||||
nvidia-l4t-arm64-cuda-13-image-merge:
|
||||
if: ${{ !cancelled() && github.repository == 'mudler/LocalAI' }}
|
||||
needs: gh-runner
|
||||
uses: ./.github/workflows/image_merge.yml
|
||||
with:
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-cuda-13'
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
|
||||
gh-runner:
|
||||
if: github.repository == 'mudler/LocalAI'
|
||||
uses: ./.github/workflows/image_build.yml
|
||||
|
||||
20
.github/workflows/image_build.yml
vendored
20
.github/workflows/image_build.yml
vendored
@@ -106,6 +106,7 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
@@ -185,11 +186,28 @@ jobs:
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
# See .github/scripts/anchor-digest-in-cache.sh for why this is needed
|
||||
# and how it interacts with image_merge.yml's cleanup step. Mirrors the
|
||||
# same anchor in backend_build.yml — quay's per-repo manifest GC reaps
|
||||
# untagged manifests in local-ai before the merge runs.
|
||||
- name: Anchor digest in ci-cache so quay GC won't reap before merge
|
||||
if: github.event_name != 'pull_request'
|
||||
env:
|
||||
TAG_SUFFIX: ${{ inputs.tag-suffix == '' && '-core' || inputs.tag-suffix }}
|
||||
PLATFORM_TAG: ${{ inputs.platform-tag || 'single' }}
|
||||
DIGEST: ${{ steps.build.outputs.digest }}
|
||||
SOURCE_IMAGE: quay.io/go-skynet/local-ai
|
||||
run: .github/scripts/anchor-digest-in-cache.sh
|
||||
|
||||
- name: Upload digest artifact
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: digests-localai${{ inputs.tag-suffix == '' && '-core' || inputs.tag-suffix }}-${{ inputs.platform-tag }}
|
||||
# `--` separator + 'single' placeholder for empty platform-tag —
|
||||
# same pattern as backend_build.yml. Prevents prefix collisions
|
||||
# in the merge-side glob (e.g. -nvidia-l4t-arm64 is a prefix of
|
||||
# -nvidia-l4t-arm64-cuda-13).
|
||||
name: digests-localai${{ inputs.tag-suffix == '' && '-core' || inputs.tag-suffix }}--${{ inputs.platform-tag || 'single' }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
33
.github/workflows/image_merge.yml
vendored
33
.github/workflows/image_merge.yml
vendored
@@ -33,10 +33,22 @@ jobs:
|
||||
env:
|
||||
quay_username: ${{ secrets.quayUsername }}
|
||||
steps:
|
||||
# Sparse checkout: needed for .github/scripts/ (the keepalive cleanup
|
||||
# script). Skips the rest of the source tree.
|
||||
- name: Checkout (.github/scripts only)
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
sparse-checkout: |
|
||||
.github/scripts
|
||||
sparse-checkout-cone-mode: false
|
||||
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v8
|
||||
with:
|
||||
pattern: digests-localai${{ inputs.tag-suffix == '' && '-core' || inputs.tag-suffix }}-*
|
||||
# `--` separator anchors the glob so we don't over-match sibling
|
||||
# tag-suffixes (e.g. -nvidia-l4t-arm64 vs -nvidia-l4t-arm64-cuda-13).
|
||||
# Must stay in sync with image_build.yml's upload-artifact name.
|
||||
pattern: digests-localai${{ inputs.tag-suffix == '' && '-core' || inputs.tag-suffix }}--*
|
||||
merge-multiple: true
|
||||
path: /tmp/digests
|
||||
|
||||
@@ -68,10 +80,18 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
# Source from ci-cache, not local-ai. See backend_merge.yml for the
|
||||
# detailed rationale — quay's manifest GC is per-repository, so the
|
||||
# untagged digest in local-ai gets reaped while the same content lives
|
||||
# tagged under ci-cache (anchored by image_build.yml). buildx imagetools
|
||||
# create copies the manifest into local-ai (blobs already cross-mounted)
|
||||
# and publishes the manifest list with user-facing tags. End state in
|
||||
# local-ai is self-contained; no embedded reference to ci-cache.
|
||||
- name: Create manifest list and push (quay)
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
@@ -82,7 +102,7 @@ jobs:
|
||||
else
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'quay.io/go-skynet/local-ai@sha256:%s ' *)
|
||||
$(printf 'quay.io/go-skynet/ci-cache@sha256:%s ' *)
|
||||
fi
|
||||
|
||||
- name: Create manifest list and push (dockerhub)
|
||||
@@ -107,6 +127,15 @@ jobs:
|
||||
docker buildx imagetools inspect "$first_tag"
|
||||
fi
|
||||
|
||||
# See .github/scripts/cleanup-keepalive-tags.sh for the best-effort
|
||||
# semantics — fails soft when the registry credential isn't OAuth-scoped.
|
||||
- name: Cleanup keepalive tags in ci-cache
|
||||
if: github.event_name != 'pull_request' && success()
|
||||
env:
|
||||
TAG_SUFFIX: ${{ inputs.tag-suffix == '' && '-core' || inputs.tag-suffix }}
|
||||
QUAY_TOKEN: ${{ secrets.quayPassword }}
|
||||
run: .github/scripts/cleanup-keepalive-tags.sh
|
||||
|
||||
- name: Job summary
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
27
.github/workflows/test-extra.yml
vendored
27
.github/workflows/test-extra.yml
vendored
@@ -28,6 +28,7 @@ jobs:
|
||||
qwen-asr: ${{ steps.detect.outputs.qwen-asr }}
|
||||
nemo: ${{ steps.detect.outputs.nemo }}
|
||||
voxcpm: ${{ steps.detect.outputs.voxcpm }}
|
||||
liquid-audio: ${{ steps.detect.outputs.liquid-audio }}
|
||||
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
|
||||
llama-cpp: ${{ steps.detect.outputs.llama-cpp }}
|
||||
ik-llama-cpp: ${{ steps.detect.outputs.ik-llama-cpp }}
|
||||
@@ -447,6 +448,32 @@ jobs:
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm test
|
||||
# liquid-audio: LFM2.5-Audio any-to-any backend. The CI smoke test
|
||||
# exercises Health() and LoadModel(mode:finetune) — fine-tune mode
|
||||
# short-circuits before pulling weights (backend.py:192), so no
|
||||
# HuggingFace download or GPU is needed. The full-inference path is
|
||||
# gated on LIQUID_AUDIO_MODEL_ID, which we don't set here.
|
||||
tests-liquid-audio:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.liquid-audio == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential ffmpeg
|
||||
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
|
||||
# Install UV
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||
- name: Test liquid-audio
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/liquid-audio
|
||||
make --jobs=5 --output-sync=target -C backend/python/liquid-audio test
|
||||
tests-llama-cpp-quantization:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.llama-cpp-quantization == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -26,6 +26,10 @@ go-bert
|
||||
LocalAI
|
||||
/local-ai
|
||||
/local-ai-launcher
|
||||
# Root-level build artifacts when running `go build ./...` against
|
||||
# Go backend packages whose main lives under backend/go/.
|
||||
/cloud-proxy
|
||||
/local-store
|
||||
# prevent above rules from omitting the helm chart
|
||||
!charts/*
|
||||
# prevent above rules from omitting the api/localai folder
|
||||
@@ -77,3 +81,6 @@ local-backends/
|
||||
tests/e2e-ui/ui-test-server
|
||||
core/http/react-ui/playwright-report/
|
||||
core/http/react-ui/test-results/
|
||||
|
||||
# Local worktrees
|
||||
.worktrees/
|
||||
|
||||
@@ -46,8 +46,52 @@ linters:
|
||||
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.Fail. See .agents/coding-style.md.'
|
||||
- pattern: '^t\.FailNow$'
|
||||
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.FailNow. See .agents/coding-style.md.'
|
||||
# In-process config should flow through ApplicationConfig / kong-bound
|
||||
# CLI flags, not via os.Getenv. The CLI layer is the legitimate
|
||||
# env→struct boundary (kong's `env:"..."` tag); anything deeper that
|
||||
# reads env directly leaks process state into business logic and
|
||||
# makes flags impossible to test or override per-request. Backend
|
||||
# subprocesses, the system/capabilities probe, and a few places that
|
||||
# read non-LocalAI env vars (HOME, PATH, AUTH_TOKEN passed by parent)
|
||||
# are exempt — see linters.exclusions.rules below.
|
||||
- pattern: '^os\.(Getenv|LookupEnv|Environ)$'
|
||||
msg: 'Plumb config through ApplicationConfig (or the relevant CLI struct) instead of reading env directly. CLI entry points (core/cli/) bind env vars via kong''s `env:` tag — that is the only sanctioned env→struct boundary. See .agents/coding-style.md.'
|
||||
exclusions:
|
||||
paths:
|
||||
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
||||
- 'backend/go/whisper/sources'
|
||||
- 'docs/'
|
||||
rules:
|
||||
# CLI entry points: kong's `env:"..."` tag is the legitimate env→struct
|
||||
# boundary, and a handful of subcommands legitimately propagate values
|
||||
# to spawned subprocesses (LLAMACPP_GRPC_SERVERS, MLX hostfile, ...).
|
||||
- path: ^core/cli/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# Backend subprocesses are independent binaries with their own env
|
||||
# surface; they're not "in-process config" of the LocalAI server.
|
||||
- path: ^backend/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# System capability probe reads HOME, PATH-style vars to discover
|
||||
# GPUs, default paths, etc. — not LocalAI config.
|
||||
- path: ^pkg/system/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# gRPC server reads AUTH_TOKEN passed in by the parent process at spawn
|
||||
# time; model.Loader sets/inherits env to communicate with subprocesses.
|
||||
- path: ^pkg/grpc/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
- path: ^pkg/model/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# Top-level main binaries (local-ai, launcher) are entry points.
|
||||
- path: ^cmd/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# Tests legitimately read $HOME, $TMPDIR, and gating env vars
|
||||
# (LOCALAI_COSIGN_LIVE, etc.) to skip live-network specs.
|
||||
- path: _test\.go$
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
|
||||
@@ -31,6 +31,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
| [.agents/adding-gallery-models.md](.agents/adding-gallery-models.md) | Adding GGUF models from HuggingFace to the model gallery |
|
||||
| [.agents/localai-assistant-mcp.md](.agents/localai-assistant-mcp.md) | LocalAI Assistant chat modality — adding admin tools to the in-process MCP server, editing skill prompts, keeping REST + MCP + skills in sync |
|
||||
| [.agents/backend-signing.md](.agents/backend-signing.md) | Backend OCI image signing (keyless cosign + sigstore-go) — producer-side CI setup, consumer-side gallery `verification:` block, strict mode (`LOCALAI_REQUIRE_BACKEND_INTEGRITY`), revocation via `not_before` |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
|
||||
21
Makefile
21
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -69,7 +69,7 @@ else
|
||||
GORELEASER=$(shell which goreleaser)
|
||||
endif
|
||||
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/...
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/... ./backend/go/cloud-proxy/... ./backend/go/local-store/...
|
||||
|
||||
|
||||
.PHONY: all test build vendor lint lint-all
|
||||
@@ -268,12 +268,13 @@ prepare-e2e:
|
||||
run-e2e-image:
|
||||
docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests
|
||||
|
||||
test-e2e: build-mock-backend prepare-e2e run-e2e-image
|
||||
test-e2e: build-mock-backend build-cloud-proxy-backend prepare-e2e run-e2e-image
|
||||
@echo 'Running e2e tests'
|
||||
BUILD_TYPE=$(BUILD_TYPE) \
|
||||
LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
|
||||
$(MAKE) clean-mock-backend
|
||||
$(MAKE) clean-cloud-proxy-backend
|
||||
$(MAKE) teardown-e2e
|
||||
docker rmi localai-tests
|
||||
|
||||
@@ -463,6 +464,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/vllm-omni
|
||||
$(MAKE) -C backend/python/sglang
|
||||
$(MAKE) -C backend/python/vibevoice
|
||||
$(MAKE) -C backend/python/liquid-audio
|
||||
$(MAKE) -C backend/python/moonshine
|
||||
$(MAKE) -C backend/python/pocket-tts
|
||||
$(MAKE) -C backend/python/qwen-tts
|
||||
@@ -488,6 +490,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/vllm test
|
||||
$(MAKE) -C backend/python/vllm-omni test
|
||||
$(MAKE) -C backend/python/vibevoice test
|
||||
$(MAKE) -C backend/python/liquid-audio test
|
||||
$(MAKE) -C backend/python/moonshine test
|
||||
$(MAKE) -C backend/python/pocket-tts test
|
||||
$(MAKE) -C backend/python/qwen-tts test
|
||||
@@ -1062,6 +1065,7 @@ BACKEND_DS4 = ds4|ds4|.|false|false
|
||||
# Golang backends
|
||||
BACKEND_PIPER = piper|golang|.|false|true
|
||||
BACKEND_LOCAL_STORE = local-store|golang|.|false|true
|
||||
BACKEND_CLOUD_PROXY = cloud-proxy|golang|.|false|true
|
||||
BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
@@ -1092,6 +1096,7 @@ BACKEND_SGLANG = sglang|python|.|false|true
|
||||
BACKEND_DIFFUSERS = diffusers|python|.|--progress=plain|true
|
||||
BACKEND_CHATTERBOX = chatterbox|python|.|false|true
|
||||
BACKEND_VIBEVOICE = vibevoice|python|.|--progress=plain|true
|
||||
BACKEND_LIQUID_AUDIO = liquid-audio|python|.|--progress=plain|true
|
||||
BACKEND_MOONSHINE = moonshine|python|.|false|true
|
||||
BACKEND_POCKET_TTS = pocket-tts|python|.|false|true
|
||||
BACKEND_QWEN_TTS = qwen-tts|python|.|false|true
|
||||
@@ -1146,6 +1151,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DS4)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CLOUD_PROXY)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
@@ -1169,6 +1175,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SGLANG)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DIFFUSERS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CHATTERBOX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LIQUID_AUDIO)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MOONSHINE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_POCKET_TTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN_TTS)))
|
||||
@@ -1197,7 +1204,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
@@ -1209,6 +1216,12 @@ build-mock-backend: protogen-go
|
||||
clean-mock-backend:
|
||||
rm -f tests/e2e/mock-backend/mock-backend
|
||||
|
||||
build-cloud-proxy-backend: protogen-go
|
||||
$(GOCMD) build -o tests/e2e/mock-backend/cloud-proxy ./backend/go/cloud-proxy
|
||||
|
||||
clean-cloud-proxy-backend:
|
||||
rm -f tests/e2e/mock-backend/cloud-proxy
|
||||
|
||||
########################################################
|
||||
### UI E2E Test Server
|
||||
########################################################
|
||||
|
||||
@@ -37,6 +37,22 @@ service Backend {
|
||||
|
||||
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
||||
|
||||
// TokenClassify runs a token-classification (NER) model on the
|
||||
// supplied text and returns each detected entity span. Used by the
|
||||
// PII redactor's optional NER tier — the regex tier still handles
|
||||
// formatted hits cheaply, while this catches names, locations, and
|
||||
// other unformatted PII that regex misses.
|
||||
rpc TokenClassify(TokenClassifyRequest) returns (TokenClassifyResponse) {}
|
||||
|
||||
// Score evaluates the model's joint log-probability of each
|
||||
// supplied candidate continuation given a shared prompt. The
|
||||
// prompt's KV cache is computed once and reused across candidates.
|
||||
// Used for routing-policy multi-label classification, reranking,
|
||||
// calibrated confidence, and reward-model scoring — any task where
|
||||
// the consumer wants the model's confidence in a pre-specified
|
||||
// continuation rather than a generated one.
|
||||
rpc Score(ScoreRequest) returns (ScoreResponse) {}
|
||||
|
||||
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
|
||||
|
||||
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||
@@ -48,6 +64,11 @@ service Backend {
|
||||
|
||||
rpc AudioTransform(AudioTransformRequest) returns (AudioTransformResult) {}
|
||||
rpc AudioTransformStream(stream AudioTransformFrameRequest) returns (stream AudioTransformFrameResponse) {}
|
||||
// AudioToAudioStream is the bidirectional any-to-any S2S RPC. Backends
|
||||
// that load a speech-to-speech model consume input audio frames and emit
|
||||
// interleaved audio + transcript + tool-call deltas as typed events.
|
||||
// Backends without S2S support return UNIMPLEMENTED.
|
||||
rpc AudioToAudioStream(stream AudioToAudioRequest) returns (stream AudioToAudioResponse) {}
|
||||
|
||||
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
||||
|
||||
@@ -63,6 +84,23 @@ service Backend {
|
||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||
|
||||
// Forward proxies a raw HTTP request to an upstream provider. The
|
||||
// cloud-proxy backend implements this for passthrough-mode model
|
||||
// configs: the client wire format is preserved end-to-end (no
|
||||
// translation through internal proto), which means new provider
|
||||
// fields work the day they ship. Translation-mode proxies use the
|
||||
// standard Predict/PredictStream RPCs instead. Backends that don't
|
||||
// support this return UNIMPLEMENTED.
|
||||
//
|
||||
// The request is bidirectionally streamed so large bodies can flow
|
||||
// without buffering. In practice the first ForwardRequest carries
|
||||
// path, method, headers, and the initial body chunk; subsequent
|
||||
// messages append body chunks. The first ForwardReply carries the
|
||||
// upstream status and response headers; subsequent messages stream
|
||||
// body chunks (SSE frames or chunked transfer). Cancellation of the
|
||||
// gRPC context closes the upstream connection.
|
||||
rpc Forward(stream ForwardRequest) returns (stream ForwardReply) {}
|
||||
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -76,6 +114,76 @@ message MetricsResponse {
|
||||
int32 prompt_tokens_processed = 5;
|
||||
}
|
||||
|
||||
// TokenClassifyRequest carries the text to classify plus an optional
|
||||
// score threshold. The transformers backend interprets threshold as
|
||||
// the minimum confidence to include in the response; 0 = include all.
|
||||
message TokenClassifyRequest {
|
||||
string text = 1;
|
||||
float threshold = 2;
|
||||
}
|
||||
|
||||
// TokenClassifyEntity is one detected entity span. Byte offsets are
|
||||
// into the original UTF-8 text — start..end is a half-open range that
|
||||
// addresses the substring corresponding to entity_group.
|
||||
//
|
||||
// entity_group follows HuggingFace's aggregated-tag convention (e.g.
|
||||
// "PER", "LOC", "ORG", or a PII-specific label like "EMAIL" /
|
||||
// "SSN" depending on the model). The redactor's per-pattern action
|
||||
// map keys off this string.
|
||||
message TokenClassifyEntity {
|
||||
string entity_group = 1;
|
||||
int32 start = 2;
|
||||
int32 end = 3;
|
||||
float score = 4;
|
||||
string text = 5;
|
||||
}
|
||||
|
||||
message TokenClassifyResponse {
|
||||
repeated TokenClassifyEntity entities = 1;
|
||||
}
|
||||
|
||||
// ScoreRequest carries one shared prompt and one or more continuations
|
||||
// to score against it. The backend tokenises the prompt once and reuses
|
||||
// the resulting KV cache across all candidates in this request.
|
||||
message ScoreRequest {
|
||||
string prompt = 1;
|
||||
repeated string candidates = 2;
|
||||
// Return per-token logprobs for each candidate when true. Default
|
||||
// false to keep the wire response small; the joint log_prob field
|
||||
// covers the common ranking case.
|
||||
bool include_token_logprobs = 3;
|
||||
// When true, the response also populates length_normalized_log_prob
|
||||
// (joint log-prob divided by candidate token count). Useful when
|
||||
// candidates differ in length and the consumer wants a per-token
|
||||
// measure comparable across them (PMI-style scoring).
|
||||
bool length_normalize = 4;
|
||||
}
|
||||
|
||||
// CandidateScore is one row in the ScoreResponse, matching by index
|
||||
// the candidate in ScoreRequest.candidates.
|
||||
message CandidateScore {
|
||||
// Sum of log P(token_i | prompt, candidate_token_<i) across the
|
||||
// candidate's tokens. The primary ranking signal.
|
||||
double log_prob = 1;
|
||||
// log_prob / num_tokens — populated when length_normalize=true on
|
||||
// the request.
|
||||
double length_normalized_log_prob = 2;
|
||||
// Per-token detail — populated when include_token_logprobs=true.
|
||||
repeated TokenLogProb tokens = 3;
|
||||
// Number of tokens the backend tokenised this candidate into, after
|
||||
// any backend-specific normalisation (e.g. leading-space handling).
|
||||
int32 num_tokens = 4;
|
||||
}
|
||||
|
||||
message TokenLogProb {
|
||||
string token = 1;
|
||||
double log_prob = 2;
|
||||
}
|
||||
|
||||
message ScoreResponse {
|
||||
repeated CandidateScore candidates = 1;
|
||||
}
|
||||
|
||||
message RerankRequest {
|
||||
string query = 1;
|
||||
repeated string documents = 2;
|
||||
@@ -320,6 +428,25 @@ message ModelOptions {
|
||||
// applied verbatim to the backend's engine constructor (e.g. vLLM AsyncEngineArgs).
|
||||
// Unknown keys produce an error at LoadModel time.
|
||||
string EngineArgs = 73;
|
||||
|
||||
// Proxy carries the cloud-proxy backend's per-model configuration.
|
||||
// Empty for non-proxy backends.
|
||||
ProxyOptions Proxy = 74;
|
||||
}
|
||||
|
||||
// ProxyOptions configures the cloud-proxy backend. UpstreamURL and
|
||||
// Mode are always meaningful; Provider only matters in translate mode.
|
||||
// The two api_key_* fields are mutually exclusive and resolved by the
|
||||
// backend at LoadModel — core forwards the references rather than the
|
||||
// plaintext key.
|
||||
message ProxyOptions {
|
||||
string upstream_url = 1;
|
||||
string mode = 2;
|
||||
string provider = 3;
|
||||
string api_key_env = 4;
|
||||
string api_key_file = 5;
|
||||
string upstream_model = 6;
|
||||
int32 request_timeout_seconds = 7;
|
||||
}
|
||||
|
||||
message Result {
|
||||
@@ -768,6 +895,93 @@ message AudioTransformFrameResponse {
|
||||
int64 frame_index = 2;
|
||||
}
|
||||
|
||||
// === AudioToAudioStream messages =========================================
|
||||
//
|
||||
// Bidirectional stream between the LocalAI core and an any-to-any audio
|
||||
// model. The client opens the stream with a Config payload, then alternates
|
||||
// Frame (input audio) and Control (turn boundaries, function-call results,
|
||||
// session updates) payloads. The server streams back typed events: audio
|
||||
// frames carry PCM in `pcm`; transcript / tool-call deltas carry JSON in
|
||||
// `meta`; the stream ends with a `response.done` (success) or `error` event.
|
||||
|
||||
message AudioToAudioRequest {
|
||||
oneof payload {
|
||||
AudioToAudioConfig config = 1;
|
||||
AudioToAudioFrame frame = 2;
|
||||
AudioToAudioControl control = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message AudioToAudioConfig {
|
||||
// PCM format for client→server audio. 0 => backend default
|
||||
// (16 kHz for the LFM2-Audio Conformer encoder).
|
||||
int32 input_sample_rate = 1;
|
||||
// Preferred server→client audio rate. 0 => backend default
|
||||
// (24 kHz for the LFM2-Audio vocoder).
|
||||
int32 output_sample_rate = 2;
|
||||
// Optional system prompt override. Empty => backend chooses based on
|
||||
// mode (e.g. "Respond with interleaved text and audio.").
|
||||
string system_prompt = 3;
|
||||
// Optional baked-voice id. Models that only ship a fixed set of
|
||||
// voices (e.g. LFM2-Audio: us_male/us_female/uk_male/uk_female) match
|
||||
// this against their voice table; an empty string keeps the default.
|
||||
string voice = 4;
|
||||
// JSON-encoded array of tool definitions in OpenAI Chat Completions
|
||||
// format. Empty => no tools.
|
||||
string tools = 5;
|
||||
// Free-form sampling / decoding parameters (temperature, top_k,
|
||||
// max_new_tokens, audio_top_k, etc).
|
||||
map<string, string> params = 6;
|
||||
// True => reset any session-scoped state before processing further
|
||||
// frames on this stream. The first Config implicitly resets.
|
||||
bool reset = 7;
|
||||
}
|
||||
|
||||
message AudioToAudioFrame {
|
||||
// Raw PCM s16le mono at config.input_sample_rate. Empty pcm + end_of_input
|
||||
// is a valid "user finished speaking" marker without trailing audio.
|
||||
bytes pcm = 1;
|
||||
// Marks the last frame of a user turn. The backend may begin emitting
|
||||
// a response immediately after seeing this.
|
||||
bool end_of_input = 2;
|
||||
}
|
||||
|
||||
message AudioToAudioControl {
|
||||
// Free-form control event names. Initial set:
|
||||
// "input_audio_buffer.commit" — user finished speaking
|
||||
// "response.cancel" — abort in-flight generation
|
||||
// "conversation.item.create" — inject a non-audio item (e.g.
|
||||
// function_call_output as JSON in
|
||||
// `payload`)
|
||||
// "session.update" — re-configure mid-stream
|
||||
string event = 1;
|
||||
// Event-specific JSON payload.
|
||||
bytes payload = 2;
|
||||
}
|
||||
|
||||
message AudioToAudioResponse {
|
||||
// Event identifies what this frame carries. Mirrors the OpenAI Realtime
|
||||
// API server-event names where applicable. Initial set:
|
||||
// "response.audio.delta"
|
||||
// "response.audio_transcript.delta"
|
||||
// "response.function_call_arguments.delta"
|
||||
// "response.function_call_arguments.done"
|
||||
// "response.done"
|
||||
// "error"
|
||||
string event = 1;
|
||||
// Populated when event = response.audio.delta.
|
||||
bytes pcm = 2;
|
||||
// Populated alongside pcm to identify its rate. 0 => same as the
|
||||
// session's negotiated output_sample_rate.
|
||||
int32 sample_rate = 3;
|
||||
// JSON payload for non-PCM events (transcript chunk, tool args, error
|
||||
// body).
|
||||
bytes meta = 4;
|
||||
// Monotonic per-stream counter, useful for client reordering and
|
||||
// debugging.
|
||||
int64 sequence = 5;
|
||||
}
|
||||
|
||||
message ModelMetadataResponse {
|
||||
bool supports_thinking = 1;
|
||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||
@@ -910,3 +1124,32 @@ message QuantizationStopRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
// ForwardHeader is one HTTP header on the request or response. Headers
|
||||
// like Authorization are typically injected by the backend (from the
|
||||
// resolved API key) rather than passed through from the client.
|
||||
message ForwardHeader {
|
||||
string name = 1;
|
||||
string value = 2;
|
||||
}
|
||||
|
||||
// ForwardRequest is a streamed HTTP request to the upstream. First
|
||||
// message carries path/method/headers; subsequent messages carry
|
||||
// body_chunk only. All fields except body_chunk are honoured on the
|
||||
// first message and ignored thereafter.
|
||||
message ForwardRequest {
|
||||
string path = 1; // e.g. "/v1/chat/completions" — appended to the model's upstream_url
|
||||
string method = 2; // usually "POST"
|
||||
repeated ForwardHeader headers = 3;
|
||||
bytes body_chunk = 4;
|
||||
}
|
||||
|
||||
// ForwardReply is a streamed HTTP response from the upstream. First
|
||||
// message carries status/headers; subsequent messages carry body_chunk
|
||||
// only. SSE responses arrive as a sequence of body_chunk frames; the
|
||||
// caller is responsible for any parsing.
|
||||
message ForwardReply {
|
||||
int32 status = 1;
|
||||
repeated ForwardHeader headers = 2;
|
||||
bytes body_chunk = 3;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=f8b4ed635d559b3a5b44bf2df6a77e21b3e9178f
|
||||
# Upstream pin lives below as DS4_VERSION?=f91c12b50a1448527c435c028bfc70d1b00f6c33
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=f8b4ed635d559b3a5b44bf2df6a77e21b3e9178f
|
||||
DS4_VERSION?=f91c12b50a1448527c435c028bfc70d1b00f6c33
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=f9a93c37e2fc021760c3c1aa99cf74c73b7591a7
|
||||
IK_LLAMA_VERSION?=9f7ba245ab41e118f03aa8dd5134d18a81159d02
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=1ec7ba0c14f33f17e980daeeda5f35b225d41994
|
||||
LLAMA_VERSION?=549b9d84330c327e6791fa812a7d60c0cf63572e
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -32,7 +32,9 @@
|
||||
#include <grpcpp/health_check_service_interface.h>
|
||||
#include <grpcpp/security/server_credentials.h>
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
@@ -120,6 +122,40 @@ static std::string base64_encode_bytes(const unsigned char* data, size_t len) {
|
||||
|
||||
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
|
||||
|
||||
// Score bypasses the slot loop (see the comment on Score below) so it
|
||||
// must not run concurrently with any slot-loop RPC. These counters
|
||||
// are a defence-in-depth tripwire — ModelConfig.Validate already
|
||||
// rejects llama-cpp configs that mix score with chat/completion/
|
||||
// embeddings, so a healthy deployment never trips them. seq_cst is
|
||||
// load-bearing for the increment-then-check pattern below.
|
||||
static std::atomic<int> slot_loop_inflight{0};
|
||||
static std::atomic<int> score_inflight{0};
|
||||
|
||||
// Increment-then-check, not check-then-increment: two simultaneous
|
||||
// racers both observe the other's increment and both abort cleanly.
|
||||
// Reversed, both could see zero and proceed.
|
||||
struct conflict_guard {
|
||||
std::atomic<int>& self;
|
||||
conflict_guard(const char* rpc, std::atomic<int>& self_, std::atomic<int>& other, const char* other_name)
|
||||
: self(self_) {
|
||||
self.fetch_add(1, std::memory_order_seq_cst);
|
||||
int o = other.load(std::memory_order_seq_cst);
|
||||
if (o > 0) {
|
||||
fprintf(stderr,
|
||||
"FATAL: %s called with %s=%d. The llama-cpp backend cannot "
|
||||
"service Score and slot-loop RPCs concurrently — Score "
|
||||
"bypasses the slot loop and races the llama_context. Bind "
|
||||
"Score-using features to a model dedicated to scoring "
|
||||
"(known_usecases: [score] with no chat/completion/embeddings).\n",
|
||||
rpc, other_name, o);
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
~conflict_guard() {
|
||||
self.fetch_sub(1, std::memory_order_seq_cst);
|
||||
}
|
||||
};
|
||||
|
||||
static std::function<void(int)> shutdown_handler;
|
||||
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
||||
|
||||
@@ -450,6 +486,8 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// vector; the turboquant fork still uses the legacy scalar. The
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
||||
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
||||
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
|
||||
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
@@ -458,7 +496,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
const bool no_spec_type = params.speculative.types.empty() ||
|
||||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
||||
if (no_spec_type) {
|
||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT };
|
||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -514,16 +552,27 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.warmup = true;
|
||||
// no_op_offload: disable host tensor op offload (default: false)
|
||||
params.no_op_offload = false;
|
||||
// kv_unified: enable unified KV cache (default: false)
|
||||
params.kv_unified = false;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot (default: 8)
|
||||
params.n_ctx_checkpoints = 8;
|
||||
|
||||
// llama memory fit fails if we don't provide a buffer for tensor overrides
|
||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||
while (params.tensor_buft_overrides.size() < ntbo) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
// kv_unified: enable unified KV cache. Upstream's server auto-enables this
|
||||
// when the slot count is auto (-np <0), bumping n_parallel to 4 alongside.
|
||||
// LocalAI keeps n_parallel=1 by default, which would skip that auto path
|
||||
// and leave kv_unified=false. We flip the default to true here so the
|
||||
// server-side prompt cache (cache_idle_slots) is actually usable on the
|
||||
// single-slot path that LocalAI ships with: without it, idle slots are
|
||||
// never persisted across requests and the prompt cache is dead weight.
|
||||
// Users can opt out with `options: [ "kv_unified:false" ]`.
|
||||
params.kv_unified = true;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot. Match upstream's
|
||||
// default (32); the previous LocalAI-specific 8 was unnecessarily tight
|
||||
// and limits partial-prefix recovery without a clear memory rationale.
|
||||
params.n_ctx_checkpoints = 32;
|
||||
// cache_idle_slots: save and clear idle slot KV to the prompt cache on
|
||||
// task switch. Upstream default is true; the server auto-disables it if
|
||||
// kv_unified=false or cache_ram_mib=0, so flipping kv_unified above is
|
||||
// what actually unlocks it.
|
||||
params.cache_idle_slots = true;
|
||||
// checkpoint_every_nt: create a context checkpoint every N tokens during
|
||||
// prefill (-1 disables). Match upstream's default (8192).
|
||||
params.checkpoint_every_nt = 8192;
|
||||
|
||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||
for (int i = 0; i < request->options_size(); i++) {
|
||||
@@ -682,9 +731,161 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
try {
|
||||
params.n_ctx_checkpoints = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (8)
|
||||
// If conversion fails, keep default value (32)
|
||||
}
|
||||
}
|
||||
|
||||
// --- server-side idle-slot prompt cache toggle (upstream --cache-idle-slots) ---
|
||||
// Saves the slot's KV state into the host-side prompt cache on task
|
||||
// switch so a later request with the same prefix can warm-load it.
|
||||
// Auto-disabled by the server if kv_unified=false or cache_ram=0.
|
||||
} else if (!strcmp(optname, "cache_idle_slots") || !strcmp(optname, "idle_slots_cache")) {
|
||||
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||
params.cache_idle_slots = true;
|
||||
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
// --- prefill checkpoint cadence (upstream -cpent / --checkpoint-every-n-tokens) ---
|
||||
// -1 disables checkpointing during prefill.
|
||||
} else if (!strcmp(optname, "checkpoint_every_nt") || !strcmp(optname, "checkpoint_every_n_tokens")) {
|
||||
if (optval != NULL) {
|
||||
try {
|
||||
params.checkpoint_every_nt = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (8192)
|
||||
}
|
||||
}
|
||||
|
||||
// --- physical batch size (upstream -ub / --ubatch-size) ---
|
||||
// Note: line ~482 already aliases n_ubatch to n_batch as a default; this
|
||||
// option lets users decouple the two (useful for embeddings/rerank).
|
||||
} else if (!strcmp(optname, "n_ubatch") || !strcmp(optname, "ubatch")) {
|
||||
if (optval != NULL) {
|
||||
try { params.n_ubatch = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
|
||||
// --- main-model batch threads (upstream -tb / --threads-batch) ---
|
||||
} else if (!strcmp(optname, "threads_batch") || !strcmp(optname, "n_threads_batch")) {
|
||||
if (optval != NULL) {
|
||||
try {
|
||||
int n = std::stoi(optval_str);
|
||||
if (n <= 0) n = (int)std::thread::hardware_concurrency();
|
||||
params.cpuparams_batch.n_threads = n;
|
||||
} catch (...) {}
|
||||
}
|
||||
|
||||
// --- pooling type for embeddings (upstream --pooling) ---
|
||||
} else if (!strcmp(optname, "pooling_type") || !strcmp(optname, "pooling")) {
|
||||
if (optval != NULL) {
|
||||
if (optval_str == "none") params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
else if (optval_str == "mean") params.pooling_type = LLAMA_POOLING_TYPE_MEAN;
|
||||
else if (optval_str == "cls") params.pooling_type = LLAMA_POOLING_TYPE_CLS;
|
||||
else if (optval_str == "last") params.pooling_type = LLAMA_POOLING_TYPE_LAST;
|
||||
else if (optval_str == "rank") params.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||
// unknown values silently leave UNSPECIFIED (auto-detect)
|
||||
}
|
||||
|
||||
// --- llama log verbosity threshold (upstream -lv / --verbosity) ---
|
||||
} else if (!strcmp(optname, "verbosity")) {
|
||||
if (optval != NULL) {
|
||||
try { params.verbosity = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
|
||||
// --- O_DIRECT model loading (upstream --direct-io) ---
|
||||
} else if (!strcmp(optname, "direct_io") || !strcmp(optname, "use_direct_io")) {
|
||||
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||
params.use_direct_io = true;
|
||||
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||
params.use_direct_io = false;
|
||||
}
|
||||
|
||||
// --- embedding normalization (upstream --embd-normalize) ---
|
||||
// -1 none, 0 max-abs, 1 taxicab, 2 L2 (default), >2 p-norm
|
||||
} else if (!strcmp(optname, "embd_normalize") || !strcmp(optname, "embedding_normalize")) {
|
||||
if (optval != NULL) {
|
||||
try { params.embd_normalize = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
|
||||
// --- reasoning parser (upstream --reasoning-format) ---
|
||||
// Picks the parser for <think> blocks emitted by reasoning models.
|
||||
// none / auto / deepseek / deepseek-legacy
|
||||
} else if (!strcmp(optname, "reasoning_format")) {
|
||||
if (optval != NULL) {
|
||||
if (optval_str == "none") params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
else if (optval_str == "auto") params.reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
else if (optval_str == "deepseek") params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
else if (optval_str == "deepseek-legacy" || optval_str == "deepseek_legacy")
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY;
|
||||
// unknown values silently keep the upstream default (DEEPSEEK)
|
||||
}
|
||||
|
||||
// --- reasoning budget (upstream --reasoning-budget) ---
|
||||
// -1 unlimited, 0 disabled, >0 token budget for thinking blocks.
|
||||
// Distinct from per-request `enable_thinking` (chat_template_kwargs).
|
||||
} else if (!strcmp(optname, "enable_reasoning") || !strcmp(optname, "reasoning_budget")) {
|
||||
if (optval != NULL) {
|
||||
try { params.enable_reasoning = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
|
||||
// --- prefill assistant turn (upstream --no-prefill-assistant) ---
|
||||
} else if (!strcmp(optname, "prefill_assistant")) {
|
||||
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||
params.prefill_assistant = true;
|
||||
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||
params.prefill_assistant = false;
|
||||
}
|
||||
|
||||
// --- mmproj GPU offload (upstream --no-mmproj-offload, inverted) ---
|
||||
} else if (!strcmp(optname, "mmproj_use_gpu") || !strcmp(optname, "mmproj_offload")) {
|
||||
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||
params.mmproj_use_gpu = true;
|
||||
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||
params.mmproj_use_gpu = false;
|
||||
}
|
||||
|
||||
// --- per-image vision token budget (upstream --image-min/max-tokens) ---
|
||||
} else if (!strcmp(optname, "image_min_tokens")) {
|
||||
if (optval != NULL) {
|
||||
try { params.image_min_tokens = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "image_max_tokens")) {
|
||||
if (optval != NULL) {
|
||||
try { params.image_max_tokens = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
|
||||
// --- main-model tensor buffer overrides (upstream --override-tensor) ---
|
||||
// Format: <tensor regex>=<buffer type>,<tensor regex>=<buffer type>,...
|
||||
// Mirrors the existing `draft_override_tensor` parser below.
|
||||
} else if (!strcmp(optname, "override_tensor") || !strcmp(optname, "tensor_buft_overrides")) {
|
||||
ggml_backend_load_all();
|
||||
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
auto * buft = ggml_backend_dev_buffer_type(dev);
|
||||
if (buft) {
|
||||
buft_list[ggml_backend_buft_name(buft)] = buft;
|
||||
}
|
||||
}
|
||||
static std::list<std::string> override_names;
|
||||
std::string cur;
|
||||
auto flush = [&](const std::string & spec) {
|
||||
auto pos = spec.find('=');
|
||||
if (pos == std::string::npos) return;
|
||||
const std::string name = spec.substr(0, pos);
|
||||
const std::string type = spec.substr(pos + 1);
|
||||
auto it = buft_list.find(type);
|
||||
if (it == buft_list.end()) return; // unknown buffer type: ignore
|
||||
override_names.push_back(name);
|
||||
params.tensor_buft_overrides.push_back(
|
||||
{override_names.back().c_str(), it->second});
|
||||
};
|
||||
for (char c : optval_str) {
|
||||
if (c == ',') { if (!cur.empty()) { flush(cur); cur.clear(); } }
|
||||
else { cur.push_back(c); }
|
||||
}
|
||||
if (!cur.empty()) flush(cur);
|
||||
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
@@ -701,16 +902,27 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// Upstream switched to a vector of types (comma-separated for multi-type
|
||||
// chaining via common_speculative_types_from_names). We keep accepting a
|
||||
// single value here, but also tolerate comma-separated lists.
|
||||
//
|
||||
// ggml-org/llama.cpp#22964 also renamed the registered names from
|
||||
// underscore- to dash-separated form, and replaced the bare
|
||||
// `draft`/`eagle3` aliases with `draft-simple`/`draft-eagle3`. We
|
||||
// normalize each token here so existing model configs keep working.
|
||||
auto normalize_spec_name = [](std::string s) -> std::string {
|
||||
std::replace(s.begin(), s.end(), '_', '-');
|
||||
if (s == "draft") return "draft-simple";
|
||||
if (s == "eagle3") return "draft-eagle3";
|
||||
return s;
|
||||
};
|
||||
std::vector<std::string> names;
|
||||
std::string item;
|
||||
for (char c : optval_str) {
|
||||
if (c == ',') {
|
||||
if (!item.empty()) { names.push_back(item); item.clear(); }
|
||||
if (!item.empty()) { names.push_back(normalize_spec_name(item)); item.clear(); }
|
||||
} else {
|
||||
item.push_back(c);
|
||||
}
|
||||
}
|
||||
if (!item.empty()) names.push_back(item);
|
||||
if (!item.empty()) names.push_back(normalize_spec_name(item));
|
||||
auto parsed = common_speculative_types_from_names(names);
|
||||
if (!parsed.empty()) {
|
||||
params.speculative.types = parsed;
|
||||
@@ -937,6 +1149,20 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.kv_overrides.back().key[0] = 0;
|
||||
}
|
||||
|
||||
// tensor_buft_overrides sentinel termination (mirrors upstream common/arg.cpp).
|
||||
// Real entries are pushed during option parsing; here we pad/terminate so the
|
||||
// model loader sees back().pattern == nullptr (GGML_ASSERT at common.cpp:1543)
|
||||
// and so llama_params_fit has the placeholder slots it requires.
|
||||
{
|
||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||
while (params.tensor_buft_overrides.size() < ntbo) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
}
|
||||
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
|
||||
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
|
||||
// TODO: Add yarn
|
||||
|
||||
if (!request->tensorsplit().empty()) {
|
||||
@@ -1255,6 +1481,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("PredictStream", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
|
||||
@@ -2014,6 +2241,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Predict", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
data["stream"] = false;
|
||||
@@ -2772,6 +3000,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Embedding", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
body["stream"] = false;
|
||||
@@ -2794,7 +3023,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
int embd_normalize = 2; // default to Euclidean/L2 norm
|
||||
// Honor the load-time embd_normalize set via options:embd_normalize.
|
||||
// -1 none, 0 max-abs, 1 taxicab, 2 L2 (default), >2 p-norm.
|
||||
int embd_normalize = params_base.embd_normalize;
|
||||
// create and queue the task
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
{
|
||||
@@ -2877,6 +3108,8 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||
}
|
||||
|
||||
conflict_guard guard("Rerank", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// Create and queue the task
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
{
|
||||
@@ -2949,12 +3182,218 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
// Score returns the model's joint log-probability of each candidate
|
||||
// continuation given a shared prompt.
|
||||
//
|
||||
// WHY bypass the slot/task queue: upstream server_context exposes
|
||||
// get_llama_context as "main thread only" and the slot loop's
|
||||
// update_slots() owns the context whenever a task is in flight.
|
||||
// No public synchronization primitive is available — so Score is
|
||||
// unsafe to call concurrently with active generation through this
|
||||
// backend. In practice routing-classifier calls happen before the
|
||||
// request is routed to a generation backend, so the model used
|
||||
// for Score is typically idle. Concurrent Score calls are
|
||||
// serialised by a local mutex; KV-cache state is isolated behind
|
||||
// a dedicated sequence ID cleared between candidates.
|
||||
//
|
||||
// A patch to server-context.cpp that adds SERVER_TASK_TYPE_SCORE
|
||||
// and routes scoring through the slot loop would be the correct
|
||||
// long-term fix; tracked as a follow-up.
|
||||
//
|
||||
// Perf TODO (measured: ~450 ms warm for 3 candidates on Arch-
|
||||
// Router-1.5B Q4_K_M + Intel SYCL): the current loop re-decodes
|
||||
// `prompt + candidate` from scratch for every candidate, throwing
|
||||
// away the prompt's KV cache between iterations. A smarter
|
||||
// version would:
|
||||
// 1. Decode just the prompt once into score_seq_id.
|
||||
// 2. Snapshot/cp that sequence (llama_memory_seq_cp) into a
|
||||
// per-candidate sequence id.
|
||||
// 3. For each candidate, decode only its tokens onto the copy
|
||||
// (continuing from the saved prompt state), read logits.
|
||||
// 4. llama_memory_seq_rm the copy.
|
||||
// Estimated speedup: 3-candidate calls 450 ms -> ~150-200 ms,
|
||||
// 6-candidate calls 630 ms -> ~220 ms. Single source-file change,
|
||||
// no proto / Go-side changes needed. Worth doing once routing is
|
||||
// wired into the middleware and Score is on the hot path of every
|
||||
// chat request.
|
||||
grpc::Status Score(ServerContext* context, const backend::ScoreRequest* request, backend::ScoreResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
if (request->candidates_size() == 0) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "candidates must be non-empty");
|
||||
}
|
||||
|
||||
// Tripwire against the slot loop. Acquired before score_mutex
|
||||
// so it fires even when this Score is queued behind another.
|
||||
conflict_guard guard("Score", score_inflight, slot_loop_inflight, "slot_loop_inflight");
|
||||
|
||||
// Serialise concurrent Score calls. The slot loop is still
|
||||
// free to race with us — see the class comment above.
|
||||
static std::mutex score_mutex;
|
||||
std::lock_guard<std::mutex> score_lock(score_mutex);
|
||||
|
||||
llama_context * lctx = ctx_server.get_llama_context();
|
||||
if (lctx == nullptr) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "llama context unavailable (sleeping?)");
|
||||
}
|
||||
const llama_vocab * vocab = ctx_server.impl->vocab;
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
const int32_t n_ctx = llama_n_ctx(lctx);
|
||||
llama_memory_t mem = llama_get_memory(lctx);
|
||||
|
||||
// The KV-cache is sized to seq_to_stream.size() at load
|
||||
// (typically equal to n_slots, often 1). Sequence IDs must
|
||||
// be in [0, n_seq_max), so we can't pick a high-value
|
||||
// "private" ID — we have to share with the slot. We clear
|
||||
// the cache before AND after each candidate to keep
|
||||
// scoring isolated from whatever state the slot held, and
|
||||
// the static mutex above guarantees no other Score call is
|
||||
// racing in the meantime. The slot loop is still free to
|
||||
// race (see comment on this method) — Score must not run
|
||||
// concurrently with generation through this backend.
|
||||
const llama_seq_id score_seq_id = 0;
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
|
||||
// Tokenize the shared prompt once with add_special=true so
|
||||
// BOS is prepended when the model requires it. parse_special
|
||||
// keeps chat-template markers in the prompt intact.
|
||||
const std::string prompt = request->prompt();
|
||||
std::vector<llama_token> prompt_tokens = common_tokenize(vocab, prompt, /*add_special=*/true, /*parse_special=*/true);
|
||||
const int32_t prompt_len = (int32_t) prompt_tokens.size();
|
||||
|
||||
for (int ci = 0; ci < request->candidates_size(); ci++) {
|
||||
const std::string & candidate_text = request->candidates(ci);
|
||||
|
||||
// Re-tokenize prompt + candidate as a single string. BPE
|
||||
// merges across the boundary can shift the tokenization
|
||||
// versus tokenize(prompt) ++ tokenize(candidate), so we
|
||||
// find the divergence point against prompt_tokens.
|
||||
std::vector<llama_token> full_tokens = common_tokenize(vocab, prompt + candidate_text, /*add_special=*/true, /*parse_special=*/true);
|
||||
int32_t divergence = prompt_len;
|
||||
const int32_t min_len = std::min<int32_t>(prompt_len, (int32_t) full_tokens.size());
|
||||
for (int32_t i = 0; i < min_len; i++) {
|
||||
if (prompt_tokens[i] != full_tokens[i]) {
|
||||
divergence = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const int32_t cand_len = (int32_t) full_tokens.size() - divergence;
|
||||
backend::CandidateScore * cs = response->add_candidates();
|
||||
cs->set_num_tokens(cand_len);
|
||||
if (cand_len <= 0) {
|
||||
cs->set_log_prob(0.0);
|
||||
if (request->length_normalize()) {
|
||||
cs->set_length_normalized_log_prob(0.0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (divergence < 1) {
|
||||
// Need at least one prior token (typically BOS) to
|
||||
// predict the first candidate token's logit. Tokeniser
|
||||
// models without BOS + an empty prompt fall in here.
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
|
||||
"Score: prompt produced no leading tokens; need at least one (e.g. BOS) to predict candidate");
|
||||
}
|
||||
if ((int32_t) full_tokens.size() > n_ctx) {
|
||||
return grpc::Status(grpc::StatusCode::OUT_OF_RANGE,
|
||||
"Score: prompt+candidate exceeds context size (got " +
|
||||
std::to_string(full_tokens.size()) + ", n_ctx=" + std::to_string(n_ctx) + ")");
|
||||
}
|
||||
|
||||
// Build a batch covering the entire prompt+candidate. We
|
||||
// need logits at (divergence-1) onward — those are the
|
||||
// predictions for each candidate token.
|
||||
llama_batch batch = llama_batch_init((int32_t) full_tokens.size(), 0, 1);
|
||||
for (int32_t i = 0; i < (int32_t) full_tokens.size(); i++) {
|
||||
batch.token[i] = full_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = score_seq_id;
|
||||
// logits[i] is "do we want the prediction *for the
|
||||
// next token*, computed from this position?"
|
||||
// We want predictions for candidate tokens at
|
||||
// positions divergence .. full_tokens.size()-1, which
|
||||
// come from logits at positions (divergence-1) ..
|
||||
// (full_tokens.size()-2).
|
||||
bool need_logit = (i >= divergence - 1) && (i < (int32_t) full_tokens.size() - 1);
|
||||
batch.logits[i] = need_logit ? 1 : 0;
|
||||
}
|
||||
batch.n_tokens = (int32_t) full_tokens.size();
|
||||
|
||||
// Decode the batch. If decode fails (e.g. KV slot
|
||||
// exhaustion), surface as INTERNAL — the caller will
|
||||
// typically fall back to a sampling-based classifier.
|
||||
int decode_err = llama_decode(lctx, batch);
|
||||
if (decode_err != 0) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_decode failed during Score: " + std::to_string(decode_err));
|
||||
}
|
||||
|
||||
// Sum log-probabilities of the actual candidate tokens.
|
||||
double total_log_prob = 0.0;
|
||||
for (int32_t k = 0; k < cand_len; k++) {
|
||||
// The k-th candidate token sits at full_tokens index
|
||||
// (divergence + k). Its predicting logit is at batch
|
||||
// position (divergence + k - 1).
|
||||
int32_t logit_pos = divergence + k - 1;
|
||||
const float * logits = llama_get_logits_ith(lctx, logit_pos);
|
||||
if (logits == nullptr) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_get_logits_ith returned null at position " + std::to_string(logit_pos));
|
||||
}
|
||||
llama_token target_token = full_tokens[divergence + k];
|
||||
|
||||
// Compute log_softmax(logits)[target_token] with the
|
||||
// max-subtraction stability trick.
|
||||
float max_logit = logits[0];
|
||||
for (int32_t v = 1; v < n_vocab; v++) {
|
||||
if (logits[v] > max_logit) max_logit = logits[v];
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (int32_t v = 0; v < n_vocab; v++) {
|
||||
sum_exp += std::exp((double)(logits[v] - max_logit));
|
||||
}
|
||||
double token_log_prob = (double)(logits[target_token] - max_logit) - std::log(sum_exp);
|
||||
total_log_prob += token_log_prob;
|
||||
|
||||
if (request->include_token_logprobs()) {
|
||||
backend::TokenLogProb * tlp = cs->add_tokens();
|
||||
std::string piece = common_token_to_piece(lctx, target_token);
|
||||
tlp->set_token(piece);
|
||||
tlp->set_log_prob(token_log_prob);
|
||||
}
|
||||
}
|
||||
|
||||
cs->set_log_prob(total_log_prob);
|
||||
if (request->length_normalize() && cand_len > 0) {
|
||||
cs->set_length_normalized_log_prob(total_log_prob / (double) cand_len);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
// Drop this candidate's KV-cache contribution so the next
|
||||
// candidate starts from a clean state. Without this, the
|
||||
// next decode would conflict at positions 0..N-1 for our
|
||||
// sequence ID.
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
}
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("TokenizeString", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
body["stream"] = false;
|
||||
|
||||
@@ -2976,6 +3415,8 @@ public:
|
||||
|
||||
grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override {
|
||||
|
||||
conflict_guard guard("GetMetrics", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// request slots data using task queue
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
int task_id = rd.queue_tasks.get_new_id();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=69d8e4be47243e83b3d0d71e932bc7aa61c644dc
|
||||
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# acestep.cpp version
|
||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||
ACESTEP_CPP_VERSION?=e0c8d75a672fca5684c88c68dbf6d12f58754258
|
||||
ACESTEP_CPP_VERSION?=ed53caf164e4492a5620b2e3f2264629cf66da24
|
||||
SO_TARGET?=libgoacestepcpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -22,12 +22,11 @@
|
||||
#include <vector>
|
||||
|
||||
// Global model contexts (loaded once, reused across requests)
|
||||
static DiTGGML g_dit = {};
|
||||
static DiTGGMLConfig g_dit_cfg;
|
||||
static VAEGGML g_vae = {};
|
||||
static bool g_dit_loaded = false;
|
||||
static bool g_vae_loaded = false;
|
||||
static bool g_is_turbo = false;
|
||||
static DiTGGML g_dit = {};
|
||||
static VAEGGML g_vae = {};
|
||||
static bool g_dit_loaded = false;
|
||||
static bool g_vae_loaded = false;
|
||||
static bool g_is_turbo = false;
|
||||
|
||||
// Silence latent [15000, 64] — read once from DiT GGUF
|
||||
static std::vector<float> g_silence_full;
|
||||
@@ -72,10 +71,9 @@ int load_model(const char * lm_model_path, const char * text_encoder_path,
|
||||
g_text_enc_path = text_encoder_path;
|
||||
g_dit_path = dit_model_path;
|
||||
|
||||
// Load DiT model
|
||||
// Load DiT model (backend init + config are handled inside dit_ggml_load)
|
||||
fprintf(stderr, "[acestep-cpp] Loading DiT from %s\n", dit_model_path);
|
||||
dit_ggml_init_backend(&g_dit);
|
||||
if (!dit_ggml_load(&g_dit, dit_model_path, g_dit_cfg, nullptr, 0.0f)) {
|
||||
if (!dit_ggml_load(&g_dit, dit_model_path)) {
|
||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load DiT from %s\n", dit_model_path);
|
||||
return 1;
|
||||
}
|
||||
@@ -149,16 +147,16 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
||||
|
||||
// Compute T (latent frames at 25Hz)
|
||||
int T = (int)(duration * FRAMES_PER_SECOND);
|
||||
T = ((T + g_dit_cfg.patch_size - 1) / g_dit_cfg.patch_size) * g_dit_cfg.patch_size;
|
||||
int S = T / g_dit_cfg.patch_size;
|
||||
T = ((T + g_dit.cfg.patch_size - 1) / g_dit.cfg.patch_size) * g_dit.cfg.patch_size;
|
||||
int S = T / g_dit.cfg.patch_size;
|
||||
|
||||
if (T > 15000) {
|
||||
fprintf(stderr, "[acestep-cpp] ERROR: T=%d exceeds max 15000\n", T);
|
||||
return 2;
|
||||
}
|
||||
|
||||
int Oc = g_dit_cfg.out_channels; // 64
|
||||
int ctx_ch = g_dit_cfg.in_channels - Oc; // 128
|
||||
int Oc = g_dit.cfg.out_channels; // 64
|
||||
int ctx_ch = g_dit.cfg.in_channels - Oc; // 128
|
||||
|
||||
fprintf(stderr, "[acestep-cpp] T=%d, S=%d, duration=%.1fs, seed=%d\n", T, S, duration, seed);
|
||||
|
||||
@@ -191,9 +189,8 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
||||
|
||||
fprintf(stderr, "[acestep-cpp] caption: %d tokens, lyrics: %d tokens\n", S_text, S_lyric);
|
||||
|
||||
// 4. Text encoder forward
|
||||
// 4. Text encoder forward (backend init handled inside qwen3_load_text_encoder)
|
||||
Qwen3GGML text_enc = {};
|
||||
qwen3_init_backend(&text_enc);
|
||||
if (!qwen3_load_text_encoder(&text_enc, g_text_enc_path.c_str())) {
|
||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load text encoder\n");
|
||||
return 4;
|
||||
@@ -209,9 +206,8 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
||||
std::vector<float> lyric_embed(H_text * S_lyric);
|
||||
qwen3_embed_lookup(&text_enc, lyric_ids.data(), S_lyric, lyric_embed.data());
|
||||
|
||||
// 6. Condition encoder
|
||||
// 6. Condition encoder (backend init handled inside cond_ggml_load)
|
||||
CondGGML cond = {};
|
||||
cond_ggml_init_backend(&cond);
|
||||
if (!cond_ggml_load(&cond, g_dit_path.c_str())) {
|
||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load condition encoder\n");
|
||||
qwen3_free(&text_enc);
|
||||
|
||||
12
backend/go/cloud-proxy/Makefile
Normal file
12
backend/go/cloud-proxy/Makefile
Normal file
@@ -0,0 +1,12 @@
|
||||
GOCMD=go
|
||||
|
||||
cloud-proxy:
|
||||
CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o cloud-proxy ./
|
||||
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
build: cloud-proxy package
|
||||
|
||||
clean:
|
||||
rm -f cloud-proxy
|
||||
16
backend/go/cloud-proxy/cloud_proxy_suite_test.go
Normal file
16
backend/go/cloud-proxy/cloud_proxy_suite_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Ginkgo bootstrap. The other Test* functions in this package use
|
||||
// raw testing.T and run independently; they coexist with Ginkgo
|
||||
// specs registered via Describe / Context.
|
||||
func TestCloudProxySpecs(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "cloud-proxy specs")
|
||||
}
|
||||
39
backend/go/cloud-proxy/main.go
Normal file
39
backend/go/cloud-proxy/main.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package main
|
||||
|
||||
// cloud-proxy is a LocalAI backend that forwards request traffic to an
|
||||
// external HTTP provider (OpenAI, Anthropic, etc.). Two modes:
|
||||
//
|
||||
// - passthrough: serves the Forward RPC; the client wire format is
|
||||
// preserved end-to-end, no translation.
|
||||
// - translate: serves Predict/PredictStream; the backend converts
|
||||
// internal proto to the provider's wire format. (Phases 5–6.)
|
||||
//
|
||||
// LoadModel reads UpstreamURL/Mode/Provider/key references from
|
||||
// ProxyOptions and resolves the API key once at load time.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/xlog"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:50051", "the address to listen on")
|
||||
|
||||
func main() {
|
||||
// xlog's default handler emits ANSI color codes; that's fine for an
|
||||
// interactive shell but unreadable when the backend's stdout is
|
||||
// captured by LocalAI and tee'd to a log file. Force plain text when
|
||||
// LOCALAI_LOG_FORMAT is unset and stdout isn't a terminal.
|
||||
format := os.Getenv("LOCALAI_LOG_FORMAT")
|
||||
if format == "" && !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
format = xlog.TextFormat
|
||||
}
|
||||
xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), format))
|
||||
flag.Parse()
|
||||
if err := grpc.StartServer(*addr, NewCloudProxy()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
13
backend/go/cloud-proxy/package.sh
Executable file
13
backend/go/cloud-proxy/package.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the cloud-proxy binary into the package dir for the
|
||||
# final Dockerfile stage. Mirrors backend/go/local-store/package.sh —
|
||||
# no extra runtime libs needed since the backend is pure Go.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
mkdir -p $CURDIR/package
|
||||
cp -avf $CURDIR/cloud-proxy $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
270
backend/go/cloud-proxy/passthrough_edge_test.go
Normal file
270
backend/go/cloud-proxy/passthrough_edge_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("composeURL", func() {
|
||||
// Upstream URL convention: gallery configs put the canonical path
|
||||
// in upstream_url, so per-request Path is ignored. A bare-host
|
||||
// upstream_url accepts the per-request path.
|
||||
DescribeTable("path resolution",
|
||||
func(upstream, reqPath, want string) {
|
||||
got, err := composeURL(upstream, reqPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(got).To(Equal(want))
|
||||
},
|
||||
Entry("full path wins", "https://api.openai.com/v1/chat/completions", "/v1/something-else", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host accepts path", "https://api.openai.com", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("root slash treated as bare", "https://api.openai.com/", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host + empty path", "https://api.openai.com", "", "https://api.openai.com"),
|
||||
)
|
||||
|
||||
It("returns an error on invalid upstream URL", func() {
|
||||
_, err := composeURL("://garbage", "")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("applyAuthHeader", func() {
|
||||
It("sets x-api-key and anthropic-version for Anthropic, no Authorization", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerAnthropic, "ant-key")
|
||||
Expect(req.Header.Get("x-api-key")).To(Equal("ant-key"))
|
||||
Expect(req.Header.Get("anthropic-version")).NotTo(BeEmpty())
|
||||
Expect(req.Header.Get("Authorization")).To(BeEmpty(), "Authorization must not leak on Anthropic backend")
|
||||
})
|
||||
|
||||
It("sets Bearer Authorization for OpenAI, no x-api-key", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerOpenAI, "sk-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer sk-key"))
|
||||
Expect(req.Header.Get("x-api-key")).To(BeEmpty(), "x-api-key must not leak on OpenAI backend")
|
||||
})
|
||||
|
||||
It("defaults to Bearer when provider is empty", func() {
|
||||
// Passthrough mode often has provider == "" because the operator
|
||||
// doesn't claim a specific upstream wire format. Most providers
|
||||
// (including OpenAI-compatible ones) accept Bearer, so default to it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, "", "some-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer some-key"))
|
||||
})
|
||||
|
||||
It("preserves an existing anthropic-version header", func() {
|
||||
// If the client supplied anthropic-version (rare but legitimate
|
||||
// for an upstream pinned to a specific date), the proxy must not
|
||||
// clobber it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
req.Header.Set("anthropic-version", "2024-10-01")
|
||||
applyAuthHeader(req, providerAnthropic, "k")
|
||||
Expect(req.Header.Get("anthropic-version")).To(Equal("2024-10-01"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("isHopByHopHeader", func() {
|
||||
DescribeTable("hop-by-hop classification",
|
||||
func(header string, want bool) {
|
||||
Expect(isHopByHopHeader(header)).To(Equal(want))
|
||||
},
|
||||
Entry("Connection is hop-by-hop", "Connection", true),
|
||||
Entry("Keep-Alive is hop-by-hop", "Keep-Alive", true),
|
||||
Entry("Proxy-Connection is hop-by-hop", "Proxy-Connection", true),
|
||||
Entry("Transfer-Encoding is hop-by-hop", "Transfer-Encoding", true),
|
||||
Entry("TE is hop-by-hop", "TE", true),
|
||||
Entry("Trailer is hop-by-hop", "Trailer", true),
|
||||
Entry("Upgrade is hop-by-hop", "Upgrade", true),
|
||||
Entry("Host is hop-by-hop", "Host", true),
|
||||
Entry("Content-Length is hop-by-hop", "Content-Length", true),
|
||||
// Case-insensitive — RFC 7230 doesn't constrain header case.
|
||||
Entry("lowercase connection is hop-by-hop", "connection", true),
|
||||
Entry("uppercase HOST is hop-by-hop", "HOST", true),
|
||||
// Non hop-by-hop — must NOT be stripped.
|
||||
Entry("Authorization is end-to-end", "Authorization", false),
|
||||
Entry("Content-Type is end-to-end", "Content-Type", false),
|
||||
Entry("Accept is end-to-end", "Accept", false),
|
||||
Entry("X-Custom is end-to-end", "X-Custom", false),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("Forward", func() {
|
||||
It("strips hop-by-hop and Connection headers before upstream, preserves custom headers", func() {
|
||||
gotConnection := make(chan string, 1)
|
||||
gotXCustom := make(chan string, 1)
|
||||
gotHost := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotConnection <- r.Header.Get("Connection")
|
||||
gotXCustom <- r.Header.Get("X-Custom")
|
||||
gotHost <- r.Header.Get("Host")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-hopbyhop"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
{Name: "Connection", Value: "keep-alive"},
|
||||
{Name: "Host", Value: "spoofed.example.com"},
|
||||
{Name: "X-Custom", Value: "preserved"},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotConnection).To(BeEmpty(), "Connection must not leak to upstream")
|
||||
Expect(<-gotHost).NotTo(Equal("spoofed.example.com"), "Host header must not be spoofed through")
|
||||
Expect(<-gotXCustom).To(Equal("preserved"), "X-Custom header must survive")
|
||||
})
|
||||
|
||||
It("replaces caller-supplied Authorization with the configured key", func() {
|
||||
// The proxy must overwrite a client-supplied Authorization header
|
||||
// so a downstream caller can't smuggle stale or wrong credentials.
|
||||
gotAuth := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_AUTH_REPLACE_KEY", "sk-real")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_AUTH_REPLACE_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-replaces-auth"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
// Client-supplied Authorization with the wrong scheme / key.
|
||||
{Name: "Authorization", Value: "Basic Zm9vOmJhcg=="},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
|
||||
})
|
||||
|
||||
It("handles concurrent calls without interference", func() {
|
||||
// CloudProxy explicitly omits base.SingleThread — independent
|
||||
// Forward streams must not block each other or leak state.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
addr := "test://forward-concurrent"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
|
||||
const N = 8
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
stream, err := c.Forward(context.Background())
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
payload := "request-" + string(rune('A'+idx))
|
||||
if err := stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
BodyChunk: []byte(payload),
|
||||
}); err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv()
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
body = append(body, r.GetBodyChunk()...)
|
||||
}
|
||||
if string(body) != payload {
|
||||
errs <- &echoMismatch{want: payload, got: string(body)}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
var collected []error
|
||||
for err := range errs {
|
||||
collected = append(collected, err)
|
||||
}
|
||||
Expect(collected).To(BeEmpty(), "no concurrent Forward call should fail")
|
||||
})
|
||||
})
|
||||
|
||||
type echoMismatch struct{ want, got string }
|
||||
|
||||
func (e *echoMismatch) Error() string {
|
||||
return "echo mismatch: want " + strconv.Quote(e.want) + " got " + strconv.Quote(e.got)
|
||||
}
|
||||
508
backend/go/cloud-proxy/provider_anthropic.go
Normal file
508
backend/go/cloud-proxy/provider_anthropic.go
Normal file
@@ -0,0 +1,508 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Anthropic Messages API wire-format types. Narrowed to what translate
|
||||
// mode preserves through the Reply proto: text + tool_use blocks +
|
||||
// usage tokens. Image blocks, prompt caching, metadata, and stop
|
||||
// sequence metadata are not modelled — passthrough mode covers those.
|
||||
//
|
||||
// Notable differences from OpenAI:
|
||||
// - max_tokens is REQUIRED. Anthropic 400s without it.
|
||||
// - Roles are user/assistant only — system messages move to a
|
||||
// top-level `system` string field.
|
||||
// - Streaming SSE uses event: lines alongside data: lines. The
|
||||
// events we care about: content_block_start (carries tool_use
|
||||
// init: id + name), content_block_delta (text_delta with text;
|
||||
// input_json_delta with partial_json for tool arguments), and
|
||||
// message_stop (terminates the stream). Others are ignored.
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int32 `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// Content is `any` because Anthropic accepts a bare string OR a
|
||||
// list of content blocks. Use the string form for plain user/
|
||||
// assistant turns; switch to []anthropicContentBlock when the
|
||||
// turn needs tool_use (assistant) or tool_result (user) blocks.
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// anthropicToolChoice mirrors the four shapes Anthropic accepts:
|
||||
// {"type":"auto"} | {"type":"any"} | {"type":"tool","name":"X"} |
|
||||
// {"type":"none"} (newer models). OpenAI's "auto"/"none"/
|
||||
// "required"/{"function":{"name":"X"}} all map here.
|
||||
type anthropicToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// anthropicContentBlock is the union shape used both for response
|
||||
// blocks (text/tool_use we read off the wire) and outbound request
|
||||
// blocks (tool_use/tool_result we emit in the conversation history).
|
||||
// Anthropic encodes tool calls inline rather than as a separate field,
|
||||
// so we walk Content[] looking for type=="tool_use" on responses and
|
||||
// produce equivalent blocks when serialising prior-turn tool calls.
|
||||
type anthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
// Tool-result block fields. tool_result uses `content` (not
|
||||
// `text`) and pairs with `tool_use_id`; modelling them as
|
||||
// distinct fields avoids ambiguity at marshal time.
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
ResultContent string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []anthropicContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// anthropicStreamEvent is the union shape used for every event type we
|
||||
// process. Type discriminates; only the matching fields are populated.
|
||||
// content_block_start carries ContentBlock (with id/name for tool_use);
|
||||
// content_block_delta carries Delta (text or partial_json).
|
||||
type anthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index,omitempty"`
|
||||
ContentBlock *anthropicContentBlock `json:"content_block,omitempty"`
|
||||
Delta *anthropicStreamDelta `json:"delta,omitempty"`
|
||||
Message *anthropicResponse `json:"message,omitempty"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicStreamDelta struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
}
|
||||
|
||||
// Anthropic requires max_tokens. If the caller didn't set it, use a
|
||||
// generous-but-bounded default so the request doesn't 400.
|
||||
const anthropicDefaultMaxTokens int32 = 4096
|
||||
|
||||
const anthropicToolChoiceNone = "none"
|
||||
|
||||
// Reused JSON-Schema defaults for malformed inputs. Anthropic requires
|
||||
// input_schema to be a JSON object and tool_use.input to be a JSON
|
||||
// object; clients that omit them must not 400 the entire request.
|
||||
var (
|
||||
emptyJSONObject = json.RawMessage(`{}`)
|
||||
emptyObjectSchema = json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
)
|
||||
|
||||
func buildAnthropicRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := anthropicRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
MaxTokens: opts.GetTokens(),
|
||||
Stream: stream,
|
||||
StopSequences: opts.GetStopPrompts(),
|
||||
}
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = anthropicDefaultMaxTokens
|
||||
}
|
||||
// Newer Anthropic models 400 when both temperature and top_p are
|
||||
// set ("`temperature` and `top_p` cannot both be specified for
|
||||
// this model. Please use only one.") even though their docs only
|
||||
// "recommend" picking one. The OpenAI-compatible chat UI almost
|
||||
// always sends both with default values, so prefer temperature
|
||||
// and drop top_p when both are present.
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
} else if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
|
||||
req.Tools = convertOpenAITools(opts.GetTools())
|
||||
req.ToolChoice = convertOpenAIToolChoice(opts.GetToolChoice())
|
||||
// Anthropic rejects tool_choice without tools and older models
|
||||
// don't accept {"type":"none"} — collapse to a no-tools request.
|
||||
if req.ToolChoice != nil && req.ToolChoice.Type == anthropicToolChoiceNone {
|
||||
req.Tools, req.ToolChoice = nil, nil
|
||||
}
|
||||
|
||||
var systemParts []string
|
||||
for _, m := range opts.GetMessages() {
|
||||
role := m.GetRole()
|
||||
if role == "system" {
|
||||
if c := m.GetContent(); c != "" {
|
||||
systemParts = append(systemParts, c)
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch role {
|
||||
case "user":
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "assistant":
|
||||
if blocks := assistantBlocks(m); blocks != nil {
|
||||
req.Messages = append(req.Messages, anthropicMessage{Role: "assistant", Content: blocks})
|
||||
continue
|
||||
}
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "assistant",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "tool", "function":
|
||||
req.Messages = appendToolResult(req.Messages, anthropicContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: m.GetToolCallId(),
|
||||
ResultContent: m.GetContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
req.System = strings.Join(systemParts, "\n\n")
|
||||
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []anthropicMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// appendToolResult appends a tool_result block as a user message,
|
||||
// merging into a preceding user message that already carries blocks.
|
||||
// Anthropic concatenates consecutive same-role messages on its end,
|
||||
// but explicit merging keeps the body smaller and the conversation
|
||||
// strictly alternating — which some upstream filters require.
|
||||
func appendToolResult(msgs []anthropicMessage, block anthropicContentBlock) []anthropicMessage {
|
||||
if n := len(msgs); n > 0 && msgs[n-1].Role == "user" {
|
||||
if existing, ok := msgs[n-1].Content.([]anthropicContentBlock); ok {
|
||||
msgs[n-1].Content = append(existing, block)
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
return append(msgs, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: []anthropicContentBlock{block},
|
||||
})
|
||||
}
|
||||
|
||||
func convertOpenAITools(toolsJSON string) []anthropicTool {
|
||||
if toolsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var raw []openAITool
|
||||
if err := json.Unmarshal([]byte(toolsJSON), &raw); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tools JSON, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
tools := make([]anthropicTool, 0, len(raw))
|
||||
for _, t := range raw {
|
||||
if t.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
schema := t.Function.Parameters
|
||||
if len(schema) == 0 {
|
||||
schema = emptyObjectSchema
|
||||
}
|
||||
tools = append(tools, anthropicTool{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// convertOpenAIToolChoice accepts the spec form
|
||||
// ({type:function, function:{name:X}}) and the flat legacy form
|
||||
// ({type:function, name:X}) some clients send. Unknown object shapes
|
||||
// are warned and dropped rather than silently treated as auto.
|
||||
func convertOpenAIToolChoice(toolChoiceJSON string) *anthropicToolChoice {
|
||||
if toolChoiceJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var asString string
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asString); err == nil {
|
||||
switch asString {
|
||||
case "auto":
|
||||
return &anthropicToolChoice{Type: "auto"}
|
||||
case "none":
|
||||
return &anthropicToolChoice{Type: anthropicToolChoiceNone}
|
||||
case "required":
|
||||
return &anthropicToolChoice{Type: "any"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var asObj struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asObj); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tool_choice, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
if name := asObj.Function.Name; name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: name}
|
||||
}
|
||||
if asObj.Name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: asObj.Name}
|
||||
}
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unrecognised tool_choice shape, dropping", "shape", toolChoiceJSON)
|
||||
return nil
|
||||
}
|
||||
|
||||
// openAITool mirrors pkg/functions.Tool but keeps Parameters as
|
||||
// json.RawMessage so the input_schema passes through verbatim — no
|
||||
// re-marshal cost, no fidelity loss on exotic schemas.
|
||||
type openAITool struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters json.RawMessage `json:"parameters"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
func assistantBlocks(m *pb.Message) []anthropicContentBlock {
|
||||
toolCallsJSON := m.GetToolCalls()
|
||||
if toolCallsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var toolCalls []openAIToolCall
|
||||
if err := json.Unmarshal([]byte(toolCallsJSON), &toolCalls); err != nil || len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
blocks := make([]anthropicContentBlock, 0, len(toolCalls)+1)
|
||||
if text := m.GetContent(); text != "" {
|
||||
blocks = append(blocks, anthropicContentBlock{Type: "text", Text: text})
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
// OpenAI's arguments are a JSON-encoded string; pass through
|
||||
// as RawMessage so a non-JSON string from a poorly-formed
|
||||
// local model doesn't crash the marshaller downstream.
|
||||
args := json.RawMessage(tc.Function.Arguments)
|
||||
if len(args) == 0 {
|
||||
args = emptyJSONObject
|
||||
}
|
||||
blocks = append(blocks, anthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: args,
|
||||
})
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
// doAnthropicRequest is the Anthropic counterpart of doOpenAIRequest.
|
||||
// applyAuthHeader sets x-api-key and anthropic-version when provider
|
||||
// is anthropic, so this method doesn't need to duplicate that.
|
||||
func (c *CloudProxy) doAnthropicRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictAnthropicRich returns the full Reply: joined text from all
|
||||
// text blocks, tool_use blocks mapped to ToolCallDelta, and usage
|
||||
// tokens.
|
||||
func (c *CloudProxy) predictAnthropicRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildAnthropicRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed anthropicResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
|
||||
reply := &pb.Reply{}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.InputTokens)
|
||||
reply.Tokens = int32(parsed.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
var content strings.Builder
|
||||
var toolCalls []*pb.ToolCallDelta
|
||||
toolIdx := 0
|
||||
for _, b := range parsed.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
content.WriteString(b.Text)
|
||||
case "tool_use":
|
||||
// Input is a structured JSON object; we serialise to a
|
||||
// string so it fits the OpenAI-shaped arguments field
|
||||
// downstream consumers expect.
|
||||
args := ""
|
||||
if len(b.Input) > 0 {
|
||||
args = string(b.Input)
|
||||
}
|
||||
toolCalls = append(toolCalls, newToolCallDelta(toolIdx, b.ID, b.Name, args))
|
||||
toolIdx++
|
||||
}
|
||||
}
|
||||
reply.Message = []byte(content.String())
|
||||
if len(toolCalls) > 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{ToolCalls: toolCalls}}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictAnthropicStreamRich streams Reply chunks from Anthropic's SSE.
|
||||
// Three event types matter: content_block_start (initialises tool_use
|
||||
// id+name), content_block_delta (carries text or input_json_delta),
|
||||
// message_stop (terminates). The block index from the wire feeds
|
||||
// straight into ToolCallDelta.Index so downstream consumers can
|
||||
// reassemble multiple parallel tool calls.
|
||||
func (c *CloudProxy) predictAnthropicStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildAnthropicRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
var ev anthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
switch ev.Type {
|
||||
case "content_block_start":
|
||||
// tool_use blocks announce id + name here; arguments arrive
|
||||
// in subsequent input_json_delta events. Emit a Reply with
|
||||
// just the tool_call init fields so consumers can allocate
|
||||
// a slot at this index.
|
||||
if ev.ContentBlock != nil && ev.ContentBlock.Type == "tool_use" {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, ev.ContentBlock.ID, ev.ContentBlock.Name, ""),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
if ev.Delta == nil {
|
||||
continue
|
||||
}
|
||||
switch ev.Delta.Type {
|
||||
case "text_delta":
|
||||
if ev.Delta.Text == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Message: []byte(ev.Delta.Text),
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: ev.Delta.Text}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
case "input_json_delta":
|
||||
if ev.Delta.PartialJSON == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, "", "", ev.Delta.PartialJSON),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_delta":
|
||||
// Anthropic sends final usage in message_delta.usage. Emit
|
||||
// a usage-only Reply so the consumer can record totals.
|
||||
if ev.Usage != nil {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Tokens: int32(ev.Usage.OutputTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
334
backend/go/cloud-proxy/provider_anthropic_test.go
Normal file
334
backend/go/cloud-proxy/provider_anthropic_test.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// fakeAnthropicUpstream mirrors fakeOpenAIUpstream but decodes the
|
||||
// request body as an anthropicRequest so tests can assert on the
|
||||
// translated wire shape (system field, max_tokens, etc.).
|
||||
func fakeAnthropicUpstream(t *testing.T, handler func(req anthropicRequest) (status int, body string, contentType string)) (*httptest.Server, *anthropicRequest) {
|
||||
t.Helper()
|
||||
var captured anthropicRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newAnthropicTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_FAKE", "sk-ant-fake")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "claude-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_FAKE",
|
||||
UpstreamModel: "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_BasicMessages(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hi there"}],"model":"claude-3-5-sonnet-20241022","usage":{"input_tokens":5,"output_tokens":2}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
g.Expect(captured.Model).To(Equal("claude-3-5-sonnet-20241022"))
|
||||
// System message must be hoisted out of Messages into top-level field.
|
||||
g.Expect(captured.System).To(Equal("be brief"))
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
// Anthropic 400s when both temperature and top_p are set; the
|
||||
// translator must prefer temperature and drop top_p.
|
||||
g.Expect(captured.TopP).To(BeNil())
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
// When only top_p is set, it should be forwarded.
|
||||
func TestPredict_Anthropic_TopPOnly(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
|
||||
TopP: 0.9,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
// PredictOptions.TopP is float32 on the wire; the translator widens
|
||||
// to float64 so 0.9 round-trips as 0.8999999761581421… — compare
|
||||
// with a small tolerance rather than exact equality.
|
||||
g.Expect(captured.TopP).NotTo(BeNil())
|
||||
g.Expect(math.Abs(*captured.TopP - 0.9)).To(BeNumerically("<=", 1e-6))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_DefaultsMaxTokens(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic 400s without max_tokens. The translator must default
|
||||
// it when the caller doesn't supply Tokens.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.MaxTokens).To(Equal(anthropicDefaultMaxTokens))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?", Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_ConcatenatesContentBlocks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic may return multiple text blocks; the translator joins
|
||||
// them so the Predict() string return is the full assistant message.
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"hello "},{"type":"text","text":"world"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello world"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 401, `{"error":{"type":"authentication_error","message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_Anthropic_StreamsTextDeltas(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Real Anthropic SSE has event: lines + data: lines. The translator
|
||||
// only needs the data: payload; only content_block_delta with
|
||||
// delta.type=text_delta carries content. message_stop ends.
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" \"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"world\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
Tokens: 16,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_TranslatesOpenAITools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "weather in Paris?"}},
|
||||
Tools: tools,
|
||||
ToolChoice: `"auto"`,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(HaveLen(1))
|
||||
g.Expect(captured.Tools[0].Name).To(Equal("get_weather"))
|
||||
g.Expect(captured.Tools[0].Description).To(Equal("Get weather"))
|
||||
// input_schema must be the parameters object verbatim.
|
||||
g.Expect(string(captured.Tools[0].InputSchema)).To(ContainSubstring(`"city"`))
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("auto"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_RequiredMapsToAny(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"required"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("any"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NoneDropsTools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"none"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(BeNil())
|
||||
g.Expect(captured.ToolChoice).To(BeNil())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NamedFunction(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"weather","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `{"type":"function","function":{"name":"weather"}}`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("tool"))
|
||||
g.Expect(captured.ToolChoice.Name).To(Equal("weather"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_RoundTripsAssistantToolCalls(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// LocalAI Assistant's second turn: the LLM previously emitted a
|
||||
// tool_use, the server executed it, and the conversation now
|
||||
// includes the assistant turn (with tool_calls) plus a tool-role
|
||||
// result message. Both must convert to Anthropic block form.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"list_models","parameters":{"type":"object"}}}]`
|
||||
toolCallsJSON := `[{"id":"call_abc","type":"function","function":{"name":"list_models","arguments":"{}"}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Tools: tools,
|
||||
Messages: []*pb.Message{
|
||||
{Role: "user", Content: "what models are installed?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: toolCallsJSON},
|
||||
{Role: "tool", Content: `{"models":["a","b"]}`, ToolCallId: "call_abc"},
|
||||
},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
g.Expect(captured.Messages).To(HaveLen(3))
|
||||
// 1. user text — bare string
|
||||
s, ok := captured.Messages[0].Content.(string)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(s).To(Equal("what models are installed?"))
|
||||
// 2. assistant — must be a content-block list with one tool_use
|
||||
// json.Unmarshal of `any` produces []any not []anthropicContentBlock.
|
||||
blocks, ok := captured.Messages[1].Content.([]any)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(blocks).To(HaveLen(1))
|
||||
b0, _ := blocks[0].(map[string]any)
|
||||
g.Expect(b0["type"]).To(Equal("tool_use"))
|
||||
g.Expect(b0["id"]).To(Equal("call_abc"))
|
||||
g.Expect(b0["name"]).To(Equal("list_models"))
|
||||
// 3. tool → user with tool_result block
|
||||
g.Expect(captured.Messages[2].Role).To(Equal("user"))
|
||||
resBlocks, _ := captured.Messages[2].Content.([]any)
|
||||
r0, _ := resBlocks[0].(map[string]any)
|
||||
g.Expect(r0["type"]).To(Equal("tool_result"))
|
||||
g.Expect(r0["tool_use_id"]).To(Equal("call_abc"))
|
||||
g.Expect(r0["content"]).To(Equal(`{"models":["a","b"]}`))
|
||||
}
|
||||
119
backend/go/cloud-proxy/provider_edge_test.go
Normal file
119
backend/go/cloud-proxy/provider_edge_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Verify buildOpenAIRequest preserves caller-supplied tools and
|
||||
// tool_choice as opaque JSON. PredictOptions carries them as strings;
|
||||
// they must land in the outbound request body unchanged so the
|
||||
// upstream sees the caller's intent verbatim. A regression here would
|
||||
// silently disable function calling for translate-mode clients.
|
||||
func TestBuildOpenAIRequest_ToolsAndToolChoicePassthrough(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
toolsJSON := `[{"type":"function","function":{"name":"search","parameters":{"type":"object"}}}]`
|
||||
choiceJSON := `{"type":"function","function":{"name":"search"}}`
|
||||
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find x"}},
|
||||
Tools: toolsJSON,
|
||||
ToolChoice: choiceJSON,
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var decoded openAIRequest
|
||||
err = json.Unmarshal(body, &decoded)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// Compare the JSON-canonical form so whitespace differences are ignored.
|
||||
gotTools, _ := json.Marshal(json.RawMessage(decoded.Tools))
|
||||
wantTools, _ := json.Marshal(json.RawMessage(toolsJSON))
|
||||
g.Expect(string(gotTools)).To(Equal(string(wantTools)))
|
||||
gotChoice, _ := json.Marshal(json.RawMessage(decoded.ToolChoice))
|
||||
wantChoice, _ := json.Marshal(json.RawMessage(choiceJSON))
|
||||
g.Expect(string(gotChoice)).To(Equal(string(wantChoice)))
|
||||
}
|
||||
|
||||
// Garbage JSON in tools / tool_choice is silently dropped (omitted)
|
||||
// rather than blowing up the request. Documents the parseRawJSON
|
||||
// behaviour — operators shouldn't see hard failures from an upstream
|
||||
// caller's mis-formatted tools field.
|
||||
func TestBuildOpenAIRequest_InvalidToolsJSONDropped(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: "this is not json",
|
||||
ToolChoice: "{also bad",
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("this is not json"))
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("{also bad"))
|
||||
}
|
||||
|
||||
// Anthropic empty content array yields an empty Reply (not an error).
|
||||
// Mirrors how an upstream tool_use-only response might arrive — the
|
||||
// content array can legitimately be empty in some edge cases.
|
||||
func TestPredictRich_Anthropic_EmptyContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"m1","type":"message","role":"assistant","content":[],"usage":{"input_tokens":3,"output_tokens":0}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(0))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(3)))
|
||||
}
|
||||
|
||||
// A truncated / malformed SSE payload mid-stream should be tolerated:
|
||||
// the malformed chunk gets skipped (xlog.Debug logged), valid chunks
|
||||
// before AND after it still reach the channel.
|
||||
func TestPredictStreamRich_OpenAI_TolerantOfBadChunks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
body := strings.Join([]string{
|
||||
`data: {"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
``,
|
||||
`data: this-is-not-json{{`,
|
||||
``,
|
||||
`data: {"choices":[{"index":0,"delta":{"content":" world"}}]}`,
|
||||
``,
|
||||
`data: [DONE]`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var assembled strings.Builder
|
||||
for reply := range results {
|
||||
assembled.Write(reply.GetMessage())
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// The good chunks before and after the malformed one both made it through.
|
||||
g.Expect(assembled.String()).To(Equal("hello world"))
|
||||
}
|
||||
320
backend/go/cloud-proxy/provider_openai.go
Normal file
320
backend/go/cloud-proxy/provider_openai.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// OpenAI Chat Completions wire-format types. Narrowed to the fields
|
||||
// translate mode needs to preserve through the Reply proto: content,
|
||||
// role, tool_calls (typed so we can map them to pb.ToolCallDelta),
|
||||
// and sampling params copied verbatim from PredictOptions.
|
||||
//
|
||||
// Provider-specific extensions (logit_bias, function calling beyond
|
||||
// tool_calls, etc.) are not modelled — passthrough mode covers callers
|
||||
// that need full upstream fidelity.
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int32 `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// openAIToolCall covers both the non-streaming response shape (full
|
||||
// id+function+arguments) and the streaming-delta shape (sparse fields,
|
||||
// index assignment). The proto's ToolCallDelta absorbs both — name is
|
||||
// set on first appearance, arguments arrive incrementally in streaming.
|
||||
type openAIToolCall struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function openAIFunctionCall `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
type openAIFunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Choices []openAIChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChunk struct {
|
||||
Choices []openAIStreamChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// buildOpenAIRequest converts pb.PredictOptions into the OpenAI Chat
|
||||
// Completions request body. Prefers Messages when non-empty; falls
|
||||
// back to wrapping Prompt as a single user message so plain
|
||||
// /completions-style calls still work in translate mode.
|
||||
func buildOpenAIRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := openAIRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
Stream: stream,
|
||||
Stop: opts.GetStopPrompts(),
|
||||
Tools: parseRawJSON(opts.GetTools()),
|
||||
ToolChoice: parseRawJSON(opts.GetToolChoice()),
|
||||
}
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
}
|
||||
if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
req.MaxTokens = &n
|
||||
}
|
||||
if p := opts.GetFrequencyPenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.FrequencyPenalty = &v
|
||||
}
|
||||
if p := opts.GetPresencePenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.PresencePenalty = &v
|
||||
}
|
||||
|
||||
for _, m := range opts.GetMessages() {
|
||||
msg := openAIMessage{
|
||||
Role: m.GetRole(),
|
||||
Content: m.GetContent(),
|
||||
Name: m.GetName(),
|
||||
ToolCallID: m.GetToolCallId(),
|
||||
}
|
||||
// Pre-existing tool_calls arrive as a JSON string from the
|
||||
// upstream caller's previous assistant turn; pass-through as-is.
|
||||
if tc := m.GetToolCalls(); tc != "" {
|
||||
_ = json.Unmarshal([]byte(tc), &msg.ToolCalls)
|
||||
}
|
||||
req.Messages = append(req.Messages, msg)
|
||||
}
|
||||
// Fallback for plain Prompt requests (no Messages array). LocalAI
|
||||
// templating may have produced a flat prompt; rewrap as a single
|
||||
// user message so the upstream chat endpoint accepts it.
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []openAIMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// modelName picks the upstream model: upstream_model from the proxy
|
||||
// config wins (operator override), else the local model name captured
|
||||
// at LoadModel time. Operator sets upstream_model to map LocalAI's
|
||||
// alias (e.g. "claude-strict") to the upstream's canonical name
|
||||
// (e.g. "claude-3-5-sonnet-20241022").
|
||||
func modelName(cfg *proxyConfig, _ *pb.PredictOptions) string {
|
||||
if cfg.upstreamModel != "" {
|
||||
return cfg.upstreamModel
|
||||
}
|
||||
return cfg.localModel
|
||||
}
|
||||
|
||||
// parseRawJSON parses a JSON string into a RawMessage so it round-trips
|
||||
// into the upstream body. Returns nil for empty/invalid input so the
|
||||
// field is omitted (omitempty).
|
||||
func parseRawJSON(s string) json.RawMessage {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
var probe json.RawMessage
|
||||
if err := json.Unmarshal([]byte(s), &probe); err != nil {
|
||||
return nil
|
||||
}
|
||||
return probe
|
||||
}
|
||||
|
||||
// doOpenAIRequest builds + sends the upstream request. Returns the
|
||||
// raw response on success; caller handles status / body.
|
||||
func (c *CloudProxy) doOpenAIRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictOpenAIRich is the non-streaming translate path. Returns a
|
||||
// fully-populated *pb.Reply with assistant content, tool calls, and
|
||||
// token usage. The gRPC server forwards the Reply verbatim.
|
||||
func (c *CloudProxy) predictOpenAIRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildOpenAIRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
if len(parsed.Choices) == 0 {
|
||||
return nil, errors.New("cloud-proxy: upstream returned no choices")
|
||||
}
|
||||
|
||||
choice := parsed.Choices[0]
|
||||
reply := &pb.Reply{
|
||||
Message: []byte(choice.Message.Content),
|
||||
}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.PromptTokens)
|
||||
reply.Tokens = int32(parsed.Usage.CompletionTokens)
|
||||
}
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// Non-streaming: a single ChatDelta carries the full tool-call
|
||||
// set. Index/Name/Arguments are populated together; downstream
|
||||
// consumers don't need to assemble streaming deltas.
|
||||
delta := &pb.ChatDelta{}
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
delta.ToolCalls = append(delta.ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
reply.ChatDeltas = []*pb.ChatDelta{delta}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictOpenAIStreamRich streams *pb.Reply chunks. Each chunk carries
|
||||
// either a content delta (Message + ChatDeltas[].Content) or tool-call
|
||||
// deltas (ChatDeltas[].ToolCalls). The final Reply carries usage tokens
|
||||
// when the upstream sends them (stream_options.include_usage).
|
||||
func (c *CloudProxy) predictOpenAIStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildOpenAIRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
var chunk openAIStreamChunk
|
||||
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
// Usage frames may arrive separately from content frames when
|
||||
// stream_options.include_usage is set; emit a usage-only Reply
|
||||
// in that case so the consumer sees the totals.
|
||||
if chunk.Usage != nil && len(chunk.Choices) == 0 {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
PromptTokens: int32(chunk.Usage.PromptTokens),
|
||||
Tokens: int32(chunk.Usage.CompletionTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, ch := range chunk.Choices {
|
||||
reply := &pb.Reply{}
|
||||
if ch.Delta.Content != "" {
|
||||
reply.Message = []byte(ch.Delta.Content)
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{Content: ch.Delta.Content}}
|
||||
}
|
||||
if len(ch.Delta.ToolCalls) > 0 {
|
||||
if len(reply.ChatDeltas) == 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{}}
|
||||
}
|
||||
for _, tc := range ch.Delta.ToolCalls {
|
||||
reply.ChatDeltas[0].ToolCalls = append(reply.ChatDeltas[0].ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
}
|
||||
if reply.Message == nil && len(reply.ChatDeltas) == 0 {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, reply) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
170
backend/go/cloud-proxy/provider_openai_test.go
Normal file
170
backend/go/cloud-proxy/provider_openai_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// fakeOpenAIUpstream returns an httptest.Server that decodes the
|
||||
// inbound request as an openAIRequest, calls handler with it, and
|
||||
// writes the handler's reply as the response.
|
||||
func fakeOpenAIUpstream(t *testing.T, handler func(req openAIRequest) (status int, body string, contentType string)) (*httptest.Server, *openAIRequest) {
|
||||
t.Helper()
|
||||
var captured openAIRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_OPENAI_FAKE", "sk-fake-openai")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "gpt-4o-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
ApiKeyEnv: "CLOUD_PROXY_OPENAI_FAKE",
|
||||
UpstreamModel: "gpt-4o",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_BasicChat(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"id":"resp-1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
// Verify the upstream saw a properly-translated request.
|
||||
g.Expect(captured.Model).To(Equal("gpt-4o"))
|
||||
g.Expect(captured.Messages).To(HaveLen(2))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("system"))
|
||||
g.Expect(captured.Messages[1].Role).To(Equal("user"))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
g.Expect(captured.MaxTokens).NotTo(BeNil())
|
||||
g.Expect(*captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// No Messages array — backend should synth a single user message
|
||||
// from Prompt so non-chat clients still route through translate.
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 401, `{"error":{"message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_OpenAI_StreamsContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Stream three content deltas then [DONE]. Verify the channel
|
||||
// receives them in order with no missing pieces.
|
||||
chunks := []string{
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":" "}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"world"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestPredict_RejectedInPassthroughMode(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_FAKE", "k")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_, err = cp.Predict(&pb.PredictOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("only valid in translate"))
|
||||
}
|
||||
429
backend/go/cloud-proxy/proxy.go
Normal file
429
backend/go/cloud-proxy/proxy.go
Normal file
@@ -0,0 +1,429 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
|
||||
// import core to keep the boundary clean.
|
||||
const (
|
||||
modePassthrough = "passthrough"
|
||||
modeTranslate = "translate"
|
||||
|
||||
providerOpenAI = "openai"
|
||||
providerAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// CloudProxy is the LocalAI backend that proxies model traffic to a
|
||||
// configured upstream HTTP provider. Concurrency: base.SingleThread is
|
||||
// NOT embedded — forward calls are independent and HTTP transport is
|
||||
// goroutine-safe, so multiple Forward streams can run in parallel.
|
||||
// Locking would serialise requests to a chat provider for no benefit.
|
||||
type CloudProxy struct {
|
||||
base.Base
|
||||
|
||||
cfg atomic.Pointer[proxyConfig]
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type proxyConfig struct {
|
||||
upstreamURL string
|
||||
mode string
|
||||
provider string
|
||||
upstreamModel string
|
||||
localModel string // ModelOptions.Model — fallback when upstream_model is unset
|
||||
apiKey string // resolved at Load time
|
||||
}
|
||||
|
||||
func NewCloudProxy() *CloudProxy {
|
||||
// No Client-level Timeout — that would bound streaming SSE
|
||||
// responses too, which can legitimately last minutes. Per-request
|
||||
// deadlines come from the gRPC stream context.
|
||||
return &CloudProxy{client: &http.Client{}}
|
||||
}
|
||||
|
||||
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
|
||||
po := opts.GetProxy()
|
||||
if po == nil {
|
||||
return errors.New("cloud-proxy: Load requires ProxyOptions to be set")
|
||||
}
|
||||
if po.GetUpstreamUrl() == "" {
|
||||
return errors.New("cloud-proxy: upstream_url is required")
|
||||
}
|
||||
if _, err := url.ParseRequestURI(po.GetUpstreamUrl()); err != nil {
|
||||
return fmt.Errorf("cloud-proxy: upstream_url %q invalid: %w", po.GetUpstreamUrl(), err)
|
||||
}
|
||||
|
||||
mode := po.GetMode()
|
||||
if mode == "" {
|
||||
mode = modePassthrough
|
||||
}
|
||||
switch mode {
|
||||
case modePassthrough:
|
||||
case modeTranslate:
|
||||
switch po.GetProvider() {
|
||||
case providerOpenAI:
|
||||
// implemented in provider_openai.go
|
||||
case providerAnthropic:
|
||||
// implemented in provider_anthropic.go
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: translate mode requires provider in {%s, %s}, got %q",
|
||||
providerOpenAI, providerAnthropic, po.GetProvider())
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: unknown mode %q", mode)
|
||||
}
|
||||
|
||||
key, err := resolveAPIKey(po.GetApiKeyEnv(), po.GetApiKeyFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.cfg.Store(&proxyConfig{
|
||||
upstreamURL: po.GetUpstreamUrl(),
|
||||
mode: mode,
|
||||
provider: po.GetProvider(),
|
||||
upstreamModel: po.GetUpstreamModel(),
|
||||
localModel: opts.GetModel(),
|
||||
apiKey: key,
|
||||
})
|
||||
xlog.Info("cloud-proxy: ready",
|
||||
"upstream", po.GetUpstreamUrl(),
|
||||
"mode", mode,
|
||||
"provider", po.GetProvider(),
|
||||
"has_key", key != "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveAPIKey mirrors config.ProxyConfig.ResolveAPIKey. Duplicated
|
||||
// (a few lines) rather than importing core/config from a backend
|
||||
// binary — keeps backends independent of core's package layout.
|
||||
// Mutual-exclusion is enforced upstream in core/config.Validate.
|
||||
func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
if envName != "" {
|
||||
v := os.Getenv(envName)
|
||||
if v == "" {
|
||||
return "", fmt.Errorf("cloud-proxy: api_key_env %q is unset", envName)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
if filePath != "" {
|
||||
b, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: read api_key_file %q: %w", filePath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(b)), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// PredictRich is the non-streaming translate path. Returns a fully-
|
||||
// populated *pb.Reply: content, tool-call deltas (ChatDeltas), and
|
||||
// usage tokens. Implements the optional grpc.AIModelRich interface;
|
||||
// the gRPC server prefers this path over Predict when present so
|
||||
// tool calls survive the round-trip. Passthrough mode rejects
|
||||
// PredictRich — callers must use Forward.
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIRich(ctx, cfg, opts)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicRich(ctx, cfg, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("cloud-proxy: predict not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// PredictStreamRich is the rich streaming counterpart of PredictRich.
|
||||
// Each emitted Reply carries either a content delta, tool-call deltas,
|
||||
// or usage tokens (the final upstream frame). base.Base.PredictStream
|
||||
// is bypassed when AIModelRich is implemented, so the channel is
|
||||
// closed by the gRPC server pump.
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict-stream", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict-stream failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIStreamRich(ctx, cfg, opts, results)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicStreamRich(ctx, cfg, opts, results)
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: predictStream not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Predict is the legacy (string, error) AIModel signature. Used only
|
||||
// if a caller goes through the non-rich path (it shouldn't, since
|
||||
// server.go prefers PredictRich). Provided so the AIModel interface
|
||||
// is satisfied for backends that haven't opted into the rich variant.
|
||||
func (c *CloudProxy) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
reply, err := c.PredictRich(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.GetMessage()), nil
|
||||
}
|
||||
|
||||
// PredictStream is the legacy chan-string streaming path. Adapts the
|
||||
// rich stream by extracting only content text — tool-call-only chunks
|
||||
// (no Message bytes) and usage-only chunks are silently dropped, since
|
||||
// the legacy chan-string contract cannot represent them. Consumers
|
||||
// that need tool calls must call PredictStreamRich directly.
|
||||
func (c *CloudProxy) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
defer close(results)
|
||||
richCh := make(chan *pb.Reply)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- c.PredictStreamRich(opts, richCh)
|
||||
close(richCh)
|
||||
}()
|
||||
for reply := range richCh {
|
||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
||||
results <- string(msg)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// sendReply pushes one Reply onto a stream channel honouring ctx
|
||||
// cancellation. Returns false on cancel so the caller can exit with
|
||||
// ctx.Err(). Used by both translate-mode providers.
|
||||
func sendReply(ctx context.Context, results chan<- *pb.Reply, reply *pb.Reply) bool {
|
||||
select {
|
||||
case results <- reply:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newToolCallDelta is a small constructor for the cross-provider
|
||||
// tool-call delta shape. Centralised so the int32 cast and the four
|
||||
// fields stay consistent across the OpenAI / Anthropic translators.
|
||||
// Empty name/args are valid — Anthropic streaming announces the call
|
||||
// with id+name then sends arguments incrementally; OpenAI's reverse
|
||||
// pattern (args without name) also lands here.
|
||||
func newToolCallDelta(index int, id, name, args string) *pb.ToolCallDelta {
|
||||
return &pb.ToolCallDelta{
|
||||
Index: int32(index),
|
||||
Id: id,
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward shovels bytes between a Forward gRPC stream and an upstream
|
||||
// HTTP request. First request message carries path/method/headers and
|
||||
// the initial body chunk; subsequent messages append body chunks. The
|
||||
// first reply carries upstream status + response headers; subsequent
|
||||
// replies stream body chunks until the upstream connection closes.
|
||||
// Cancellation of ctx (the gRPC stream context) closes the upstream
|
||||
// connection.
|
||||
func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest, out chan<- *pb.ForwardReply) error {
|
||||
defer close(out)
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
}
|
||||
|
||||
first, ok := <-in
|
||||
if !ok {
|
||||
return errors.New("cloud-proxy: Forward stream closed before first request")
|
||||
}
|
||||
|
||||
// Honour the per-request path only when the configured upstream_url
|
||||
// has no path of its own — gallery convention is to put the
|
||||
// canonical path in upstream_url.
|
||||
fullURL, err := composeURL(cfg.upstreamURL, first.GetPath())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method := first.GetMethod()
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
// Pipe the body in from the gRPC stream so the HTTP request can
|
||||
// start before the client finishes sending. The pipe-reader is
|
||||
// closed via CloseWithError on the error paths so the writer
|
||||
// goroutine doesn't block forever.
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go func() {
|
||||
var writeErr error
|
||||
defer func() { _ = pw.CloseWithError(writeErr) }()
|
||||
if len(first.GetBodyChunk()) > 0 {
|
||||
if _, writeErr = pw.Write(first.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
for req := range in {
|
||||
if len(req.GetBodyChunk()) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, writeErr = pw.Write(req.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, fullURL, pr)
|
||||
if err != nil {
|
||||
_ = pr.CloseWithError(err) // unblocks the body-pump's pw.Write
|
||||
return fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
|
||||
// Apply caller-supplied headers, then override with the
|
||||
// authorization header derived from the resolved key. Caller-
|
||||
// supplied Authorization is always replaced — operators may not
|
||||
// know the backend's auth scheme, and silently leaking through a
|
||||
// client Authorization header to a different upstream would
|
||||
// confuse the upstream and could leak credentials.
|
||||
for _, h := range first.GetHeaders() {
|
||||
if h == nil || h.GetName() == "" {
|
||||
continue
|
||||
}
|
||||
// Strip hop-by-hop headers that aren't meaningful to the
|
||||
// upstream (Host is set by the http client from the URL;
|
||||
// Content-Length is computed from the body).
|
||||
if isHopByHopHeader(h.GetName()) {
|
||||
continue
|
||||
}
|
||||
req.Header.Add(h.GetName(), h.GetValue())
|
||||
}
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
|
||||
xlog.Info("cloud-proxy: forward", "method", method, "url", fullURL, "provider", cfg.provider)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: forward upstream failed", "url", fullURL, "error", err)
|
||||
return fmt.Errorf("cloud-proxy: upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
logFn := xlog.Info
|
||||
if resp.StatusCode >= 400 {
|
||||
logFn = xlog.Warn
|
||||
}
|
||||
logFn("cloud-proxy: forward response", "url", fullURL, "status", resp.StatusCode)
|
||||
|
||||
// First reply: status + response headers, no body.
|
||||
headers := make([]*pb.ForwardHeader, 0, len(resp.Header))
|
||||
for k, vs := range resp.Header {
|
||||
for _, v := range vs {
|
||||
headers = append(headers, &pb.ForwardHeader{Name: k, Value: v})
|
||||
}
|
||||
}
|
||||
out <- &pb.ForwardReply{Status: int32(resp.StatusCode), Headers: headers}
|
||||
|
||||
// Subsequent replies: body chunks. Use a fixed 8KB buffer — small
|
||||
// enough that SSE token frames flush promptly, large enough that
|
||||
// long chunked-transfer bodies aren't death by a thousand reads.
|
||||
buf := make([]byte, 8*1024)
|
||||
for {
|
||||
n, rerr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
out <- &pb.ForwardReply{BodyChunk: chunk}
|
||||
}
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("cloud-proxy: upstream body read: %w", rerr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// composeURL combines the configured upstream URL with the per-request
|
||||
// path. The upstream URL typically already includes the canonical path
|
||||
// (e.g. https://api.openai.com/v1/chat/completions) so the per-request
|
||||
// path is ignored in that case. When upstream_url is a bare host
|
||||
// (https://api.openai.com), the request path is appended.
|
||||
func composeURL(upstream, reqPath string) (string, error) {
|
||||
u, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: parse upstream_url %q: %w", upstream, err)
|
||||
}
|
||||
if u.Path == "" || u.Path == "/" {
|
||||
u.Path = reqPath
|
||||
}
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// applyAuthHeader writes the appropriate authorization header for the
|
||||
// provider. OpenAI/Anthropic/most providers use Bearer; Anthropic
|
||||
// historically uses x-api-key + anthropic-version, but accepts Bearer
|
||||
// too via the OpenAI-compatible path. Default to Bearer when provider
|
||||
// is empty (passthrough mode where the operator doesn't claim a
|
||||
// provider).
|
||||
func applyAuthHeader(req *http.Request, provider, key string) {
|
||||
switch provider {
|
||||
case providerAnthropic:
|
||||
req.Header.Set("x-api-key", key)
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
}
|
||||
|
||||
// isHopByHopHeader returns true for headers that should not be
|
||||
// forwarded from the client request to the upstream (RFC 7230 §6.1
|
||||
// hop-by-hop list, plus a few that the http.Client sets itself).
|
||||
func isHopByHopHeader(name string) bool {
|
||||
switch strings.ToLower(name) {
|
||||
case "connection", "proxy-connection", "keep-alive", "transfer-encoding",
|
||||
"te", "trailer", "upgrade", "host", "content-length":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
206
backend/go/cloud-proxy/proxy_test.go
Normal file
206
backend/go/cloud-proxy/proxy_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// helper: run a CloudProxy in-process via grpc.Provide so tests can
|
||||
// call Forward through the public Backend interface without listening
|
||||
// on a real socket.
|
||||
func newInProcClient(t *testing.T, proxy *CloudProxy) grpc.Backend {
|
||||
t.Helper()
|
||||
addr := "test://" + t.Name()
|
||||
grpc.Provide(addr, proxy)
|
||||
return grpc.NewClient(addr, true, nil, false)
|
||||
}
|
||||
|
||||
func TestForward_PassthroughEcho(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Fake upstream: echoes the request body back, prefixed with a
|
||||
// canary so the test can assert both that the body reached the
|
||||
// upstream and the response made it back to the client.
|
||||
gotBody := make(chan string, 1)
|
||||
gotAuth := make(chan string, 1)
|
||||
gotPath := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody <- string(body)
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
gotPath <- r.URL.Path
|
||||
w.Header().Set("X-Echo", "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("echo: " + string(body)))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_FAKE_KEY", "sk-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{{Name: "Content-Type", Value: "application/json"}},
|
||||
BodyChunk: []byte(`{"prompt":`),
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{BodyChunk: []byte(`"hi"}`)})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.CloseSend()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// First reply: status + headers.
|
||||
first, err := stream.Recv()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(first.Status).To(Equal(int32(http.StatusOK)))
|
||||
g.Expect(hasHeader(first.Headers, "X-Echo", "true")).To(BeTrue())
|
||||
|
||||
// Subsequent replies: body.
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
body = append(body, r.BodyChunk...)
|
||||
}
|
||||
g.Expect(string(body)).To(Equal(`echo: {"prompt":"hi"}`))
|
||||
|
||||
// Upstream observations.
|
||||
var gotBodyVal, gotAuthVal, gotPathVal string
|
||||
g.Eventually(gotBody).Should(Receive(&gotBodyVal), "upstream never saw body")
|
||||
g.Expect(gotBodyVal).To(Equal(`{"prompt":"hi"}`))
|
||||
g.Eventually(gotAuth).Should(Receive(&gotAuthVal), "upstream never saw auth header")
|
||||
g.Expect(gotAuthVal).To(Equal("Bearer sk-fake"))
|
||||
g.Eventually(gotPath).Should(Receive(&gotPathVal), "upstream never saw path")
|
||||
g.Expect(gotPathVal).To(Equal("/v1/chat/completions"))
|
||||
}
|
||||
|
||||
func TestForward_AnthropicAuthHeader(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
gotXAPIKey := make(chan string, 1)
|
||||
gotVersion := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotXAPIKey <- r.Header.Get("x-api-key")
|
||||
gotVersion <- r.Header.Get("anthropic-version")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_KEY", "sk-ant-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{Path: "/v1/messages", Method: "POST"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv() // drain status
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
g.Expect(<-gotXAPIKey).To(Equal("sk-ant-fake"))
|
||||
g.Expect(<-gotVersion).NotTo(BeEmpty())
|
||||
}
|
||||
|
||||
func TestLoad_ValidatesConfig(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
|
||||
err := cp.Load(&pb.ModelOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("ProxyOptions"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("upstream_url"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: "rewrite",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unknown mode"))
|
||||
|
||||
// translate + openai should load successfully (Phase 5).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/chat/completions",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// translate + anthropic should load successfully (Phase 6).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/messages",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
ApiKeyEnv: "DEFINITELY_UNSET_ENV_VAR_XYZ",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unset"))
|
||||
}
|
||||
|
||||
func TestForward_RejectsWithoutLoad(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, err = stream.Recv()
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("not loaded"))
|
||||
}
|
||||
|
||||
func hasHeader(hs []*pb.ForwardHeader, name, value string) bool {
|
||||
for _, h := range hs {
|
||||
if strings.EqualFold(h.GetName(), name) && h.GetValue() == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
6
backend/go/cloud-proxy/run.sh
Executable file
6
backend/go/cloud-proxy/run.sh
Executable file
@@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
exec $CURDIR/cloud-proxy "$@"
|
||||
232
backend/go/cloud-proxy/toolcalls_test.go
Normal file
232
backend/go/cloud-proxy/toolcalls_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// OpenAI: non-streaming tool call response. Verify the response is
|
||||
// mapped to Reply.ChatDeltas[].ToolCalls with id/name/arguments intact,
|
||||
// and usage tokens land on Reply.PromptTokens / Reply.Tokens.
|
||||
func TestPredictRich_OpenAI_ToolCalls(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"resp-1",
|
||||
"choices":[{
|
||||
"index":0,
|
||||
"message":{
|
||||
"role":"assistant",
|
||||
"content":"",
|
||||
"tool_calls":[
|
||||
{"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"SF\"}"}},
|
||||
{"id":"call_def","type":"function","function":{"name":"get_time","arguments":"{\"tz\":\"PT\"}"}}
|
||||
]
|
||||
},
|
||||
"finish_reason":"tool_calls"
|
||||
}],
|
||||
"usage":{"prompt_tokens":42,"completion_tokens":18,"total_tokens":60}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(42)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(18)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
tcs := reply.GetChatDeltas()[0].GetToolCalls()
|
||||
g.Expect(tcs).To(HaveLen(2))
|
||||
g.Expect(tcs[0].GetId()).To(Equal("call_abc"))
|
||||
g.Expect(tcs[0].GetName()).To(Equal("get_weather"))
|
||||
g.Expect(tcs[0].GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
g.Expect(tcs[1].GetId()).To(Equal("call_def"))
|
||||
g.Expect(tcs[1].GetName()).To(Equal("get_time"))
|
||||
}
|
||||
|
||||
// OpenAI: streaming tool call. Arguments arrive as a sequence of
|
||||
// delta chunks; the consumer is expected to concatenate by tool index.
|
||||
// Verify each chunk reaches the channel and the assembled arguments
|
||||
// match the input.
|
||||
func TestPredictStreamRich_OpenAI_ToolCallDeltas(t *testing.T) {
|
||||
chunks := []string{
|
||||
// Frame 0: announce the tool call (id + name, no args yet).
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_xyz","type":"function","function":{"name":"search"}}]}}]}`,
|
||||
// Frames 1-3: arguments arrive in fragments.
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"q\":"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"clo"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"uds\"}"}}]}}]}`,
|
||||
// Stop frame.
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find something"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolName string
|
||||
toolID string
|
||||
toolIndex int32 = -1
|
||||
argsBuf strings.Builder
|
||||
)
|
||||
for reply := range results {
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if toolIndex == -1 {
|
||||
toolIndex = tc.GetIndex()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("call_xyz"))
|
||||
g.Expect(toolName).To(Equal("search"))
|
||||
g.Expect(toolIndex).To(Equal(int32(0)))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"clouds"}`))
|
||||
}
|
||||
|
||||
// Anthropic: non-streaming tool_use block. The block appears in
|
||||
// Content[] alongside text blocks; the input field is a structured
|
||||
// JSON object. Map to ToolCallDelta with arguments as serialised JSON
|
||||
// so downstream OpenAI-shaped consumers see a familiar format.
|
||||
func TestPredictRich_Anthropic_ToolUse(t *testing.T) {
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"msg_1","type":"message","role":"assistant",
|
||||
"content":[
|
||||
{"type":"text","text":"Let me check that."},
|
||||
{"type":"tool_use","id":"toolu_01","name":"weather","input":{"location":"SF"}}
|
||||
],
|
||||
"model":"claude","usage":{"input_tokens":12,"output_tokens":34}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal("Let me check that."))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(12)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(34)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
g.Expect(reply.GetChatDeltas()[0].GetToolCalls()).To(HaveLen(1))
|
||||
tc := reply.GetChatDeltas()[0].GetToolCalls()[0]
|
||||
g.Expect(tc.GetId()).To(Equal("toolu_01"))
|
||||
g.Expect(tc.GetName()).To(Equal("weather"))
|
||||
g.Expect(tc.GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
}
|
||||
|
||||
// Anthropic: streaming tool_use. content_block_start announces the
|
||||
// tool's id + name; input_json_delta events carry argument fragments
|
||||
// which the consumer accumulates. message_delta carries final usage.
|
||||
func TestPredictStreamRich_Anthropic_InputJSONDelta(t *testing.T) {
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
// Block 0 is a tool_use; consumer should allocate a slot.
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_42\",\"name\":\"lookup\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"q\\\":\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"\\\"rain\\\"}\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "rain?"}},
|
||||
Tokens: 64,
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolID, toolName string
|
||||
argsBuf strings.Builder
|
||||
finalTokens int32
|
||||
)
|
||||
for reply := range results {
|
||||
if reply.GetTokens() > 0 && len(reply.GetChatDeltas()) == 0 {
|
||||
finalTokens = reply.GetTokens()
|
||||
continue
|
||||
}
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("toolu_42"))
|
||||
g.Expect(toolName).To(Equal("lookup"))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"rain"}`))
|
||||
g.Expect(finalTokens).To(Equal(int32(7)))
|
||||
}
|
||||
|
||||
// Sanity: the legacy Predict() (string, error) signature still works
|
||||
// — it delegates to PredictRich and extracts Message.
|
||||
func TestPredict_LegacyWrapper_OpenAI(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "hi"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello"))
|
||||
}
|
||||
@@ -8,6 +8,6 @@ import (
|
||||
|
||||
func assert(cond bool, msg string) {
|
||||
if !cond {
|
||||
xlog.Fatal().Stack().Msg(msg)
|
||||
xlog.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
package main
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
// LocalAI's in-process vector store, exposed as a gRPC backend. Keep
|
||||
// the implementation here — NOT in a pkg/ library imported by the main
|
||||
// LocalAI process. The whole point of the gRPC surface is that vector
|
||||
// storage is a backend like any other (local-store, qdrant, pinecone,
|
||||
// ...) and can be swapped without changing the routing/recognition
|
||||
// code that consumes it.
|
||||
//
|
||||
// Storage is a sorted parallel-slice (keys [][]float32, values
|
||||
// [][]byte). Set/Delete preserve the sort so Get can binary-search.
|
||||
// Find scans linearly and uses a heap to keep the top-K — fine for
|
||||
// the tens-to-thousands range. The "normalized fast path" (Find when
|
||||
// every stored key has unit magnitude AND the query is normalized)
|
||||
// skips the per-item magnitude calculation.
|
||||
//
|
||||
// Concurrency: base.SingleThread serialises gRPC calls so the
|
||||
// non-thread-safe slice/heap manipulation here is sound.
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
@@ -10,30 +25,27 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
base.SingleThread
|
||||
|
||||
// The sorted keys
|
||||
keys [][]float32
|
||||
// The sorted values
|
||||
keys [][]float32
|
||||
values [][]byte
|
||||
|
||||
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
|
||||
// TODO: Should we normalize incoming keys if they are not instead?
|
||||
// keysAreNormalized stays true until any non-unit-magnitude key
|
||||
// is added; once false, the magnitude-aware fallback path is
|
||||
// used by Find. Re-evaluated only at Set time, never again on
|
||||
// its own — a deletion of the offending key does NOT flip it
|
||||
// back to true (the bookkeeping cost would dominate the gain).
|
||||
keysAreNormalized bool
|
||||
// The first key decides the length of the keys
|
||||
keyLen int
|
||||
}
|
||||
|
||||
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
|
||||
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
|
||||
type Pair struct {
|
||||
Key []float32
|
||||
Value []byte
|
||||
// keyLen is the dimension of every stored key. -1 means "no
|
||||
// keys yet, dimension is open". Dimension mismatch on Set is
|
||||
// rejected so cosine similarity (which requires equal-length
|
||||
// vectors) doesn't silently mis-match.
|
||||
keyLen int
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
@@ -45,334 +57,278 @@ func NewStore() *Store {
|
||||
}
|
||||
}
|
||||
|
||||
func compareSlices(k1, k2 []float32) int {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
return slices.Compare(k1, k2)
|
||||
}
|
||||
|
||||
func hasKey(unsortedSlice [][]float32, target []float32) bool {
|
||||
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
|
||||
return compareSlices(k, target) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
|
||||
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
|
||||
return compareSlices(k, t)
|
||||
})
|
||||
}
|
||||
|
||||
func isSortedPairs(kvs []Pair) bool {
|
||||
for i := 1; i < len(kvs); i++ {
|
||||
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isSortedKeys(keys [][]float32) bool {
|
||||
for i := 1; i < len(keys); i++ {
|
||||
if compareSlices(keys[i-1], keys[i]) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
|
||||
ks := make([][]float32, len(keys))
|
||||
|
||||
for i, k := range keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
|
||||
slices.SortFunc(ks, compareSlices)
|
||||
|
||||
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
|
||||
assert(isSortedKeys(ks), "keys are not sorted")
|
||||
|
||||
return ks
|
||||
}
|
||||
|
||||
// Load is a no-op — local-store has no on-disk artefact. opts.Model is
|
||||
// just a namespace identifier; isolation is already handled upstream
|
||||
// (ModelLoader spawns a fresh local-store process per (backend,
|
||||
// model) tuple, so each namespace is its own Store{} instance).
|
||||
func (s *Store) Load(opts *pb.ModelOptions) error {
|
||||
// local-store is an in-memory vector store with no on-disk artefact to
|
||||
// load — opts.Model is just a namespace identifier. The old `!= ""` guard
|
||||
// rejected any non-empty model name with "not implemented", which broke
|
||||
// callers that pass a namespace to isolate embedding spaces (face vs.
|
||||
// voice biometrics both go through local-store but need distinct stores
|
||||
// so ArcFace 512-D and ECAPA-TDNN 192-D don't collide). Namespace
|
||||
// isolation is already handled upstream: ModelLoader spawns a fresh
|
||||
// local-store process per (backend, model) tuple, so each namespace is
|
||||
// its own Store{} instance. Nothing to do here beyond accepting the load.
|
||||
_ = opts
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort the incoming kvs and merge them with the existing sorted kvs
|
||||
func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
values := store.UnwrapValues(opts.Values)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Set: no keys to add")
|
||||
}
|
||||
|
||||
if len(opts.Keys) != len(opts.Values) {
|
||||
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
|
||||
if len(keys) != len(values) {
|
||||
return fmt.Errorf("local-store: Set: len(keys) = %d, len(values) = %d", len(keys), len(values))
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
s.keyLen = len(keys[0])
|
||||
} else if len(keys[0]) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key length %d does not match existing %d", len(keys[0]), s.keyLen)
|
||||
}
|
||||
|
||||
kvs := make([]Pair, len(opts.Keys))
|
||||
|
||||
for i, k := range opts.Keys {
|
||||
if s.keysAreNormalized && !isNormalized(k.Floats) {
|
||||
kvs := make([]incomingPair, len(keys))
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
if s.keysAreNormalized && !isNormalized(k) {
|
||||
s.keysAreNormalized = false
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = k.Floats[:5]
|
||||
} else {
|
||||
sample = k.Floats
|
||||
}
|
||||
xlog.Debug("Key is not normalized", "sample", sample)
|
||||
}
|
||||
|
||||
kvs[i] = Pair{
|
||||
Key: k.Floats,
|
||||
Value: opts.Values[i].Bytes,
|
||||
}
|
||||
kvs[i] = incomingPair{key: k, value: values[i]}
|
||||
}
|
||||
|
||||
slices.SortFunc(kvs, func(a, b Pair) int {
|
||||
return compareSlices(a.Key, b.Key)
|
||||
})
|
||||
|
||||
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
|
||||
assert(isSortedPairs(kvs), "keys are not sorted")
|
||||
|
||||
l := len(kvs) + len(s.keys)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
i, j := 0, 0
|
||||
for {
|
||||
if i+j >= l {
|
||||
break
|
||||
}
|
||||
|
||||
if i >= len(kvs) {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
continue
|
||||
}
|
||||
|
||||
if j >= len(s.keys) {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
c := compareSlices(kvs[i].Key, s.keys[j])
|
||||
if c < 0 {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
} else if c > 0 {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
} else {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
|
||||
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
slices.SortFunc(kvs, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) })
|
||||
|
||||
merged := mergeSortedPairs(s.keys, s.values, kvs)
|
||||
s.keys = merged.keys
|
||||
s.values = merged.values
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Set: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Set: keys/values length skew")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to delete")
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Delete: no keys to delete")
|
||||
}
|
||||
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
l := len(s.keys) - len(ks)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
tail_ks := s.keys
|
||||
tail_vs := s.values
|
||||
for _, k := range ks {
|
||||
j, found := findInSortedSlice(tail_ks, k)
|
||||
|
||||
if found {
|
||||
merge_ks = append(merge_ks, tail_ks[:j]...)
|
||||
merge_vs = append(merge_vs, tail_vs[:j]...)
|
||||
tail_ks = tail_ks[j+1:]
|
||||
tail_vs = tail_vs[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
|
||||
}
|
||||
|
||||
xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs))
|
||||
}
|
||||
|
||||
merge_ks = append(merge_ks, tail_ks...)
|
||||
merge_vs = append(merge_vs, tail_vs...)
|
||||
|
||||
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
|
||||
assert(isSortedKeys(s.keys), "keys are not sorted")
|
||||
assert(func() bool {
|
||||
for _, k := range ks {
|
||||
if _, found := findInSortedSlice(s.keys, k); found {
|
||||
return false
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Delete: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}(), "Keys to delete still present")
|
||||
|
||||
if len(s.keys) != l {
|
||||
xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l)
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
|
||||
mergedK := make([][]float32, 0, len(s.keys))
|
||||
mergedV := make([][]byte, 0, len(s.keys))
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if ok {
|
||||
mergedK = append(mergedK, tailK[:j]...)
|
||||
mergedV = append(mergedV, tailV[:j]...)
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
}
|
||||
mergedK = append(mergedK, tailK...)
|
||||
mergedV = append(mergedV, tailV...)
|
||||
s.keys = mergedK
|
||||
s.values = mergedV
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Delete: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Delete: keys/values length skew")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StoresGet fetches values for the given keys. Missing keys are
|
||||
// omitted from the result rather than reported as an error — callers
|
||||
// compare returned-key length against requested-key length to detect
|
||||
// them. Returned slices are aligned.
|
||||
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys))
|
||||
pbValues := make([]*pb.StoresValue, 0, len(opts.Keys))
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
if len(s.keys) == 0 {
|
||||
xlog.Debug("Get: No keys in store")
|
||||
return pb.StoresGetResult{}, nil
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("local-store: Get: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
|
||||
tail_k := s.keys
|
||||
tail_v := s.values
|
||||
for i, k := range ks {
|
||||
j, found := findInSortedSlice(tail_k, k)
|
||||
|
||||
if found {
|
||||
pbKeys = append(pbKeys, &pb.StoresKey{
|
||||
Floats: k,
|
||||
})
|
||||
pbValues = append(pbValues, &pb.StoresValue{
|
||||
Bytes: tail_v[j],
|
||||
})
|
||||
|
||||
tail_k = tail_k[j+1:]
|
||||
tail_v = tail_v[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k))
|
||||
var foundKeys [][]float32
|
||||
var foundValues [][]byte
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
foundKeys = append(foundKeys, tailK[j])
|
||||
foundValues = append(foundValues, tailV[j])
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
|
||||
if len(pbKeys) != len(opts.Keys) {
|
||||
xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys))
|
||||
}
|
||||
|
||||
return pb.StoresGetResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Keys: store.WrapKeys(foundKeys),
|
||||
Values: store.WrapValues(foundValues),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StoresFind returns the topK nearest stored entries by cosine
|
||||
// similarity, ordered most-similar first. An empty store returns
|
||||
// empty slices and no error.
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
query := opts.Key.Floats
|
||||
topK := int(opts.TopK)
|
||||
if topK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: topK = %d, must be >= 1", topK)
|
||||
}
|
||||
if len(s.keys) == 0 {
|
||||
return pb.StoresFindResult{}, nil
|
||||
}
|
||||
if len(query) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: query length %d does not match existing %d", len(query), s.keyLen)
|
||||
}
|
||||
|
||||
var keys [][]float32
|
||||
var values [][]byte
|
||||
var sims []float32
|
||||
if s.keysAreNormalized && isNormalized(query) {
|
||||
keys, values, sims = s.findNormalized(query, topK)
|
||||
} else {
|
||||
keys, values, sims = s.findFallback(query, topK)
|
||||
}
|
||||
return pb.StoresFindResult{
|
||||
Keys: store.WrapKeys(keys),
|
||||
Values: store.WrapValues(values),
|
||||
Similarities: sims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) findNormalized(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
assert(s.keysAreNormalized, "findNormalized: s.keysAreNormalized is false")
|
||||
assert(isNormalized(query), "findNormalized: query is not unit-length")
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot float32
|
||||
for j := range k {
|
||||
dot += query[j] * k[j]
|
||||
}
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("findNormalized: dot %f out of [-1, 1] — keysAreNormalized invariant violated", dot))
|
||||
heap.Push(&pq, &priorityItem{similarity: dot, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func (s *Store) findFallback(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
var qmag float64
|
||||
for _, v := range query {
|
||||
qmag += float64(v) * float64(v)
|
||||
}
|
||||
qmag = math.Sqrt(qmag)
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot, kmag float64
|
||||
for j := range k {
|
||||
dot += float64(query[j]) * float64(k[j])
|
||||
kmag += float64(k[j]) * float64(k[j])
|
||||
}
|
||||
denom := qmag * math.Sqrt(kmag)
|
||||
var sim float32
|
||||
if denom > 0 {
|
||||
sim = float32(dot / denom)
|
||||
}
|
||||
heap.Push(&pq, &priorityItem{similarity: sim, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func isNormalized(k []float32) bool {
|
||||
var sum float64
|
||||
|
||||
for _, v := range k {
|
||||
v64 := float64(v)
|
||||
sum += v64 * v64
|
||||
sum += float64(v) * float64(v)
|
||||
}
|
||||
|
||||
s := math.Sqrt(sum)
|
||||
|
||||
return s >= 0.99 && s <= 1.01
|
||||
mag := math.Sqrt(sum)
|
||||
return mag >= 0.99 && mag <= 1.01
|
||||
}
|
||||
|
||||
// TODO: This we could replace with handwritten SIMD code
|
||||
func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
type incomingPair struct {
|
||||
key []float32
|
||||
value []byte
|
||||
}
|
||||
|
||||
var dot float32
|
||||
for i := range len(k1) {
|
||||
dot += k1[i] * k2[i]
|
||||
type pairs struct {
|
||||
keys [][]float32
|
||||
values [][]byte
|
||||
}
|
||||
|
||||
// mergeSortedPairs merges (existing, incoming) into a fresh sorted
|
||||
// slice. Equal keys take the incoming value — Set is upsert.
|
||||
func mergeSortedPairs(existingK [][]float32, existingV [][]byte, incoming []incomingPair) pairs {
|
||||
assert(slices.IsSortedFunc(existingK, slices.Compare[[]float32]), "mergeSortedPairs: existing not sorted")
|
||||
assert(slices.IsSortedFunc(incoming, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) }), "mergeSortedPairs: incoming not sorted")
|
||||
l := len(existingK) + len(incoming)
|
||||
mk := make([][]float32, 0, l)
|
||||
mv := make([][]byte, 0, l)
|
||||
i, j := 0, 0
|
||||
for i < len(incoming) || j < len(existingK) {
|
||||
switch {
|
||||
case j >= len(existingK):
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case i >= len(incoming):
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
c := slices.Compare(incoming[i].key, existingK[j])
|
||||
switch {
|
||||
case c < 0:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case c > 0:
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot))
|
||||
|
||||
// 2.0 * (1.0 - dot) would be the Euclidean distance
|
||||
return dot
|
||||
return pairs{keys: mk, values: mv}
|
||||
}
|
||||
|
||||
type PriorityItem struct {
|
||||
Similarity float32
|
||||
Key []float32
|
||||
Value []byte
|
||||
type priorityItem struct {
|
||||
similarity float32
|
||||
key []float32
|
||||
value []byte
|
||||
}
|
||||
|
||||
type PriorityQueue []*PriorityItem
|
||||
type priorityQueue []*priorityItem
|
||||
|
||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq PriorityQueue) Less(i, j int) bool {
|
||||
// Inverted because the most similar should be at the top
|
||||
return pq[i].Similarity < pq[j].Similarity
|
||||
}
|
||||
|
||||
func (pq PriorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Push(x any) {
|
||||
item := x.(*PriorityItem)
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Pop() any {
|
||||
func (pq priorityQueue) Len() int { return len(pq) }
|
||||
func (pq priorityQueue) Less(i, j int) bool { return pq[i].similarity < pq[j].similarity }
|
||||
func (pq priorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
|
||||
func (pq *priorityQueue) Push(x any) { *pq = append(*pq, x.(*priorityItem)) }
|
||||
func (pq *priorityQueue) Pop() any {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
@@ -380,142 +336,16 @@ func (pq *PriorityQueue) Pop() any {
|
||||
return item
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
for i, k := range s.keys {
|
||||
sim := normalizedCosineSimilarity(tk, k)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: sim,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot, mag2 float64
|
||||
for i := range len(k1) {
|
||||
dot += float64(k1[i] * k2[i])
|
||||
mag2 += float64(k2[i] * k2[i])
|
||||
}
|
||||
|
||||
sim := float32(dot / (mag1 * math.Sqrt(mag2)))
|
||||
|
||||
assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim))
|
||||
|
||||
return sim
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
var mag1 float64
|
||||
for _, v := range tk {
|
||||
mag1 += float64(v * v)
|
||||
}
|
||||
mag1 = math.Sqrt(mag1)
|
||||
|
||||
for i, k := range s.keys {
|
||||
dist := cosineSimilarity(tk, k, mag1)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: dist,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
|
||||
if len(tk) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen)
|
||||
}
|
||||
|
||||
if opts.TopK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK)
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Key.Floats)
|
||||
} else {
|
||||
if len(opts.Key.Floats) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
if s.keysAreNormalized && isNormalized(tk) {
|
||||
return s.StoresFindNormalized(opts)
|
||||
} else {
|
||||
if s.keysAreNormalized {
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = tk[:5]
|
||||
} else {
|
||||
sample = tk
|
||||
}
|
||||
xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample)
|
||||
}
|
||||
|
||||
return s.StoresFindFallback(opts)
|
||||
func drainPQ(pq *priorityQueue) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
n := pq.Len()
|
||||
keys = make([][]float32, n)
|
||||
values = make([][]byte, n)
|
||||
similarities = make([]float32, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
item := heap.Pop(pq).(*priorityItem)
|
||||
keys[i] = item.key
|
||||
values[i] = item.value
|
||||
similarities[i] = item.similarity
|
||||
}
|
||||
return keys, values, similarities
|
||||
}
|
||||
|
||||
13
backend/go/local-store/store_suite_test.go
Normal file
13
backend/go/local-store/store_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLocalStore(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "local-store test suite")
|
||||
}
|
||||
284
backend/go/local-store/store_test.go
Normal file
284
backend/go/local-store/store_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package main
|
||||
|
||||
// Regression suite for the local-store gRPC backend. Exercises the
|
||||
// Stores{Set,Get,Find,Delete} surface — the only public contract.
|
||||
// Callers (face/voice recognition, the routing KNN classifier) reach
|
||||
// this code via grpc.Backend, so testing at the wire-shaped boundary
|
||||
// matches the production import shape.
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("StoresSet", func() {
|
||||
It("rejects empty input", func() {
|
||||
Expect(NewStore().StoresSet(&pb.StoresSetOptions{})).NotTo(Succeed(), "Set with no keys should fail")
|
||||
})
|
||||
|
||||
It("rejects key/value length mismatch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("a"), []byte("b")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "len(keys) != len(values) should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch on later add", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("3d")})
|
||||
err := s.StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "dimension mismatch on later Set should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch within batch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}, {1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("3d"), []byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "mixed-dimension within one batch should fail")
|
||||
})
|
||||
|
||||
It("merges sorted and updates existing key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.3, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("c"), []byte("a")})
|
||||
mustSet(s, [][]float32{{0.2, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("b"), []byte("a-updated")})
|
||||
Expect(s.keys).To(HaveLen(3))
|
||||
got := singleGet(s, []float32{0.1, 0, 0})
|
||||
Expect(string(got)).To(Equal("a-updated"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresGet", func() {
|
||||
It("round-trips multi-key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c")},
|
||||
)
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("omits missing keys rather than erroring", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.1, 0, 0}, {0.9, 0, 0}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresDelete", func() {
|
||||
It("removes and preserves sort", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0, 0}, {0.2, 0, 0}, {0.3, 0, 0}, {0.4, 0, 0}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c"), []byte("d")},
|
||||
)
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.2, 0, 0}, {0.4, 0, 0}}),
|
||||
})).To(Succeed())
|
||||
Expect(s.keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("tolerates missing keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.9, 0, 0}}),
|
||||
})).To(Succeed(), "delete of missing key should succeed")
|
||||
Expect(s.keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresFind", func() {
|
||||
It("returns normalized top-K", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{
|
||||
normalizeVec([]float32{1, 0, 0}),
|
||||
normalizeVec([]float32{0, 1, 0}),
|
||||
normalizeVec([]float32{0, 0, 1}),
|
||||
},
|
||||
[][]byte{[]byte("x"), []byte("y"), []byte("z")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{0.9, 0.1, 0})},
|
||||
TopK: 2,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", res.Similarities[1]), "results not sorted desc by similarity")
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
})
|
||||
|
||||
It("falls back for non-normalized keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{2, 0, 0}, {0, 3, 0}}, [][]byte{[]byte("x"), []byte("y")})
|
||||
Expect(s.keysAreNormalized).To(BeFalse(), "store should report non-normalized after Set with magnitude > 1")
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{4, 0, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", float32(0.99)))
|
||||
Expect(res.Similarities[0]).To(BeNumerically("<=", float32(1.01)))
|
||||
})
|
||||
|
||||
It("rejects zero topK", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 0,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with topK=0 should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with mismatched dimension should fail")
|
||||
})
|
||||
|
||||
It("returns empty result on empty store", func() {
|
||||
res, err := NewStore().StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 5,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred(), "Find on empty store should succeed")
|
||||
Expect(res.Keys).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("handles topK larger than store", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{normalizeVec([]float32{1, 0, 0}), normalizeVec([]float32{0, 1, 0})},
|
||||
[][]byte{[]byte("x"), []byte("y")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{1, 0, 0})},
|
||||
TopK: 10,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresLoad", func() {
|
||||
It("is a no-op", func() {
|
||||
Expect(NewStore().Load(&pb.ModelOptions{Model: "any-namespace"})).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
func BenchmarkStoresFindNormalized(b *testing.B) {
|
||||
const dim = 768
|
||||
for _, n := range []int{8, 32, 128, 512} {
|
||||
b.Run(fmtN(n), func(b *testing.B) {
|
||||
s := buildStore(b, n, dim)
|
||||
query := normalizeVec(randVec(dim, 42))
|
||||
req := &pb.StoresFindOptions{Key: &pb.StoresKey{Floats: query}, TopK: 1}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := s.StoresFind(req); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
func mustSet(s *Store, keys [][]float32, values [][]byte) {
|
||||
ExpectWithOffset(1, s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)})).To(Succeed())
|
||||
}
|
||||
|
||||
func singleGet(s *Store, key []float32) []byte {
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{Keys: wrapKeys([][]float32{key})})
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
if len(res.Values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return res.Values[0].Bytes
|
||||
}
|
||||
|
||||
func wrapKeys(in [][]float32) []*pb.StoresKey {
|
||||
out := make([]*pb.StoresKey, len(in))
|
||||
for i, k := range in {
|
||||
out[i] = &pb.StoresKey{Floats: k}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func wrapValues(in [][]byte) []*pb.StoresValue {
|
||||
out := make([]*pb.StoresValue, len(in))
|
||||
for i, v := range in {
|
||||
out[i] = &pb.StoresValue{Bytes: v}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildStore(tb testing.TB, n, dim int) *Store {
|
||||
tb.Helper()
|
||||
s := NewStore()
|
||||
keys := make([][]float32, n)
|
||||
values := make([][]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
keys[i] = normalizeVec(randVec(dim, int64(i)+1))
|
||||
values[i] = []byte{byte(i)}
|
||||
}
|
||||
if err := s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)}); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func randVec(dim int, seed int64) []float32 {
|
||||
r := rand.New(rand.NewPCG(uint64(seed), 0xabcdef))
|
||||
v := make([]float32, dim)
|
||||
for i := range v {
|
||||
v[i] = float32(r.NormFloat64())
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func normalizeVec(v []float32) []float32 {
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
mag := math.Sqrt(sum)
|
||||
if mag == 0 {
|
||||
return v
|
||||
}
|
||||
out := make([]float32, len(v))
|
||||
for i, x := range v {
|
||||
out[i] = float32(float64(x) / mag)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fmtN(n int) string {
|
||||
return map[int]string{8: "n=8", 32: "n=32", 128: "n=128", 512: "n=512"}[n]
|
||||
}
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=90e87bc846f17059771efb8aaa31e9ef0cab6f78
|
||||
STABLEDIFFUSION_GGML_VERSION?=a397e03488cc27e1a42da646b82dfce9f50741c0
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -376,6 +376,8 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *clip_g_path = "";
|
||||
const char *t5xxl_path = "";
|
||||
const char *vae_path = "";
|
||||
const char *audio_vae_path = "";
|
||||
const char *embeddings_connectors_path = "";
|
||||
const char *scheduler_str = "";
|
||||
const char *sampler = "";
|
||||
const char *clip_vision_path = "";
|
||||
@@ -431,6 +433,12 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "vae_path")) {
|
||||
vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "audio_vae_path")) {
|
||||
audio_vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "embeddings_connectors_path")) {
|
||||
embeddings_connectors_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "scheduler")) {
|
||||
scheduler_str = optval;
|
||||
}
|
||||
@@ -563,6 +571,8 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
ctx_params.taesd_path = taesd_path;
|
||||
ctx_params.control_net_path = control_net_path;
|
||||
if (lora_dir && strlen(lora_dir) > 0) {
|
||||
@@ -1188,6 +1198,9 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
p->high_noise_sample_params.scheduler = scheduler;
|
||||
p->high_noise_sample_params.flow_shift = flow_shift;
|
||||
|
||||
// Pin output fps in params; upstream uses it for audio sync (and we also mux at this rate).
|
||||
p->fps = fps;
|
||||
|
||||
// Load init/end reference images if provided (resized to output dims).
|
||||
uint8_t* init_buf = nullptr;
|
||||
uint8_t* end_buf = nullptr;
|
||||
@@ -1206,11 +1219,14 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
|
||||
// Generate
|
||||
int num_frames_out = 0;
|
||||
sd_image_t* frames = generate_video(sd_c, p, &num_frames_out);
|
||||
sd_image_t* frames = nullptr;
|
||||
sd_audio_t* audio = nullptr;
|
||||
bool ok = generate_video(sd_c, p, &frames, &num_frames_out, &audio);
|
||||
std::free(p);
|
||||
|
||||
if (!frames || num_frames_out == 0) {
|
||||
if (!ok || !frames || num_frames_out == 0) {
|
||||
fprintf(stderr, "generate_video produced no frames\n");
|
||||
if (audio) free_sd_audio(audio);
|
||||
if (init_buf) free(init_buf);
|
||||
if (end_buf) free(end_buf);
|
||||
return 1;
|
||||
@@ -1224,6 +1240,7 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
if (frames[i].data) free(frames[i].data);
|
||||
}
|
||||
free(frames);
|
||||
if (audio) free_sd_audio(audio);
|
||||
if (init_buf) free(init_buf);
|
||||
if (end_buf) free(end_buf);
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=338cce1e58133261753243802a0e7a430118866d
|
||||
WHISPER_CPP_VERSION?=0ccd896f5b882628e1c077f9769735ef4ce52860
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -847,6 +847,35 @@
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-vibevoice"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-vibevoice"
|
||||
icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4
|
||||
- &liquid-audio
|
||||
urls:
|
||||
- https://github.com/Liquid4All/liquid-audio
|
||||
- https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B
|
||||
description: |
|
||||
LiquidAI LFM2 / LFM2.5 Audio Python backend. End-to-end speech-to-speech, ASR,
|
||||
TTS (4 baked voices), and text chat from a single 1.5B model. Wraps the
|
||||
upstream `liquid-audio` package; supports fine-tuning via LocalAI's
|
||||
/v1/fine-tuning/jobs endpoint.
|
||||
tags:
|
||||
- speech-to-speech
|
||||
- any-to-any
|
||||
- text-to-speech
|
||||
- speech-to-text
|
||||
- TTS
|
||||
- ASR
|
||||
- realtime
|
||||
license: LFM-Open-License-v1.0
|
||||
name: "liquid-audio"
|
||||
alias: "liquid-audio"
|
||||
capabilities:
|
||||
nvidia: "cuda12-liquid-audio"
|
||||
intel: "intel-liquid-audio"
|
||||
amd: "rocm-liquid-audio"
|
||||
default: "cpu-liquid-audio"
|
||||
nvidia-cuda-13: "cuda13-liquid-audio"
|
||||
nvidia-cuda-12: "cuda12-liquid-audio"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-liquid-audio"
|
||||
icon: https://cdn-avatars.huggingface.co/v1/production/uploads/61b8e2ba285851687028d395/7_6D7rWrLxp2hb6OHSV1p.png
|
||||
- &qwen-tts
|
||||
urls:
|
||||
- https://github.com/QwenLM/Qwen3-TTS
|
||||
@@ -3437,6 +3466,77 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-vibevoice"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-vibevoice
|
||||
## liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "liquid-audio-development"
|
||||
capabilities:
|
||||
nvidia: "cuda12-liquid-audio-development"
|
||||
intel: "intel-liquid-audio-development"
|
||||
amd: "rocm-liquid-audio-development"
|
||||
default: "cpu-liquid-audio-development"
|
||||
nvidia-cuda-13: "cuda13-liquid-audio-development"
|
||||
nvidia-cuda-12: "cuda12-liquid-audio-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-liquid-audio-development"
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "cpu-liquid-audio"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "cpu-liquid-audio-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "cuda12-liquid-audio"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "cuda12-liquid-audio-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "cuda13-liquid-audio"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "cuda13-liquid-audio-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "intel-liquid-audio"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "intel-liquid-audio-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "rocm-liquid-audio"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "rocm-liquid-audio-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "cuda13-nvidia-l4t-arm64-liquid-audio"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-liquid-audio
|
||||
- !!merge <<: *liquid-audio
|
||||
name: "cuda13-nvidia-l4t-arm64-liquid-audio-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-liquid-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-liquid-audio
|
||||
## qwen-tts
|
||||
- !!merge <<: *qwen-tts
|
||||
name: "qwen-tts-development"
|
||||
|
||||
23
backend/python/liquid-audio/Makefile
Normal file
23
backend/python/liquid-audio/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: liquid-audio
|
||||
liquid-audio:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: liquid-audio
|
||||
@echo "Running liquid-audio..."
|
||||
bash run.sh
|
||||
@echo "liquid-audio run."
|
||||
|
||||
.PHONY: test
|
||||
test: liquid-audio
|
||||
@echo "Testing liquid-audio..."
|
||||
bash test.sh
|
||||
@echo "liquid-audio tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
871
backend/python/liquid-audio/backend.py
Normal file
871
backend/python/liquid-audio/backend.py
Normal file
@@ -0,0 +1,871 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Liquid Audio backend for LocalAI.
|
||||
|
||||
Wraps LiquidAI's `liquid-audio` Python package (https://github.com/Liquid4All/liquid-audio).
|
||||
The same model serves four roles, selected by the `mode` option at load time:
|
||||
chat, asr, tts, s2s. Fine-tuning is exposed via StartFineTune.
|
||||
"""
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
import grpc
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors # noqa: E402
|
||||
from python_utils import parse_options # noqa: E402
|
||||
|
||||
import backend_pb2 # noqa: E402
|
||||
import backend_pb2_grpc # noqa: E402
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
# Voice id → system-prompt suffix. The model only ships these four voices.
|
||||
VOICE_PROMPTS = {
|
||||
"us_male": "Perform TTS. Use the US male voice.",
|
||||
"us_female": "Perform TTS. Use the US female voice.",
|
||||
"uk_male": "Perform TTS. Use the UK male voice.",
|
||||
"uk_female": "Perform TTS. Use the UK female voice.",
|
||||
}
|
||||
DEFAULT_VOICE = "us_female"
|
||||
|
||||
# Special-token IDs that LFM2-Audio emits to delimit modality boundaries.
|
||||
# Sourced from liquid_audio/model/lfm2_audio.py (see generate_sequential/_sample_*).
|
||||
TEXT_END_TOKEN = 130 # <|text_end|>
|
||||
AUDIO_START_TOKEN = 128 # <|audio_start|>
|
||||
IM_END_TOKEN = 7 # <|im_end|>
|
||||
AUDIO_EOS_CODE = 2048 # signals end-of-audio in any codebook position
|
||||
|
||||
_PATCHED_LOCAL_PATHS = False
|
||||
|
||||
|
||||
def _patch_liquid_audio_local_paths():
|
||||
"""Make liquid_audio.utils.get_model_dir() tolerate local directories.
|
||||
|
||||
Upstream always passes its argument to huggingface_hub.snapshot_download,
|
||||
which only accepts `owner/repo` ids. LocalAI's gallery hands us absolute
|
||||
paths under <ModelPath>/<owner>/<repo>, so we intercept snapshot_download
|
||||
in the liquid_audio.utils namespace and return the directory as-is when
|
||||
it already exists on disk. Idempotent.
|
||||
"""
|
||||
global _PATCHED_LOCAL_PATHS
|
||||
if _PATCHED_LOCAL_PATHS:
|
||||
return
|
||||
import liquid_audio.utils as _la_utils
|
||||
_orig_snapshot_download = _la_utils.snapshot_download
|
||||
|
||||
def _local_first_snapshot_download(repo_id, revision=None, **kwargs):
|
||||
if isinstance(repo_id, (str, os.PathLike)) and os.path.isdir(str(repo_id)):
|
||||
return str(repo_id)
|
||||
return _orig_snapshot_download(repo_id, revision=revision, **kwargs)
|
||||
|
||||
_la_utils.snapshot_download = _local_first_snapshot_download
|
||||
_PATCHED_LOCAL_PATHS = True
|
||||
|
||||
|
||||
def _select_device():
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
|
||||
class ActiveJob:
|
||||
"""Tracks an in-flight fine-tune so FineTuneProgress can stream from its queue."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = queue.Queue()
|
||||
self.thread = None
|
||||
self.stopped = False
|
||||
self.completed = False
|
||||
self.error = None
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.processor = None
|
||||
self.model = None
|
||||
self.device = "cpu"
|
||||
self.dtype = None
|
||||
self.options = {}
|
||||
self.model_id = None
|
||||
self.active_job = None
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return str(self.options.get("mode", "chat")).lower()
|
||||
|
||||
@property
|
||||
def voice(self):
|
||||
v = str(self.options.get("voice", DEFAULT_VOICE)).lower()
|
||||
return v if v in VOICE_PROMPTS else DEFAULT_VOICE
|
||||
|
||||
|
||||
def Free(self, request, context):
|
||||
# Called by LocalAI when unloading the model. Drop GPU tensors so the
|
||||
# next load starts from a clean state instead of bumping into OOM.
|
||||
try:
|
||||
for attr in ("model", "processor", "tokenizer"):
|
||||
if hasattr(self, attr):
|
||||
try:
|
||||
delattr(self, attr)
|
||||
except Exception:
|
||||
pass
|
||||
import gc
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
except Exception as exc:
|
||||
print(f"Free failed: {exc}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=str(exc))
|
||||
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
try:
|
||||
import torch
|
||||
|
||||
self.options = parse_options(request.Options)
|
||||
if self.options.get("voice") and self.options["voice"] not in VOICE_PROMPTS:
|
||||
print(f"Warning: unknown voice '{self.options['voice']}'; defaulting to '{DEFAULT_VOICE}'",
|
||||
file=sys.stderr)
|
||||
|
||||
requested_device = self.options.get("device")
|
||||
self.device = requested_device or _select_device()
|
||||
if self.device == "cuda" and not torch.cuda.is_available():
|
||||
return backend_pb2.Result(success=False, message="CUDA requested but not available")
|
||||
if self.device == "mps" and not (hasattr(torch.backends, "mps") and
|
||||
torch.backends.mps.is_available()):
|
||||
print("MPS not available; falling back to CPU", file=sys.stderr)
|
||||
self.device = "cpu"
|
||||
|
||||
dtype_name = str(self.options.get("dtype", "bfloat16")).lower()
|
||||
self.dtype = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"bf16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
"fp16": torch.float16,
|
||||
"half": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"fp32": torch.float32,
|
||||
}.get(dtype_name, torch.bfloat16)
|
||||
|
||||
# request.Model holds the raw `parameters.model` value (an HF
|
||||
# repo id like "LiquidAI/LFM2.5-Audio-1.5B"); request.ModelFile
|
||||
# is LocalAI's ModelPath-prefixed local copy that exists only
|
||||
# when the gallery supplied a `files:` list. Mirror the
|
||||
# transformers/vibevoice convention: prefer the repo id and
|
||||
# only switch to the local path if it's been staged on disk.
|
||||
model_id = request.Model
|
||||
if not model_id:
|
||||
model_id = request.ModelFile
|
||||
if not model_id:
|
||||
return backend_pb2.Result(success=False, message="No model identifier provided")
|
||||
if request.ModelFile and os.path.isdir(request.ModelFile):
|
||||
model_id = request.ModelFile
|
||||
self.model_id = model_id
|
||||
|
||||
# Pure fine-tune jobs don't need an in-memory inference model — the
|
||||
# Trainer instantiates its own copy at StartFineTune time.
|
||||
if self.mode == "finetune":
|
||||
print(f"Loaded liquid-audio backend in fine-tune mode (model id: {model_id})",
|
||||
file=sys.stderr)
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor
|
||||
|
||||
# liquid_audio's from_pretrained unconditionally routes through
|
||||
# huggingface_hub.snapshot_download, which rejects local paths
|
||||
# (HFValidationError on `/models/LiquidAI/LFM2.5-Audio-1.5B`).
|
||||
# When LocalAI's gallery has already staged the weights on disk,
|
||||
# short-circuit the download to return the local directory.
|
||||
_patch_liquid_audio_local_paths()
|
||||
|
||||
print(f"Loading liquid-audio model '{model_id}' on {self.device} ({self.dtype})",
|
||||
file=sys.stderr)
|
||||
self.processor = LFM2AudioProcessor.from_pretrained(model_id, device=self.device).eval()
|
||||
self.model = LFM2AudioModel.from_pretrained(
|
||||
model_id, device=self.device, dtype=self.dtype
|
||||
).eval()
|
||||
|
||||
print(f"Liquid-audio mode={self.mode}, voice={self.voice}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
except Exception as exc:
|
||||
print(f"LoadModel failed: {exc}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=str(exc))
|
||||
|
||||
|
||||
def Predict(self, request, context):
|
||||
try:
|
||||
text = "".join(self._generate_text_stream(request))
|
||||
return backend_pb2.Reply(message=text.encode("utf-8"))
|
||||
except Exception as exc:
|
||||
print(f"Predict failed: {exc}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(exc))
|
||||
return backend_pb2.Reply()
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
try:
|
||||
for delta in self._generate_text_stream(request):
|
||||
yield backend_pb2.Reply(message=delta.encode("utf-8"))
|
||||
except Exception as exc:
|
||||
print(f"PredictStream failed: {exc}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(exc))
|
||||
|
||||
|
||||
def VAD(self, request, context):
|
||||
# Stub voice-activity detector: RMS-energy threshold over 30ms frames at
|
||||
# 16 kHz. Good enough for the realtime endpoint's handleVAD loop, which
|
||||
# only inspects segment presence + last segment end. The proper signal
|
||||
# would come from the model's audio encoder, but that ride-along is a
|
||||
# PR-D scope item — until then this keeps the legacy pipeline path
|
||||
# working without forcing the operator to install a separate VAD model.
|
||||
import numpy as np
|
||||
try:
|
||||
audio = np.asarray(request.audio, dtype=np.float32)
|
||||
if audio.size == 0:
|
||||
return backend_pb2.VADResponse(segments=[])
|
||||
|
||||
sample_rate = 16000
|
||||
frame_size = sample_rate * 30 // 1000 # 30ms → 480 samples
|
||||
threshold = float(self.options.get("vad_rms_threshold", 0.01))
|
||||
min_speech_frames = int(self.options.get("vad_min_speech_frames", 2)) # ≥60ms
|
||||
# handleVAD ticks every 300 ms and only inspects segment presence
|
||||
# + last segment end relative to silence_threshold (~500 ms). Cap
|
||||
# the analysed window to the tail of the buffer so we don't redo
|
||||
# the entire growing utterance every tick.
|
||||
window_s = float(self.options.get("vad_window_s", 5.0))
|
||||
window_samples = int(window_s * sample_rate)
|
||||
time_offset_s = 0.0
|
||||
if audio.size > window_samples:
|
||||
time_offset_s = (audio.size - window_samples) / sample_rate
|
||||
audio = audio[-window_samples:]
|
||||
|
||||
n_frames = audio.size // frame_size
|
||||
if n_frames == 0:
|
||||
return backend_pb2.VADResponse(segments=[])
|
||||
frames = audio[: n_frames * frame_size].reshape(n_frames, frame_size)
|
||||
rms = np.sqrt(np.mean(frames ** 2, axis=1))
|
||||
speech = rms > threshold
|
||||
|
||||
def _emit(start_idx, end_idx, out):
|
||||
if end_idx - start_idx >= min_speech_frames:
|
||||
out.append(backend_pb2.VADSegment(
|
||||
start=time_offset_s + start_idx * frame_size / sample_rate,
|
||||
end=time_offset_s + end_idx * frame_size / sample_rate,
|
||||
))
|
||||
|
||||
segments = []
|
||||
start_idx = None
|
||||
for i, is_speech in enumerate(speech):
|
||||
if is_speech and start_idx is None:
|
||||
start_idx = i
|
||||
elif not is_speech and start_idx is not None:
|
||||
_emit(start_idx, i, segments)
|
||||
start_idx = None
|
||||
if start_idx is not None:
|
||||
_emit(start_idx, n_frames, segments)
|
||||
return backend_pb2.VADResponse(segments=segments)
|
||||
except Exception as exc:
|
||||
print(f"VAD failed: {exc}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(exc))
|
||||
return backend_pb2.VADResponse(segments=[])
|
||||
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
if self.model is None or self.processor is None:
|
||||
return backend_pb2.Result(success=False, message="Model not loaded")
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from liquid_audio import ChatState
|
||||
|
||||
voice = request.voice.lower() if request.voice else self.voice
|
||||
voice = voice.removeprefix("lfm2:").removeprefix("lfm:")
|
||||
if voice not in VOICE_PROMPTS:
|
||||
voice = self.voice
|
||||
system_prompt = VOICE_PROMPTS[voice]
|
||||
|
||||
chat = ChatState(self.processor)
|
||||
chat.new_turn("system")
|
||||
chat.add_text(system_prompt)
|
||||
chat.end_turn()
|
||||
chat.new_turn("user")
|
||||
chat.add_text(request.text or "")
|
||||
chat.end_turn()
|
||||
chat.new_turn("assistant")
|
||||
|
||||
audio_top_k = int(self.options.get("audio_top_k", 64))
|
||||
audio_temp = float(self.options.get("audio_temperature", 0.8))
|
||||
max_new = int(self.options.get("max_new_tokens", 2048))
|
||||
|
||||
audio_out = []
|
||||
for tok in self.model.generate_sequential(
|
||||
**chat,
|
||||
max_new_tokens=max_new,
|
||||
audio_temperature=audio_temp,
|
||||
audio_top_k=audio_top_k,
|
||||
):
|
||||
if tok.numel() > 1:
|
||||
audio_out.append(tok)
|
||||
|
||||
if len(audio_out) <= 1:
|
||||
return backend_pb2.Result(success=False, message="No audio frames generated")
|
||||
|
||||
# Drop the trailing end-of-audio frame, matching the package's examples.
|
||||
audio_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
|
||||
waveform = self.processor.decode(audio_codes)
|
||||
|
||||
out_path = request.dst
|
||||
if not out_path:
|
||||
return backend_pb2.Result(success=False, message="dst path is required")
|
||||
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
||||
# soundfile in preference to torchaudio.save — the latter routes
|
||||
# through torchcodec, whose native libs need NVIDIA NPP that we
|
||||
# don't bundle in the cuda13 image.
|
||||
import soundfile as _sf
|
||||
_sf.write(out_path, waveform.cpu().numpy().squeeze(0).T, 24_000)
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
except Exception as exc:
|
||||
print(f"TTS failed: {exc}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=str(exc))
|
||||
|
||||
|
||||
def AudioToAudioStream(self, request_iterator, context):
|
||||
"""Bidirectional any-to-any speech-to-speech stream.
|
||||
|
||||
See `backend.proto` AudioToAudioStream for the wire protocol. Audio
|
||||
is decoded once per turn here; chunked detokenization for sub-second
|
||||
TTFB is left to a future iteration once the LFM2AudioDetokenizer
|
||||
gains a streaming entry point.
|
||||
"""
|
||||
try:
|
||||
yield from self._audio_to_audio_stream(request_iterator, context)
|
||||
except Exception as exc:
|
||||
print(f"AudioToAudioStream failed: {exc}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
yield backend_pb2.AudioToAudioResponse(
|
||||
event="error",
|
||||
meta=json.dumps({"message": str(exc)}).encode("utf-8"),
|
||||
)
|
||||
|
||||
def _audio_to_audio_stream(self, request_iterator, context):
|
||||
if self.model is None or self.processor is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from liquid_audio import ChatState
|
||||
|
||||
cfg = None
|
||||
chat = None
|
||||
input_sample_rate = 16000
|
||||
output_sample_rate = 24000
|
||||
sequence = 0
|
||||
|
||||
def _new_event(event, **kwargs):
|
||||
nonlocal sequence
|
||||
sequence += 1
|
||||
kwargs.setdefault("sequence", sequence)
|
||||
return backend_pb2.AudioToAudioResponse(event=event, **kwargs)
|
||||
|
||||
def _ensure_chat():
|
||||
"""Build a fresh ChatState seeded with the system prompt."""
|
||||
nonlocal chat
|
||||
chat = ChatState(self.processor)
|
||||
system_prompt = (cfg.system_prompt if cfg and cfg.system_prompt
|
||||
else "Respond with interleaved text and audio.")
|
||||
chat.new_turn("system")
|
||||
chat.add_text(system_prompt)
|
||||
chat.end_turn()
|
||||
|
||||
# Buffers for the in-flight user turn
|
||||
pcm_buffer = bytearray()
|
||||
|
||||
def _consume_user_turn():
|
||||
nonlocal pcm_buffer
|
||||
if not pcm_buffer:
|
||||
return
|
||||
# Avoid the bytes(pcm_buffer) copy and let the float widen happen
|
||||
# in-place: numpy view → torch view → in-place divide.
|
||||
import numpy as np
|
||||
arr = np.frombuffer(memoryview(pcm_buffer), dtype=np.int16)
|
||||
wav = torch.from_numpy(arr).to(torch.float32).div_(32768.0).unsqueeze(0)
|
||||
chat.new_turn("user")
|
||||
chat.add_audio(wav, input_sample_rate)
|
||||
chat.end_turn()
|
||||
pcm_buffer = bytearray()
|
||||
|
||||
def _run_generation():
|
||||
"""Run generate_interleaved; yield response events as we go."""
|
||||
chat.new_turn("assistant")
|
||||
audio_top_k = int(self.options.get("audio_top_k", 4))
|
||||
audio_temp = float(self.options.get("audio_temperature", 1.0))
|
||||
text_top_k = int(self.options.get("text_top_k", 0)) or None
|
||||
text_temp = float(self.options.get("text_temperature", 0)) or None
|
||||
max_new = int(self.options.get("max_new_tokens", 512))
|
||||
|
||||
audio_tokens = []
|
||||
for tok in self.model.generate_interleaved(
|
||||
**chat,
|
||||
max_new_tokens=max_new,
|
||||
text_temperature=text_temp,
|
||||
text_top_k=text_top_k,
|
||||
audio_temperature=audio_temp,
|
||||
audio_top_k=audio_top_k,
|
||||
):
|
||||
if tok.numel() == 1:
|
||||
if tok.item() == IM_END_TOKEN:
|
||||
break
|
||||
text = self.processor.text.decode(tok)
|
||||
if not text:
|
||||
continue
|
||||
yield _new_event(
|
||||
"response.audio_transcript.delta",
|
||||
meta=json.dumps({"delta": text}).encode("utf-8"),
|
||||
)
|
||||
else:
|
||||
audio_tokens.append(tok)
|
||||
|
||||
# Detokenize the accumulated audio at end-of-turn — the
|
||||
# LFM2AudioDetokenizer is non-streaming today.
|
||||
if len(audio_tokens) > 1:
|
||||
audio_codes = torch.stack(audio_tokens[:-1], 1).unsqueeze(0)
|
||||
waveform = self.processor.decode(audio_codes)
|
||||
# Convert to s16le PCM bytes at output_sample_rate
|
||||
if output_sample_rate != 24000:
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform.cpu(), 24000, output_sample_rate
|
||||
)
|
||||
pcm = (waveform.cpu().squeeze(0).clamp(-1, 1) * 32767.0).to(
|
||||
torch.int16
|
||||
).numpy().tobytes()
|
||||
yield _new_event(
|
||||
"response.audio.delta",
|
||||
pcm=pcm,
|
||||
sample_rate=output_sample_rate,
|
||||
)
|
||||
|
||||
yield _new_event("response.done", meta=b"{}")
|
||||
|
||||
for req in request_iterator:
|
||||
if not context.is_active():
|
||||
return
|
||||
payload = req.WhichOneof("payload")
|
||||
if payload == "config":
|
||||
cfg = req.config
|
||||
if cfg.input_sample_rate > 0:
|
||||
input_sample_rate = cfg.input_sample_rate
|
||||
if cfg.output_sample_rate > 0:
|
||||
output_sample_rate = cfg.output_sample_rate
|
||||
# The first config implicitly resets state.
|
||||
_ensure_chat()
|
||||
pcm_buffer = bytearray()
|
||||
elif payload == "frame":
|
||||
if chat is None:
|
||||
_ensure_chat()
|
||||
if req.frame.pcm:
|
||||
pcm_buffer.extend(req.frame.pcm)
|
||||
if req.frame.end_of_input:
|
||||
_consume_user_turn()
|
||||
yield from _run_generation()
|
||||
elif payload == "control":
|
||||
event = req.control.event
|
||||
if event == "input_audio_buffer.commit":
|
||||
_consume_user_turn()
|
||||
yield from _run_generation()
|
||||
elif event == "response.cancel":
|
||||
# Synchronous generation here means cancel can only
|
||||
# take effect between turns; we ack so the client unblocks.
|
||||
yield _new_event("response.done", meta=b'{"cancelled":true}')
|
||||
elif event == "session.update":
|
||||
# Free-form session re-config; treat as a soft reset.
|
||||
_ensure_chat()
|
||||
pcm_buffer = bytearray()
|
||||
# Unknown events are ignored — forward-compatible.
|
||||
|
||||
|
||||
def AudioTranscription(self, request, context):
|
||||
try:
|
||||
if self.model is None or self.processor is None:
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
|
||||
import torchaudio
|
||||
from liquid_audio import ChatState
|
||||
|
||||
audio_path = request.dst
|
||||
if not audio_path:
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
|
||||
chat = ChatState(self.processor)
|
||||
chat.new_turn("system")
|
||||
chat.add_text("Perform ASR.")
|
||||
chat.end_turn()
|
||||
chat.new_turn("user")
|
||||
# soundfile in preference to torchaudio.load — the latter routes
|
||||
# through torchcodec which needs NVIDIA NPP libs we don't bundle.
|
||||
import soundfile as _sf
|
||||
import torch
|
||||
audio_np, sr = _sf.read(audio_path, dtype="float32", always_2d=True)
|
||||
wav = torch.from_numpy(audio_np.T) # (channels, samples)
|
||||
if wav.shape[0] > 1:
|
||||
# Down-mix to mono — the processor expects a single channel
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
chat.add_audio(wav, sr)
|
||||
chat.end_turn()
|
||||
chat.new_turn("assistant")
|
||||
|
||||
max_new = int(self.options.get("max_new_tokens", 1024))
|
||||
|
||||
pieces = []
|
||||
for tok in self.model.generate_sequential(**chat, max_new_tokens=max_new):
|
||||
if tok.numel() == 1:
|
||||
if tok.item() == IM_END_TOKEN:
|
||||
break
|
||||
pieces.append(self.processor.text.decode(tok))
|
||||
|
||||
text = "".join(pieces).strip()
|
||||
duration_ms = int((wav.shape[1] / sr) * 1000)
|
||||
segment = backend_pb2.TranscriptSegment(
|
||||
id=0, start=0, end=duration_ms, text=text, tokens=[],
|
||||
)
|
||||
return backend_pb2.TranscriptResult(segments=[segment], text=text)
|
||||
except Exception as exc:
|
||||
print(f"AudioTranscription failed: {exc}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
|
||||
|
||||
def StartFineTune(self, request, context):
|
||||
if self.active_job is not None and not self.active_job.completed:
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id="", success=False,
|
||||
message="A fine-tuning job is already running",
|
||||
)
|
||||
|
||||
job_id = request.job_id or str(uuid.uuid4())
|
||||
job = ActiveJob(job_id)
|
||||
self.active_job = job
|
||||
|
||||
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
|
||||
job.thread = thread
|
||||
thread.start()
|
||||
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id=job_id, success=True, message="Training started",
|
||||
)
|
||||
|
||||
def FineTuneProgress(self, request, context):
|
||||
if self.active_job is None or self.active_job.job_id != request.job_id:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(f"Job {request.job_id} not found")
|
||||
return
|
||||
|
||||
job = self.active_job
|
||||
while True:
|
||||
try:
|
||||
update = job.progress_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
if job.completed or job.stopped:
|
||||
break
|
||||
if not context.is_active():
|
||||
break
|
||||
continue
|
||||
if update is None:
|
||||
break
|
||||
yield update
|
||||
if update.status in ("completed", "failed", "stopped"):
|
||||
break
|
||||
|
||||
def StopFineTune(self, request, context):
|
||||
# We can't kill the Accelerate training loop mid-step cleanly from here;
|
||||
# LocalAI's job manager kills the backend process on stop. The flag below
|
||||
# at least lets the progress stream terminate quickly.
|
||||
if self.active_job is not None and self.active_job.job_id == request.job_id:
|
||||
self.active_job.stopped = True
|
||||
self.active_job.progress_queue.put(None)
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def _run_training(self, request, job):
|
||||
try:
|
||||
self._do_train(request, job)
|
||||
job.completed = True
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="completed", message="Training completed",
|
||||
progress_percent=100.0,
|
||||
))
|
||||
except Exception as exc:
|
||||
job.error = str(exc)
|
||||
job.completed = True
|
||||
print(f"Training failed: {exc}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="failed", message=str(exc),
|
||||
))
|
||||
finally:
|
||||
job.progress_queue.put(None)
|
||||
|
||||
def _do_train(self, request, job):
|
||||
from liquid_audio import LFM2AudioModel # noqa: F401 (sanity import)
|
||||
from liquid_audio.data.dataloader import LFM2DataLoader
|
||||
from liquid_audio.trainer import Trainer
|
||||
|
||||
model_id = request.model or self.model_id or "LiquidAI/LFM2.5-Audio-1.5B"
|
||||
|
||||
dataset_path = request.dataset_source
|
||||
if not dataset_path:
|
||||
raise ValueError("dataset_source is required (path to a preprocessed dataset)")
|
||||
|
||||
extras = dict(request.extra_options) if request.extra_options else {}
|
||||
val_path = extras.get("val_dataset")
|
||||
|
||||
# Map FineTuneRequest hyperparameters to liquid_audio.Trainer constructor args
|
||||
lr = request.learning_rate or 3e-5
|
||||
max_steps = request.max_steps or 1000
|
||||
warmup_steps = request.warmup_steps or min(100, max_steps // 10)
|
||||
batch_size = request.batch_size or 16
|
||||
save_interval = request.save_steps or max(1, max_steps // 4)
|
||||
|
||||
output_dir = request.output_dir or os.path.join(
|
||||
os.environ.get("LIQUID_AUDIO_OUTPUT_DIR", "/tmp"),
|
||||
f"liquid-audio-{job.job_id}",
|
||||
)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_dataset",
|
||||
message=f"Loading preprocessed dataset from {dataset_path}",
|
||||
))
|
||||
train_data = LFM2DataLoader(dataset_path)
|
||||
val_data = LFM2DataLoader(val_path) if val_path else None
|
||||
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_model",
|
||||
message=f"Loading base model {model_id}",
|
||||
))
|
||||
|
||||
# The Liquid Trainer logs via self.accelerator.print; we subclass it to
|
||||
# also push progress events onto the queue every logging_interval steps.
|
||||
progress_q = job.progress_queue
|
||||
|
||||
class QueuedTrainer(Trainer):
|
||||
def log(self_, model_output):
|
||||
if self_.step > 0 and self_.step % self_.logging_interval == 0:
|
||||
try:
|
||||
loss = self_.accelerator.reduce(
|
||||
model_output.loss.detach(), reduction="mean"
|
||||
).item()
|
||||
except Exception:
|
||||
loss = float("nan")
|
||||
lr_now = self_.optimizer.param_groups[0]["lr"]
|
||||
pct = (self_.step / self_.max_steps * 100.0) if self_.max_steps else 0.0
|
||||
progress_q.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id,
|
||||
current_step=int(self_.step),
|
||||
total_steps=int(self_.max_steps),
|
||||
current_epoch=float(self_.epoch),
|
||||
loss=float(loss),
|
||||
learning_rate=float(lr_now),
|
||||
progress_percent=float(pct),
|
||||
status="training",
|
||||
))
|
||||
# Honour stop requests: raising here terminates the loop cleanly
|
||||
if job.stopped:
|
||||
raise KeyboardInterrupt("stop requested")
|
||||
return super().log(model_output)
|
||||
|
||||
def validate(self_):
|
||||
progress_q.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, current_step=int(self_.step),
|
||||
total_steps=int(self_.max_steps), status="training",
|
||||
message=f"Running validation at step {self_.step}",
|
||||
))
|
||||
return super().validate()
|
||||
|
||||
trainer = QueuedTrainer(
|
||||
model_id=model_id,
|
||||
train_data=train_data,
|
||||
val_data=val_data,
|
||||
lr=lr,
|
||||
max_steps=max_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
batch_size=batch_size,
|
||||
save_interval=save_interval,
|
||||
output_dir=output_dir,
|
||||
weight_decay=request.weight_decay or 0.1,
|
||||
)
|
||||
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="training", message="Training started",
|
||||
total_steps=int(max_steps),
|
||||
))
|
||||
trainer.train()
|
||||
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="saving",
|
||||
message=f"Saved final model to {output_dir}",
|
||||
checkpoint_path=os.path.join(output_dir, "final"),
|
||||
))
|
||||
|
||||
|
||||
def _build_chat_state(self, messages, user_prompt, tools_prelude=None):
|
||||
"""Build a ChatState from a list of (role, content) tuples plus an optional final user turn.
|
||||
|
||||
tools_prelude, when non-empty, is prepended as an extra system turn carrying
|
||||
the LFM2 tool-list block — mirrors gallery/lfm.yaml's `function:` template
|
||||
so the model sees the same prompt shape whether served via llama-cpp or here.
|
||||
"""
|
||||
from liquid_audio import ChatState
|
||||
chat = ChatState(self.processor)
|
||||
if tools_prelude:
|
||||
chat.new_turn("system")
|
||||
chat.add_text(tools_prelude)
|
||||
chat.end_turn()
|
||||
for role, content in messages:
|
||||
chat.new_turn(role)
|
||||
chat.add_text(content)
|
||||
chat.end_turn()
|
||||
if user_prompt:
|
||||
chat.new_turn("user")
|
||||
chat.add_text(user_prompt)
|
||||
chat.end_turn()
|
||||
chat.new_turn("assistant")
|
||||
return chat
|
||||
|
||||
def _collect_messages(self, request):
|
||||
"""Translate PredictOptions.Messages into (role, content) tuples."""
|
||||
out = []
|
||||
for m in request.Messages:
|
||||
role = (m.role or "user").lower()
|
||||
if role not in ("system", "user", "assistant"):
|
||||
role = "user"
|
||||
out.append((role, m.content or ""))
|
||||
return out
|
||||
|
||||
def _render_tools_prelude(self, request):
|
||||
"""Build the LFM2 `<|tool_list_start|>…<|tool_list_end|>` system prelude
|
||||
from request.Tools (OpenAI Chat-Completions tool JSON). Returns "" when
|
||||
no tools are attached. Output mirrors gallery/lfm.yaml's `function:`
|
||||
template so the model sees the same prompt whether routed via llama-cpp
|
||||
or this backend."""
|
||||
tools_raw = getattr(request, "Tools", "") or ""
|
||||
if not tools_raw:
|
||||
return ""
|
||||
try:
|
||||
tools = json.loads(tools_raw)
|
||||
except json.JSONDecodeError:
|
||||
print(f"liquid-audio: ignoring malformed Tools JSON: {tools_raw[:200]!r}",
|
||||
file=sys.stderr)
|
||||
return ""
|
||||
if not isinstance(tools, list) or not tools:
|
||||
return ""
|
||||
# The LFM2 chat template uses single-quoted Python-dict-ish syntax in
|
||||
# examples, but the tokenizer treats this whole block as opaque text;
|
||||
# JSON works fine and is what other backends emit.
|
||||
return (
|
||||
"You are a function calling AI model. You are provided with functions to "
|
||||
"execute. You may call one or more functions to assist with the user query. "
|
||||
"Don't make assumptions about what values to plug into functions.\n"
|
||||
"List of tools: <|tool_list_start|>"
|
||||
+ json.dumps(tools, separators=(",", ":"))
|
||||
+ "<|tool_list_end|>"
|
||||
)
|
||||
|
||||
def _generate_text_stream(self, request):
|
||||
"""Yield text-only deltas from generate_sequential. Caller joins for unary Predict."""
|
||||
if self.model is None or self.processor is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
messages = self._collect_messages(request)
|
||||
user_prompt = request.Prompt or None
|
||||
tools_prelude = self._render_tools_prelude(request)
|
||||
# If the request already carries Messages, Prompt is the templated form
|
||||
# of the same content — don't append a duplicate user turn.
|
||||
chat = self._build_chat_state(
|
||||
messages,
|
||||
user_prompt if not messages else None,
|
||||
tools_prelude=tools_prelude,
|
||||
)
|
||||
|
||||
max_new = request.Tokens if request.Tokens > 0 else int(self.options.get("max_new_tokens", 512))
|
||||
temperature = request.Temperature if request.Temperature > 0 else None
|
||||
top_k = request.TopK if request.TopK > 0 else None
|
||||
|
||||
for tok in self.model.generate_sequential(
|
||||
**chat,
|
||||
max_new_tokens=max_new,
|
||||
text_temperature=temperature,
|
||||
text_top_k=top_k,
|
||||
):
|
||||
if tok.numel() == 1:
|
||||
if tok.item() == IM_END_TOKEN:
|
||||
break
|
||||
yield self.processor.text.decode(tok)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print(f"Liquid-audio backend listening on {address}", file=sys.stderr, flush=True)
|
||||
|
||||
def stop(_signum, _frame):
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, stop)
|
||||
signal.signal(signal.SIGINT, stop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Liquid Audio gRPC backend")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||
args = parser.parse_args()
|
||||
serve(args.addr)
|
||||
18
backend/python/liquid-audio/install.sh
Executable file
18
backend/python/liquid-audio/install.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# liquid-audio requires Python ≥ 3.12 (per its pyproject.toml); the default
|
||||
# portable Python in libbackend.sh is 3.10. Override before sourcing.
|
||||
export PYTHON_VERSION="${PYTHON_VERSION:-3.12}"
|
||||
export PYTHON_PATCH="${PYTHON_PATCH:-11}"
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# liquid-audio's torch wheels are large; allow upgrades to satisfy transitive pins
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
installRequirements
|
||||
11
backend/python/liquid-audio/protogen.sh
Executable file
11
backend/python/liquid-audio/protogen.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runProtogen
|
||||
13
backend/python/liquid-audio/requirements-cpu.txt
Normal file
13
backend/python/liquid-audio/requirements-cpu.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch>=2.8.0
|
||||
torchaudio>=2.8.0
|
||||
torchcodec>=0.9.1
|
||||
transformers>=4.55.4
|
||||
accelerate>=1.10.1
|
||||
datasets>=4.8.4
|
||||
einops>=0.8.1
|
||||
librosa>=0.11.0
|
||||
soundfile>=0.12.1
|
||||
sentencepiece>=0.2.1
|
||||
huggingface-hub>=1.3.0
|
||||
liquid-audio>=1.2.0
|
||||
13
backend/python/liquid-audio/requirements-cublas12.txt
Normal file
13
backend/python/liquid-audio/requirements-cublas12.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torch>=2.8.0
|
||||
torchaudio>=2.8.0
|
||||
torchcodec>=0.9.1
|
||||
transformers>=4.55.4
|
||||
accelerate>=1.10.1
|
||||
datasets>=4.8.4
|
||||
einops>=0.8.1
|
||||
librosa>=0.11.0
|
||||
soundfile>=0.12.1
|
||||
sentencepiece>=0.2.1
|
||||
huggingface-hub>=1.3.0
|
||||
liquid-audio>=1.2.0
|
||||
13
backend/python/liquid-audio/requirements-cublas13.txt
Normal file
13
backend/python/liquid-audio/requirements-cublas13.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch>=2.8.0
|
||||
torchaudio>=2.8.0
|
||||
torchcodec>=0.9.1
|
||||
transformers>=4.55.4
|
||||
accelerate>=1.10.1
|
||||
datasets>=4.8.4
|
||||
einops>=0.8.1
|
||||
librosa>=0.11.0
|
||||
soundfile>=0.12.1
|
||||
sentencepiece>=0.2.1
|
||||
huggingface-hub>=1.3.0
|
||||
liquid-audio>=1.2.0
|
||||
13
backend/python/liquid-audio/requirements-hipblas.txt
Normal file
13
backend/python/liquid-audio/requirements-hipblas.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm7.0
|
||||
torch>=2.8.0
|
||||
torchaudio>=2.8.0
|
||||
torchcodec>=0.9.1
|
||||
transformers>=4.55.4
|
||||
accelerate>=1.10.1
|
||||
datasets>=4.8.4
|
||||
einops>=0.8.1
|
||||
librosa>=0.11.0
|
||||
soundfile>=0.12.1
|
||||
sentencepiece>=0.2.1
|
||||
huggingface-hub>=1.3.0
|
||||
liquid-audio>=1.2.0
|
||||
13
backend/python/liquid-audio/requirements-l4t13.txt
Normal file
13
backend/python/liquid-audio/requirements-l4t13.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp7/cu130
|
||||
torch>=2.8.0
|
||||
torchaudio>=2.8.0
|
||||
torchcodec>=0.9.1
|
||||
transformers>=4.55.4
|
||||
accelerate>=1.10.1
|
||||
datasets>=4.8.4
|
||||
einops>=0.8.1
|
||||
librosa>=0.11.0
|
||||
soundfile>=0.12.1
|
||||
sentencepiece>=0.2.1
|
||||
huggingface-hub>=1.3.0
|
||||
liquid-audio>=1.2.0
|
||||
12
backend/python/liquid-audio/requirements-mps.txt
Normal file
12
backend/python/liquid-audio/requirements-mps.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
torch>=2.8.0
|
||||
torchaudio>=2.8.0
|
||||
torchcodec>=0.9.1
|
||||
transformers>=4.55.4
|
||||
accelerate>=1.10.1
|
||||
datasets>=4.8.4
|
||||
einops>=0.8.1
|
||||
librosa>=0.11.0
|
||||
soundfile>=0.12.1
|
||||
sentencepiece>=0.2.1
|
||||
huggingface-hub>=1.3.0
|
||||
liquid-audio>=1.2.0
|
||||
3
backend/python/liquid-audio/requirements.txt
Normal file
3
backend/python/liquid-audio/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
10
backend/python/liquid-audio/run.sh
Executable file
10
backend/python/liquid-audio/run.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
89
backend/python/liquid-audio/test.py
Normal file
89
backend/python/liquid-audio/test.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Smoke tests for the liquid-audio backend.
|
||||
|
||||
These run without contacting HuggingFace or loading model weights:
|
||||
they only verify that the gRPC service starts and Health() responds.
|
||||
|
||||
To run an end-to-end inference test, set LIQUID_AUDIO_MODEL_ID
|
||||
(e.g. "LiquidAI/LFM2.5-Audio-1.5B") in the environment — see test_inference().
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import grpc
|
||||
|
||||
# Ensure generated protobuf stubs are importable
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
class TestBackend(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
addr = os.environ.get("LIQUID_AUDIO_TEST_ADDR", "localhost:50053")
|
||||
cls.addr = addr
|
||||
cls.server = subprocess.Popen(
|
||||
[sys.executable, os.path.join(os.path.dirname(__file__), "backend.py"), "--addr", addr],
|
||||
)
|
||||
time.sleep(2) # Give the server a moment to bind
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.server.terminate()
|
||||
try:
|
||||
cls.server.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
cls.server.kill()
|
||||
|
||||
def _stub(self):
|
||||
channel = grpc.insecure_channel(self.addr)
|
||||
return backend_pb2_grpc.BackendStub(channel)
|
||||
|
||||
def test_health(self):
|
||||
stub = self._stub()
|
||||
reply = stub.Health(backend_pb2.HealthMessage(), timeout=5)
|
||||
self.assertEqual(reply.message, b"OK")
|
||||
|
||||
def test_load_finetune_mode_without_weights(self):
|
||||
"""Loading in fine-tune mode should succeed without pulling model weights."""
|
||||
stub = self._stub()
|
||||
result = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(
|
||||
Model="LiquidAI/LFM2.5-Audio-1.5B",
|
||||
Options=["mode:finetune"],
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
self.assertTrue(result.success, msg=result.message)
|
||||
|
||||
@unittest.skipUnless(os.environ.get("LIQUID_AUDIO_MODEL_ID"),
|
||||
"Set LIQUID_AUDIO_MODEL_ID to run an end-to-end inference smoke test")
|
||||
def test_inference(self):
|
||||
"""End-to-end: load a real LFM2-Audio model and run one short prediction."""
|
||||
stub = self._stub()
|
||||
model_id = os.environ["LIQUID_AUDIO_MODEL_ID"]
|
||||
result = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(
|
||||
Model=model_id,
|
||||
Options=["mode:chat"],
|
||||
),
|
||||
timeout=600,
|
||||
)
|
||||
self.assertTrue(result.success, msg=result.message)
|
||||
reply = stub.Predict(
|
||||
backend_pb2.PredictOptions(
|
||||
Prompt="Hello!",
|
||||
Tokens=8,
|
||||
Temperature=0.0,
|
||||
),
|
||||
timeout=120,
|
||||
)
|
||||
self.assertGreater(len(reply.message), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/liquid-audio/test.sh
Executable file
11
backend/python/liquid-audio/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -36,15 +36,11 @@ fi
|
||||
# flash-attn-4 4.0 stable lands.
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --prerelease=allow"
|
||||
|
||||
# JetPack 7 / L4T arm64 wheels are built for cp312 and shipped via
|
||||
# pypi.jetson-ai-lab.io. Bump the venv Python so the prebuilt sglang
|
||||
# wheel resolves cleanly. The actual install on l4t13 goes through
|
||||
# pyproject.toml (see the elif branch below) so [tool.uv.sources] can
|
||||
# pin only torch/torchvision/torchaudio/sglang to the jetson-ai-lab
|
||||
# index — leaving PyPI as the path for transitive deps like
|
||||
# markdown-it-py / anthropic / propcache that the L4T mirror's proxy
|
||||
# 503s on. No --index-strategy flag here: the explicit index keeps the
|
||||
# scoping clean.
|
||||
# JetPack 7 / L4T arm64 sglang + torch wheels come straight from PyPI now
|
||||
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and sglang 0.5.11+
|
||||
# ships a cp312 aarch64 wheel pinned to that torch). They're cp312-only,
|
||||
# so bump the venv Python accordingly.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
@@ -110,27 +106,6 @@ if [ "x${BUILD_TYPE}" == "x" ] || [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
|
||||
fi
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
|
||||
popd
|
||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
||||
# [tool.uv.sources] can pin torch/torchvision/torchaudio/sglang to the
|
||||
# jetson-ai-lab index, while everything else (transitive deps and
|
||||
# PyPI-resolvable packages like transformers / accelerate) comes from
|
||||
# PyPI. Bypasses installRequirements because uv pip install -r
|
||||
# requirements.txt does not honor sources — see
|
||||
# backend/python/sglang/pyproject.toml for the rationale. Mirrors the
|
||||
# equivalent path in backend/python/vllm/install.sh.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
pushd "${backend_dir}"
|
||||
# Build deps first (matches installRequirements' requirements-install.txt
|
||||
# pass — sglang/sgl-kernel sdists need packaging/setuptools-scm in the
|
||||
# venv before they can build under --no-build-isolation).
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
||||
popd
|
||||
runProtogen
|
||||
else
|
||||
installRequirements
|
||||
fi
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the sglang backend.
|
||||
#
|
||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
||||
#
|
||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / sglang / sgl-kernel
|
||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently.
|
||||
# With `--extra-index-url` + `--index-strategy=unsafe-best-match` (the
|
||||
# historical fix in install.sh) uv would pick those proxy URLs for ordinary
|
||||
# PyPI packages — markdown-it-py, anthropic, propcache, etc. — and trip on
|
||||
# the 503s. See e.g. CI run 25439791228 (markdown-it-py-4.0.0).
|
||||
#
|
||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
||||
# This breaks the historical 503 path without losing access to the L4T
|
||||
# wheels we actually need from there. Mirrors the equivalent fix already
|
||||
# in backend/python/vllm/pyproject.toml.
|
||||
#
|
||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
||||
# pipeline through libbackend.sh's installRequirements and never read
|
||||
# this file.
|
||||
[project]
|
||||
name = "localai-sglang-l4t13"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
||||
# l4t13 path bypasses installRequirements (see install.sh).
|
||||
"grpcio==1.80.0",
|
||||
"protobuf",
|
||||
"certifi",
|
||||
"setuptools",
|
||||
"pillow",
|
||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
# sglang on jetson — the [all] extra is deliberately omitted because it
|
||||
# pulls outlines/decord, and decord has no aarch64 cp312 wheel anywhere
|
||||
# (PyPI nor the jetson-ai-lab index ships only legacy cp35-cp37). With
|
||||
# [all] uv backtracks through versions trying to satisfy decord and
|
||||
# lands on sglang==0.1.16. The 0.5.0 floor matches the only major
|
||||
# series the jetson-ai-lab sbsa/cu130 mirror currently publishes
|
||||
# (sglang==0.5.1.post2 as of 2026-05-06). Bumping to >=0.5.11 here
|
||||
# would make the build unsatisfiable until the mirror catches up.
|
||||
# Gemma 4 / MTP recipes are therefore not supported on l4t13 — those
|
||||
# features land on cublas12/cublas13 hosts that pull the newer wheel
|
||||
# from PyPI. backend.py keeps backward compat with the 0.5.x SamplingParams
|
||||
# field rename via runtime detection.
|
||||
"sglang>=0.5.0",
|
||||
# PyPI-resolvable packages that complete the runtime.
|
||||
"accelerate",
|
||||
"transformers",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "jetson-ai-lab"
|
||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "jetson-ai-lab" }
|
||||
torchvision = { index = "jetson-ai-lab" }
|
||||
torchaudio = { index = "jetson-ai-lab" }
|
||||
sglang = { index = "jetson-ai-lab" }
|
||||
15
backend/python/sglang/requirements-l4t13-after.txt
Normal file
15
backend/python/sglang/requirements-l4t13-after.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
# sglang 0.5.11+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist
|
||||
# pins torch==2.11.0 / torchaudio==2.11.0, locking an ABI-consistent set with
|
||||
# the cu130 torch wheel installed above. 0.5.11 is the floor for Gemma 4
|
||||
# support (sgl-project/sglang#21952).
|
||||
#
|
||||
# The [all] extra is deliberately NOT used on aarch64: it pulls the
|
||||
# [diffusion] sub-extra which requires `xatlas`, and xatlas ships no
|
||||
# aarch64 wheel and its sdist depends on scikit_build_core without
|
||||
# declaring it in build-system.requires — so under --no-build-isolation
|
||||
# uv can't build it. Upstream sglang gates st_attn and vsa on
|
||||
# platform_machine != aarch64 in the diffusion extra but forgot xatlas.
|
||||
# Plain `sglang` carries everything backend.py uses (Engine, ServerArgs,
|
||||
# FunctionCallParser, ReasoningParser); the [all] extras are optional
|
||||
# accelerators not required at import time.
|
||||
sglang>=0.5.11
|
||||
9
backend/python/sglang/requirements-l4t13.txt
Normal file
9
backend/python/sglang/requirements-l4t13.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
@@ -26,7 +26,7 @@ import torch.cuda
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
import transformers as transformers_module
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, pipeline
|
||||
from scipy.io import wavfile
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
@@ -200,6 +200,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
autoTokenizer = False
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.SentenceTransformer = True
|
||||
elif request.Type == "TokenClassification":
|
||||
# NER / PII tagging via HuggingFace's token-classification
|
||||
# pipeline. aggregation_strategy="simple" merges B-/I- tags
|
||||
# into single spans and gives byte offsets back. The
|
||||
# tokenizer is bundled inside the pipeline, so we skip the
|
||||
# AutoTokenizer load below.
|
||||
autoTokenizer = False
|
||||
self.tokenClassifier = pipeline(
|
||||
"token-classification",
|
||||
model=model_name,
|
||||
aggregation_strategy="simple",
|
||||
device=0 if self.CUDA else -1,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
)
|
||||
self.TokenClassification = True
|
||||
else:
|
||||
# Generic: dynamically resolve model class from transformers
|
||||
model_type = TYPE_ALIASES.get(request.Type, request.Type)
|
||||
@@ -253,6 +268,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TokenClassify(self, request, context):
|
||||
# Runs HuggingFace's token-classification pipeline and returns
|
||||
# the aggregated entity spans. The pipeline gives us byte
|
||||
# offsets via aggregation_strategy="simple" (set at load
|
||||
# time), so the caller can slice the original text without
|
||||
# re-tokenising on the Go side.
|
||||
if not getattr(self, "TokenClassification", False):
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("model was not loaded as Type=TokenClassification")
|
||||
return backend_pb2.TokenClassifyResponse()
|
||||
try:
|
||||
results = self.tokenClassifier(request.text)
|
||||
except Exception as err:
|
||||
print("TokenClassify error:", err, file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"token-classification failed: {err}")
|
||||
return backend_pb2.TokenClassifyResponse()
|
||||
|
||||
threshold = request.threshold if request.threshold > 0 else 0.0
|
||||
entities = []
|
||||
for r in results:
|
||||
score = float(r.get("score", 0.0))
|
||||
if score < threshold:
|
||||
continue
|
||||
entities.append(backend_pb2.TokenClassifyEntity(
|
||||
entity_group=str(r.get("entity_group") or r.get("entity") or ""),
|
||||
start=int(r.get("start", 0)),
|
||||
end=int(r.get("end", 0)),
|
||||
score=score,
|
||||
text=str(r.get("word", "")),
|
||||
))
|
||||
return backend_pb2.TokenClassifyResponse(entities=entities)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
set_seed(request.Seed)
|
||||
# Tokenize input
|
||||
|
||||
@@ -2,9 +2,9 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,9 +2,9 @@ torch==2.7.1
|
||||
accelerate
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,9 +2,9 @@
|
||||
torch==2.9.0
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -1,11 +1,11 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm7.0
|
||||
torch==2.10.0+rocm7.0
|
||||
accelerate
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -3,9 +3,9 @@ torch
|
||||
optimum[openvino]
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,9 +2,9 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
|
||||
@@ -13,14 +13,14 @@ else
|
||||
fi
|
||||
|
||||
# Handle l4t build profiles (Python 3.12, pip fallback) if needed.
|
||||
# unsafe-best-match is required on l4t13 because the jetson-ai-lab index
|
||||
# lists transitive deps at limited versions — without it uv pins to the
|
||||
# first matching index and fails to resolve a compatible wheel from PyPI.
|
||||
# Since PyTorch 2.11 (April 2026) PyPI ships aarch64 + cu130 manylinux wheels
|
||||
# directly for torch/torchvision/torchaudio and an aarch64 vllm wheel pinned
|
||||
# to that torch, so the jetson-ai-lab mirror is no longer needed.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
EXTRA_PIP_INSTALL_FLAGS="${EXTRA_PIP_INSTALL_FLAGS:-} --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
@@ -42,18 +42,11 @@ if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
else
|
||||
uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700
|
||||
fi
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
# JetPack 7 / L4T arm64 cu130 — vllm comes from the prebuilt SBSA wheel
|
||||
# at jetson-ai-lab. Version is unpinned: the index ships whatever build
|
||||
# matches the cu130/cp312 ABI. unsafe-best-match lets uv fall through
|
||||
# to PyPI for transitive deps not present on the jetson-ai-lab index.
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
else
|
||||
uv pip install --index-strategy=unsafe-best-match vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
fi
|
||||
elif [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
||||
# vllm 0.19+ defaults to cu130 wheels on PyPI, no extra index needed.
|
||||
elif [ "x${BUILD_PROFILE}" == "xcublas13" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
# cublas13 (x86_64) and l4t13 (aarch64) both pull vllm from PyPI now:
|
||||
# vllm 0.19+ defaults to cu130 wheels on x86_64 and vllm 0.20+ ships an
|
||||
# aarch64 manylinux wheel pinned to torch==2.11.0. No extra index needed
|
||||
# in either case.
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm --torch-backend=auto
|
||||
else
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. PyPI ships aarch64 + cu130 manylinux wheels
|
||||
# for torch/torchvision/torchaudio directly since PyTorch 2.11 (April 2026),
|
||||
# so no custom index is needed. flash-attn is dropped here: PyPI has no
|
||||
# aarch64 wheel for it, but vLLM 0.20+ bundles its own vllm_flash_attn
|
||||
# (fa2 + fa3) inside the main wheel, so it is not required at runtime.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
bitsandbytes
|
||||
flash-attn
|
||||
diffusers
|
||||
librosa
|
||||
soundfile
|
||||
|
||||
@@ -356,6 +356,133 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
async def Score(self, request, context):
|
||||
"""
|
||||
Joint log-probability of each candidate continuation given the
|
||||
shared prompt. Used by routing-policy multi-label classification
|
||||
(read the distribution rather than asking the model to emit a
|
||||
single argmax label), reranking, and reward-model scoring.
|
||||
|
||||
Implementation uses vLLM's `prompt_logprobs` to recover the
|
||||
per-token log P(token_i | tokens_<i) for the full concatenated
|
||||
sequence; the candidate's tokens are the suffix whose logprobs
|
||||
get summed. max_tokens=1 because vLLM requires at least one
|
||||
generated token; the generated token is discarded.
|
||||
"""
|
||||
if not hasattr(self, 'llm') or self.llm is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Model not loaded")
|
||||
return backend_pb2.ScoreResponse()
|
||||
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Tokenizer not available")
|
||||
return backend_pb2.ScoreResponse()
|
||||
if len(request.candidates) == 0:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("candidates must be non-empty")
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
try:
|
||||
prompt = request.prompt or ""
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
results = []
|
||||
|
||||
for candidate in request.candidates:
|
||||
# Tokenise the concatenated sequence. We can't naively
|
||||
# use len(prompt_tokens) + len(tokenizer.encode(candidate))
|
||||
# because BPE merges at the boundary may produce a
|
||||
# different tokenisation. Encoding the joined text and
|
||||
# walking the divergence point is the correct primitive.
|
||||
full_text = prompt + candidate
|
||||
full_token_ids = self.tokenizer.encode(full_text)
|
||||
|
||||
divergence = prompt_len
|
||||
min_len = min(prompt_len, len(full_token_ids))
|
||||
for i in range(min_len):
|
||||
if prompt_token_ids[i] != full_token_ids[i]:
|
||||
divergence = i
|
||||
break
|
||||
|
||||
candidate_token_ids = full_token_ids[divergence:]
|
||||
num_candidate_tokens = len(candidate_token_ids)
|
||||
if num_candidate_tokens == 0:
|
||||
results.append(backend_pb2.CandidateScore(
|
||||
log_prob=0.0,
|
||||
length_normalized_log_prob=0.0,
|
||||
num_tokens=0,
|
||||
))
|
||||
continue
|
||||
|
||||
sampling = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
prompt_logprobs=1,
|
||||
detokenize=False,
|
||||
)
|
||||
|
||||
request_id = random_uuid()
|
||||
last_output = None
|
||||
outputs_iter = self.llm.generate(
|
||||
{"prompt": full_text},
|
||||
sampling_params=sampling,
|
||||
request_id=request_id,
|
||||
)
|
||||
try:
|
||||
async for out in outputs_iter:
|
||||
last_output = out
|
||||
finally:
|
||||
try:
|
||||
await outputs_iter.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if last_output is None or not getattr(last_output, "prompt_logprobs", None):
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details("vLLM did not return prompt_logprobs")
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
prompt_logprobs = last_output.prompt_logprobs
|
||||
total = 0.0
|
||||
tokens_proto = []
|
||||
for offset, tok_id in enumerate(candidate_token_ids):
|
||||
position = divergence + offset
|
||||
if position >= len(prompt_logprobs) or prompt_logprobs[position] is None:
|
||||
continue
|
||||
entry = prompt_logprobs[position]
|
||||
lp_obj = entry.get(tok_id)
|
||||
if lp_obj is not None:
|
||||
lp = lp_obj.logprob
|
||||
else:
|
||||
# Token not in top-K; vLLM's top-1 may miss it.
|
||||
# Fall back to the lowest available logprob in the
|
||||
# entry — a conservative lower-bound on the true
|
||||
# log P, biased against this candidate.
|
||||
lp = min(v.logprob for v in entry.values())
|
||||
total += lp
|
||||
if request.include_token_logprobs:
|
||||
tokens_proto.append(backend_pb2.TokenLogProb(
|
||||
token=self.tokenizer.decode([tok_id]),
|
||||
log_prob=lp,
|
||||
))
|
||||
|
||||
cs = backend_pb2.CandidateScore(
|
||||
log_prob=total,
|
||||
num_tokens=num_candidate_tokens,
|
||||
)
|
||||
if request.length_normalize and num_candidate_tokens > 0:
|
||||
cs.length_normalized_log_prob = total / num_candidate_tokens
|
||||
if tokens_proto:
|
||||
cs.tokens.extend(tokens_proto)
|
||||
results.append(cs)
|
||||
|
||||
return backend_pb2.ScoreResponse(candidates=results)
|
||||
except Exception as e:
|
||||
print(f"Score error: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
# Build the sampling parameters
|
||||
# NOTE: this must stay in sync with the vllm backend
|
||||
|
||||
@@ -43,14 +43,11 @@ if [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# JetPack 7 / L4T arm64 wheels (torch, vllm, flash-attn) live on
|
||||
# pypi.jetson-ai-lab.io and are built for cp312, so bump the venv Python
|
||||
# accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||
#
|
||||
# l4t13 uses pyproject.toml (see the elif branch below) to pin only the
|
||||
# L4T-specific wheels to the jetson-ai-lab index via [tool.uv.sources].
|
||||
# That keeps PyPI as the resolution path for transitive deps like
|
||||
# anthropic/openai/propcache, which the L4T mirror's proxy 503s on.
|
||||
# JetPack 7 / L4T arm64 vllm + torch wheels come straight from PyPI now
|
||||
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and vllm 0.20+ ships
|
||||
# an aarch64 wheel pinned to that torch). They're cp312-only, so bump the
|
||||
# venv Python accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
@@ -103,25 +100,6 @@ if [ "x${BUILD_TYPE}" == "xintel" ]; then
|
||||
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
|
||||
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
|
||||
popd
|
||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
||||
# [tool.uv.sources] can pin torch/vllm/flash-attn/torchvision/torchaudio
|
||||
# to the jetson-ai-lab index, while everything else (transitive deps and
|
||||
# PyPI-resolvable packages like transformers) comes from PyPI. Bypasses
|
||||
# installRequirements because uv pip install -r requirements.txt does not
|
||||
# honor sources — see backend/python/vllm/pyproject.toml for the rationale.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
pushd "${backend_dir}"
|
||||
# Build deps first (matches installRequirements' requirements-install.txt
|
||||
# pass — fastsafetensors and friends need pybind11 in the venv before
|
||||
# their sdists can build under --no-build-isolation).
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
||||
popd
|
||||
runProtogen
|
||||
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
|
||||
# requirements-cpu-after.txt and compiles vllm locally against the host's
|
||||
# actual CPU. Not used by default because it takes ~30-40 minutes, but
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the vllm backend.
|
||||
#
|
||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
||||
#
|
||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / vllm / flash-attn
|
||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently. With
|
||||
# `--extra-index-url` + `--index-strategy=unsafe-best-match` (the historical
|
||||
# fix in install.sh) uv would pick those proxy URLs for ordinary PyPI
|
||||
# packages — `anthropic`, `openai`, `propcache`, `annotated-types` — and
|
||||
# trip on the 503s. See e.g. CI run 25212201349 (anthropic-0.97.0).
|
||||
#
|
||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
||||
# This breaks the historical 503 path without losing access to the L4T
|
||||
# wheels we actually need from there.
|
||||
#
|
||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
||||
# pipeline through libbackend.sh's installRequirements and never read
|
||||
# this file.
|
||||
[project]
|
||||
name = "localai-vllm-l4t13"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
||||
# l4t13 path bypasses installRequirements (see install.sh).
|
||||
"grpcio==1.80.0",
|
||||
"protobuf",
|
||||
"certifi",
|
||||
"setuptools",
|
||||
"pillow",
|
||||
"charset-normalizer>=3.4.7",
|
||||
"chardet",
|
||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"flash-attn",
|
||||
"vllm",
|
||||
# PyPI-resolvable packages that complete the runtime — accelerate,
|
||||
# transformers, bitsandbytes carry their own wheels for aarch64.
|
||||
"accelerate",
|
||||
"transformers",
|
||||
"bitsandbytes",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "jetson-ai-lab"
|
||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "jetson-ai-lab" }
|
||||
torchvision = { index = "jetson-ai-lab" }
|
||||
torchaudio = { index = "jetson-ai-lab" }
|
||||
flash-attn = { index = "jetson-ai-lab" }
|
||||
vllm = { index = "jetson-ai-lab" }
|
||||
@@ -3,5 +3,5 @@
|
||||
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
||||
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
||||
# so uv consults this index alongside PyPI.
|
||||
--extra-index-url https://wheels.vllm.ai/0.20.2/cu130
|
||||
vllm==0.20.2
|
||||
--extra-index-url https://wheels.vllm.ai/0.21.0/cu130
|
||||
vllm==0.21.0
|
||||
|
||||
4
backend/python/vllm/requirements-l4t13-after.txt
Normal file
4
backend/python/vllm/requirements-l4t13-after.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# vLLM 0.20+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist pins
|
||||
# torch==2.11.0 / torchvision==0.26.0 / torchaudio==2.11.0, locking an ABI-
|
||||
# consistent set with the cu130 torch wheel installed above.
|
||||
vllm
|
||||
8
backend/python/vllm/requirements-l4t13.txt
Normal file
8
backend/python/vllm/requirements-l4t13.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
transformers
|
||||
bitsandbytes
|
||||
@@ -375,6 +375,15 @@ impl Backend for KokorosService {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type AudioToAudioStreamStream = ReceiverStream<Result<backend::AudioToAudioResponse, Status>>;
|
||||
|
||||
async fn audio_to_audio_stream(
|
||||
&self,
|
||||
_: Request<tonic::Streaming<backend::AudioToAudioRequest>>,
|
||||
) -> Result<Response<Self::AudioToAudioStreamStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn sound_generation(
|
||||
&self,
|
||||
_: Request<backend::SoundGenerationRequest>,
|
||||
|
||||
@@ -9,11 +9,18 @@ import (
|
||||
|
||||
corebackend "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services/facerecognition"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/voicerecognition"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
@@ -51,6 +58,22 @@ type Application struct {
|
||||
faceRegistry facerecognition.Registry
|
||||
voiceRegistry voicerecognition.Registry
|
||||
authDB *gorm.DB
|
||||
metricsService *monitoring.LocalAIMetricsService
|
||||
statsRecorder *billing.Recorder
|
||||
fallbackUser *auth.User
|
||||
piiRedactor *pii.Redactor
|
||||
piiEvents pii.EventStore
|
||||
mitmCA atomic.Pointer[mitm.CA]
|
||||
mitmServer atomic.Pointer[mitm.Server]
|
||||
mitmMutex sync.Mutex // serializes Stop+Start; readers use atomic loads
|
||||
// mitmHostConflicts records duplicate-host claims across model configs.
|
||||
// Non-empty disables the MITM listener until resolved — the strict
|
||||
// 1-to-1 host↔model invariant the dispatcher relies on. Read by
|
||||
// /api/middleware/status so the admin UI can surface the cause.
|
||||
mitmHostConflicts atomic.Pointer[map[string][]string]
|
||||
routerDecisions router.DecisionStore
|
||||
routerRegistry *router.Registry
|
||||
admissionLimiter *admission.Limiter
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
p2pMutex sync.Mutex
|
||||
@@ -185,6 +208,103 @@ func (a *Application) AuthDB() *gorm.DB {
|
||||
return a.authDB
|
||||
}
|
||||
|
||||
// MetricsService returns the OTel + Prometheus metric service. nil when
|
||||
// --disable-metrics is set or initialisation failed at startup.
|
||||
//
|
||||
// The service is created in startup.go before any counter is registered
|
||||
// so that otel.SetMeterProvider runs early enough for the billing
|
||||
// recorder's counters to bind to the Prom-backed provider rather than
|
||||
// the no-op global. core/http/app.go reuses this instance instead of
|
||||
// constructing its own — two providers would orphan one set of counters
|
||||
// behind whichever provider lost the SetMeterProvider race.
|
||||
func (a *Application) MetricsService() *monitoring.LocalAIMetricsService {
|
||||
return a.metricsService
|
||||
}
|
||||
|
||||
// StatsRecorder returns the billing recorder used by the usage
|
||||
// middleware. It is non-nil whenever stats are not explicitly disabled
|
||||
// — i.e., the no-auth single-user path still gets a working recorder
|
||||
// (in-memory by default). Routes register UsageMiddleware against this
|
||||
// recorder regardless of auth state.
|
||||
func (a *Application) StatsRecorder() *billing.Recorder {
|
||||
return a.statsRecorder
|
||||
}
|
||||
|
||||
// FallbackUser is the synthetic "local" user that UsageMiddleware uses
|
||||
// to attribute requests when no authenticated user is on the context
|
||||
// (i.e., --auth is off). nil when auth is on, since real users are
|
||||
// always available there.
|
||||
func (a *Application) FallbackUser() *auth.User {
|
||||
return a.fallbackUser
|
||||
}
|
||||
|
||||
// PIIRedactor returns the regex-tier PII redactor or nil if PII
|
||||
// filtering is disabled. The chat-route middleware uses this to apply
|
||||
// redaction before dispatch.
|
||||
func (a *Application) PIIRedactor() *pii.Redactor {
|
||||
return a.piiRedactor
|
||||
}
|
||||
|
||||
// PIIEvents returns the PII event store. Same nil-when-disabled
|
||||
// semantics as PIIRedactor; admin REST and MCP read tools call List
|
||||
// against it.
|
||||
func (a *Application) PIIEvents() pii.EventStore {
|
||||
return a.piiEvents
|
||||
}
|
||||
|
||||
// MITMCA returns the cloudproxy MITM proxy's CA, or nil when the
|
||||
// MITM listener is disabled.
|
||||
func (a *Application) MITMCA() *mitm.CA { return a.mitmCA.Load() }
|
||||
|
||||
// MITMServer returns the running MITM proxy or nil.
|
||||
func (a *Application) MITMServer() *mitm.Server { return a.mitmServer.Load() }
|
||||
|
||||
// MITMHostConflicts returns a snapshot of host→[]model-name pairs that
|
||||
// are claimed by 2+ model configs. Empty when the 1-to-1 invariant
|
||||
// holds. Non-empty disables the MITM listener — read by the admin
|
||||
// status endpoint to explain why.
|
||||
func (a *Application) MITMHostConflicts() map[string][]string {
|
||||
p := a.mitmHostConflicts.Load()
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return *p
|
||||
}
|
||||
|
||||
// MITMHostOwners returns the host→model-name map, useful for the
|
||||
// admin status endpoint. The lookup is recomputed on each call to
|
||||
// stay current with model-config edits without needing a
|
||||
// MITMRestart.
|
||||
func (a *Application) MITMHostOwners() map[string]string {
|
||||
if a.backendLoader == nil {
|
||||
return nil
|
||||
}
|
||||
return a.backendLoader.MITMHostOwners().Owners
|
||||
}
|
||||
|
||||
// RouterDecisions returns the routing decision store. nil when stats
|
||||
// are disabled (--disable-stats); the RouteModel middleware skips the
|
||||
// log write in that case but still rewrites requests.
|
||||
func (a *Application) RouterDecisions() router.DecisionStore {
|
||||
return a.routerDecisions
|
||||
}
|
||||
|
||||
// RouterClassifierRegistry returns the process-wide classifier cache.
|
||||
// Shared between the OpenAI and Anthropic route middlewares so the
|
||||
// admin stats endpoint sees every live classifier — and so a
|
||||
// classifier built on the OpenAI route is reused on Anthropic.
|
||||
func (a *Application) RouterClassifierRegistry() *router.Registry {
|
||||
return a.routerRegistry
|
||||
}
|
||||
|
||||
// AdmissionLimiter returns the per-model admission limiter. The
|
||||
// admission middleware uses it to gate concurrent requests; the
|
||||
// admin status surface reads InFlight/Capacity from it for live
|
||||
// load visibility.
|
||||
func (a *Application) AdmissionLimiter() *admission.Limiter {
|
||||
return a.admissionLimiter
|
||||
}
|
||||
|
||||
// StartupConfig returns the original startup configuration (from env vars, before file loading)
|
||||
func (a *Application) StartupConfig() *config.ApplicationConfig {
|
||||
return a.startupConfig
|
||||
@@ -255,6 +375,15 @@ func (a *Application) start() error {
|
||||
a.modelLoader,
|
||||
a.galleryService,
|
||||
)
|
||||
// Wire usage tracking so the assistant's get_usage_stats tool
|
||||
// returns real data; nil values keep the tool returning a clear
|
||||
// "unavailable" error if startup ran with --disable-stats.
|
||||
assistantClient.StatsRecorder = a.statsRecorder
|
||||
assistantClient.FallbackUser = a.fallbackUser
|
||||
// PII filter — same nil-or-real wiring.
|
||||
assistantClient.PIIRedactor = a.piiRedactor
|
||||
assistantClient.PIIEvents = a.piiEvents
|
||||
assistantClient.RouterDecisions = a.routerDecisions
|
||||
if err := holder.Initialize(a.applicationConfig.Context, assistantClient, localaitools.Options{}); err != nil {
|
||||
// Why log+continue instead of fail: the assistant is an optional
|
||||
// feature; a failure here must not take down the whole server.
|
||||
|
||||
@@ -169,7 +169,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
cfg.Distributed.HealthCheckIntervalOrDefault(),
|
||||
cfg.Distributed.StaleNodeThresholdOrDefault(),
|
||||
routerAuthToken,
|
||||
cfg.Distributed.PerModelHealthCheck,
|
||||
!cfg.Distributed.DisablePerModelHealthCheck,
|
||||
)
|
||||
|
||||
// Initialize job store
|
||||
@@ -233,7 +233,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("File stager initialized (HTTP direct transfer)")
|
||||
}
|
||||
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(
|
||||
registry,
|
||||
natsClient,
|
||||
cfg.Distributed.BackendInstallTimeoutOrDefault(),
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
|
||||
146
core/application/mitm.go
Normal file
146
core/application/mitm.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func startMITMProxy(app *Application, options *config.ApplicationConfig) error {
|
||||
app.mitmMutex.Lock()
|
||||
defer app.mitmMutex.Unlock()
|
||||
return startMITMLocked(app, options)
|
||||
}
|
||||
|
||||
func startMITMLocked(app *Application, options *config.ApplicationConfig) error {
|
||||
// Validate the host↔model-config 1-to-1 invariant before binding
|
||||
// the listener. Two configs claiming the same host means the
|
||||
// dispatcher would have ambiguous PII settings; refuse to start
|
||||
// rather than silently picking one. The conflict map is published
|
||||
// for /api/middleware/status to surface in the UI.
|
||||
ownership := app.backendLoader.MITMHostOwners()
|
||||
if len(ownership.Conflicts) > 0 {
|
||||
conflicts := ownership.Conflicts
|
||||
app.mitmHostConflicts.Store(&conflicts)
|
||||
hosts := make([]string, 0, len(conflicts))
|
||||
for h := range conflicts {
|
||||
hosts = append(hosts, h)
|
||||
}
|
||||
sort.Strings(hosts)
|
||||
xlog.Error("mitm: refusing to start — duplicate host claims across model configs",
|
||||
"hosts", hosts,
|
||||
"conflicts", conflicts,
|
||||
)
|
||||
return errors.New("mitm: configuration error: duplicate host claims (see /api/middleware/status)")
|
||||
}
|
||||
app.mitmHostConflicts.Store(nil)
|
||||
|
||||
caDir := options.MITMCADir
|
||||
if caDir == "" {
|
||||
base := options.DataPath
|
||||
if base == "" {
|
||||
base = "."
|
||||
}
|
||||
caDir = filepath.Join(base, "mitm-ca")
|
||||
}
|
||||
|
||||
if app.mitmCA.Load() == nil {
|
||||
ca, err := mitm.LoadOrCreateCA(caDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ca: %w", err)
|
||||
}
|
||||
app.mitmCA.Store(ca)
|
||||
}
|
||||
|
||||
// Allowlist is exactly the set of hosts claimed by model configs.
|
||||
// No global list — admins add hosts by creating an MITM model
|
||||
// config (template available in the Add Model UI). When no config
|
||||
// claims any host, the listener still starts but every CONNECT
|
||||
// tunnels through unmodified.
|
||||
effectiveHosts := make([]string, 0, len(ownership.Owners))
|
||||
for h := range ownership.Owners {
|
||||
effectiveHosts = append(effectiveHosts, h)
|
||||
}
|
||||
sort.Strings(effectiveHosts)
|
||||
|
||||
// Per-host PII gate inherits from the owning model's pii.enabled.
|
||||
// A non-cloud-proxy backend with no explicit pii.enabled resolves
|
||||
// to false → host is intercepted but the regex pass is skipped
|
||||
// (audit events still record).
|
||||
var piiDisabled []string
|
||||
for host, modelName := range ownership.Owners {
|
||||
cfg, exists := app.backendLoader.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if !cfg.PIIIsEnabled() {
|
||||
piiDisabled = append(piiDisabled, host)
|
||||
}
|
||||
}
|
||||
|
||||
handler := mitm.NewPIIHandler(mitm.PIIHandlerOptions{
|
||||
Redactor: app.piiRedactor,
|
||||
EventStore: app.piiEvents,
|
||||
HostsWithPIIDisabled: piiDisabled,
|
||||
})
|
||||
|
||||
srv, err := mitm.NewServer(mitm.Config{
|
||||
Addr: options.MITMListen,
|
||||
CA: app.mitmCA.Load(),
|
||||
InterceptHosts: effectiveHosts,
|
||||
Handler: handler,
|
||||
EventStore: app.piiEvents,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("server: %w", err)
|
||||
}
|
||||
if err := srv.Start(); err != nil {
|
||||
return fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
app.mitmServer.Store(srv)
|
||||
|
||||
xlog.Info("mitm: cloudproxy listener started",
|
||||
"addr", srv.Addr(),
|
||||
"ca_dir", caDir,
|
||||
"intercept_hosts", effectiveHosts,
|
||||
"model_owned_hosts", len(ownership.Owners),
|
||||
"pii_disabled_hosts", len(piiDisabled),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopMITM is idempotent.
|
||||
func (a *Application) StopMITM() error {
|
||||
a.mitmMutex.Lock()
|
||||
defer a.mitmMutex.Unlock()
|
||||
stopMITMLocked(a)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartMITM reuses the existing CA so trusted clients keep
|
||||
// working across listener flips.
|
||||
func (a *Application) RestartMITM() error {
|
||||
a.mitmMutex.Lock()
|
||||
defer a.mitmMutex.Unlock()
|
||||
stopMITMLocked(a)
|
||||
if a.applicationConfig.MITMListen == "" {
|
||||
xlog.Info("mitm: cloudproxy listener stays disabled (no listen address)")
|
||||
return nil
|
||||
}
|
||||
return startMITMLocked(a, a.applicationConfig)
|
||||
}
|
||||
|
||||
func stopMITMLocked(a *Application) {
|
||||
srv := a.mitmServer.Load()
|
||||
if srv == nil {
|
||||
return
|
||||
}
|
||||
srv.Stop()
|
||||
a.mitmServer.Store(nil)
|
||||
xlog.Info("mitm: cloudproxy listener stopped")
|
||||
}
|
||||
63
core/application/router_factories.go
Normal file
63
core/application/router_factories.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// adapterConfig resolves a model name to its runtime ModelConfig, or
|
||||
// nil when the name is unknown. Shared by the router-facing factories
|
||||
// below and by ModelConfigLookup.
|
||||
func (a *Application) adapterConfig(modelName string) *config.ModelConfig {
|
||||
cfg, err := a.backendLoader.LoadModelConfigFileByNameDefaultOptions(modelName, a.applicationConfig)
|
||||
if err != nil || cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ModelConfigLookup is the lookup function the router middleware's
|
||||
// classifier validator uses to confirm classifier_model declares
|
||||
// FLAG_SCORE before binding it.
|
||||
func (a *Application) ModelConfigLookup() func(modelName string) *config.ModelConfig {
|
||||
return a.adapterConfig
|
||||
}
|
||||
|
||||
// Scorer returns a backend.Scorer bound to the named model, or nil
|
||||
// when the model is unknown. Used as a method value (app.Scorer) by
|
||||
// router.ClassifierDeps — no factory-of-factory wrapper needed.
|
||||
func (a *Application) Scorer(modelName string) backend.Scorer {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewScorer(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// Reranker returns a backend.Reranker bound to the named model, or
|
||||
// nil when unknown. The reranker model's `type:` (e.g. "colbert")
|
||||
// selects the scoring head inside the rerankers backend.
|
||||
func (a *Application) Reranker(modelName string) backend.Reranker {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewReranker(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// Embedder returns a backend.Embedder bound to the named model, or
|
||||
// nil when unknown. Used by the router's L2 embedding cache.
|
||||
func (a *Application) Embedder(modelName string) backend.Embedder {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewEmbedder(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// VectorStore returns a backend.VectorStore for the named collection,
|
||||
// or nil when the name is empty. Each router model gets its own
|
||||
// backend process via the model loader's cache keyed by storeName.
|
||||
func (a *Application) VectorStore(storeName string) backend.VectorStore {
|
||||
return backend.NewVectorStore(a.modelLoader, a.applicationConfig, storeName)
|
||||
}
|
||||
@@ -87,6 +87,28 @@ var _ = Describe("loadRuntimeSettingsFromFile", func() {
|
||||
})
|
||||
})
|
||||
|
||||
// MITM listener address. The file is the only source — no env var
|
||||
// exists — so a regression here means an admin who configured the
|
||||
// listener via /api/settings loses it after a reboot, even though
|
||||
// the value is still on disk in the volume. (Intercept hosts now
|
||||
// live in model YAML mitm.hosts: blocks, not runtime_settings.json.)
|
||||
Describe("MITM fields", func() {
|
||||
It("loads mitm_listen", func() {
|
||||
cfg := &config.ApplicationConfig{DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`)}
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.MITMListen).To(Equal(":8443"))
|
||||
})
|
||||
|
||||
It("does not override an explicit CLI flag", func() {
|
||||
cfg := &config.ApplicationConfig{
|
||||
DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`),
|
||||
MITMListen: ":9999", // simulate WithMITMListen(":9999")
|
||||
}
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.MITMListen).To(Equal(":9999"), "CLI flag must win over the persisted file value")
|
||||
})
|
||||
})
|
||||
|
||||
// The Agent Pool block has a mix of zero and non-zero defaults
|
||||
// (Enabled=true, EmbeddingModel="granite-...", MaxChunkingSize=400,
|
||||
// VectorEngine="chromem", AgentHubURL="https://agenthub.localai.io").
|
||||
|
||||
@@ -15,11 +15,18 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
@@ -128,6 +135,117 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}()
|
||||
}
|
||||
|
||||
// Initialize the OTel + Prometheus metric pipeline before any
|
||||
// counter is created. monitoring.NewLocalAIMetricsService calls
|
||||
// otel.SetMeterProvider, so any subsequent otel.Meter() call —
|
||||
// including billing.NewRecorder below — sees the real provider
|
||||
// rather than the no-op global. Initialising metrics later (in
|
||||
// core/http/app.go) leaves billing's counters bound to a no-op
|
||||
// meter and never reaches /metrics. We deliberately ignore
|
||||
// DisableMetrics here for ordering purposes; the HTTP middleware
|
||||
// that records api_call histograms is still gated.
|
||||
if !options.DisableMetrics {
|
||||
ms, err := monitoring.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
xlog.Error("failed to initialize metrics provider", "error", err)
|
||||
} else {
|
||||
application.metricsService = ms
|
||||
// Bind the billing package's counters to the same meter the
|
||||
// metrics service exports. Without this, billing's counters
|
||||
// resolve via the OTel global and never reach /metrics.
|
||||
billing.SetMeter(ms.Meter)
|
||||
}
|
||||
}
|
||||
|
||||
// Wire the routing-module billing recorder. The recorder runs in
|
||||
// every mode (auth on/off, distributed/single-node) so that token
|
||||
// tracking is not gated on auth — a no-auth single-user box still
|
||||
// gets dashboards and `/api/usage` populated.
|
||||
//
|
||||
// fallbackUser is wired *unconditionally* when stats are enabled.
|
||||
// UsageMiddleware uses it as the attribution source whenever
|
||||
// auth.GetUser(c) is nil — that covers (a) no-auth deployments and
|
||||
// (b) internal callers under auth-on (cron flushers, distributed
|
||||
// worker callbacks) that hit a recordable endpoint without a user
|
||||
// in context. The billing.user_id_present invariant still rejects
|
||||
// empty IDs; LocalUser() returns a stable UUID per data path.
|
||||
if !options.DisableStats {
|
||||
var statsBackend billing.StatsBackend
|
||||
switch {
|
||||
case application.authDB != nil:
|
||||
statsBackend = billing.NewGormBackend(application.authDB, 0, 0)
|
||||
xlog.Info("stats: using auth DB for usage records")
|
||||
default:
|
||||
statsBackend = billing.NewMemoryBackend(0)
|
||||
xlog.Info("stats: using in-memory ring buffer (no-auth single-user mode)")
|
||||
}
|
||||
application.fallbackUser = billing.LocalUser(options.DataPath)
|
||||
application.statsRecorder = billing.NewRecorder(statsBackend)
|
||||
// Drain pending records on SIGTERM. The GORM backend buffers up
|
||||
// to maxPending (5k) records across a 5s flush tick, so without
|
||||
// this the last few seconds of usage disappear on graceful exit.
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
_ = application.statsRecorder.Close()
|
||||
})
|
||||
xlog.Info("stats: fallback user wired", "local_user_id", application.fallbackUser.ID)
|
||||
} else {
|
||||
xlog.Info("stats: disabled by --disable-stats")
|
||||
}
|
||||
|
||||
// Wire the regex PII filter. Default-on: a single-user box gets
|
||||
// the built-in pattern set the first time it starts, with email/
|
||||
// phone/SSN/credit-card on mask and api_key_prefix on block. If
|
||||
// the operator wants different actions, --pii-config points at a
|
||||
// YAML file that overrides per-id; --disable-pii turns it off
|
||||
// entirely.
|
||||
if !options.DisablePII {
|
||||
patterns, err := pii.LoadConfig(options.PIIConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pii config: %w", err)
|
||||
}
|
||||
application.piiRedactor = pii.NewRedactor(patterns)
|
||||
application.piiEvents = pii.NewMemoryEventStore(0)
|
||||
// Apply persisted per-pattern overrides — admins toggling
|
||||
// action/disabled via the UI and clicking "Save to disk" land
|
||||
// here on the next start. Bad ids are warned and ignored so a
|
||||
// stale entry doesn't block startup.
|
||||
for id, ov := range options.PIIPatternOverrides {
|
||||
if ov.Action != nil {
|
||||
if err := application.piiRedactor.SetAction(id, pii.Action(*ov.Action)); err != nil {
|
||||
xlog.Warn("pii: persisted override skipped", "pattern", id, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if ov.Disabled != nil {
|
||||
if err := application.piiRedactor.SetDisabled(id, *ov.Disabled); err != nil {
|
||||
xlog.Warn("pii: persisted disable skipped", "pattern", id, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
xlog.Info("pii: filter enabled",
|
||||
"patterns", len(patterns),
|
||||
"config_path", options.PIIConfigPath,
|
||||
"persisted_overrides", len(options.PIIPatternOverrides),
|
||||
)
|
||||
} else {
|
||||
xlog.Info("pii: disabled by --disable-pii")
|
||||
}
|
||||
|
||||
// Wire the routing decision log. Always-on when stats are enabled —
|
||||
// the per-router admin page reads this as the live activity feed
|
||||
// and as input to drift checks for subsystem 5.
|
||||
if !options.DisableStats {
|
||||
application.routerDecisions = router.NewMemoryDecisionStore(0)
|
||||
}
|
||||
// Process-wide classifier cache shared across all route middlewares so
|
||||
// the embedding-cache stats endpoint sees a single source of truth.
|
||||
application.routerRegistry = router.NewRegistry()
|
||||
|
||||
// Subsystem 5: admission control. Limiter is always wired so a
|
||||
// model that gains a limits: block via gallery install or YAML
|
||||
// edit takes effect on the next restart without conditional plumbing.
|
||||
application.admissionLimiter = admission.New()
|
||||
|
||||
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
|
||||
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
|
||||
if application.authDB != nil && application.agentJobService != nil {
|
||||
@@ -195,12 +313,36 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
}
|
||||
// Hydrate from the store first so the wildcard subscriber finds an
|
||||
// already-populated statuses map for any operations still in flight
|
||||
// on a peer replica.
|
||||
if err := application.galleryService.Hydrate(); err != nil {
|
||||
xlog.Warn("Gallery service hydrate failed", "error", err)
|
||||
}
|
||||
// Bind cache-invalidation handler before SubscribeBroadcasts so the
|
||||
// first inbound event is already routed. Peer replicas install a
|
||||
// model and broadcast on SubjectCacheInvalidateModels; this
|
||||
// callback re-runs LoadModelConfigsFromPath so a subsequent chat
|
||||
// completion that load-balances onto this replica finds the new
|
||||
// config. The originating replica reloads inline in modelHandler
|
||||
// and never enters this path.
|
||||
gs := application.galleryService
|
||||
sys := options.SystemState
|
||||
cfgLoaderOpts := options.ToConfigLoaderOptions()
|
||||
gs.OnModelsChanged = func(_ messaging.CacheInvalidateEvent) {
|
||||
if err := application.ModelConfigLoader().LoadModelConfigsFromPath(sys.Model.ModelsPath, cfgLoaderOpts...); err != nil {
|
||||
xlog.Warn("Failed to reload model configs after peer invalidation", "error", err)
|
||||
}
|
||||
}
|
||||
if err := application.galleryService.SubscribeBroadcasts(); err != nil {
|
||||
xlog.Warn("Gallery service subscribe failed", "error", err)
|
||||
}
|
||||
// Wire distributed model/backend managers so delete propagates to workers
|
||||
application.galleryService.SetModelManager(
|
||||
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||
)
|
||||
application.galleryService.SetBackendManager(
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry, application.galleryService),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -212,12 +354,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.RequireBackendIntegrity, nil, options.ModelsURL...); err != nil {
|
||||
xlog.Error("error installing models", "error", err)
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", "", options.RequireBackendIntegrity); err != nil {
|
||||
xlog.Error("error installing external backend", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -267,13 +409,13 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels, options.RequireBackendIntegrity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath, options.RequireBackendIntegrity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -291,6 +433,20 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
loadRuntimeSettingsFromFile(options)
|
||||
}
|
||||
|
||||
// Wire the cloudproxy MITM listener. Opt-in: empty MITMListen
|
||||
// means "no MITM" — operators must explicitly choose to start
|
||||
// it because clients have to install the generated CA cert.
|
||||
// The handler reuses the global redactor + event store so an
|
||||
// admin who's already configured PII filtering for direct API
|
||||
// traffic doesn't need a parallel config for MITM traffic.
|
||||
// Runs after loadRuntimeSettingsFromFile so a listener configured
|
||||
// via /api/settings is brought back up across restarts.
|
||||
if options.MITMListen != "" {
|
||||
if err := startMITMProxy(application, options); err != nil {
|
||||
return nil, fmt.Errorf("mitm: startup: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging)
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
@@ -552,6 +708,13 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
}
|
||||
if settings.TracingMaxBodyBytes != nil {
|
||||
// Allow the on-disk setting to override the CLI/env default. The
|
||||
// startup default is non-zero (see NewApplicationConfig), so a plain
|
||||
// `== 0` guard like the others would never trigger; we instead respect
|
||||
// any value the file specifies. 0 in the file means "uncapped".
|
||||
options.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||
}
|
||||
|
||||
// Branding / whitelabeling. There are no env vars for these — the file is
|
||||
// the only source — so apply unconditionally. Without this block a server
|
||||
@@ -573,6 +736,25 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.Branding.FaviconFile = *settings.FaviconFile
|
||||
}
|
||||
|
||||
// MITM listener address. The CLI flag WithMITMListen populates
|
||||
// options at startup; if the user configured MITM via /api/settings
|
||||
// after the fact, only the file holds the value. Apply when the
|
||||
// CLI flag did not already set it. (Intercept hosts now live in
|
||||
// model YAML mitm.hosts: rather than runtime_settings.json.)
|
||||
if settings.MITMListen != nil && options.MITMListen == "" {
|
||||
options.MITMListen = *settings.MITMListen
|
||||
}
|
||||
|
||||
// PII pattern overrides — file is the only source; CLI flags don't
|
||||
// reach into this map. Apply unconditionally when present; the
|
||||
// redactor wiring below sees the result on first construction.
|
||||
if settings.PIIPatternOverrides != nil {
|
||||
options.PIIPatternOverrides = make(map[string]config.PIIPatternRuntimeOverride, len(*settings.PIIPatternOverrides))
|
||||
for id, ov := range *settings.PIIPatternOverrides {
|
||||
options.PIIPatternOverrides[id] = ov
|
||||
}
|
||||
}
|
||||
|
||||
// Backend upgrade flags
|
||||
if settings.AutoUpgradeBackends != nil {
|
||||
if !options.AutoUpgradeBackends {
|
||||
|
||||
@@ -217,7 +217,7 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
err = bm.UpgradeBackend(ctx, name, nil)
|
||||
} else {
|
||||
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil)
|
||||
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)
|
||||
}
|
||||
if err != nil {
|
||||
xlog.Error("Failed to auto-upgrade backend",
|
||||
|
||||
@@ -78,7 +78,7 @@ func ModelAudioTransform(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func ModelAudioTransform(
|
||||
data["sample_rate"] = res.SampleRate
|
||||
data["samples"] = res.Samples
|
||||
data["reference_provided"] = res.ReferenceProvided
|
||||
if snippet := trace.AudioSnippet(dst); snippet != nil {
|
||||
if snippet := trace.AudioSnippet(dst, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
|
||||
169
core/backend/ctx_propagation_test.go
Normal file
169
core/backend/ctx_propagation_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package backend_test
|
||||
|
||||
// Regression spec for X-LocalAI-Node coverage on audio/image/TTS/rerank/VAD.
|
||||
//
|
||||
// The X-LocalAI-Node middleware (core/http/middleware.ExposeNodeHeader)
|
||||
// works end-to-end only if the per-request holder attached to the HTTP
|
||||
// request context reaches the SmartRouter via ml.Load(opts...). The chain
|
||||
// is:
|
||||
//
|
||||
// handler -> backend.Foo(ctx, ...) -> ModelOptions(cfg, app, WithContext(ctx))
|
||||
// -> ml.Load(opts...) -> grpcModel(..., o.context) -> modelRouter(ctx, ...)
|
||||
// -> SmartRouter -> distributedhdr.Stamp(ctx, nodeID)
|
||||
//
|
||||
// If any backend helper drops `ctx` and lets ModelOptions fall back to the
|
||||
// app context, the router never sees the per-request holder and the
|
||||
// header silently stays empty for that endpoint. These specs pin the
|
||||
// request-context-reaches-router contract for the five backend helpers
|
||||
// that were previously dropping ctx between the handler and Load.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pbproto "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// newCapturingLoader returns a ModelLoader wired with a stub model router
|
||||
// that captures the context it receives and then short-circuits with a
|
||||
// sentinel error. The router callback is the exact seam where the
|
||||
// SmartRouter would call distributedhdr.Stamp in production, so observing
|
||||
// the holder here is equivalent to observing it at the real router.
|
||||
func newCapturingLoader() (*model.ModelLoader, *atomic.Value, func() context.Context) {
|
||||
loader := model.NewModelLoader(&system.SystemState{})
|
||||
var captured atomic.Value
|
||||
loader.SetModelRouter(func(ctx context.Context, _ string, _, _, _ string, _ *pbproto.ModelOptions, _ bool) (*model.Model, error) {
|
||||
captured.Store(ctx)
|
||||
// Return an error so the backend short-circuits before trying to
|
||||
// dial gRPC. We only care about the context-arrival contract.
|
||||
return nil, errRouterShortCircuit
|
||||
})
|
||||
get := func() context.Context {
|
||||
v, _ := captured.Load().(context.Context)
|
||||
return v
|
||||
}
|
||||
return loader, &captured, get
|
||||
}
|
||||
|
||||
var errRouterShortCircuit = sentinelErr("router short-circuit (test)")
|
||||
|
||||
type sentinelErr string
|
||||
|
||||
func (s sentinelErr) Error() string { return string(s) }
|
||||
|
||||
func newAppCfg() *config.ApplicationConfig {
|
||||
return config.NewApplicationConfig(config.WithSystemState(&system.SystemState{}))
|
||||
}
|
||||
|
||||
func newModelCfg() config.ModelConfig {
|
||||
threads := 1
|
||||
cfg := config.ModelConfig{
|
||||
Name: "test-model",
|
||||
Backend: "stub-backend",
|
||||
Threads: &threads,
|
||||
}
|
||||
cfg.Model = "test.bin"
|
||||
return cfg
|
||||
}
|
||||
|
||||
var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
|
||||
const fakeNodeID = "node-ctx-propagation-7"
|
||||
|
||||
var (
|
||||
appCfg *config.ApplicationConfig
|
||||
modelCfg config.ModelConfig
|
||||
loader *model.ModelLoader
|
||||
routerCtxOf func() context.Context
|
||||
holder *atomic.Value
|
||||
reqCtx context.Context
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
appCfg = newAppCfg()
|
||||
modelCfg = newModelCfg()
|
||||
loader, _, routerCtxOf = newCapturingLoader()
|
||||
holder = distributedhdr.NewHolder()
|
||||
reqCtx = distributedhdr.WithHolder(context.Background(), holder)
|
||||
})
|
||||
|
||||
// stampViaRouterCtx asserts the captured router context carries the
|
||||
// SAME holder that was attached to the request. We verify by stamping
|
||||
// through the router-side ctx and observing the value via the
|
||||
// request-side holder; if the holders were different objects the load
|
||||
// would return "".
|
||||
stampViaRouterCtx := func() {
|
||||
routerCtx := routerCtxOf()
|
||||
Expect(routerCtx).ToNot(BeNil(), "router callback must have been invoked")
|
||||
distributedhdr.Stamp(routerCtx, fakeNodeID)
|
||||
Expect(distributedhdr.Load(holder)).To(Equal(fakeNodeID),
|
||||
"stamp via router-side ctx must be observable via the request-side holder")
|
||||
}
|
||||
|
||||
It("Rerank forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.Rerank(reqCtx, &pbproto.RerankRequest{Query: "q"}, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("VAD forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.VAD(&schema.VADRequest{}, reqCtx, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTS forwards the request context to the SmartRouter", func() {
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTranscriptionWithOptions forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.ModelTranscriptionWithOptions(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTranscriptionStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTranscriptionStream(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg, func(backend.TranscriptionStreamChunk) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ImageGeneration forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.ImageGeneration(reqCtx, 64, 64, 1, 0, "p", "", "", "/tmp/out.png", loader, modelCfg, appCfg, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("does NOT leak the holder when the app context is used instead", func() {
|
||||
// Sanity: the bug being fixed manifests as the router getting
|
||||
// appCfg.Context (no holder) instead of reqCtx (holder). A direct
|
||||
// call with context.Background() must not see the holder via the
|
||||
// app context surface.
|
||||
appCtxOnly := appCfg.Context
|
||||
Expect(distributedhdr.Holder(appCtxOnly)).To(BeNil(),
|
||||
"the app context must not be the carrier of per-request holders")
|
||||
})
|
||||
})
|
||||
@@ -35,7 +35,7 @@ func Detection(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -11,9 +12,38 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
// Embedder produces a fixed-dimension vector from a prompt. The
|
||||
// router's L2 embedding cache uses it to look up semantically-similar
|
||||
// past decisions.
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, text string) ([]float32, error)
|
||||
}
|
||||
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// NewEmbedder binds (loader, modelConfig, appConfig) into an Embedder.
|
||||
func NewEmbedder(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Embedder {
|
||||
return &modelEmbedder{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelEmbedder struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (e *modelEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
fn, err := ModelEmbedding(ctx, text, nil, e.loader, e.modelConfig, e.appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fn()
|
||||
}
|
||||
|
||||
func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
@@ -67,7 +97,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConf
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"input_text": trace.TruncateString(s, 1000),
|
||||
|
||||
@@ -32,7 +32,7 @@ func FaceAnalyze(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func FaceVerify(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -10,9 +11,12 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
func ImageGeneration(ctx context.Context, height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(
|
||||
opts...,
|
||||
)
|
||||
@@ -23,7 +27,7 @@ func ImageGeneration(height, width, step, seed int, positive_prompt, negative_pr
|
||||
|
||||
fn := func() error {
|
||||
_, err := inferenceModel.GenerateImage(
|
||||
appConfig.Context,
|
||||
ctx,
|
||||
&proto.GenerateImageRequest{
|
||||
Height: int32(height),
|
||||
Width: int32(width),
|
||||
@@ -41,7 +45,7 @@ func ImageGeneration(height, width, step, seed int, positive_prompt, negative_pr
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"positive_prompt": positive_prompt,
|
||||
|
||||
@@ -86,7 +86,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
if !slices.Contains(modelNames, modelName) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries, o.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
|
||||
//return nil, err
|
||||
@@ -94,7 +94,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
opts := ModelOptions(*c, o)
|
||||
opts := ModelOptions(*c, o, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(o, c.Name, c.Backend, err, map[string]any{"model_file": modelFile})
|
||||
@@ -305,7 +305,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
|
||||
if o.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems, o.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
@@ -316,9 +316,13 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
"audios_count": len(audios),
|
||||
}
|
||||
|
||||
// Cap the captured fields up front: agent-pool LLM calls embed the
|
||||
// full augmented chat history in messages and the full reply in
|
||||
// response, so without a per-field cap a single trace can dwarf the
|
||||
// rest of the buffer. The cap matches the API-trace body cap.
|
||||
if len(messages) > 0 {
|
||||
if msgJSON, err := json.Marshal(messages); err == nil {
|
||||
traceData["messages"] = string(msgJSON)
|
||||
traceData["messages"] = trace.TruncateToBytes(string(msgJSON), o.TracingMaxBodyBytes)
|
||||
}
|
||||
}
|
||||
if reasoningJSON, err := json.Marshal(c.ReasoningConfig); err == nil {
|
||||
@@ -337,7 +341,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
resp, err := originalFn()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
traceData["response"] = resp.Response
|
||||
traceData["response"] = trace.TruncateToBytes(resp.Response, o.TracingMaxBodyBytes)
|
||||
traceData["token_usage"] = map[string]any{
|
||||
"prompt": resp.Usage.Prompt,
|
||||
"completion": resp.Usage.Completion,
|
||||
@@ -359,10 +363,10 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
toolCallCount += len(d.ToolCalls)
|
||||
}
|
||||
if len(contentParts) > 0 {
|
||||
chatDeltasInfo["content"] = strings.Join(contentParts, "")
|
||||
chatDeltasInfo["content"] = trace.TruncateToBytes(strings.Join(contentParts, ""), o.TracingMaxBodyBytes)
|
||||
}
|
||||
if len(reasoningParts) > 0 {
|
||||
chatDeltasInfo["reasoning_content"] = strings.Join(reasoningParts, "")
|
||||
chatDeltasInfo["reasoning_content"] = trace.TruncateToBytes(strings.Join(reasoningParts, ""), o.TracingMaxBodyBytes)
|
||||
}
|
||||
if toolCallCount > 0 {
|
||||
chatDeltasInfo["tool_call_count"] = toolCallCount
|
||||
|
||||
@@ -21,7 +21,7 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
|
||||
if !appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceModelLoad,
|
||||
@@ -242,6 +242,18 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
Tokenizer: c.Tokenizer,
|
||||
}
|
||||
|
||||
if c.Backend == "cloud-proxy" {
|
||||
opts.Proxy = &pb.ProxyOptions{
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
}
|
||||
}
|
||||
|
||||
if c.MMProj != "" {
|
||||
opts.MMProj = filepath.Join(modelPath, c.MMProj)
|
||||
}
|
||||
@@ -277,7 +289,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
MinP: float32(*c.MinP),
|
||||
Tokens: int32(*c.Maxtokens),
|
||||
Threads: int32(*c.Threads),
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
PromptCacheAll: *c.PromptCacheAll,
|
||||
PromptCacheRO: c.PromptCacheRO,
|
||||
PromptCachePath: promptCachePath,
|
||||
F16KV: *c.F16,
|
||||
|
||||
@@ -11,8 +11,56 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// RerankResult is the per-document score returned to consumers,
|
||||
// narrowed from proto.RerankResult so callers don't need to depend on
|
||||
// the proto package.
|
||||
type RerankResult struct {
|
||||
Index int
|
||||
RelevanceScore float32
|
||||
}
|
||||
|
||||
// Reranker scores a list of candidate documents against a query.
|
||||
// Returns one RerankResult per input document (no top-N truncation -
|
||||
// callers that need it can sort and slice).
|
||||
type Reranker interface {
|
||||
Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error)
|
||||
}
|
||||
|
||||
// NewReranker binds (loader, modelConfig, appConfig) into a Reranker.
|
||||
func NewReranker(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Reranker {
|
||||
return &modelReranker{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelReranker struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (r *modelReranker) Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error) {
|
||||
req := &proto.RerankRequest{
|
||||
Query: query,
|
||||
Documents: documents,
|
||||
// TopN=0: backend returns scores for every document. Truncating
|
||||
// here would silently zero out labels the reranker considered
|
||||
// unlikely, which the router classifier needs.
|
||||
}
|
||||
res, err := Rerank(ctx, req, r.loader, r.appConfig, r.modelConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]RerankResult, 0, len(res.GetResults()))
|
||||
for _, dr := range res.GetResults() {
|
||||
out = append(out, RerankResult{Index: int(dr.GetIndex()), RelevanceScore: dr.GetRelevanceScore()})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
rerankModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -25,7 +73,7 @@ func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.Mod
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
159
core/backend/score.go
Normal file
159
core/backend/score.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// ScoreOptions controls a single Score request.
|
||||
type ScoreOptions struct {
|
||||
// IncludeTokenLogprobs returns per-token log-probability detail for
|
||||
// each candidate. Off by default — the joint LogProb is enough for
|
||||
// ranking; callers that need calibration / entropy over the token
|
||||
// stream opt in.
|
||||
IncludeTokenLogprobs bool
|
||||
// LengthNormalize divides the joint log-prob by the candidate's
|
||||
// token count. Useful when comparing candidates of different
|
||||
// lengths — without it, longer candidates score lower by default.
|
||||
LengthNormalize bool
|
||||
}
|
||||
|
||||
// CandidateScore is the per-candidate result. Mirrors pb.CandidateScore
|
||||
// but avoids leaking the proto type to consumers.
|
||||
type CandidateScore struct {
|
||||
LogProb float64
|
||||
LengthNormalizedLogProb float64
|
||||
NumTokens int
|
||||
Tokens []TokenLogProb
|
||||
}
|
||||
|
||||
type TokenLogProb struct {
|
||||
Token string
|
||||
LogProb float64
|
||||
}
|
||||
|
||||
// Scorer evaluates a model's joint log-probability of each candidate
|
||||
// continuation given a shared prompt. Implemented by NewScorer over a
|
||||
// model-loaded backend; the router's score classifier consumes this
|
||||
// for multi-label policy selection.
|
||||
type Scorer interface {
|
||||
Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error)
|
||||
}
|
||||
|
||||
// NewScorer binds (loader, modelConfig, appConfig) into a Scorer. The
|
||||
// underlying backend is resolved lazily on the first Score call.
|
||||
// Returns nil only as a contract violation — callers that need to
|
||||
// detect "model not loadable" should look up the config first.
|
||||
func NewScorer(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Scorer {
|
||||
return &modelScorer{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelScorer struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (m *modelScorer) Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error) {
|
||||
fn, err := ModelScore(prompt, candidates, ScoreOptions{LengthNormalize: true}, m.loader, m.modelConfig, m.appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
// ModelScore loads the backend for modelConfig and returns a closure
|
||||
// that scores `candidates` against `prompt`. The closure is bound to
|
||||
// the loaded model so callers can keep it around for repeat scoring
|
||||
// within the same request without re-resolving the backend.
|
||||
func ModelScore(prompt string, candidates []string, opts ScoreOptions, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func(ctx context.Context) ([]CandidateScore, error), error) {
|
||||
modelOpts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(modelOpts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
b, ok := inferenceModel.(grpc.Backend)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("scoring not supported by backend %q", modelConfig.Backend)
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("Score: candidates must be non-empty")
|
||||
}
|
||||
return func(ctx context.Context) ([]CandidateScore, error) {
|
||||
// Surface score calls in the Traces UI alongside the LLM calls
|
||||
// they typically gate (router classifier, eval scoring). Without
|
||||
// this, a router-classified request shows only the downstream LLM
|
||||
// trace with no record of the classification that picked it.
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
resp, err := b.Score(ctx, &pb.ScoreRequest{
|
||||
Prompt: prompt,
|
||||
Candidates: candidates,
|
||||
IncludeTokenLogprobs: opts.IncludeTokenLogprobs,
|
||||
LengthNormalize: opts.LengthNormalize,
|
||||
})
|
||||
results := scoreResponseToCandidates(resp, opts.IncludeTokenLogprobs)
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceScore,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(prompt, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
// Copy candidates so the trace buffer doesn't pin a
|
||||
// caller-owned slice for the lifetime of the ring.
|
||||
"candidates": append([]string(nil), candidates...),
|
||||
"results": results,
|
||||
},
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// scoreResponseToCandidates converts the wire-format pb response into
|
||||
// the value type consumed by callers. Extracted to keep ModelScore's
|
||||
// closure trivial and so the conversion can be unit-tested without a
|
||||
// real backend.
|
||||
func scoreResponseToCandidates(resp *pb.ScoreResponse, includeTokens bool) []CandidateScore {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]CandidateScore, len(resp.Candidates))
|
||||
for i, c := range resp.Candidates {
|
||||
cs := CandidateScore{
|
||||
LogProb: c.LogProb,
|
||||
LengthNormalizedLogProb: c.LengthNormalizedLogProb,
|
||||
NumTokens: int(c.NumTokens),
|
||||
}
|
||||
if includeTokens && len(c.Tokens) > 0 {
|
||||
cs.Tokens = make([]TokenLogProb, len(c.Tokens))
|
||||
for j, t := range c.Tokens {
|
||||
cs.Tokens[j] = TokenLogProb{Token: t.Token, LogProb: t.LogProb}
|
||||
}
|
||||
}
|
||||
out[i] = cs
|
||||
}
|
||||
return out
|
||||
}
|
||||
63
core/backend/score_test.go
Normal file
63
core/backend/score_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("scoreResponseToCandidates", func() {
|
||||
It("returns nil for a nil response", func() {
|
||||
Expect(scoreResponseToCandidates(nil, false)).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns an empty slice when the response has no candidates", func() {
|
||||
Expect(scoreResponseToCandidates(&pb.ScoreResponse{}, false)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("copies LogProb / LengthNormalizedLogProb / NumTokens for every candidate", func() {
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{
|
||||
{LogProb: -2.0, LengthNormalizedLogProb: -1.0, NumTokens: 2},
|
||||
{LogProb: -7.5, LengthNormalizedLogProb: -1.5, NumTokens: 5},
|
||||
}}
|
||||
got := scoreResponseToCandidates(resp, false)
|
||||
Expect(got).To(HaveLen(2))
|
||||
Expect(got[0].LogProb).To(Equal(-2.0))
|
||||
Expect(got[0].LengthNormalizedLogProb).To(Equal(-1.0))
|
||||
Expect(got[0].NumTokens).To(Equal(2))
|
||||
Expect(got[1].LogProb).To(Equal(-7.5))
|
||||
Expect(got[1].NumTokens).To(Equal(5))
|
||||
})
|
||||
|
||||
It("omits per-token detail when includeTokens=false even if the wire response carries it", func() {
|
||||
// Defensive: if the backend over-reports we still respect the
|
||||
// caller's opt-in so consumers don't pay marshaling for data
|
||||
// they didn't ask for.
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{
|
||||
LogProb: -1.0,
|
||||
Tokens: []*pb.TokenLogProb{{Token: "hi", LogProb: -1.0}},
|
||||
}}}
|
||||
got := scoreResponseToCandidates(resp, false)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Tokens).To(BeNil())
|
||||
})
|
||||
|
||||
It("populates per-token detail when includeTokens=true", func() {
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{
|
||||
LogProb: -3.0,
|
||||
NumTokens: 2,
|
||||
Tokens: []*pb.TokenLogProb{
|
||||
{Token: "Hello", LogProb: -1.0},
|
||||
{Token: " world", LogProb: -2.0},
|
||||
},
|
||||
}}}
|
||||
got := scoreResponseToCandidates(resp, true)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Tokens).To(HaveLen(2))
|
||||
Expect(got[0].Tokens[0].Token).To(Equal("Hello"))
|
||||
Expect(got[0].Tokens[0].LogProb).To(Equal(-1.0))
|
||||
Expect(got[0].Tokens[1].Token).To(Equal(" world"))
|
||||
Expect(got[0].Tokens[1].LogProb).To(Equal(-2.0))
|
||||
})
|
||||
})
|
||||
@@ -98,7 +98,7 @@ func SoundGeneration(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,74 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
// VectorStore is the narrowed KNN store used by the router's embedding
|
||||
// cache. Search returns the top-1 match (cosine similarity in [-1, 1])
|
||||
// and the serialised payload, or ok=false on a clean miss.
|
||||
type VectorStore interface {
|
||||
Search(ctx context.Context, vec []float32) (similarity float64, payload []byte, ok bool, err error)
|
||||
Insert(ctx context.Context, vec []float32, payload []byte) error
|
||||
}
|
||||
|
||||
// NewVectorStore returns a VectorStore backed by the local-store
|
||||
// gRPC backend, namespaced by storeName so two routers don't collide.
|
||||
func NewVectorStore(loader *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) VectorStore {
|
||||
if storeName == "" {
|
||||
return nil
|
||||
}
|
||||
return &localVectorStore{loader: loader, appConfig: appConfig, storeName: storeName}
|
||||
}
|
||||
|
||||
type localVectorStore struct {
|
||||
loader *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
storeName string
|
||||
}
|
||||
|
||||
func (s *localVectorStore) backend(_ context.Context) (grpc.Backend, error) {
|
||||
return StoreBackend(s.loader, s.appConfig, s.storeName, "")
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Search(ctx context.Context, vec []float32) (float64, []byte, bool, error) {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return 0, nil, false, fmt.Errorf("vector store load: %w", err)
|
||||
}
|
||||
_, values, similarities, err := store.Find(ctx, be, vec, 1)
|
||||
if err != nil {
|
||||
// local-store's Find returns "existing length is -1" before
|
||||
// any keys are inserted. Surface that as a clean miss so the
|
||||
// cache layer treats it as an empty store and proceeds to
|
||||
// Insert rather than skipping.
|
||||
if strings.Contains(err.Error(), "existing length is -1") {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return 0, nil, false, fmt.Errorf("vector store find: %w", err)
|
||||
}
|
||||
if len(values) == 0 || len(similarities) == 0 {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return float64(similarities[0]), values[0], true, nil
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload []byte) error {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vector store load: %w", err)
|
||||
}
|
||||
return store.SetSingle(ctx, be, vec, payload)
|
||||
}
|
||||
|
||||
func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string, backend string) (grpc.Backend, error) {
|
||||
if backend == "" {
|
||||
backend = model.LocalStoreBackend
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user