mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-26 01:31:27 -04:00
Compare commits
75 Commits
v4.2.6
...
fix/galler
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11c5fd677d | ||
|
|
b3300ef207 | ||
|
|
8d6548c0b9 | ||
|
|
b02e3ffe61 | ||
|
|
a891eedd08 | ||
|
|
06e777b75e | ||
|
|
90ea327178 | ||
|
|
6a80e23733 | ||
|
|
1dcd1ae915 | ||
|
|
acad78a95a | ||
|
|
c94d1e1f5b | ||
|
|
270c256409 | ||
|
|
1a30020a82 | ||
|
|
8bbe89a537 | ||
|
|
dcc5599f89 | ||
|
|
a95f4e63e0 | ||
|
|
dfd19a3f88 | ||
|
|
d7387c725c | ||
|
|
63d84a5705 | ||
|
|
1198d10b58 | ||
|
|
a0f3e26245 | ||
|
|
e4cc1f11f3 | ||
|
|
6ed269d0b9 | ||
|
|
5756fb046d | ||
|
|
7980629bc5 | ||
|
|
d0a59be9de | ||
|
|
5cda4f1ccf | ||
|
|
c500461c69 | ||
|
|
834ecc36bf | ||
|
|
61bf34ea2f | ||
|
|
0b2ae3c6ca | ||
|
|
4735345105 | ||
|
|
7384fd800b | ||
|
|
6942713d85 | ||
|
|
0cf52c44d4 | ||
|
|
0d34cf7cbd | ||
|
|
f0cb02afb8 | ||
|
|
a39e025d64 | ||
|
|
05e8e1e9f4 | ||
|
|
a7f6cc8956 | ||
|
|
f15b9178ec | ||
|
|
959de86761 | ||
|
|
4c234abc2c | ||
|
|
c68818a62e | ||
|
|
11d5bd0cc3 | ||
|
|
12e056e96d | ||
|
|
308aa8908a | ||
|
|
b2d68a53a2 | ||
|
|
e3706c0512 | ||
|
|
1ffd82a050 | ||
|
|
f515168dbe | ||
|
|
ef6ca34513 | ||
|
|
9413c3767f | ||
|
|
3bf3cce232 | ||
|
|
06f8159035 | ||
|
|
f6a73f54fa | ||
|
|
24e04d8e81 | ||
|
|
b9a49449ae | ||
|
|
1879e11042 | ||
|
|
403d391316 | ||
|
|
fc3980dadd | ||
|
|
2009544b44 | ||
|
|
e859345b12 | ||
|
|
f30712f8e8 | ||
|
|
a19c77c5f8 | ||
|
|
4b02d23c0c | ||
|
|
21140e96b2 | ||
|
|
fc803e8d48 | ||
|
|
ca51606bfe | ||
|
|
cb502de309 | ||
|
|
5d0b549049 | ||
|
|
11cff1b309 | ||
|
|
4ca3d2cdc0 | ||
|
|
3cba35ed32 | ||
|
|
265ae35231 |
@@ -112,6 +112,8 @@ Add a YAML anchor definition in the `## metas` section (around line 2-300). Look
|
||||
|
||||
Add image entries at the end of the file, following the pattern of similar backends such as `diffusers` or `chatterbox`. Include both `latest` (production) and `master` (development) tags.
|
||||
|
||||
**Note on integrity:** OCI backends installed from a gallery whose `verification:` block is set are verified against a keyless-cosign policy before extraction; tarball/HTTP backends use the optional `sha256:` field. New backends do not need any extra YAML — the gallery-level `verification:` block covers every entry. See [.agents/backend-signing.md](backend-signing.md) for the producer-side CI step.
|
||||
|
||||
## 4. Update the Makefile
|
||||
|
||||
The Makefile needs to be updated in several places to support building and testing the new backend:
|
||||
|
||||
126
.agents/backend-signing.md
Normal file
126
.agents/backend-signing.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Backend image signing & verification
|
||||
|
||||
LocalAI verifies backend OCI images against a per-gallery keyless-cosign
|
||||
policy. This page documents the trust model, the producer side
|
||||
(`.github/workflows/backend_merge.yml` in this repo), and the consumer
|
||||
side (`pkg/oci/cosignverify` plus the gallery YAML).
|
||||
|
||||
## Trust model
|
||||
|
||||
- **Producer:** `.github/workflows/backend_merge.yml` signs each pushed
|
||||
manifest list with `cosign sign --recursive` in keyless mode after
|
||||
`docker buildx imagetools create`. The signing cert is issued by
|
||||
Fulcio bound to the workflow's OIDC identity. There is no long-lived
|
||||
signing key. `--recursive` signs both the manifest list and every
|
||||
per-arch entry — needed because our consumer resolves a tag to a
|
||||
per-arch manifest before checking signatures.
|
||||
- **Storage:** Signatures are written as OCI 1.1 referrers
|
||||
(`--registry-referrers-mode=oci-1-1`) in the new Sigstore bundle format
|
||||
(current cosign releases do this by default; no `--new-bundle-format`
|
||||
flag). No `:sha256-<hex>.sig` tag clutter.
|
||||
- **Consumer:** `pkg/oci/cosignverify` discovers the bundle via the
|
||||
referrers API, hands it to `sigstore-go`, and verifies it against the
|
||||
policy declared in the gallery YAML (`Gallery.Verification`).
|
||||
- **Revocation:** Keyless cosign certs are ephemeral (10-minute Fulcio
|
||||
validity), so revocation is policy-side, not CA-side. The gallery's
|
||||
`verification.not_before` (RFC3339) is the kill-switch — advance it to
|
||||
invalidate every signature produced before a known compromise window.
|
||||
|
||||
## Producer setup
|
||||
|
||||
`backend_merge.yml` is the workflow that joins per-arch digests into the
|
||||
multi-arch manifest list users actually pull, so it's also the right place
|
||||
to sign. The job needs:
|
||||
|
||||
- `permissions: { id-token: write, contents: read }` at the job level so
|
||||
the runner can exchange its GitHub OIDC token for a Fulcio cert.
|
||||
- `sigstore/cosign-installer@v3` step (current cosign releases already
|
||||
default to the new bundle format).
|
||||
- After each `docker buildx imagetools create`, resolve the resulting
|
||||
list digest with `docker buildx imagetools inspect <tag> --format
|
||||
'{{.Manifest.Digest}}'` and sign:
|
||||
|
||||
```sh
|
||||
cosign sign --yes --recursive \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"${REGISTRY_REPO}@${DIGEST}"
|
||||
```
|
||||
|
||||
Sign by digest, never by tag — signing by tag binds the signature to
|
||||
whatever the tag points at *now*, and a subsequent tag push orphans it.
|
||||
|
||||
`--registry-referrers-mode=oci-1-1` is still gated behind
|
||||
`COSIGN_EXPERIMENTAL=1` in cosign v2.4.x (set at the job env level in
|
||||
`backend_merge.yml`). Re-evaluate when bumping the pinned cosign release
|
||||
— newer versions are expected to graduate this flag and the env var can
|
||||
then be dropped.
|
||||
|
||||
`backend_build_darwin.yml` builds and pushes single-arch darwin images
|
||||
that bypass the manifest-list merge. If/when those entries get a gallery
|
||||
`verification:` policy, the equivalent cosign step has to land there
|
||||
too.
|
||||
|
||||
## Consumer setup (in `mudler/LocalAI` gallery YAML)
|
||||
|
||||
Once CI is signing, add a `verification:` block to the backend gallery
|
||||
entry (`backend/index.yaml`):
|
||||
|
||||
```yaml
|
||||
- name: localai
|
||||
url: github:mudler/LocalAI/backend/index.yaml@master
|
||||
verification:
|
||||
issuer: "https://token.actions.githubusercontent.com"
|
||||
identity_regex: "^https://github\\.com/mudler/LocalAI/\\.github/workflows/backend_merge\\.yml@refs/heads/master$"
|
||||
# Optional revocation cutoff; advance during incident response.
|
||||
# not_before: "2026-06-01T00:00:00Z"
|
||||
```
|
||||
|
||||
Identity matching pins the OIDC subject Fulcio issued the signing cert
|
||||
to. Without this, any image signed by *anyone* with a Fulcio cert would
|
||||
pass — the regex is what makes a signature mean "produced by our CI".
|
||||
|
||||
## Strict mode
|
||||
|
||||
Default behaviour: OCI backends without a `verification:` block install
|
||||
with a warning (logs include `installing OCI backend without signature
|
||||
verification`). Tarball/HTTP backends without a `sha256` field log a
|
||||
similar warning.
|
||||
|
||||
For production, set `LOCALAI_REQUIRE_BACKEND_INTEGRITY=1` (or pass
|
||||
`--require-backend-integrity` to `local-ai run` / `local-ai backends
|
||||
install` / `local-ai models install`). The warning becomes a hard error
|
||||
and unverifiable backends refuse to install.
|
||||
|
||||
## Revocation playbook
|
||||
|
||||
If `backend_merge.yml` (or any workflow with `id-token: write`) is
|
||||
compromised and we've shipped malicious signed images:
|
||||
|
||||
1. **Identify the compromise window.** Find the earliest IntegratedTime
|
||||
from the bad signatures (Rekor search by `subject` filter).
|
||||
2. **Set `verification.not_before`** in `backend/index.yaml` to a
|
||||
timestamp just *after* that window's start.
|
||||
3. **Push the YAML.** Deployed LocalAI instances pick it up on next
|
||||
gallery refresh (1-hour cache in `core/gallery/gallery.go`).
|
||||
4. **Fix the underlying compromise** in the workflow and re-sign images
|
||||
with the new build, which will have IntegratedTime > `not_before`.
|
||||
5. **Optional:** for absolute decisiveness, also rotate to a new
|
||||
workflow path (`backend_merge_v2.yml`) and update `identity_regex`.
|
||||
|
||||
## Where the code lives
|
||||
|
||||
- `pkg/oci/cosignverify/` — verifier, policy, OCI referrer fetch, NotBefore enforcement.
|
||||
- `pkg/downloader/uri.go` — `WithImageVerifier` option threaded through `DownloadFileWithContext`.
|
||||
- `core/gallery/backends.go` — `backendDownloadOptions` builds the verifier from the gallery's policy.
|
||||
- `core/config/gallery.go` — `Gallery.Verification` YAML schema.
|
||||
- `core/cli/run.go`, `core/cli/backends.go`, `core/cli/models.go` — `--require-backend-integrity` flag propagation.
|
||||
- `.github/workflows/backend_merge.yml` — producer-side `cosign sign --recursive` after each multi-arch manifest list push.
|
||||
|
||||
## Out of scope (follow-ups)
|
||||
|
||||
- **Signing the gallery YAML itself.** The index is fetched over HTTPS
|
||||
from GitHub; we trust the host. A cosign blob signature on the YAML
|
||||
would close that gap but adds key-management overhead. Revisit this
|
||||
page if/when added.
|
||||
- **Tarball/HTTP backend signing.** Cosign can sign arbitrary blobs, but
|
||||
for now non-OCI backends keep using the `sha256:` field in YAML.
|
||||
@@ -4,6 +4,7 @@
|
||||
.devcontainer
|
||||
models
|
||||
backends
|
||||
volumes
|
||||
examples/chatbot-ui/models
|
||||
backend/go/image/stablediffusion-ggml/build/
|
||||
backend/go/*/build
|
||||
@@ -21,3 +22,11 @@ __pycache__
|
||||
# backend virtual environments
|
||||
**/venv
|
||||
backend/python/**/source
|
||||
|
||||
# In-place llama.cpp clone + per-variant build copies. The Makefile
|
||||
# clones llama.cpp itself at the pinned LLAMA_VERSION; if a stale
|
||||
# local checkout is COPY'd into the image, the `llama.cpp:` target
|
||||
# sees the directory and skips re-cloning, so grpc-server.cpp ends
|
||||
# up compiled against whatever (likely older) commit the host had.
|
||||
backend/cpp/llama-cpp/llama.cpp
|
||||
backend/cpp/llama-cpp-*-build
|
||||
|
||||
59
.github/workflows/backend_merge.yml
vendored
59
.github/workflows/backend_merge.yml
vendored
@@ -31,8 +31,20 @@ on:
|
||||
jobs:
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
# id-token: write is required for keyless cosign — the workflow
|
||||
# exchanges the GitHub OIDC token for a short-lived Fulcio cert that
|
||||
# signs each pushed manifest. Without this permission the runner
|
||||
# cannot mint the token, and `cosign sign` fails with "no token".
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
env:
|
||||
quay_username: ${{ secrets.quayUsername }}
|
||||
# cosign v2.4.x still gates --registry-referrers-mode=oci-1-1 behind
|
||||
# this flag. Without it, signing fails with:
|
||||
# invalid argument "oci-1-1" for "--registry-referrers-mode" flag:
|
||||
# in order to use mode "oci-1-1", you must set COSIGN_EXPERIMENTAL=1
|
||||
COSIGN_EXPERIMENTAL: '1'
|
||||
steps:
|
||||
# Sparse checkout: the merge job needs `.github/scripts/` (for the
|
||||
# keepalive cleanup script) but none of the source tree.
|
||||
@@ -57,6 +69,16 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@master
|
||||
|
||||
# cosign signs each pushed manifest list with --recursive so the
|
||||
# index and every per-arch entry get an attached Sigstore bundle.
|
||||
# Recent cosign releases always emit the new bundle format, so
|
||||
# there's no extra CLI flag to opt into it.
|
||||
- name: Install cosign
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: sigstore/cosign-installer@v3
|
||||
with:
|
||||
cosign-release: 'v2.4.1'
|
||||
|
||||
- name: Login to DockerHub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v4
|
||||
@@ -120,11 +142,25 @@ jobs:
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
if [ -z "$tags" ]; then
|
||||
echo "No quay.io tags from docker/metadata-action; skipping quay merge"
|
||||
else
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'quay.io/go-skynet/ci-cache@sha256:%s ' *)
|
||||
exit 0
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'quay.io/go-skynet/ci-cache@sha256:%s ' *)
|
||||
# Resolve the manifest-list digest (any tag points at it) so
|
||||
# cosign can sign by digest. Signing by tag would leave the
|
||||
# signature orphaned the next time the tag moves.
|
||||
first_tag=$(jq -cr '
|
||||
.tags | map(select(startswith("quay.io/"))) | .[0]
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||
# --recursive walks the list and signs every per-arch entry
|
||||
# too — clients that resolve a tag to a platform-specific
|
||||
# manifest before checking signatures need the per-arch
|
||||
# signatures, not just the list-level one.
|
||||
cosign sign --yes --recursive \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"quay.io/go-skynet/local-ai-backends@${digest}"
|
||||
|
||||
- name: Create manifest list and push (dockerhub)
|
||||
if: github.event_name != 'pull_request'
|
||||
@@ -139,11 +175,18 @@ jobs:
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
if [ -z "$tags" ]; then
|
||||
echo "No dockerhub tags from docker/metadata-action; skipping dockerhub merge"
|
||||
else
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'localai/localai-backends@sha256:%s ' *)
|
||||
exit 0
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'localai/localai-backends@sha256:%s ' *)
|
||||
first_tag=$(jq -cr '
|
||||
.tags | map(select(startswith("localai/"))) | .[0]
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||
cosign sign --yes --recursive \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"localai/localai-backends@${digest}"
|
||||
|
||||
- name: Inspect manifest
|
||||
if: github.event_name != 'pull_request'
|
||||
|
||||
1
.github/workflows/image_build.yml
vendored
1
.github/workflows/image_build.yml
vendored
@@ -106,6 +106,7 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
1
.github/workflows/image_merge.yml
vendored
1
.github/workflows/image_merge.yml
vendored
@@ -80,6 +80,7 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -26,6 +26,10 @@ go-bert
|
||||
LocalAI
|
||||
/local-ai
|
||||
/local-ai-launcher
|
||||
# Root-level build artifacts when running `go build ./...` against
|
||||
# Go backend packages whose main lives under backend/go/.
|
||||
/cloud-proxy
|
||||
/local-store
|
||||
# prevent above rules from omitting the helm chart
|
||||
!charts/*
|
||||
# prevent above rules from omitting the api/localai folder
|
||||
@@ -77,3 +81,6 @@ local-backends/
|
||||
tests/e2e-ui/ui-test-server
|
||||
core/http/react-ui/playwright-report/
|
||||
core/http/react-ui/test-results/
|
||||
|
||||
# Local worktrees
|
||||
.worktrees/
|
||||
|
||||
@@ -46,8 +46,52 @@ linters:
|
||||
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.Fail. See .agents/coding-style.md.'
|
||||
- pattern: '^t\.FailNow$'
|
||||
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.FailNow. See .agents/coding-style.md.'
|
||||
# In-process config should flow through ApplicationConfig / kong-bound
|
||||
# CLI flags, not via os.Getenv. The CLI layer is the legitimate
|
||||
# env→struct boundary (kong's `env:"..."` tag); anything deeper that
|
||||
# reads env directly leaks process state into business logic and
|
||||
# makes flags impossible to test or override per-request. Backend
|
||||
# subprocesses, the system/capabilities probe, and a few places that
|
||||
# read non-LocalAI env vars (HOME, PATH, AUTH_TOKEN passed by parent)
|
||||
# are exempt — see linters.exclusions.rules below.
|
||||
- pattern: '^os\.(Getenv|LookupEnv|Environ)$'
|
||||
msg: 'Plumb config through ApplicationConfig (or the relevant CLI struct) instead of reading env directly. CLI entry points (core/cli/) bind env vars via kong''s `env:` tag — that is the only sanctioned env→struct boundary. See .agents/coding-style.md.'
|
||||
exclusions:
|
||||
paths:
|
||||
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
||||
- 'backend/go/whisper/sources'
|
||||
- 'docs/'
|
||||
rules:
|
||||
# CLI entry points: kong's `env:"..."` tag is the legitimate env→struct
|
||||
# boundary, and a handful of subcommands legitimately propagate values
|
||||
# to spawned subprocesses (LLAMACPP_GRPC_SERVERS, MLX hostfile, ...).
|
||||
- path: ^core/cli/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# Backend subprocesses are independent binaries with their own env
|
||||
# surface; they're not "in-process config" of the LocalAI server.
|
||||
- path: ^backend/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# System capability probe reads HOME, PATH-style vars to discover
|
||||
# GPUs, default paths, etc. — not LocalAI config.
|
||||
- path: ^pkg/system/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# gRPC server reads AUTH_TOKEN passed in by the parent process at spawn
|
||||
# time; model.Loader sets/inherits env to communicate with subprocesses.
|
||||
- path: ^pkg/grpc/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
- path: ^pkg/model/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# Top-level main binaries (local-ai, launcher) are entry points.
|
||||
- path: ^cmd/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# Tests legitimately read $HOME, $TMPDIR, and gating env vars
|
||||
# (LOCALAI_COSIGN_LIVE, etc.) to skip live-network specs.
|
||||
- path: _test\.go$
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
|
||||
@@ -31,6 +31,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
| [.agents/adding-gallery-models.md](.agents/adding-gallery-models.md) | Adding GGUF models from HuggingFace to the model gallery |
|
||||
| [.agents/localai-assistant-mcp.md](.agents/localai-assistant-mcp.md) | LocalAI Assistant chat modality — adding admin tools to the in-process MCP server, editing skill prompts, keeping REST + MCP + skills in sync |
|
||||
| [.agents/backend-signing.md](.agents/backend-signing.md) | Backend OCI image signing (keyless cosign + sigstore-go) — producer-side CI setup, consumer-side gallery `verification:` block, strict mode (`LOCALAI_REQUIRE_BACKEND_INTEGRITY`), revocation via `not_before` |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
|
||||
15
Makefile
15
Makefile
@@ -69,7 +69,7 @@ else
|
||||
GORELEASER=$(shell which goreleaser)
|
||||
endif
|
||||
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/...
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/... ./backend/go/cloud-proxy/... ./backend/go/local-store/...
|
||||
|
||||
|
||||
.PHONY: all test build vendor lint lint-all
|
||||
@@ -268,12 +268,13 @@ prepare-e2e:
|
||||
run-e2e-image:
|
||||
docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests
|
||||
|
||||
test-e2e: build-mock-backend prepare-e2e run-e2e-image
|
||||
test-e2e: build-mock-backend build-cloud-proxy-backend prepare-e2e run-e2e-image
|
||||
@echo 'Running e2e tests'
|
||||
BUILD_TYPE=$(BUILD_TYPE) \
|
||||
LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
|
||||
$(MAKE) clean-mock-backend
|
||||
$(MAKE) clean-cloud-proxy-backend
|
||||
$(MAKE) teardown-e2e
|
||||
docker rmi localai-tests
|
||||
|
||||
@@ -1064,6 +1065,7 @@ BACKEND_DS4 = ds4|ds4|.|false|false
|
||||
# Golang backends
|
||||
BACKEND_PIPER = piper|golang|.|false|true
|
||||
BACKEND_LOCAL_STORE = local-store|golang|.|false|true
|
||||
BACKEND_CLOUD_PROXY = cloud-proxy|golang|.|false|true
|
||||
BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
@@ -1149,6 +1151,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DS4)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CLOUD_PROXY)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
@@ -1201,7 +1204,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
@@ -1213,6 +1216,12 @@ build-mock-backend: protogen-go
|
||||
clean-mock-backend:
|
||||
rm -f tests/e2e/mock-backend/mock-backend
|
||||
|
||||
build-cloud-proxy-backend: protogen-go
|
||||
$(GOCMD) build -o tests/e2e/mock-backend/cloud-proxy ./backend/go/cloud-proxy
|
||||
|
||||
clean-cloud-proxy-backend:
|
||||
rm -f tests/e2e/mock-backend/cloud-proxy
|
||||
|
||||
########################################################
|
||||
### UI E2E Test Server
|
||||
########################################################
|
||||
|
||||
@@ -37,6 +37,22 @@ service Backend {
|
||||
|
||||
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
||||
|
||||
// TokenClassify runs a token-classification (NER) model on the
|
||||
// supplied text and returns each detected entity span. Used by the
|
||||
// PII redactor's optional NER tier — the regex tier still handles
|
||||
// formatted hits cheaply, while this catches names, locations, and
|
||||
// other unformatted PII that regex misses.
|
||||
rpc TokenClassify(TokenClassifyRequest) returns (TokenClassifyResponse) {}
|
||||
|
||||
// Score evaluates the model's joint log-probability of each
|
||||
// supplied candidate continuation given a shared prompt. The
|
||||
// prompt's KV cache is computed once and reused across candidates.
|
||||
// Used for routing-policy multi-label classification, reranking,
|
||||
// calibrated confidence, and reward-model scoring — any task where
|
||||
// the consumer wants the model's confidence in a pre-specified
|
||||
// continuation rather than a generated one.
|
||||
rpc Score(ScoreRequest) returns (ScoreResponse) {}
|
||||
|
||||
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
|
||||
|
||||
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||
@@ -68,6 +84,23 @@ service Backend {
|
||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||
|
||||
// Forward proxies a raw HTTP request to an upstream provider. The
|
||||
// cloud-proxy backend implements this for passthrough-mode model
|
||||
// configs: the client wire format is preserved end-to-end (no
|
||||
// translation through internal proto), which means new provider
|
||||
// fields work the day they ship. Translation-mode proxies use the
|
||||
// standard Predict/PredictStream RPCs instead. Backends that don't
|
||||
// support this return UNIMPLEMENTED.
|
||||
//
|
||||
// The request is bidirectionally streamed so large bodies can flow
|
||||
// without buffering. In practice the first ForwardRequest carries
|
||||
// path, method, headers, and the initial body chunk; subsequent
|
||||
// messages append body chunks. The first ForwardReply carries the
|
||||
// upstream status and response headers; subsequent messages stream
|
||||
// body chunks (SSE frames or chunked transfer). Cancellation of the
|
||||
// gRPC context closes the upstream connection.
|
||||
rpc Forward(stream ForwardRequest) returns (stream ForwardReply) {}
|
||||
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -81,6 +114,76 @@ message MetricsResponse {
|
||||
int32 prompt_tokens_processed = 5;
|
||||
}
|
||||
|
||||
// TokenClassifyRequest carries the text to classify plus an optional
|
||||
// score threshold. The transformers backend interprets threshold as
|
||||
// the minimum confidence to include in the response; 0 = include all.
|
||||
message TokenClassifyRequest {
|
||||
string text = 1;
|
||||
float threshold = 2;
|
||||
}
|
||||
|
||||
// TokenClassifyEntity is one detected entity span. Byte offsets are
|
||||
// into the original UTF-8 text — start..end is a half-open range that
|
||||
// addresses the substring corresponding to entity_group.
|
||||
//
|
||||
// entity_group follows HuggingFace's aggregated-tag convention (e.g.
|
||||
// "PER", "LOC", "ORG", or a PII-specific label like "EMAIL" /
|
||||
// "SSN" depending on the model). The redactor's per-pattern action
|
||||
// map keys off this string.
|
||||
message TokenClassifyEntity {
|
||||
string entity_group = 1;
|
||||
int32 start = 2;
|
||||
int32 end = 3;
|
||||
float score = 4;
|
||||
string text = 5;
|
||||
}
|
||||
|
||||
message TokenClassifyResponse {
|
||||
repeated TokenClassifyEntity entities = 1;
|
||||
}
|
||||
|
||||
// ScoreRequest carries one shared prompt and one or more continuations
|
||||
// to score against it. The backend tokenises the prompt once and reuses
|
||||
// the resulting KV cache across all candidates in this request.
|
||||
message ScoreRequest {
|
||||
string prompt = 1;
|
||||
repeated string candidates = 2;
|
||||
// Return per-token logprobs for each candidate when true. Default
|
||||
// false to keep the wire response small; the joint log_prob field
|
||||
// covers the common ranking case.
|
||||
bool include_token_logprobs = 3;
|
||||
// When true, the response also populates length_normalized_log_prob
|
||||
// (joint log-prob divided by candidate token count). Useful when
|
||||
// candidates differ in length and the consumer wants a per-token
|
||||
// measure comparable across them (PMI-style scoring).
|
||||
bool length_normalize = 4;
|
||||
}
|
||||
|
||||
// CandidateScore is one row in the ScoreResponse, matching by index
|
||||
// the candidate in ScoreRequest.candidates.
|
||||
message CandidateScore {
|
||||
// Sum of log P(token_i | prompt, candidate_token_<i) across the
|
||||
// candidate's tokens. The primary ranking signal.
|
||||
double log_prob = 1;
|
||||
// log_prob / num_tokens — populated when length_normalize=true on
|
||||
// the request.
|
||||
double length_normalized_log_prob = 2;
|
||||
// Per-token detail — populated when include_token_logprobs=true.
|
||||
repeated TokenLogProb tokens = 3;
|
||||
// Number of tokens the backend tokenised this candidate into, after
|
||||
// any backend-specific normalisation (e.g. leading-space handling).
|
||||
int32 num_tokens = 4;
|
||||
}
|
||||
|
||||
message TokenLogProb {
|
||||
string token = 1;
|
||||
double log_prob = 2;
|
||||
}
|
||||
|
||||
message ScoreResponse {
|
||||
repeated CandidateScore candidates = 1;
|
||||
}
|
||||
|
||||
message RerankRequest {
|
||||
string query = 1;
|
||||
repeated string documents = 2;
|
||||
@@ -325,6 +428,25 @@ message ModelOptions {
|
||||
// applied verbatim to the backend's engine constructor (e.g. vLLM AsyncEngineArgs).
|
||||
// Unknown keys produce an error at LoadModel time.
|
||||
string EngineArgs = 73;
|
||||
|
||||
// Proxy carries the cloud-proxy backend's per-model configuration.
|
||||
// Empty for non-proxy backends.
|
||||
ProxyOptions Proxy = 74;
|
||||
}
|
||||
|
||||
// ProxyOptions configures the cloud-proxy backend. UpstreamURL and
|
||||
// Mode are always meaningful; Provider only matters in translate mode.
|
||||
// The two api_key_* fields are mutually exclusive and resolved by the
|
||||
// backend at LoadModel — core forwards the references rather than the
|
||||
// plaintext key.
|
||||
message ProxyOptions {
|
||||
string upstream_url = 1;
|
||||
string mode = 2;
|
||||
string provider = 3;
|
||||
string api_key_env = 4;
|
||||
string api_key_file = 5;
|
||||
string upstream_model = 6;
|
||||
int32 request_timeout_seconds = 7;
|
||||
}
|
||||
|
||||
message Result {
|
||||
@@ -1002,3 +1124,32 @@ message QuantizationStopRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
// ForwardHeader is one HTTP header on the request or response. Headers
|
||||
// like Authorization are typically injected by the backend (from the
|
||||
// resolved API key) rather than passed through from the client.
|
||||
message ForwardHeader {
|
||||
string name = 1;
|
||||
string value = 2;
|
||||
}
|
||||
|
||||
// ForwardRequest is a streamed HTTP request to the upstream. First
|
||||
// message carries path/method/headers; subsequent messages carry
|
||||
// body_chunk only. All fields except body_chunk are honoured on the
|
||||
// first message and ignored thereafter.
|
||||
message ForwardRequest {
|
||||
string path = 1; // e.g. "/v1/chat/completions" — appended to the model's upstream_url
|
||||
string method = 2; // usually "POST"
|
||||
repeated ForwardHeader headers = 3;
|
||||
bytes body_chunk = 4;
|
||||
}
|
||||
|
||||
// ForwardReply is a streamed HTTP response from the upstream. First
|
||||
// message carries status/headers; subsequent messages carry body_chunk
|
||||
// only. SSE responses arrive as a sequence of body_chunk frames; the
|
||||
// caller is responsible for any parsing.
|
||||
message ForwardReply {
|
||||
int32 status = 1;
|
||||
repeated ForwardHeader headers = 2;
|
||||
bytes body_chunk = 3;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=ef0a4905d05263df8e63689f2dd1efac618a752c
|
||||
# Upstream pin lives below as DS4_VERSION?=f91c12b50a1448527c435c028bfc70d1b00f6c33
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=ef0a4905d05263df8e63689f2dd1efac618a752c
|
||||
DS4_VERSION?=f91c12b50a1448527c435c028bfc70d1b00f6c33
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=3e573cfea6e0a332eff822ffbdb1dd3b112e9051
|
||||
IK_LLAMA_VERSION?=9f7ba245ab41e118f03aa8dd5134d18a81159d02
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=0253fb21f595246f54c192fe8332f34173be251b
|
||||
LLAMA_VERSION?=549b9d84330c327e6791fa812a7d60c0cf63572e
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
@@ -121,6 +122,40 @@ static std::string base64_encode_bytes(const unsigned char* data, size_t len) {
|
||||
|
||||
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
|
||||
|
||||
// Score bypasses the slot loop (see the comment on Score below) so it
|
||||
// must not run concurrently with any slot-loop RPC. These counters
|
||||
// are a defence-in-depth tripwire — ModelConfig.Validate already
|
||||
// rejects llama-cpp configs that mix score with chat/completion/
|
||||
// embeddings, so a healthy deployment never trips them. seq_cst is
|
||||
// load-bearing for the increment-then-check pattern below.
|
||||
static std::atomic<int> slot_loop_inflight{0};
|
||||
static std::atomic<int> score_inflight{0};
|
||||
|
||||
// Increment-then-check, not check-then-increment: two simultaneous
|
||||
// racers both observe the other's increment and both abort cleanly.
|
||||
// Reversed, both could see zero and proceed.
|
||||
struct conflict_guard {
|
||||
std::atomic<int>& self;
|
||||
conflict_guard(const char* rpc, std::atomic<int>& self_, std::atomic<int>& other, const char* other_name)
|
||||
: self(self_) {
|
||||
self.fetch_add(1, std::memory_order_seq_cst);
|
||||
int o = other.load(std::memory_order_seq_cst);
|
||||
if (o > 0) {
|
||||
fprintf(stderr,
|
||||
"FATAL: %s called with %s=%d. The llama-cpp backend cannot "
|
||||
"service Score and slot-loop RPCs concurrently — Score "
|
||||
"bypasses the slot loop and races the llama_context. Bind "
|
||||
"Score-using features to a model dedicated to scoring "
|
||||
"(known_usecases: [score] with no chat/completion/embeddings).\n",
|
||||
rpc, other_name, o);
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
~conflict_guard() {
|
||||
self.fetch_sub(1, std::memory_order_seq_cst);
|
||||
}
|
||||
};
|
||||
|
||||
static std::function<void(int)> shutdown_handler;
|
||||
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
||||
|
||||
@@ -517,16 +552,27 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.warmup = true;
|
||||
// no_op_offload: disable host tensor op offload (default: false)
|
||||
params.no_op_offload = false;
|
||||
// kv_unified: enable unified KV cache (default: false)
|
||||
params.kv_unified = false;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot (default: 8)
|
||||
params.n_ctx_checkpoints = 8;
|
||||
|
||||
// llama memory fit fails if we don't provide a buffer for tensor overrides
|
||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||
while (params.tensor_buft_overrides.size() < ntbo) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
// kv_unified: enable unified KV cache. Upstream's server auto-enables this
|
||||
// when the slot count is auto (-np <0), bumping n_parallel to 4 alongside.
|
||||
// LocalAI keeps n_parallel=1 by default, which would skip that auto path
|
||||
// and leave kv_unified=false. We flip the default to true here so the
|
||||
// server-side prompt cache (cache_idle_slots) is actually usable on the
|
||||
// single-slot path that LocalAI ships with: without it, idle slots are
|
||||
// never persisted across requests and the prompt cache is dead weight.
|
||||
// Users can opt out with `options: [ "kv_unified:false" ]`.
|
||||
params.kv_unified = true;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot. Match upstream's
|
||||
// default (32); the previous LocalAI-specific 8 was unnecessarily tight
|
||||
// and limits partial-prefix recovery without a clear memory rationale.
|
||||
params.n_ctx_checkpoints = 32;
|
||||
// cache_idle_slots: save and clear idle slot KV to the prompt cache on
|
||||
// task switch. Upstream default is true; the server auto-disables it if
|
||||
// kv_unified=false or cache_ram_mib=0, so flipping kv_unified above is
|
||||
// what actually unlocks it.
|
||||
params.cache_idle_slots = true;
|
||||
// checkpoint_every_nt: create a context checkpoint every N tokens during
|
||||
// prefill (-1 disables). Match upstream's default (8192).
|
||||
params.checkpoint_every_nt = 8192;
|
||||
|
||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||
for (int i = 0; i < request->options_size(); i++) {
|
||||
@@ -685,7 +731,29 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
try {
|
||||
params.n_ctx_checkpoints = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (8)
|
||||
// If conversion fails, keep default value (32)
|
||||
}
|
||||
}
|
||||
|
||||
// --- server-side idle-slot prompt cache toggle (upstream --cache-idle-slots) ---
|
||||
// Saves the slot's KV state into the host-side prompt cache on task
|
||||
// switch so a later request with the same prefix can warm-load it.
|
||||
// Auto-disabled by the server if kv_unified=false or cache_ram=0.
|
||||
} else if (!strcmp(optname, "cache_idle_slots") || !strcmp(optname, "idle_slots_cache")) {
|
||||
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||
params.cache_idle_slots = true;
|
||||
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
// --- prefill checkpoint cadence (upstream -cpent / --checkpoint-every-n-tokens) ---
|
||||
// -1 disables checkpointing during prefill.
|
||||
} else if (!strcmp(optname, "checkpoint_every_nt") || !strcmp(optname, "checkpoint_every_n_tokens")) {
|
||||
if (optval != NULL) {
|
||||
try {
|
||||
params.checkpoint_every_nt = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (8192)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1081,6 +1149,20 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.kv_overrides.back().key[0] = 0;
|
||||
}
|
||||
|
||||
// tensor_buft_overrides sentinel termination (mirrors upstream common/arg.cpp).
|
||||
// Real entries are pushed during option parsing; here we pad/terminate so the
|
||||
// model loader sees back().pattern == nullptr (GGML_ASSERT at common.cpp:1543)
|
||||
// and so llama_params_fit has the placeholder slots it requires.
|
||||
{
|
||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||
while (params.tensor_buft_overrides.size() < ntbo) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
}
|
||||
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
|
||||
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
|
||||
// TODO: Add yarn
|
||||
|
||||
if (!request->tensorsplit().empty()) {
|
||||
@@ -1399,6 +1481,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("PredictStream", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
|
||||
@@ -2158,6 +2241,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Predict", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
data["stream"] = false;
|
||||
@@ -2916,6 +3000,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Embedding", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
body["stream"] = false;
|
||||
@@ -3023,6 +3108,8 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||
}
|
||||
|
||||
conflict_guard guard("Rerank", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// Create and queue the task
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
{
|
||||
@@ -3095,12 +3182,218 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
// Score returns the model's joint log-probability of each candidate
|
||||
// continuation given a shared prompt.
|
||||
//
|
||||
// WHY bypass the slot/task queue: upstream server_context exposes
|
||||
// get_llama_context as "main thread only" and the slot loop's
|
||||
// update_slots() owns the context whenever a task is in flight.
|
||||
// No public synchronization primitive is available — so Score is
|
||||
// unsafe to call concurrently with active generation through this
|
||||
// backend. In practice routing-classifier calls happen before the
|
||||
// request is routed to a generation backend, so the model used
|
||||
// for Score is typically idle. Concurrent Score calls are
|
||||
// serialised by a local mutex; KV-cache state is isolated behind
|
||||
// a dedicated sequence ID cleared between candidates.
|
||||
//
|
||||
// A patch to server-context.cpp that adds SERVER_TASK_TYPE_SCORE
|
||||
// and routes scoring through the slot loop would be the correct
|
||||
// long-term fix; tracked as a follow-up.
|
||||
//
|
||||
// Perf TODO (measured: ~450 ms warm for 3 candidates on Arch-
|
||||
// Router-1.5B Q4_K_M + Intel SYCL): the current loop re-decodes
|
||||
// `prompt + candidate` from scratch for every candidate, throwing
|
||||
// away the prompt's KV cache between iterations. A smarter
|
||||
// version would:
|
||||
// 1. Decode just the prompt once into score_seq_id.
|
||||
// 2. Snapshot/cp that sequence (llama_memory_seq_cp) into a
|
||||
// per-candidate sequence id.
|
||||
// 3. For each candidate, decode only its tokens onto the copy
|
||||
// (continuing from the saved prompt state), read logits.
|
||||
// 4. llama_memory_seq_rm the copy.
|
||||
// Estimated speedup: 3-candidate calls 450 ms -> ~150-200 ms,
|
||||
// 6-candidate calls 630 ms -> ~220 ms. Single source-file change,
|
||||
// no proto / Go-side changes needed. Worth doing once routing is
|
||||
// wired into the middleware and Score is on the hot path of every
|
||||
// chat request.
|
||||
grpc::Status Score(ServerContext* context, const backend::ScoreRequest* request, backend::ScoreResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
if (request->candidates_size() == 0) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "candidates must be non-empty");
|
||||
}
|
||||
|
||||
// Tripwire against the slot loop. Acquired before score_mutex
|
||||
// so it fires even when this Score is queued behind another.
|
||||
conflict_guard guard("Score", score_inflight, slot_loop_inflight, "slot_loop_inflight");
|
||||
|
||||
// Serialise concurrent Score calls. The slot loop is still
|
||||
// free to race with us — see the class comment above.
|
||||
static std::mutex score_mutex;
|
||||
std::lock_guard<std::mutex> score_lock(score_mutex);
|
||||
|
||||
llama_context * lctx = ctx_server.get_llama_context();
|
||||
if (lctx == nullptr) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "llama context unavailable (sleeping?)");
|
||||
}
|
||||
const llama_vocab * vocab = ctx_server.impl->vocab;
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
const int32_t n_ctx = llama_n_ctx(lctx);
|
||||
llama_memory_t mem = llama_get_memory(lctx);
|
||||
|
||||
// The KV-cache is sized to seq_to_stream.size() at load
|
||||
// (typically equal to n_slots, often 1). Sequence IDs must
|
||||
// be in [0, n_seq_max), so we can't pick a high-value
|
||||
// "private" ID — we have to share with the slot. We clear
|
||||
// the cache before AND after each candidate to keep
|
||||
// scoring isolated from whatever state the slot held, and
|
||||
// the static mutex above guarantees no other Score call is
|
||||
// racing in the meantime. The slot loop is still free to
|
||||
// race (see comment on this method) — Score must not run
|
||||
// concurrently with generation through this backend.
|
||||
const llama_seq_id score_seq_id = 0;
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
|
||||
// Tokenize the shared prompt once with add_special=true so
|
||||
// BOS is prepended when the model requires it. parse_special
|
||||
// keeps chat-template markers in the prompt intact.
|
||||
const std::string prompt = request->prompt();
|
||||
std::vector<llama_token> prompt_tokens = common_tokenize(vocab, prompt, /*add_special=*/true, /*parse_special=*/true);
|
||||
const int32_t prompt_len = (int32_t) prompt_tokens.size();
|
||||
|
||||
for (int ci = 0; ci < request->candidates_size(); ci++) {
|
||||
const std::string & candidate_text = request->candidates(ci);
|
||||
|
||||
// Re-tokenize prompt + candidate as a single string. BPE
|
||||
// merges across the boundary can shift the tokenization
|
||||
// versus tokenize(prompt) ++ tokenize(candidate), so we
|
||||
// find the divergence point against prompt_tokens.
|
||||
std::vector<llama_token> full_tokens = common_tokenize(vocab, prompt + candidate_text, /*add_special=*/true, /*parse_special=*/true);
|
||||
int32_t divergence = prompt_len;
|
||||
const int32_t min_len = std::min<int32_t>(prompt_len, (int32_t) full_tokens.size());
|
||||
for (int32_t i = 0; i < min_len; i++) {
|
||||
if (prompt_tokens[i] != full_tokens[i]) {
|
||||
divergence = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const int32_t cand_len = (int32_t) full_tokens.size() - divergence;
|
||||
backend::CandidateScore * cs = response->add_candidates();
|
||||
cs->set_num_tokens(cand_len);
|
||||
if (cand_len <= 0) {
|
||||
cs->set_log_prob(0.0);
|
||||
if (request->length_normalize()) {
|
||||
cs->set_length_normalized_log_prob(0.0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (divergence < 1) {
|
||||
// Need at least one prior token (typically BOS) to
|
||||
// predict the first candidate token's logit. Tokeniser
|
||||
// models without BOS + an empty prompt fall in here.
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
|
||||
"Score: prompt produced no leading tokens; need at least one (e.g. BOS) to predict candidate");
|
||||
}
|
||||
if ((int32_t) full_tokens.size() > n_ctx) {
|
||||
return grpc::Status(grpc::StatusCode::OUT_OF_RANGE,
|
||||
"Score: prompt+candidate exceeds context size (got " +
|
||||
std::to_string(full_tokens.size()) + ", n_ctx=" + std::to_string(n_ctx) + ")");
|
||||
}
|
||||
|
||||
// Build a batch covering the entire prompt+candidate. We
|
||||
// need logits at (divergence-1) onward — those are the
|
||||
// predictions for each candidate token.
|
||||
llama_batch batch = llama_batch_init((int32_t) full_tokens.size(), 0, 1);
|
||||
for (int32_t i = 0; i < (int32_t) full_tokens.size(); i++) {
|
||||
batch.token[i] = full_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = score_seq_id;
|
||||
// logits[i] is "do we want the prediction *for the
|
||||
// next token*, computed from this position?"
|
||||
// We want predictions for candidate tokens at
|
||||
// positions divergence .. full_tokens.size()-1, which
|
||||
// come from logits at positions (divergence-1) ..
|
||||
// (full_tokens.size()-2).
|
||||
bool need_logit = (i >= divergence - 1) && (i < (int32_t) full_tokens.size() - 1);
|
||||
batch.logits[i] = need_logit ? 1 : 0;
|
||||
}
|
||||
batch.n_tokens = (int32_t) full_tokens.size();
|
||||
|
||||
// Decode the batch. If decode fails (e.g. KV slot
|
||||
// exhaustion), surface as INTERNAL — the caller will
|
||||
// typically fall back to a sampling-based classifier.
|
||||
int decode_err = llama_decode(lctx, batch);
|
||||
if (decode_err != 0) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_decode failed during Score: " + std::to_string(decode_err));
|
||||
}
|
||||
|
||||
// Sum log-probabilities of the actual candidate tokens.
|
||||
double total_log_prob = 0.0;
|
||||
for (int32_t k = 0; k < cand_len; k++) {
|
||||
// The k-th candidate token sits at full_tokens index
|
||||
// (divergence + k). Its predicting logit is at batch
|
||||
// position (divergence + k - 1).
|
||||
int32_t logit_pos = divergence + k - 1;
|
||||
const float * logits = llama_get_logits_ith(lctx, logit_pos);
|
||||
if (logits == nullptr) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_get_logits_ith returned null at position " + std::to_string(logit_pos));
|
||||
}
|
||||
llama_token target_token = full_tokens[divergence + k];
|
||||
|
||||
// Compute log_softmax(logits)[target_token] with the
|
||||
// max-subtraction stability trick.
|
||||
float max_logit = logits[0];
|
||||
for (int32_t v = 1; v < n_vocab; v++) {
|
||||
if (logits[v] > max_logit) max_logit = logits[v];
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (int32_t v = 0; v < n_vocab; v++) {
|
||||
sum_exp += std::exp((double)(logits[v] - max_logit));
|
||||
}
|
||||
double token_log_prob = (double)(logits[target_token] - max_logit) - std::log(sum_exp);
|
||||
total_log_prob += token_log_prob;
|
||||
|
||||
if (request->include_token_logprobs()) {
|
||||
backend::TokenLogProb * tlp = cs->add_tokens();
|
||||
std::string piece = common_token_to_piece(lctx, target_token);
|
||||
tlp->set_token(piece);
|
||||
tlp->set_log_prob(token_log_prob);
|
||||
}
|
||||
}
|
||||
|
||||
cs->set_log_prob(total_log_prob);
|
||||
if (request->length_normalize() && cand_len > 0) {
|
||||
cs->set_length_normalized_log_prob(total_log_prob / (double) cand_len);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
// Drop this candidate's KV-cache contribution so the next
|
||||
// candidate starts from a clean state. Without this, the
|
||||
// next decode would conflict at positions 0..N-1 for our
|
||||
// sequence ID.
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
}
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("TokenizeString", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
body["stream"] = false;
|
||||
|
||||
@@ -3122,6 +3415,8 @@ public:
|
||||
|
||||
grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override {
|
||||
|
||||
conflict_guard guard("GetMetrics", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// request slots data using task queue
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
int task_id = rd.queue_tasks.get_new_id();
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# acestep.cpp version
|
||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||
ACESTEP_CPP_VERSION?=e0c8d75a672fca5684c88c68dbf6d12f58754258
|
||||
ACESTEP_CPP_VERSION?=ed53caf164e4492a5620b2e3f2264629cf66da24
|
||||
SO_TARGET?=libgoacestepcpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -22,12 +22,11 @@
|
||||
#include <vector>
|
||||
|
||||
// Global model contexts (loaded once, reused across requests)
|
||||
static DiTGGML g_dit = {};
|
||||
static DiTGGMLConfig g_dit_cfg;
|
||||
static VAEGGML g_vae = {};
|
||||
static bool g_dit_loaded = false;
|
||||
static bool g_vae_loaded = false;
|
||||
static bool g_is_turbo = false;
|
||||
static DiTGGML g_dit = {};
|
||||
static VAEGGML g_vae = {};
|
||||
static bool g_dit_loaded = false;
|
||||
static bool g_vae_loaded = false;
|
||||
static bool g_is_turbo = false;
|
||||
|
||||
// Silence latent [15000, 64] — read once from DiT GGUF
|
||||
static std::vector<float> g_silence_full;
|
||||
@@ -72,10 +71,9 @@ int load_model(const char * lm_model_path, const char * text_encoder_path,
|
||||
g_text_enc_path = text_encoder_path;
|
||||
g_dit_path = dit_model_path;
|
||||
|
||||
// Load DiT model
|
||||
// Load DiT model (backend init + config are handled inside dit_ggml_load)
|
||||
fprintf(stderr, "[acestep-cpp] Loading DiT from %s\n", dit_model_path);
|
||||
dit_ggml_init_backend(&g_dit);
|
||||
if (!dit_ggml_load(&g_dit, dit_model_path, g_dit_cfg, nullptr, 0.0f)) {
|
||||
if (!dit_ggml_load(&g_dit, dit_model_path)) {
|
||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load DiT from %s\n", dit_model_path);
|
||||
return 1;
|
||||
}
|
||||
@@ -149,16 +147,16 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
||||
|
||||
// Compute T (latent frames at 25Hz)
|
||||
int T = (int)(duration * FRAMES_PER_SECOND);
|
||||
T = ((T + g_dit_cfg.patch_size - 1) / g_dit_cfg.patch_size) * g_dit_cfg.patch_size;
|
||||
int S = T / g_dit_cfg.patch_size;
|
||||
T = ((T + g_dit.cfg.patch_size - 1) / g_dit.cfg.patch_size) * g_dit.cfg.patch_size;
|
||||
int S = T / g_dit.cfg.patch_size;
|
||||
|
||||
if (T > 15000) {
|
||||
fprintf(stderr, "[acestep-cpp] ERROR: T=%d exceeds max 15000\n", T);
|
||||
return 2;
|
||||
}
|
||||
|
||||
int Oc = g_dit_cfg.out_channels; // 64
|
||||
int ctx_ch = g_dit_cfg.in_channels - Oc; // 128
|
||||
int Oc = g_dit.cfg.out_channels; // 64
|
||||
int ctx_ch = g_dit.cfg.in_channels - Oc; // 128
|
||||
|
||||
fprintf(stderr, "[acestep-cpp] T=%d, S=%d, duration=%.1fs, seed=%d\n", T, S, duration, seed);
|
||||
|
||||
@@ -191,9 +189,8 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
||||
|
||||
fprintf(stderr, "[acestep-cpp] caption: %d tokens, lyrics: %d tokens\n", S_text, S_lyric);
|
||||
|
||||
// 4. Text encoder forward
|
||||
// 4. Text encoder forward (backend init handled inside qwen3_load_text_encoder)
|
||||
Qwen3GGML text_enc = {};
|
||||
qwen3_init_backend(&text_enc);
|
||||
if (!qwen3_load_text_encoder(&text_enc, g_text_enc_path.c_str())) {
|
||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load text encoder\n");
|
||||
return 4;
|
||||
@@ -209,9 +206,8 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
||||
std::vector<float> lyric_embed(H_text * S_lyric);
|
||||
qwen3_embed_lookup(&text_enc, lyric_ids.data(), S_lyric, lyric_embed.data());
|
||||
|
||||
// 6. Condition encoder
|
||||
// 6. Condition encoder (backend init handled inside cond_ggml_load)
|
||||
CondGGML cond = {};
|
||||
cond_ggml_init_backend(&cond);
|
||||
if (!cond_ggml_load(&cond, g_dit_path.c_str())) {
|
||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load condition encoder\n");
|
||||
qwen3_free(&text_enc);
|
||||
|
||||
12
backend/go/cloud-proxy/Makefile
Normal file
12
backend/go/cloud-proxy/Makefile
Normal file
@@ -0,0 +1,12 @@
|
||||
GOCMD=go
|
||||
|
||||
cloud-proxy:
|
||||
CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o cloud-proxy ./
|
||||
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
build: cloud-proxy package
|
||||
|
||||
clean:
|
||||
rm -f cloud-proxy
|
||||
16
backend/go/cloud-proxy/cloud_proxy_suite_test.go
Normal file
16
backend/go/cloud-proxy/cloud_proxy_suite_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Ginkgo bootstrap. The other Test* functions in this package use
|
||||
// raw testing.T and run independently; they coexist with Ginkgo
|
||||
// specs registered via Describe / Context.
|
||||
func TestCloudProxySpecs(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "cloud-proxy specs")
|
||||
}
|
||||
39
backend/go/cloud-proxy/main.go
Normal file
39
backend/go/cloud-proxy/main.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package main
|
||||
|
||||
// cloud-proxy is a LocalAI backend that forwards request traffic to an
|
||||
// external HTTP provider (OpenAI, Anthropic, etc.). Two modes:
|
||||
//
|
||||
// - passthrough: serves the Forward RPC; the client wire format is
|
||||
// preserved end-to-end, no translation.
|
||||
// - translate: serves Predict/PredictStream; the backend converts
|
||||
// internal proto to the provider's wire format. (Phases 5–6.)
|
||||
//
|
||||
// LoadModel reads UpstreamURL/Mode/Provider/key references from
|
||||
// ProxyOptions and resolves the API key once at load time.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/xlog"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:50051", "the address to listen on")
|
||||
|
||||
func main() {
|
||||
// xlog's default handler emits ANSI color codes; that's fine for an
|
||||
// interactive shell but unreadable when the backend's stdout is
|
||||
// captured by LocalAI and tee'd to a log file. Force plain text when
|
||||
// LOCALAI_LOG_FORMAT is unset and stdout isn't a terminal.
|
||||
format := os.Getenv("LOCALAI_LOG_FORMAT")
|
||||
if format == "" && !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
format = xlog.TextFormat
|
||||
}
|
||||
xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), format))
|
||||
flag.Parse()
|
||||
if err := grpc.StartServer(*addr, NewCloudProxy()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
13
backend/go/cloud-proxy/package.sh
Executable file
13
backend/go/cloud-proxy/package.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the cloud-proxy binary into the package dir for the
|
||||
# final Dockerfile stage. Mirrors backend/go/local-store/package.sh —
|
||||
# no extra runtime libs needed since the backend is pure Go.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
mkdir -p $CURDIR/package
|
||||
cp -avf $CURDIR/cloud-proxy $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
270
backend/go/cloud-proxy/passthrough_edge_test.go
Normal file
270
backend/go/cloud-proxy/passthrough_edge_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("composeURL", func() {
|
||||
// Upstream URL convention: gallery configs put the canonical path
|
||||
// in upstream_url, so per-request Path is ignored. A bare-host
|
||||
// upstream_url accepts the per-request path.
|
||||
DescribeTable("path resolution",
|
||||
func(upstream, reqPath, want string) {
|
||||
got, err := composeURL(upstream, reqPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(got).To(Equal(want))
|
||||
},
|
||||
Entry("full path wins", "https://api.openai.com/v1/chat/completions", "/v1/something-else", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host accepts path", "https://api.openai.com", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("root slash treated as bare", "https://api.openai.com/", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host + empty path", "https://api.openai.com", "", "https://api.openai.com"),
|
||||
)
|
||||
|
||||
It("returns an error on invalid upstream URL", func() {
|
||||
_, err := composeURL("://garbage", "")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("applyAuthHeader", func() {
|
||||
It("sets x-api-key and anthropic-version for Anthropic, no Authorization", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerAnthropic, "ant-key")
|
||||
Expect(req.Header.Get("x-api-key")).To(Equal("ant-key"))
|
||||
Expect(req.Header.Get("anthropic-version")).NotTo(BeEmpty())
|
||||
Expect(req.Header.Get("Authorization")).To(BeEmpty(), "Authorization must not leak on Anthropic backend")
|
||||
})
|
||||
|
||||
It("sets Bearer Authorization for OpenAI, no x-api-key", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerOpenAI, "sk-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer sk-key"))
|
||||
Expect(req.Header.Get("x-api-key")).To(BeEmpty(), "x-api-key must not leak on OpenAI backend")
|
||||
})
|
||||
|
||||
It("defaults to Bearer when provider is empty", func() {
|
||||
// Passthrough mode often has provider == "" because the operator
|
||||
// doesn't claim a specific upstream wire format. Most providers
|
||||
// (including OpenAI-compatible ones) accept Bearer, so default to it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, "", "some-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer some-key"))
|
||||
})
|
||||
|
||||
It("preserves an existing anthropic-version header", func() {
|
||||
// If the client supplied anthropic-version (rare but legitimate
|
||||
// for an upstream pinned to a specific date), the proxy must not
|
||||
// clobber it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
req.Header.Set("anthropic-version", "2024-10-01")
|
||||
applyAuthHeader(req, providerAnthropic, "k")
|
||||
Expect(req.Header.Get("anthropic-version")).To(Equal("2024-10-01"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("isHopByHopHeader", func() {
|
||||
DescribeTable("hop-by-hop classification",
|
||||
func(header string, want bool) {
|
||||
Expect(isHopByHopHeader(header)).To(Equal(want))
|
||||
},
|
||||
Entry("Connection is hop-by-hop", "Connection", true),
|
||||
Entry("Keep-Alive is hop-by-hop", "Keep-Alive", true),
|
||||
Entry("Proxy-Connection is hop-by-hop", "Proxy-Connection", true),
|
||||
Entry("Transfer-Encoding is hop-by-hop", "Transfer-Encoding", true),
|
||||
Entry("TE is hop-by-hop", "TE", true),
|
||||
Entry("Trailer is hop-by-hop", "Trailer", true),
|
||||
Entry("Upgrade is hop-by-hop", "Upgrade", true),
|
||||
Entry("Host is hop-by-hop", "Host", true),
|
||||
Entry("Content-Length is hop-by-hop", "Content-Length", true),
|
||||
// Case-insensitive — RFC 7230 doesn't constrain header case.
|
||||
Entry("lowercase connection is hop-by-hop", "connection", true),
|
||||
Entry("uppercase HOST is hop-by-hop", "HOST", true),
|
||||
// Non hop-by-hop — must NOT be stripped.
|
||||
Entry("Authorization is end-to-end", "Authorization", false),
|
||||
Entry("Content-Type is end-to-end", "Content-Type", false),
|
||||
Entry("Accept is end-to-end", "Accept", false),
|
||||
Entry("X-Custom is end-to-end", "X-Custom", false),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("Forward", func() {
|
||||
It("strips hop-by-hop and Connection headers before upstream, preserves custom headers", func() {
|
||||
gotConnection := make(chan string, 1)
|
||||
gotXCustom := make(chan string, 1)
|
||||
gotHost := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotConnection <- r.Header.Get("Connection")
|
||||
gotXCustom <- r.Header.Get("X-Custom")
|
||||
gotHost <- r.Header.Get("Host")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-hopbyhop"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
{Name: "Connection", Value: "keep-alive"},
|
||||
{Name: "Host", Value: "spoofed.example.com"},
|
||||
{Name: "X-Custom", Value: "preserved"},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotConnection).To(BeEmpty(), "Connection must not leak to upstream")
|
||||
Expect(<-gotHost).NotTo(Equal("spoofed.example.com"), "Host header must not be spoofed through")
|
||||
Expect(<-gotXCustom).To(Equal("preserved"), "X-Custom header must survive")
|
||||
})
|
||||
|
||||
It("replaces caller-supplied Authorization with the configured key", func() {
|
||||
// The proxy must overwrite a client-supplied Authorization header
|
||||
// so a downstream caller can't smuggle stale or wrong credentials.
|
||||
gotAuth := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_AUTH_REPLACE_KEY", "sk-real")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_AUTH_REPLACE_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-replaces-auth"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
// Client-supplied Authorization with the wrong scheme / key.
|
||||
{Name: "Authorization", Value: "Basic Zm9vOmJhcg=="},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
|
||||
})
|
||||
|
||||
It("handles concurrent calls without interference", func() {
|
||||
// CloudProxy explicitly omits base.SingleThread — independent
|
||||
// Forward streams must not block each other or leak state.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
addr := "test://forward-concurrent"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
|
||||
const N = 8
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
stream, err := c.Forward(context.Background())
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
payload := "request-" + string(rune('A'+idx))
|
||||
if err := stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
BodyChunk: []byte(payload),
|
||||
}); err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv()
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
body = append(body, r.GetBodyChunk()...)
|
||||
}
|
||||
if string(body) != payload {
|
||||
errs <- &echoMismatch{want: payload, got: string(body)}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
var collected []error
|
||||
for err := range errs {
|
||||
collected = append(collected, err)
|
||||
}
|
||||
Expect(collected).To(BeEmpty(), "no concurrent Forward call should fail")
|
||||
})
|
||||
})
|
||||
|
||||
type echoMismatch struct{ want, got string }
|
||||
|
||||
func (e *echoMismatch) Error() string {
|
||||
return "echo mismatch: want " + strconv.Quote(e.want) + " got " + strconv.Quote(e.got)
|
||||
}
|
||||
508
backend/go/cloud-proxy/provider_anthropic.go
Normal file
508
backend/go/cloud-proxy/provider_anthropic.go
Normal file
@@ -0,0 +1,508 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Anthropic Messages API wire-format types. Narrowed to what translate
|
||||
// mode preserves through the Reply proto: text + tool_use blocks +
|
||||
// usage tokens. Image blocks, prompt caching, metadata, and stop
|
||||
// sequence metadata are not modelled — passthrough mode covers those.
|
||||
//
|
||||
// Notable differences from OpenAI:
|
||||
// - max_tokens is REQUIRED. Anthropic 400s without it.
|
||||
// - Roles are user/assistant only — system messages move to a
|
||||
// top-level `system` string field.
|
||||
// - Streaming SSE uses event: lines alongside data: lines. The
|
||||
// events we care about: content_block_start (carries tool_use
|
||||
// init: id + name), content_block_delta (text_delta with text;
|
||||
// input_json_delta with partial_json for tool arguments), and
|
||||
// message_stop (terminates the stream). Others are ignored.
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int32 `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// Content is `any` because Anthropic accepts a bare string OR a
|
||||
// list of content blocks. Use the string form for plain user/
|
||||
// assistant turns; switch to []anthropicContentBlock when the
|
||||
// turn needs tool_use (assistant) or tool_result (user) blocks.
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// anthropicToolChoice mirrors the four shapes Anthropic accepts:
|
||||
// {"type":"auto"} | {"type":"any"} | {"type":"tool","name":"X"} |
|
||||
// {"type":"none"} (newer models). OpenAI's "auto"/"none"/
|
||||
// "required"/{"function":{"name":"X"}} all map here.
|
||||
type anthropicToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// anthropicContentBlock is the union shape used both for response
|
||||
// blocks (text/tool_use we read off the wire) and outbound request
|
||||
// blocks (tool_use/tool_result we emit in the conversation history).
|
||||
// Anthropic encodes tool calls inline rather than as a separate field,
|
||||
// so we walk Content[] looking for type=="tool_use" on responses and
|
||||
// produce equivalent blocks when serialising prior-turn tool calls.
|
||||
type anthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
// Tool-result block fields. tool_result uses `content` (not
|
||||
// `text`) and pairs with `tool_use_id`; modelling them as
|
||||
// distinct fields avoids ambiguity at marshal time.
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
ResultContent string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []anthropicContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// anthropicStreamEvent is the union shape used for every event type we
|
||||
// process. Type discriminates; only the matching fields are populated.
|
||||
// content_block_start carries ContentBlock (with id/name for tool_use);
|
||||
// content_block_delta carries Delta (text or partial_json).
|
||||
type anthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index,omitempty"`
|
||||
ContentBlock *anthropicContentBlock `json:"content_block,omitempty"`
|
||||
Delta *anthropicStreamDelta `json:"delta,omitempty"`
|
||||
Message *anthropicResponse `json:"message,omitempty"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicStreamDelta struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
}
|
||||
|
||||
// Anthropic requires max_tokens. If the caller didn't set it, use a
|
||||
// generous-but-bounded default so the request doesn't 400.
|
||||
const anthropicDefaultMaxTokens int32 = 4096
|
||||
|
||||
const anthropicToolChoiceNone = "none"
|
||||
|
||||
// Reused JSON-Schema defaults for malformed inputs. Anthropic requires
|
||||
// input_schema to be a JSON object and tool_use.input to be a JSON
|
||||
// object; clients that omit them must not 400 the entire request.
|
||||
var (
|
||||
emptyJSONObject = json.RawMessage(`{}`)
|
||||
emptyObjectSchema = json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
)
|
||||
|
||||
func buildAnthropicRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := anthropicRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
MaxTokens: opts.GetTokens(),
|
||||
Stream: stream,
|
||||
StopSequences: opts.GetStopPrompts(),
|
||||
}
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = anthropicDefaultMaxTokens
|
||||
}
|
||||
// Newer Anthropic models 400 when both temperature and top_p are
|
||||
// set ("`temperature` and `top_p` cannot both be specified for
|
||||
// this model. Please use only one.") even though their docs only
|
||||
// "recommend" picking one. The OpenAI-compatible chat UI almost
|
||||
// always sends both with default values, so prefer temperature
|
||||
// and drop top_p when both are present.
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
} else if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
|
||||
req.Tools = convertOpenAITools(opts.GetTools())
|
||||
req.ToolChoice = convertOpenAIToolChoice(opts.GetToolChoice())
|
||||
// Anthropic rejects tool_choice without tools and older models
|
||||
// don't accept {"type":"none"} — collapse to a no-tools request.
|
||||
if req.ToolChoice != nil && req.ToolChoice.Type == anthropicToolChoiceNone {
|
||||
req.Tools, req.ToolChoice = nil, nil
|
||||
}
|
||||
|
||||
var systemParts []string
|
||||
for _, m := range opts.GetMessages() {
|
||||
role := m.GetRole()
|
||||
if role == "system" {
|
||||
if c := m.GetContent(); c != "" {
|
||||
systemParts = append(systemParts, c)
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch role {
|
||||
case "user":
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "assistant":
|
||||
if blocks := assistantBlocks(m); blocks != nil {
|
||||
req.Messages = append(req.Messages, anthropicMessage{Role: "assistant", Content: blocks})
|
||||
continue
|
||||
}
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "assistant",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "tool", "function":
|
||||
req.Messages = appendToolResult(req.Messages, anthropicContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: m.GetToolCallId(),
|
||||
ResultContent: m.GetContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
req.System = strings.Join(systemParts, "\n\n")
|
||||
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []anthropicMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// appendToolResult appends a tool_result block as a user message,
|
||||
// merging into a preceding user message that already carries blocks.
|
||||
// Anthropic concatenates consecutive same-role messages on its end,
|
||||
// but explicit merging keeps the body smaller and the conversation
|
||||
// strictly alternating — which some upstream filters require.
|
||||
func appendToolResult(msgs []anthropicMessage, block anthropicContentBlock) []anthropicMessage {
|
||||
if n := len(msgs); n > 0 && msgs[n-1].Role == "user" {
|
||||
if existing, ok := msgs[n-1].Content.([]anthropicContentBlock); ok {
|
||||
msgs[n-1].Content = append(existing, block)
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
return append(msgs, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: []anthropicContentBlock{block},
|
||||
})
|
||||
}
|
||||
|
||||
func convertOpenAITools(toolsJSON string) []anthropicTool {
|
||||
if toolsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var raw []openAITool
|
||||
if err := json.Unmarshal([]byte(toolsJSON), &raw); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tools JSON, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
tools := make([]anthropicTool, 0, len(raw))
|
||||
for _, t := range raw {
|
||||
if t.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
schema := t.Function.Parameters
|
||||
if len(schema) == 0 {
|
||||
schema = emptyObjectSchema
|
||||
}
|
||||
tools = append(tools, anthropicTool{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// convertOpenAIToolChoice accepts the spec form
|
||||
// ({type:function, function:{name:X}}) and the flat legacy form
|
||||
// ({type:function, name:X}) some clients send. Unknown object shapes
|
||||
// are warned and dropped rather than silently treated as auto.
|
||||
func convertOpenAIToolChoice(toolChoiceJSON string) *anthropicToolChoice {
|
||||
if toolChoiceJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var asString string
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asString); err == nil {
|
||||
switch asString {
|
||||
case "auto":
|
||||
return &anthropicToolChoice{Type: "auto"}
|
||||
case "none":
|
||||
return &anthropicToolChoice{Type: anthropicToolChoiceNone}
|
||||
case "required":
|
||||
return &anthropicToolChoice{Type: "any"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var asObj struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asObj); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tool_choice, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
if name := asObj.Function.Name; name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: name}
|
||||
}
|
||||
if asObj.Name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: asObj.Name}
|
||||
}
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unrecognised tool_choice shape, dropping", "shape", toolChoiceJSON)
|
||||
return nil
|
||||
}
|
||||
|
||||
// openAITool mirrors pkg/functions.Tool but keeps Parameters as
|
||||
// json.RawMessage so the input_schema passes through verbatim — no
|
||||
// re-marshal cost, no fidelity loss on exotic schemas.
|
||||
type openAITool struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters json.RawMessage `json:"parameters"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
func assistantBlocks(m *pb.Message) []anthropicContentBlock {
|
||||
toolCallsJSON := m.GetToolCalls()
|
||||
if toolCallsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var toolCalls []openAIToolCall
|
||||
if err := json.Unmarshal([]byte(toolCallsJSON), &toolCalls); err != nil || len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
blocks := make([]anthropicContentBlock, 0, len(toolCalls)+1)
|
||||
if text := m.GetContent(); text != "" {
|
||||
blocks = append(blocks, anthropicContentBlock{Type: "text", Text: text})
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
// OpenAI's arguments are a JSON-encoded string; pass through
|
||||
// as RawMessage so a non-JSON string from a poorly-formed
|
||||
// local model doesn't crash the marshaller downstream.
|
||||
args := json.RawMessage(tc.Function.Arguments)
|
||||
if len(args) == 0 {
|
||||
args = emptyJSONObject
|
||||
}
|
||||
blocks = append(blocks, anthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: args,
|
||||
})
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
// doAnthropicRequest is the Anthropic counterpart of doOpenAIRequest.
|
||||
// applyAuthHeader sets x-api-key and anthropic-version when provider
|
||||
// is anthropic, so this method doesn't need to duplicate that.
|
||||
func (c *CloudProxy) doAnthropicRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictAnthropicRich returns the full Reply: joined text from all
|
||||
// text blocks, tool_use blocks mapped to ToolCallDelta, and usage
|
||||
// tokens.
|
||||
func (c *CloudProxy) predictAnthropicRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildAnthropicRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed anthropicResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
|
||||
reply := &pb.Reply{}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.InputTokens)
|
||||
reply.Tokens = int32(parsed.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
var content strings.Builder
|
||||
var toolCalls []*pb.ToolCallDelta
|
||||
toolIdx := 0
|
||||
for _, b := range parsed.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
content.WriteString(b.Text)
|
||||
case "tool_use":
|
||||
// Input is a structured JSON object; we serialise to a
|
||||
// string so it fits the OpenAI-shaped arguments field
|
||||
// downstream consumers expect.
|
||||
args := ""
|
||||
if len(b.Input) > 0 {
|
||||
args = string(b.Input)
|
||||
}
|
||||
toolCalls = append(toolCalls, newToolCallDelta(toolIdx, b.ID, b.Name, args))
|
||||
toolIdx++
|
||||
}
|
||||
}
|
||||
reply.Message = []byte(content.String())
|
||||
if len(toolCalls) > 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{ToolCalls: toolCalls}}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictAnthropicStreamRich streams Reply chunks from Anthropic's SSE.
|
||||
// Three event types matter: content_block_start (initialises tool_use
|
||||
// id+name), content_block_delta (carries text or input_json_delta),
|
||||
// message_stop (terminates). The block index from the wire feeds
|
||||
// straight into ToolCallDelta.Index so downstream consumers can
|
||||
// reassemble multiple parallel tool calls.
|
||||
func (c *CloudProxy) predictAnthropicStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildAnthropicRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
var ev anthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
switch ev.Type {
|
||||
case "content_block_start":
|
||||
// tool_use blocks announce id + name here; arguments arrive
|
||||
// in subsequent input_json_delta events. Emit a Reply with
|
||||
// just the tool_call init fields so consumers can allocate
|
||||
// a slot at this index.
|
||||
if ev.ContentBlock != nil && ev.ContentBlock.Type == "tool_use" {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, ev.ContentBlock.ID, ev.ContentBlock.Name, ""),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
if ev.Delta == nil {
|
||||
continue
|
||||
}
|
||||
switch ev.Delta.Type {
|
||||
case "text_delta":
|
||||
if ev.Delta.Text == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Message: []byte(ev.Delta.Text),
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: ev.Delta.Text}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
case "input_json_delta":
|
||||
if ev.Delta.PartialJSON == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, "", "", ev.Delta.PartialJSON),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_delta":
|
||||
// Anthropic sends final usage in message_delta.usage. Emit
|
||||
// a usage-only Reply so the consumer can record totals.
|
||||
if ev.Usage != nil {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Tokens: int32(ev.Usage.OutputTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
334
backend/go/cloud-proxy/provider_anthropic_test.go
Normal file
334
backend/go/cloud-proxy/provider_anthropic_test.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// fakeAnthropicUpstream mirrors fakeOpenAIUpstream but decodes the
|
||||
// request body as an anthropicRequest so tests can assert on the
|
||||
// translated wire shape (system field, max_tokens, etc.).
|
||||
func fakeAnthropicUpstream(t *testing.T, handler func(req anthropicRequest) (status int, body string, contentType string)) (*httptest.Server, *anthropicRequest) {
|
||||
t.Helper()
|
||||
var captured anthropicRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newAnthropicTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_FAKE", "sk-ant-fake")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "claude-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_FAKE",
|
||||
UpstreamModel: "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_BasicMessages(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hi there"}],"model":"claude-3-5-sonnet-20241022","usage":{"input_tokens":5,"output_tokens":2}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
g.Expect(captured.Model).To(Equal("claude-3-5-sonnet-20241022"))
|
||||
// System message must be hoisted out of Messages into top-level field.
|
||||
g.Expect(captured.System).To(Equal("be brief"))
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
// Anthropic 400s when both temperature and top_p are set; the
|
||||
// translator must prefer temperature and drop top_p.
|
||||
g.Expect(captured.TopP).To(BeNil())
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
// When only top_p is set, it should be forwarded.
|
||||
func TestPredict_Anthropic_TopPOnly(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
|
||||
TopP: 0.9,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
// PredictOptions.TopP is float32 on the wire; the translator widens
|
||||
// to float64 so 0.9 round-trips as 0.8999999761581421… — compare
|
||||
// with a small tolerance rather than exact equality.
|
||||
g.Expect(captured.TopP).NotTo(BeNil())
|
||||
g.Expect(math.Abs(*captured.TopP - 0.9)).To(BeNumerically("<=", 1e-6))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_DefaultsMaxTokens(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic 400s without max_tokens. The translator must default
|
||||
// it when the caller doesn't supply Tokens.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.MaxTokens).To(Equal(anthropicDefaultMaxTokens))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?", Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_ConcatenatesContentBlocks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic may return multiple text blocks; the translator joins
|
||||
// them so the Predict() string return is the full assistant message.
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"hello "},{"type":"text","text":"world"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello world"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 401, `{"error":{"type":"authentication_error","message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_Anthropic_StreamsTextDeltas(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Real Anthropic SSE has event: lines + data: lines. The translator
|
||||
// only needs the data: payload; only content_block_delta with
|
||||
// delta.type=text_delta carries content. message_stop ends.
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" \"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"world\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
Tokens: 16,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_TranslatesOpenAITools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "weather in Paris?"}},
|
||||
Tools: tools,
|
||||
ToolChoice: `"auto"`,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(HaveLen(1))
|
||||
g.Expect(captured.Tools[0].Name).To(Equal("get_weather"))
|
||||
g.Expect(captured.Tools[0].Description).To(Equal("Get weather"))
|
||||
// input_schema must be the parameters object verbatim.
|
||||
g.Expect(string(captured.Tools[0].InputSchema)).To(ContainSubstring(`"city"`))
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("auto"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_RequiredMapsToAny(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"required"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("any"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NoneDropsTools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"none"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(BeNil())
|
||||
g.Expect(captured.ToolChoice).To(BeNil())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NamedFunction(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"weather","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `{"type":"function","function":{"name":"weather"}}`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("tool"))
|
||||
g.Expect(captured.ToolChoice.Name).To(Equal("weather"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_RoundTripsAssistantToolCalls(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// LocalAI Assistant's second turn: the LLM previously emitted a
|
||||
// tool_use, the server executed it, and the conversation now
|
||||
// includes the assistant turn (with tool_calls) plus a tool-role
|
||||
// result message. Both must convert to Anthropic block form.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"list_models","parameters":{"type":"object"}}}]`
|
||||
toolCallsJSON := `[{"id":"call_abc","type":"function","function":{"name":"list_models","arguments":"{}"}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Tools: tools,
|
||||
Messages: []*pb.Message{
|
||||
{Role: "user", Content: "what models are installed?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: toolCallsJSON},
|
||||
{Role: "tool", Content: `{"models":["a","b"]}`, ToolCallId: "call_abc"},
|
||||
},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
g.Expect(captured.Messages).To(HaveLen(3))
|
||||
// 1. user text — bare string
|
||||
s, ok := captured.Messages[0].Content.(string)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(s).To(Equal("what models are installed?"))
|
||||
// 2. assistant — must be a content-block list with one tool_use
|
||||
// json.Unmarshal of `any` produces []any not []anthropicContentBlock.
|
||||
blocks, ok := captured.Messages[1].Content.([]any)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(blocks).To(HaveLen(1))
|
||||
b0, _ := blocks[0].(map[string]any)
|
||||
g.Expect(b0["type"]).To(Equal("tool_use"))
|
||||
g.Expect(b0["id"]).To(Equal("call_abc"))
|
||||
g.Expect(b0["name"]).To(Equal("list_models"))
|
||||
// 3. tool → user with tool_result block
|
||||
g.Expect(captured.Messages[2].Role).To(Equal("user"))
|
||||
resBlocks, _ := captured.Messages[2].Content.([]any)
|
||||
r0, _ := resBlocks[0].(map[string]any)
|
||||
g.Expect(r0["type"]).To(Equal("tool_result"))
|
||||
g.Expect(r0["tool_use_id"]).To(Equal("call_abc"))
|
||||
g.Expect(r0["content"]).To(Equal(`{"models":["a","b"]}`))
|
||||
}
|
||||
119
backend/go/cloud-proxy/provider_edge_test.go
Normal file
119
backend/go/cloud-proxy/provider_edge_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Verify buildOpenAIRequest preserves caller-supplied tools and
|
||||
// tool_choice as opaque JSON. PredictOptions carries them as strings;
|
||||
// they must land in the outbound request body unchanged so the
|
||||
// upstream sees the caller's intent verbatim. A regression here would
|
||||
// silently disable function calling for translate-mode clients.
|
||||
func TestBuildOpenAIRequest_ToolsAndToolChoicePassthrough(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
toolsJSON := `[{"type":"function","function":{"name":"search","parameters":{"type":"object"}}}]`
|
||||
choiceJSON := `{"type":"function","function":{"name":"search"}}`
|
||||
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find x"}},
|
||||
Tools: toolsJSON,
|
||||
ToolChoice: choiceJSON,
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var decoded openAIRequest
|
||||
err = json.Unmarshal(body, &decoded)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// Compare the JSON-canonical form so whitespace differences are ignored.
|
||||
gotTools, _ := json.Marshal(json.RawMessage(decoded.Tools))
|
||||
wantTools, _ := json.Marshal(json.RawMessage(toolsJSON))
|
||||
g.Expect(string(gotTools)).To(Equal(string(wantTools)))
|
||||
gotChoice, _ := json.Marshal(json.RawMessage(decoded.ToolChoice))
|
||||
wantChoice, _ := json.Marshal(json.RawMessage(choiceJSON))
|
||||
g.Expect(string(gotChoice)).To(Equal(string(wantChoice)))
|
||||
}
|
||||
|
||||
// Garbage JSON in tools / tool_choice is silently dropped (omitted)
|
||||
// rather than blowing up the request. Documents the parseRawJSON
|
||||
// behaviour — operators shouldn't see hard failures from an upstream
|
||||
// caller's mis-formatted tools field.
|
||||
func TestBuildOpenAIRequest_InvalidToolsJSONDropped(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: "this is not json",
|
||||
ToolChoice: "{also bad",
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("this is not json"))
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("{also bad"))
|
||||
}
|
||||
|
||||
// Anthropic empty content array yields an empty Reply (not an error).
|
||||
// Mirrors how an upstream tool_use-only response might arrive — the
|
||||
// content array can legitimately be empty in some edge cases.
|
||||
func TestPredictRich_Anthropic_EmptyContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"m1","type":"message","role":"assistant","content":[],"usage":{"input_tokens":3,"output_tokens":0}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(0))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(3)))
|
||||
}
|
||||
|
||||
// A truncated / malformed SSE payload mid-stream should be tolerated:
|
||||
// the malformed chunk gets skipped (xlog.Debug logged), valid chunks
|
||||
// before AND after it still reach the channel.
|
||||
func TestPredictStreamRich_OpenAI_TolerantOfBadChunks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
body := strings.Join([]string{
|
||||
`data: {"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
``,
|
||||
`data: this-is-not-json{{`,
|
||||
``,
|
||||
`data: {"choices":[{"index":0,"delta":{"content":" world"}}]}`,
|
||||
``,
|
||||
`data: [DONE]`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var assembled strings.Builder
|
||||
for reply := range results {
|
||||
assembled.Write(reply.GetMessage())
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// The good chunks before and after the malformed one both made it through.
|
||||
g.Expect(assembled.String()).To(Equal("hello world"))
|
||||
}
|
||||
320
backend/go/cloud-proxy/provider_openai.go
Normal file
320
backend/go/cloud-proxy/provider_openai.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// OpenAI Chat Completions wire-format types. Narrowed to the fields
|
||||
// translate mode needs to preserve through the Reply proto: content,
|
||||
// role, tool_calls (typed so we can map them to pb.ToolCallDelta),
|
||||
// and sampling params copied verbatim from PredictOptions.
|
||||
//
|
||||
// Provider-specific extensions (logit_bias, function calling beyond
|
||||
// tool_calls, etc.) are not modelled — passthrough mode covers callers
|
||||
// that need full upstream fidelity.
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int32 `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// openAIToolCall covers both the non-streaming response shape (full
|
||||
// id+function+arguments) and the streaming-delta shape (sparse fields,
|
||||
// index assignment). The proto's ToolCallDelta absorbs both — name is
|
||||
// set on first appearance, arguments arrive incrementally in streaming.
|
||||
type openAIToolCall struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function openAIFunctionCall `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
type openAIFunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Choices []openAIChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChunk struct {
|
||||
Choices []openAIStreamChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// buildOpenAIRequest converts pb.PredictOptions into the OpenAI Chat
|
||||
// Completions request body. Prefers Messages when non-empty; falls
|
||||
// back to wrapping Prompt as a single user message so plain
|
||||
// /completions-style calls still work in translate mode.
|
||||
func buildOpenAIRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := openAIRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
Stream: stream,
|
||||
Stop: opts.GetStopPrompts(),
|
||||
Tools: parseRawJSON(opts.GetTools()),
|
||||
ToolChoice: parseRawJSON(opts.GetToolChoice()),
|
||||
}
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
}
|
||||
if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
req.MaxTokens = &n
|
||||
}
|
||||
if p := opts.GetFrequencyPenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.FrequencyPenalty = &v
|
||||
}
|
||||
if p := opts.GetPresencePenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.PresencePenalty = &v
|
||||
}
|
||||
|
||||
for _, m := range opts.GetMessages() {
|
||||
msg := openAIMessage{
|
||||
Role: m.GetRole(),
|
||||
Content: m.GetContent(),
|
||||
Name: m.GetName(),
|
||||
ToolCallID: m.GetToolCallId(),
|
||||
}
|
||||
// Pre-existing tool_calls arrive as a JSON string from the
|
||||
// upstream caller's previous assistant turn; pass-through as-is.
|
||||
if tc := m.GetToolCalls(); tc != "" {
|
||||
_ = json.Unmarshal([]byte(tc), &msg.ToolCalls)
|
||||
}
|
||||
req.Messages = append(req.Messages, msg)
|
||||
}
|
||||
// Fallback for plain Prompt requests (no Messages array). LocalAI
|
||||
// templating may have produced a flat prompt; rewrap as a single
|
||||
// user message so the upstream chat endpoint accepts it.
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []openAIMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// modelName picks the upstream model: upstream_model from the proxy
|
||||
// config wins (operator override), else the local model name captured
|
||||
// at LoadModel time. Operator sets upstream_model to map LocalAI's
|
||||
// alias (e.g. "claude-strict") to the upstream's canonical name
|
||||
// (e.g. "claude-3-5-sonnet-20241022").
|
||||
func modelName(cfg *proxyConfig, _ *pb.PredictOptions) string {
|
||||
if cfg.upstreamModel != "" {
|
||||
return cfg.upstreamModel
|
||||
}
|
||||
return cfg.localModel
|
||||
}
|
||||
|
||||
// parseRawJSON parses a JSON string into a RawMessage so it round-trips
|
||||
// into the upstream body. Returns nil for empty/invalid input so the
|
||||
// field is omitted (omitempty).
|
||||
func parseRawJSON(s string) json.RawMessage {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
var probe json.RawMessage
|
||||
if err := json.Unmarshal([]byte(s), &probe); err != nil {
|
||||
return nil
|
||||
}
|
||||
return probe
|
||||
}
|
||||
|
||||
// doOpenAIRequest builds + sends the upstream request. Returns the
|
||||
// raw response on success; caller handles status / body.
|
||||
func (c *CloudProxy) doOpenAIRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictOpenAIRich is the non-streaming translate path. Returns a
|
||||
// fully-populated *pb.Reply with assistant content, tool calls, and
|
||||
// token usage. The gRPC server forwards the Reply verbatim.
|
||||
func (c *CloudProxy) predictOpenAIRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildOpenAIRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
if len(parsed.Choices) == 0 {
|
||||
return nil, errors.New("cloud-proxy: upstream returned no choices")
|
||||
}
|
||||
|
||||
choice := parsed.Choices[0]
|
||||
reply := &pb.Reply{
|
||||
Message: []byte(choice.Message.Content),
|
||||
}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.PromptTokens)
|
||||
reply.Tokens = int32(parsed.Usage.CompletionTokens)
|
||||
}
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// Non-streaming: a single ChatDelta carries the full tool-call
|
||||
// set. Index/Name/Arguments are populated together; downstream
|
||||
// consumers don't need to assemble streaming deltas.
|
||||
delta := &pb.ChatDelta{}
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
delta.ToolCalls = append(delta.ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
reply.ChatDeltas = []*pb.ChatDelta{delta}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictOpenAIStreamRich streams *pb.Reply chunks. Each chunk carries
|
||||
// either a content delta (Message + ChatDeltas[].Content) or tool-call
|
||||
// deltas (ChatDeltas[].ToolCalls). The final Reply carries usage tokens
|
||||
// when the upstream sends them (stream_options.include_usage).
|
||||
func (c *CloudProxy) predictOpenAIStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildOpenAIRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
var chunk openAIStreamChunk
|
||||
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
// Usage frames may arrive separately from content frames when
|
||||
// stream_options.include_usage is set; emit a usage-only Reply
|
||||
// in that case so the consumer sees the totals.
|
||||
if chunk.Usage != nil && len(chunk.Choices) == 0 {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
PromptTokens: int32(chunk.Usage.PromptTokens),
|
||||
Tokens: int32(chunk.Usage.CompletionTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, ch := range chunk.Choices {
|
||||
reply := &pb.Reply{}
|
||||
if ch.Delta.Content != "" {
|
||||
reply.Message = []byte(ch.Delta.Content)
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{Content: ch.Delta.Content}}
|
||||
}
|
||||
if len(ch.Delta.ToolCalls) > 0 {
|
||||
if len(reply.ChatDeltas) == 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{}}
|
||||
}
|
||||
for _, tc := range ch.Delta.ToolCalls {
|
||||
reply.ChatDeltas[0].ToolCalls = append(reply.ChatDeltas[0].ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
}
|
||||
if reply.Message == nil && len(reply.ChatDeltas) == 0 {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, reply) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
170
backend/go/cloud-proxy/provider_openai_test.go
Normal file
170
backend/go/cloud-proxy/provider_openai_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// fakeOpenAIUpstream returns an httptest.Server that decodes the
|
||||
// inbound request as an openAIRequest, calls handler with it, and
|
||||
// writes the handler's reply as the response.
|
||||
func fakeOpenAIUpstream(t *testing.T, handler func(req openAIRequest) (status int, body string, contentType string)) (*httptest.Server, *openAIRequest) {
|
||||
t.Helper()
|
||||
var captured openAIRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_OPENAI_FAKE", "sk-fake-openai")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "gpt-4o-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
ApiKeyEnv: "CLOUD_PROXY_OPENAI_FAKE",
|
||||
UpstreamModel: "gpt-4o",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_BasicChat(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"id":"resp-1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
// Verify the upstream saw a properly-translated request.
|
||||
g.Expect(captured.Model).To(Equal("gpt-4o"))
|
||||
g.Expect(captured.Messages).To(HaveLen(2))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("system"))
|
||||
g.Expect(captured.Messages[1].Role).To(Equal("user"))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
g.Expect(captured.MaxTokens).NotTo(BeNil())
|
||||
g.Expect(*captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// No Messages array — backend should synth a single user message
|
||||
// from Prompt so non-chat clients still route through translate.
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 401, `{"error":{"message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_OpenAI_StreamsContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Stream three content deltas then [DONE]. Verify the channel
|
||||
// receives them in order with no missing pieces.
|
||||
chunks := []string{
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":" "}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"world"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestPredict_RejectedInPassthroughMode(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_FAKE", "k")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_, err = cp.Predict(&pb.PredictOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("only valid in translate"))
|
||||
}
|
||||
429
backend/go/cloud-proxy/proxy.go
Normal file
429
backend/go/cloud-proxy/proxy.go
Normal file
@@ -0,0 +1,429 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
|
||||
// import core to keep the boundary clean.
|
||||
const (
|
||||
modePassthrough = "passthrough"
|
||||
modeTranslate = "translate"
|
||||
|
||||
providerOpenAI = "openai"
|
||||
providerAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// CloudProxy is the LocalAI backend that proxies model traffic to a
|
||||
// configured upstream HTTP provider. Concurrency: base.SingleThread is
|
||||
// NOT embedded — forward calls are independent and HTTP transport is
|
||||
// goroutine-safe, so multiple Forward streams can run in parallel.
|
||||
// Locking would serialise requests to a chat provider for no benefit.
|
||||
type CloudProxy struct {
|
||||
base.Base
|
||||
|
||||
cfg atomic.Pointer[proxyConfig]
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type proxyConfig struct {
|
||||
upstreamURL string
|
||||
mode string
|
||||
provider string
|
||||
upstreamModel string
|
||||
localModel string // ModelOptions.Model — fallback when upstream_model is unset
|
||||
apiKey string // resolved at Load time
|
||||
}
|
||||
|
||||
func NewCloudProxy() *CloudProxy {
|
||||
// No Client-level Timeout — that would bound streaming SSE
|
||||
// responses too, which can legitimately last minutes. Per-request
|
||||
// deadlines come from the gRPC stream context.
|
||||
return &CloudProxy{client: &http.Client{}}
|
||||
}
|
||||
|
||||
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
|
||||
po := opts.GetProxy()
|
||||
if po == nil {
|
||||
return errors.New("cloud-proxy: Load requires ProxyOptions to be set")
|
||||
}
|
||||
if po.GetUpstreamUrl() == "" {
|
||||
return errors.New("cloud-proxy: upstream_url is required")
|
||||
}
|
||||
if _, err := url.ParseRequestURI(po.GetUpstreamUrl()); err != nil {
|
||||
return fmt.Errorf("cloud-proxy: upstream_url %q invalid: %w", po.GetUpstreamUrl(), err)
|
||||
}
|
||||
|
||||
mode := po.GetMode()
|
||||
if mode == "" {
|
||||
mode = modePassthrough
|
||||
}
|
||||
switch mode {
|
||||
case modePassthrough:
|
||||
case modeTranslate:
|
||||
switch po.GetProvider() {
|
||||
case providerOpenAI:
|
||||
// implemented in provider_openai.go
|
||||
case providerAnthropic:
|
||||
// implemented in provider_anthropic.go
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: translate mode requires provider in {%s, %s}, got %q",
|
||||
providerOpenAI, providerAnthropic, po.GetProvider())
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: unknown mode %q", mode)
|
||||
}
|
||||
|
||||
key, err := resolveAPIKey(po.GetApiKeyEnv(), po.GetApiKeyFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.cfg.Store(&proxyConfig{
|
||||
upstreamURL: po.GetUpstreamUrl(),
|
||||
mode: mode,
|
||||
provider: po.GetProvider(),
|
||||
upstreamModel: po.GetUpstreamModel(),
|
||||
localModel: opts.GetModel(),
|
||||
apiKey: key,
|
||||
})
|
||||
xlog.Info("cloud-proxy: ready",
|
||||
"upstream", po.GetUpstreamUrl(),
|
||||
"mode", mode,
|
||||
"provider", po.GetProvider(),
|
||||
"has_key", key != "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveAPIKey mirrors config.ProxyConfig.ResolveAPIKey. Duplicated
|
||||
// (a few lines) rather than importing core/config from a backend
|
||||
// binary — keeps backends independent of core's package layout.
|
||||
// Mutual-exclusion is enforced upstream in core/config.Validate.
|
||||
func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
if envName != "" {
|
||||
v := os.Getenv(envName)
|
||||
if v == "" {
|
||||
return "", fmt.Errorf("cloud-proxy: api_key_env %q is unset", envName)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
if filePath != "" {
|
||||
b, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: read api_key_file %q: %w", filePath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(b)), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// PredictRich is the non-streaming translate path. Returns a fully-
|
||||
// populated *pb.Reply: content, tool-call deltas (ChatDeltas), and
|
||||
// usage tokens. Implements the optional grpc.AIModelRich interface;
|
||||
// the gRPC server prefers this path over Predict when present so
|
||||
// tool calls survive the round-trip. Passthrough mode rejects
|
||||
// PredictRich — callers must use Forward.
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIRich(ctx, cfg, opts)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicRich(ctx, cfg, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("cloud-proxy: predict not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// PredictStreamRich is the rich streaming counterpart of PredictRich.
|
||||
// Each emitted Reply carries either a content delta, tool-call deltas,
|
||||
// or usage tokens (the final upstream frame). base.Base.PredictStream
|
||||
// is bypassed when AIModelRich is implemented, so the channel is
|
||||
// closed by the gRPC server pump.
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict-stream", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict-stream failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIStreamRich(ctx, cfg, opts, results)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicStreamRich(ctx, cfg, opts, results)
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: predictStream not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Predict is the legacy (string, error) AIModel signature. Used only
|
||||
// if a caller goes through the non-rich path (it shouldn't, since
|
||||
// server.go prefers PredictRich). Provided so the AIModel interface
|
||||
// is satisfied for backends that haven't opted into the rich variant.
|
||||
func (c *CloudProxy) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
reply, err := c.PredictRich(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.GetMessage()), nil
|
||||
}
|
||||
|
||||
// PredictStream is the legacy chan-string streaming path. Adapts the
|
||||
// rich stream by extracting only content text — tool-call-only chunks
|
||||
// (no Message bytes) and usage-only chunks are silently dropped, since
|
||||
// the legacy chan-string contract cannot represent them. Consumers
|
||||
// that need tool calls must call PredictStreamRich directly.
|
||||
func (c *CloudProxy) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
defer close(results)
|
||||
richCh := make(chan *pb.Reply)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- c.PredictStreamRich(opts, richCh)
|
||||
close(richCh)
|
||||
}()
|
||||
for reply := range richCh {
|
||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
||||
results <- string(msg)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// sendReply pushes one Reply onto a stream channel honouring ctx
|
||||
// cancellation. Returns false on cancel so the caller can exit with
|
||||
// ctx.Err(). Used by both translate-mode providers.
|
||||
func sendReply(ctx context.Context, results chan<- *pb.Reply, reply *pb.Reply) bool {
|
||||
select {
|
||||
case results <- reply:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newToolCallDelta is a small constructor for the cross-provider
|
||||
// tool-call delta shape. Centralised so the int32 cast and the four
|
||||
// fields stay consistent across the OpenAI / Anthropic translators.
|
||||
// Empty name/args are valid — Anthropic streaming announces the call
|
||||
// with id+name then sends arguments incrementally; OpenAI's reverse
|
||||
// pattern (args without name) also lands here.
|
||||
func newToolCallDelta(index int, id, name, args string) *pb.ToolCallDelta {
|
||||
return &pb.ToolCallDelta{
|
||||
Index: int32(index),
|
||||
Id: id,
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward shovels bytes between a Forward gRPC stream and an upstream
|
||||
// HTTP request. First request message carries path/method/headers and
|
||||
// the initial body chunk; subsequent messages append body chunks. The
|
||||
// first reply carries upstream status + response headers; subsequent
|
||||
// replies stream body chunks until the upstream connection closes.
|
||||
// Cancellation of ctx (the gRPC stream context) closes the upstream
|
||||
// connection.
|
||||
func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest, out chan<- *pb.ForwardReply) error {
|
||||
defer close(out)
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
}
|
||||
|
||||
first, ok := <-in
|
||||
if !ok {
|
||||
return errors.New("cloud-proxy: Forward stream closed before first request")
|
||||
}
|
||||
|
||||
// Honour the per-request path only when the configured upstream_url
|
||||
// has no path of its own — gallery convention is to put the
|
||||
// canonical path in upstream_url.
|
||||
fullURL, err := composeURL(cfg.upstreamURL, first.GetPath())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method := first.GetMethod()
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
// Pipe the body in from the gRPC stream so the HTTP request can
|
||||
// start before the client finishes sending. The pipe-reader is
|
||||
// closed via CloseWithError on the error paths so the writer
|
||||
// goroutine doesn't block forever.
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go func() {
|
||||
var writeErr error
|
||||
defer func() { _ = pw.CloseWithError(writeErr) }()
|
||||
if len(first.GetBodyChunk()) > 0 {
|
||||
if _, writeErr = pw.Write(first.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
for req := range in {
|
||||
if len(req.GetBodyChunk()) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, writeErr = pw.Write(req.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, fullURL, pr)
|
||||
if err != nil {
|
||||
_ = pr.CloseWithError(err) // unblocks the body-pump's pw.Write
|
||||
return fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
|
||||
// Apply caller-supplied headers, then override with the
|
||||
// authorization header derived from the resolved key. Caller-
|
||||
// supplied Authorization is always replaced — operators may not
|
||||
// know the backend's auth scheme, and silently leaking through a
|
||||
// client Authorization header to a different upstream would
|
||||
// confuse the upstream and could leak credentials.
|
||||
for _, h := range first.GetHeaders() {
|
||||
if h == nil || h.GetName() == "" {
|
||||
continue
|
||||
}
|
||||
// Strip hop-by-hop headers that aren't meaningful to the
|
||||
// upstream (Host is set by the http client from the URL;
|
||||
// Content-Length is computed from the body).
|
||||
if isHopByHopHeader(h.GetName()) {
|
||||
continue
|
||||
}
|
||||
req.Header.Add(h.GetName(), h.GetValue())
|
||||
}
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
|
||||
xlog.Info("cloud-proxy: forward", "method", method, "url", fullURL, "provider", cfg.provider)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: forward upstream failed", "url", fullURL, "error", err)
|
||||
return fmt.Errorf("cloud-proxy: upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
logFn := xlog.Info
|
||||
if resp.StatusCode >= 400 {
|
||||
logFn = xlog.Warn
|
||||
}
|
||||
logFn("cloud-proxy: forward response", "url", fullURL, "status", resp.StatusCode)
|
||||
|
||||
// First reply: status + response headers, no body.
|
||||
headers := make([]*pb.ForwardHeader, 0, len(resp.Header))
|
||||
for k, vs := range resp.Header {
|
||||
for _, v := range vs {
|
||||
headers = append(headers, &pb.ForwardHeader{Name: k, Value: v})
|
||||
}
|
||||
}
|
||||
out <- &pb.ForwardReply{Status: int32(resp.StatusCode), Headers: headers}
|
||||
|
||||
// Subsequent replies: body chunks. Use a fixed 8KB buffer — small
|
||||
// enough that SSE token frames flush promptly, large enough that
|
||||
// long chunked-transfer bodies aren't death by a thousand reads.
|
||||
buf := make([]byte, 8*1024)
|
||||
for {
|
||||
n, rerr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
out <- &pb.ForwardReply{BodyChunk: chunk}
|
||||
}
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("cloud-proxy: upstream body read: %w", rerr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// composeURL combines the configured upstream URL with the per-request
|
||||
// path. The upstream URL typically already includes the canonical path
|
||||
// (e.g. https://api.openai.com/v1/chat/completions) so the per-request
|
||||
// path is ignored in that case. When upstream_url is a bare host
|
||||
// (https://api.openai.com), the request path is appended.
|
||||
func composeURL(upstream, reqPath string) (string, error) {
|
||||
u, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: parse upstream_url %q: %w", upstream, err)
|
||||
}
|
||||
if u.Path == "" || u.Path == "/" {
|
||||
u.Path = reqPath
|
||||
}
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// applyAuthHeader writes the appropriate authorization header for the
|
||||
// provider. OpenAI/Anthropic/most providers use Bearer; Anthropic
|
||||
// historically uses x-api-key + anthropic-version, but accepts Bearer
|
||||
// too via the OpenAI-compatible path. Default to Bearer when provider
|
||||
// is empty (passthrough mode where the operator doesn't claim a
|
||||
// provider).
|
||||
func applyAuthHeader(req *http.Request, provider, key string) {
|
||||
switch provider {
|
||||
case providerAnthropic:
|
||||
req.Header.Set("x-api-key", key)
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
}
|
||||
|
||||
// isHopByHopHeader returns true for headers that should not be
|
||||
// forwarded from the client request to the upstream (RFC 7230 §6.1
|
||||
// hop-by-hop list, plus a few that the http.Client sets itself).
|
||||
func isHopByHopHeader(name string) bool {
|
||||
switch strings.ToLower(name) {
|
||||
case "connection", "proxy-connection", "keep-alive", "transfer-encoding",
|
||||
"te", "trailer", "upgrade", "host", "content-length":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
206
backend/go/cloud-proxy/proxy_test.go
Normal file
206
backend/go/cloud-proxy/proxy_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// helper: run a CloudProxy in-process via grpc.Provide so tests can
|
||||
// call Forward through the public Backend interface without listening
|
||||
// on a real socket.
|
||||
func newInProcClient(t *testing.T, proxy *CloudProxy) grpc.Backend {
|
||||
t.Helper()
|
||||
addr := "test://" + t.Name()
|
||||
grpc.Provide(addr, proxy)
|
||||
return grpc.NewClient(addr, true, nil, false)
|
||||
}
|
||||
|
||||
func TestForward_PassthroughEcho(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Fake upstream: echoes the request body back, prefixed with a
|
||||
// canary so the test can assert both that the body reached the
|
||||
// upstream and the response made it back to the client.
|
||||
gotBody := make(chan string, 1)
|
||||
gotAuth := make(chan string, 1)
|
||||
gotPath := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody <- string(body)
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
gotPath <- r.URL.Path
|
||||
w.Header().Set("X-Echo", "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("echo: " + string(body)))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_FAKE_KEY", "sk-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{{Name: "Content-Type", Value: "application/json"}},
|
||||
BodyChunk: []byte(`{"prompt":`),
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{BodyChunk: []byte(`"hi"}`)})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.CloseSend()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// First reply: status + headers.
|
||||
first, err := stream.Recv()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(first.Status).To(Equal(int32(http.StatusOK)))
|
||||
g.Expect(hasHeader(first.Headers, "X-Echo", "true")).To(BeTrue())
|
||||
|
||||
// Subsequent replies: body.
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
body = append(body, r.BodyChunk...)
|
||||
}
|
||||
g.Expect(string(body)).To(Equal(`echo: {"prompt":"hi"}`))
|
||||
|
||||
// Upstream observations.
|
||||
var gotBodyVal, gotAuthVal, gotPathVal string
|
||||
g.Eventually(gotBody).Should(Receive(&gotBodyVal), "upstream never saw body")
|
||||
g.Expect(gotBodyVal).To(Equal(`{"prompt":"hi"}`))
|
||||
g.Eventually(gotAuth).Should(Receive(&gotAuthVal), "upstream never saw auth header")
|
||||
g.Expect(gotAuthVal).To(Equal("Bearer sk-fake"))
|
||||
g.Eventually(gotPath).Should(Receive(&gotPathVal), "upstream never saw path")
|
||||
g.Expect(gotPathVal).To(Equal("/v1/chat/completions"))
|
||||
}
|
||||
|
||||
func TestForward_AnthropicAuthHeader(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
gotXAPIKey := make(chan string, 1)
|
||||
gotVersion := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotXAPIKey <- r.Header.Get("x-api-key")
|
||||
gotVersion <- r.Header.Get("anthropic-version")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_KEY", "sk-ant-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{Path: "/v1/messages", Method: "POST"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv() // drain status
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
g.Expect(<-gotXAPIKey).To(Equal("sk-ant-fake"))
|
||||
g.Expect(<-gotVersion).NotTo(BeEmpty())
|
||||
}
|
||||
|
||||
func TestLoad_ValidatesConfig(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
|
||||
err := cp.Load(&pb.ModelOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("ProxyOptions"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("upstream_url"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: "rewrite",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unknown mode"))
|
||||
|
||||
// translate + openai should load successfully (Phase 5).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/chat/completions",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// translate + anthropic should load successfully (Phase 6).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/messages",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
ApiKeyEnv: "DEFINITELY_UNSET_ENV_VAR_XYZ",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unset"))
|
||||
}
|
||||
|
||||
func TestForward_RejectsWithoutLoad(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, err = stream.Recv()
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("not loaded"))
|
||||
}
|
||||
|
||||
func hasHeader(hs []*pb.ForwardHeader, name, value string) bool {
|
||||
for _, h := range hs {
|
||||
if strings.EqualFold(h.GetName(), name) && h.GetValue() == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
6
backend/go/cloud-proxy/run.sh
Executable file
6
backend/go/cloud-proxy/run.sh
Executable file
@@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
exec $CURDIR/cloud-proxy "$@"
|
||||
232
backend/go/cloud-proxy/toolcalls_test.go
Normal file
232
backend/go/cloud-proxy/toolcalls_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// OpenAI: non-streaming tool call response. Verify the response is
|
||||
// mapped to Reply.ChatDeltas[].ToolCalls with id/name/arguments intact,
|
||||
// and usage tokens land on Reply.PromptTokens / Reply.Tokens.
|
||||
func TestPredictRich_OpenAI_ToolCalls(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"resp-1",
|
||||
"choices":[{
|
||||
"index":0,
|
||||
"message":{
|
||||
"role":"assistant",
|
||||
"content":"",
|
||||
"tool_calls":[
|
||||
{"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"SF\"}"}},
|
||||
{"id":"call_def","type":"function","function":{"name":"get_time","arguments":"{\"tz\":\"PT\"}"}}
|
||||
]
|
||||
},
|
||||
"finish_reason":"tool_calls"
|
||||
}],
|
||||
"usage":{"prompt_tokens":42,"completion_tokens":18,"total_tokens":60}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(42)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(18)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
tcs := reply.GetChatDeltas()[0].GetToolCalls()
|
||||
g.Expect(tcs).To(HaveLen(2))
|
||||
g.Expect(tcs[0].GetId()).To(Equal("call_abc"))
|
||||
g.Expect(tcs[0].GetName()).To(Equal("get_weather"))
|
||||
g.Expect(tcs[0].GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
g.Expect(tcs[1].GetId()).To(Equal("call_def"))
|
||||
g.Expect(tcs[1].GetName()).To(Equal("get_time"))
|
||||
}
|
||||
|
||||
// OpenAI: streaming tool call. Arguments arrive as a sequence of
|
||||
// delta chunks; the consumer is expected to concatenate by tool index.
|
||||
// Verify each chunk reaches the channel and the assembled arguments
|
||||
// match the input.
|
||||
func TestPredictStreamRich_OpenAI_ToolCallDeltas(t *testing.T) {
|
||||
chunks := []string{
|
||||
// Frame 0: announce the tool call (id + name, no args yet).
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_xyz","type":"function","function":{"name":"search"}}]}}]}`,
|
||||
// Frames 1-3: arguments arrive in fragments.
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"q\":"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"clo"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"uds\"}"}}]}}]}`,
|
||||
// Stop frame.
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find something"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolName string
|
||||
toolID string
|
||||
toolIndex int32 = -1
|
||||
argsBuf strings.Builder
|
||||
)
|
||||
for reply := range results {
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if toolIndex == -1 {
|
||||
toolIndex = tc.GetIndex()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("call_xyz"))
|
||||
g.Expect(toolName).To(Equal("search"))
|
||||
g.Expect(toolIndex).To(Equal(int32(0)))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"clouds"}`))
|
||||
}
|
||||
|
||||
// Anthropic: non-streaming tool_use block. The block appears in
|
||||
// Content[] alongside text blocks; the input field is a structured
|
||||
// JSON object. Map to ToolCallDelta with arguments as serialised JSON
|
||||
// so downstream OpenAI-shaped consumers see a familiar format.
|
||||
func TestPredictRich_Anthropic_ToolUse(t *testing.T) {
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"msg_1","type":"message","role":"assistant",
|
||||
"content":[
|
||||
{"type":"text","text":"Let me check that."},
|
||||
{"type":"tool_use","id":"toolu_01","name":"weather","input":{"location":"SF"}}
|
||||
],
|
||||
"model":"claude","usage":{"input_tokens":12,"output_tokens":34}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal("Let me check that."))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(12)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(34)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
g.Expect(reply.GetChatDeltas()[0].GetToolCalls()).To(HaveLen(1))
|
||||
tc := reply.GetChatDeltas()[0].GetToolCalls()[0]
|
||||
g.Expect(tc.GetId()).To(Equal("toolu_01"))
|
||||
g.Expect(tc.GetName()).To(Equal("weather"))
|
||||
g.Expect(tc.GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
}
|
||||
|
||||
// Anthropic: streaming tool_use. content_block_start announces the
|
||||
// tool's id + name; input_json_delta events carry argument fragments
|
||||
// which the consumer accumulates. message_delta carries final usage.
|
||||
func TestPredictStreamRich_Anthropic_InputJSONDelta(t *testing.T) {
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
// Block 0 is a tool_use; consumer should allocate a slot.
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_42\",\"name\":\"lookup\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"q\\\":\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"\\\"rain\\\"}\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "rain?"}},
|
||||
Tokens: 64,
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolID, toolName string
|
||||
argsBuf strings.Builder
|
||||
finalTokens int32
|
||||
)
|
||||
for reply := range results {
|
||||
if reply.GetTokens() > 0 && len(reply.GetChatDeltas()) == 0 {
|
||||
finalTokens = reply.GetTokens()
|
||||
continue
|
||||
}
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("toolu_42"))
|
||||
g.Expect(toolName).To(Equal("lookup"))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"rain"}`))
|
||||
g.Expect(finalTokens).To(Equal(int32(7)))
|
||||
}
|
||||
|
||||
// Sanity: the legacy Predict() (string, error) signature still works
|
||||
// — it delegates to PredictRich and extracts Message.
|
||||
func TestPredict_LegacyWrapper_OpenAI(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "hi"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello"))
|
||||
}
|
||||
@@ -8,6 +8,6 @@ import (
|
||||
|
||||
func assert(cond bool, msg string) {
|
||||
if !cond {
|
||||
xlog.Fatal().Stack().Msg(msg)
|
||||
xlog.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
package main
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
// LocalAI's in-process vector store, exposed as a gRPC backend. Keep
|
||||
// the implementation here — NOT in a pkg/ library imported by the main
|
||||
// LocalAI process. The whole point of the gRPC surface is that vector
|
||||
// storage is a backend like any other (local-store, qdrant, pinecone,
|
||||
// ...) and can be swapped without changing the routing/recognition
|
||||
// code that consumes it.
|
||||
//
|
||||
// Storage is a sorted parallel-slice (keys [][]float32, values
|
||||
// [][]byte). Set/Delete preserve the sort so Get can binary-search.
|
||||
// Find scans linearly and uses a heap to keep the top-K — fine for
|
||||
// the tens-to-thousands range. The "normalized fast path" (Find when
|
||||
// every stored key has unit magnitude AND the query is normalized)
|
||||
// skips the per-item magnitude calculation.
|
||||
//
|
||||
// Concurrency: base.SingleThread serialises gRPC calls so the
|
||||
// non-thread-safe slice/heap manipulation here is sound.
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
@@ -10,30 +25,27 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
base.SingleThread
|
||||
|
||||
// The sorted keys
|
||||
keys [][]float32
|
||||
// The sorted values
|
||||
keys [][]float32
|
||||
values [][]byte
|
||||
|
||||
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
|
||||
// TODO: Should we normalize incoming keys if they are not instead?
|
||||
// keysAreNormalized stays true until any non-unit-magnitude key
|
||||
// is added; once false, the magnitude-aware fallback path is
|
||||
// used by Find. Re-evaluated only at Set time, never again on
|
||||
// its own — a deletion of the offending key does NOT flip it
|
||||
// back to true (the bookkeeping cost would dominate the gain).
|
||||
keysAreNormalized bool
|
||||
// The first key decides the length of the keys
|
||||
keyLen int
|
||||
}
|
||||
|
||||
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
|
||||
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
|
||||
type Pair struct {
|
||||
Key []float32
|
||||
Value []byte
|
||||
// keyLen is the dimension of every stored key. -1 means "no
|
||||
// keys yet, dimension is open". Dimension mismatch on Set is
|
||||
// rejected so cosine similarity (which requires equal-length
|
||||
// vectors) doesn't silently mis-match.
|
||||
keyLen int
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
@@ -45,334 +57,278 @@ func NewStore() *Store {
|
||||
}
|
||||
}
|
||||
|
||||
func compareSlices(k1, k2 []float32) int {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
return slices.Compare(k1, k2)
|
||||
}
|
||||
|
||||
func hasKey(unsortedSlice [][]float32, target []float32) bool {
|
||||
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
|
||||
return compareSlices(k, target) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
|
||||
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
|
||||
return compareSlices(k, t)
|
||||
})
|
||||
}
|
||||
|
||||
func isSortedPairs(kvs []Pair) bool {
|
||||
for i := 1; i < len(kvs); i++ {
|
||||
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isSortedKeys(keys [][]float32) bool {
|
||||
for i := 1; i < len(keys); i++ {
|
||||
if compareSlices(keys[i-1], keys[i]) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
|
||||
ks := make([][]float32, len(keys))
|
||||
|
||||
for i, k := range keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
|
||||
slices.SortFunc(ks, compareSlices)
|
||||
|
||||
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
|
||||
assert(isSortedKeys(ks), "keys are not sorted")
|
||||
|
||||
return ks
|
||||
}
|
||||
|
||||
// Load is a no-op — local-store has no on-disk artefact. opts.Model is
|
||||
// just a namespace identifier; isolation is already handled upstream
|
||||
// (ModelLoader spawns a fresh local-store process per (backend,
|
||||
// model) tuple, so each namespace is its own Store{} instance).
|
||||
func (s *Store) Load(opts *pb.ModelOptions) error {
|
||||
// local-store is an in-memory vector store with no on-disk artefact to
|
||||
// load — opts.Model is just a namespace identifier. The old `!= ""` guard
|
||||
// rejected any non-empty model name with "not implemented", which broke
|
||||
// callers that pass a namespace to isolate embedding spaces (face vs.
|
||||
// voice biometrics both go through local-store but need distinct stores
|
||||
// so ArcFace 512-D and ECAPA-TDNN 192-D don't collide). Namespace
|
||||
// isolation is already handled upstream: ModelLoader spawns a fresh
|
||||
// local-store process per (backend, model) tuple, so each namespace is
|
||||
// its own Store{} instance. Nothing to do here beyond accepting the load.
|
||||
_ = opts
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort the incoming kvs and merge them with the existing sorted kvs
|
||||
func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
values := store.UnwrapValues(opts.Values)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Set: no keys to add")
|
||||
}
|
||||
|
||||
if len(opts.Keys) != len(opts.Values) {
|
||||
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
|
||||
if len(keys) != len(values) {
|
||||
return fmt.Errorf("local-store: Set: len(keys) = %d, len(values) = %d", len(keys), len(values))
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
s.keyLen = len(keys[0])
|
||||
} else if len(keys[0]) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key length %d does not match existing %d", len(keys[0]), s.keyLen)
|
||||
}
|
||||
|
||||
kvs := make([]Pair, len(opts.Keys))
|
||||
|
||||
for i, k := range opts.Keys {
|
||||
if s.keysAreNormalized && !isNormalized(k.Floats) {
|
||||
kvs := make([]incomingPair, len(keys))
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
if s.keysAreNormalized && !isNormalized(k) {
|
||||
s.keysAreNormalized = false
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = k.Floats[:5]
|
||||
} else {
|
||||
sample = k.Floats
|
||||
}
|
||||
xlog.Debug("Key is not normalized", "sample", sample)
|
||||
}
|
||||
|
||||
kvs[i] = Pair{
|
||||
Key: k.Floats,
|
||||
Value: opts.Values[i].Bytes,
|
||||
}
|
||||
kvs[i] = incomingPair{key: k, value: values[i]}
|
||||
}
|
||||
|
||||
slices.SortFunc(kvs, func(a, b Pair) int {
|
||||
return compareSlices(a.Key, b.Key)
|
||||
})
|
||||
|
||||
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
|
||||
assert(isSortedPairs(kvs), "keys are not sorted")
|
||||
|
||||
l := len(kvs) + len(s.keys)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
i, j := 0, 0
|
||||
for {
|
||||
if i+j >= l {
|
||||
break
|
||||
}
|
||||
|
||||
if i >= len(kvs) {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
continue
|
||||
}
|
||||
|
||||
if j >= len(s.keys) {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
c := compareSlices(kvs[i].Key, s.keys[j])
|
||||
if c < 0 {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
} else if c > 0 {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
} else {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
|
||||
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
slices.SortFunc(kvs, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) })
|
||||
|
||||
merged := mergeSortedPairs(s.keys, s.values, kvs)
|
||||
s.keys = merged.keys
|
||||
s.values = merged.values
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Set: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Set: keys/values length skew")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to delete")
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Delete: no keys to delete")
|
||||
}
|
||||
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
l := len(s.keys) - len(ks)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
tail_ks := s.keys
|
||||
tail_vs := s.values
|
||||
for _, k := range ks {
|
||||
j, found := findInSortedSlice(tail_ks, k)
|
||||
|
||||
if found {
|
||||
merge_ks = append(merge_ks, tail_ks[:j]...)
|
||||
merge_vs = append(merge_vs, tail_vs[:j]...)
|
||||
tail_ks = tail_ks[j+1:]
|
||||
tail_vs = tail_vs[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
|
||||
}
|
||||
|
||||
xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs))
|
||||
}
|
||||
|
||||
merge_ks = append(merge_ks, tail_ks...)
|
||||
merge_vs = append(merge_vs, tail_vs...)
|
||||
|
||||
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
|
||||
assert(isSortedKeys(s.keys), "keys are not sorted")
|
||||
assert(func() bool {
|
||||
for _, k := range ks {
|
||||
if _, found := findInSortedSlice(s.keys, k); found {
|
||||
return false
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Delete: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}(), "Keys to delete still present")
|
||||
|
||||
if len(s.keys) != l {
|
||||
xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l)
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
|
||||
mergedK := make([][]float32, 0, len(s.keys))
|
||||
mergedV := make([][]byte, 0, len(s.keys))
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if ok {
|
||||
mergedK = append(mergedK, tailK[:j]...)
|
||||
mergedV = append(mergedV, tailV[:j]...)
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
}
|
||||
mergedK = append(mergedK, tailK...)
|
||||
mergedV = append(mergedV, tailV...)
|
||||
s.keys = mergedK
|
||||
s.values = mergedV
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Delete: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Delete: keys/values length skew")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StoresGet fetches values for the given keys. Missing keys are
|
||||
// omitted from the result rather than reported as an error — callers
|
||||
// compare returned-key length against requested-key length to detect
|
||||
// them. Returned slices are aligned.
|
||||
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys))
|
||||
pbValues := make([]*pb.StoresValue, 0, len(opts.Keys))
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
if len(s.keys) == 0 {
|
||||
xlog.Debug("Get: No keys in store")
|
||||
return pb.StoresGetResult{}, nil
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("local-store: Get: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
|
||||
tail_k := s.keys
|
||||
tail_v := s.values
|
||||
for i, k := range ks {
|
||||
j, found := findInSortedSlice(tail_k, k)
|
||||
|
||||
if found {
|
||||
pbKeys = append(pbKeys, &pb.StoresKey{
|
||||
Floats: k,
|
||||
})
|
||||
pbValues = append(pbValues, &pb.StoresValue{
|
||||
Bytes: tail_v[j],
|
||||
})
|
||||
|
||||
tail_k = tail_k[j+1:]
|
||||
tail_v = tail_v[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k))
|
||||
var foundKeys [][]float32
|
||||
var foundValues [][]byte
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
foundKeys = append(foundKeys, tailK[j])
|
||||
foundValues = append(foundValues, tailV[j])
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
|
||||
if len(pbKeys) != len(opts.Keys) {
|
||||
xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys))
|
||||
}
|
||||
|
||||
return pb.StoresGetResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Keys: store.WrapKeys(foundKeys),
|
||||
Values: store.WrapValues(foundValues),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StoresFind returns the topK nearest stored entries by cosine
|
||||
// similarity, ordered most-similar first. An empty store returns
|
||||
// empty slices and no error.
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
query := opts.Key.Floats
|
||||
topK := int(opts.TopK)
|
||||
if topK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: topK = %d, must be >= 1", topK)
|
||||
}
|
||||
if len(s.keys) == 0 {
|
||||
return pb.StoresFindResult{}, nil
|
||||
}
|
||||
if len(query) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: query length %d does not match existing %d", len(query), s.keyLen)
|
||||
}
|
||||
|
||||
var keys [][]float32
|
||||
var values [][]byte
|
||||
var sims []float32
|
||||
if s.keysAreNormalized && isNormalized(query) {
|
||||
keys, values, sims = s.findNormalized(query, topK)
|
||||
} else {
|
||||
keys, values, sims = s.findFallback(query, topK)
|
||||
}
|
||||
return pb.StoresFindResult{
|
||||
Keys: store.WrapKeys(keys),
|
||||
Values: store.WrapValues(values),
|
||||
Similarities: sims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) findNormalized(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
assert(s.keysAreNormalized, "findNormalized: s.keysAreNormalized is false")
|
||||
assert(isNormalized(query), "findNormalized: query is not unit-length")
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot float32
|
||||
for j := range k {
|
||||
dot += query[j] * k[j]
|
||||
}
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("findNormalized: dot %f out of [-1, 1] — keysAreNormalized invariant violated", dot))
|
||||
heap.Push(&pq, &priorityItem{similarity: dot, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func (s *Store) findFallback(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
var qmag float64
|
||||
for _, v := range query {
|
||||
qmag += float64(v) * float64(v)
|
||||
}
|
||||
qmag = math.Sqrt(qmag)
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot, kmag float64
|
||||
for j := range k {
|
||||
dot += float64(query[j]) * float64(k[j])
|
||||
kmag += float64(k[j]) * float64(k[j])
|
||||
}
|
||||
denom := qmag * math.Sqrt(kmag)
|
||||
var sim float32
|
||||
if denom > 0 {
|
||||
sim = float32(dot / denom)
|
||||
}
|
||||
heap.Push(&pq, &priorityItem{similarity: sim, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func isNormalized(k []float32) bool {
|
||||
var sum float64
|
||||
|
||||
for _, v := range k {
|
||||
v64 := float64(v)
|
||||
sum += v64 * v64
|
||||
sum += float64(v) * float64(v)
|
||||
}
|
||||
|
||||
s := math.Sqrt(sum)
|
||||
|
||||
return s >= 0.99 && s <= 1.01
|
||||
mag := math.Sqrt(sum)
|
||||
return mag >= 0.99 && mag <= 1.01
|
||||
}
|
||||
|
||||
// TODO: This we could replace with handwritten SIMD code
|
||||
func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
type incomingPair struct {
|
||||
key []float32
|
||||
value []byte
|
||||
}
|
||||
|
||||
var dot float32
|
||||
for i := range len(k1) {
|
||||
dot += k1[i] * k2[i]
|
||||
type pairs struct {
|
||||
keys [][]float32
|
||||
values [][]byte
|
||||
}
|
||||
|
||||
// mergeSortedPairs merges (existing, incoming) into a fresh sorted
|
||||
// slice. Equal keys take the incoming value — Set is upsert.
|
||||
func mergeSortedPairs(existingK [][]float32, existingV [][]byte, incoming []incomingPair) pairs {
|
||||
assert(slices.IsSortedFunc(existingK, slices.Compare[[]float32]), "mergeSortedPairs: existing not sorted")
|
||||
assert(slices.IsSortedFunc(incoming, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) }), "mergeSortedPairs: incoming not sorted")
|
||||
l := len(existingK) + len(incoming)
|
||||
mk := make([][]float32, 0, l)
|
||||
mv := make([][]byte, 0, l)
|
||||
i, j := 0, 0
|
||||
for i < len(incoming) || j < len(existingK) {
|
||||
switch {
|
||||
case j >= len(existingK):
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case i >= len(incoming):
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
c := slices.Compare(incoming[i].key, existingK[j])
|
||||
switch {
|
||||
case c < 0:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case c > 0:
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot))
|
||||
|
||||
// 2.0 * (1.0 - dot) would be the Euclidean distance
|
||||
return dot
|
||||
return pairs{keys: mk, values: mv}
|
||||
}
|
||||
|
||||
type PriorityItem struct {
|
||||
Similarity float32
|
||||
Key []float32
|
||||
Value []byte
|
||||
type priorityItem struct {
|
||||
similarity float32
|
||||
key []float32
|
||||
value []byte
|
||||
}
|
||||
|
||||
type PriorityQueue []*PriorityItem
|
||||
type priorityQueue []*priorityItem
|
||||
|
||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq PriorityQueue) Less(i, j int) bool {
|
||||
// Inverted because the most similar should be at the top
|
||||
return pq[i].Similarity < pq[j].Similarity
|
||||
}
|
||||
|
||||
func (pq PriorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Push(x any) {
|
||||
item := x.(*PriorityItem)
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Pop() any {
|
||||
func (pq priorityQueue) Len() int { return len(pq) }
|
||||
func (pq priorityQueue) Less(i, j int) bool { return pq[i].similarity < pq[j].similarity }
|
||||
func (pq priorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
|
||||
func (pq *priorityQueue) Push(x any) { *pq = append(*pq, x.(*priorityItem)) }
|
||||
func (pq *priorityQueue) Pop() any {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
@@ -380,142 +336,16 @@ func (pq *PriorityQueue) Pop() any {
|
||||
return item
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
for i, k := range s.keys {
|
||||
sim := normalizedCosineSimilarity(tk, k)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: sim,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot, mag2 float64
|
||||
for i := range len(k1) {
|
||||
dot += float64(k1[i] * k2[i])
|
||||
mag2 += float64(k2[i] * k2[i])
|
||||
}
|
||||
|
||||
sim := float32(dot / (mag1 * math.Sqrt(mag2)))
|
||||
|
||||
assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim))
|
||||
|
||||
return sim
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
var mag1 float64
|
||||
for _, v := range tk {
|
||||
mag1 += float64(v * v)
|
||||
}
|
||||
mag1 = math.Sqrt(mag1)
|
||||
|
||||
for i, k := range s.keys {
|
||||
dist := cosineSimilarity(tk, k, mag1)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: dist,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
|
||||
if len(tk) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen)
|
||||
}
|
||||
|
||||
if opts.TopK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK)
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Key.Floats)
|
||||
} else {
|
||||
if len(opts.Key.Floats) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
if s.keysAreNormalized && isNormalized(tk) {
|
||||
return s.StoresFindNormalized(opts)
|
||||
} else {
|
||||
if s.keysAreNormalized {
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = tk[:5]
|
||||
} else {
|
||||
sample = tk
|
||||
}
|
||||
xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample)
|
||||
}
|
||||
|
||||
return s.StoresFindFallback(opts)
|
||||
func drainPQ(pq *priorityQueue) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
n := pq.Len()
|
||||
keys = make([][]float32, n)
|
||||
values = make([][]byte, n)
|
||||
similarities = make([]float32, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
item := heap.Pop(pq).(*priorityItem)
|
||||
keys[i] = item.key
|
||||
values[i] = item.value
|
||||
similarities[i] = item.similarity
|
||||
}
|
||||
return keys, values, similarities
|
||||
}
|
||||
|
||||
13
backend/go/local-store/store_suite_test.go
Normal file
13
backend/go/local-store/store_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLocalStore(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "local-store test suite")
|
||||
}
|
||||
284
backend/go/local-store/store_test.go
Normal file
284
backend/go/local-store/store_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package main
|
||||
|
||||
// Regression suite for the local-store gRPC backend. Exercises the
|
||||
// Stores{Set,Get,Find,Delete} surface — the only public contract.
|
||||
// Callers (face/voice recognition, the routing KNN classifier) reach
|
||||
// this code via grpc.Backend, so testing at the wire-shaped boundary
|
||||
// matches the production import shape.
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("StoresSet", func() {
|
||||
It("rejects empty input", func() {
|
||||
Expect(NewStore().StoresSet(&pb.StoresSetOptions{})).NotTo(Succeed(), "Set with no keys should fail")
|
||||
})
|
||||
|
||||
It("rejects key/value length mismatch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("a"), []byte("b")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "len(keys) != len(values) should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch on later add", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("3d")})
|
||||
err := s.StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "dimension mismatch on later Set should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch within batch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}, {1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("3d"), []byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "mixed-dimension within one batch should fail")
|
||||
})
|
||||
|
||||
It("merges sorted and updates existing key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.3, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("c"), []byte("a")})
|
||||
mustSet(s, [][]float32{{0.2, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("b"), []byte("a-updated")})
|
||||
Expect(s.keys).To(HaveLen(3))
|
||||
got := singleGet(s, []float32{0.1, 0, 0})
|
||||
Expect(string(got)).To(Equal("a-updated"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresGet", func() {
|
||||
It("round-trips multi-key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c")},
|
||||
)
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("omits missing keys rather than erroring", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.1, 0, 0}, {0.9, 0, 0}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresDelete", func() {
|
||||
It("removes and preserves sort", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0, 0}, {0.2, 0, 0}, {0.3, 0, 0}, {0.4, 0, 0}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c"), []byte("d")},
|
||||
)
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.2, 0, 0}, {0.4, 0, 0}}),
|
||||
})).To(Succeed())
|
||||
Expect(s.keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("tolerates missing keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.9, 0, 0}}),
|
||||
})).To(Succeed(), "delete of missing key should succeed")
|
||||
Expect(s.keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresFind", func() {
|
||||
It("returns normalized top-K", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{
|
||||
normalizeVec([]float32{1, 0, 0}),
|
||||
normalizeVec([]float32{0, 1, 0}),
|
||||
normalizeVec([]float32{0, 0, 1}),
|
||||
},
|
||||
[][]byte{[]byte("x"), []byte("y"), []byte("z")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{0.9, 0.1, 0})},
|
||||
TopK: 2,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", res.Similarities[1]), "results not sorted desc by similarity")
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
})
|
||||
|
||||
It("falls back for non-normalized keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{2, 0, 0}, {0, 3, 0}}, [][]byte{[]byte("x"), []byte("y")})
|
||||
Expect(s.keysAreNormalized).To(BeFalse(), "store should report non-normalized after Set with magnitude > 1")
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{4, 0, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", float32(0.99)))
|
||||
Expect(res.Similarities[0]).To(BeNumerically("<=", float32(1.01)))
|
||||
})
|
||||
|
||||
It("rejects zero topK", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 0,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with topK=0 should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with mismatched dimension should fail")
|
||||
})
|
||||
|
||||
It("returns empty result on empty store", func() {
|
||||
res, err := NewStore().StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 5,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred(), "Find on empty store should succeed")
|
||||
Expect(res.Keys).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("handles topK larger than store", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{normalizeVec([]float32{1, 0, 0}), normalizeVec([]float32{0, 1, 0})},
|
||||
[][]byte{[]byte("x"), []byte("y")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{1, 0, 0})},
|
||||
TopK: 10,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresLoad", func() {
|
||||
It("is a no-op", func() {
|
||||
Expect(NewStore().Load(&pb.ModelOptions{Model: "any-namespace"})).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
func BenchmarkStoresFindNormalized(b *testing.B) {
|
||||
const dim = 768
|
||||
for _, n := range []int{8, 32, 128, 512} {
|
||||
b.Run(fmtN(n), func(b *testing.B) {
|
||||
s := buildStore(b, n, dim)
|
||||
query := normalizeVec(randVec(dim, 42))
|
||||
req := &pb.StoresFindOptions{Key: &pb.StoresKey{Floats: query}, TopK: 1}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := s.StoresFind(req); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
func mustSet(s *Store, keys [][]float32, values [][]byte) {
|
||||
ExpectWithOffset(1, s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)})).To(Succeed())
|
||||
}
|
||||
|
||||
func singleGet(s *Store, key []float32) []byte {
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{Keys: wrapKeys([][]float32{key})})
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
if len(res.Values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return res.Values[0].Bytes
|
||||
}
|
||||
|
||||
func wrapKeys(in [][]float32) []*pb.StoresKey {
|
||||
out := make([]*pb.StoresKey, len(in))
|
||||
for i, k := range in {
|
||||
out[i] = &pb.StoresKey{Floats: k}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func wrapValues(in [][]byte) []*pb.StoresValue {
|
||||
out := make([]*pb.StoresValue, len(in))
|
||||
for i, v := range in {
|
||||
out[i] = &pb.StoresValue{Bytes: v}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildStore(tb testing.TB, n, dim int) *Store {
|
||||
tb.Helper()
|
||||
s := NewStore()
|
||||
keys := make([][]float32, n)
|
||||
values := make([][]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
keys[i] = normalizeVec(randVec(dim, int64(i)+1))
|
||||
values[i] = []byte{byte(i)}
|
||||
}
|
||||
if err := s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)}); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func randVec(dim int, seed int64) []float32 {
|
||||
r := rand.New(rand.NewPCG(uint64(seed), 0xabcdef))
|
||||
v := make([]float32, dim)
|
||||
for i := range v {
|
||||
v[i] = float32(r.NormFloat64())
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func normalizeVec(v []float32) []float32 {
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
mag := math.Sqrt(sum)
|
||||
if mag == 0 {
|
||||
return v
|
||||
}
|
||||
out := make([]float32, len(v))
|
||||
for i, x := range v {
|
||||
out[i] = float32(float64(x) / mag)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fmtN(n int) string {
|
||||
return map[int]string{8: "n=8", 32: "n=32", 128: "n=128", 512: "n=512"}[n]
|
||||
}
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=bd17f53b7386fb5f60e8587b75e73c4b2fed3426
|
||||
STABLEDIFFUSION_GGML_VERSION?=a397e03488cc27e1a42da646b82dfce9f50741c0
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -376,6 +376,8 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *clip_g_path = "";
|
||||
const char *t5xxl_path = "";
|
||||
const char *vae_path = "";
|
||||
const char *audio_vae_path = "";
|
||||
const char *embeddings_connectors_path = "";
|
||||
const char *scheduler_str = "";
|
||||
const char *sampler = "";
|
||||
const char *clip_vision_path = "";
|
||||
@@ -431,6 +433,12 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "vae_path")) {
|
||||
vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "audio_vae_path")) {
|
||||
audio_vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "embeddings_connectors_path")) {
|
||||
embeddings_connectors_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "scheduler")) {
|
||||
scheduler_str = optval;
|
||||
}
|
||||
@@ -563,6 +571,8 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
ctx_params.taesd_path = taesd_path;
|
||||
ctx_params.control_net_path = control_net_path;
|
||||
if (lora_dir && strlen(lora_dir) > 0) {
|
||||
@@ -1188,6 +1198,9 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
p->high_noise_sample_params.scheduler = scheduler;
|
||||
p->high_noise_sample_params.flow_shift = flow_shift;
|
||||
|
||||
// Pin output fps in params; upstream uses it for audio sync (and we also mux at this rate).
|
||||
p->fps = fps;
|
||||
|
||||
// Load init/end reference images if provided (resized to output dims).
|
||||
uint8_t* init_buf = nullptr;
|
||||
uint8_t* end_buf = nullptr;
|
||||
@@ -1206,11 +1219,14 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
|
||||
// Generate
|
||||
int num_frames_out = 0;
|
||||
sd_image_t* frames = generate_video(sd_c, p, &num_frames_out);
|
||||
sd_image_t* frames = nullptr;
|
||||
sd_audio_t* audio = nullptr;
|
||||
bool ok = generate_video(sd_c, p, &frames, &num_frames_out, &audio);
|
||||
std::free(p);
|
||||
|
||||
if (!frames || num_frames_out == 0) {
|
||||
if (!ok || !frames || num_frames_out == 0) {
|
||||
fprintf(stderr, "generate_video produced no frames\n");
|
||||
if (audio) free_sd_audio(audio);
|
||||
if (init_buf) free(init_buf);
|
||||
if (end_buf) free(end_buf);
|
||||
return 1;
|
||||
@@ -1224,6 +1240,7 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
if (frames[i].data) free(frames[i].data);
|
||||
}
|
||||
free(frames);
|
||||
if (audio) free_sd_audio(audio);
|
||||
if (init_buf) free(init_buf);
|
||||
if (end_buf) free(end_buf);
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=968eebe77225d25e57a3f981da7c696310f0e881
|
||||
WHISPER_CPP_VERSION?=0ccd896f5b882628e1c077f9769735ef4ce52860
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -36,15 +36,11 @@ fi
|
||||
# flash-attn-4 4.0 stable lands.
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --prerelease=allow"
|
||||
|
||||
# JetPack 7 / L4T arm64 wheels are built for cp312 and shipped via
|
||||
# pypi.jetson-ai-lab.io. Bump the venv Python so the prebuilt sglang
|
||||
# wheel resolves cleanly. The actual install on l4t13 goes through
|
||||
# pyproject.toml (see the elif branch below) so [tool.uv.sources] can
|
||||
# pin only torch/torchvision/torchaudio/sglang to the jetson-ai-lab
|
||||
# index — leaving PyPI as the path for transitive deps like
|
||||
# markdown-it-py / anthropic / propcache that the L4T mirror's proxy
|
||||
# 503s on. No --index-strategy flag here: the explicit index keeps the
|
||||
# scoping clean.
|
||||
# JetPack 7 / L4T arm64 sglang + torch wheels come straight from PyPI now
|
||||
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and sglang 0.5.11+
|
||||
# ships a cp312 aarch64 wheel pinned to that torch). They're cp312-only,
|
||||
# so bump the venv Python accordingly.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
@@ -110,27 +106,6 @@ if [ "x${BUILD_TYPE}" == "x" ] || [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
|
||||
fi
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
|
||||
popd
|
||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
||||
# [tool.uv.sources] can pin torch/torchvision/torchaudio/sglang to the
|
||||
# jetson-ai-lab index, while everything else (transitive deps and
|
||||
# PyPI-resolvable packages like transformers / accelerate) comes from
|
||||
# PyPI. Bypasses installRequirements because uv pip install -r
|
||||
# requirements.txt does not honor sources — see
|
||||
# backend/python/sglang/pyproject.toml for the rationale. Mirrors the
|
||||
# equivalent path in backend/python/vllm/install.sh.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
pushd "${backend_dir}"
|
||||
# Build deps first (matches installRequirements' requirements-install.txt
|
||||
# pass — sglang/sgl-kernel sdists need packaging/setuptools-scm in the
|
||||
# venv before they can build under --no-build-isolation).
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
||||
popd
|
||||
runProtogen
|
||||
else
|
||||
installRequirements
|
||||
fi
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the sglang backend.
|
||||
#
|
||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
||||
#
|
||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / sglang / sgl-kernel
|
||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently.
|
||||
# With `--extra-index-url` + `--index-strategy=unsafe-best-match` (the
|
||||
# historical fix in install.sh) uv would pick those proxy URLs for ordinary
|
||||
# PyPI packages — markdown-it-py, anthropic, propcache, etc. — and trip on
|
||||
# the 503s. See e.g. CI run 25439791228 (markdown-it-py-4.0.0).
|
||||
#
|
||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
||||
# This breaks the historical 503 path without losing access to the L4T
|
||||
# wheels we actually need from there. Mirrors the equivalent fix already
|
||||
# in backend/python/vllm/pyproject.toml.
|
||||
#
|
||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
||||
# pipeline through libbackend.sh's installRequirements and never read
|
||||
# this file.
|
||||
[project]
|
||||
name = "localai-sglang-l4t13"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
||||
# l4t13 path bypasses installRequirements (see install.sh).
|
||||
"grpcio==1.80.0",
|
||||
"protobuf",
|
||||
"certifi",
|
||||
"setuptools",
|
||||
"pillow",
|
||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
# sglang on jetson — the [all] extra is deliberately omitted because it
|
||||
# pulls outlines/decord, and decord has no aarch64 cp312 wheel anywhere
|
||||
# (PyPI nor the jetson-ai-lab index ships only legacy cp35-cp37). With
|
||||
# [all] uv backtracks through versions trying to satisfy decord and
|
||||
# lands on sglang==0.1.16. The 0.5.0 floor matches the only major
|
||||
# series the jetson-ai-lab sbsa/cu130 mirror currently publishes
|
||||
# (sglang==0.5.1.post2 as of 2026-05-06). Bumping to >=0.5.11 here
|
||||
# would make the build unsatisfiable until the mirror catches up.
|
||||
# Gemma 4 / MTP recipes are therefore not supported on l4t13 — those
|
||||
# features land on cublas12/cublas13 hosts that pull the newer wheel
|
||||
# from PyPI. backend.py keeps backward compat with the 0.5.x SamplingParams
|
||||
# field rename via runtime detection.
|
||||
"sglang>=0.5.0",
|
||||
# PyPI-resolvable packages that complete the runtime.
|
||||
"accelerate",
|
||||
"transformers",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "jetson-ai-lab"
|
||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "jetson-ai-lab" }
|
||||
torchvision = { index = "jetson-ai-lab" }
|
||||
torchaudio = { index = "jetson-ai-lab" }
|
||||
sglang = { index = "jetson-ai-lab" }
|
||||
15
backend/python/sglang/requirements-l4t13-after.txt
Normal file
15
backend/python/sglang/requirements-l4t13-after.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
# sglang 0.5.11+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist
|
||||
# pins torch==2.11.0 / torchaudio==2.11.0, locking an ABI-consistent set with
|
||||
# the cu130 torch wheel installed above. 0.5.11 is the floor for Gemma 4
|
||||
# support (sgl-project/sglang#21952).
|
||||
#
|
||||
# The [all] extra is deliberately NOT used on aarch64: it pulls the
|
||||
# [diffusion] sub-extra which requires `xatlas`, and xatlas ships no
|
||||
# aarch64 wheel and its sdist depends on scikit_build_core without
|
||||
# declaring it in build-system.requires — so under --no-build-isolation
|
||||
# uv can't build it. Upstream sglang gates st_attn and vsa on
|
||||
# platform_machine != aarch64 in the diffusion extra but forgot xatlas.
|
||||
# Plain `sglang` carries everything backend.py uses (Engine, ServerArgs,
|
||||
# FunctionCallParser, ReasoningParser); the [all] extras are optional
|
||||
# accelerators not required at import time.
|
||||
sglang>=0.5.11
|
||||
9
backend/python/sglang/requirements-l4t13.txt
Normal file
9
backend/python/sglang/requirements-l4t13.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
@@ -26,7 +26,7 @@ import torch.cuda
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
import transformers as transformers_module
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, pipeline
|
||||
from scipy.io import wavfile
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
@@ -200,6 +200,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
autoTokenizer = False
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.SentenceTransformer = True
|
||||
elif request.Type == "TokenClassification":
|
||||
# NER / PII tagging via HuggingFace's token-classification
|
||||
# pipeline. aggregation_strategy="simple" merges B-/I- tags
|
||||
# into single spans and gives byte offsets back. The
|
||||
# tokenizer is bundled inside the pipeline, so we skip the
|
||||
# AutoTokenizer load below.
|
||||
autoTokenizer = False
|
||||
self.tokenClassifier = pipeline(
|
||||
"token-classification",
|
||||
model=model_name,
|
||||
aggregation_strategy="simple",
|
||||
device=0 if self.CUDA else -1,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
)
|
||||
self.TokenClassification = True
|
||||
else:
|
||||
# Generic: dynamically resolve model class from transformers
|
||||
model_type = TYPE_ALIASES.get(request.Type, request.Type)
|
||||
@@ -253,6 +268,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TokenClassify(self, request, context):
|
||||
# Runs HuggingFace's token-classification pipeline and returns
|
||||
# the aggregated entity spans. The pipeline gives us byte
|
||||
# offsets via aggregation_strategy="simple" (set at load
|
||||
# time), so the caller can slice the original text without
|
||||
# re-tokenising on the Go side.
|
||||
if not getattr(self, "TokenClassification", False):
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("model was not loaded as Type=TokenClassification")
|
||||
return backend_pb2.TokenClassifyResponse()
|
||||
try:
|
||||
results = self.tokenClassifier(request.text)
|
||||
except Exception as err:
|
||||
print("TokenClassify error:", err, file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"token-classification failed: {err}")
|
||||
return backend_pb2.TokenClassifyResponse()
|
||||
|
||||
threshold = request.threshold if request.threshold > 0 else 0.0
|
||||
entities = []
|
||||
for r in results:
|
||||
score = float(r.get("score", 0.0))
|
||||
if score < threshold:
|
||||
continue
|
||||
entities.append(backend_pb2.TokenClassifyEntity(
|
||||
entity_group=str(r.get("entity_group") or r.get("entity") or ""),
|
||||
start=int(r.get("start", 0)),
|
||||
end=int(r.get("end", 0)),
|
||||
score=score,
|
||||
text=str(r.get("word", "")),
|
||||
))
|
||||
return backend_pb2.TokenClassifyResponse(entities=entities)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
set_seed(request.Seed)
|
||||
# Tokenize input
|
||||
|
||||
@@ -2,9 +2,9 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,9 +2,9 @@ torch==2.7.1
|
||||
accelerate
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,9 +2,9 @@
|
||||
torch==2.9.0
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -1,11 +1,11 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm7.0
|
||||
torch==2.10.0+rocm7.0
|
||||
accelerate
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -3,9 +3,9 @@ torch
|
||||
optimum[openvino]
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,9 +2,9 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers>=5.8.0
|
||||
transformers>=5.8.1
|
||||
bitsandbytes
|
||||
sentence-transformers==5.4.0
|
||||
sentence-transformers==5.5.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
|
||||
@@ -13,14 +13,14 @@ else
|
||||
fi
|
||||
|
||||
# Handle l4t build profiles (Python 3.12, pip fallback) if needed.
|
||||
# unsafe-best-match is required on l4t13 because the jetson-ai-lab index
|
||||
# lists transitive deps at limited versions — without it uv pins to the
|
||||
# first matching index and fails to resolve a compatible wheel from PyPI.
|
||||
# Since PyTorch 2.11 (April 2026) PyPI ships aarch64 + cu130 manylinux wheels
|
||||
# directly for torch/torchvision/torchaudio and an aarch64 vllm wheel pinned
|
||||
# to that torch, so the jetson-ai-lab mirror is no longer needed.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
EXTRA_PIP_INSTALL_FLAGS="${EXTRA_PIP_INSTALL_FLAGS:-} --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
@@ -42,18 +42,11 @@ if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
else
|
||||
uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700
|
||||
fi
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
# JetPack 7 / L4T arm64 cu130 — vllm comes from the prebuilt SBSA wheel
|
||||
# at jetson-ai-lab. Version is unpinned: the index ships whatever build
|
||||
# matches the cu130/cp312 ABI. unsafe-best-match lets uv fall through
|
||||
# to PyPI for transitive deps not present on the jetson-ai-lab index.
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
else
|
||||
uv pip install --index-strategy=unsafe-best-match vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
fi
|
||||
elif [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
||||
# vllm 0.19+ defaults to cu130 wheels on PyPI, no extra index needed.
|
||||
elif [ "x${BUILD_PROFILE}" == "xcublas13" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
# cublas13 (x86_64) and l4t13 (aarch64) both pull vllm from PyPI now:
|
||||
# vllm 0.19+ defaults to cu130 wheels on x86_64 and vllm 0.20+ ships an
|
||||
# aarch64 manylinux wheel pinned to torch==2.11.0. No extra index needed
|
||||
# in either case.
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm --torch-backend=auto
|
||||
else
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. PyPI ships aarch64 + cu130 manylinux wheels
|
||||
# for torch/torchvision/torchaudio directly since PyTorch 2.11 (April 2026),
|
||||
# so no custom index is needed. flash-attn is dropped here: PyPI has no
|
||||
# aarch64 wheel for it, but vLLM 0.20+ bundles its own vllm_flash_attn
|
||||
# (fa2 + fa3) inside the main wheel, so it is not required at runtime.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
bitsandbytes
|
||||
flash-attn
|
||||
diffusers
|
||||
librosa
|
||||
soundfile
|
||||
|
||||
@@ -356,6 +356,133 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
async def Score(self, request, context):
|
||||
"""
|
||||
Joint log-probability of each candidate continuation given the
|
||||
shared prompt. Used by routing-policy multi-label classification
|
||||
(read the distribution rather than asking the model to emit a
|
||||
single argmax label), reranking, and reward-model scoring.
|
||||
|
||||
Implementation uses vLLM's `prompt_logprobs` to recover the
|
||||
per-token log P(token_i | tokens_<i) for the full concatenated
|
||||
sequence; the candidate's tokens are the suffix whose logprobs
|
||||
get summed. max_tokens=1 because vLLM requires at least one
|
||||
generated token; the generated token is discarded.
|
||||
"""
|
||||
if not hasattr(self, 'llm') or self.llm is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Model not loaded")
|
||||
return backend_pb2.ScoreResponse()
|
||||
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Tokenizer not available")
|
||||
return backend_pb2.ScoreResponse()
|
||||
if len(request.candidates) == 0:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("candidates must be non-empty")
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
try:
|
||||
prompt = request.prompt or ""
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
results = []
|
||||
|
||||
for candidate in request.candidates:
|
||||
# Tokenise the concatenated sequence. We can't naively
|
||||
# use len(prompt_tokens) + len(tokenizer.encode(candidate))
|
||||
# because BPE merges at the boundary may produce a
|
||||
# different tokenisation. Encoding the joined text and
|
||||
# walking the divergence point is the correct primitive.
|
||||
full_text = prompt + candidate
|
||||
full_token_ids = self.tokenizer.encode(full_text)
|
||||
|
||||
divergence = prompt_len
|
||||
min_len = min(prompt_len, len(full_token_ids))
|
||||
for i in range(min_len):
|
||||
if prompt_token_ids[i] != full_token_ids[i]:
|
||||
divergence = i
|
||||
break
|
||||
|
||||
candidate_token_ids = full_token_ids[divergence:]
|
||||
num_candidate_tokens = len(candidate_token_ids)
|
||||
if num_candidate_tokens == 0:
|
||||
results.append(backend_pb2.CandidateScore(
|
||||
log_prob=0.0,
|
||||
length_normalized_log_prob=0.0,
|
||||
num_tokens=0,
|
||||
))
|
||||
continue
|
||||
|
||||
sampling = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
prompt_logprobs=1,
|
||||
detokenize=False,
|
||||
)
|
||||
|
||||
request_id = random_uuid()
|
||||
last_output = None
|
||||
outputs_iter = self.llm.generate(
|
||||
{"prompt": full_text},
|
||||
sampling_params=sampling,
|
||||
request_id=request_id,
|
||||
)
|
||||
try:
|
||||
async for out in outputs_iter:
|
||||
last_output = out
|
||||
finally:
|
||||
try:
|
||||
await outputs_iter.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if last_output is None or not getattr(last_output, "prompt_logprobs", None):
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details("vLLM did not return prompt_logprobs")
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
prompt_logprobs = last_output.prompt_logprobs
|
||||
total = 0.0
|
||||
tokens_proto = []
|
||||
for offset, tok_id in enumerate(candidate_token_ids):
|
||||
position = divergence + offset
|
||||
if position >= len(prompt_logprobs) or prompt_logprobs[position] is None:
|
||||
continue
|
||||
entry = prompt_logprobs[position]
|
||||
lp_obj = entry.get(tok_id)
|
||||
if lp_obj is not None:
|
||||
lp = lp_obj.logprob
|
||||
else:
|
||||
# Token not in top-K; vLLM's top-1 may miss it.
|
||||
# Fall back to the lowest available logprob in the
|
||||
# entry — a conservative lower-bound on the true
|
||||
# log P, biased against this candidate.
|
||||
lp = min(v.logprob for v in entry.values())
|
||||
total += lp
|
||||
if request.include_token_logprobs:
|
||||
tokens_proto.append(backend_pb2.TokenLogProb(
|
||||
token=self.tokenizer.decode([tok_id]),
|
||||
log_prob=lp,
|
||||
))
|
||||
|
||||
cs = backend_pb2.CandidateScore(
|
||||
log_prob=total,
|
||||
num_tokens=num_candidate_tokens,
|
||||
)
|
||||
if request.length_normalize and num_candidate_tokens > 0:
|
||||
cs.length_normalized_log_prob = total / num_candidate_tokens
|
||||
if tokens_proto:
|
||||
cs.tokens.extend(tokens_proto)
|
||||
results.append(cs)
|
||||
|
||||
return backend_pb2.ScoreResponse(candidates=results)
|
||||
except Exception as e:
|
||||
print(f"Score error: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
# Build the sampling parameters
|
||||
# NOTE: this must stay in sync with the vllm backend
|
||||
|
||||
@@ -43,14 +43,11 @@ if [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# JetPack 7 / L4T arm64 wheels (torch, vllm, flash-attn) live on
|
||||
# pypi.jetson-ai-lab.io and are built for cp312, so bump the venv Python
|
||||
# accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||
#
|
||||
# l4t13 uses pyproject.toml (see the elif branch below) to pin only the
|
||||
# L4T-specific wheels to the jetson-ai-lab index via [tool.uv.sources].
|
||||
# That keeps PyPI as the resolution path for transitive deps like
|
||||
# anthropic/openai/propcache, which the L4T mirror's proxy 503s on.
|
||||
# JetPack 7 / L4T arm64 vllm + torch wheels come straight from PyPI now
|
||||
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and vllm 0.20+ ships
|
||||
# an aarch64 wheel pinned to that torch). They're cp312-only, so bump the
|
||||
# venv Python accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
@@ -103,25 +100,6 @@ if [ "x${BUILD_TYPE}" == "xintel" ]; then
|
||||
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
|
||||
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
|
||||
popd
|
||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
||||
# [tool.uv.sources] can pin torch/vllm/flash-attn/torchvision/torchaudio
|
||||
# to the jetson-ai-lab index, while everything else (transitive deps and
|
||||
# PyPI-resolvable packages like transformers) comes from PyPI. Bypasses
|
||||
# installRequirements because uv pip install -r requirements.txt does not
|
||||
# honor sources — see backend/python/vllm/pyproject.toml for the rationale.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
pushd "${backend_dir}"
|
||||
# Build deps first (matches installRequirements' requirements-install.txt
|
||||
# pass — fastsafetensors and friends need pybind11 in the venv before
|
||||
# their sdists can build under --no-build-isolation).
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
||||
popd
|
||||
runProtogen
|
||||
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
|
||||
# requirements-cpu-after.txt and compiles vllm locally against the host's
|
||||
# actual CPU. Not used by default because it takes ~30-40 minutes, but
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the vllm backend.
|
||||
#
|
||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
||||
#
|
||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / vllm / flash-attn
|
||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently. With
|
||||
# `--extra-index-url` + `--index-strategy=unsafe-best-match` (the historical
|
||||
# fix in install.sh) uv would pick those proxy URLs for ordinary PyPI
|
||||
# packages — `anthropic`, `openai`, `propcache`, `annotated-types` — and
|
||||
# trip on the 503s. See e.g. CI run 25212201349 (anthropic-0.97.0).
|
||||
#
|
||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
||||
# This breaks the historical 503 path without losing access to the L4T
|
||||
# wheels we actually need from there.
|
||||
#
|
||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
||||
# pipeline through libbackend.sh's installRequirements and never read
|
||||
# this file.
|
||||
[project]
|
||||
name = "localai-vllm-l4t13"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
||||
# l4t13 path bypasses installRequirements (see install.sh).
|
||||
"grpcio==1.80.0",
|
||||
"protobuf",
|
||||
"certifi",
|
||||
"setuptools",
|
||||
"pillow",
|
||||
"charset-normalizer>=3.4.7",
|
||||
"chardet",
|
||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"flash-attn",
|
||||
"vllm",
|
||||
# PyPI-resolvable packages that complete the runtime — accelerate,
|
||||
# transformers, bitsandbytes carry their own wheels for aarch64.
|
||||
"accelerate",
|
||||
"transformers",
|
||||
"bitsandbytes",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "jetson-ai-lab"
|
||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "jetson-ai-lab" }
|
||||
torchvision = { index = "jetson-ai-lab" }
|
||||
torchaudio = { index = "jetson-ai-lab" }
|
||||
flash-attn = { index = "jetson-ai-lab" }
|
||||
vllm = { index = "jetson-ai-lab" }
|
||||
4
backend/python/vllm/requirements-l4t13-after.txt
Normal file
4
backend/python/vllm/requirements-l4t13-after.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# vLLM 0.20+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist pins
|
||||
# torch==2.11.0 / torchvision==0.26.0 / torchaudio==2.11.0, locking an ABI-
|
||||
# consistent set with the cu130 torch wheel installed above.
|
||||
vllm
|
||||
8
backend/python/vllm/requirements-l4t13.txt
Normal file
8
backend/python/vllm/requirements-l4t13.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
transformers
|
||||
bitsandbytes
|
||||
@@ -375,6 +375,15 @@ impl Backend for KokorosService {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type AudioToAudioStreamStream = ReceiverStream<Result<backend::AudioToAudioResponse, Status>>;
|
||||
|
||||
async fn audio_to_audio_stream(
|
||||
&self,
|
||||
_: Request<tonic::Streaming<backend::AudioToAudioRequest>>,
|
||||
) -> Result<Response<Self::AudioToAudioStreamStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn sound_generation(
|
||||
&self,
|
||||
_: Request<backend::SoundGenerationRequest>,
|
||||
|
||||
@@ -9,11 +9,18 @@ import (
|
||||
|
||||
corebackend "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services/facerecognition"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/voicerecognition"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
@@ -51,6 +58,22 @@ type Application struct {
|
||||
faceRegistry facerecognition.Registry
|
||||
voiceRegistry voicerecognition.Registry
|
||||
authDB *gorm.DB
|
||||
metricsService *monitoring.LocalAIMetricsService
|
||||
statsRecorder *billing.Recorder
|
||||
fallbackUser *auth.User
|
||||
piiRedactor *pii.Redactor
|
||||
piiEvents pii.EventStore
|
||||
mitmCA atomic.Pointer[mitm.CA]
|
||||
mitmServer atomic.Pointer[mitm.Server]
|
||||
mitmMutex sync.Mutex // serializes Stop+Start; readers use atomic loads
|
||||
// mitmHostConflicts records duplicate-host claims across model configs.
|
||||
// Non-empty disables the MITM listener until resolved — the strict
|
||||
// 1-to-1 host↔model invariant the dispatcher relies on. Read by
|
||||
// /api/middleware/status so the admin UI can surface the cause.
|
||||
mitmHostConflicts atomic.Pointer[map[string][]string]
|
||||
routerDecisions router.DecisionStore
|
||||
routerRegistry *router.Registry
|
||||
admissionLimiter *admission.Limiter
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
p2pMutex sync.Mutex
|
||||
@@ -185,6 +208,103 @@ func (a *Application) AuthDB() *gorm.DB {
|
||||
return a.authDB
|
||||
}
|
||||
|
||||
// MetricsService returns the OTel + Prometheus metric service. nil when
|
||||
// --disable-metrics is set or initialisation failed at startup.
|
||||
//
|
||||
// The service is created in startup.go before any counter is registered
|
||||
// so that otel.SetMeterProvider runs early enough for the billing
|
||||
// recorder's counters to bind to the Prom-backed provider rather than
|
||||
// the no-op global. core/http/app.go reuses this instance instead of
|
||||
// constructing its own — two providers would orphan one set of counters
|
||||
// behind whichever provider lost the SetMeterProvider race.
|
||||
func (a *Application) MetricsService() *monitoring.LocalAIMetricsService {
|
||||
return a.metricsService
|
||||
}
|
||||
|
||||
// StatsRecorder returns the billing recorder used by the usage
|
||||
// middleware. It is non-nil whenever stats are not explicitly disabled
|
||||
// — i.e., the no-auth single-user path still gets a working recorder
|
||||
// (in-memory by default). Routes register UsageMiddleware against this
|
||||
// recorder regardless of auth state.
|
||||
func (a *Application) StatsRecorder() *billing.Recorder {
|
||||
return a.statsRecorder
|
||||
}
|
||||
|
||||
// FallbackUser is the synthetic "local" user that UsageMiddleware uses
|
||||
// to attribute requests when no authenticated user is on the context
|
||||
// (i.e., --auth is off). nil when auth is on, since real users are
|
||||
// always available there.
|
||||
func (a *Application) FallbackUser() *auth.User {
|
||||
return a.fallbackUser
|
||||
}
|
||||
|
||||
// PIIRedactor returns the regex-tier PII redactor or nil if PII
|
||||
// filtering is disabled. The chat-route middleware uses this to apply
|
||||
// redaction before dispatch.
|
||||
func (a *Application) PIIRedactor() *pii.Redactor {
|
||||
return a.piiRedactor
|
||||
}
|
||||
|
||||
// PIIEvents returns the PII event store. Same nil-when-disabled
|
||||
// semantics as PIIRedactor; admin REST and MCP read tools call List
|
||||
// against it.
|
||||
func (a *Application) PIIEvents() pii.EventStore {
|
||||
return a.piiEvents
|
||||
}
|
||||
|
||||
// MITMCA returns the cloudproxy MITM proxy's CA, or nil when the
|
||||
// MITM listener is disabled.
|
||||
func (a *Application) MITMCA() *mitm.CA { return a.mitmCA.Load() }
|
||||
|
||||
// MITMServer returns the running MITM proxy or nil.
|
||||
func (a *Application) MITMServer() *mitm.Server { return a.mitmServer.Load() }
|
||||
|
||||
// MITMHostConflicts returns a snapshot of host→[]model-name pairs that
|
||||
// are claimed by 2+ model configs. Empty when the 1-to-1 invariant
|
||||
// holds. Non-empty disables the MITM listener — read by the admin
|
||||
// status endpoint to explain why.
|
||||
func (a *Application) MITMHostConflicts() map[string][]string {
|
||||
p := a.mitmHostConflicts.Load()
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return *p
|
||||
}
|
||||
|
||||
// MITMHostOwners returns the host→model-name map, useful for the
|
||||
// admin status endpoint. The lookup is recomputed on each call to
|
||||
// stay current with model-config edits without needing a
|
||||
// MITMRestart.
|
||||
func (a *Application) MITMHostOwners() map[string]string {
|
||||
if a.backendLoader == nil {
|
||||
return nil
|
||||
}
|
||||
return a.backendLoader.MITMHostOwners().Owners
|
||||
}
|
||||
|
||||
// RouterDecisions returns the routing decision store. nil when stats
|
||||
// are disabled (--disable-stats); the RouteModel middleware skips the
|
||||
// log write in that case but still rewrites requests.
|
||||
func (a *Application) RouterDecisions() router.DecisionStore {
|
||||
return a.routerDecisions
|
||||
}
|
||||
|
||||
// RouterClassifierRegistry returns the process-wide classifier cache.
|
||||
// Shared between the OpenAI and Anthropic route middlewares so the
|
||||
// admin stats endpoint sees every live classifier — and so a
|
||||
// classifier built on the OpenAI route is reused on Anthropic.
|
||||
func (a *Application) RouterClassifierRegistry() *router.Registry {
|
||||
return a.routerRegistry
|
||||
}
|
||||
|
||||
// AdmissionLimiter returns the per-model admission limiter. The
|
||||
// admission middleware uses it to gate concurrent requests; the
|
||||
// admin status surface reads InFlight/Capacity from it for live
|
||||
// load visibility.
|
||||
func (a *Application) AdmissionLimiter() *admission.Limiter {
|
||||
return a.admissionLimiter
|
||||
}
|
||||
|
||||
// StartupConfig returns the original startup configuration (from env vars, before file loading)
|
||||
func (a *Application) StartupConfig() *config.ApplicationConfig {
|
||||
return a.startupConfig
|
||||
@@ -255,6 +375,15 @@ func (a *Application) start() error {
|
||||
a.modelLoader,
|
||||
a.galleryService,
|
||||
)
|
||||
// Wire usage tracking so the assistant's get_usage_stats tool
|
||||
// returns real data; nil values keep the tool returning a clear
|
||||
// "unavailable" error if startup ran with --disable-stats.
|
||||
assistantClient.StatsRecorder = a.statsRecorder
|
||||
assistantClient.FallbackUser = a.fallbackUser
|
||||
// PII filter — same nil-or-real wiring.
|
||||
assistantClient.PIIRedactor = a.piiRedactor
|
||||
assistantClient.PIIEvents = a.piiEvents
|
||||
assistantClient.RouterDecisions = a.routerDecisions
|
||||
if err := holder.Initialize(a.applicationConfig.Context, assistantClient, localaitools.Options{}); err != nil {
|
||||
// Why log+continue instead of fail: the assistant is an optional
|
||||
// feature; a failure here must not take down the whole server.
|
||||
|
||||
@@ -233,7 +233,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("File stager initialized (HTTP direct transfer)")
|
||||
}
|
||||
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(
|
||||
registry,
|
||||
natsClient,
|
||||
cfg.Distributed.BackendInstallTimeoutOrDefault(),
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
|
||||
146
core/application/mitm.go
Normal file
146
core/application/mitm.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func startMITMProxy(app *Application, options *config.ApplicationConfig) error {
|
||||
app.mitmMutex.Lock()
|
||||
defer app.mitmMutex.Unlock()
|
||||
return startMITMLocked(app, options)
|
||||
}
|
||||
|
||||
func startMITMLocked(app *Application, options *config.ApplicationConfig) error {
|
||||
// Validate the host↔model-config 1-to-1 invariant before binding
|
||||
// the listener. Two configs claiming the same host means the
|
||||
// dispatcher would have ambiguous PII settings; refuse to start
|
||||
// rather than silently picking one. The conflict map is published
|
||||
// for /api/middleware/status to surface in the UI.
|
||||
ownership := app.backendLoader.MITMHostOwners()
|
||||
if len(ownership.Conflicts) > 0 {
|
||||
conflicts := ownership.Conflicts
|
||||
app.mitmHostConflicts.Store(&conflicts)
|
||||
hosts := make([]string, 0, len(conflicts))
|
||||
for h := range conflicts {
|
||||
hosts = append(hosts, h)
|
||||
}
|
||||
sort.Strings(hosts)
|
||||
xlog.Error("mitm: refusing to start — duplicate host claims across model configs",
|
||||
"hosts", hosts,
|
||||
"conflicts", conflicts,
|
||||
)
|
||||
return errors.New("mitm: configuration error: duplicate host claims (see /api/middleware/status)")
|
||||
}
|
||||
app.mitmHostConflicts.Store(nil)
|
||||
|
||||
caDir := options.MITMCADir
|
||||
if caDir == "" {
|
||||
base := options.DataPath
|
||||
if base == "" {
|
||||
base = "."
|
||||
}
|
||||
caDir = filepath.Join(base, "mitm-ca")
|
||||
}
|
||||
|
||||
if app.mitmCA.Load() == nil {
|
||||
ca, err := mitm.LoadOrCreateCA(caDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ca: %w", err)
|
||||
}
|
||||
app.mitmCA.Store(ca)
|
||||
}
|
||||
|
||||
// Allowlist is exactly the set of hosts claimed by model configs.
|
||||
// No global list — admins add hosts by creating an MITM model
|
||||
// config (template available in the Add Model UI). When no config
|
||||
// claims any host, the listener still starts but every CONNECT
|
||||
// tunnels through unmodified.
|
||||
effectiveHosts := make([]string, 0, len(ownership.Owners))
|
||||
for h := range ownership.Owners {
|
||||
effectiveHosts = append(effectiveHosts, h)
|
||||
}
|
||||
sort.Strings(effectiveHosts)
|
||||
|
||||
// Per-host PII gate inherits from the owning model's pii.enabled.
|
||||
// A non-cloud-proxy backend with no explicit pii.enabled resolves
|
||||
// to false → host is intercepted but the regex pass is skipped
|
||||
// (audit events still record).
|
||||
var piiDisabled []string
|
||||
for host, modelName := range ownership.Owners {
|
||||
cfg, exists := app.backendLoader.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if !cfg.PIIIsEnabled() {
|
||||
piiDisabled = append(piiDisabled, host)
|
||||
}
|
||||
}
|
||||
|
||||
handler := mitm.NewPIIHandler(mitm.PIIHandlerOptions{
|
||||
Redactor: app.piiRedactor,
|
||||
EventStore: app.piiEvents,
|
||||
HostsWithPIIDisabled: piiDisabled,
|
||||
})
|
||||
|
||||
srv, err := mitm.NewServer(mitm.Config{
|
||||
Addr: options.MITMListen,
|
||||
CA: app.mitmCA.Load(),
|
||||
InterceptHosts: effectiveHosts,
|
||||
Handler: handler,
|
||||
EventStore: app.piiEvents,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("server: %w", err)
|
||||
}
|
||||
if err := srv.Start(); err != nil {
|
||||
return fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
app.mitmServer.Store(srv)
|
||||
|
||||
xlog.Info("mitm: cloudproxy listener started",
|
||||
"addr", srv.Addr(),
|
||||
"ca_dir", caDir,
|
||||
"intercept_hosts", effectiveHosts,
|
||||
"model_owned_hosts", len(ownership.Owners),
|
||||
"pii_disabled_hosts", len(piiDisabled),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopMITM is idempotent.
|
||||
func (a *Application) StopMITM() error {
|
||||
a.mitmMutex.Lock()
|
||||
defer a.mitmMutex.Unlock()
|
||||
stopMITMLocked(a)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartMITM reuses the existing CA so trusted clients keep
|
||||
// working across listener flips.
|
||||
func (a *Application) RestartMITM() error {
|
||||
a.mitmMutex.Lock()
|
||||
defer a.mitmMutex.Unlock()
|
||||
stopMITMLocked(a)
|
||||
if a.applicationConfig.MITMListen == "" {
|
||||
xlog.Info("mitm: cloudproxy listener stays disabled (no listen address)")
|
||||
return nil
|
||||
}
|
||||
return startMITMLocked(a, a.applicationConfig)
|
||||
}
|
||||
|
||||
func stopMITMLocked(a *Application) {
|
||||
srv := a.mitmServer.Load()
|
||||
if srv == nil {
|
||||
return
|
||||
}
|
||||
srv.Stop()
|
||||
a.mitmServer.Store(nil)
|
||||
xlog.Info("mitm: cloudproxy listener stopped")
|
||||
}
|
||||
63
core/application/router_factories.go
Normal file
63
core/application/router_factories.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// adapterConfig resolves a model name to its runtime ModelConfig, or
|
||||
// nil when the name is unknown. Shared by the router-facing factories
|
||||
// below and by ModelConfigLookup.
|
||||
func (a *Application) adapterConfig(modelName string) *config.ModelConfig {
|
||||
cfg, err := a.backendLoader.LoadModelConfigFileByNameDefaultOptions(modelName, a.applicationConfig)
|
||||
if err != nil || cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ModelConfigLookup is the lookup function the router middleware's
|
||||
// classifier validator uses to confirm classifier_model declares
|
||||
// FLAG_SCORE before binding it.
|
||||
func (a *Application) ModelConfigLookup() func(modelName string) *config.ModelConfig {
|
||||
return a.adapterConfig
|
||||
}
|
||||
|
||||
// Scorer returns a backend.Scorer bound to the named model, or nil
|
||||
// when the model is unknown. Used as a method value (app.Scorer) by
|
||||
// router.ClassifierDeps — no factory-of-factory wrapper needed.
|
||||
func (a *Application) Scorer(modelName string) backend.Scorer {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewScorer(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// Reranker returns a backend.Reranker bound to the named model, or
|
||||
// nil when unknown. The reranker model's `type:` (e.g. "colbert")
|
||||
// selects the scoring head inside the rerankers backend.
|
||||
func (a *Application) Reranker(modelName string) backend.Reranker {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewReranker(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// Embedder returns a backend.Embedder bound to the named model, or
|
||||
// nil when unknown. Used by the router's L2 embedding cache.
|
||||
func (a *Application) Embedder(modelName string) backend.Embedder {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewEmbedder(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// VectorStore returns a backend.VectorStore for the named collection,
|
||||
// or nil when the name is empty. Each router model gets its own
|
||||
// backend process via the model loader's cache keyed by storeName.
|
||||
func (a *Application) VectorStore(storeName string) backend.VectorStore {
|
||||
return backend.NewVectorStore(a.modelLoader, a.applicationConfig, storeName)
|
||||
}
|
||||
@@ -87,6 +87,28 @@ var _ = Describe("loadRuntimeSettingsFromFile", func() {
|
||||
})
|
||||
})
|
||||
|
||||
// MITM listener address. The file is the only source — no env var
|
||||
// exists — so a regression here means an admin who configured the
|
||||
// listener via /api/settings loses it after a reboot, even though
|
||||
// the value is still on disk in the volume. (Intercept hosts now
|
||||
// live in model YAML mitm.hosts: blocks, not runtime_settings.json.)
|
||||
Describe("MITM fields", func() {
|
||||
It("loads mitm_listen", func() {
|
||||
cfg := &config.ApplicationConfig{DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`)}
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.MITMListen).To(Equal(":8443"))
|
||||
})
|
||||
|
||||
It("does not override an explicit CLI flag", func() {
|
||||
cfg := &config.ApplicationConfig{
|
||||
DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`),
|
||||
MITMListen: ":9999", // simulate WithMITMListen(":9999")
|
||||
}
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.MITMListen).To(Equal(":9999"), "CLI flag must win over the persisted file value")
|
||||
})
|
||||
})
|
||||
|
||||
// The Agent Pool block has a mix of zero and non-zero defaults
|
||||
// (Enabled=true, EmbeddingModel="granite-...", MaxChunkingSize=400,
|
||||
// VectorEngine="chromem", AgentHubURL="https://agenthub.localai.io").
|
||||
|
||||
@@ -15,11 +15,18 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
@@ -128,6 +135,117 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}()
|
||||
}
|
||||
|
||||
// Initialize the OTel + Prometheus metric pipeline before any
|
||||
// counter is created. monitoring.NewLocalAIMetricsService calls
|
||||
// otel.SetMeterProvider, so any subsequent otel.Meter() call —
|
||||
// including billing.NewRecorder below — sees the real provider
|
||||
// rather than the no-op global. Initialising metrics later (in
|
||||
// core/http/app.go) leaves billing's counters bound to a no-op
|
||||
// meter and never reaches /metrics. We deliberately ignore
|
||||
// DisableMetrics here for ordering purposes; the HTTP middleware
|
||||
// that records api_call histograms is still gated.
|
||||
if !options.DisableMetrics {
|
||||
ms, err := monitoring.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
xlog.Error("failed to initialize metrics provider", "error", err)
|
||||
} else {
|
||||
application.metricsService = ms
|
||||
// Bind the billing package's counters to the same meter the
|
||||
// metrics service exports. Without this, billing's counters
|
||||
// resolve via the OTel global and never reach /metrics.
|
||||
billing.SetMeter(ms.Meter)
|
||||
}
|
||||
}
|
||||
|
||||
// Wire the routing-module billing recorder. The recorder runs in
|
||||
// every mode (auth on/off, distributed/single-node) so that token
|
||||
// tracking is not gated on auth — a no-auth single-user box still
|
||||
// gets dashboards and `/api/usage` populated.
|
||||
//
|
||||
// fallbackUser is wired *unconditionally* when stats are enabled.
|
||||
// UsageMiddleware uses it as the attribution source whenever
|
||||
// auth.GetUser(c) is nil — that covers (a) no-auth deployments and
|
||||
// (b) internal callers under auth-on (cron flushers, distributed
|
||||
// worker callbacks) that hit a recordable endpoint without a user
|
||||
// in context. The billing.user_id_present invariant still rejects
|
||||
// empty IDs; LocalUser() returns a stable UUID per data path.
|
||||
if !options.DisableStats {
|
||||
var statsBackend billing.StatsBackend
|
||||
switch {
|
||||
case application.authDB != nil:
|
||||
statsBackend = billing.NewGormBackend(application.authDB, 0, 0)
|
||||
xlog.Info("stats: using auth DB for usage records")
|
||||
default:
|
||||
statsBackend = billing.NewMemoryBackend(0)
|
||||
xlog.Info("stats: using in-memory ring buffer (no-auth single-user mode)")
|
||||
}
|
||||
application.fallbackUser = billing.LocalUser(options.DataPath)
|
||||
application.statsRecorder = billing.NewRecorder(statsBackend)
|
||||
// Drain pending records on SIGTERM. The GORM backend buffers up
|
||||
// to maxPending (5k) records across a 5s flush tick, so without
|
||||
// this the last few seconds of usage disappear on graceful exit.
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
_ = application.statsRecorder.Close()
|
||||
})
|
||||
xlog.Info("stats: fallback user wired", "local_user_id", application.fallbackUser.ID)
|
||||
} else {
|
||||
xlog.Info("stats: disabled by --disable-stats")
|
||||
}
|
||||
|
||||
// Wire the regex PII filter. Default-on: a single-user box gets
|
||||
// the built-in pattern set the first time it starts, with email/
|
||||
// phone/SSN/credit-card on mask and api_key_prefix on block. If
|
||||
// the operator wants different actions, --pii-config points at a
|
||||
// YAML file that overrides per-id; --disable-pii turns it off
|
||||
// entirely.
|
||||
if !options.DisablePII {
|
||||
patterns, err := pii.LoadConfig(options.PIIConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pii config: %w", err)
|
||||
}
|
||||
application.piiRedactor = pii.NewRedactor(patterns)
|
||||
application.piiEvents = pii.NewMemoryEventStore(0)
|
||||
// Apply persisted per-pattern overrides — admins toggling
|
||||
// action/disabled via the UI and clicking "Save to disk" land
|
||||
// here on the next start. Bad ids are warned and ignored so a
|
||||
// stale entry doesn't block startup.
|
||||
for id, ov := range options.PIIPatternOverrides {
|
||||
if ov.Action != nil {
|
||||
if err := application.piiRedactor.SetAction(id, pii.Action(*ov.Action)); err != nil {
|
||||
xlog.Warn("pii: persisted override skipped", "pattern", id, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if ov.Disabled != nil {
|
||||
if err := application.piiRedactor.SetDisabled(id, *ov.Disabled); err != nil {
|
||||
xlog.Warn("pii: persisted disable skipped", "pattern", id, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
xlog.Info("pii: filter enabled",
|
||||
"patterns", len(patterns),
|
||||
"config_path", options.PIIConfigPath,
|
||||
"persisted_overrides", len(options.PIIPatternOverrides),
|
||||
)
|
||||
} else {
|
||||
xlog.Info("pii: disabled by --disable-pii")
|
||||
}
|
||||
|
||||
// Wire the routing decision log. Always-on when stats are enabled —
|
||||
// the per-router admin page reads this as the live activity feed
|
||||
// and as input to drift checks for subsystem 5.
|
||||
if !options.DisableStats {
|
||||
application.routerDecisions = router.NewMemoryDecisionStore(0)
|
||||
}
|
||||
// Process-wide classifier cache shared across all route middlewares so
|
||||
// the embedding-cache stats endpoint sees a single source of truth.
|
||||
application.routerRegistry = router.NewRegistry()
|
||||
|
||||
// Subsystem 5: admission control. Limiter is always wired so a
|
||||
// model that gains a limits: block via gallery install or YAML
|
||||
// edit takes effect on the next restart without conditional plumbing.
|
||||
application.admissionLimiter = admission.New()
|
||||
|
||||
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
|
||||
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
|
||||
if application.authDB != nil && application.agentJobService != nil {
|
||||
@@ -195,12 +313,36 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
}
|
||||
// Hydrate from the store first so the wildcard subscriber finds an
|
||||
// already-populated statuses map for any operations still in flight
|
||||
// on a peer replica.
|
||||
if err := application.galleryService.Hydrate(); err != nil {
|
||||
xlog.Warn("Gallery service hydrate failed", "error", err)
|
||||
}
|
||||
// Bind cache-invalidation handler before SubscribeBroadcasts so the
|
||||
// first inbound event is already routed. Peer replicas install a
|
||||
// model and broadcast on SubjectCacheInvalidateModels; this
|
||||
// callback re-runs LoadModelConfigsFromPath so a subsequent chat
|
||||
// completion that load-balances onto this replica finds the new
|
||||
// config. The originating replica reloads inline in modelHandler
|
||||
// and never enters this path.
|
||||
gs := application.galleryService
|
||||
sys := options.SystemState
|
||||
cfgLoaderOpts := options.ToConfigLoaderOptions()
|
||||
gs.OnModelsChanged = func(_ messaging.CacheInvalidateEvent) {
|
||||
if err := application.ModelConfigLoader().LoadModelConfigsFromPath(sys.Model.ModelsPath, cfgLoaderOpts...); err != nil {
|
||||
xlog.Warn("Failed to reload model configs after peer invalidation", "error", err)
|
||||
}
|
||||
}
|
||||
if err := application.galleryService.SubscribeBroadcasts(); err != nil {
|
||||
xlog.Warn("Gallery service subscribe failed", "error", err)
|
||||
}
|
||||
// Wire distributed model/backend managers so delete propagates to workers
|
||||
application.galleryService.SetModelManager(
|
||||
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||
)
|
||||
application.galleryService.SetBackendManager(
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry, application.galleryService),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -212,12 +354,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.RequireBackendIntegrity, nil, options.ModelsURL...); err != nil {
|
||||
xlog.Error("error installing models", "error", err)
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", "", options.RequireBackendIntegrity); err != nil {
|
||||
xlog.Error("error installing external backend", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -267,13 +409,13 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels, options.RequireBackendIntegrity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath, options.RequireBackendIntegrity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -291,6 +433,20 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
loadRuntimeSettingsFromFile(options)
|
||||
}
|
||||
|
||||
// Wire the cloudproxy MITM listener. Opt-in: empty MITMListen
|
||||
// means "no MITM" — operators must explicitly choose to start
|
||||
// it because clients have to install the generated CA cert.
|
||||
// The handler reuses the global redactor + event store so an
|
||||
// admin who's already configured PII filtering for direct API
|
||||
// traffic doesn't need a parallel config for MITM traffic.
|
||||
// Runs after loadRuntimeSettingsFromFile so a listener configured
|
||||
// via /api/settings is brought back up across restarts.
|
||||
if options.MITMListen != "" {
|
||||
if err := startMITMProxy(application, options); err != nil {
|
||||
return nil, fmt.Errorf("mitm: startup: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging)
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
@@ -552,6 +708,13 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
}
|
||||
if settings.TracingMaxBodyBytes != nil {
|
||||
// Allow the on-disk setting to override the CLI/env default. The
|
||||
// startup default is non-zero (see NewApplicationConfig), so a plain
|
||||
// `== 0` guard like the others would never trigger; we instead respect
|
||||
// any value the file specifies. 0 in the file means "uncapped".
|
||||
options.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||
}
|
||||
|
||||
// Branding / whitelabeling. There are no env vars for these — the file is
|
||||
// the only source — so apply unconditionally. Without this block a server
|
||||
@@ -573,6 +736,25 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.Branding.FaviconFile = *settings.FaviconFile
|
||||
}
|
||||
|
||||
// MITM listener address. The CLI flag WithMITMListen populates
|
||||
// options at startup; if the user configured MITM via /api/settings
|
||||
// after the fact, only the file holds the value. Apply when the
|
||||
// CLI flag did not already set it. (Intercept hosts now live in
|
||||
// model YAML mitm.hosts: rather than runtime_settings.json.)
|
||||
if settings.MITMListen != nil && options.MITMListen == "" {
|
||||
options.MITMListen = *settings.MITMListen
|
||||
}
|
||||
|
||||
// PII pattern overrides — file is the only source; CLI flags don't
|
||||
// reach into this map. Apply unconditionally when present; the
|
||||
// redactor wiring below sees the result on first construction.
|
||||
if settings.PIIPatternOverrides != nil {
|
||||
options.PIIPatternOverrides = make(map[string]config.PIIPatternRuntimeOverride, len(*settings.PIIPatternOverrides))
|
||||
for id, ov := range *settings.PIIPatternOverrides {
|
||||
options.PIIPatternOverrides[id] = ov
|
||||
}
|
||||
}
|
||||
|
||||
// Backend upgrade flags
|
||||
if settings.AutoUpgradeBackends != nil {
|
||||
if !options.AutoUpgradeBackends {
|
||||
|
||||
@@ -217,7 +217,7 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
err = bm.UpgradeBackend(ctx, name, nil)
|
||||
} else {
|
||||
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil)
|
||||
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)
|
||||
}
|
||||
if err != nil {
|
||||
xlog.Error("Failed to auto-upgrade backend",
|
||||
|
||||
@@ -78,7 +78,7 @@ func ModelAudioTransform(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func ModelAudioTransform(
|
||||
data["sample_rate"] = res.SampleRate
|
||||
data["samples"] = res.Samples
|
||||
data["reference_provided"] = res.ReferenceProvided
|
||||
if snippet := trace.AudioSnippet(dst); snippet != nil {
|
||||
if snippet := trace.AudioSnippet(dst, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
|
||||
169
core/backend/ctx_propagation_test.go
Normal file
169
core/backend/ctx_propagation_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package backend_test
|
||||
|
||||
// Regression spec for X-LocalAI-Node coverage on audio/image/TTS/rerank/VAD.
|
||||
//
|
||||
// The X-LocalAI-Node middleware (core/http/middleware.ExposeNodeHeader)
|
||||
// works end-to-end only if the per-request holder attached to the HTTP
|
||||
// request context reaches the SmartRouter via ml.Load(opts...). The chain
|
||||
// is:
|
||||
//
|
||||
// handler -> backend.Foo(ctx, ...) -> ModelOptions(cfg, app, WithContext(ctx))
|
||||
// -> ml.Load(opts...) -> grpcModel(..., o.context) -> modelRouter(ctx, ...)
|
||||
// -> SmartRouter -> distributedhdr.Stamp(ctx, nodeID)
|
||||
//
|
||||
// If any backend helper drops `ctx` and lets ModelOptions fall back to the
|
||||
// app context, the router never sees the per-request holder and the
|
||||
// header silently stays empty for that endpoint. These specs pin the
|
||||
// request-context-reaches-router contract for the five backend helpers
|
||||
// that were previously dropping ctx between the handler and Load.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pbproto "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// newCapturingLoader returns a ModelLoader wired with a stub model router
|
||||
// that captures the context it receives and then short-circuits with a
|
||||
// sentinel error. The router callback is the exact seam where the
|
||||
// SmartRouter would call distributedhdr.Stamp in production, so observing
|
||||
// the holder here is equivalent to observing it at the real router.
|
||||
func newCapturingLoader() (*model.ModelLoader, *atomic.Value, func() context.Context) {
|
||||
loader := model.NewModelLoader(&system.SystemState{})
|
||||
var captured atomic.Value
|
||||
loader.SetModelRouter(func(ctx context.Context, _ string, _, _, _ string, _ *pbproto.ModelOptions, _ bool) (*model.Model, error) {
|
||||
captured.Store(ctx)
|
||||
// Return an error so the backend short-circuits before trying to
|
||||
// dial gRPC. We only care about the context-arrival contract.
|
||||
return nil, errRouterShortCircuit
|
||||
})
|
||||
get := func() context.Context {
|
||||
v, _ := captured.Load().(context.Context)
|
||||
return v
|
||||
}
|
||||
return loader, &captured, get
|
||||
}
|
||||
|
||||
var errRouterShortCircuit = sentinelErr("router short-circuit (test)")
|
||||
|
||||
type sentinelErr string
|
||||
|
||||
func (s sentinelErr) Error() string { return string(s) }
|
||||
|
||||
func newAppCfg() *config.ApplicationConfig {
|
||||
return config.NewApplicationConfig(config.WithSystemState(&system.SystemState{}))
|
||||
}
|
||||
|
||||
func newModelCfg() config.ModelConfig {
|
||||
threads := 1
|
||||
cfg := config.ModelConfig{
|
||||
Name: "test-model",
|
||||
Backend: "stub-backend",
|
||||
Threads: &threads,
|
||||
}
|
||||
cfg.Model = "test.bin"
|
||||
return cfg
|
||||
}
|
||||
|
||||
var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
|
||||
const fakeNodeID = "node-ctx-propagation-7"
|
||||
|
||||
var (
|
||||
appCfg *config.ApplicationConfig
|
||||
modelCfg config.ModelConfig
|
||||
loader *model.ModelLoader
|
||||
routerCtxOf func() context.Context
|
||||
holder *atomic.Value
|
||||
reqCtx context.Context
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
appCfg = newAppCfg()
|
||||
modelCfg = newModelCfg()
|
||||
loader, _, routerCtxOf = newCapturingLoader()
|
||||
holder = distributedhdr.NewHolder()
|
||||
reqCtx = distributedhdr.WithHolder(context.Background(), holder)
|
||||
})
|
||||
|
||||
// stampViaRouterCtx asserts the captured router context carries the
|
||||
// SAME holder that was attached to the request. We verify by stamping
|
||||
// through the router-side ctx and observing the value via the
|
||||
// request-side holder; if the holders were different objects the load
|
||||
// would return "".
|
||||
stampViaRouterCtx := func() {
|
||||
routerCtx := routerCtxOf()
|
||||
Expect(routerCtx).ToNot(BeNil(), "router callback must have been invoked")
|
||||
distributedhdr.Stamp(routerCtx, fakeNodeID)
|
||||
Expect(distributedhdr.Load(holder)).To(Equal(fakeNodeID),
|
||||
"stamp via router-side ctx must be observable via the request-side holder")
|
||||
}
|
||||
|
||||
It("Rerank forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.Rerank(reqCtx, &pbproto.RerankRequest{Query: "q"}, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("VAD forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.VAD(&schema.VADRequest{}, reqCtx, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTS forwards the request context to the SmartRouter", func() {
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTranscriptionWithOptions forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.ModelTranscriptionWithOptions(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTranscriptionStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTranscriptionStream(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg, func(backend.TranscriptionStreamChunk) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ImageGeneration forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.ImageGeneration(reqCtx, 64, 64, 1, 0, "p", "", "", "/tmp/out.png", loader, modelCfg, appCfg, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("does NOT leak the holder when the app context is used instead", func() {
|
||||
// Sanity: the bug being fixed manifests as the router getting
|
||||
// appCfg.Context (no holder) instead of reqCtx (holder). A direct
|
||||
// call with context.Background() must not see the holder via the
|
||||
// app context surface.
|
||||
appCtxOnly := appCfg.Context
|
||||
Expect(distributedhdr.Holder(appCtxOnly)).To(BeNil(),
|
||||
"the app context must not be the carrier of per-request holders")
|
||||
})
|
||||
})
|
||||
@@ -35,7 +35,7 @@ func Detection(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -11,9 +12,38 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
// Embedder produces a fixed-dimension vector from a prompt. The
|
||||
// router's L2 embedding cache uses it to look up semantically-similar
|
||||
// past decisions.
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, text string) ([]float32, error)
|
||||
}
|
||||
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// NewEmbedder binds (loader, modelConfig, appConfig) into an Embedder.
|
||||
func NewEmbedder(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Embedder {
|
||||
return &modelEmbedder{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelEmbedder struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (e *modelEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
fn, err := ModelEmbedding(ctx, text, nil, e.loader, e.modelConfig, e.appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fn()
|
||||
}
|
||||
|
||||
func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
@@ -67,7 +97,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConf
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"input_text": trace.TruncateString(s, 1000),
|
||||
|
||||
@@ -32,7 +32,7 @@ func FaceAnalyze(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func FaceVerify(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -10,9 +11,12 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
func ImageGeneration(ctx context.Context, height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(
|
||||
opts...,
|
||||
)
|
||||
@@ -23,7 +27,7 @@ func ImageGeneration(height, width, step, seed int, positive_prompt, negative_pr
|
||||
|
||||
fn := func() error {
|
||||
_, err := inferenceModel.GenerateImage(
|
||||
appConfig.Context,
|
||||
ctx,
|
||||
&proto.GenerateImageRequest{
|
||||
Height: int32(height),
|
||||
Width: int32(width),
|
||||
@@ -41,7 +45,7 @@ func ImageGeneration(height, width, step, seed int, positive_prompt, negative_pr
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"positive_prompt": positive_prompt,
|
||||
|
||||
@@ -86,7 +86,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
if !slices.Contains(modelNames, modelName) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries, o.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
|
||||
//return nil, err
|
||||
@@ -94,7 +94,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
opts := ModelOptions(*c, o)
|
||||
opts := ModelOptions(*c, o, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(o, c.Name, c.Backend, err, map[string]any{"model_file": modelFile})
|
||||
@@ -305,7 +305,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
|
||||
if o.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems, o.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
@@ -316,9 +316,13 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
"audios_count": len(audios),
|
||||
}
|
||||
|
||||
// Cap the captured fields up front: agent-pool LLM calls embed the
|
||||
// full augmented chat history in messages and the full reply in
|
||||
// response, so without a per-field cap a single trace can dwarf the
|
||||
// rest of the buffer. The cap matches the API-trace body cap.
|
||||
if len(messages) > 0 {
|
||||
if msgJSON, err := json.Marshal(messages); err == nil {
|
||||
traceData["messages"] = string(msgJSON)
|
||||
traceData["messages"] = trace.TruncateToBytes(string(msgJSON), o.TracingMaxBodyBytes)
|
||||
}
|
||||
}
|
||||
if reasoningJSON, err := json.Marshal(c.ReasoningConfig); err == nil {
|
||||
@@ -337,7 +341,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
resp, err := originalFn()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
traceData["response"] = resp.Response
|
||||
traceData["response"] = trace.TruncateToBytes(resp.Response, o.TracingMaxBodyBytes)
|
||||
traceData["token_usage"] = map[string]any{
|
||||
"prompt": resp.Usage.Prompt,
|
||||
"completion": resp.Usage.Completion,
|
||||
@@ -359,10 +363,10 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
toolCallCount += len(d.ToolCalls)
|
||||
}
|
||||
if len(contentParts) > 0 {
|
||||
chatDeltasInfo["content"] = strings.Join(contentParts, "")
|
||||
chatDeltasInfo["content"] = trace.TruncateToBytes(strings.Join(contentParts, ""), o.TracingMaxBodyBytes)
|
||||
}
|
||||
if len(reasoningParts) > 0 {
|
||||
chatDeltasInfo["reasoning_content"] = strings.Join(reasoningParts, "")
|
||||
chatDeltasInfo["reasoning_content"] = trace.TruncateToBytes(strings.Join(reasoningParts, ""), o.TracingMaxBodyBytes)
|
||||
}
|
||||
if toolCallCount > 0 {
|
||||
chatDeltasInfo["tool_call_count"] = toolCallCount
|
||||
|
||||
@@ -21,7 +21,7 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
|
||||
if !appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceModelLoad,
|
||||
@@ -242,6 +242,18 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
Tokenizer: c.Tokenizer,
|
||||
}
|
||||
|
||||
if c.Backend == "cloud-proxy" {
|
||||
opts.Proxy = &pb.ProxyOptions{
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
}
|
||||
}
|
||||
|
||||
if c.MMProj != "" {
|
||||
opts.MMProj = filepath.Join(modelPath, c.MMProj)
|
||||
}
|
||||
@@ -277,7 +289,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
MinP: float32(*c.MinP),
|
||||
Tokens: int32(*c.Maxtokens),
|
||||
Threads: int32(*c.Threads),
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
PromptCacheAll: *c.PromptCacheAll,
|
||||
PromptCacheRO: c.PromptCacheRO,
|
||||
PromptCachePath: promptCachePath,
|
||||
F16KV: *c.F16,
|
||||
|
||||
@@ -11,8 +11,56 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// RerankResult is the per-document score returned to consumers,
|
||||
// narrowed from proto.RerankResult so callers don't need to depend on
|
||||
// the proto package.
|
||||
type RerankResult struct {
|
||||
Index int
|
||||
RelevanceScore float32
|
||||
}
|
||||
|
||||
// Reranker scores a list of candidate documents against a query.
|
||||
// Returns one RerankResult per input document (no top-N truncation -
|
||||
// callers that need it can sort and slice).
|
||||
type Reranker interface {
|
||||
Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error)
|
||||
}
|
||||
|
||||
// NewReranker binds (loader, modelConfig, appConfig) into a Reranker.
|
||||
func NewReranker(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Reranker {
|
||||
return &modelReranker{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelReranker struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (r *modelReranker) Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error) {
|
||||
req := &proto.RerankRequest{
|
||||
Query: query,
|
||||
Documents: documents,
|
||||
// TopN=0: backend returns scores for every document. Truncating
|
||||
// here would silently zero out labels the reranker considered
|
||||
// unlikely, which the router classifier needs.
|
||||
}
|
||||
res, err := Rerank(ctx, req, r.loader, r.appConfig, r.modelConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]RerankResult, 0, len(res.GetResults()))
|
||||
for _, dr := range res.GetResults() {
|
||||
out = append(out, RerankResult{Index: int(dr.GetIndex()), RelevanceScore: dr.GetRelevanceScore()})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
rerankModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -25,7 +73,7 @@ func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.Mod
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
159
core/backend/score.go
Normal file
159
core/backend/score.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// ScoreOptions controls a single Score request.
|
||||
type ScoreOptions struct {
|
||||
// IncludeTokenLogprobs returns per-token log-probability detail for
|
||||
// each candidate. Off by default — the joint LogProb is enough for
|
||||
// ranking; callers that need calibration / entropy over the token
|
||||
// stream opt in.
|
||||
IncludeTokenLogprobs bool
|
||||
// LengthNormalize divides the joint log-prob by the candidate's
|
||||
// token count. Useful when comparing candidates of different
|
||||
// lengths — without it, longer candidates score lower by default.
|
||||
LengthNormalize bool
|
||||
}
|
||||
|
||||
// CandidateScore is the per-candidate result. Mirrors pb.CandidateScore
|
||||
// but avoids leaking the proto type to consumers.
|
||||
type CandidateScore struct {
|
||||
LogProb float64
|
||||
LengthNormalizedLogProb float64
|
||||
NumTokens int
|
||||
Tokens []TokenLogProb
|
||||
}
|
||||
|
||||
type TokenLogProb struct {
|
||||
Token string
|
||||
LogProb float64
|
||||
}
|
||||
|
||||
// Scorer evaluates a model's joint log-probability of each candidate
|
||||
// continuation given a shared prompt. Implemented by NewScorer over a
|
||||
// model-loaded backend; the router's score classifier consumes this
|
||||
// for multi-label policy selection.
|
||||
type Scorer interface {
|
||||
Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error)
|
||||
}
|
||||
|
||||
// NewScorer binds (loader, modelConfig, appConfig) into a Scorer. The
|
||||
// underlying backend is resolved lazily on the first Score call.
|
||||
// Returns nil only as a contract violation — callers that need to
|
||||
// detect "model not loadable" should look up the config first.
|
||||
func NewScorer(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Scorer {
|
||||
return &modelScorer{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelScorer struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (m *modelScorer) Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error) {
|
||||
fn, err := ModelScore(prompt, candidates, ScoreOptions{LengthNormalize: true}, m.loader, m.modelConfig, m.appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
// ModelScore loads the backend for modelConfig and returns a closure
|
||||
// that scores `candidates` against `prompt`. The closure is bound to
|
||||
// the loaded model so callers can keep it around for repeat scoring
|
||||
// within the same request without re-resolving the backend.
|
||||
func ModelScore(prompt string, candidates []string, opts ScoreOptions, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func(ctx context.Context) ([]CandidateScore, error), error) {
|
||||
modelOpts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(modelOpts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
b, ok := inferenceModel.(grpc.Backend)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("scoring not supported by backend %q", modelConfig.Backend)
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("Score: candidates must be non-empty")
|
||||
}
|
||||
return func(ctx context.Context) ([]CandidateScore, error) {
|
||||
// Surface score calls in the Traces UI alongside the LLM calls
|
||||
// they typically gate (router classifier, eval scoring). Without
|
||||
// this, a router-classified request shows only the downstream LLM
|
||||
// trace with no record of the classification that picked it.
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
resp, err := b.Score(ctx, &pb.ScoreRequest{
|
||||
Prompt: prompt,
|
||||
Candidates: candidates,
|
||||
IncludeTokenLogprobs: opts.IncludeTokenLogprobs,
|
||||
LengthNormalize: opts.LengthNormalize,
|
||||
})
|
||||
results := scoreResponseToCandidates(resp, opts.IncludeTokenLogprobs)
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceScore,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(prompt, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
// Copy candidates so the trace buffer doesn't pin a
|
||||
// caller-owned slice for the lifetime of the ring.
|
||||
"candidates": append([]string(nil), candidates...),
|
||||
"results": results,
|
||||
},
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// scoreResponseToCandidates converts the wire-format pb response into
|
||||
// the value type consumed by callers. Extracted to keep ModelScore's
|
||||
// closure trivial and so the conversion can be unit-tested without a
|
||||
// real backend.
|
||||
func scoreResponseToCandidates(resp *pb.ScoreResponse, includeTokens bool) []CandidateScore {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]CandidateScore, len(resp.Candidates))
|
||||
for i, c := range resp.Candidates {
|
||||
cs := CandidateScore{
|
||||
LogProb: c.LogProb,
|
||||
LengthNormalizedLogProb: c.LengthNormalizedLogProb,
|
||||
NumTokens: int(c.NumTokens),
|
||||
}
|
||||
if includeTokens && len(c.Tokens) > 0 {
|
||||
cs.Tokens = make([]TokenLogProb, len(c.Tokens))
|
||||
for j, t := range c.Tokens {
|
||||
cs.Tokens[j] = TokenLogProb{Token: t.Token, LogProb: t.LogProb}
|
||||
}
|
||||
}
|
||||
out[i] = cs
|
||||
}
|
||||
return out
|
||||
}
|
||||
63
core/backend/score_test.go
Normal file
63
core/backend/score_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("scoreResponseToCandidates", func() {
|
||||
It("returns nil for a nil response", func() {
|
||||
Expect(scoreResponseToCandidates(nil, false)).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns an empty slice when the response has no candidates", func() {
|
||||
Expect(scoreResponseToCandidates(&pb.ScoreResponse{}, false)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("copies LogProb / LengthNormalizedLogProb / NumTokens for every candidate", func() {
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{
|
||||
{LogProb: -2.0, LengthNormalizedLogProb: -1.0, NumTokens: 2},
|
||||
{LogProb: -7.5, LengthNormalizedLogProb: -1.5, NumTokens: 5},
|
||||
}}
|
||||
got := scoreResponseToCandidates(resp, false)
|
||||
Expect(got).To(HaveLen(2))
|
||||
Expect(got[0].LogProb).To(Equal(-2.0))
|
||||
Expect(got[0].LengthNormalizedLogProb).To(Equal(-1.0))
|
||||
Expect(got[0].NumTokens).To(Equal(2))
|
||||
Expect(got[1].LogProb).To(Equal(-7.5))
|
||||
Expect(got[1].NumTokens).To(Equal(5))
|
||||
})
|
||||
|
||||
It("omits per-token detail when includeTokens=false even if the wire response carries it", func() {
|
||||
// Defensive: if the backend over-reports we still respect the
|
||||
// caller's opt-in so consumers don't pay marshaling for data
|
||||
// they didn't ask for.
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{
|
||||
LogProb: -1.0,
|
||||
Tokens: []*pb.TokenLogProb{{Token: "hi", LogProb: -1.0}},
|
||||
}}}
|
||||
got := scoreResponseToCandidates(resp, false)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Tokens).To(BeNil())
|
||||
})
|
||||
|
||||
It("populates per-token detail when includeTokens=true", func() {
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{
|
||||
LogProb: -3.0,
|
||||
NumTokens: 2,
|
||||
Tokens: []*pb.TokenLogProb{
|
||||
{Token: "Hello", LogProb: -1.0},
|
||||
{Token: " world", LogProb: -2.0},
|
||||
},
|
||||
}}}
|
||||
got := scoreResponseToCandidates(resp, true)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Tokens).To(HaveLen(2))
|
||||
Expect(got[0].Tokens[0].Token).To(Equal("Hello"))
|
||||
Expect(got[0].Tokens[0].LogProb).To(Equal(-1.0))
|
||||
Expect(got[0].Tokens[1].Token).To(Equal(" world"))
|
||||
Expect(got[0].Tokens[1].LogProb).To(Equal(-2.0))
|
||||
})
|
||||
})
|
||||
@@ -98,7 +98,7 @@ func SoundGeneration(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,74 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
// VectorStore is the narrowed KNN store used by the router's embedding
|
||||
// cache. Search returns the top-1 match (cosine similarity in [-1, 1])
|
||||
// and the serialised payload, or ok=false on a clean miss.
|
||||
type VectorStore interface {
|
||||
Search(ctx context.Context, vec []float32) (similarity float64, payload []byte, ok bool, err error)
|
||||
Insert(ctx context.Context, vec []float32, payload []byte) error
|
||||
}
|
||||
|
||||
// NewVectorStore returns a VectorStore backed by the local-store
|
||||
// gRPC backend, namespaced by storeName so two routers don't collide.
|
||||
func NewVectorStore(loader *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) VectorStore {
|
||||
if storeName == "" {
|
||||
return nil
|
||||
}
|
||||
return &localVectorStore{loader: loader, appConfig: appConfig, storeName: storeName}
|
||||
}
|
||||
|
||||
type localVectorStore struct {
|
||||
loader *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
storeName string
|
||||
}
|
||||
|
||||
func (s *localVectorStore) backend(_ context.Context) (grpc.Backend, error) {
|
||||
return StoreBackend(s.loader, s.appConfig, s.storeName, "")
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Search(ctx context.Context, vec []float32) (float64, []byte, bool, error) {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return 0, nil, false, fmt.Errorf("vector store load: %w", err)
|
||||
}
|
||||
_, values, similarities, err := store.Find(ctx, be, vec, 1)
|
||||
if err != nil {
|
||||
// local-store's Find returns "existing length is -1" before
|
||||
// any keys are inserted. Surface that as a clean miss so the
|
||||
// cache layer treats it as an empty store and proceeds to
|
||||
// Insert rather than skipping.
|
||||
if strings.Contains(err.Error(), "existing length is -1") {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return 0, nil, false, fmt.Errorf("vector store find: %w", err)
|
||||
}
|
||||
if len(values) == 0 || len(similarities) == 0 {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return float64(similarities[0]), values[0], true, nil
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload []byte) error {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vector store load: %w", err)
|
||||
}
|
||||
return store.SetSingle(ctx, be, vec, payload)
|
||||
}
|
||||
|
||||
func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string, backend string) (grpc.Backend, error) {
|
||||
if backend == "" {
|
||||
backend = model.LocalStoreBackend
|
||||
|
||||
@@ -27,7 +27,7 @@ func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.Model
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -41,11 +41,14 @@ func (r *TranscriptionRequest) toProto(threads uint32) *proto.TranscriptRequest
|
||||
}
|
||||
}
|
||||
|
||||
func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
||||
func loadTranscriptionModel(ctx context.Context, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
||||
if modelConfig.Backend == "" {
|
||||
modelConfig.Backend = model.WhisperBackend
|
||||
}
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
transcriptionModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -68,7 +71,7 @@ func ModelTranscription(ctx context.Context, audio, language string, translate,
|
||||
}
|
||||
|
||||
func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -76,10 +79,10 @@ func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest
|
||||
var startTime time.Time
|
||||
var audioSnippet map[string]any
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
// Capture audio before the backend call — the backend may delete the file.
|
||||
audioSnippet = trace.AudioSnippet(req.Audio)
|
||||
audioSnippet = trace.AudioSnippet(req.Audio, appConfig.TracingMaxBodyBytes)
|
||||
}
|
||||
|
||||
r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
|
||||
@@ -150,7 +153,7 @@ type TranscriptionStreamChunk struct {
|
||||
// support real streaming should still emit one terminal event with Final set,
|
||||
// which the HTTP layer turns into a single delta + done SSE pair.
|
||||
func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -29,7 +29,10 @@ func ModelTTS(
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (string, *proto.Result, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -67,7 +70,7 @@ func ModelTTS(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -93,7 +96,7 @@ func ModelTTS(
|
||||
"language": language,
|
||||
}
|
||||
if err == nil && res.Success {
|
||||
if snippet := trace.AudioSnippet(filePath); snippet != nil {
|
||||
if snippet := trace.AudioSnippet(filePath, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
@@ -131,7 +134,7 @@ func ModelTTSStream(
|
||||
modelConfig config.ModelConfig,
|
||||
audioCallback func([]byte) error,
|
||||
) error {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -161,7 +164,7 @@ func ModelTTSStream(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -260,7 +263,7 @@ func ModelTTSStream(
|
||||
"streaming": true,
|
||||
}
|
||||
if resultErr == nil && len(snippetPCM) > 0 {
|
||||
if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes); snippet != nil {
|
||||
if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,10 @@ func VAD(request *schema.VADRequest,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
|
||||
@@ -42,7 +42,7 @@ func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, en
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
|
||||
traceData := map[string]any{
|
||||
"prompt": prompt,
|
||||
|
||||
@@ -31,7 +31,7 @@ func VoiceAnalyze(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func VoiceEmbed(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func VoiceVerify(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -17,9 +17,10 @@ import (
|
||||
)
|
||||
|
||||
type BackendsCMDFlags struct {
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||
RequireBackendIntegrity bool `env:"LOCALAI_REQUIRE_BACKEND_INTEGRITY,REQUIRE_BACKEND_INTEGRITY" help:"If true, reject backend installs without a configured signature verification policy (OCI URIs) or SHA256 (tarball/HTTP URIs)." group:"hardening" default:"false"`
|
||||
}
|
||||
|
||||
type BackendsList struct {
|
||||
@@ -126,7 +127,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState)
|
||||
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias, bi.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -197,7 +198,7 @@ func (bu *BackendsUpgrade) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := gallery.UpgradeBackend(context.Background(), systemState, modelLoader, galleries, name, progressCallback); err != nil {
|
||||
if err := gallery.UpgradeBackend(context.Background(), systemState, modelLoader, galleries, name, progressCallback, bu.RequireBackendIntegrity); err != nil {
|
||||
fmt.Printf("Failed to upgrade %s: %v\n", name, err)
|
||||
} else {
|
||||
fmt.Printf("Backend %s upgraded successfully\n", name)
|
||||
|
||||
@@ -32,6 +32,7 @@ type ModelsList struct {
|
||||
|
||||
type ModelsInstall struct {
|
||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||
RequireBackendIntegrity bool `env:"LOCALAI_REQUIRE_BACKEND_INTEGRITY,REQUIRE_BACKEND_INTEGRITY" help:"If true, reject backend installs without a configured signature verification policy (OCI URIs) or SHA256 (tarball/HTTP URIs)." group:"hardening" default:"false"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES" help:"If true, automatically loads backend galleries" group:"backends" default:"true"`
|
||||
ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
|
||||
|
||||
@@ -71,7 +72,6 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(mi.ModelsPath),
|
||||
system.WithBackendPath(mi.BackendsPath),
|
||||
@@ -135,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState)
|
||||
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, mi.RequireBackendIntegrity, progressCallback, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -39,19 +39,19 @@ type RunCMD struct {
|
||||
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
||||
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
|
||||
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
|
||||
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
AutoUpgradeBackends bool `env:"LOCALAI_AUTO_UPGRADE_BACKENDS,AUTO_UPGRADE_BACKENDS" help:"Automatically upgrade backends when new versions are detected" group:"backends" default:"false"`
|
||||
PreferDevelopmentBackends bool `env:"LOCALAI_PREFER_DEV_BACKENDS,PREFER_DEV_BACKENDS" help:"Prefer development backend versions (shows development backends by default in UI)" group:"backends" default:"false"`
|
||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
|
||||
F16 bool `name:"f16" env:"LOCALAI_F16,F16" help:"Enable GPU acceleration" group:"performance"`
|
||||
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
|
||||
@@ -67,6 +67,7 @@ type RunCMD struct {
|
||||
OllamaAPIRootEndpoint bool `env:"LOCALAI_OLLAMA_API_ROOT_ENDPOINT" default:"false" help:"Register Ollama-compatible health check on / (replaces web UI on root path). The /api/* Ollama endpoints are always available regardless of this flag" group:"api"`
|
||||
DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
|
||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||
RequireBackendIntegrity bool `env:"LOCALAI_REQUIRE_BACKEND_INTEGRITY,REQUIRE_BACKEND_INTEGRITY" help:"If true, backend installs without a configured signature verification policy (for OCI URIs) or SHA256 (for tarball/HTTP URIs) are rejected. Default is to warn and install. Set this in production once your gallery's verification: block is populated." group:"hardening" default:"false"`
|
||||
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
||||
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
|
||||
DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
|
||||
@@ -99,6 +100,7 @@ type RunCMD struct {
|
||||
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
|
||||
EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
|
||||
TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
|
||||
TracingMaxBodyBytes int `env:"LOCALAI_TRACING_MAX_BODY_BYTES" default:"65536" help:"Maximum bytes captured per request/response body in the trace buffer (0 = uncapped). Caps memory growth from chatty endpoints like /embeddings." group:"api"`
|
||||
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`
|
||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||
|
||||
@@ -143,18 +145,25 @@ type RunCMD struct {
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
|
||||
// Cloud-proxy MITM listener (off by default).
|
||||
MITMListen string `env:"LOCALAI_MITM_LISTEN" help:"Address (host:port) for the cloudproxy MITM listener. Empty = disabled. Clients set HTTPS_PROXY=http://<this>:<port>. Intercept hosts are declared per-model via the model YAML mitm.hosts: block; create one from the Add Model UI." group:"middleware"`
|
||||
MITMCADir string `env:"LOCALAI_MITM_CA_DIR" type:"path" help:"Directory holding the MITM proxy CA cert + key. Defaults to <data-path>/mitm-ca." group:"middleware"`
|
||||
}
|
||||
|
||||
func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
@@ -213,6 +222,8 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithLoadToMemory(r.LoadToMemory),
|
||||
config.WithMachineTag(r.MachineTag),
|
||||
config.WithAPIAddress(r.Address),
|
||||
config.WithMITMListen(r.MITMListen),
|
||||
config.WithMITMCADir(r.MITMCADir),
|
||||
config.WithAgentJobRetentionDays(r.AgentJobRetentionDays),
|
||||
config.WithLlamaCPPTunnelCallback(func(tunnels []string) {
|
||||
tunnelEnvVar := strings.Join(tunnels, ",")
|
||||
@@ -253,12 +264,29 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.StorageSecretKey != "" {
|
||||
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
|
||||
}
|
||||
if r.BackendInstallTimeout != "" {
|
||||
d, err := time.ParseDuration(r.BackendInstallTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT %q: %w", r.BackendInstallTimeout, err)
|
||||
}
|
||||
opts = append(opts, config.WithBackendInstallTimeout(d))
|
||||
}
|
||||
if r.BackendUpgradeTimeout != "" {
|
||||
d, err := time.ParseDuration(r.BackendUpgradeTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT %q: %w", r.BackendUpgradeTimeout, err)
|
||||
}
|
||||
opts = append(opts, config.WithBackendUpgradeTimeout(d))
|
||||
}
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
opts = append(opts, config.DisableMetricsEndpoint)
|
||||
@@ -272,6 +300,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.EnableTracing)
|
||||
}
|
||||
opts = append(opts, config.WithTracingMaxItems(r.TracingMaxItems))
|
||||
opts = append(opts, config.WithTracingMaxBodyBytes(r.TracingMaxBodyBytes))
|
||||
|
||||
token := ""
|
||||
if r.Peer2Peer || r.Peer2PeerToken != "" {
|
||||
@@ -503,6 +532,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.WithAutoUpgradeBackends(r.AutoUpgradeBackends))
|
||||
}
|
||||
|
||||
if r.RequireBackendIntegrity {
|
||||
opts = append(opts, config.WithRequireBackendIntegrity(r.RequireBackendIntegrity))
|
||||
}
|
||||
|
||||
if r.PreferDevelopmentBackends {
|
||||
opts = append(opts, config.WithPreferDevelopmentBackends(r.PreferDevelopmentBackends))
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package worker
|
||||
|
||||
type WorkerFlags struct {
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||
RequireBackendIntegrity bool `env:"LOCALAI_REQUIRE_BACKEND_INTEGRITY,REQUIRE_BACKEND_INTEGRITY" help:"If true, reject backend installs without a configured signature verification policy (OCI URIs) or SHA256 (tarball/HTTP URIs)." group:"hardening" default:"false"`
|
||||
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
// installing the backend from the gallery if it isn't present.
|
||||
// `name` is the gallery entry name (for vLLM the meta entry "vllm"
|
||||
// resolves to a platform-specific package via capability lookup).
|
||||
func findBackendPath(name, galleries string, systemState *system.SystemState) (string, error) {
|
||||
func findBackendPath(name, galleries string, systemState *system.SystemState, requireIntegrity bool) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -33,7 +33,7 @@ func findBackendPath(name, galleries string, systemState *system.SystemState) (s
|
||||
xlog.Error("failed loading galleries", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, name, nil, true); err != nil {
|
||||
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, name, nil, true, requireIntegrity); err != nil {
|
||||
xlog.Error("backend not found, failed to install it", "name", name, "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ const (
|
||||
llamaCPPGalleryName = "llama-cpp"
|
||||
)
|
||||
|
||||
func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (string, error) {
|
||||
func findLLamaCPPBackend(galleries string, systemState *system.SystemState, requireIntegrity bool) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed listing system backends", "error", err)
|
||||
@@ -43,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
|
||||
xlog.Error("failed loading galleries", "error", err)
|
||||
return "", err
|
||||
}
|
||||
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true, requireIntegrity)
|
||||
if err != nil {
|
||||
xlog.Error("llama-cpp backend not found, failed to install it", "error", err)
|
||||
return "", err
|
||||
@@ -76,7 +76,7 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
|
||||
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
const mlxDistributedGalleryName = "mlx-distributed"
|
||||
|
||||
func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState) (string, error) {
|
||||
return findBackendPath(mlxDistributedGalleryName, galleries, systemState)
|
||||
func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState, requireIntegrity bool) (string, error) {
|
||||
return findBackendPath(mlxDistributedGalleryName, galleries, systemState, requireIntegrity)
|
||||
}
|
||||
|
||||
// buildMLXCommand builds the exec.Cmd to launch the mlx-distributed backend.
|
||||
|
||||
@@ -28,7 +28,7 @@ func (r *MLXDistributed) Run(ctx *cliContext.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState)
|
||||
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot find mlx-distributed backend: %w", err)
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
for {
|
||||
xlog.Info("Starting llama-cpp-rpc-server", "address", address, "port", port)
|
||||
|
||||
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
|
||||
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to find llama-cpp-rpc-server", "error", err)
|
||||
return
|
||||
|
||||
@@ -48,7 +48,7 @@ func (r *P2PMLX) Run(ctx *cliContext.Context) error {
|
||||
c, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState)
|
||||
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
xlog.Warn("Could not find mlx-distributed backend from gallery, will try backend.py directly", "error", err)
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
|
||||
return fmt.Errorf("getting system state: %w", err)
|
||||
}
|
||||
|
||||
backendPath, err := findBackendPath("vllm", r.BackendGalleries, systemState)
|
||||
backendPath, err := findBackendPath("vllm", r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot find vllm backend: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type ApplicationConfig struct {
|
||||
Debug bool
|
||||
EnableTracing bool
|
||||
TracingMaxItems int
|
||||
TracingMaxBodyBytes int // Per-body cap for captured request/response bodies; 0 disables the cap
|
||||
EnableBackendLogging bool
|
||||
GeneratedContentDir string
|
||||
|
||||
@@ -39,6 +40,54 @@ type ApplicationConfig struct {
|
||||
P2PNetworkID string
|
||||
Federated bool
|
||||
|
||||
// DisableStats turns off per-request token tracking. By default the
|
||||
// routing module's billing recorder runs in every mode (including
|
||||
// no-auth single-user) so dashboards and `/api/usage` are immediately
|
||||
// useful; set this to opt out of that, e.g., for ephemeral CI runs
|
||||
// or privacy-strict deployments where no token-count history should
|
||||
// touch disk or memory.
|
||||
DisableStats bool
|
||||
|
||||
// PIIConfigPath points to an optional YAML file describing the PII
|
||||
// pattern set. When empty, the routing/pii module's DefaultPatterns()
|
||||
// (email, phone, SSN, credit card, IPv4, API key prefixes) are
|
||||
// loaded with their default actions. Each entry overrides the
|
||||
// matching default by ID:
|
||||
//
|
||||
// patterns:
|
||||
// - id: email
|
||||
// action: route_local # downgrade default mask -> route_local
|
||||
// - id: ssn
|
||||
// action: block # upgrade default mask -> block
|
||||
//
|
||||
// Unknown ids are rejected with a clear error at startup.
|
||||
PIIConfigPath string
|
||||
|
||||
// DisablePII turns the regex PII filter off entirely. Default
|
||||
// (false) enables it on the OpenAI chat completions route.
|
||||
DisablePII bool
|
||||
|
||||
// MITMListen is the address (host:port) the cloudproxy MITM
|
||||
// listener binds on. Empty disables the MITM proxy entirely.
|
||||
// Use case: redacting PII from Claude Code / Codex CLI traffic
|
||||
// without LocalAI holding the upstream API key. Clients set
|
||||
// HTTPS_PROXY=http://localai:port and trust the CA cert
|
||||
// LocalAI exposes at /api/middleware/proxy-ca.crt.
|
||||
MITMListen string
|
||||
|
||||
// MITMCADir holds the persisted MITM proxy CA cert and private
|
||||
// key. The CA is generated on first start; subsequent starts
|
||||
// reload it so clients keep trusting the same root. The key
|
||||
// file is mode 0600.
|
||||
MITMCADir string
|
||||
|
||||
|
||||
// PIIPatternOverrides applies persisted per-id deltas (action,
|
||||
// disabled) to the live redactor at startup. Loaded from
|
||||
// runtime_settings.json and applied right after pii.NewRedactor.
|
||||
// nil/empty leaves the YAML defaults in place.
|
||||
PIIPatternOverrides map[string]PIIPatternRuntimeOverride
|
||||
|
||||
DisableWebUI bool
|
||||
OllamaAPIRootEndpoint bool
|
||||
EnforcePredownloadScans bool
|
||||
@@ -60,6 +109,13 @@ type ApplicationConfig struct {
|
||||
AutoUpgradeBackends bool
|
||||
PreferDevelopmentBackends bool
|
||||
|
||||
// RequireBackendIntegrity promotes a missing SHA256 (tarball/HTTP URIs)
|
||||
// or missing verification policy (OCI URIs) from a warning to a hard
|
||||
// failure during backend install/upgrade. Off by default to keep
|
||||
// upgrades non-breaking; operators opt in explicitly via
|
||||
// --require-backend-integrity / LOCALAI_REQUIRE_BACKEND_INTEGRITY.
|
||||
RequireBackendIntegrity bool
|
||||
|
||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
WatchDogIdle bool
|
||||
@@ -104,6 +160,18 @@ type ApplicationConfig struct {
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed DistributedConfig
|
||||
|
||||
// ExposeNodeHeader, when true, activates middleware.ExposeNodeHeader on
|
||||
// the inference routes (OpenAI chat/completions/embeddings, Anthropic
|
||||
// /v1/messages, Ollama /api/chat,/api/generate,/api/embed). The
|
||||
// middleware wraps the response writer and attaches an "X-LocalAI-Node"
|
||||
// response header carrying the ID of the distributed-mode worker node
|
||||
// that served the request. Off by default because the node ID is
|
||||
// internal topology that can aid attacker reconnaissance if surfaced on
|
||||
// a public endpoint; operators opt in explicitly via
|
||||
// --expose-node-header / LOCALAI_EXPOSE_NODE_HEADER for debugging,
|
||||
// observability and load-balancer attribution.
|
||||
ExposeNodeHeader bool
|
||||
|
||||
// LocalAI Assistant chat modality. Hard-disable the in-process admin MCP
|
||||
// server with this flag; runtime-toggleable via /api/settings.
|
||||
DisableLocalAIAssistant bool
|
||||
@@ -180,6 +248,7 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||
TracingMaxItems: 1024,
|
||||
TracingMaxBodyBytes: 64 * 1024, // 64 KiB - caps each request/response body in the trace buffer
|
||||
AgentPool: AgentPoolConfig{
|
||||
Enabled: true,
|
||||
Timeout: "5m",
|
||||
@@ -436,6 +505,10 @@ func WithAutoUpgradeBackends(v bool) AppOption {
|
||||
return func(o *ApplicationConfig) { o.AutoUpgradeBackends = v }
|
||||
}
|
||||
|
||||
func WithRequireBackendIntegrity(v bool) AppOption {
|
||||
return func(o *ApplicationConfig) { o.RequireBackendIntegrity = v }
|
||||
}
|
||||
|
||||
func WithPreferDevelopmentBackends(v bool) AppOption {
|
||||
return func(o *ApplicationConfig) { o.PreferDevelopmentBackends = v }
|
||||
}
|
||||
@@ -567,6 +640,12 @@ func WithTracingMaxItems(items int) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithTracingMaxBodyBytes(bytes int) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.TracingMaxBodyBytes = bytes
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeneratedContentDir(generatedContentDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.GeneratedContentDir = generatedContentDir
|
||||
@@ -585,6 +664,45 @@ func WithDataPath(dataPath string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisableStats turns off the billing recorder. CLI: --disable-stats.
|
||||
func WithDisableStats(disable bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DisableStats = disable
|
||||
}
|
||||
}
|
||||
|
||||
// WithPIIConfigPath points the routing PII filter at a YAML config
|
||||
// file. CLI: --pii-config.
|
||||
func WithPIIConfigPath(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.PIIConfigPath = path
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisablePII turns the regex PII filter off. CLI: --disable-pii.
|
||||
func WithDisablePII(disable bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DisablePII = disable
|
||||
}
|
||||
}
|
||||
|
||||
// WithMITMListen sets the address the cloudproxy MITM listener
|
||||
// binds on. Empty = disabled. CLI: --mitm-listen.
|
||||
func WithMITMListen(addr string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MITMListen = addr
|
||||
}
|
||||
}
|
||||
|
||||
// WithMITMCADir sets the directory used to persist the MITM proxy
|
||||
// CA cert + key. CLI: --mitm-ca-dir.
|
||||
func WithMITMCADir(dir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MITMCADir = dir
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DynamicConfigsDir = dynamicConfigsDir
|
||||
@@ -874,6 +992,15 @@ func WithDisableLocalAIAssistant(disabled bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithExposeNodeHeader enables the X-LocalAI-Node response header on
|
||||
// inference endpoints. Default off; the node ID reveals internal cluster
|
||||
// topology and is opt-in for that reason.
|
||||
func WithExposeNodeHeader(enabled bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.ExposeNodeHeader = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
|
||||
// Some options defined at the application level are going to be passed as defaults for
|
||||
// all the configuration for the models.
|
||||
@@ -909,6 +1036,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
f16 := o.F16
|
||||
debug := o.Debug
|
||||
tracingMaxItems := o.TracingMaxItems
|
||||
tracingMaxBodyBytes := o.TracingMaxBodyBytes
|
||||
enableTracing := o.EnableTracing
|
||||
enableBackendLogging := o.EnableBackendLogging
|
||||
cors := o.CORS
|
||||
@@ -978,6 +1106,8 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
logoHorizontalFile := o.Branding.LogoHorizontalFile
|
||||
faviconFile := o.Branding.FaviconFile
|
||||
|
||||
mitmListen := o.MITMListen
|
||||
|
||||
return RuntimeSettings{
|
||||
WatchdogEnabled: &watchdogEnabled,
|
||||
WatchdogIdleEnabled: &watchdogIdle,
|
||||
@@ -997,6 +1127,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
F16: &f16,
|
||||
Debug: &debug,
|
||||
TracingMaxItems: &tracingMaxItems,
|
||||
TracingMaxBodyBytes: &tracingMaxBodyBytes,
|
||||
EnableTracing: &enableTracing,
|
||||
EnableBackendLogging: &enableBackendLogging,
|
||||
CORS: &cors,
|
||||
@@ -1030,6 +1161,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
LogoFile: &logoFile,
|
||||
LogoHorizontalFile: &logoHorizontalFile,
|
||||
FaviconFile: &faviconFile,
|
||||
MITMListen: &mitmListen,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1135,6 +1267,9 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
if settings.TracingMaxItems != nil {
|
||||
o.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
if settings.TracingMaxBodyBytes != nil {
|
||||
o.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||
}
|
||||
if settings.EnableBackendLogging != nil {
|
||||
o.EnableBackendLogging = *settings.EnableBackendLogging
|
||||
}
|
||||
@@ -1252,6 +1387,10 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
o.Branding.FaviconFile = *settings.FaviconFile
|
||||
}
|
||||
|
||||
if settings.MITMListen != nil {
|
||||
o.MITMListen = *settings.MITMListen
|
||||
}
|
||||
|
||||
// Note: ApiKeys requires special handling (merging with startup keys) - handled in caller
|
||||
|
||||
return requireRestart
|
||||
|
||||
@@ -40,7 +40,10 @@ type DistributedConfig struct {
|
||||
// model-row cleanup on MarkUnhealthy / MarkDraining).
|
||||
DisablePerModelHealthCheck bool
|
||||
|
||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||
|
||||
BackendInstallTimeout time.Duration // NATS round-trip timeout for backend.install (default 15m)
|
||||
BackendUpgradeTimeout time.Duration // NATS round-trip timeout for backend.upgrade (default 15m)
|
||||
|
||||
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
|
||||
|
||||
@@ -68,13 +71,15 @@ func (c DistributedConfig) Validate() error {
|
||||
}
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
"mcp-tool-timeout": c.MCPToolTimeout,
|
||||
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
|
||||
"worker-wait-timeout": c.WorkerWaitTimeout,
|
||||
"drain-timeout": c.DrainTimeout,
|
||||
"health-check-interval": c.HealthCheckInterval,
|
||||
"stale-node-threshold": c.StaleNodeThreshold,
|
||||
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
|
||||
FlagMCPToolTimeout: c.MCPToolTimeout,
|
||||
FlagMCPDiscoveryTimeout: c.MCPDiscoveryTimeout,
|
||||
FlagWorkerWaitTimeout: c.WorkerWaitTimeout,
|
||||
FlagDrainTimeout: c.DrainTimeout,
|
||||
FlagHealthCheckInterval: c.HealthCheckInterval,
|
||||
FlagStaleNodeThreshold: c.StaleNodeThreshold,
|
||||
FlagMCPCIJobTimeout: c.MCPCIJobTimeout,
|
||||
FlagBackendInstallTimeout: c.BackendInstallTimeout,
|
||||
FlagBackendUpgradeTimeout: c.BackendUpgradeTimeout,
|
||||
} {
|
||||
if d < 0 {
|
||||
return fmt.Errorf("%s must not be negative", name)
|
||||
@@ -137,24 +142,66 @@ func WithStorageSecretKey(key string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendInstallTimeout(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.BackendInstallTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendUpgradeTimeout(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.BackendUpgradeTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// Flag names for distributed timeout / interval configuration. These are
|
||||
// the kebab-case identifiers kong derives from the matching RunCMD struct
|
||||
// fields; they appear in Validate error messages and any other operator-
|
||||
// facing surface that needs to reference a specific knob by name. Keeping
|
||||
// them as constants prevents the string from drifting from the actual
|
||||
// flag a future rename would produce.
|
||||
const (
|
||||
FlagMCPToolTimeout = "mcp-tool-timeout"
|
||||
FlagMCPDiscoveryTimeout = "mcp-discovery-timeout"
|
||||
FlagWorkerWaitTimeout = "worker-wait-timeout"
|
||||
FlagDrainTimeout = "drain-timeout"
|
||||
FlagHealthCheckInterval = "health-check-interval"
|
||||
FlagStaleNodeThreshold = "stale-node-threshold"
|
||||
FlagMCPCIJobTimeout = "mcp-ci-job-timeout"
|
||||
FlagBackendInstallTimeout = "backend-install-timeout"
|
||||
FlagBackendUpgradeTimeout = "backend-upgrade-timeout"
|
||||
)
|
||||
|
||||
// Defaults for distributed timeouts.
|
||||
const (
|
||||
DefaultMCPToolTimeout = 360 * time.Second
|
||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
DefaultHealthCheckInterval = 15 * time.Second
|
||||
DefaultStaleNodeThreshold = 60 * time.Second
|
||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||
DefaultMCPToolTimeout = 360 * time.Second
|
||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
DefaultHealthCheckInterval = 15 * time.Second
|
||||
DefaultStaleNodeThreshold = 60 * time.Second
|
||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||
DefaultBackendInstallTimeout = 15 * time.Minute
|
||||
DefaultBackendUpgradeTimeout = 15 * time.Minute
|
||||
)
|
||||
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||
}
|
||||
|
||||
// BackendUpgradeTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendUpgradeTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendUpgradeTimeout, DefaultBackendUpgradeTimeout)
|
||||
}
|
||||
|
||||
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
|
||||
|
||||
90
core/config/distributed_config_test.go
Normal file
90
core/config/distributed_config_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
var _ = Describe("DistributedConfig backend NATS timeouts", func() {
|
||||
Context("BackendInstallTimeoutOrDefault", func() {
|
||||
It("returns 15 minutes when unset", func() {
|
||||
c := config.DistributedConfig{}
|
||||
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||
})
|
||||
|
||||
It("returns the configured value when set", func() {
|
||||
c := config.DistributedConfig{BackendInstallTimeout: 42 * time.Minute}
|
||||
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(42 * time.Minute))
|
||||
})
|
||||
})
|
||||
|
||||
Context("BackendUpgradeTimeoutOrDefault", func() {
|
||||
It("returns 15 minutes when unset", func() {
|
||||
c := config.DistributedConfig{}
|
||||
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||
})
|
||||
|
||||
It("returns the configured value when set", func() {
|
||||
c := config.DistributedConfig{BackendUpgradeTimeout: 30 * time.Minute}
|
||||
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(30 * time.Minute))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("DistributedConfig flag-name constants", func() {
|
||||
// Pin the kebab-case strings so a rename of the Go field name (or a
|
||||
// CLI flag naming convention change) forces the constant to update,
|
||||
// keeping the Validate error messages and any future operator-facing
|
||||
// surface in sync with the actual CLI flag.
|
||||
DescribeTable("flag name constants",
|
||||
func(actual, expected string) {
|
||||
Expect(actual).To(Equal(expected))
|
||||
},
|
||||
Entry("MCP tool timeout", config.FlagMCPToolTimeout, "mcp-tool-timeout"),
|
||||
Entry("MCP discovery timeout", config.FlagMCPDiscoveryTimeout, "mcp-discovery-timeout"),
|
||||
Entry("worker wait timeout", config.FlagWorkerWaitTimeout, "worker-wait-timeout"),
|
||||
Entry("drain timeout", config.FlagDrainTimeout, "drain-timeout"),
|
||||
Entry("health check interval", config.FlagHealthCheckInterval, "health-check-interval"),
|
||||
Entry("stale node threshold", config.FlagStaleNodeThreshold, "stale-node-threshold"),
|
||||
Entry("MCP CI job timeout", config.FlagMCPCIJobTimeout, "mcp-ci-job-timeout"),
|
||||
Entry("backend install timeout", config.FlagBackendInstallTimeout, "backend-install-timeout"),
|
||||
Entry("backend upgrade timeout", config.FlagBackendUpgradeTimeout, "backend-upgrade-timeout"),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("DistributedConfig.Validate negative-duration errors", func() {
|
||||
It("rejects a negative BackendInstallTimeout with the flag name in the error", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
BackendInstallTimeout: -1 * time.Second,
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(config.FlagBackendInstallTimeout))
|
||||
Expect(err.Error()).To(ContainSubstring("must not be negative"))
|
||||
})
|
||||
|
||||
It("rejects a negative BackendUpgradeTimeout with the flag name in the error", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
BackendUpgradeTimeout: -1 * time.Second,
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(config.FlagBackendUpgradeTimeout))
|
||||
})
|
||||
|
||||
It("accepts all-zero durations as valid (defaults apply)", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
}
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user