mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-23 08:10:48 -04:00
Compare commits
82 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0f3e26245 | ||
|
|
e4cc1f11f3 | ||
|
|
6ed269d0b9 | ||
|
|
5756fb046d | ||
|
|
7980629bc5 | ||
|
|
d0a59be9de | ||
|
|
5cda4f1ccf | ||
|
|
c500461c69 | ||
|
|
834ecc36bf | ||
|
|
61bf34ea2f | ||
|
|
0b2ae3c6ca | ||
|
|
4735345105 | ||
|
|
7384fd800b | ||
|
|
6942713d85 | ||
|
|
0cf52c44d4 | ||
|
|
0d34cf7cbd | ||
|
|
f0cb02afb8 | ||
|
|
a39e025d64 | ||
|
|
05e8e1e9f4 | ||
|
|
a7f6cc8956 | ||
|
|
f15b9178ec | ||
|
|
959de86761 | ||
|
|
4c234abc2c | ||
|
|
c68818a62e | ||
|
|
11d5bd0cc3 | ||
|
|
12e056e96d | ||
|
|
308aa8908a | ||
|
|
b2d68a53a2 | ||
|
|
e3706c0512 | ||
|
|
1ffd82a050 | ||
|
|
f515168dbe | ||
|
|
ef6ca34513 | ||
|
|
9413c3767f | ||
|
|
3bf3cce232 | ||
|
|
06f8159035 | ||
|
|
f6a73f54fa | ||
|
|
24e04d8e81 | ||
|
|
b9a49449ae | ||
|
|
1879e11042 | ||
|
|
403d391316 | ||
|
|
fc3980dadd | ||
|
|
2009544b44 | ||
|
|
e859345b12 | ||
|
|
f30712f8e8 | ||
|
|
a19c77c5f8 | ||
|
|
4b02d23c0c | ||
|
|
21140e96b2 | ||
|
|
fc803e8d48 | ||
|
|
ca51606bfe | ||
|
|
cb502de309 | ||
|
|
5d0b549049 | ||
|
|
11cff1b309 | ||
|
|
4ca3d2cdc0 | ||
|
|
3cba35ed32 | ||
|
|
265ae35231 | ||
|
|
6a48157a80 | ||
|
|
41c838b2df | ||
|
|
21e793ad2a | ||
|
|
7c190bb4b9 | ||
|
|
d77a9137d8 | ||
|
|
661a0c3b9d | ||
|
|
00b8989886 | ||
|
|
43e0d397ca | ||
|
|
a1a7a219ed | ||
|
|
3937ec6527 | ||
|
|
1355b55794 | ||
|
|
5a2626d465 | ||
|
|
a39591f144 | ||
|
|
8c785dbe4a | ||
|
|
4abf5befbb | ||
|
|
195b910260 | ||
|
|
ba21bf667c | ||
|
|
7bd1693ad0 | ||
|
|
b5ac3a7373 | ||
|
|
53de474ef5 | ||
|
|
c33d36b870 | ||
|
|
57fa178a64 | ||
|
|
745473cbe6 | ||
|
|
594c9fd92e | ||
|
|
8af963bdd9 | ||
|
|
6e1dbae256 | ||
|
|
53bdb18d10 |
@@ -112,6 +112,8 @@ Add a YAML anchor definition in the `## metas` section (around line 2-300). Look
|
|||||||
|
|
||||||
Add image entries at the end of the file, following the pattern of similar backends such as `diffusers` or `chatterbox`. Include both `latest` (production) and `master` (development) tags.
|
Add image entries at the end of the file, following the pattern of similar backends such as `diffusers` or `chatterbox`. Include both `latest` (production) and `master` (development) tags.
|
||||||
|
|
||||||
|
**Note on integrity:** OCI backends installed from a gallery whose `verification:` block is set are verified against a keyless-cosign policy before extraction; tarball/HTTP backends use the optional `sha256:` field. New backends do not need any extra YAML — the gallery-level `verification:` block covers every entry. See [.agents/backend-signing.md](backend-signing.md) for the producer-side CI step.
|
||||||
|
|
||||||
## 4. Update the Makefile
|
## 4. Update the Makefile
|
||||||
|
|
||||||
The Makefile needs to be updated in several places to support building and testing the new backend:
|
The Makefile needs to be updated in several places to support building and testing the new backend:
|
||||||
|
|||||||
120
.agents/backend-signing.md
Normal file
120
.agents/backend-signing.md
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# Backend image signing & verification
|
||||||
|
|
||||||
|
LocalAI verifies backend OCI images against a per-gallery keyless-cosign
|
||||||
|
policy. This page documents the trust model, the producer side
|
||||||
|
(`.github/workflows/backend_merge.yml` in this repo), and the consumer
|
||||||
|
side (`pkg/oci/cosignverify` plus the gallery YAML).
|
||||||
|
|
||||||
|
## Trust model
|
||||||
|
|
||||||
|
- **Producer:** `.github/workflows/backend_merge.yml` signs each pushed
|
||||||
|
manifest list with `cosign sign --recursive` in keyless mode after
|
||||||
|
`docker buildx imagetools create`. The signing cert is issued by
|
||||||
|
Fulcio bound to the workflow's OIDC identity. There is no long-lived
|
||||||
|
signing key. `--recursive` signs both the manifest list and every
|
||||||
|
per-arch entry — needed because our consumer resolves a tag to a
|
||||||
|
per-arch manifest before checking signatures.
|
||||||
|
- **Storage:** Signatures are written as OCI 1.1 referrers
|
||||||
|
(`--registry-referrers-mode=oci-1-1`) in the new Sigstore bundle format
|
||||||
|
(current cosign releases do this by default; no `--new-bundle-format`
|
||||||
|
flag). No `:sha256-<hex>.sig` tag clutter.
|
||||||
|
- **Consumer:** `pkg/oci/cosignverify` discovers the bundle via the
|
||||||
|
referrers API, hands it to `sigstore-go`, and verifies it against the
|
||||||
|
policy declared in the gallery YAML (`Gallery.Verification`).
|
||||||
|
- **Revocation:** Keyless cosign certs are ephemeral (10-minute Fulcio
|
||||||
|
validity), so revocation is policy-side, not CA-side. The gallery's
|
||||||
|
`verification.not_before` (RFC3339) is the kill-switch — advance it to
|
||||||
|
invalidate every signature produced before a known compromise window.
|
||||||
|
|
||||||
|
## Producer setup
|
||||||
|
|
||||||
|
`backend_merge.yml` is the workflow that joins per-arch digests into the
|
||||||
|
multi-arch manifest list users actually pull, so it's also the right place
|
||||||
|
to sign. The job needs:
|
||||||
|
|
||||||
|
- `permissions: { id-token: write, contents: read }` at the job level so
|
||||||
|
the runner can exchange its GitHub OIDC token for a Fulcio cert.
|
||||||
|
- `sigstore/cosign-installer@v3` step (current cosign releases already
|
||||||
|
default to the new bundle format).
|
||||||
|
- After each `docker buildx imagetools create`, resolve the resulting
|
||||||
|
list digest with `docker buildx imagetools inspect <tag> --format
|
||||||
|
'{{.Manifest.Digest}}'` and sign:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cosign sign --yes --recursive \
|
||||||
|
--registry-referrers-mode=oci-1-1 \
|
||||||
|
"${REGISTRY_REPO}@${DIGEST}"
|
||||||
|
```
|
||||||
|
|
||||||
|
Sign by digest, never by tag — signing by tag binds the signature to
|
||||||
|
whatever the tag points at *now*, and a subsequent tag push orphans it.
|
||||||
|
|
||||||
|
`backend_build_darwin.yml` builds and pushes single-arch darwin images
|
||||||
|
that bypass the manifest-list merge. If/when those entries get a gallery
|
||||||
|
`verification:` policy, the equivalent cosign step has to land there
|
||||||
|
too.
|
||||||
|
|
||||||
|
## Consumer setup (in `mudler/LocalAI` gallery YAML)
|
||||||
|
|
||||||
|
Once CI is signing, add a `verification:` block to the backend gallery
|
||||||
|
entry (`backend/index.yaml`):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- name: localai
|
||||||
|
url: github:mudler/LocalAI/backend/index.yaml@master
|
||||||
|
verification:
|
||||||
|
issuer: "https://token.actions.githubusercontent.com"
|
||||||
|
identity_regex: "^https://github\\.com/mudler/LocalAI/\\.github/workflows/backend_merge\\.yml@refs/heads/master$"
|
||||||
|
# Optional revocation cutoff; advance during incident response.
|
||||||
|
# not_before: "2026-06-01T00:00:00Z"
|
||||||
|
```
|
||||||
|
|
||||||
|
Identity matching pins the OIDC subject Fulcio issued the signing cert
|
||||||
|
to. Without this, any image signed by *anyone* with a Fulcio cert would
|
||||||
|
pass — the regex is what makes a signature mean "produced by our CI".
|
||||||
|
|
||||||
|
## Strict mode
|
||||||
|
|
||||||
|
Default behaviour: OCI backends without a `verification:` block install
|
||||||
|
with a warning (logs include `installing OCI backend without signature
|
||||||
|
verification`). Tarball/HTTP backends without a `sha256` field log a
|
||||||
|
similar warning.
|
||||||
|
|
||||||
|
For production, set `LOCALAI_REQUIRE_BACKEND_INTEGRITY=1` (or pass
|
||||||
|
`--require-backend-integrity` to `local-ai run` / `local-ai backends
|
||||||
|
install` / `local-ai models install`). The warning becomes a hard error
|
||||||
|
and unverifiable backends refuse to install.
|
||||||
|
|
||||||
|
## Revocation playbook
|
||||||
|
|
||||||
|
If `backend_merge.yml` (or any workflow with `id-token: write`) is
|
||||||
|
compromised and we've shipped malicious signed images:
|
||||||
|
|
||||||
|
1. **Identify the compromise window.** Find the earliest IntegratedTime
|
||||||
|
from the bad signatures (Rekor search by `subject` filter).
|
||||||
|
2. **Set `verification.not_before`** in `backend/index.yaml` to a
|
||||||
|
timestamp just *after* that window's start.
|
||||||
|
3. **Push the YAML.** Deployed LocalAI instances pick it up on next
|
||||||
|
gallery refresh (1-hour cache in `core/gallery/gallery.go`).
|
||||||
|
4. **Fix the underlying compromise** in the workflow and re-sign images
|
||||||
|
with the new build, which will have IntegratedTime > `not_before`.
|
||||||
|
5. **Optional:** for absolute decisiveness, also rotate to a new
|
||||||
|
workflow path (`backend_merge_v2.yml`) and update `identity_regex`.
|
||||||
|
|
||||||
|
## Where the code lives
|
||||||
|
|
||||||
|
- `pkg/oci/cosignverify/` — verifier, policy, OCI referrer fetch, NotBefore enforcement.
|
||||||
|
- `pkg/downloader/uri.go` — `WithImageVerifier` option threaded through `DownloadFileWithContext`.
|
||||||
|
- `core/gallery/backends.go` — `backendDownloadOptions` builds the verifier from the gallery's policy.
|
||||||
|
- `core/config/gallery.go` — `Gallery.Verification` YAML schema.
|
||||||
|
- `core/cli/run.go`, `core/cli/backends.go`, `core/cli/models.go` — `--require-backend-integrity` flag propagation.
|
||||||
|
- `.github/workflows/backend_merge.yml` — producer-side `cosign sign --recursive` after each multi-arch manifest list push.
|
||||||
|
|
||||||
|
## Out of scope (follow-ups)
|
||||||
|
|
||||||
|
- **Signing the gallery YAML itself.** The index is fetched over HTTPS
|
||||||
|
from GitHub; we trust the host. A cosign blob signature on the YAML
|
||||||
|
would close that gap but adds key-management overhead. Revisit this
|
||||||
|
page if/when added.
|
||||||
|
- **Tarball/HTTP backend signing.** Cosign can sign arbitrary blobs, but
|
||||||
|
for now non-OCI backends keep using the `sha256:` field in YAML.
|
||||||
@@ -61,6 +61,12 @@ Always check `llama.cpp` for new model configuration options that should be supp
|
|||||||
- `reasoning_format` - Reasoning format options
|
- `reasoning_format` - Reasoning format options
|
||||||
- Any new flags or parameters
|
- Any new flags or parameters
|
||||||
|
|
||||||
|
### Speculative Decoding Types
|
||||||
|
|
||||||
|
The `spec_type` option in `grpc-server.cpp` delegates to upstream's `common_speculative_types_from_names()`, so new speculative types added to the `common_speculative_type_from_name` map in `common/speculative.cpp` are picked up automatically with no code changes - only docs need an entry in `docs/content/advanced/model-configuration.md`. Current values: `none`, `draft-simple`, `draft-eagle3`, `draft-mtp`, `ngram-simple`, `ngram-map-k`, `ngram-map-k4v`, `ngram-mod`, `ngram-cache`.
|
||||||
|
|
||||||
|
`draft-mtp` (Multi-Token Prediction, [ggml-org/llama.cpp#22673](https://github.com/ggml-org/llama.cpp/pull/22673)) does not need a separate draft GGUF: when `spec_type` includes `draft-mtp` and `draftmodel` is empty, the upstream server creates an MTP context off the target model itself. LocalAI's gRPC layer needs no changes for this — it works through the existing `params.speculative.types` plumbing and the derived `cparams.n_rs_seq = params.speculative.need_n_rs_seq()` in `common_context_params_to_llama`.
|
||||||
|
|
||||||
### Implementation Guidelines
|
### Implementation Guidelines
|
||||||
|
|
||||||
1. **Feature Parity**: Always aim for feature parity with llama.cpp's implementation
|
1. **Feature Parity**: Always aim for feature parity with llama.cpp's implementation
|
||||||
|
|||||||
54
.github/workflows/backend_merge.yml
vendored
54
.github/workflows/backend_merge.yml
vendored
@@ -31,6 +31,13 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
merge:
|
merge:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
# id-token: write is required for keyless cosign — the workflow
|
||||||
|
# exchanges the GitHub OIDC token for a short-lived Fulcio cert that
|
||||||
|
# signs each pushed manifest. Without this permission the runner
|
||||||
|
# cannot mint the token, and `cosign sign` fails with "no token".
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
id-token: write
|
||||||
env:
|
env:
|
||||||
quay_username: ${{ secrets.quayUsername }}
|
quay_username: ${{ secrets.quayUsername }}
|
||||||
steps:
|
steps:
|
||||||
@@ -57,6 +64,16 @@ jobs:
|
|||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@master
|
uses: docker/setup-buildx-action@master
|
||||||
|
|
||||||
|
# cosign signs each pushed manifest list with --recursive so the
|
||||||
|
# index and every per-arch entry get an attached Sigstore bundle.
|
||||||
|
# Recent cosign releases always emit the new bundle format, so
|
||||||
|
# there's no extra CLI flag to opt into it.
|
||||||
|
- name: Install cosign
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
uses: sigstore/cosign-installer@v3
|
||||||
|
with:
|
||||||
|
cosign-release: 'v2.4.1'
|
||||||
|
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@v4
|
uses: docker/login-action@v4
|
||||||
@@ -120,11 +137,25 @@ jobs:
|
|||||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||||
if [ -z "$tags" ]; then
|
if [ -z "$tags" ]; then
|
||||||
echo "No quay.io tags from docker/metadata-action; skipping quay merge"
|
echo "No quay.io tags from docker/metadata-action; skipping quay merge"
|
||||||
else
|
exit 0
|
||||||
# shellcheck disable=SC2086
|
|
||||||
docker buildx imagetools create $tags \
|
|
||||||
$(printf 'quay.io/go-skynet/ci-cache@sha256:%s ' *)
|
|
||||||
fi
|
fi
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
docker buildx imagetools create $tags \
|
||||||
|
$(printf 'quay.io/go-skynet/ci-cache@sha256:%s ' *)
|
||||||
|
# Resolve the manifest-list digest (any tag points at it) so
|
||||||
|
# cosign can sign by digest. Signing by tag would leave the
|
||||||
|
# signature orphaned the next time the tag moves.
|
||||||
|
first_tag=$(jq -cr '
|
||||||
|
.tags | map(select(startswith("quay.io/"))) | .[0]
|
||||||
|
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||||
|
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||||
|
# --recursive walks the list and signs every per-arch entry
|
||||||
|
# too — clients that resolve a tag to a platform-specific
|
||||||
|
# manifest before checking signatures need the per-arch
|
||||||
|
# signatures, not just the list-level one.
|
||||||
|
cosign sign --yes --recursive \
|
||||||
|
--registry-referrers-mode=oci-1-1 \
|
||||||
|
"quay.io/go-skynet/local-ai-backends@${digest}"
|
||||||
|
|
||||||
- name: Create manifest list and push (dockerhub)
|
- name: Create manifest list and push (dockerhub)
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
@@ -139,11 +170,18 @@ jobs:
|
|||||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||||
if [ -z "$tags" ]; then
|
if [ -z "$tags" ]; then
|
||||||
echo "No dockerhub tags from docker/metadata-action; skipping dockerhub merge"
|
echo "No dockerhub tags from docker/metadata-action; skipping dockerhub merge"
|
||||||
else
|
exit 0
|
||||||
# shellcheck disable=SC2086
|
|
||||||
docker buildx imagetools create $tags \
|
|
||||||
$(printf 'localai/localai-backends@sha256:%s ' *)
|
|
||||||
fi
|
fi
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
docker buildx imagetools create $tags \
|
||||||
|
$(printf 'localai/localai-backends@sha256:%s ' *)
|
||||||
|
first_tag=$(jq -cr '
|
||||||
|
.tags | map(select(startswith("localai/"))) | .[0]
|
||||||
|
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||||
|
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||||
|
cosign sign --yes --recursive \
|
||||||
|
--registry-referrers-mode=oci-1-1 \
|
||||||
|
"localai/localai-backends@${digest}"
|
||||||
|
|
||||||
- name: Inspect manifest
|
- name: Inspect manifest
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
|
|||||||
1
.github/workflows/image_build.yml
vendored
1
.github/workflows/image_build.yml
vendored
@@ -106,6 +106,7 @@ jobs:
|
|||||||
type=ref,event=branch
|
type=ref,event=branch
|
||||||
type=semver,pattern={{raw}}
|
type=semver,pattern={{raw}}
|
||||||
type=sha
|
type=sha
|
||||||
|
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||||
flavor: |
|
flavor: |
|
||||||
latest=${{ inputs.tag-latest }}
|
latest=${{ inputs.tag-latest }}
|
||||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||||
|
|||||||
1
.github/workflows/image_merge.yml
vendored
1
.github/workflows/image_merge.yml
vendored
@@ -80,6 +80,7 @@ jobs:
|
|||||||
type=ref,event=branch
|
type=ref,event=branch
|
||||||
type=semver,pattern={{raw}}
|
type=semver,pattern={{raw}}
|
||||||
type=sha
|
type=sha
|
||||||
|
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||||
flavor: |
|
flavor: |
|
||||||
latest=${{ inputs.tag-latest }}
|
latest=${{ inputs.tag-latest }}
|
||||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -77,3 +77,6 @@ local-backends/
|
|||||||
tests/e2e-ui/ui-test-server
|
tests/e2e-ui/ui-test-server
|
||||||
core/http/react-ui/playwright-report/
|
core/http/react-ui/playwright-report/
|
||||||
core/http/react-ui/test-results/
|
core/http/react-ui/test-results/
|
||||||
|
|
||||||
|
# Local worktrees
|
||||||
|
.worktrees/
|
||||||
|
|||||||
@@ -46,8 +46,52 @@ linters:
|
|||||||
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.Fail. See .agents/coding-style.md.'
|
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.Fail. See .agents/coding-style.md.'
|
||||||
- pattern: '^t\.FailNow$'
|
- pattern: '^t\.FailNow$'
|
||||||
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.FailNow. See .agents/coding-style.md.'
|
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.FailNow. See .agents/coding-style.md.'
|
||||||
|
# In-process config should flow through ApplicationConfig / kong-bound
|
||||||
|
# CLI flags, not via os.Getenv. The CLI layer is the legitimate
|
||||||
|
# env→struct boundary (kong's `env:"..."` tag); anything deeper that
|
||||||
|
# reads env directly leaks process state into business logic and
|
||||||
|
# makes flags impossible to test or override per-request. Backend
|
||||||
|
# subprocesses, the system/capabilities probe, and a few places that
|
||||||
|
# read non-LocalAI env vars (HOME, PATH, AUTH_TOKEN passed by parent)
|
||||||
|
# are exempt — see linters.exclusions.rules below.
|
||||||
|
- pattern: '^os\.(Getenv|LookupEnv|Environ)$'
|
||||||
|
msg: 'Plumb config through ApplicationConfig (or the relevant CLI struct) instead of reading env directly. CLI entry points (core/cli/) bind env vars via kong''s `env:` tag — that is the only sanctioned env→struct boundary. See .agents/coding-style.md.'
|
||||||
exclusions:
|
exclusions:
|
||||||
paths:
|
paths:
|
||||||
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
||||||
- 'backend/go/whisper/sources'
|
- 'backend/go/whisper/sources'
|
||||||
- 'docs/'
|
- 'docs/'
|
||||||
|
rules:
|
||||||
|
# CLI entry points: kong's `env:"..."` tag is the legitimate env→struct
|
||||||
|
# boundary, and a handful of subcommands legitimately propagate values
|
||||||
|
# to spawned subprocesses (LLAMACPP_GRPC_SERVERS, MLX hostfile, ...).
|
||||||
|
- path: ^core/cli/
|
||||||
|
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||||
|
linters: [forbidigo]
|
||||||
|
# Backend subprocesses are independent binaries with their own env
|
||||||
|
# surface; they're not "in-process config" of the LocalAI server.
|
||||||
|
- path: ^backend/
|
||||||
|
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||||
|
linters: [forbidigo]
|
||||||
|
# System capability probe reads HOME, PATH-style vars to discover
|
||||||
|
# GPUs, default paths, etc. — not LocalAI config.
|
||||||
|
- path: ^pkg/system/
|
||||||
|
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||||
|
linters: [forbidigo]
|
||||||
|
# gRPC server reads AUTH_TOKEN passed in by the parent process at spawn
|
||||||
|
# time; model.Loader sets/inherits env to communicate with subprocesses.
|
||||||
|
- path: ^pkg/grpc/
|
||||||
|
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||||
|
linters: [forbidigo]
|
||||||
|
- path: ^pkg/model/
|
||||||
|
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||||
|
linters: [forbidigo]
|
||||||
|
# Top-level main binaries (local-ai, launcher) are entry points.
|
||||||
|
- path: ^cmd/
|
||||||
|
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||||
|
linters: [forbidigo]
|
||||||
|
# Tests legitimately read $HOME, $TMPDIR, and gating env vars
|
||||||
|
# (LOCALAI_COSIGN_LIVE, etc.) to skip live-network specs.
|
||||||
|
- path: _test\.go$
|
||||||
|
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||||
|
linters: [forbidigo]
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
|||||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||||
| [.agents/adding-gallery-models.md](.agents/adding-gallery-models.md) | Adding GGUF models from HuggingFace to the model gallery |
|
| [.agents/adding-gallery-models.md](.agents/adding-gallery-models.md) | Adding GGUF models from HuggingFace to the model gallery |
|
||||||
| [.agents/localai-assistant-mcp.md](.agents/localai-assistant-mcp.md) | LocalAI Assistant chat modality — adding admin tools to the in-process MCP server, editing skill prompts, keeping REST + MCP + skills in sync |
|
| [.agents/localai-assistant-mcp.md](.agents/localai-assistant-mcp.md) | LocalAI Assistant chat modality — adding admin tools to the in-process MCP server, editing skill prompts, keeping REST + MCP + skills in sync |
|
||||||
|
| [.agents/backend-signing.md](.agents/backend-signing.md) | Backend OCI image signing (keyless cosign + sigstore-go) — producer-side CI setup, consumer-side gallery `verification:` block, strict mode (`LOCALAI_REQUIRE_BACKEND_INTEGRITY`), revocation via `not_before` |
|
||||||
|
|
||||||
## Quick Reference
|
## Quick Reference
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# ds4 backend Makefile.
|
# ds4 backend Makefile.
|
||||||
#
|
#
|
||||||
# Upstream pin lives below as DS4_VERSION?=0cba357ca1bc0e7510421cc26888e420ea942123
|
# Upstream pin lives below as DS4_VERSION?=8d576642c39b9a2d782a80159ba84ef5a81c0b81
|
||||||
# (.github/bump_deps.sh) can find and update it - matches the
|
# (.github/bump_deps.sh) can find and update it - matches the
|
||||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||||
|
|
||||||
DS4_VERSION?=0cba357ca1bc0e7510421cc26888e420ea942123
|
DS4_VERSION?=8d576642c39b9a2d782a80159ba84ef5a81c0b81
|
||||||
DS4_REPO?=https://github.com/antirez/ds4
|
DS4_REPO?=https://github.com/antirez/ds4
|
||||||
|
|
||||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
IK_LLAMA_VERSION?=949bb8f1d660fc1264c137a6f3dbd619375f6134
|
IK_LLAMA_VERSION?=b3d39cff8bffbd67296d6badd4076a1486a0715c
|
||||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
LLAMA_VERSION?=a9883db8ee021cf16783016a60996d41820b5195
|
LLAMA_VERSION?=1acee6bf8939948f9bcbf4b14034e4b475f06069
|
||||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
|
|||||||
@@ -32,6 +32,7 @@
|
|||||||
#include <grpcpp/health_check_service_interface.h>
|
#include <grpcpp/health_check_service_interface.h>
|
||||||
#include <grpcpp/security/server_credentials.h>
|
#include <grpcpp/security/server_credentials.h>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
#include <algorithm>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@@ -450,6 +451,8 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
// vector; the turboquant fork still uses the legacy scalar. The
|
// vector; the turboquant fork still uses the legacy scalar. The
|
||||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
||||||
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
||||||
|
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
|
||||||
|
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
|
||||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||||
@@ -458,7 +461,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
const bool no_spec_type = params.speculative.types.empty() ||
|
const bool no_spec_type = params.speculative.types.empty() ||
|
||||||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
||||||
if (no_spec_type) {
|
if (no_spec_type) {
|
||||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT };
|
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@@ -514,16 +517,27 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
params.warmup = true;
|
params.warmup = true;
|
||||||
// no_op_offload: disable host tensor op offload (default: false)
|
// no_op_offload: disable host tensor op offload (default: false)
|
||||||
params.no_op_offload = false;
|
params.no_op_offload = false;
|
||||||
// kv_unified: enable unified KV cache (default: false)
|
// kv_unified: enable unified KV cache. Upstream's server auto-enables this
|
||||||
params.kv_unified = false;
|
// when the slot count is auto (-np <0), bumping n_parallel to 4 alongside.
|
||||||
// n_ctx_checkpoints: max context checkpoints per slot (default: 8)
|
// LocalAI keeps n_parallel=1 by default, which would skip that auto path
|
||||||
params.n_ctx_checkpoints = 8;
|
// and leave kv_unified=false. We flip the default to true here so the
|
||||||
|
// server-side prompt cache (cache_idle_slots) is actually usable on the
|
||||||
// llama memory fit fails if we don't provide a buffer for tensor overrides
|
// single-slot path that LocalAI ships with: without it, idle slots are
|
||||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
// never persisted across requests and the prompt cache is dead weight.
|
||||||
while (params.tensor_buft_overrides.size() < ntbo) {
|
// Users can opt out with `options: [ "kv_unified:false" ]`.
|
||||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
params.kv_unified = true;
|
||||||
}
|
// n_ctx_checkpoints: max context checkpoints per slot. Match upstream's
|
||||||
|
// default (32); the previous LocalAI-specific 8 was unnecessarily tight
|
||||||
|
// and limits partial-prefix recovery without a clear memory rationale.
|
||||||
|
params.n_ctx_checkpoints = 32;
|
||||||
|
// cache_idle_slots: save and clear idle slot KV to the prompt cache on
|
||||||
|
// task switch. Upstream default is true; the server auto-disables it if
|
||||||
|
// kv_unified=false or cache_ram_mib=0, so flipping kv_unified above is
|
||||||
|
// what actually unlocks it.
|
||||||
|
params.cache_idle_slots = true;
|
||||||
|
// checkpoint_every_nt: create a context checkpoint every N tokens during
|
||||||
|
// prefill (-1 disables). Match upstream's default (8192).
|
||||||
|
params.checkpoint_every_nt = 8192;
|
||||||
|
|
||||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||||
for (int i = 0; i < request->options_size(); i++) {
|
for (int i = 0; i < request->options_size(); i++) {
|
||||||
@@ -682,9 +696,161 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
try {
|
try {
|
||||||
params.n_ctx_checkpoints = std::stoi(optval_str);
|
params.n_ctx_checkpoints = std::stoi(optval_str);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
// If conversion fails, keep default value (8)
|
// If conversion fails, keep default value (32)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- server-side idle-slot prompt cache toggle (upstream --cache-idle-slots) ---
|
||||||
|
// Saves the slot's KV state into the host-side prompt cache on task
|
||||||
|
// switch so a later request with the same prefix can warm-load it.
|
||||||
|
// Auto-disabled by the server if kv_unified=false or cache_ram=0.
|
||||||
|
} else if (!strcmp(optname, "cache_idle_slots") || !strcmp(optname, "idle_slots_cache")) {
|
||||||
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||||
|
params.cache_idle_slots = true;
|
||||||
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||||
|
params.cache_idle_slots = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- prefill checkpoint cadence (upstream -cpent / --checkpoint-every-n-tokens) ---
|
||||||
|
// -1 disables checkpointing during prefill.
|
||||||
|
} else if (!strcmp(optname, "checkpoint_every_nt") || !strcmp(optname, "checkpoint_every_n_tokens")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try {
|
||||||
|
params.checkpoint_every_nt = std::stoi(optval_str);
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
// If conversion fails, keep default value (8192)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- physical batch size (upstream -ub / --ubatch-size) ---
|
||||||
|
// Note: line ~482 already aliases n_ubatch to n_batch as a default; this
|
||||||
|
// option lets users decouple the two (useful for embeddings/rerank).
|
||||||
|
} else if (!strcmp(optname, "n_ubatch") || !strcmp(optname, "ubatch")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.n_ubatch = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- main-model batch threads (upstream -tb / --threads-batch) ---
|
||||||
|
} else if (!strcmp(optname, "threads_batch") || !strcmp(optname, "n_threads_batch")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try {
|
||||||
|
int n = std::stoi(optval_str);
|
||||||
|
if (n <= 0) n = (int)std::thread::hardware_concurrency();
|
||||||
|
params.cpuparams_batch.n_threads = n;
|
||||||
|
} catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- pooling type for embeddings (upstream --pooling) ---
|
||||||
|
} else if (!strcmp(optname, "pooling_type") || !strcmp(optname, "pooling")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
if (optval_str == "none") params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||||
|
else if (optval_str == "mean") params.pooling_type = LLAMA_POOLING_TYPE_MEAN;
|
||||||
|
else if (optval_str == "cls") params.pooling_type = LLAMA_POOLING_TYPE_CLS;
|
||||||
|
else if (optval_str == "last") params.pooling_type = LLAMA_POOLING_TYPE_LAST;
|
||||||
|
else if (optval_str == "rank") params.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||||
|
// unknown values silently leave UNSPECIFIED (auto-detect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- llama log verbosity threshold (upstream -lv / --verbosity) ---
|
||||||
|
} else if (!strcmp(optname, "verbosity")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.verbosity = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- O_DIRECT model loading (upstream --direct-io) ---
|
||||||
|
} else if (!strcmp(optname, "direct_io") || !strcmp(optname, "use_direct_io")) {
|
||||||
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||||
|
params.use_direct_io = true;
|
||||||
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||||
|
params.use_direct_io = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- embedding normalization (upstream --embd-normalize) ---
|
||||||
|
// -1 none, 0 max-abs, 1 taxicab, 2 L2 (default), >2 p-norm
|
||||||
|
} else if (!strcmp(optname, "embd_normalize") || !strcmp(optname, "embedding_normalize")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.embd_normalize = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- reasoning parser (upstream --reasoning-format) ---
|
||||||
|
// Picks the parser for <think> blocks emitted by reasoning models.
|
||||||
|
// none / auto / deepseek / deepseek-legacy
|
||||||
|
} else if (!strcmp(optname, "reasoning_format")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
if (optval_str == "none") params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
|
else if (optval_str == "auto") params.reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||||
|
else if (optval_str == "deepseek") params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||||
|
else if (optval_str == "deepseek-legacy" || optval_str == "deepseek_legacy")
|
||||||
|
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY;
|
||||||
|
// unknown values silently keep the upstream default (DEEPSEEK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- reasoning budget (upstream --reasoning-budget) ---
|
||||||
|
// -1 unlimited, 0 disabled, >0 token budget for thinking blocks.
|
||||||
|
// Distinct from per-request `enable_thinking` (chat_template_kwargs).
|
||||||
|
} else if (!strcmp(optname, "enable_reasoning") || !strcmp(optname, "reasoning_budget")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.enable_reasoning = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- prefill assistant turn (upstream --no-prefill-assistant) ---
|
||||||
|
} else if (!strcmp(optname, "prefill_assistant")) {
|
||||||
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||||
|
params.prefill_assistant = true;
|
||||||
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||||
|
params.prefill_assistant = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- mmproj GPU offload (upstream --no-mmproj-offload, inverted) ---
|
||||||
|
} else if (!strcmp(optname, "mmproj_use_gpu") || !strcmp(optname, "mmproj_offload")) {
|
||||||
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||||
|
params.mmproj_use_gpu = true;
|
||||||
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||||
|
params.mmproj_use_gpu = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- per-image vision token budget (upstream --image-min/max-tokens) ---
|
||||||
|
} else if (!strcmp(optname, "image_min_tokens")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.image_min_tokens = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "image_max_tokens")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.image_max_tokens = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- main-model tensor buffer overrides (upstream --override-tensor) ---
|
||||||
|
// Format: <tensor regex>=<buffer type>,<tensor regex>=<buffer type>,...
|
||||||
|
// Mirrors the existing `draft_override_tensor` parser below.
|
||||||
|
} else if (!strcmp(optname, "override_tensor") || !strcmp(optname, "tensor_buft_overrides")) {
|
||||||
|
ggml_backend_load_all();
|
||||||
|
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||||
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||||
|
auto * dev = ggml_backend_dev_get(i);
|
||||||
|
auto * buft = ggml_backend_dev_buffer_type(dev);
|
||||||
|
if (buft) {
|
||||||
|
buft_list[ggml_backend_buft_name(buft)] = buft;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
static std::list<std::string> override_names;
|
||||||
|
std::string cur;
|
||||||
|
auto flush = [&](const std::string & spec) {
|
||||||
|
auto pos = spec.find('=');
|
||||||
|
if (pos == std::string::npos) return;
|
||||||
|
const std::string name = spec.substr(0, pos);
|
||||||
|
const std::string type = spec.substr(pos + 1);
|
||||||
|
auto it = buft_list.find(type);
|
||||||
|
if (it == buft_list.end()) return; // unknown buffer type: ignore
|
||||||
|
override_names.push_back(name);
|
||||||
|
params.tensor_buft_overrides.push_back(
|
||||||
|
{override_names.back().c_str(), it->second});
|
||||||
|
};
|
||||||
|
for (char c : optval_str) {
|
||||||
|
if (c == ',') { if (!cur.empty()) { flush(cur); cur.clear(); } }
|
||||||
|
else { cur.push_back(c); }
|
||||||
|
}
|
||||||
|
if (!cur.empty()) flush(cur);
|
||||||
|
|
||||||
// Speculative decoding options
|
// Speculative decoding options
|
||||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||||
@@ -701,16 +867,27 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
// Upstream switched to a vector of types (comma-separated for multi-type
|
// Upstream switched to a vector of types (comma-separated for multi-type
|
||||||
// chaining via common_speculative_types_from_names). We keep accepting a
|
// chaining via common_speculative_types_from_names). We keep accepting a
|
||||||
// single value here, but also tolerate comma-separated lists.
|
// single value here, but also tolerate comma-separated lists.
|
||||||
|
//
|
||||||
|
// ggml-org/llama.cpp#22964 also renamed the registered names from
|
||||||
|
// underscore- to dash-separated form, and replaced the bare
|
||||||
|
// `draft`/`eagle3` aliases with `draft-simple`/`draft-eagle3`. We
|
||||||
|
// normalize each token here so existing model configs keep working.
|
||||||
|
auto normalize_spec_name = [](std::string s) -> std::string {
|
||||||
|
std::replace(s.begin(), s.end(), '_', '-');
|
||||||
|
if (s == "draft") return "draft-simple";
|
||||||
|
if (s == "eagle3") return "draft-eagle3";
|
||||||
|
return s;
|
||||||
|
};
|
||||||
std::vector<std::string> names;
|
std::vector<std::string> names;
|
||||||
std::string item;
|
std::string item;
|
||||||
for (char c : optval_str) {
|
for (char c : optval_str) {
|
||||||
if (c == ',') {
|
if (c == ',') {
|
||||||
if (!item.empty()) { names.push_back(item); item.clear(); }
|
if (!item.empty()) { names.push_back(normalize_spec_name(item)); item.clear(); }
|
||||||
} else {
|
} else {
|
||||||
item.push_back(c);
|
item.push_back(c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!item.empty()) names.push_back(item);
|
if (!item.empty()) names.push_back(normalize_spec_name(item));
|
||||||
auto parsed = common_speculative_types_from_names(names);
|
auto parsed = common_speculative_types_from_names(names);
|
||||||
if (!parsed.empty()) {
|
if (!parsed.empty()) {
|
||||||
params.speculative.types = parsed;
|
params.speculative.types = parsed;
|
||||||
@@ -937,6 +1114,20 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
params.kv_overrides.back().key[0] = 0;
|
params.kv_overrides.back().key[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tensor_buft_overrides sentinel termination (mirrors upstream common/arg.cpp).
|
||||||
|
// Real entries are pushed during option parsing; here we pad/terminate so the
|
||||||
|
// model loader sees back().pattern == nullptr (GGML_ASSERT at common.cpp:1543)
|
||||||
|
// and so llama_params_fit has the placeholder slots it requires.
|
||||||
|
{
|
||||||
|
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||||
|
while (params.tensor_buft_overrides.size() < ntbo) {
|
||||||
|
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
|
||||||
|
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Add yarn
|
// TODO: Add yarn
|
||||||
|
|
||||||
if (!request->tensorsplit().empty()) {
|
if (!request->tensorsplit().empty()) {
|
||||||
@@ -2794,7 +2985,9 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int embd_normalize = 2; // default to Euclidean/L2 norm
|
// Honor the load-time embd_normalize set via options:embd_normalize.
|
||||||
|
// -1 none, 0 max-abs, 1 taxicab, 2 L2 (default), >2 p-norm.
|
||||||
|
int embd_normalize = params_base.embd_normalize;
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
auto rd = ctx_server.get_response_reader();
|
auto rd = ctx_server.get_response_reader();
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# acestep.cpp version
|
# acestep.cpp version
|
||||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||||
ACESTEP_CPP_VERSION?=e0c8d75a672fca5684c88c68dbf6d12f58754258
|
ACESTEP_CPP_VERSION?=ed53caf164e4492a5620b2e3f2264629cf66da24
|
||||||
SO_TARGET?=libgoacestepcpp.so
|
SO_TARGET?=libgoacestepcpp.so
|
||||||
|
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|||||||
@@ -22,12 +22,11 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Global model contexts (loaded once, reused across requests)
|
// Global model contexts (loaded once, reused across requests)
|
||||||
static DiTGGML g_dit = {};
|
static DiTGGML g_dit = {};
|
||||||
static DiTGGMLConfig g_dit_cfg;
|
static VAEGGML g_vae = {};
|
||||||
static VAEGGML g_vae = {};
|
static bool g_dit_loaded = false;
|
||||||
static bool g_dit_loaded = false;
|
static bool g_vae_loaded = false;
|
||||||
static bool g_vae_loaded = false;
|
static bool g_is_turbo = false;
|
||||||
static bool g_is_turbo = false;
|
|
||||||
|
|
||||||
// Silence latent [15000, 64] — read once from DiT GGUF
|
// Silence latent [15000, 64] — read once from DiT GGUF
|
||||||
static std::vector<float> g_silence_full;
|
static std::vector<float> g_silence_full;
|
||||||
@@ -72,10 +71,9 @@ int load_model(const char * lm_model_path, const char * text_encoder_path,
|
|||||||
g_text_enc_path = text_encoder_path;
|
g_text_enc_path = text_encoder_path;
|
||||||
g_dit_path = dit_model_path;
|
g_dit_path = dit_model_path;
|
||||||
|
|
||||||
// Load DiT model
|
// Load DiT model (backend init + config are handled inside dit_ggml_load)
|
||||||
fprintf(stderr, "[acestep-cpp] Loading DiT from %s\n", dit_model_path);
|
fprintf(stderr, "[acestep-cpp] Loading DiT from %s\n", dit_model_path);
|
||||||
dit_ggml_init_backend(&g_dit);
|
if (!dit_ggml_load(&g_dit, dit_model_path)) {
|
||||||
if (!dit_ggml_load(&g_dit, dit_model_path, g_dit_cfg, nullptr, 0.0f)) {
|
|
||||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load DiT from %s\n", dit_model_path);
|
fprintf(stderr, "[acestep-cpp] FATAL: failed to load DiT from %s\n", dit_model_path);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@@ -149,16 +147,16 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
|||||||
|
|
||||||
// Compute T (latent frames at 25Hz)
|
// Compute T (latent frames at 25Hz)
|
||||||
int T = (int)(duration * FRAMES_PER_SECOND);
|
int T = (int)(duration * FRAMES_PER_SECOND);
|
||||||
T = ((T + g_dit_cfg.patch_size - 1) / g_dit_cfg.patch_size) * g_dit_cfg.patch_size;
|
T = ((T + g_dit.cfg.patch_size - 1) / g_dit.cfg.patch_size) * g_dit.cfg.patch_size;
|
||||||
int S = T / g_dit_cfg.patch_size;
|
int S = T / g_dit.cfg.patch_size;
|
||||||
|
|
||||||
if (T > 15000) {
|
if (T > 15000) {
|
||||||
fprintf(stderr, "[acestep-cpp] ERROR: T=%d exceeds max 15000\n", T);
|
fprintf(stderr, "[acestep-cpp] ERROR: T=%d exceeds max 15000\n", T);
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
int Oc = g_dit_cfg.out_channels; // 64
|
int Oc = g_dit.cfg.out_channels; // 64
|
||||||
int ctx_ch = g_dit_cfg.in_channels - Oc; // 128
|
int ctx_ch = g_dit.cfg.in_channels - Oc; // 128
|
||||||
|
|
||||||
fprintf(stderr, "[acestep-cpp] T=%d, S=%d, duration=%.1fs, seed=%d\n", T, S, duration, seed);
|
fprintf(stderr, "[acestep-cpp] T=%d, S=%d, duration=%.1fs, seed=%d\n", T, S, duration, seed);
|
||||||
|
|
||||||
@@ -191,9 +189,8 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
|||||||
|
|
||||||
fprintf(stderr, "[acestep-cpp] caption: %d tokens, lyrics: %d tokens\n", S_text, S_lyric);
|
fprintf(stderr, "[acestep-cpp] caption: %d tokens, lyrics: %d tokens\n", S_text, S_lyric);
|
||||||
|
|
||||||
// 4. Text encoder forward
|
// 4. Text encoder forward (backend init handled inside qwen3_load_text_encoder)
|
||||||
Qwen3GGML text_enc = {};
|
Qwen3GGML text_enc = {};
|
||||||
qwen3_init_backend(&text_enc);
|
|
||||||
if (!qwen3_load_text_encoder(&text_enc, g_text_enc_path.c_str())) {
|
if (!qwen3_load_text_encoder(&text_enc, g_text_enc_path.c_str())) {
|
||||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load text encoder\n");
|
fprintf(stderr, "[acestep-cpp] FATAL: failed to load text encoder\n");
|
||||||
return 4;
|
return 4;
|
||||||
@@ -209,9 +206,8 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
|||||||
std::vector<float> lyric_embed(H_text * S_lyric);
|
std::vector<float> lyric_embed(H_text * S_lyric);
|
||||||
qwen3_embed_lookup(&text_enc, lyric_ids.data(), S_lyric, lyric_embed.data());
|
qwen3_embed_lookup(&text_enc, lyric_ids.data(), S_lyric, lyric_embed.data());
|
||||||
|
|
||||||
// 6. Condition encoder
|
// 6. Condition encoder (backend init handled inside cond_ggml_load)
|
||||||
CondGGML cond = {};
|
CondGGML cond = {};
|
||||||
cond_ggml_init_backend(&cond);
|
|
||||||
if (!cond_ggml_load(&cond, g_dit_path.c_str())) {
|
if (!cond_ggml_load(&cond, g_dit_path.c_str())) {
|
||||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load condition encoder\n");
|
fprintf(stderr, "[acestep-cpp] FATAL: failed to load condition encoder\n");
|
||||||
qwen3_free(&text_enc);
|
qwen3_free(&text_enc);
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# stablediffusion.cpp (ggml)
|
# stablediffusion.cpp (ggml)
|
||||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||||
STABLEDIFFUSION_GGML_VERSION?=90e87bc846f17059771efb8aaa31e9ef0cab6f78
|
STABLEDIFFUSION_GGML_VERSION?=0baf721215f45335a5df8caf0ecb34e870c956e7
|
||||||
|
|
||||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||||
|
|
||||||
|
|||||||
@@ -1188,6 +1188,9 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
|||||||
p->high_noise_sample_params.scheduler = scheduler;
|
p->high_noise_sample_params.scheduler = scheduler;
|
||||||
p->high_noise_sample_params.flow_shift = flow_shift;
|
p->high_noise_sample_params.flow_shift = flow_shift;
|
||||||
|
|
||||||
|
// Pin output fps in params; upstream uses it for audio sync (and we also mux at this rate).
|
||||||
|
p->fps = fps;
|
||||||
|
|
||||||
// Load init/end reference images if provided (resized to output dims).
|
// Load init/end reference images if provided (resized to output dims).
|
||||||
uint8_t* init_buf = nullptr;
|
uint8_t* init_buf = nullptr;
|
||||||
uint8_t* end_buf = nullptr;
|
uint8_t* end_buf = nullptr;
|
||||||
@@ -1206,11 +1209,14 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
|||||||
|
|
||||||
// Generate
|
// Generate
|
||||||
int num_frames_out = 0;
|
int num_frames_out = 0;
|
||||||
sd_image_t* frames = generate_video(sd_c, p, &num_frames_out);
|
sd_image_t* frames = nullptr;
|
||||||
|
sd_audio_t* audio = nullptr;
|
||||||
|
bool ok = generate_video(sd_c, p, &frames, &num_frames_out, &audio);
|
||||||
std::free(p);
|
std::free(p);
|
||||||
|
|
||||||
if (!frames || num_frames_out == 0) {
|
if (!ok || !frames || num_frames_out == 0) {
|
||||||
fprintf(stderr, "generate_video produced no frames\n");
|
fprintf(stderr, "generate_video produced no frames\n");
|
||||||
|
if (audio) free_sd_audio(audio);
|
||||||
if (init_buf) free(init_buf);
|
if (init_buf) free(init_buf);
|
||||||
if (end_buf) free(end_buf);
|
if (end_buf) free(end_buf);
|
||||||
return 1;
|
return 1;
|
||||||
@@ -1224,6 +1230,7 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
|||||||
if (frames[i].data) free(frames[i].data);
|
if (frames[i].data) free(frames[i].data);
|
||||||
}
|
}
|
||||||
free(frames);
|
free(frames);
|
||||||
|
if (audio) free_sd_audio(audio);
|
||||||
if (init_buf) free(init_buf);
|
if (init_buf) free(init_buf);
|
||||||
if (end_buf) free(end_buf);
|
if (end_buf) free(end_buf);
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||||
WHISPER_CPP_VERSION?=3e9b7d0fef3528ee2208da3cdb873a2c53d2ae2f
|
WHISPER_CPP_VERSION?=0ccd896f5b882628e1c077f9769735ef4ce52860
|
||||||
SO_TARGET?=libgowhisper.so
|
SO_TARGET?=libgowhisper.so
|
||||||
|
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|||||||
@@ -36,15 +36,11 @@ fi
|
|||||||
# flash-attn-4 4.0 stable lands.
|
# flash-attn-4 4.0 stable lands.
|
||||||
EXTRA_PIP_INSTALL_FLAGS+=" --prerelease=allow"
|
EXTRA_PIP_INSTALL_FLAGS+=" --prerelease=allow"
|
||||||
|
|
||||||
# JetPack 7 / L4T arm64 wheels are built for cp312 and shipped via
|
# JetPack 7 / L4T arm64 sglang + torch wheels come straight from PyPI now
|
||||||
# pypi.jetson-ai-lab.io. Bump the venv Python so the prebuilt sglang
|
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and sglang 0.5.11+
|
||||||
# wheel resolves cleanly. The actual install on l4t13 goes through
|
# ships a cp312 aarch64 wheel pinned to that torch). They're cp312-only,
|
||||||
# pyproject.toml (see the elif branch below) so [tool.uv.sources] can
|
# so bump the venv Python accordingly.
|
||||||
# pin only torch/torchvision/torchaudio/sglang to the jetson-ai-lab
|
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||||
# index — leaving PyPI as the path for transitive deps like
|
|
||||||
# markdown-it-py / anthropic / propcache that the L4T mirror's proxy
|
|
||||||
# 503s on. No --index-strategy flag here: the explicit index keeps the
|
|
||||||
# scoping clean.
|
|
||||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||||
PYTHON_VERSION="3.12"
|
PYTHON_VERSION="3.12"
|
||||||
PYTHON_PATCH="12"
|
PYTHON_PATCH="12"
|
||||||
@@ -110,27 +106,6 @@ if [ "x${BUILD_TYPE}" == "x" ] || [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
|
|||||||
fi
|
fi
|
||||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
|
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
|
||||||
popd
|
popd
|
||||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
|
||||||
# [tool.uv.sources] can pin torch/torchvision/torchaudio/sglang to the
|
|
||||||
# jetson-ai-lab index, while everything else (transitive deps and
|
|
||||||
# PyPI-resolvable packages like transformers / accelerate) comes from
|
|
||||||
# PyPI. Bypasses installRequirements because uv pip install -r
|
|
||||||
# requirements.txt does not honor sources — see
|
|
||||||
# backend/python/sglang/pyproject.toml for the rationale. Mirrors the
|
|
||||||
# equivalent path in backend/python/vllm/install.sh.
|
|
||||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
|
||||||
ensureVenv
|
|
||||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
|
||||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
|
||||||
fi
|
|
||||||
pushd "${backend_dir}"
|
|
||||||
# Build deps first (matches installRequirements' requirements-install.txt
|
|
||||||
# pass — sglang/sgl-kernel sdists need packaging/setuptools-scm in the
|
|
||||||
# venv before they can build under --no-build-isolation).
|
|
||||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
|
||||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
|
||||||
popd
|
|
||||||
runProtogen
|
|
||||||
else
|
else
|
||||||
installRequirements
|
installRequirements
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the sglang backend.
|
|
||||||
#
|
|
||||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
|
||||||
#
|
|
||||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / sglang / sgl-kernel
|
|
||||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
|
||||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently.
|
|
||||||
# With `--extra-index-url` + `--index-strategy=unsafe-best-match` (the
|
|
||||||
# historical fix in install.sh) uv would pick those proxy URLs for ordinary
|
|
||||||
# PyPI packages — markdown-it-py, anthropic, propcache, etc. — and trip on
|
|
||||||
# the 503s. See e.g. CI run 25439791228 (markdown-it-py-4.0.0).
|
|
||||||
#
|
|
||||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
|
||||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
|
||||||
# This breaks the historical 503 path without losing access to the L4T
|
|
||||||
# wheels we actually need from there. Mirrors the equivalent fix already
|
|
||||||
# in backend/python/vllm/pyproject.toml.
|
|
||||||
#
|
|
||||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
|
||||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
|
||||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
|
||||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
|
||||||
# pipeline through libbackend.sh's installRequirements and never read
|
|
||||||
# this file.
|
|
||||||
[project]
|
|
||||||
name = "localai-sglang-l4t13"
|
|
||||||
version = "0.0.0"
|
|
||||||
requires-python = ">=3.12,<3.13"
|
|
||||||
dependencies = [
|
|
||||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
|
||||||
# l4t13 path bypasses installRequirements (see install.sh).
|
|
||||||
"grpcio==1.80.0",
|
|
||||||
"protobuf",
|
|
||||||
"certifi",
|
|
||||||
"setuptools",
|
|
||||||
"pillow",
|
|
||||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
|
||||||
"torch",
|
|
||||||
"torchvision",
|
|
||||||
"torchaudio",
|
|
||||||
# sglang on jetson — the [all] extra is deliberately omitted because it
|
|
||||||
# pulls outlines/decord, and decord has no aarch64 cp312 wheel anywhere
|
|
||||||
# (PyPI nor the jetson-ai-lab index ships only legacy cp35-cp37). With
|
|
||||||
# [all] uv backtracks through versions trying to satisfy decord and
|
|
||||||
# lands on sglang==0.1.16. The 0.5.0 floor matches the only major
|
|
||||||
# series the jetson-ai-lab sbsa/cu130 mirror currently publishes
|
|
||||||
# (sglang==0.5.1.post2 as of 2026-05-06). Bumping to >=0.5.11 here
|
|
||||||
# would make the build unsatisfiable until the mirror catches up.
|
|
||||||
# Gemma 4 / MTP recipes are therefore not supported on l4t13 — those
|
|
||||||
# features land on cublas12/cublas13 hosts that pull the newer wheel
|
|
||||||
# from PyPI. backend.py keeps backward compat with the 0.5.x SamplingParams
|
|
||||||
# field rename via runtime detection.
|
|
||||||
"sglang>=0.5.0",
|
|
||||||
# PyPI-resolvable packages that complete the runtime.
|
|
||||||
"accelerate",
|
|
||||||
"transformers",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "jetson-ai-lab"
|
|
||||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
|
||||||
explicit = true
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
|
||||||
torch = { index = "jetson-ai-lab" }
|
|
||||||
torchvision = { index = "jetson-ai-lab" }
|
|
||||||
torchaudio = { index = "jetson-ai-lab" }
|
|
||||||
sglang = { index = "jetson-ai-lab" }
|
|
||||||
15
backend/python/sglang/requirements-l4t13-after.txt
Normal file
15
backend/python/sglang/requirements-l4t13-after.txt
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# sglang 0.5.11+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist
|
||||||
|
# pins torch==2.11.0 / torchaudio==2.11.0, locking an ABI-consistent set with
|
||||||
|
# the cu130 torch wheel installed above. 0.5.11 is the floor for Gemma 4
|
||||||
|
# support (sgl-project/sglang#21952).
|
||||||
|
#
|
||||||
|
# The [all] extra is deliberately NOT used on aarch64: it pulls the
|
||||||
|
# [diffusion] sub-extra which requires `xatlas`, and xatlas ships no
|
||||||
|
# aarch64 wheel and its sdist depends on scikit_build_core without
|
||||||
|
# declaring it in build-system.requires — so under --no-build-isolation
|
||||||
|
# uv can't build it. Upstream sglang gates st_attn and vsa on
|
||||||
|
# platform_machine != aarch64 in the diffusion extra but forgot xatlas.
|
||||||
|
# Plain `sglang` carries everything backend.py uses (Engine, ServerArgs,
|
||||||
|
# FunctionCallParser, ReasoningParser); the [all] extras are optional
|
||||||
|
# accelerators not required at import time.
|
||||||
|
sglang>=0.5.11
|
||||||
9
backend/python/sglang/requirements-l4t13.txt
Normal file
9
backend/python/sglang/requirements-l4t13.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||||
|
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||||
|
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||||
|
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||||
|
accelerate
|
||||||
|
torch
|
||||||
|
torchvision
|
||||||
|
torchaudio
|
||||||
|
transformers
|
||||||
@@ -2,9 +2,9 @@ torch==2.7.1
|
|||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers>=5.8.0
|
transformers>=5.8.1
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.4.0
|
sentence-transformers==5.5.0
|
||||||
diffusers
|
diffusers
|
||||||
soundfile
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -2,9 +2,9 @@ torch==2.7.1
|
|||||||
accelerate
|
accelerate
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
transformers>=5.8.0
|
transformers>=5.8.1
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.4.0
|
sentence-transformers==5.5.0
|
||||||
diffusers
|
diffusers
|
||||||
soundfile
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -2,9 +2,9 @@
|
|||||||
torch==2.9.0
|
torch==2.9.0
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
transformers>=5.8.0
|
transformers>=5.8.1
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.4.0
|
sentence-transformers==5.5.0
|
||||||
diffusers
|
diffusers
|
||||||
soundfile
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/rocm7.0
|
--extra-index-url https://download.pytorch.org/whl/rocm7.0
|
||||||
torch==2.10.0+rocm7.0
|
torch==2.10.0+rocm7.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers>=5.8.0
|
transformers>=5.8.1
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.4.0
|
sentence-transformers==5.5.0
|
||||||
diffusers
|
diffusers
|
||||||
soundfile
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -3,9 +3,9 @@ torch
|
|||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
transformers>=5.8.0
|
transformers>=5.8.1
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.4.0
|
sentence-transformers==5.5.0
|
||||||
diffusers
|
diffusers
|
||||||
soundfile
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -2,9 +2,9 @@ torch==2.7.1
|
|||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers>=5.8.0
|
transformers>=5.8.1
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.4.0
|
sentence-transformers==5.5.0
|
||||||
diffusers
|
diffusers
|
||||||
soundfile
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
|
|||||||
@@ -13,14 +13,14 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Handle l4t build profiles (Python 3.12, pip fallback) if needed.
|
# Handle l4t build profiles (Python 3.12, pip fallback) if needed.
|
||||||
# unsafe-best-match is required on l4t13 because the jetson-ai-lab index
|
# Since PyTorch 2.11 (April 2026) PyPI ships aarch64 + cu130 manylinux wheels
|
||||||
# lists transitive deps at limited versions — without it uv pins to the
|
# directly for torch/torchvision/torchaudio and an aarch64 vllm wheel pinned
|
||||||
# first matching index and fails to resolve a compatible wheel from PyPI.
|
# to that torch, so the jetson-ai-lab mirror is no longer needed.
|
||||||
|
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||||
PYTHON_VERSION="3.12"
|
PYTHON_VERSION="3.12"
|
||||||
PYTHON_PATCH="12"
|
PYTHON_PATCH="12"
|
||||||
PY_STANDALONE_TAG="20251120"
|
PY_STANDALONE_TAG="20251120"
|
||||||
EXTRA_PIP_INSTALL_FLAGS="${EXTRA_PIP_INSTALL_FLAGS:-} --index-strategy=unsafe-best-match"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||||
@@ -42,18 +42,11 @@ if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
|||||||
else
|
else
|
||||||
uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700
|
uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700
|
||||||
fi
|
fi
|
||||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
elif [ "x${BUILD_PROFILE}" == "xcublas13" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||||
# JetPack 7 / L4T arm64 cu130 — vllm comes from the prebuilt SBSA wheel
|
# cublas13 (x86_64) and l4t13 (aarch64) both pull vllm from PyPI now:
|
||||||
# at jetson-ai-lab. Version is unpinned: the index ships whatever build
|
# vllm 0.19+ defaults to cu130 wheels on x86_64 and vllm 0.20+ ships an
|
||||||
# matches the cu130/cp312 ABI. unsafe-best-match lets uv fall through
|
# aarch64 manylinux wheel pinned to torch==2.11.0. No extra index needed
|
||||||
# to PyPI for transitive deps not present on the jetson-ai-lab index.
|
# in either case.
|
||||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
|
||||||
pip install vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
|
||||||
else
|
|
||||||
uv pip install --index-strategy=unsafe-best-match vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
|
||||||
fi
|
|
||||||
elif [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
|
||||||
# vllm 0.19+ defaults to cu130 wheels on PyPI, no extra index needed.
|
|
||||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||||
pip install vllm --torch-backend=auto
|
pip install vllm --torch-backend=auto
|
||||||
else
|
else
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
--extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
# JetPack 7 / L4T arm64 + CUDA 13. PyPI ships aarch64 + cu130 manylinux wheels
|
||||||
|
# for torch/torchvision/torchaudio directly since PyTorch 2.11 (April 2026),
|
||||||
|
# so no custom index is needed. flash-attn is dropped here: PyPI has no
|
||||||
|
# aarch64 wheel for it, but vLLM 0.20+ bundles its own vllm_flash_attn
|
||||||
|
# (fa2 + fa3) inside the main wheel, so it is not required at runtime.
|
||||||
|
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||||
accelerate
|
accelerate
|
||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
flash-attn
|
|
||||||
diffusers
|
diffusers
|
||||||
librosa
|
librosa
|
||||||
soundfile
|
soundfile
|
||||||
|
|||||||
@@ -43,14 +43,11 @@ if [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
|||||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# JetPack 7 / L4T arm64 wheels (torch, vllm, flash-attn) live on
|
# JetPack 7 / L4T arm64 vllm + torch wheels come straight from PyPI now
|
||||||
# pypi.jetson-ai-lab.io and are built for cp312, so bump the venv Python
|
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and vllm 0.20+ ships
|
||||||
# accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
# an aarch64 wheel pinned to that torch). They're cp312-only, so bump the
|
||||||
#
|
# venv Python accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||||
# l4t13 uses pyproject.toml (see the elif branch below) to pin only the
|
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||||
# L4T-specific wheels to the jetson-ai-lab index via [tool.uv.sources].
|
|
||||||
# That keeps PyPI as the resolution path for transitive deps like
|
|
||||||
# anthropic/openai/propcache, which the L4T mirror's proxy 503s on.
|
|
||||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||||
USE_PIP=true
|
USE_PIP=true
|
||||||
fi
|
fi
|
||||||
@@ -103,25 +100,6 @@ if [ "x${BUILD_TYPE}" == "xintel" ]; then
|
|||||||
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
|
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
|
||||||
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
|
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
|
||||||
popd
|
popd
|
||||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
|
||||||
# [tool.uv.sources] can pin torch/vllm/flash-attn/torchvision/torchaudio
|
|
||||||
# to the jetson-ai-lab index, while everything else (transitive deps and
|
|
||||||
# PyPI-resolvable packages like transformers) comes from PyPI. Bypasses
|
|
||||||
# installRequirements because uv pip install -r requirements.txt does not
|
|
||||||
# honor sources — see backend/python/vllm/pyproject.toml for the rationale.
|
|
||||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
|
||||||
ensureVenv
|
|
||||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
|
||||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
|
||||||
fi
|
|
||||||
pushd "${backend_dir}"
|
|
||||||
# Build deps first (matches installRequirements' requirements-install.txt
|
|
||||||
# pass — fastsafetensors and friends need pybind11 in the venv before
|
|
||||||
# their sdists can build under --no-build-isolation).
|
|
||||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
|
||||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
|
||||||
popd
|
|
||||||
runProtogen
|
|
||||||
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
|
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
|
||||||
# requirements-cpu-after.txt and compiles vllm locally against the host's
|
# requirements-cpu-after.txt and compiles vllm locally against the host's
|
||||||
# actual CPU. Not used by default because it takes ~30-40 minutes, but
|
# actual CPU. Not used by default because it takes ~30-40 minutes, but
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the vllm backend.
|
|
||||||
#
|
|
||||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
|
||||||
#
|
|
||||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / vllm / flash-attn
|
|
||||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
|
||||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently. With
|
|
||||||
# `--extra-index-url` + `--index-strategy=unsafe-best-match` (the historical
|
|
||||||
# fix in install.sh) uv would pick those proxy URLs for ordinary PyPI
|
|
||||||
# packages — `anthropic`, `openai`, `propcache`, `annotated-types` — and
|
|
||||||
# trip on the 503s. See e.g. CI run 25212201349 (anthropic-0.97.0).
|
|
||||||
#
|
|
||||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
|
||||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
|
||||||
# This breaks the historical 503 path without losing access to the L4T
|
|
||||||
# wheels we actually need from there.
|
|
||||||
#
|
|
||||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
|
||||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
|
||||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
|
||||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
|
||||||
# pipeline through libbackend.sh's installRequirements and never read
|
|
||||||
# this file.
|
|
||||||
[project]
|
|
||||||
name = "localai-vllm-l4t13"
|
|
||||||
version = "0.0.0"
|
|
||||||
requires-python = ">=3.12,<3.13"
|
|
||||||
dependencies = [
|
|
||||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
|
||||||
# l4t13 path bypasses installRequirements (see install.sh).
|
|
||||||
"grpcio==1.80.0",
|
|
||||||
"protobuf",
|
|
||||||
"certifi",
|
|
||||||
"setuptools",
|
|
||||||
"pillow",
|
|
||||||
"charset-normalizer>=3.4.7",
|
|
||||||
"chardet",
|
|
||||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
|
||||||
"torch",
|
|
||||||
"torchvision",
|
|
||||||
"torchaudio",
|
|
||||||
"flash-attn",
|
|
||||||
"vllm",
|
|
||||||
# PyPI-resolvable packages that complete the runtime — accelerate,
|
|
||||||
# transformers, bitsandbytes carry their own wheels for aarch64.
|
|
||||||
"accelerate",
|
|
||||||
"transformers",
|
|
||||||
"bitsandbytes",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "jetson-ai-lab"
|
|
||||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
|
||||||
explicit = true
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
|
||||||
torch = { index = "jetson-ai-lab" }
|
|
||||||
torchvision = { index = "jetson-ai-lab" }
|
|
||||||
torchaudio = { index = "jetson-ai-lab" }
|
|
||||||
flash-attn = { index = "jetson-ai-lab" }
|
|
||||||
vllm = { index = "jetson-ai-lab" }
|
|
||||||
@@ -3,5 +3,5 @@
|
|||||||
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
||||||
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
||||||
# so uv consults this index alongside PyPI.
|
# so uv consults this index alongside PyPI.
|
||||||
--extra-index-url https://wheels.vllm.ai/0.20.2/cu130
|
--extra-index-url https://wheels.vllm.ai/0.21.0/cu130
|
||||||
vllm==0.20.2
|
vllm==0.21.0
|
||||||
|
|||||||
4
backend/python/vllm/requirements-l4t13-after.txt
Normal file
4
backend/python/vllm/requirements-l4t13-after.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# vLLM 0.20+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist pins
|
||||||
|
# torch==2.11.0 / torchvision==0.26.0 / torchaudio==2.11.0, locking an ABI-
|
||||||
|
# consistent set with the cu130 torch wheel installed above.
|
||||||
|
vllm
|
||||||
8
backend/python/vllm/requirements-l4t13.txt
Normal file
8
backend/python/vllm/requirements-l4t13.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||||
|
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||||
|
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||||
|
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||||
|
accelerate
|
||||||
|
torch
|
||||||
|
transformers
|
||||||
|
bitsandbytes
|
||||||
@@ -233,7 +233,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
|||||||
xlog.Info("File stager initialized (HTTP direct transfer)")
|
xlog.Info("File stager initialized (HTTP direct transfer)")
|
||||||
}
|
}
|
||||||
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
||||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
remoteUnloader := nodes.NewRemoteUnloaderAdapter(
|
||||||
|
registry,
|
||||||
|
natsClient,
|
||||||
|
cfg.Distributed.BackendInstallTimeoutOrDefault(),
|
||||||
|
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||||
|
)
|
||||||
|
|
||||||
// All dependencies ready — build SmartRouter with all options at once
|
// All dependencies ready — build SmartRouter with all options at once
|
||||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/services/jobs"
|
"github.com/mudler/LocalAI/core/services/jobs"
|
||||||
"github.com/mudler/LocalAI/core/services/nodes"
|
"github.com/mudler/LocalAI/core/services/nodes"
|
||||||
"github.com/mudler/LocalAI/core/services/storage"
|
"github.com/mudler/LocalAI/core/services/storage"
|
||||||
"github.com/mudler/LocalAI/pkg/vram"
|
|
||||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||||
"github.com/mudler/LocalAI/internal"
|
"github.com/mudler/LocalAI/internal"
|
||||||
|
"github.com/mudler/LocalAI/pkg/vram"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||||
@@ -200,7 +200,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
|||||||
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||||
)
|
)
|
||||||
application.galleryService.SetBackendManager(
|
application.galleryService.SetBackendManager(
|
||||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
|
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry, application.galleryService),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -212,12 +212,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.RequireBackendIntegrity, nil, options.ModelsURL...); err != nil {
|
||||||
xlog.Error("error installing models", "error", err)
|
xlog.Error("error installing models", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, backend := range options.ExternalBackends {
|
for _, backend := range options.ExternalBackends {
|
||||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", "", options.RequireBackendIntegrity); err != nil {
|
||||||
xlog.Error("error installing external backend", "error", err)
|
xlog.Error("error installing external backend", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -267,13 +267,13 @@ func New(opts ...config.AppOption) (*Application, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if options.PreloadJSONModels != "" {
|
if options.PreloadJSONModels != "" {
|
||||||
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels, options.RequireBackendIntegrity); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.PreloadModelsFromPath != "" {
|
if options.PreloadModelsFromPath != "" {
|
||||||
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath, options.RequireBackendIntegrity); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -552,6 +552,13 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
|||||||
options.TracingMaxItems = *settings.TracingMaxItems
|
options.TracingMaxItems = *settings.TracingMaxItems
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if settings.TracingMaxBodyBytes != nil {
|
||||||
|
// Allow the on-disk setting to override the CLI/env default. The
|
||||||
|
// startup default is non-zero (see NewApplicationConfig), so a plain
|
||||||
|
// `== 0` guard like the others would never trigger; we instead respect
|
||||||
|
// any value the file specifies. 0 in the file means "uncapped".
|
||||||
|
options.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
// Branding / whitelabeling. There are no env vars for these — the file is
|
// Branding / whitelabeling. There are no env vars for these — the file is
|
||||||
// the only source — so apply unconditionally. Without this block a server
|
// the only source — so apply unconditionally. Without this block a server
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
|||||||
err = bm.UpgradeBackend(ctx, name, nil)
|
err = bm.UpgradeBackend(ctx, name, nil)
|
||||||
} else {
|
} else {
|
||||||
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||||
uc.galleries, name, nil)
|
uc.galleries, name, nil, uc.appConfig.RequireBackendIntegrity)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Failed to auto-upgrade backend",
|
xlog.Error("Failed to auto-upgrade backend",
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
|||||||
if !slices.Contains(modelNames, modelName) {
|
if !slices.Contains(modelNames, modelName) {
|
||||||
utils.ResetDownloadTimers()
|
utils.ResetDownloadTimers()
|
||||||
// if we failed to load the model, we try to download it
|
// if we failed to load the model, we try to download it
|
||||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries, o.RequireBackendIntegrity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
|
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
|
||||||
//return nil, err
|
//return nil, err
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
|||||||
MinP: float32(*c.MinP),
|
MinP: float32(*c.MinP),
|
||||||
Tokens: int32(*c.Maxtokens),
|
Tokens: int32(*c.Maxtokens),
|
||||||
Threads: int32(*c.Threads),
|
Threads: int32(*c.Threads),
|
||||||
PromptCacheAll: c.PromptCacheAll,
|
PromptCacheAll: *c.PromptCacheAll,
|
||||||
PromptCacheRO: c.PromptCacheRO,
|
PromptCacheRO: c.PromptCacheRO,
|
||||||
PromptCachePath: promptCachePath,
|
PromptCachePath: promptCachePath,
|
||||||
F16KV: *c.F16,
|
F16KV: *c.F16,
|
||||||
|
|||||||
@@ -17,9 +17,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type BackendsCMDFlags struct {
|
type BackendsCMDFlags struct {
|
||||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
|
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
|
||||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||||
|
RequireBackendIntegrity bool `env:"LOCALAI_REQUIRE_BACKEND_INTEGRITY,REQUIRE_BACKEND_INTEGRITY" help:"If true, reject backend installs without a configured signature verification policy (OCI URIs) or SHA256 (tarball/HTTP URIs)." group:"hardening" default:"false"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type BackendsList struct {
|
type BackendsList struct {
|
||||||
@@ -126,7 +127,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelLoader := model.NewModelLoader(systemState)
|
modelLoader := model.NewModelLoader(systemState)
|
||||||
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias, bi.RequireBackendIntegrity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -197,7 +198,7 @@ func (bu *BackendsUpgrade) Run(ctx *cliContext.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := gallery.UpgradeBackend(context.Background(), systemState, modelLoader, galleries, name, progressCallback); err != nil {
|
if err := gallery.UpgradeBackend(context.Background(), systemState, modelLoader, galleries, name, progressCallback, bu.RequireBackendIntegrity); err != nil {
|
||||||
fmt.Printf("Failed to upgrade %s: %v\n", name, err)
|
fmt.Printf("Failed to upgrade %s: %v\n", name, err)
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("Backend %s upgraded successfully\n", name)
|
fmt.Printf("Backend %s upgraded successfully\n", name)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ type ModelsList struct {
|
|||||||
|
|
||||||
type ModelsInstall struct {
|
type ModelsInstall struct {
|
||||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||||
|
RequireBackendIntegrity bool `env:"LOCALAI_REQUIRE_BACKEND_INTEGRITY,REQUIRE_BACKEND_INTEGRITY" help:"If true, reject backend installs without a configured signature verification policy (OCI URIs) or SHA256 (tarball/HTTP URIs)." group:"hardening" default:"false"`
|
||||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES" help:"If true, automatically loads backend galleries" group:"backends" default:"true"`
|
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES" help:"If true, automatically loads backend galleries" group:"backends" default:"true"`
|
||||||
ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
|
ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
|
||||||
|
|
||||||
@@ -71,7 +72,6 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||||
|
|
||||||
systemState, err := system.GetSystemState(
|
systemState, err := system.GetSystemState(
|
||||||
system.WithModelPath(mi.ModelsPath),
|
system.WithModelPath(mi.ModelsPath),
|
||||||
system.WithBackendPath(mi.BackendsPath),
|
system.WithBackendPath(mi.BackendsPath),
|
||||||
@@ -135,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelLoader := model.NewModelLoader(systemState)
|
modelLoader := model.NewModelLoader(systemState)
|
||||||
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, mi.RequireBackendIntegrity, progressCallback, modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,19 +39,19 @@ type RunCMD struct {
|
|||||||
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
||||||
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
|
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
|
||||||
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
|
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
|
||||||
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
||||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||||
AutoUpgradeBackends bool `env:"LOCALAI_AUTO_UPGRADE_BACKENDS,AUTO_UPGRADE_BACKENDS" help:"Automatically upgrade backends when new versions are detected" group:"backends" default:"false"`
|
AutoUpgradeBackends bool `env:"LOCALAI_AUTO_UPGRADE_BACKENDS,AUTO_UPGRADE_BACKENDS" help:"Automatically upgrade backends when new versions are detected" group:"backends" default:"false"`
|
||||||
PreferDevelopmentBackends bool `env:"LOCALAI_PREFER_DEV_BACKENDS,PREFER_DEV_BACKENDS" help:"Prefer development backend versions (shows development backends by default in UI)" group:"backends" default:"false"`
|
PreferDevelopmentBackends bool `env:"LOCALAI_PREFER_DEV_BACKENDS,PREFER_DEV_BACKENDS" help:"Prefer development backend versions (shows development backends by default in UI)" group:"backends" default:"false"`
|
||||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||||
|
|
||||||
F16 bool `name:"f16" env:"LOCALAI_F16,F16" help:"Enable GPU acceleration" group:"performance"`
|
F16 bool `name:"f16" env:"LOCALAI_F16,F16" help:"Enable GPU acceleration" group:"performance"`
|
||||||
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
|
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
|
||||||
@@ -67,6 +67,7 @@ type RunCMD struct {
|
|||||||
OllamaAPIRootEndpoint bool `env:"LOCALAI_OLLAMA_API_ROOT_ENDPOINT" default:"false" help:"Register Ollama-compatible health check on / (replaces web UI on root path). The /api/* Ollama endpoints are always available regardless of this flag" group:"api"`
|
OllamaAPIRootEndpoint bool `env:"LOCALAI_OLLAMA_API_ROOT_ENDPOINT" default:"false" help:"Register Ollama-compatible health check on / (replaces web UI on root path). The /api/* Ollama endpoints are always available regardless of this flag" group:"api"`
|
||||||
DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
|
DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
|
||||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||||
|
RequireBackendIntegrity bool `env:"LOCALAI_REQUIRE_BACKEND_INTEGRITY,REQUIRE_BACKEND_INTEGRITY" help:"If true, backend installs without a configured signature verification policy (for OCI URIs) or SHA256 (for tarball/HTTP URIs) are rejected. Default is to warn and install. Set this in production once your gallery's verification: block is populated." group:"hardening" default:"false"`
|
||||||
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
||||||
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
|
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
|
||||||
DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
|
DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
|
||||||
@@ -99,6 +100,7 @@ type RunCMD struct {
|
|||||||
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
|
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
|
||||||
EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
|
EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
|
||||||
TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
|
TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
|
||||||
|
TracingMaxBodyBytes int `env:"LOCALAI_TRACING_MAX_BODY_BYTES" default:"65536" help:"Maximum bytes captured per request/response body in the trace buffer (0 = uncapped). Caps memory growth from chatty endpoints like /embeddings." group:"api"`
|
||||||
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`
|
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`
|
||||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||||
|
|
||||||
@@ -143,16 +145,18 @@ type RunCMD struct {
|
|||||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||||
|
|
||||||
// Distributed / Horizontal Scaling
|
// Distributed / Horizontal Scaling
|
||||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||||
|
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||||
|
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||||
|
|
||||||
Version bool
|
Version bool
|
||||||
}
|
}
|
||||||
@@ -253,6 +257,20 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
if r.StorageSecretKey != "" {
|
if r.StorageSecretKey != "" {
|
||||||
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
|
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
|
||||||
}
|
}
|
||||||
|
if r.BackendInstallTimeout != "" {
|
||||||
|
d, err := time.ParseDuration(r.BackendInstallTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT %q: %w", r.BackendInstallTimeout, err)
|
||||||
|
}
|
||||||
|
opts = append(opts, config.WithBackendInstallTimeout(d))
|
||||||
|
}
|
||||||
|
if r.BackendUpgradeTimeout != "" {
|
||||||
|
d, err := time.ParseDuration(r.BackendUpgradeTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT %q: %w", r.BackendUpgradeTimeout, err)
|
||||||
|
}
|
||||||
|
opts = append(opts, config.WithBackendUpgradeTimeout(d))
|
||||||
|
}
|
||||||
if r.RegistrationToken != "" {
|
if r.RegistrationToken != "" {
|
||||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||||
}
|
}
|
||||||
@@ -272,6 +290,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
opts = append(opts, config.EnableTracing)
|
opts = append(opts, config.EnableTracing)
|
||||||
}
|
}
|
||||||
opts = append(opts, config.WithTracingMaxItems(r.TracingMaxItems))
|
opts = append(opts, config.WithTracingMaxItems(r.TracingMaxItems))
|
||||||
|
opts = append(opts, config.WithTracingMaxBodyBytes(r.TracingMaxBodyBytes))
|
||||||
|
|
||||||
token := ""
|
token := ""
|
||||||
if r.Peer2Peer || r.Peer2PeerToken != "" {
|
if r.Peer2Peer || r.Peer2PeerToken != "" {
|
||||||
@@ -503,6 +522,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
opts = append(opts, config.WithAutoUpgradeBackends(r.AutoUpgradeBackends))
|
opts = append(opts, config.WithAutoUpgradeBackends(r.AutoUpgradeBackends))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.RequireBackendIntegrity {
|
||||||
|
opts = append(opts, config.WithRequireBackendIntegrity(r.RequireBackendIntegrity))
|
||||||
|
}
|
||||||
|
|
||||||
if r.PreferDevelopmentBackends {
|
if r.PreferDevelopmentBackends {
|
||||||
opts = append(opts, config.WithPreferDevelopmentBackends(r.PreferDevelopmentBackends))
|
opts = append(opts, config.WithPreferDevelopmentBackends(r.PreferDevelopmentBackends))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
package worker
|
package worker
|
||||||
|
|
||||||
type WorkerFlags struct {
|
type WorkerFlags struct {
|
||||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
||||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||||
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
RequireBackendIntegrity bool `env:"LOCALAI_REQUIRE_BACKEND_INTEGRITY,REQUIRE_BACKEND_INTEGRITY" help:"If true, reject backend installs without a configured signature verification policy (OCI URIs) or SHA256 (tarball/HTTP URIs)." group:"hardening" default:"false"`
|
||||||
|
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Worker struct {
|
type Worker struct {
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
// installing the backend from the gallery if it isn't present.
|
// installing the backend from the gallery if it isn't present.
|
||||||
// `name` is the gallery entry name (for vLLM the meta entry "vllm"
|
// `name` is the gallery entry name (for vLLM the meta entry "vllm"
|
||||||
// resolves to a platform-specific package via capability lookup).
|
// resolves to a platform-specific package via capability lookup).
|
||||||
func findBackendPath(name, galleries string, systemState *system.SystemState) (string, error) {
|
func findBackendPath(name, galleries string, systemState *system.SystemState, requireIntegrity bool) (string, error) {
|
||||||
backends, err := gallery.ListSystemBackends(systemState)
|
backends, err := gallery.ListSystemBackends(systemState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -33,7 +33,7 @@ func findBackendPath(name, galleries string, systemState *system.SystemState) (s
|
|||||||
xlog.Error("failed loading galleries", "error", err)
|
xlog.Error("failed loading galleries", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, name, nil, true); err != nil {
|
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, name, nil, true, requireIntegrity); err != nil {
|
||||||
xlog.Error("backend not found, failed to install it", "name", name, "error", err)
|
xlog.Error("backend not found, failed to install it", "name", name, "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ const (
|
|||||||
llamaCPPGalleryName = "llama-cpp"
|
llamaCPPGalleryName = "llama-cpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (string, error) {
|
func findLLamaCPPBackend(galleries string, systemState *system.SystemState, requireIntegrity bool) (string, error) {
|
||||||
backends, err := gallery.ListSystemBackends(systemState)
|
backends, err := gallery.ListSystemBackends(systemState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Warn("Failed listing system backends", "error", err)
|
xlog.Warn("Failed listing system backends", "error", err)
|
||||||
@@ -43,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
|
|||||||
xlog.Error("failed loading galleries", "error", err)
|
xlog.Error("failed loading galleries", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true, requireIntegrity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("llama-cpp backend not found, failed to install it", "error", err)
|
xlog.Error("llama-cpp backend not found, failed to install it", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -76,7 +76,7 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
|
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
|
|
||||||
const mlxDistributedGalleryName = "mlx-distributed"
|
const mlxDistributedGalleryName = "mlx-distributed"
|
||||||
|
|
||||||
func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState) (string, error) {
|
func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState, requireIntegrity bool) (string, error) {
|
||||||
return findBackendPath(mlxDistributedGalleryName, galleries, systemState)
|
return findBackendPath(mlxDistributedGalleryName, galleries, systemState, requireIntegrity)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildMLXCommand builds the exec.Cmd to launch the mlx-distributed backend.
|
// buildMLXCommand builds the exec.Cmd to launch the mlx-distributed backend.
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func (r *MLXDistributed) Run(ctx *cliContext.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState)
|
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot find mlx-distributed backend: %w", err)
|
return fmt.Errorf("cannot find mlx-distributed backend: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
|||||||
for {
|
for {
|
||||||
xlog.Info("Starting llama-cpp-rpc-server", "address", address, "port", port)
|
xlog.Info("Starting llama-cpp-rpc-server", "address", address, "port", port)
|
||||||
|
|
||||||
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
|
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Failed to find llama-cpp-rpc-server", "error", err)
|
xlog.Error("Failed to find llama-cpp-rpc-server", "error", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (r *P2PMLX) Run(ctx *cliContext.Context) error {
|
|||||||
c, cancel := context.WithCancel(context.Background())
|
c, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState)
|
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Warn("Could not find mlx-distributed backend from gallery, will try backend.py directly", "error", err)
|
xlog.Warn("Could not find mlx-distributed backend from gallery, will try backend.py directly", "error", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
|
|||||||
return fmt.Errorf("getting system state: %w", err)
|
return fmt.Errorf("getting system state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
backendPath, err := findBackendPath("vllm", r.BackendGalleries, systemState)
|
backendPath, err := findBackendPath("vllm", r.BackendGalleries, systemState, r.RequireBackendIntegrity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot find vllm backend: %w", err)
|
return fmt.Errorf("cannot find vllm backend: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type ApplicationConfig struct {
|
|||||||
Debug bool
|
Debug bool
|
||||||
EnableTracing bool
|
EnableTracing bool
|
||||||
TracingMaxItems int
|
TracingMaxItems int
|
||||||
|
TracingMaxBodyBytes int // Per-body cap for captured request/response bodies; 0 disables the cap
|
||||||
EnableBackendLogging bool
|
EnableBackendLogging bool
|
||||||
GeneratedContentDir string
|
GeneratedContentDir string
|
||||||
|
|
||||||
@@ -60,6 +61,13 @@ type ApplicationConfig struct {
|
|||||||
AutoUpgradeBackends bool
|
AutoUpgradeBackends bool
|
||||||
PreferDevelopmentBackends bool
|
PreferDevelopmentBackends bool
|
||||||
|
|
||||||
|
// RequireBackendIntegrity promotes a missing SHA256 (tarball/HTTP URIs)
|
||||||
|
// or missing verification policy (OCI URIs) from a warning to a hard
|
||||||
|
// failure during backend install/upgrade. Off by default to keep
|
||||||
|
// upgrades non-breaking; operators opt in explicitly via
|
||||||
|
// --require-backend-integrity / LOCALAI_REQUIRE_BACKEND_INTEGRITY.
|
||||||
|
RequireBackendIntegrity bool
|
||||||
|
|
||||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||||
WatchDogIdle bool
|
WatchDogIdle bool
|
||||||
@@ -180,6 +188,7 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
|||||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||||
TracingMaxItems: 1024,
|
TracingMaxItems: 1024,
|
||||||
|
TracingMaxBodyBytes: 64 * 1024, // 64 KiB - caps each request/response body in the trace buffer
|
||||||
AgentPool: AgentPoolConfig{
|
AgentPool: AgentPoolConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Timeout: "5m",
|
Timeout: "5m",
|
||||||
@@ -436,6 +445,10 @@ func WithAutoUpgradeBackends(v bool) AppOption {
|
|||||||
return func(o *ApplicationConfig) { o.AutoUpgradeBackends = v }
|
return func(o *ApplicationConfig) { o.AutoUpgradeBackends = v }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithRequireBackendIntegrity(v bool) AppOption {
|
||||||
|
return func(o *ApplicationConfig) { o.RequireBackendIntegrity = v }
|
||||||
|
}
|
||||||
|
|
||||||
func WithPreferDevelopmentBackends(v bool) AppOption {
|
func WithPreferDevelopmentBackends(v bool) AppOption {
|
||||||
return func(o *ApplicationConfig) { o.PreferDevelopmentBackends = v }
|
return func(o *ApplicationConfig) { o.PreferDevelopmentBackends = v }
|
||||||
}
|
}
|
||||||
@@ -567,6 +580,12 @@ func WithTracingMaxItems(items int) AppOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithTracingMaxBodyBytes(bytes int) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.TracingMaxBodyBytes = bytes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithGeneratedContentDir(generatedContentDir string) AppOption {
|
func WithGeneratedContentDir(generatedContentDir string) AppOption {
|
||||||
return func(o *ApplicationConfig) {
|
return func(o *ApplicationConfig) {
|
||||||
o.GeneratedContentDir = generatedContentDir
|
o.GeneratedContentDir = generatedContentDir
|
||||||
@@ -909,6 +928,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
|||||||
f16 := o.F16
|
f16 := o.F16
|
||||||
debug := o.Debug
|
debug := o.Debug
|
||||||
tracingMaxItems := o.TracingMaxItems
|
tracingMaxItems := o.TracingMaxItems
|
||||||
|
tracingMaxBodyBytes := o.TracingMaxBodyBytes
|
||||||
enableTracing := o.EnableTracing
|
enableTracing := o.EnableTracing
|
||||||
enableBackendLogging := o.EnableBackendLogging
|
enableBackendLogging := o.EnableBackendLogging
|
||||||
cors := o.CORS
|
cors := o.CORS
|
||||||
@@ -997,6 +1017,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
|||||||
F16: &f16,
|
F16: &f16,
|
||||||
Debug: &debug,
|
Debug: &debug,
|
||||||
TracingMaxItems: &tracingMaxItems,
|
TracingMaxItems: &tracingMaxItems,
|
||||||
|
TracingMaxBodyBytes: &tracingMaxBodyBytes,
|
||||||
EnableTracing: &enableTracing,
|
EnableTracing: &enableTracing,
|
||||||
EnableBackendLogging: &enableBackendLogging,
|
EnableBackendLogging: &enableBackendLogging,
|
||||||
CORS: &cors,
|
CORS: &cors,
|
||||||
@@ -1135,6 +1156,9 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
|||||||
if settings.TracingMaxItems != nil {
|
if settings.TracingMaxItems != nil {
|
||||||
o.TracingMaxItems = *settings.TracingMaxItems
|
o.TracingMaxItems = *settings.TracingMaxItems
|
||||||
}
|
}
|
||||||
|
if settings.TracingMaxBodyBytes != nil {
|
||||||
|
o.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||||
|
}
|
||||||
if settings.EnableBackendLogging != nil {
|
if settings.EnableBackendLogging != nil {
|
||||||
o.EnableBackendLogging = *settings.EnableBackendLogging
|
o.EnableBackendLogging = *settings.EnableBackendLogging
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,10 @@ type DistributedConfig struct {
|
|||||||
// model-row cleanup on MarkUnhealthy / MarkDraining).
|
// model-row cleanup on MarkUnhealthy / MarkDraining).
|
||||||
DisablePerModelHealthCheck bool
|
DisablePerModelHealthCheck bool
|
||||||
|
|
||||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||||
|
|
||||||
|
BackendInstallTimeout time.Duration // NATS round-trip timeout for backend.install (default 15m)
|
||||||
|
BackendUpgradeTimeout time.Duration // NATS round-trip timeout for backend.upgrade (default 15m)
|
||||||
|
|
||||||
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
|
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
|
||||||
|
|
||||||
@@ -68,13 +71,15 @@ func (c DistributedConfig) Validate() error {
|
|||||||
}
|
}
|
||||||
// Check for negative durations
|
// Check for negative durations
|
||||||
for name, d := range map[string]time.Duration{
|
for name, d := range map[string]time.Duration{
|
||||||
"mcp-tool-timeout": c.MCPToolTimeout,
|
FlagMCPToolTimeout: c.MCPToolTimeout,
|
||||||
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
|
FlagMCPDiscoveryTimeout: c.MCPDiscoveryTimeout,
|
||||||
"worker-wait-timeout": c.WorkerWaitTimeout,
|
FlagWorkerWaitTimeout: c.WorkerWaitTimeout,
|
||||||
"drain-timeout": c.DrainTimeout,
|
FlagDrainTimeout: c.DrainTimeout,
|
||||||
"health-check-interval": c.HealthCheckInterval,
|
FlagHealthCheckInterval: c.HealthCheckInterval,
|
||||||
"stale-node-threshold": c.StaleNodeThreshold,
|
FlagStaleNodeThreshold: c.StaleNodeThreshold,
|
||||||
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
|
FlagMCPCIJobTimeout: c.MCPCIJobTimeout,
|
||||||
|
FlagBackendInstallTimeout: c.BackendInstallTimeout,
|
||||||
|
FlagBackendUpgradeTimeout: c.BackendUpgradeTimeout,
|
||||||
} {
|
} {
|
||||||
if d < 0 {
|
if d < 0 {
|
||||||
return fmt.Errorf("%s must not be negative", name)
|
return fmt.Errorf("%s must not be negative", name)
|
||||||
@@ -137,24 +142,66 @@ func WithStorageSecretKey(key string) AppOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithBackendInstallTimeout(d time.Duration) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.BackendInstallTimeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBackendUpgradeTimeout(d time.Duration) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.BackendUpgradeTimeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||||
o.Distributed.AutoApproveNodes = true
|
o.Distributed.AutoApproveNodes = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flag names for distributed timeout / interval configuration. These are
|
||||||
|
// the kebab-case identifiers kong derives from the matching RunCMD struct
|
||||||
|
// fields; they appear in Validate error messages and any other operator-
|
||||||
|
// facing surface that needs to reference a specific knob by name. Keeping
|
||||||
|
// them as constants prevents the string from drifting from the actual
|
||||||
|
// flag a future rename would produce.
|
||||||
|
const (
|
||||||
|
FlagMCPToolTimeout = "mcp-tool-timeout"
|
||||||
|
FlagMCPDiscoveryTimeout = "mcp-discovery-timeout"
|
||||||
|
FlagWorkerWaitTimeout = "worker-wait-timeout"
|
||||||
|
FlagDrainTimeout = "drain-timeout"
|
||||||
|
FlagHealthCheckInterval = "health-check-interval"
|
||||||
|
FlagStaleNodeThreshold = "stale-node-threshold"
|
||||||
|
FlagMCPCIJobTimeout = "mcp-ci-job-timeout"
|
||||||
|
FlagBackendInstallTimeout = "backend-install-timeout"
|
||||||
|
FlagBackendUpgradeTimeout = "backend-upgrade-timeout"
|
||||||
|
)
|
||||||
|
|
||||||
// Defaults for distributed timeouts.
|
// Defaults for distributed timeouts.
|
||||||
const (
|
const (
|
||||||
DefaultMCPToolTimeout = 360 * time.Second
|
DefaultMCPToolTimeout = 360 * time.Second
|
||||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||||
DefaultDrainTimeout = 30 * time.Second
|
DefaultDrainTimeout = 30 * time.Second
|
||||||
DefaultHealthCheckInterval = 15 * time.Second
|
DefaultHealthCheckInterval = 15 * time.Second
|
||||||
DefaultStaleNodeThreshold = 60 * time.Second
|
DefaultStaleNodeThreshold = 60 * time.Second
|
||||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||||
|
DefaultBackendInstallTimeout = 15 * time.Minute
|
||||||
|
DefaultBackendUpgradeTimeout = 15 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||||
const DefaultMaxUploadSize int64 = 50 << 30
|
const DefaultMaxUploadSize int64 = 50 << 30
|
||||||
|
|
||||||
|
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||||
|
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||||
|
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackendUpgradeTimeoutOrDefault returns the configured timeout or the default.
|
||||||
|
func (c DistributedConfig) BackendUpgradeTimeoutOrDefault() time.Duration {
|
||||||
|
return cmp.Or(c.BackendUpgradeTimeout, DefaultBackendUpgradeTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
|
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
|
||||||
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
|
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
|
||||||
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
|
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
|
||||||
|
|||||||
90
core/config/distributed_config_test.go
Normal file
90
core/config/distributed_config_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package config_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("DistributedConfig backend NATS timeouts", func() {
|
||||||
|
Context("BackendInstallTimeoutOrDefault", func() {
|
||||||
|
It("returns 15 minutes when unset", func() {
|
||||||
|
c := config.DistributedConfig{}
|
||||||
|
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the configured value when set", func() {
|
||||||
|
c := config.DistributedConfig{BackendInstallTimeout: 42 * time.Minute}
|
||||||
|
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(42 * time.Minute))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("BackendUpgradeTimeoutOrDefault", func() {
|
||||||
|
It("returns 15 minutes when unset", func() {
|
||||||
|
c := config.DistributedConfig{}
|
||||||
|
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the configured value when set", func() {
|
||||||
|
c := config.DistributedConfig{BackendUpgradeTimeout: 30 * time.Minute}
|
||||||
|
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(30 * time.Minute))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("DistributedConfig flag-name constants", func() {
|
||||||
|
// Pin the kebab-case strings so a rename of the Go field name (or a
|
||||||
|
// CLI flag naming convention change) forces the constant to update,
|
||||||
|
// keeping the Validate error messages and any future operator-facing
|
||||||
|
// surface in sync with the actual CLI flag.
|
||||||
|
DescribeTable("flag name constants",
|
||||||
|
func(actual, expected string) {
|
||||||
|
Expect(actual).To(Equal(expected))
|
||||||
|
},
|
||||||
|
Entry("MCP tool timeout", config.FlagMCPToolTimeout, "mcp-tool-timeout"),
|
||||||
|
Entry("MCP discovery timeout", config.FlagMCPDiscoveryTimeout, "mcp-discovery-timeout"),
|
||||||
|
Entry("worker wait timeout", config.FlagWorkerWaitTimeout, "worker-wait-timeout"),
|
||||||
|
Entry("drain timeout", config.FlagDrainTimeout, "drain-timeout"),
|
||||||
|
Entry("health check interval", config.FlagHealthCheckInterval, "health-check-interval"),
|
||||||
|
Entry("stale node threshold", config.FlagStaleNodeThreshold, "stale-node-threshold"),
|
||||||
|
Entry("MCP CI job timeout", config.FlagMCPCIJobTimeout, "mcp-ci-job-timeout"),
|
||||||
|
Entry("backend install timeout", config.FlagBackendInstallTimeout, "backend-install-timeout"),
|
||||||
|
Entry("backend upgrade timeout", config.FlagBackendUpgradeTimeout, "backend-upgrade-timeout"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("DistributedConfig.Validate negative-duration errors", func() {
|
||||||
|
It("rejects a negative BackendInstallTimeout with the flag name in the error", func() {
|
||||||
|
c := config.DistributedConfig{
|
||||||
|
Enabled: true,
|
||||||
|
NatsURL: "nats://localhost:4222",
|
||||||
|
BackendInstallTimeout: -1 * time.Second,
|
||||||
|
}
|
||||||
|
err := c.Validate()
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring(config.FlagBackendInstallTimeout))
|
||||||
|
Expect(err.Error()).To(ContainSubstring("must not be negative"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects a negative BackendUpgradeTimeout with the flag name in the error", func() {
|
||||||
|
c := config.DistributedConfig{
|
||||||
|
Enabled: true,
|
||||||
|
NatsURL: "nats://localhost:4222",
|
||||||
|
BackendUpgradeTimeout: -1 * time.Second,
|
||||||
|
}
|
||||||
|
err := c.Validate()
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring(config.FlagBackendUpgradeTimeout))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("accepts all-zero durations as valid (defaults apply)", func() {
|
||||||
|
c := config.DistributedConfig{
|
||||||
|
Enabled: true,
|
||||||
|
NatsURL: "nats://localhost:4222",
|
||||||
|
}
|
||||||
|
Expect(c.Validate()).To(Succeed())
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,6 +1,37 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
type Gallery struct {
|
// GalleryVerification declares the keyless-cosign signature policy that
|
||||||
URL string `json:"url" yaml:"url"`
|
// every OCI backend image fetched from this gallery must satisfy.
|
||||||
Name string `json:"name" yaml:"name"`
|
//
|
||||||
|
// Verification is opt-in: galleries without a Verification block install
|
||||||
|
// backends with no signature check (the downloader logs a warning when
|
||||||
|
// LOCALAI_REQUIRE_BACKEND_INTEGRITY is unset; that flag turns the warning
|
||||||
|
// into a hard error).
|
||||||
|
//
|
||||||
|
// Identity matching: set Issuer (exact) or IssuerRegex, AND Identity
|
||||||
|
// (exact) or IdentityRegex. For GitHub Actions keyless signing the
|
||||||
|
// typical shape is:
|
||||||
|
//
|
||||||
|
// verification:
|
||||||
|
// issuer: "https://token.actions.githubusercontent.com"
|
||||||
|
// identity_regex: "^https://github\\.com/mudler/local-ai-backends/\\.github/workflows/build\\.yaml@refs/heads/master$"
|
||||||
|
// not_before: "2026-05-01T00:00:00Z"
|
||||||
|
//
|
||||||
|
// NotBefore is the revocation lever: advance it to invalidate every
|
||||||
|
// signature produced before a known compromise window. Keyless cosign
|
||||||
|
// certs are ephemeral so there is no CA-side revocation.
|
||||||
|
type GalleryVerification struct {
|
||||||
|
Issuer string `json:"issuer,omitempty" yaml:"issuer,omitempty"`
|
||||||
|
IssuerRegex string `json:"issuer_regex,omitempty" yaml:"issuer_regex,omitempty"`
|
||||||
|
Identity string `json:"identity,omitempty" yaml:"identity,omitempty"`
|
||||||
|
IdentityRegex string `json:"identity_regex,omitempty" yaml:"identity_regex,omitempty"`
|
||||||
|
|
||||||
|
// NotBefore is an RFC3339 timestamp. Empty disables the time check.
|
||||||
|
NotBefore string `json:"not_before,omitempty" yaml:"not_before,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Gallery struct {
|
||||||
|
URL string `json:"url" yaml:"url"`
|
||||||
|
Name string `json:"name" yaml:"name"`
|
||||||
|
Verification *GalleryVerification `json:"verification,omitempty" yaml:"verification,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,13 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
|||||||
cfg.modelTemplate = chatTemplate.ValueString()
|
cfg.modelTemplate = chatTemplate.ValueString()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Auto-enable Multi-Token Prediction (ggml-org/llama.cpp#22673) when the
|
||||||
|
// GGUF carries an embedded MTP head. Skipped silently for non-MTP models
|
||||||
|
// and when the user already configured a spec_type.
|
||||||
|
if n, ok := HasEmbeddedMTPHead(f); ok {
|
||||||
|
ApplyMTPDefaults(cfg, n)
|
||||||
|
}
|
||||||
|
|
||||||
// Thinking support detection is done after model load via DetectThinkingSupportFromBackend
|
// Thinking support detection is done after model load via DetectThinkingSupportFromBackend
|
||||||
|
|
||||||
// template estimations
|
// template estimations
|
||||||
|
|||||||
@@ -136,4 +136,36 @@ var _ = Describe("Backend hooks and parser defaults", func() {
|
|||||||
Expect(cfg.EngineArgs["enable_chunked_prefill"]).To(Equal(true))
|
Expect(cfg.EngineArgs["enable_chunked_prefill"]).To(Equal(true))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("PromptCacheAll default", func() {
|
||||||
|
It("defaults to true when omitted from YAML", func() {
|
||||||
|
cfg := &ModelConfig{}
|
||||||
|
cfg.SetDefaults()
|
||||||
|
|
||||||
|
Expect(cfg.PromptCacheAll).NotTo(BeNil())
|
||||||
|
Expect(*cfg.PromptCacheAll).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("preserves an explicit false from YAML", func() {
|
||||||
|
falseV := false
|
||||||
|
cfg := &ModelConfig{
|
||||||
|
LLMConfig: LLMConfig{PromptCacheAll: &falseV},
|
||||||
|
}
|
||||||
|
cfg.SetDefaults()
|
||||||
|
|
||||||
|
Expect(cfg.PromptCacheAll).NotTo(BeNil())
|
||||||
|
Expect(*cfg.PromptCacheAll).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("preserves an explicit true from YAML", func() {
|
||||||
|
trueV := true
|
||||||
|
cfg := &ModelConfig{
|
||||||
|
LLMConfig: LLMConfig{PromptCacheAll: &trueV},
|
||||||
|
}
|
||||||
|
cfg.SetDefaults()
|
||||||
|
|
||||||
|
Expect(cfg.PromptCacheAll).NotTo(BeNil())
|
||||||
|
Expect(*cfg.PromptCacheAll).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ type LLMConfig struct {
|
|||||||
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
|
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
|
||||||
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
|
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
|
||||||
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
|
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
|
||||||
PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
|
PromptCacheAll *bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
|
||||||
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
|
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
|
||||||
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
|
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
|
||||||
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
|
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
|
||||||
@@ -494,6 +494,13 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
|||||||
cfg.Reranking = &falseV
|
cfg.Reranking = &falseV
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.PromptCacheAll == nil {
|
||||||
|
// Match upstream llama.cpp's default (common/common.h: cache_prompt = true)
|
||||||
|
// and let cache_idle_slots / kv_unified actually do useful work; users can
|
||||||
|
// opt out with an explicit `prompt_cache_all: false` in the model YAML.
|
||||||
|
cfg.PromptCacheAll = &trueV
|
||||||
|
}
|
||||||
|
|
||||||
if threads == 0 {
|
if threads == 0 {
|
||||||
// Threads can't be 0
|
// Threads can't be 0
|
||||||
threads = 4
|
threads = 4
|
||||||
|
|||||||
84
core/config/mtp.go
Normal file
84
core/config/mtp.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
gguf "github.com/gpustack/gguf-parser-go"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mtpSpecOptions lists the speculative-decoding option keys auto-applied when
|
||||||
|
// an MTP head is detected on a llama-cpp GGUF. Defaults track the upstream
|
||||||
|
// MTP PR (ggml-org/llama.cpp#22673):
|
||||||
|
//
|
||||||
|
// - spec_type:draft-mtp activates Multi-Token Prediction
|
||||||
|
// - spec_n_max:6 draft window
|
||||||
|
// - spec_p_min:0.75 pinned because upstream marked the 0.75 default
|
||||||
|
// with a "change to 0.0f" TODO; locking it here keeps acceptance
|
||||||
|
// thresholds stable across future bumps
|
||||||
|
var mtpSpecOptions = []string{
|
||||||
|
"spec_type:draft-mtp",
|
||||||
|
"spec_n_max:6",
|
||||||
|
"spec_p_min:0.75",
|
||||||
|
}
|
||||||
|
|
||||||
|
// MTPSpecOptions returns a copy of the option keys auto-applied when an MTP
|
||||||
|
// head is detected. Exported for testing and for the importer.
|
||||||
|
func MTPSpecOptions() []string {
|
||||||
|
out := make([]string, len(mtpSpecOptions))
|
||||||
|
copy(out, mtpSpecOptions)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a Multi-Token
|
||||||
|
// Prediction head. Detection reads `<arch>.nextn_predict_layers`, which is
|
||||||
|
// what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
|
||||||
|
// `conversion/qwen.py` MTP mixin. A positive layer count means the head is
|
||||||
|
// present in the same GGUF as the trunk.
|
||||||
|
func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
|
||||||
|
if f == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
arch := f.Architecture().Architecture
|
||||||
|
if arch == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
v, ok := f.Header.MetadataKV.Get(arch + ".nextn_predict_layers")
|
||||||
|
if !ok {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
n := gguf.ValueNumeric[uint32](v)
|
||||||
|
return n, n > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasSpecTypeOption returns true when the slice already contains a
|
||||||
|
// user-configured `spec_type:` / `speculative_type:` entry. Used to avoid
|
||||||
|
// clobbering an explicit choice with the MTP auto-defaults.
|
||||||
|
func hasSpecTypeOption(opts []string) bool {
|
||||||
|
for _, o := range opts {
|
||||||
|
if strings.HasPrefix(o, "spec_type:") || strings.HasPrefix(o, "speculative_type:") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyMTPDefaults appends the auto-MTP option keys to cfg.Options when none
|
||||||
|
// is already configured. It is a no-op when the user already picked a
|
||||||
|
// `spec_type` (either via YAML or via the importer's preferences flow).
|
||||||
|
//
|
||||||
|
// `layers` is the value read from `<arch>.nextn_predict_layers` and is only
|
||||||
|
// used for the diagnostic log line.
|
||||||
|
func ApplyMTPDefaults(cfg *ModelConfig, layers uint32) {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if hasSpecTypeOption(cfg.Options) {
|
||||||
|
xlog.Debug("[mtp] embedded MTP head detected but spec_type already configured; leaving user choice intact",
|
||||||
|
"name", cfg.Name, "nextn_layers", layers)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.Options = append(cfg.Options, mtpSpecOptions...)
|
||||||
|
xlog.Info("[mtp] embedded MTP head detected; enabling draft-mtp speculative decoding",
|
||||||
|
"name", cfg.Name, "nextn_layers", layers, "spec_n_max", 6, "spec_p_min", 0.75)
|
||||||
|
}
|
||||||
86
core/config/mtp_test.go
Normal file
86
core/config/mtp_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package config_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/mudler/LocalAI/core/config"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("MTP auto-defaults", func() {
|
||||||
|
Context("MTPSpecOptions", func() {
|
||||||
|
It("returns the upstream-recommended speculative tuple", func() {
|
||||||
|
Expect(MTPSpecOptions()).To(Equal([]string{
|
||||||
|
"spec_type:draft-mtp",
|
||||||
|
"spec_n_max:6",
|
||||||
|
"spec_p_min:0.75",
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns a defensive copy so callers cannot mutate the package default", func() {
|
||||||
|
opts := MTPSpecOptions()
|
||||||
|
opts[0] = "spec_type:none"
|
||||||
|
Expect(MTPSpecOptions()[0]).To(Equal("spec_type:draft-mtp"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("ApplyMTPDefaults", func() {
|
||||||
|
It("appends MTP options when nothing is configured", func() {
|
||||||
|
cfg := &ModelConfig{Name: "qwen-mtp"}
|
||||||
|
ApplyMTPDefaults(cfg, 1)
|
||||||
|
Expect(cfg.Options).To(Equal([]string{
|
||||||
|
"spec_type:draft-mtp",
|
||||||
|
"spec_n_max:6",
|
||||||
|
"spec_p_min:0.75",
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("preserves unrelated options already on the config", func() {
|
||||||
|
cfg := &ModelConfig{
|
||||||
|
Name: "qwen-mtp",
|
||||||
|
Options: []string{"use_jinja:true", "cache_reuse:256"},
|
||||||
|
}
|
||||||
|
ApplyMTPDefaults(cfg, 1)
|
||||||
|
Expect(cfg.Options).To(Equal([]string{
|
||||||
|
"use_jinja:true",
|
||||||
|
"cache_reuse:256",
|
||||||
|
"spec_type:draft-mtp",
|
||||||
|
"spec_n_max:6",
|
||||||
|
"spec_p_min:0.75",
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("is a no-op when the user already configured spec_type", func() {
|
||||||
|
cfg := &ModelConfig{
|
||||||
|
Name: "qwen-mtp",
|
||||||
|
Options: []string{"spec_type:ngram-simple", "use_jinja:true"},
|
||||||
|
}
|
||||||
|
ApplyMTPDefaults(cfg, 1)
|
||||||
|
Expect(cfg.Options).To(Equal([]string{
|
||||||
|
"spec_type:ngram-simple",
|
||||||
|
"use_jinja:true",
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("also respects the legacy speculative_type alias", func() {
|
||||||
|
cfg := &ModelConfig{
|
||||||
|
Name: "qwen-mtp",
|
||||||
|
Options: []string{"speculative_type:ngram-mod"},
|
||||||
|
}
|
||||||
|
ApplyMTPDefaults(cfg, 1)
|
||||||
|
Expect(cfg.Options).To(Equal([]string{"speculative_type:ngram-mod"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("tolerates a nil config", func() {
|
||||||
|
Expect(func() { ApplyMTPDefaults(nil, 1) }).ToNot(Panic())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("HasEmbeddedMTPHead", func() {
|
||||||
|
It("returns false on a nil GGUF file", func() {
|
||||||
|
n, ok := HasEmbeddedMTPHead(nil)
|
||||||
|
Expect(ok).To(BeFalse())
|
||||||
|
Expect(n).To(BeZero())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -38,6 +38,7 @@ type RuntimeSettings struct {
|
|||||||
Debug *bool `json:"debug,omitempty"`
|
Debug *bool `json:"debug,omitempty"`
|
||||||
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
||||||
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
||||||
|
TracingMaxBodyBytes *int `json:"tracing_max_body_bytes,omitempty"` // Per-body cap in bytes; 0 disables the cap
|
||||||
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
|
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
|
||||||
|
|
||||||
// Security/CORS settings
|
// Security/CORS settings
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/pkg/downloader"
|
"github.com/mudler/LocalAI/pkg/downloader"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/oci"
|
"github.com/mudler/LocalAI/pkg/oci"
|
||||||
|
"github.com/mudler/LocalAI/pkg/oci/cosignverify"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
cp "github.com/otiai10/copy"
|
cp "github.com/otiai10/copy"
|
||||||
@@ -102,8 +103,81 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// backendDownloadOptions translates the gallery's verification policy into
|
||||||
|
// downloader options, and gates the call on strict-integrity mode. Both
|
||||||
|
// InstallBackend and UpgradeBackend MUST route their download through these
|
||||||
|
// options — without them, the corresponding code path silently downloads
|
||||||
|
// and activates unverified backend bytes even when the gallery has a
|
||||||
|
// verification: policy configured.
|
||||||
|
//
|
||||||
|
// For OCI URIs with a verification policy, returns a slice containing
|
||||||
|
// downloader.WithImageVerifier(v) — the downloader will then run cosign
|
||||||
|
// signature verification between fetching the manifest and extracting
|
||||||
|
// layers (see pkg/downloader/uri.go OCI branch).
|
||||||
|
//
|
||||||
|
// For OCI URIs without a verification policy, or non-OCI URIs without a
|
||||||
|
// SHA256, the function either returns a non-fatal warning (requireIntegrity
|
||||||
|
// false) or fails the install (requireIntegrity true).
|
||||||
|
func backendDownloadOptions(config *GalleryBackend, requireIntegrity bool) ([]downloader.DownloadOption, error) {
|
||||||
|
uri := downloader.URI(config.URI)
|
||||||
|
hasVerification := config.Gallery.Verification != nil
|
||||||
|
hasSHA := config.SHA256 != ""
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case uri.LooksLikeOCI():
|
||||||
|
if !hasVerification {
|
||||||
|
if requireIntegrity {
|
||||||
|
return nil, fmt.Errorf("strict integrity: gallery %q has no verification policy for OCI backend %q (set verification: in the gallery YAML or disable --require-backend-integrity)",
|
||||||
|
config.Gallery.Name, config.Name)
|
||||||
|
}
|
||||||
|
xlog.Warn("installing OCI backend without signature verification",
|
||||||
|
"backend", config.Name, "gallery", config.Gallery.Name, "uri", config.URI)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
v, err := newGalleryVerifier(config.Gallery.Verification)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gallery %q verification policy: %w", config.Gallery.Name, err)
|
||||||
|
}
|
||||||
|
return []downloader.DownloadOption{downloader.WithImageVerifier(v)}, nil
|
||||||
|
|
||||||
|
case uri.LooksLikeDir():
|
||||||
|
// Local directory — out of scope for integrity checks.
|
||||||
|
return nil, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
if !hasSHA && requireIntegrity {
|
||||||
|
return nil, fmt.Errorf("strict integrity: backend %q has no SHA256 (gallery %q)",
|
||||||
|
config.Name, config.Gallery.Name)
|
||||||
|
}
|
||||||
|
// Non-strict: pkg/downloader already emits a warning when sha is empty.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newGalleryVerifier constructs a cosignverify.Verifier from the gallery
|
||||||
|
// policy. Parses NotBefore (RFC3339) here so YAML errors surface at install
|
||||||
|
// time rather than during signature verification.
|
||||||
|
func newGalleryVerifier(p *config.GalleryVerification) (*cosignverify.Verifier, error) {
|
||||||
|
pol := cosignverify.Policy{
|
||||||
|
Issuer: p.Issuer,
|
||||||
|
IssuerRegex: p.IssuerRegex,
|
||||||
|
Identity: p.Identity,
|
||||||
|
IdentityRegex: p.IdentityRegex,
|
||||||
|
}
|
||||||
|
if p.NotBefore != "" {
|
||||||
|
t, err := time.Parse(time.RFC3339, p.NotBefore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("not_before %q: %w", p.NotBefore, err)
|
||||||
|
}
|
||||||
|
pol.NotBefore = t
|
||||||
|
}
|
||||||
|
return cosignverify.NewVerifier(pol, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// InstallBackendFromGallery installs a backend from the gallery.
|
// InstallBackendFromGallery installs a backend from the gallery.
|
||||||
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
// requireIntegrity escalates a missing SHA256 / verification policy from a
|
||||||
|
// warning to a hard failure (see backendDownloadOptions).
|
||||||
|
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force, requireIntegrity bool) error {
|
||||||
if !force {
|
if !force {
|
||||||
// check if we already have the backend installed
|
// check if we already have the backend installed
|
||||||
backends, err := ListSystemBackends(systemState)
|
backends, err := ListSystemBackends(systemState)
|
||||||
@@ -149,7 +223,7 @@ func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery,
|
|||||||
xlog.Debug("Installing backend from meta backend", "name", name, "bestBackend", bestBackend.Name)
|
xlog.Debug("Installing backend from meta backend", "name", name, "bestBackend", bestBackend.Name)
|
||||||
|
|
||||||
// Then, let's install the best backend
|
// Then, let's install the best backend
|
||||||
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
|
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus, requireIntegrity); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,10 +249,10 @@ func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
|
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus, requireIntegrity)
|
||||||
}
|
}
|
||||||
|
|
||||||
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64), requireIntegrity bool) error {
|
||||||
// Get configurable fallback tag values from SystemState
|
// Get configurable fallback tag values from SystemState
|
||||||
latestTag, masterTag, devSuffix := getFallbackTagValues(systemState)
|
latestTag, masterTag, devSuffix := getFallbackTagValues(systemState)
|
||||||
|
|
||||||
@@ -213,6 +287,14 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
|||||||
return fmt.Errorf("failed to create base path: %v", err)
|
return fmt.Errorf("failed to create base path: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build the download options once and reuse for every retry path —
|
||||||
|
// mirrors and tag fallbacks must verify against the same gallery
|
||||||
|
// policy or we open a hole where a non-default URI bypasses the check.
|
||||||
|
downloadOpts, optsErr := backendDownloadOptions(config, requireIntegrity)
|
||||||
|
if optsErr != nil {
|
||||||
|
return fmt.Errorf("backend %q: %w", config.Name, optsErr)
|
||||||
|
}
|
||||||
|
|
||||||
uri := downloader.URI(config.URI)
|
uri := downloader.URI(config.URI)
|
||||||
// Check if it is a directory
|
// Check if it is a directory
|
||||||
if uri.LooksLikeDir() {
|
if uri.LooksLikeDir() {
|
||||||
@@ -222,7 +304,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath)
|
xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath)
|
||||||
if err := uri.DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err != nil {
|
if err := uri.DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus, downloadOpts...); err != nil {
|
||||||
xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err)
|
xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err)
|
||||||
|
|
||||||
// resetBackendPath cleans up partial state from a failed OCI extraction
|
// resetBackendPath cleans up partial state from a failed OCI extraction
|
||||||
@@ -243,7 +325,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
resetBackendPath()
|
resetBackendPath()
|
||||||
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err == nil {
|
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus, downloadOpts...); err == nil {
|
||||||
success = true
|
success = true
|
||||||
xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath)
|
xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath)
|
||||||
break
|
break
|
||||||
@@ -256,7 +338,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
|||||||
if fallbackURI != string(config.URI) {
|
if fallbackURI != string(config.URI) {
|
||||||
resetBackendPath()
|
resetBackendPath()
|
||||||
xlog.Info("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
|
xlog.Info("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
|
||||||
if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err == nil {
|
if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus, downloadOpts...); err == nil {
|
||||||
xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath)
|
xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath)
|
||||||
success = true
|
success = true
|
||||||
} else {
|
} else {
|
||||||
@@ -265,7 +347,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
|||||||
resetBackendPath()
|
resetBackendPath()
|
||||||
devFallbackURI := fallbackURI + "-" + devSuffix
|
devFallbackURI := fallbackURI + "-" + devSuffix
|
||||||
xlog.Info("Trying development fallback URI", "fallback", devFallbackURI)
|
xlog.Info("Trying development fallback URI", "fallback", devFallbackURI)
|
||||||
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err == nil {
|
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus, downloadOpts...); err == nil {
|
||||||
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
|
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
|
||||||
success = true
|
success = true
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -117,13 +117,13 @@ var _ = Describe("Gallery Backends", func() {
|
|||||||
|
|
||||||
Describe("InstallBackendFromGallery", func() {
|
Describe("InstallBackendFromGallery", func() {
|
||||||
It("should return error when backend is not found", func() {
|
It("should return error when backend is not found", func() {
|
||||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
|
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true, false)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
|
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should install backend from gallery", func() {
|
It("should install backend from gallery", func() {
|
||||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
|
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true, false)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
|
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
|
||||||
})
|
})
|
||||||
@@ -545,7 +545,7 @@ var _ = Describe("Gallery Backends", func() {
|
|||||||
VRAM: 1000000000000,
|
VRAM: 1000000000000,
|
||||||
Backend: system.Backend{BackendsPath: tempDir},
|
Backend: system.Backend{BackendsPath: tempDir},
|
||||||
}
|
}
|
||||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true, false)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||||
@@ -625,7 +625,7 @@ var _ = Describe("Gallery Backends", func() {
|
|||||||
VRAM: 1000000000000,
|
VRAM: 1000000000000,
|
||||||
Backend: system.Backend{BackendsPath: tempDir},
|
Backend: system.Backend{BackendsPath: tempDir},
|
||||||
}
|
}
|
||||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true, false)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||||
@@ -709,7 +709,7 @@ var _ = Describe("Gallery Backends", func() {
|
|||||||
VRAM: 1000000000000,
|
VRAM: 1000000000000,
|
||||||
Backend: system.Backend{BackendsPath: tempDir},
|
Backend: system.Backend{BackendsPath: tempDir},
|
||||||
}
|
}
|
||||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true, false)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||||
@@ -808,7 +808,7 @@ var _ = Describe("Gallery Backends", func() {
|
|||||||
system.WithBackendPath(newPath),
|
system.WithBackendPath(newPath),
|
||||||
)
|
)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil, false)
|
||||||
Expect(newPath).To(BeADirectory())
|
Expect(newPath).To(BeADirectory())
|
||||||
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
|
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
|
||||||
})
|
})
|
||||||
@@ -840,7 +840,7 @@ var _ = Describe("Gallery Backends", func() {
|
|||||||
system.WithBackendPath(tempDir),
|
system.WithBackendPath(tempDir),
|
||||||
)
|
)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil, false)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||||
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
|
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
|
||||||
@@ -873,7 +873,7 @@ var _ = Describe("Gallery Backends", func() {
|
|||||||
|
|
||||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
|
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
|
||||||
|
|
||||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil, false)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||||
})
|
})
|
||||||
@@ -894,7 +894,7 @@ var _ = Describe("Gallery Backends", func() {
|
|||||||
system.WithBackendPath(tempDir),
|
system.WithBackendPath(tempDir),
|
||||||
)
|
)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil, false)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ var _ = Describe("Backend versioning", func() {
|
|||||||
backend.URI = srcDir
|
backend.URI = srcDir
|
||||||
backend.Version = "1.2.3"
|
backend.Version = "1.2.3"
|
||||||
|
|
||||||
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil, false)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
// Read the metadata file and check version
|
// Read the metadata file and check version
|
||||||
@@ -74,7 +74,7 @@ var _ = Describe("Backend versioning", func() {
|
|||||||
backend.URI = srcDir
|
backend.URI = srcDir
|
||||||
backend.Version = "2.0.0"
|
backend.Version = "2.0.0"
|
||||||
|
|
||||||
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil, false)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
metadataPath := filepath.Join(tempDir, "test-backend-uri", "metadata.json")
|
metadataPath := filepath.Join(tempDir, "test-backend-uri", "metadata.json")
|
||||||
@@ -100,7 +100,7 @@ var _ = Describe("Backend versioning", func() {
|
|||||||
backend.URI = srcDir
|
backend.URI = srcDir
|
||||||
// Version intentionally left empty
|
// Version intentionally left empty
|
||||||
|
|
||||||
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil)
|
err = gallery.InstallBackend(context.Background(), systemState, modelLoader, backend, nil, false)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
metadataPath := filepath.Join(tempDir, "test-backend-noversion", "metadata.json")
|
metadataPath := filepath.Join(tempDir, "test-backend-noversion", "metadata.json")
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package importers
|
package importers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gguf "github.com/gpustack/gguf-parser-go"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
@@ -261,6 +264,13 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
|||||||
// Apply per-model-family inference parameter defaults
|
// Apply per-model-family inference parameter defaults
|
||||||
config.ApplyInferenceDefaults(&modelConfig, details.URI)
|
config.ApplyInferenceDefaults(&modelConfig, details.URI)
|
||||||
|
|
||||||
|
// Auto-detect Multi-Token Prediction heads (ggml-org/llama.cpp#22673) and
|
||||||
|
// enable speculative decoding. Mirrors the load-time hook so freshly
|
||||||
|
// imported configs already carry spec_type:draft-mtp before the model is
|
||||||
|
// ever loaded - users see it in the YAML preview rather than discovering
|
||||||
|
// it after the first start.
|
||||||
|
maybeApplyMTPDefaults(&modelConfig, details, &cfg)
|
||||||
|
|
||||||
data, err := yaml.Marshal(modelConfig)
|
data, err := yaml.Marshal(modelConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return gallery.ModelConfig{}, err
|
return gallery.ModelConfig{}, err
|
||||||
@@ -291,6 +301,85 @@ func pickPreferredGroup(groups []hfapi.ShardGroup, prefs []string) *hfapi.ShardG
|
|||||||
return &groups[len(groups)-1]
|
return &groups[len(groups)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maybeApplyMTPDefaults parses the picked GGUF header (range-fetched over
|
||||||
|
// HTTP for HF/URL imports) and, if the file declares a Multi-Token Prediction
|
||||||
|
// head, appends the auto-MTP option keys to modelConfig.Options. Failures
|
||||||
|
// during the probe are non-fatal: the importer keeps the config without MTP
|
||||||
|
// so an unrelated network blip or weird header doesn't break the import.
|
||||||
|
//
|
||||||
|
// OCI/Ollama URIs are skipped because the artifact isn't directly fetchable
|
||||||
|
// as a GGUF byte stream - the load-time hook (core/config/gguf.go) covers
|
||||||
|
// those once the model is materialised on disk.
|
||||||
|
func maybeApplyMTPDefaults(modelConfig *config.ModelConfig, details Details, cfg *gallery.ModelConfig) {
|
||||||
|
probeURL := pickMTPProbeURL(details, cfg)
|
||||||
|
if probeURL == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
xlog.Debug("[mtp-importer] panic while probing GGUF header", "uri", probeURL, "recover", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
f, err := gguf.ParseGGUFFileRemote(ctx, probeURL)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Debug("[mtp-importer] failed to read remote GGUF header for MTP detection", "uri", probeURL, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n, ok := config.HasEmbeddedMTPHead(f)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.ApplyMTPDefaults(modelConfig, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickMTPProbeURL returns an HTTP(S) URL pointing at the main (non-mmproj)
|
||||||
|
// GGUF shard that should be inspected for an MTP head, or "" when no
|
||||||
|
// suitable URL is available. Custom URI schemes (`huggingface://`,
|
||||||
|
// `ollama://`, etc.) are run through `downloader.URI.ResolveURL` so the
|
||||||
|
// resulting URL is something `gguf.ParseGGUFFileRemote` can actually open.
|
||||||
|
// OCI/Ollama URIs are skipped because the artifact is not directly
|
||||||
|
// streamable as a GGUF byte range.
|
||||||
|
func pickMTPProbeURL(details Details, cfg *gallery.ModelConfig) string {
|
||||||
|
uri := downloader.URI(details.URI)
|
||||||
|
|
||||||
|
if uri.LooksLikeOCI() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(strings.ToLower(details.URI), ".gguf") {
|
||||||
|
return resolveHTTPProbe(details.URI)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range cfg.Files {
|
||||||
|
lower := strings.ToLower(f.Filename)
|
||||||
|
if strings.Contains(lower, "mmproj") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(lower, ".gguf") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return resolveHTTPProbe(f.URI)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveHTTPProbe resolves an importer-side URI to the HTTP(S) URL that
|
||||||
|
// `gguf.ParseGGUFFileRemote` can range-fetch. Returns "" if the URI can't
|
||||||
|
// be reduced to an HTTP(S) endpoint (e.g. local path, unsupported scheme).
|
||||||
|
func resolveHTTPProbe(uri string) string {
|
||||||
|
resolved := downloader.URI(uri).ResolveURL()
|
||||||
|
if downloader.URI(resolved).LooksLikeHTTPURL() {
|
||||||
|
return resolved
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// appendShardGroup copies every shard of group into cfg.Files under dest,
|
// appendShardGroup copies every shard of group into cfg.Files under dest,
|
||||||
// skipping any entry whose target filename is already present so repeated
|
// skipping any entry whose target filename is already present so repeated
|
||||||
// calls (e.g. the rare case of mmproj + model picking the same group)
|
// calls (e.g. the rare case of mmproj + model picking the same group)
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func InstallModelFromGallery(
|
|||||||
modelGalleries, backendGalleries []lconfig.Gallery,
|
modelGalleries, backendGalleries []lconfig.Gallery,
|
||||||
systemState *system.SystemState,
|
systemState *system.SystemState,
|
||||||
modelLoader *model.ModelLoader,
|
modelLoader *model.ModelLoader,
|
||||||
name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error {
|
name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend, requireBackendIntegrity bool) error {
|
||||||
|
|
||||||
applyModel := func(model *GalleryModel) error {
|
applyModel := func(model *GalleryModel) error {
|
||||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||||
@@ -137,7 +137,7 @@ func InstallModelFromGallery(
|
|||||||
if automaticallyInstallBackend && installedModel.Backend != "" {
|
if automaticallyInstallBackend && installedModel.Backend != "" {
|
||||||
xlog.Debug("Installing backend", "backend", installedModel.Backend)
|
xlog.Debug("Installing backend", "backend", installedModel.Backend)
|
||||||
|
|
||||||
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false, requireBackendIntegrity); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ var _ = Describe("Model test", func() {
|
|||||||
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
|
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
|
||||||
Expect(models[0].Installed).To(BeFalse())
|
Expect(models[0].Installed).To(BeFalse())
|
||||||
|
|
||||||
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true, false)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ func summarizeNodeDrift(nodes []NodeBackendRef) (majority struct{ version, diges
|
|||||||
|
|
||||||
// UpgradeBackend upgrades a single backend to the latest gallery version using
|
// UpgradeBackend upgrades a single backend to the latest gallery version using
|
||||||
// an atomic swap with backup-based rollback on failure.
|
// an atomic swap with backup-based rollback on failure.
|
||||||
func UpgradeBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, galleries []config.Gallery, backendName string, downloadStatus func(string, string, string, float64)) error {
|
func UpgradeBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, galleries []config.Gallery, backendName string, downloadStatus func(string, string, string, float64), requireIntegrity bool) error {
|
||||||
// Look up the installed backend
|
// Look up the installed backend
|
||||||
installedBackends, err := ListSystemBackends(systemState)
|
installedBackends, err := ListSystemBackends(systemState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -251,7 +251,7 @@ func UpgradeBackend(ctx context.Context, systemState *system.SystemState, modelL
|
|||||||
// If this is a meta backend, recursively upgrade the concrete backend it points to
|
// If this is a meta backend, recursively upgrade the concrete backend it points to
|
||||||
if installed.Metadata != nil && installed.Metadata.MetaBackendFor != "" {
|
if installed.Metadata != nil && installed.Metadata.MetaBackendFor != "" {
|
||||||
xlog.Info("Meta backend detected, upgrading concrete backend", "meta", backendName, "concrete", installed.Metadata.MetaBackendFor)
|
xlog.Info("Meta backend detected, upgrading concrete backend", "meta", backendName, "concrete", installed.Metadata.MetaBackendFor)
|
||||||
return UpgradeBackend(ctx, systemState, modelLoader, galleries, installed.Metadata.MetaBackendFor, downloadStatus)
|
return UpgradeBackend(ctx, systemState, modelLoader, galleries, installed.Metadata.MetaBackendFor, downloadStatus, requireIntegrity)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the gallery entry
|
// Find the gallery entry
|
||||||
@@ -265,6 +265,16 @@ func UpgradeBackend(ctx context.Context, systemState *system.SystemState, modelL
|
|||||||
return fmt.Errorf("no gallery entry found for backend %q", backendName)
|
return fmt.Errorf("no gallery entry found for backend %q", backendName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resolve integrity options (cosign verifier for OCI URIs, strict-mode
|
||||||
|
// gate for missing SHA256/policy) BEFORE writing anything to disk.
|
||||||
|
// Without this, the upgrade path would atomically swap in an
|
||||||
|
// unverified backend even when the gallery has a verification policy
|
||||||
|
// — see backendDownloadOptions in backends.go.
|
||||||
|
downloadOpts, err := backendDownloadOptions(galleryEntry, requireIntegrity)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("upgrade %q: %w", backendName, err)
|
||||||
|
}
|
||||||
|
|
||||||
backendPath := filepath.Join(systemState.Backend.BackendsPath, backendName)
|
backendPath := filepath.Join(systemState.Backend.BackendsPath, backendName)
|
||||||
tmpPath := backendPath + ".upgrade-tmp"
|
tmpPath := backendPath + ".upgrade-tmp"
|
||||||
backupPath := backendPath + ".backup"
|
backupPath := backendPath + ".backup"
|
||||||
@@ -285,7 +295,7 @@ func UpgradeBackend(ctx context.Context, systemState *system.SystemState, modelL
|
|||||||
return fmt.Errorf("failed to copy backend from directory: %w", err)
|
return fmt.Errorf("failed to copy backend from directory: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := uri.DownloadFileWithContext(ctx, tmpPath, "", 1, 1, downloadStatus); err != nil {
|
if err := uri.DownloadFileWithContext(ctx, tmpPath, galleryEntry.SHA256, 1, 1, downloadStatus, downloadOpts...); err != nil {
|
||||||
os.RemoveAll(tmpPath)
|
os.RemoveAll(tmpPath)
|
||||||
return fmt.Errorf("failed to download backend: %w", err)
|
return fmt.Errorf("failed to download backend: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -383,7 +383,7 @@ var _ = Describe("Upgrade Detection and Execution", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ml := model.NewModelLoader(systemState)
|
ml := model.NewModelLoader(systemState)
|
||||||
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil)
|
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil, false)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
// Verify run.sh was updated
|
// Verify run.sh was updated
|
||||||
@@ -417,7 +417,7 @@ var _ = Describe("Upgrade Detection and Execution", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ml := model.NewModelLoader(systemState)
|
ml := model.NewModelLoader(systemState)
|
||||||
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil)
|
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil, false)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
|
|
||||||
// Verify v1 is still intact
|
// Verify v1 is still intact
|
||||||
@@ -432,5 +432,41 @@ var _ = Describe("Upgrade Detection and Execution", func() {
|
|||||||
Expect(json.Unmarshal(metaData, &meta)).To(Succeed())
|
Expect(json.Unmarshal(metaData, &meta)).To(Succeed())
|
||||||
Expect(meta.Version).To(Equal("1.0.0"))
|
Expect(meta.Version).To(Equal("1.0.0"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Regression: an earlier version of UpgradeBackend wrote the
|
||||||
|
// downloaded bytes to disk without going through
|
||||||
|
// backendDownloadOptions, so the gallery's verification policy
|
||||||
|
// (and strict-integrity gate) didn't apply on upgrade. This test
|
||||||
|
// pins the upgrade path to the same integrity gate as installs:
|
||||||
|
// strict mode + an OCI URI without a verification: block must
|
||||||
|
// hard-fail *before* anything is downloaded or swapped in.
|
||||||
|
It("should refuse to upgrade an OCI backend that bypasses integrity in strict mode", func() {
|
||||||
|
installBackendWithVersion("my-backend", "1.0.0", "#!/bin/sh\necho v1")
|
||||||
|
|
||||||
|
// OCI URI, no Gallery.Verification → backendDownloadOptions
|
||||||
|
// returns a strict-integrity error before any network call.
|
||||||
|
writeGalleryYAML([]GalleryBackend{
|
||||||
|
{
|
||||||
|
Metadata: Metadata{
|
||||||
|
Name: "my-backend",
|
||||||
|
},
|
||||||
|
URI: "oci://example.invalid/missing:never-fetched",
|
||||||
|
Version: "2.0.0",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ml := model.NewModelLoader(systemState)
|
||||||
|
err := UpgradeBackend(context.Background(), systemState, ml, galleries, "my-backend", nil, true)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("strict integrity"))
|
||||||
|
|
||||||
|
// The installed v1 must be untouched — the upgrade should
|
||||||
|
// have aborted before writing anything.
|
||||||
|
content, err := os.ReadFile(filepath.Join(backendsPath, "my-backend", "run.sh"))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(string(content)).To(Equal("#!/bin/sh\necho v1"))
|
||||||
|
Expect(filepath.Join(backendsPath, "my-backend.upgrade-tmp")).NotTo(BeAnExistingFile())
|
||||||
|
Expect(filepath.Join(backendsPath, "my-backend.backup")).NotTo(BeAnExistingFile())
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||||
"github.com/mudler/LocalAI/core/services/nodes"
|
"github.com/mudler/LocalAI/core/services/nodes"
|
||||||
"github.com/mudler/LocalAI/core/services/quantization"
|
"github.com/mudler/LocalAI/core/services/quantization"
|
||||||
|
"github.com/mudler/LocalAI/pkg/signals"
|
||||||
|
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
@@ -267,9 +268,12 @@ func API(application *application.Application) (*echo.Echo, error) {
|
|||||||
e.Static("/generated-videos", videoPath)
|
e.Static("/generated-videos", videoPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize usage recording when auth DB is available
|
// Initialize usage recording when auth DB is available, and ensure the
|
||||||
|
// batcher drains its in-memory queue on graceful shutdown so the last
|
||||||
|
// few seconds of usage don't disappear when the process exits.
|
||||||
if application.AuthDB() != nil {
|
if application.AuthDB() != nil {
|
||||||
httpMiddleware.InitUsageRecorder(application.AuthDB())
|
httpMiddleware.InitUsageRecorder(application.AuthDB())
|
||||||
|
signals.RegisterGracefulTerminationHandler(httpMiddleware.ShutdownUsageRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is
|
// Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is
|
||||||
@@ -403,7 +407,7 @@ func API(application *application.Application) (*echo.Echo, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||||
|
|
||||||
// Distributed SSE routes (job progress + agent events via NATS)
|
// Distributed SSE routes (job progress + agent events via NATS)
|
||||||
if d := application.Distributed(); d != nil {
|
if d := application.Distributed(); d != nil {
|
||||||
|
|||||||
@@ -38,9 +38,15 @@ func InitDB(databaseURL string) (*gorm.DB, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Backfill: users created before the provider column existed have an empty
|
// Backfill: users created before the provider column existed have an empty
|
||||||
// provider — treat them as local accounts so the UI can identify them.
|
// provider - treat them as local accounts so the UI can identify them.
|
||||||
db.Exec("UPDATE users SET provider = ? WHERE provider = '' OR provider IS NULL", ProviderLocal)
|
db.Exec("UPDATE users SET provider = ? WHERE provider = '' OR provider IS NULL", ProviderLocal)
|
||||||
|
|
||||||
|
// Backfill: pre-feature usage_records have no source column. Classify them so the
|
||||||
|
// new per-source aggregators include them.
|
||||||
|
if err := BackfillUsageSource(db); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to backfill usage source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create composite index on users(provider, subject) for fast OAuth lookups
|
// Create composite index on users(provider, subject) for fast OAuth lookups
|
||||||
if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil {
|
if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil {
|
||||||
// Ignore error on postgres if index already exists
|
// Ignore error on postgres if index already exists
|
||||||
|
|||||||
@@ -16,8 +16,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
contextKeyUser = "auth_user"
|
contextKeyUser = "auth_user"
|
||||||
contextKeyRole = "auth_role"
|
contextKeyRole = "auth_role"
|
||||||
|
contextKeyAPIKey = "auth_apikey"
|
||||||
|
contextKeySource = "auth_source"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Middleware returns an Echo middleware that handles authentication.
|
// Middleware returns an Echo middleware that handles authentication.
|
||||||
@@ -75,6 +77,7 @@ func Middleware(db *gorm.DB, appConfig *config.ApplicationConfig) echo.Middlewar
|
|||||||
}
|
}
|
||||||
c.Set(contextKeyUser, syntheticUser)
|
c.Set(contextKeyUser, syntheticUser)
|
||||||
c.Set(contextKeyRole, RoleAdmin)
|
c.Set(contextKeyRole, RoleAdmin)
|
||||||
|
c.Set(contextKeySource, UsageSourceLegacy)
|
||||||
authenticated = true
|
authenticated = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -213,6 +216,20 @@ func GetUserRole(c echo.Context) string {
|
|||||||
return role
|
return role
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAPIKey returns the resolved API key from the echo context, or nil.
|
||||||
|
// Nil for session-cookie and legacy-env-key authentication.
|
||||||
|
func GetAPIKey(c echo.Context) *UserAPIKey {
|
||||||
|
k, _ := c.Get(contextKeyAPIKey).(*UserAPIKey)
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSource returns the request's authentication source: UsageSourceAPIKey,
|
||||||
|
// UsageSourceWeb, UsageSourceLegacy, or empty if no authentication was performed.
|
||||||
|
func GetSource(c echo.Context) string {
|
||||||
|
s, _ := c.Get(contextKeySource).(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// RequireRouteFeature returns a global middleware that checks the user has access
|
// RequireRouteFeature returns a global middleware that checks the user has access
|
||||||
// to the feature required by the matched route. It uses the RouteFeatureRegistry
|
// to the feature required by the matched route. It uses the RouteFeatureRegistry
|
||||||
// to look up the required feature for each route pattern + HTTP method.
|
// to look up the required feature for each route pattern + HTTP method.
|
||||||
@@ -421,47 +438,67 @@ func RequireQuota(db *gorm.DB) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tryAuthenticate attempts to authenticate the request using the database.
|
// tryAuthenticate attempts to authenticate the request using the database.
|
||||||
|
//
|
||||||
|
// On success it returns the user and, as a side effect, sets the following
|
||||||
|
// values on the Echo context:
|
||||||
|
// - contextKeySource ("auth_source"): always set, one of UsageSourceWeb /
|
||||||
|
// UsageSourceAPIKey. UsageSourceLegacy is set elsewhere by the parent
|
||||||
|
// Middleware when a legacy env key matches.
|
||||||
|
// - contextKeyAPIKey ("auth_apikey"): set to the resolved *UserAPIKey for
|
||||||
|
// named-key branches (Bearer, x-api-key, xi-api-key, token cookie).
|
||||||
|
// - "_auth_session": session record, used by Middleware to drive cookie
|
||||||
|
// rotation. Only set on the session-cookie branch.
|
||||||
|
//
|
||||||
|
// contextKeyUser and contextKeyRole are populated by the parent Middleware
|
||||||
|
// after this function returns.
|
||||||
func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User {
|
func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User {
|
||||||
hmacSecret := appConfig.Auth.APIKeyHMACSecret
|
hmacSecret := appConfig.Auth.APIKeyHMACSecret
|
||||||
|
|
||||||
// a. Session cookie
|
// a. Session cookie -> web UI
|
||||||
if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" {
|
if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" {
|
||||||
if user, session := ValidateSession(db, cookie.Value, hmacSecret); user != nil {
|
if user, session := ValidateSession(db, cookie.Value, hmacSecret); user != nil {
|
||||||
// Store session for rotation check in middleware
|
// Store session for rotation check in middleware
|
||||||
c.Set("_auth_session", session)
|
c.Set("_auth_session", session)
|
||||||
|
c.Set(contextKeySource, UsageSourceWeb)
|
||||||
return user
|
return user
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// b. Authorization: Bearer token
|
// b. Authorization: Bearer
|
||||||
authHeader := c.Request().Header.Get("Authorization")
|
authHeader := c.Request().Header.Get("Authorization")
|
||||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
|
||||||
// Try as session ID first
|
// b1. Session token via Bearer -> still web UI
|
||||||
if user, _ := ValidateSession(db, token, hmacSecret); user != nil {
|
if user, _ := ValidateSession(db, token, hmacSecret); user != nil {
|
||||||
|
c.Set(contextKeySource, UsageSourceWeb)
|
||||||
return user
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try as user API key
|
// b2. Named API key
|
||||||
if key, err := ValidateAPIKey(db, token, hmacSecret); err == nil {
|
if key, err := ValidateAPIKey(db, token, hmacSecret); err == nil {
|
||||||
|
c.Set(contextKeySource, UsageSourceAPIKey)
|
||||||
|
c.Set(contextKeyAPIKey, key)
|
||||||
return &key.User
|
return &key.User
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// c. x-api-key / xi-api-key headers
|
// c. x-api-key / xi-api-key -> named API key
|
||||||
for _, header := range []string{"x-api-key", "xi-api-key"} {
|
for _, header := range []string{"x-api-key", "xi-api-key"} {
|
||||||
if key := c.Request().Header.Get(header); key != "" {
|
if k := c.Request().Header.Get(header); k != "" {
|
||||||
if apiKey, err := ValidateAPIKey(db, key, hmacSecret); err == nil {
|
if apiKey, err := ValidateAPIKey(db, k, hmacSecret); err == nil {
|
||||||
|
c.Set(contextKeySource, UsageSourceAPIKey)
|
||||||
|
c.Set(contextKeyAPIKey, apiKey)
|
||||||
return &apiKey.User
|
return &apiKey.User
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// d. token cookie (legacy)
|
// d. token cookie -> named API key
|
||||||
if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" {
|
if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" {
|
||||||
// Try as user API key
|
|
||||||
if key, err := ValidateAPIKey(db, cookie.Value, hmacSecret); err == nil {
|
if key, err := ValidateAPIKey(db, cookie.Value, hmacSecret); err == nil {
|
||||||
|
c.Set(contextKeySource, UsageSourceAPIKey)
|
||||||
|
c.Set(contextKeyAPIKey, key)
|
||||||
return &key.User
|
return &key.User
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -303,4 +303,122 @@ var _ = Describe("Auth Middleware", func() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Describe("auth context plumbing for usage source", func() {
|
||||||
|
// probeApp builds a minimal echo app with the auth middleware and a single
|
||||||
|
// "/probe" route that captures the user, source, and apikey from context.
|
||||||
|
type probe struct {
|
||||||
|
user *auth.User
|
||||||
|
source string
|
||||||
|
key *auth.UserAPIKey
|
||||||
|
}
|
||||||
|
probeApp := func(db *gorm.DB, appConfig *config.ApplicationConfig, p *probe) *echo.Echo {
|
||||||
|
e := echo.New()
|
||||||
|
e.Use(auth.Middleware(db, appConfig))
|
||||||
|
e.GET("/probe", func(c echo.Context) error {
|
||||||
|
p.user = auth.GetUser(c)
|
||||||
|
p.source = auth.GetSource(c)
|
||||||
|
p.key = auth.GetAPIKey(c)
|
||||||
|
return c.NoContent(http.StatusOK)
|
||||||
|
})
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
It("session cookie sets source=web, apikey=nil", func() {
|
||||||
|
db := testDB()
|
||||||
|
appConfig := config.NewApplicationConfig()
|
||||||
|
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||||
|
token := createTestSession(db, user.ID)
|
||||||
|
|
||||||
|
var p probe
|
||||||
|
app := probeApp(db, appConfig, &p)
|
||||||
|
rec := doRequest(app, http.MethodGet, "/probe", withSessionCookie(token))
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(p.user).ToNot(BeNil())
|
||||||
|
Expect(p.user.ID).To(Equal(user.ID))
|
||||||
|
Expect(p.source).To(Equal(auth.UsageSourceWeb))
|
||||||
|
Expect(p.key).To(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Bearer session token sets source=web, apikey=nil", func() {
|
||||||
|
db := testDB()
|
||||||
|
appConfig := config.NewApplicationConfig()
|
||||||
|
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||||
|
token := createTestSession(db, user.ID)
|
||||||
|
|
||||||
|
var p probe
|
||||||
|
app := probeApp(db, appConfig, &p)
|
||||||
|
rec := doRequest(app, http.MethodGet, "/probe", withBearerToken(token))
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(p.user).ToNot(BeNil())
|
||||||
|
Expect(p.user.ID).To(Equal(user.ID))
|
||||||
|
Expect(p.source).To(Equal(auth.UsageSourceWeb))
|
||||||
|
Expect(p.key).To(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Bearer API key sets source=apikey and exposes the resolved *UserAPIKey", func() {
|
||||||
|
db := testDB()
|
||||||
|
appConfig := config.NewApplicationConfig()
|
||||||
|
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||||
|
plaintext, key, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
var p probe
|
||||||
|
app := probeApp(db, appConfig, &p)
|
||||||
|
rec := doRequest(app, http.MethodGet, "/probe", withBearerToken(plaintext))
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(p.source).To(Equal(auth.UsageSourceAPIKey))
|
||||||
|
Expect(p.key).ToNot(BeNil())
|
||||||
|
Expect(p.key.ID).To(Equal(key.ID))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("x-api-key header sets source=apikey", func() {
|
||||||
|
db := testDB()
|
||||||
|
appConfig := config.NewApplicationConfig()
|
||||||
|
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||||
|
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
var p probe
|
||||||
|
app := probeApp(db, appConfig, &p)
|
||||||
|
rec := doRequest(app, http.MethodGet, "/probe", withXApiKey(plaintext))
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(p.source).To(Equal(auth.UsageSourceAPIKey))
|
||||||
|
Expect(p.key).ToNot(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("token cookie sets source=apikey", func() {
|
||||||
|
db := testDB()
|
||||||
|
appConfig := config.NewApplicationConfig()
|
||||||
|
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||||
|
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
var p probe
|
||||||
|
app := probeApp(db, appConfig, &p)
|
||||||
|
rec := doRequest(app, http.MethodGet, "/probe", withTokenCookie(plaintext))
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(p.source).To(Equal(auth.UsageSourceAPIKey))
|
||||||
|
Expect(p.key).ToNot(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("legacy env key sets source=legacy, apikey=nil", func() {
|
||||||
|
db := testDB()
|
||||||
|
appConfig := config.NewApplicationConfig()
|
||||||
|
appConfig.ApiKeys = []string{"legacy-secret"}
|
||||||
|
|
||||||
|
var p probe
|
||||||
|
app := probeApp(db, appConfig, &p)
|
||||||
|
rec := doRequest(app, http.MethodGet, "/probe", withBearerToken("legacy-secret"))
|
||||||
|
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(p.source).To(Equal(auth.UsageSourceLegacy))
|
||||||
|
Expect(p.key).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -5,14 +5,31 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/xlog"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Source classification for a UsageRecord.
|
||||||
|
const (
|
||||||
|
UsageSourceAPIKey = "apikey" // request authenticated with a named UserAPIKey
|
||||||
|
UsageSourceWeb = "web" // request authenticated with a session cookie (web UI)
|
||||||
|
UsageSourceLegacy = "legacy" // request authenticated with an env-configured legacy key
|
||||||
|
)
|
||||||
|
|
||||||
// UsageRecord represents a single API request's token usage.
|
// UsageRecord represents a single API request's token usage.
|
||||||
type UsageRecord struct {
|
type UsageRecord struct {
|
||||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||||
UserName string `gorm:"size:255"`
|
UserName string `gorm:"size:255"`
|
||||||
|
|
||||||
|
// Source classifies how the request authenticated. One of UsageSource* constants.
|
||||||
|
// Empty for pre-feature rows until the InitDB backfill runs.
|
||||||
|
Source string `gorm:"size:16;index:idx_usage_source"`
|
||||||
|
// APIKeyID is the UserAPIKey.ID when Source == UsageSourceAPIKey. Nil otherwise.
|
||||||
|
APIKeyID *string `gorm:"size:36;index:idx_usage_apikey"`
|
||||||
|
// APIKeyName is a snapshot of UserAPIKey.Name at write time. Survives key deletion.
|
||||||
|
APIKeyName string `gorm:"size:255"`
|
||||||
|
|
||||||
Model string `gorm:"size:255;index"`
|
Model string `gorm:"size:255;index"`
|
||||||
Endpoint string `gorm:"size:255"`
|
Endpoint string `gorm:"size:255"`
|
||||||
PromptTokens int64
|
PromptTokens int64
|
||||||
@@ -30,9 +47,12 @@ func RecordUsage(db *gorm.DB, record *UsageRecord) error {
|
|||||||
// UsageBucket is an aggregated time bucket for the dashboard.
|
// UsageBucket is an aggregated time bucket for the dashboard.
|
||||||
type UsageBucket struct {
|
type UsageBucket struct {
|
||||||
Bucket string `json:"bucket"`
|
Bucket string `json:"bucket"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model,omitempty"`
|
||||||
UserID string `json:"user_id,omitempty"`
|
UserID string `json:"user_id,omitempty"`
|
||||||
UserName string `json:"user_name,omitempty"`
|
UserName string `json:"user_name,omitempty"`
|
||||||
|
Source string `json:"source,omitempty"`
|
||||||
|
APIKeyID string `json:"api_key_id,omitempty"`
|
||||||
|
APIKeyName string `json:"api_key_name,omitempty"`
|
||||||
PromptTokens int64 `json:"prompt_tokens"`
|
PromptTokens int64 `json:"prompt_tokens"`
|
||||||
CompletionTokens int64 `json:"completion_tokens"`
|
CompletionTokens int64 `json:"completion_tokens"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
@@ -119,6 +139,28 @@ func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) {
|
|||||||
return buckets, nil
|
return buckets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BackfillUsageSource sets the Source column on pre-feature usage rows.
|
||||||
|
// Idempotent: only touches rows where source is NULL or empty.
|
||||||
|
// - rows whose user_id == "legacy-api-key" -> UsageSourceLegacy
|
||||||
|
// - everything else -> UsageSourceWeb
|
||||||
|
func BackfillUsageSource(db *gorm.DB) error {
|
||||||
|
// Legacy first (more specific predicate)
|
||||||
|
if err := db.Exec(
|
||||||
|
`UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '') AND user_id = ?`,
|
||||||
|
UsageSourceLegacy, "legacy-api-key",
|
||||||
|
).Error; err != nil {
|
||||||
|
return fmt.Errorf("backfill legacy usage source: %w", err)
|
||||||
|
}
|
||||||
|
// Everything else -> web
|
||||||
|
if err := db.Exec(
|
||||||
|
`UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '')`,
|
||||||
|
UsageSourceWeb,
|
||||||
|
).Error; err != nil {
|
||||||
|
return fmt.Errorf("backfill web usage source: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetAllUsage returns aggregated usage for all users (admin). Optional userID filter.
|
// GetAllUsage returns aggregated usage for all users (admin). Optional userID filter.
|
||||||
func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
||||||
sqlite := isSQLiteDB(db)
|
sqlite := isSQLiteDB(db)
|
||||||
@@ -149,3 +191,257 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
|||||||
}
|
}
|
||||||
return buckets, nil
|
return buckets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TotalsEntry is a token+request roll-up.
|
||||||
|
type TotalsEntry struct {
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyTotal is the per-key roll-up returned by sources endpoints. UserID and
|
||||||
|
// UserName are snapshotted from the UsageRecord so revoked-and-deleted keys
|
||||||
|
// still carry their owner attribution in admin views.
|
||||||
|
type KeyTotal struct {
|
||||||
|
APIKeyID string `json:"api_key_id"`
|
||||||
|
APIKeyName string `json:"api_key_name"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
UserName string `json:"user_name"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
LastUsed time.Time `json:"last_used"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserSourceTotal is a per-(user, source) roll-up for sources that don't carry
|
||||||
|
// a named API key identity (web, legacy). It exists so admin views can show
|
||||||
|
// which user generated each block of Web UI / legacy traffic; the per-apikey
|
||||||
|
// breakdown for source=apikey already lives in KeyTotal.
|
||||||
|
type UserSourceTotal struct {
|
||||||
|
Source string `json:"source"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
UserName string `json:"user_name"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SourceTotals summarises a per-source breakdown.
|
||||||
|
type SourceTotals struct {
|
||||||
|
BySource map[string]TotalsEntry `json:"by_source"`
|
||||||
|
ByKey []KeyTotal `json:"by_key"` // server-sorted desc by tokens, capped
|
||||||
|
ByUserSource []UserSourceTotal `json:"by_user_source,omitempty"` // populated only when includeLegacy=true
|
||||||
|
GrandTotal TotalsEntry `json:"grand_total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxKeyTotals = 200
|
||||||
|
|
||||||
|
// GetUserUsageBySource returns per-source aggregated usage for one user. Legacy
|
||||||
|
// is excluded by design (visible to admins only via the admin variant).
|
||||||
|
func GetUserUsageBySource(db *gorm.DB, userID, period string) ([]UsageBucket, SourceTotals, error) {
|
||||||
|
sqlite := isSQLiteDB(db)
|
||||||
|
since, dateFmt := periodToWindow(period, sqlite)
|
||||||
|
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||||
|
|
||||||
|
query := db.Model(&UsageRecord{}).
|
||||||
|
Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
|
||||||
|
"SUM(prompt_tokens) as prompt_tokens, "+
|
||||||
|
"SUM(completion_tokens) as completion_tokens, "+
|
||||||
|
"SUM(total_tokens) as total_tokens, "+
|
||||||
|
"COUNT(*) as request_count").
|
||||||
|
Where("user_id = ?", userID).
|
||||||
|
Where("source <> ?", UsageSourceLegacy).
|
||||||
|
Group("bucket, source, api_key_id, api_key_name").
|
||||||
|
Order("bucket ASC")
|
||||||
|
|
||||||
|
if !since.IsZero() {
|
||||||
|
query = query.Where("created_at >= ?", since)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buckets []UsageBucket
|
||||||
|
if err := query.Find(&buckets).Error; err != nil {
|
||||||
|
return nil, SourceTotals{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
totals := computeSourceTotals(db, userID, "", since, false)
|
||||||
|
return buckets, totals, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// computeSourceTotals rolls up by_source / by_key / grand_total.
|
||||||
|
// userID/apiKeyID are optional filters. includeLegacy controls whether the
|
||||||
|
// legacy bucket is exposed (admin-only).
|
||||||
|
func computeSourceTotals(db *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) SourceTotals {
|
||||||
|
totals := SourceTotals{BySource: map[string]TotalsEntry{}}
|
||||||
|
|
||||||
|
bySourceQ := db.Model(&UsageRecord{}).
|
||||||
|
Select("source, SUM(total_tokens) as tokens, COUNT(*) as requests").
|
||||||
|
Group("source")
|
||||||
|
bySourceQ = applyFilters(bySourceQ, userID, apiKeyID, since, includeLegacy)
|
||||||
|
|
||||||
|
var bySourceRows []struct {
|
||||||
|
Source string
|
||||||
|
Tokens int64
|
||||||
|
Requests int64
|
||||||
|
}
|
||||||
|
if err := bySourceQ.Scan(&bySourceRows).Error; err != nil {
|
||||||
|
xlog.Warn("computeSourceTotals: by-source Scan failed", "error", err)
|
||||||
|
return totals
|
||||||
|
}
|
||||||
|
for _, r := range bySourceRows {
|
||||||
|
totals.BySource[r.Source] = TotalsEntry{Tokens: r.Tokens, Requests: r.Requests}
|
||||||
|
totals.GrandTotal.Tokens += r.Tokens
|
||||||
|
totals.GrandTotal.Requests += r.Requests
|
||||||
|
}
|
||||||
|
|
||||||
|
byKeyQ := db.Model(&UsageRecord{}).
|
||||||
|
Select("COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
|
||||||
|
"user_id, user_name, "+
|
||||||
|
"SUM(total_tokens) as tokens, COUNT(*) as requests, MAX(created_at) as last_used").
|
||||||
|
Where("api_key_id IS NOT NULL AND api_key_id <> ''").
|
||||||
|
Group("api_key_id, api_key_name, user_id, user_name").
|
||||||
|
Order("tokens DESC").
|
||||||
|
Limit(maxKeyTotals)
|
||||||
|
byKeyQ = applyFilters(byKeyQ, userID, apiKeyID, since, includeLegacy)
|
||||||
|
|
||||||
|
// Iterate Rows() manually because MAX(created_at) is returned as a string by
|
||||||
|
// the SQLite driver, and Go's database/sql refuses to scan that into
|
||||||
|
// *time.Time. Postgres returns a proper timestamp. We accept both shapes
|
||||||
|
// via a Rows.Scan into a string column, then parse uniformly.
|
||||||
|
rows, err := byKeyQ.Rows()
|
||||||
|
if err != nil {
|
||||||
|
xlog.Warn("computeSourceTotals: by-key Rows() failed", "error", err)
|
||||||
|
} else {
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
out := make([]KeyTotal, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
apiKeyID, apiKeyName, userIDCol, userName, lastUsedRaw string
|
||||||
|
tokens, requests int64
|
||||||
|
)
|
||||||
|
if scanErr := rows.Scan(&apiKeyID, &apiKeyName, &userIDCol, &userName, &tokens, &requests, &lastUsedRaw); scanErr != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, KeyTotal{
|
||||||
|
APIKeyID: apiKeyID,
|
||||||
|
APIKeyName: apiKeyName,
|
||||||
|
UserID: userIDCol,
|
||||||
|
UserName: userName,
|
||||||
|
Tokens: tokens,
|
||||||
|
Requests: requests,
|
||||||
|
LastUsed: parseLastUsedString(lastUsedRaw),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if rerr := rows.Err(); rerr != nil {
|
||||||
|
xlog.Warn("computeSourceTotals: by-key rows iteration failed", "error", rerr)
|
||||||
|
}
|
||||||
|
totals.ByKey = out
|
||||||
|
}
|
||||||
|
|
||||||
|
// by_user_source: only populated for admin callers (includeLegacy=true) so
|
||||||
|
// they can attribute Web UI / legacy traffic to specific users. Per-apikey
|
||||||
|
// rows already carry user info via KeyTotal above, so this query only
|
||||||
|
// covers source != apikey.
|
||||||
|
if includeLegacy {
|
||||||
|
byUserSourceQ := db.Model(&UsageRecord{}).
|
||||||
|
Select("source, user_id, user_name, "+
|
||||||
|
"SUM(total_tokens) as tokens, COUNT(*) as requests").
|
||||||
|
Where("source <> ?", UsageSourceAPIKey).
|
||||||
|
Group("source, user_id, user_name").
|
||||||
|
Order("tokens DESC")
|
||||||
|
byUserSourceQ = applyFilters(byUserSourceQ, userID, apiKeyID, since, includeLegacy)
|
||||||
|
|
||||||
|
var byUserSourceRows []UserSourceTotal
|
||||||
|
if scanErr := byUserSourceQ.Scan(&byUserSourceRows).Error; scanErr != nil {
|
||||||
|
xlog.Warn("computeSourceTotals: by-user-source Scan failed", "error", scanErr)
|
||||||
|
} else {
|
||||||
|
totals.ByUserSource = byUserSourceRows
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return totals
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseLastUsedString converts the textual MAX(created_at) value returned by
|
||||||
|
// SQLite (or any driver that surfaces the timestamp as a string) into a
|
||||||
|
// time.Time. Returns the zero time on parse failure.
|
||||||
|
func parseLastUsedString(s string) time.Time {
|
||||||
|
if s == "" {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
// GORM's SQLite driver emits Go's default time formatting. Try the formats
|
||||||
|
// it commonly produces, falling back to RFC3339Nano.
|
||||||
|
layouts := []string{
|
||||||
|
"2006-01-02 15:04:05.999999999 -0700 MST",
|
||||||
|
"2006-01-02 15:04:05.999999999-07:00",
|
||||||
|
"2006-01-02 15:04:05.999999999",
|
||||||
|
"2006-01-02 15:04:05",
|
||||||
|
time.RFC3339Nano,
|
||||||
|
time.RFC3339,
|
||||||
|
}
|
||||||
|
for _, layout := range layouts {
|
||||||
|
if t, err := time.Parse(layout, s); err == nil {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xlog.Warn("parseLastUsedString: unrecognised format", "value", s)
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllUsageBySource is the admin variant of GetUserUsageBySource.
|
||||||
|
// Optional filters: userID and apiKeyID. Legacy is included.
|
||||||
|
// truncated == true iff the per-key roll-up was capped at maxKeyTotals.
|
||||||
|
func GetAllUsageBySource(db *gorm.DB, period, userID, apiKeyID string) ([]UsageBucket, SourceTotals, bool, error) {
|
||||||
|
sqlite := isSQLiteDB(db)
|
||||||
|
since, dateFmt := periodToWindow(period, sqlite)
|
||||||
|
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||||
|
|
||||||
|
query := db.Model(&UsageRecord{}).
|
||||||
|
Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
|
||||||
|
"user_id, user_name, "+
|
||||||
|
"SUM(prompt_tokens) as prompt_tokens, "+
|
||||||
|
"SUM(completion_tokens) as completion_tokens, "+
|
||||||
|
"SUM(total_tokens) as total_tokens, "+
|
||||||
|
"COUNT(*) as request_count").
|
||||||
|
Group("bucket, source, api_key_id, api_key_name, user_id, user_name").
|
||||||
|
Order("bucket ASC")
|
||||||
|
|
||||||
|
query = applyFilters(query, userID, apiKeyID, since, true)
|
||||||
|
|
||||||
|
var buckets []UsageBucket
|
||||||
|
if err := query.Find(&buckets).Error; err != nil {
|
||||||
|
return nil, SourceTotals{}, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
totals := computeSourceTotals(db, userID, apiKeyID, since, true)
|
||||||
|
|
||||||
|
// Count distinct api_key_ids matching the filters. If > maxKeyTotals,
|
||||||
|
// the by_key slice was capped and we signal truncation to the caller.
|
||||||
|
truncated := false
|
||||||
|
var distinct int64
|
||||||
|
countQ := applyFilters(
|
||||||
|
db.Model(&UsageRecord{}).
|
||||||
|
Distinct("api_key_id").
|
||||||
|
Where("api_key_id IS NOT NULL AND api_key_id <> ''"),
|
||||||
|
userID, apiKeyID, since, true,
|
||||||
|
)
|
||||||
|
if err := countQ.Count(&distinct).Error; err != nil {
|
||||||
|
xlog.Warn("GetAllUsageBySource: distinct api_key_id count failed", "error", err)
|
||||||
|
} else {
|
||||||
|
truncated = distinct > maxKeyTotals
|
||||||
|
}
|
||||||
|
|
||||||
|
return buckets, totals, truncated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyFilters(q *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) *gorm.DB {
|
||||||
|
if userID != "" {
|
||||||
|
q = q.Where("user_id = ?", userID)
|
||||||
|
}
|
||||||
|
if apiKeyID != "" {
|
||||||
|
q = q.Where("api_key_id = ?", apiKeyID)
|
||||||
|
}
|
||||||
|
if !since.IsZero() {
|
||||||
|
q = q.Where("created_at >= ?", since)
|
||||||
|
}
|
||||||
|
if !includeLegacy {
|
||||||
|
q = q.Where("source <> ?", UsageSourceLegacy)
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,11 +3,13 @@
|
|||||||
package auth_test
|
package auth_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/http/auth"
|
"github.com/mudler/LocalAI/core/http/auth"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Usage", func() {
|
var _ = Describe("Usage", func() {
|
||||||
@@ -158,4 +160,275 @@ var _ = Describe("Usage", func() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Describe("Usage source backfill", func() {
|
||||||
|
It("backfills 'web' for pre-feature rows", func() {
|
||||||
|
db := testDB()
|
||||||
|
|
||||||
|
rawDB, err := db.DB()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = rawDB.Exec(
|
||||||
|
`INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`,
|
||||||
|
"user-x", "gpt-4", time.Now())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||||
|
|
||||||
|
var loaded auth.UsageRecord
|
||||||
|
Expect(db.Where("user_id = ?", "user-x").First(&loaded).Error).To(Succeed())
|
||||||
|
Expect(loaded.Source).To(Equal(auth.UsageSourceWeb))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("backfills 'legacy' for pre-feature rows with legacy-api-key user_id", func() {
|
||||||
|
db := testDB()
|
||||||
|
|
||||||
|
rawDB, err := db.DB()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = rawDB.Exec(
|
||||||
|
`INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`,
|
||||||
|
"legacy-api-key", "gpt-4", time.Now())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||||
|
|
||||||
|
var loaded auth.UsageRecord
|
||||||
|
Expect(db.Where("user_id = ?", "legacy-api-key").First(&loaded).Error).To(Succeed())
|
||||||
|
Expect(loaded.Source).To(Equal(auth.UsageSourceLegacy))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("is idempotent on re-run", func() {
|
||||||
|
db := testDB()
|
||||||
|
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||||
|
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("UsageRecord with source fields", func() {
|
||||||
|
It("persists Source, APIKeyID, APIKeyName", func() {
|
||||||
|
db := testDB()
|
||||||
|
keyID := "key-uuid-1"
|
||||||
|
record := &auth.UsageRecord{
|
||||||
|
UserID: "user-1",
|
||||||
|
UserName: "Test User",
|
||||||
|
Source: auth.UsageSourceAPIKey,
|
||||||
|
APIKeyID: &keyID,
|
||||||
|
APIKeyName: "ci-runner",
|
||||||
|
Model: "gpt-4",
|
||||||
|
Endpoint: "/v1/chat/completions",
|
||||||
|
TotalTokens: 150,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
Expect(auth.RecordUsage(db, record)).To(Succeed())
|
||||||
|
|
||||||
|
var loaded auth.UsageRecord
|
||||||
|
Expect(db.First(&loaded, record.ID).Error).To(Succeed())
|
||||||
|
Expect(loaded.Source).To(Equal(auth.UsageSourceAPIKey))
|
||||||
|
Expect(loaded.APIKeyID).ToNot(BeNil())
|
||||||
|
Expect(*loaded.APIKeyID).To(Equal("key-uuid-1"))
|
||||||
|
Expect(loaded.APIKeyName).To(Equal("ci-runner"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("allows nil APIKeyID for web/legacy sources", func() {
|
||||||
|
db := testDB()
|
||||||
|
record := &auth.UsageRecord{
|
||||||
|
UserID: "user-1",
|
||||||
|
Source: auth.UsageSourceWeb,
|
||||||
|
Model: "gpt-4",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
Expect(auth.RecordUsage(db, record)).To(Succeed())
|
||||||
|
|
||||||
|
var loaded auth.UsageRecord
|
||||||
|
Expect(db.First(&loaded, record.ID).Error).To(Succeed())
|
||||||
|
Expect(loaded.Source).To(Equal(auth.UsageSourceWeb))
|
||||||
|
Expect(loaded.APIKeyID).To(BeNil())
|
||||||
|
Expect(loaded.APIKeyName).To(BeEmpty())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("GetUserUsageBySource", func() {
|
||||||
|
insert := func(db *gorm.DB, userID, source, keyID, keyName string, tokens int64, when time.Time) {
|
||||||
|
rec := &auth.UsageRecord{
|
||||||
|
UserID: userID,
|
||||||
|
Source: source,
|
||||||
|
Model: "gpt-4",
|
||||||
|
TotalTokens: tokens,
|
||||||
|
CreatedAt: when,
|
||||||
|
}
|
||||||
|
if keyID != "" {
|
||||||
|
rec.APIKeyID = &keyID
|
||||||
|
rec.APIKeyName = keyName
|
||||||
|
}
|
||||||
|
Expect(auth.RecordUsage(db, rec)).To(Succeed())
|
||||||
|
}
|
||||||
|
|
||||||
|
It("returns only the caller's rows, never legacy", func() {
|
||||||
|
db := testDB()
|
||||||
|
now := time.Now()
|
||||||
|
insert(db, "alice", auth.UsageSourceAPIKey, "k1", "ci", 100, now)
|
||||||
|
insert(db, "alice", auth.UsageSourceWeb, "", "", 50, now)
|
||||||
|
insert(db, "alice", auth.UsageSourceLegacy, "", "", 30, now)
|
||||||
|
insert(db, "bob", auth.UsageSourceAPIKey, "k2", "bobk", 90, now)
|
||||||
|
|
||||||
|
buckets, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
for _, b := range buckets {
|
||||||
|
Expect(b.UserID).To(Or(BeEmpty(), Equal("alice")))
|
||||||
|
Expect(b.Source).ToNot(Equal(auth.UsageSourceLegacy))
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(totals.GrandTotal.Tokens).To(Equal(int64(150)))
|
||||||
|
Expect(totals.BySource[auth.UsageSourceAPIKey].Tokens).To(Equal(int64(100)))
|
||||||
|
Expect(totals.BySource[auth.UsageSourceWeb].Tokens).To(Equal(int64(50)))
|
||||||
|
_, hasLegacy := totals.BySource[auth.UsageSourceLegacy]
|
||||||
|
Expect(hasLegacy).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("snapshots survive key deletion", func() {
|
||||||
|
db := testDB()
|
||||||
|
now := time.Now()
|
||||||
|
insert(db, "alice", auth.UsageSourceAPIKey, "deleted-key", "old-name", 42, now)
|
||||||
|
_, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(totals.ByKey).To(HaveLen(1))
|
||||||
|
Expect(totals.ByKey[0].APIKeyName).To(Equal("old-name"))
|
||||||
|
Expect(totals.ByKey[0].APIKeyID).To(Equal("deleted-key"))
|
||||||
|
Expect(totals.ByKey[0].LastUsed).ToNot(BeZero())
|
||||||
|
Expect(totals.ByKey[0].LastUsed).To(BeTemporally("~", now, 2*time.Second))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("GetAllUsageBySource", func() {
|
||||||
|
insert := func(db *gorm.DB, userID, source, keyID string, tokens int64) {
|
||||||
|
rec := &auth.UsageRecord{
|
||||||
|
UserID: userID,
|
||||||
|
Source: source,
|
||||||
|
Model: "gpt-4",
|
||||||
|
TotalTokens: tokens,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if keyID != "" {
|
||||||
|
rec.APIKeyID = &keyID
|
||||||
|
rec.APIKeyName = "name-" + keyID
|
||||||
|
}
|
||||||
|
Expect(auth.RecordUsage(db, rec)).To(Succeed())
|
||||||
|
}
|
||||||
|
|
||||||
|
It("includes legacy for admins", func() {
|
||||||
|
db := testDB()
|
||||||
|
insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10)
|
||||||
|
insert(db, "legacy-api-key", auth.UsageSourceLegacy, "", 5)
|
||||||
|
|
||||||
|
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(totals.BySource).To(HaveKey(auth.UsageSourceLegacy))
|
||||||
|
Expect(totals.BySource[auth.UsageSourceLegacy].Tokens).To(Equal(int64(5)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("filters by user_id AND api_key_id", func() {
|
||||||
|
db := testDB()
|
||||||
|
insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10)
|
||||||
|
insert(db, "alice", auth.UsageSourceAPIKey, "k2", 20)
|
||||||
|
insert(db, "bob", auth.UsageSourceAPIKey, "k3", 30)
|
||||||
|
|
||||||
|
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "alice", "k2")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(totals.GrandTotal.Tokens).To(Equal(int64(20)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets truncated=true when by_key exceeds the cap", func() {
|
||||||
|
db := testDB()
|
||||||
|
for i := 0; i < 210; i++ {
|
||||||
|
insert(db, "alice", auth.UsageSourceAPIKey, fmt.Sprintf("key-%03d", i), int64(210-i))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, totals, truncated, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(truncated).To(BeTrue())
|
||||||
|
Expect(totals.ByKey).To(HaveLen(200))
|
||||||
|
Expect(totals.ByKey[0].Tokens > totals.ByKey[199].Tokens).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
// insertNamed records a row with explicit user_id, user_name, source,
|
||||||
|
// and optional api key snapshot. Used by the user-attribution tests
|
||||||
|
// below which the older insert helper can't express.
|
||||||
|
insertNamed := func(db *gorm.DB, userID, userName, source, keyID, keyName string, tokens int64) {
|
||||||
|
rec := &auth.UsageRecord{
|
||||||
|
UserID: userID,
|
||||||
|
UserName: userName,
|
||||||
|
Source: source,
|
||||||
|
Model: "gpt-4",
|
||||||
|
TotalTokens: tokens,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if keyID != "" {
|
||||||
|
rec.APIKeyID = &keyID
|
||||||
|
rec.APIKeyName = keyName
|
||||||
|
}
|
||||||
|
Expect(auth.RecordUsage(db, rec)).To(Succeed())
|
||||||
|
}
|
||||||
|
|
||||||
|
It("attributes each KeyTotal to its owner user", func() {
|
||||||
|
db := testDB()
|
||||||
|
insertNamed(db, "alice", "Alice", auth.UsageSourceAPIKey, "k1", "ci-runner", 100)
|
||||||
|
insertNamed(db, "bob", "Bob", auth.UsageSourceAPIKey, "k2", "lap", 50)
|
||||||
|
|
||||||
|
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(totals.ByKey).To(HaveLen(2))
|
||||||
|
|
||||||
|
byID := map[string]auth.KeyTotal{}
|
||||||
|
for _, k := range totals.ByKey {
|
||||||
|
byID[k.APIKeyID] = k
|
||||||
|
}
|
||||||
|
Expect(byID["k1"].UserID).To(Equal("alice"))
|
||||||
|
Expect(byID["k1"].UserName).To(Equal("Alice"))
|
||||||
|
Expect(byID["k2"].UserID).To(Equal("bob"))
|
||||||
|
Expect(byID["k2"].UserName).To(Equal("Bob"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("breaks Web UI and legacy traffic out per user in by_user_source for admin", func() {
|
||||||
|
db := testDB()
|
||||||
|
// Alice and Bob both have Web UI traffic; a synthetic legacy user
|
||||||
|
// also contributes. ByUserSource should expose one row per
|
||||||
|
// (source, user) pair, never for source=apikey.
|
||||||
|
insertNamed(db, "alice", "Alice", auth.UsageSourceWeb, "", "", 30)
|
||||||
|
insertNamed(db, "bob", "Bob", auth.UsageSourceWeb, "", "", 70)
|
||||||
|
insertNamed(db, "legacy-api-key", "API Key User", auth.UsageSourceLegacy, "", "", 10)
|
||||||
|
insertNamed(db, "alice", "Alice", auth.UsageSourceAPIKey, "k1", "ci-runner", 5)
|
||||||
|
|
||||||
|
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(totals.ByUserSource).ToNot(BeEmpty())
|
||||||
|
|
||||||
|
for _, r := range totals.ByUserSource {
|
||||||
|
Expect(r.Source).ToNot(Equal(auth.UsageSourceAPIKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
webByUser := map[string]int64{}
|
||||||
|
legacyByUser := map[string]int64{}
|
||||||
|
for _, r := range totals.ByUserSource {
|
||||||
|
switch r.Source {
|
||||||
|
case auth.UsageSourceWeb:
|
||||||
|
webByUser[r.UserID] = r.Tokens
|
||||||
|
case auth.UsageSourceLegacy:
|
||||||
|
legacyByUser[r.UserID] = r.Tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Expect(webByUser["alice"]).To(Equal(int64(30)))
|
||||||
|
Expect(webByUser["bob"]).To(Equal(int64(70)))
|
||||||
|
Expect(legacyByUser["legacy-api-key"]).To(Equal(int64(10)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does NOT populate by_user_source in the non-admin path", func() {
|
||||||
|
db := testDB()
|
||||||
|
insertNamed(db, "alice", "Alice", auth.UsageSourceWeb, "", "", 30)
|
||||||
|
|
||||||
|
_, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
// Non-admin path uses includeLegacy=false, so by_user_source stays nil.
|
||||||
|
Expect(totals.ByUserSource).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -16,8 +16,11 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/http/auth"
|
"github.com/mudler/LocalAI/core/http/auth"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
"github.com/mudler/LocalAI/core/services/nodes"
|
"github.com/mudler/LocalAI/core/services/nodes"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -381,14 +384,24 @@ func ResumeNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// InstallBackendOnNodeEndpoint triggers backend installation on a worker node via NATS.
|
// InstallBackendOnNodeEndpoint triggers backend installation on a worker node.
|
||||||
|
// Async: enqueues a ManagementOp on the gallery service channel and returns a
|
||||||
|
// jobID immediately. The gallery service worker goroutine drives the actual
|
||||||
|
// install via DistributedBackendManager.InstallBackend, which honors the op's
|
||||||
|
// TargetNodeID to scope the fan-out to one node. The UI polls /api/backends/job/:uid
|
||||||
|
// for progress, mirroring /api/backends/install/:id.
|
||||||
|
//
|
||||||
// Backend can be either a gallery ID (resolved against BackendGalleries) or a
|
// Backend can be either a gallery ID (resolved against BackendGalleries) or a
|
||||||
// direct URI install (URI + Name + optional Alias) — same shape as the
|
// direct URI install (URI + Name + optional Alias) - same shape as the
|
||||||
// standalone /api/backends/install-external path, just scoped to one node.
|
// standalone /api/backends/install-external path, just scoped to one node.
|
||||||
func InstallBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc {
|
//
|
||||||
|
// The legacy unloader argument is retained for signature symmetry with
|
||||||
|
// DeleteBackendOnNodeEndpoint / ListBackendsOnNodeEndpoint but is no longer
|
||||||
|
// used here - the async path goes through galleryService.
|
||||||
|
func InstallBackendOnNodeEndpoint(_ nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
if unloader == nil {
|
if galleryService == nil {
|
||||||
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "gallery service not configured"))
|
||||||
}
|
}
|
||||||
nodeID := c.Param("id")
|
nodeID := c.Param("id")
|
||||||
var req struct {
|
var req struct {
|
||||||
@@ -401,25 +414,65 @@ func InstallBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.Handler
|
|||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
||||||
}
|
}
|
||||||
// Either a gallery backend name or a direct URI must be supplied.
|
|
||||||
if req.Backend == "" && req.URI == "" {
|
if req.Backend == "" && req.URI == "" {
|
||||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name or uri required"))
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name or uri required"))
|
||||||
}
|
}
|
||||||
// Admin-driven backend install: not tied to a specific replica slot
|
|
||||||
// (no model is being loaded). Pass replica 0 to match the worker's
|
jobUUID, err := uuid.NewUUID()
|
||||||
// admin process-key convention (`backend#0`). The worker's fast path
|
|
||||||
// takes over if the backend is already running — upgrades go through
|
|
||||||
// the dedicated /api/backends/upgrade path on backend.upgrade.
|
|
||||||
reply, err := unloader.InstallBackend(nodeID, req.Backend, "", req.BackendGalleries, req.URI, req.Name, req.Alias, 0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Failed to install backend on node", "node", nodeID, "backend", req.Backend, "uri", req.URI, "error", err)
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to generate job id"))
|
||||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to install backend on node"))
|
|
||||||
}
|
}
|
||||||
if !reply.Success {
|
jobID := jobUUID.String()
|
||||||
xlog.Error("Backend install failed on node", "node", nodeID, "backend", req.Backend, "uri", req.URI, "error", reply.Error)
|
|
||||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "backend installation failed"))
|
// Cache key: for gallery installs, use the backend slug; for URI
|
||||||
|
// installs prefer the provided Name (falling back to URI). All keys
|
||||||
|
// are node-scoped so concurrent installs of the same backend on
|
||||||
|
// different nodes do not stomp each other in opcache.
|
||||||
|
backendKey := req.Backend
|
||||||
|
if backendKey == "" {
|
||||||
|
backendKey = req.Name
|
||||||
|
if backendKey == "" {
|
||||||
|
backendKey = req.URI
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]string{"message": "backend installed"})
|
cacheKey := galleryop.NodeScopedKey(nodeID, backendKey)
|
||||||
|
opcache.SetBackend(cacheKey, jobID)
|
||||||
|
|
||||||
|
// Optional caller-supplied galleries override. Mirrors the standalone
|
||||||
|
// install path so an admin can point at a private gallery.
|
||||||
|
galleries := appConfig.BackendGalleries
|
||||||
|
if req.BackendGalleries != "" {
|
||||||
|
var custom []config.Gallery
|
||||||
|
if err := json.Unmarshal([]byte(req.BackendGalleries), &custom); err != nil {
|
||||||
|
xlog.Warn("Ignoring malformed backend_galleries override; falling back to configured galleries", "error", err, "nodeID", nodeID)
|
||||||
|
} else if len(custom) > 0 {
|
||||||
|
galleries = custom
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
op := galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||||
|
ID: jobID,
|
||||||
|
GalleryElementName: req.Backend,
|
||||||
|
Galleries: galleries,
|
||||||
|
TargetNodeID: nodeID,
|
||||||
|
ExternalURI: req.URI,
|
||||||
|
ExternalName: req.Name,
|
||||||
|
ExternalAlias: req.Alias,
|
||||||
|
Context: ctx,
|
||||||
|
CancelFunc: cancelFunc,
|
||||||
|
}
|
||||||
|
galleryService.StoreCancellation(jobID, cancelFunc)
|
||||||
|
go func() {
|
||||||
|
galleryService.BackendGalleryChannel <- op
|
||||||
|
}()
|
||||||
|
|
||||||
|
xlog.Info("Node-scoped backend install dispatched", "node", nodeID, "backend", req.Backend, "uri", req.URI, "jobID", jobID)
|
||||||
|
return c.JSON(http.StatusAccepted, map[string]string{
|
||||||
|
"jobID": jobID,
|
||||||
|
"statusUrl": "/api/backends/job/" + jobID,
|
||||||
|
"message": "backend installation started",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
123
core/http/endpoints/localai/nodes_install_async_test.go
Normal file
123
core/http/endpoints/localai/nodes_install_async_test.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package localai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
|
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InstallBackendOnNodeEndpoint became async to stop blocking the browser on
|
||||||
|
// the 3-minute NATS reply timeout. These specs lock in the new contract:
|
||||||
|
// HTTP 202 with a jobID, a ManagementOp enqueued on the gallery channel, and
|
||||||
|
// an opcache entry keyed by NodeScopedKey so concurrent installs of the same
|
||||||
|
// backend on different nodes do not stomp each other.
|
||||||
|
var _ = Describe("InstallBackendOnNodeEndpoint async behavior", func() {
|
||||||
|
var (
|
||||||
|
e *echo.Echo
|
||||||
|
galleryService *galleryop.GalleryService
|
||||||
|
opcache *galleryop.OpCache
|
||||||
|
appCfg *config.ApplicationConfig
|
||||||
|
dispatched chan galleryop.ManagementOp[gallery.GalleryBackend, any]
|
||||||
|
done chan struct{}
|
||||||
|
drainExited chan struct{}
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
e = echo.New()
|
||||||
|
appCfg = &config.ApplicationConfig{
|
||||||
|
BackendGalleries: []config.Gallery{{Name: "test-gallery", URL: "http://example.com"}},
|
||||||
|
}
|
||||||
|
galleryService = galleryop.NewGalleryService(appCfg, nil)
|
||||||
|
opcache = galleryop.NewOpCache(galleryService)
|
||||||
|
// Drain the gallery channel into a buffered side channel so the
|
||||||
|
// handler's `go func() { ch <- op }()` send does not block waiting
|
||||||
|
// for the real worker (which is not running in this unit test).
|
||||||
|
dispatched = make(chan galleryop.ManagementOp[gallery.GalleryBackend, any], 4)
|
||||||
|
done = make(chan struct{})
|
||||||
|
drainExited = make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(drainExited)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case op := <-galleryService.BackendGalleryChannel:
|
||||||
|
dispatched <- op
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
// Signal the drain goroutine to exit. We do NOT close
|
||||||
|
// BackendGalleryChannel: the handler's dispatch goroutine may still
|
||||||
|
// be pending (specs that don't Eventually-Receive), and a send on a
|
||||||
|
// closed channel panics. Signalling via `done` lets the drain
|
||||||
|
// goroutine return without touching the gallery channel.
|
||||||
|
close(done)
|
||||||
|
Eventually(drainExited, "2s").Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns 202 with a jobID and dispatches a TargetNodeID-scoped op", func() {
|
||||||
|
body := `{"backend": "llama-cpp"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/nodes/node-xyz/backends/install", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
c.SetParamNames("id")
|
||||||
|
c.SetParamValues("node-xyz")
|
||||||
|
|
||||||
|
handler := localai.InstallBackendOnNodeEndpoint(nil, galleryService, opcache, appCfg)
|
||||||
|
Expect(handler(c)).To(Succeed())
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusAccepted))
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||||
|
Expect(resp["jobID"]).To(BeAssignableToTypeOf(""))
|
||||||
|
Expect(resp["jobID"].(string)).ToNot(BeEmpty())
|
||||||
|
Expect(resp["message"]).To(Equal("backend installation started"))
|
||||||
|
|
||||||
|
Eventually(dispatched, "2s").Should(Receive())
|
||||||
|
Expect(opcache.Exists(galleryop.NodeScopedKey("node-xyz", "llama-cpp"))).To(BeTrue())
|
||||||
|
Expect(opcache.IsBackendOp(galleryop.NodeScopedKey("node-xyz", "llama-cpp"))).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns 400 when neither backend nor uri is supplied", func() {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/nodes/node-xyz/backends/install", bytes.NewBufferString(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
c.SetParamNames("id")
|
||||||
|
c.SetParamValues("node-xyz")
|
||||||
|
|
||||||
|
handler := localai.InstallBackendOnNodeEndpoint(nil, galleryService, opcache, appCfg)
|
||||||
|
Expect(handler(c)).To(Succeed())
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("accepts a direct URI install and uses the name as the cache key", func() {
|
||||||
|
body := `{"uri": "oci://example.com/custom-backend:v1", "name": "custom"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/nodes/node-xyz/backends/install", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
c.SetParamNames("id")
|
||||||
|
c.SetParamValues("node-xyz")
|
||||||
|
|
||||||
|
handler := localai.InstallBackendOnNodeEndpoint(nil, galleryService, opcache, appCfg)
|
||||||
|
Expect(handler(c)).To(Succeed())
|
||||||
|
Expect(rec.Code).To(Equal(http.StatusAccepted))
|
||||||
|
|
||||||
|
Expect(opcache.Exists(galleryop.NodeScopedKey("node-xyz", "custom"))).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -22,12 +22,19 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/backend"
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var videoDownloadClient = http.Client{Timeout: 30 * time.Second}
|
||||||
|
|
||||||
func downloadFile(url string) (string, error) {
|
func downloadFile(url string) (string, error) {
|
||||||
|
if err := utils.ValidateExternalURL(url); err != nil {
|
||||||
|
return "", fmt.Errorf("URL validation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Get the data
|
// Get the data
|
||||||
resp, err := http.Get(url)
|
resp, err := videoDownloadClient.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,362 +73,6 @@ func mergeToolCallDeltas(existing []schema.ToolCall, deltas []schema.ToolCall) [
|
|||||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
// @Router /v1/chat/completions [post]
|
// @Router /v1/chat/completions [post]
|
||||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder) echo.HandlerFunc {
|
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder) echo.HandlerFunc {
|
||||||
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, id string, created int) error {
|
|
||||||
initialMessage := schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
responses <- initialMessage
|
|
||||||
|
|
||||||
// Detect if thinking token is already in prompt or template
|
|
||||||
// When UseTokenizerTemplate is enabled, predInput is empty, so we check the template
|
|
||||||
var template string
|
|
||||||
if config.TemplateConfig.UseTokenizerTemplate {
|
|
||||||
template = config.GetModelTemplate()
|
|
||||||
} else {
|
|
||||||
template = s
|
|
||||||
}
|
|
||||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
|
||||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
|
||||||
|
|
||||||
_, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
|
||||||
var reasoningDelta, contentDelta string
|
|
||||||
|
|
||||||
// Always keep the Go-side extractor in sync with raw tokens so it
|
|
||||||
// can serve as fallback for backends without an autoparser (e.g. vLLM).
|
|
||||||
goReasoning, goContent := extractor.ProcessToken(s)
|
|
||||||
|
|
||||||
// When C++ autoparser chat deltas are available, prefer them — they
|
|
||||||
// handle model-specific formats (Gemma 4, etc.) without Go-side tags.
|
|
||||||
// Otherwise fall back to Go-side extraction.
|
|
||||||
if tokenUsage.HasChatDeltaContent() {
|
|
||||||
rawReasoning, cd := tokenUsage.ChatDeltaReasoningAndContent()
|
|
||||||
contentDelta = cd
|
|
||||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
|
||||||
} else {
|
|
||||||
reasoningDelta = goReasoning
|
|
||||||
contentDelta = goContent
|
|
||||||
}
|
|
||||||
|
|
||||||
usage := schema.OpenAIUsage{
|
|
||||||
PromptTokens: tokenUsage.Prompt,
|
|
||||||
CompletionTokens: tokenUsage.Completion,
|
|
||||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
|
||||||
}
|
|
||||||
if extraUsage {
|
|
||||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
|
||||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
|
||||||
}
|
|
||||||
|
|
||||||
delta := &schema.Message{}
|
|
||||||
if contentDelta != "" {
|
|
||||||
delta.Content = &contentDelta
|
|
||||||
}
|
|
||||||
if reasoningDelta != "" {
|
|
||||||
delta.Reasoning = &reasoningDelta
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Usage: usage,
|
|
||||||
}
|
|
||||||
|
|
||||||
responses <- resp
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
close(responses)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, id string, created int, textContentToReturn *string) error {
|
|
||||||
// Detect if thinking token is already in prompt or template
|
|
||||||
var template string
|
|
||||||
if config.TemplateConfig.UseTokenizerTemplate {
|
|
||||||
template = config.GetModelTemplate()
|
|
||||||
} else {
|
|
||||||
template = prompt
|
|
||||||
}
|
|
||||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
|
||||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
|
||||||
|
|
||||||
result := ""
|
|
||||||
lastEmittedCount := 0
|
|
||||||
sentInitialRole := false
|
|
||||||
sentReasoning := false
|
|
||||||
hasChatDeltaToolCalls := false
|
|
||||||
hasChatDeltaContent := false
|
|
||||||
|
|
||||||
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
|
||||||
result += s
|
|
||||||
|
|
||||||
// Track whether ChatDeltas from the C++ autoparser contain
|
|
||||||
// tool calls or content, so the retry decision can account for them.
|
|
||||||
for _, d := range usage.ChatDeltas {
|
|
||||||
if len(d.ToolCalls) > 0 {
|
|
||||||
hasChatDeltaToolCalls = true
|
|
||||||
}
|
|
||||||
if d.Content != "" {
|
|
||||||
hasChatDeltaContent = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var reasoningDelta, contentDelta string
|
|
||||||
|
|
||||||
goReasoning, goContent := extractor.ProcessToken(s)
|
|
||||||
|
|
||||||
if usage.HasChatDeltaContent() {
|
|
||||||
rawReasoning, cd := usage.ChatDeltaReasoningAndContent()
|
|
||||||
contentDelta = cd
|
|
||||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
|
||||||
} else {
|
|
||||||
reasoningDelta = goReasoning
|
|
||||||
contentDelta = goContent
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
|
|
||||||
// (OpenAI spec: reasoning and tool_calls never share a delta)
|
|
||||||
if reasoningDelta != "" {
|
|
||||||
responses <- schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model,
|
|
||||||
Choices: []schema.Choice{{
|
|
||||||
Delta: &schema.Message{Reasoning: &reasoningDelta},
|
|
||||||
Index: 0,
|
|
||||||
}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
sentReasoning = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream content deltas (cleaned of reasoning tags) while no tool calls
|
|
||||||
// have been detected. Once the incremental parser finds tool calls,
|
|
||||||
// content stops — per OpenAI spec, content and tool_calls don't mix.
|
|
||||||
if lastEmittedCount == 0 && contentDelta != "" {
|
|
||||||
if !sentInitialRole {
|
|
||||||
responses <- schema.OpenAIResponse{
|
|
||||||
ID: id, Created: created, Model: req.Model,
|
|
||||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
sentInitialRole = true
|
|
||||||
}
|
|
||||||
responses <- schema.OpenAIResponse{
|
|
||||||
ID: id, Created: created, Model: req.Model,
|
|
||||||
Choices: []schema.Choice{{
|
|
||||||
Delta: &schema.Message{Content: &contentDelta},
|
|
||||||
Index: 0,
|
|
||||||
}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try incremental XML parsing for streaming support using iterative parser
|
|
||||||
// This allows emitting partial tool calls as they're being generated
|
|
||||||
cleanedResult := functions.CleanupLLMResult(result, config.FunctionsConfig)
|
|
||||||
|
|
||||||
// Determine XML format from config
|
|
||||||
var xmlFormat *functions.XMLToolCallFormat
|
|
||||||
if config.FunctionsConfig.XMLFormat != nil {
|
|
||||||
xmlFormat = config.FunctionsConfig.XMLFormat
|
|
||||||
} else if config.FunctionsConfig.XMLFormatPreset != "" {
|
|
||||||
xmlFormat = functions.GetXMLFormatPreset(config.FunctionsConfig.XMLFormatPreset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use iterative parser for streaming (partial parsing enabled)
|
|
||||||
// Try XML parsing first
|
|
||||||
partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true)
|
|
||||||
if parseErr == nil && len(partialResults) > 0 {
|
|
||||||
// Emit new XML tool calls that weren't emitted before
|
|
||||||
if len(partialResults) > lastEmittedCount {
|
|
||||||
for i := lastEmittedCount; i < len(partialResults); i++ {
|
|
||||||
toolCall := partialResults[i]
|
|
||||||
initialMessage := schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model,
|
|
||||||
Choices: []schema.Choice{{
|
|
||||||
Delta: &schema.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
ToolCalls: []schema.ToolCall{
|
|
||||||
{
|
|
||||||
Index: i,
|
|
||||||
ID: id,
|
|
||||||
Type: "function",
|
|
||||||
FunctionCall: schema.FunctionCall{
|
|
||||||
Name: toolCall.Name,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Index: 0,
|
|
||||||
FinishReason: nil,
|
|
||||||
}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case responses <- initialMessage:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lastEmittedCount = len(partialResults)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Try JSON tool call parsing for streaming.
|
|
||||||
// Only emit NEW tool calls (same guard as XML parser above).
|
|
||||||
jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true)
|
|
||||||
if jsonErr == nil && len(jsonResults) > lastEmittedCount {
|
|
||||||
for i := lastEmittedCount; i < len(jsonResults); i++ {
|
|
||||||
jsonObj := jsonResults[i]
|
|
||||||
name, ok := jsonObj["name"].(string)
|
|
||||||
if !ok || name == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
args := "{}"
|
|
||||||
if argsVal, ok := jsonObj["arguments"]; ok {
|
|
||||||
if argsStr, ok := argsVal.(string); ok {
|
|
||||||
args = argsStr
|
|
||||||
} else {
|
|
||||||
argsBytes, _ := json.Marshal(argsVal)
|
|
||||||
args = string(argsBytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
initialMessage := schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model,
|
|
||||||
Choices: []schema.Choice{{
|
|
||||||
Delta: &schema.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
ToolCalls: []schema.ToolCall{
|
|
||||||
{
|
|
||||||
Index: i,
|
|
||||||
ID: id,
|
|
||||||
Type: "function",
|
|
||||||
FunctionCall: schema.FunctionCall{
|
|
||||||
Name: name,
|
|
||||||
Arguments: args,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Index: 0,
|
|
||||||
FinishReason: nil,
|
|
||||||
}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
responses <- initialMessage
|
|
||||||
}
|
|
||||||
lastEmittedCount = len(jsonResults)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
},
|
|
||||||
func(attempt int) bool {
|
|
||||||
// After streaming completes: check if we got actionable content
|
|
||||||
cleaned := extractor.CleanedContent()
|
|
||||||
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
|
||||||
// but we need to know here whether to retry).
|
|
||||||
// Also check ChatDelta flags — when the C++ autoparser is active,
|
|
||||||
// tool calls and content are delivered via ChatDeltas while the
|
|
||||||
// raw message is cleared. Without this check, we'd retry
|
|
||||||
// unnecessarily, losing valid results and concatenating output.
|
|
||||||
hasToolCalls := lastEmittedCount > 0 || hasChatDeltaToolCalls
|
|
||||||
hasContent := cleaned != "" || hasChatDeltaContent
|
|
||||||
if !hasContent && !hasToolCalls {
|
|
||||||
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
|
||||||
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
|
||||||
extractor.ResetAndSuppressReasoning()
|
|
||||||
result = ""
|
|
||||||
lastEmittedCount = 0
|
|
||||||
sentInitialRole = false
|
|
||||||
hasChatDeltaToolCalls = false
|
|
||||||
hasChatDeltaContent = false
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Try using pre-parsed tool calls from C++ autoparser (chat deltas)
|
|
||||||
var functionResults []functions.FuncCallResults
|
|
||||||
var reasoning string
|
|
||||||
|
|
||||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
|
||||||
xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls))
|
|
||||||
functionResults = deltaToolCalls
|
|
||||||
// Use content/reasoning from deltas too
|
|
||||||
*textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
|
||||||
reasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
|
||||||
} else {
|
|
||||||
// Fallback: parse tool calls from raw text (no chat deltas from backend)
|
|
||||||
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
|
|
||||||
reasoning = extractor.Reasoning()
|
|
||||||
cleanedResult := extractor.CleanedContent()
|
|
||||||
*textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig)
|
|
||||||
cleanedResult = functions.CleanupLLMResult(cleanedResult, config.FunctionsConfig)
|
|
||||||
functionResults = functions.ParseFunctionCall(cleanedResult, config.FunctionsConfig)
|
|
||||||
}
|
|
||||||
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", *textContentToReturn)
|
|
||||||
// noAction is a sentinel "just answer" pseudo-function — not a real
|
|
||||||
// tool call. Scan the whole slice rather than only index 0 so we
|
|
||||||
// don't drop a real tool call that happens to follow a noAction
|
|
||||||
// entry, and so the default branch isn't entered with only noAction
|
|
||||||
// entries to emit as tool_calls.
|
|
||||||
noActionToRun := !hasRealCall(functionResults, noAction)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case noActionToRun:
|
|
||||||
usage := schema.OpenAIUsage{
|
|
||||||
PromptTokens: tokenUsage.Prompt,
|
|
||||||
CompletionTokens: tokenUsage.Completion,
|
|
||||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
|
||||||
}
|
|
||||||
if extraUsage {
|
|
||||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
|
||||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
|
||||||
}
|
|
||||||
|
|
||||||
var result string
|
|
||||||
if !sentInitialRole {
|
|
||||||
var hqErr error
|
|
||||||
result, hqErr = handleQuestion(config, functionResults, extractor.CleanedContent(), prompt)
|
|
||||||
if hqErr != nil {
|
|
||||||
xlog.Error("error handling question", "error", hqErr)
|
|
||||||
return hqErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, chunk := range buildNoActionFinalChunks(
|
|
||||||
id, req.Model, created,
|
|
||||||
sentInitialRole, sentReasoning,
|
|
||||||
result, reasoning, usage,
|
|
||||||
) {
|
|
||||||
responses <- chunk
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
for _, chunk := range buildDeferredToolCallChunks(
|
|
||||||
id, req.Model, created,
|
|
||||||
functionResults, lastEmittedCount,
|
|
||||||
sentInitialRole, *textContentToReturn,
|
|
||||||
sentReasoning, reasoning,
|
|
||||||
) {
|
|
||||||
responses <- chunk
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
close(responses)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
var textContentToReturn string
|
var textContentToReturn string
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
@@ -696,17 +340,19 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
}
|
}
|
||||||
|
|
||||||
responses := make(chan schema.OpenAIResponse)
|
responses := make(chan schema.OpenAIResponse)
|
||||||
ended := make(chan error, 1)
|
ended := make(chan streamWorkerResult, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if !shouldUseFn {
|
if !shouldUseFn {
|
||||||
ended <- process(predInput, input, config, ml, responses, extraUsage, id, created)
|
u, err := processStream(predInput, input, config, cl, startupOptions, ml, responses, id, created)
|
||||||
|
ended <- streamWorkerResult{usage: u, err: err}
|
||||||
} else {
|
} else {
|
||||||
ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage, id, created, &textContentToReturn)
|
u, err := processStreamWithTools(noActionName, predInput, input, config, cl, startupOptions, ml, responses, id, created, &textContentToReturn)
|
||||||
|
ended <- streamWorkerResult{usage: u, err: err}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
usage := &schema.OpenAIUsage{}
|
var finalUsage backend.TokenUsage
|
||||||
toolsCalled := false
|
toolsCalled := false
|
||||||
var collectedToolCalls []schema.ToolCall
|
var collectedToolCalls []schema.ToolCall
|
||||||
var collectedContent string
|
var collectedContent string
|
||||||
@@ -724,7 +370,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
xlog.Debug("No choices in the response, skipping")
|
xlog.Debug("No choices in the response, skipping")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
|
||||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||||
toolsCalled = true
|
toolsCalled = true
|
||||||
// Collect and merge tool call deltas for MCP execution
|
// Collect and merge tool call deltas for MCP execution
|
||||||
@@ -754,15 +399,16 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.Response().Flush()
|
c.Response().Flush()
|
||||||
case err := <-ended:
|
case res := <-ended:
|
||||||
if err == nil {
|
if res.err == nil {
|
||||||
|
finalUsage = res.usage
|
||||||
break LOOP
|
break LOOP
|
||||||
}
|
}
|
||||||
xlog.Error("Stream ended with error", "error", err)
|
xlog.Error("Stream ended with error", "error", res.err)
|
||||||
|
|
||||||
errorResp := schema.ErrorResponse{
|
errorResp := schema.ErrorResponse{
|
||||||
Error: &schema.APIError{
|
Error: &schema.APIError{
|
||||||
Message: err.Error(),
|
Message: res.err.Error(),
|
||||||
Type: "server_error",
|
Type: "server_error",
|
||||||
Code: "server_error",
|
Code: "server_error",
|
||||||
},
|
},
|
||||||
@@ -785,7 +431,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
// still trying to send (e.g., after client disconnect). The goroutine
|
// still trying to send (e.g., after client disconnect). The goroutine
|
||||||
// calls close(responses) when done, which terminates the drain.
|
// calls close(responses) when done, which terminates the drain.
|
||||||
if input.Context.Err() != nil {
|
if input.Context.Err() != nil {
|
||||||
go func() { for range responses {} }()
|
go func() {
|
||||||
|
for range responses {
|
||||||
|
}
|
||||||
|
}()
|
||||||
<-ended
|
<-ended
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -888,6 +537,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
finishReason = FinishReasonFunctionCall
|
finishReason = FinishReasonFunctionCall
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Final delta chunk: empty delta with finish_reason set. Per
|
||||||
|
// OpenAI streaming spec this chunk does NOT carry usage —
|
||||||
|
// the optional trailer (below) does, gated on include_usage.
|
||||||
resp := &schema.OpenAIResponse{
|
resp := &schema.OpenAIResponse{
|
||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
@@ -899,11 +551,26 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
Delta: &schema.Message{},
|
Delta: &schema.Message{},
|
||||||
}},
|
}},
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Usage: *usage,
|
|
||||||
}
|
}
|
||||||
respData, _ := json.Marshal(resp)
|
respData, _ := json.Marshal(resp)
|
||||||
|
|
||||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||||
|
|
||||||
|
// Trailing usage chunk per OpenAI spec: emit only when the
|
||||||
|
// caller opted in via stream_options.include_usage. Shape:
|
||||||
|
// {"choices":[],"usage":{...},"object":"chat.completion.chunk",...}
|
||||||
|
//
|
||||||
|
// finalUsage is the authoritative TokenUsage returned by the
|
||||||
|
// worker function (process / processTools) via the `ended`
|
||||||
|
// channel. The worker reads it from ComputeChoices' return
|
||||||
|
// value, which is the cumulative count produced by the backend
|
||||||
|
// over the whole prediction. Issue #9927 was caused by the
|
||||||
|
// tools-path worker not surfacing this value at all.
|
||||||
|
if input.StreamOptions != nil && input.StreamOptions.IncludeUsage {
|
||||||
|
trailerUsage := streamUsageFromTokenUsage(finalUsage, extraUsage)
|
||||||
|
trailer := streamUsageTrailerJSON(id, input.Model, created, trailerUsage)
|
||||||
|
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", trailer)
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||||
c.Response().Flush()
|
c.Response().Flush()
|
||||||
xlog.Debug("Stream ended")
|
xlog.Debug("Stream ended")
|
||||||
@@ -1263,7 +930,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
Choices: result,
|
Choices: result,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Usage: usage,
|
Usage: &usage,
|
||||||
}
|
}
|
||||||
respData, _ := json.Marshal(resp)
|
respData, _ := json.Marshal(resp)
|
||||||
xlog.Debug("Response", "response", string(respData))
|
xlog.Debug("Response", "response", string(respData))
|
||||||
|
|||||||
@@ -1,12 +1,74 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// streamWorkerResult is what the streaming workers (process / processTools)
|
||||||
|
// hand back to the outer ChatEndpoint loop through the `ended` channel.
|
||||||
|
// Threading the final TokenUsage here, instead of piggy-backing it on the
|
||||||
|
// `responses` SSE channel, keeps the SSE channel single-purpose (wire chunks)
|
||||||
|
// and gives the trailer emitter a plain Go value to read after LOOP exits.
|
||||||
|
// Fix for issue #9927: the previous tools-path worker never surfaced the
|
||||||
|
// cumulative token counts at all, so the include_usage trailer reported zeros.
|
||||||
|
type streamWorkerResult struct {
|
||||||
|
usage backend.TokenUsage
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamUsageFromTokenUsage converts the backend's cumulative TokenUsage into
|
||||||
|
// the OpenAI-spec OpenAIUsage shape used on the wire. `extraUsage` controls
|
||||||
|
// whether the non-standard timing fields are forwarded.
|
||||||
|
func streamUsageFromTokenUsage(usage backend.TokenUsage, extraUsage bool) schema.OpenAIUsage {
|
||||||
|
out := schema.OpenAIUsage{
|
||||||
|
PromptTokens: usage.Prompt,
|
||||||
|
CompletionTokens: usage.Completion,
|
||||||
|
TotalTokens: usage.Prompt + usage.Completion,
|
||||||
|
}
|
||||||
|
if extraUsage {
|
||||||
|
out.TimingTokenGeneration = usage.TimingTokenGeneration
|
||||||
|
out.TimingPromptProcessing = usage.TimingPromptProcessing
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamUsageTrailerJSON returns the bytes of the OpenAI-spec trailing usage
|
||||||
|
// chunk emitted in streaming completions when the request opts in via
|
||||||
|
// `stream_options.include_usage: true`. The shape is:
|
||||||
|
//
|
||||||
|
// {"id":"...","object":"chat.completion.chunk","created":N,
|
||||||
|
// "model":"...","choices":[],"usage":{...}}
|
||||||
|
//
|
||||||
|
// `choices` is intentionally an empty array (not absent, not null) — that is
|
||||||
|
// what the OpenAI spec mandates, and what consumers like the official OpenAI
|
||||||
|
// SDK and Continue's openai-adapter look for to recognise this as the usage
|
||||||
|
// chunk rather than a content chunk. schema.OpenAIResponse has `omitempty`
|
||||||
|
// on Choices, so we cannot reuse it for the trailer.
|
||||||
|
func streamUsageTrailerJSON(id, model string, created int, usage schema.OpenAIUsage) []byte {
|
||||||
|
trailer := struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Choices []schema.Choice `json:"choices"`
|
||||||
|
Usage schema.OpenAIUsage `json:"usage"`
|
||||||
|
}{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: model,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Choices: []schema.Choice{},
|
||||||
|
Usage: usage,
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(trailer)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// hasRealCall reports whether functionResults contains at least one
|
// hasRealCall reports whether functionResults contains at least one
|
||||||
// entry whose Name is something other than the noAction sentinel.
|
// entry whose Name is something other than the noAction sentinel.
|
||||||
// Used by processTools to decide between the "answer the question"
|
// Used by processTools to decide between the "answer the question"
|
||||||
@@ -25,10 +87,10 @@ func hasRealCall(functionResults []functions.FuncCallResults, noAction string) b
|
|||||||
// pseudo-function or emitted no tool calls at all).
|
// pseudo-function or emitted no tool calls at all).
|
||||||
//
|
//
|
||||||
// When content was already streamed (contentAlreadyStreamed=true) the
|
// When content was already streamed (contentAlreadyStreamed=true) the
|
||||||
// helper emits a single trailing usage chunk, optionally carrying
|
// helper emits a trailing reasoning chunk if any non-streamed reasoning
|
||||||
// reasoning that was produced but not streamed incrementally. When
|
// remains, else nothing. When content was not streamed it emits a role
|
||||||
// content was not streamed it emits a role chunk followed by a
|
// chunk followed by a content (+reasoning) chunk — the "send everything
|
||||||
// content+reasoning+usage chunk — the "send everything at once" fallback.
|
// at once" fallback.
|
||||||
//
|
//
|
||||||
// Reasoning re-emission is guarded by reasoningAlreadyStreamed, not by
|
// Reasoning re-emission is guarded by reasoningAlreadyStreamed, not by
|
||||||
// probing the extractor's Go-side state: the C++ autoparser delivers
|
// probing the extractor's Go-side state: the C++ autoparser delivers
|
||||||
@@ -36,6 +98,10 @@ func hasRealCall(functionResults []functions.FuncCallResults, noAction string) b
|
|||||||
// separate accumulator that extractor.Reasoning() does not expose.
|
// separate accumulator that extractor.Reasoning() does not expose.
|
||||||
// Without this guard the callback would stream reasoning incrementally
|
// Without this guard the callback would stream reasoning incrementally
|
||||||
// and the final chunk would duplicate it.
|
// and the final chunk would duplicate it.
|
||||||
|
//
|
||||||
|
// The returned chunks intentionally do NOT carry a `usage` field. The
|
||||||
|
// usage trailer is emitted separately by the streaming handler when
|
||||||
|
// `stream_options.include_usage` is true, per OpenAI spec.
|
||||||
func buildNoActionFinalChunks(
|
func buildNoActionFinalChunks(
|
||||||
id, model string,
|
id, model string,
|
||||||
created int,
|
created int,
|
||||||
@@ -43,26 +109,26 @@ func buildNoActionFinalChunks(
|
|||||||
reasoningAlreadyStreamed bool,
|
reasoningAlreadyStreamed bool,
|
||||||
content string,
|
content string,
|
||||||
reasoning string,
|
reasoning string,
|
||||||
usage schema.OpenAIUsage,
|
|
||||||
) []schema.OpenAIResponse {
|
) []schema.OpenAIResponse {
|
||||||
var out []schema.OpenAIResponse
|
var out []schema.OpenAIResponse
|
||||||
|
|
||||||
if contentAlreadyStreamed {
|
if contentAlreadyStreamed {
|
||||||
delta := &schema.Message{}
|
if reasoning == "" || reasoningAlreadyStreamed {
|
||||||
if reasoning != "" && !reasoningAlreadyStreamed {
|
return nil
|
||||||
r := reasoning
|
|
||||||
delta.Reasoning = &r
|
|
||||||
}
|
}
|
||||||
|
r := reasoning
|
||||||
out = append(out, schema.OpenAIResponse{
|
out = append(out, schema.OpenAIResponse{
|
||||||
ID: id, Created: created, Model: model,
|
ID: id, Created: created, Model: model,
|
||||||
Choices: []schema.Choice{{Delta: delta, Index: 0}},
|
Choices: []schema.Choice{{
|
||||||
Object: "chat.completion.chunk",
|
Delta: &schema.Message{Reasoning: &r},
|
||||||
Usage: usage,
|
Index: 0,
|
||||||
|
}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
})
|
})
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// Content was not streamed — send role, then content (+reasoning) + usage.
|
// Content was not streamed — send role, then content (+reasoning).
|
||||||
out = append(out, schema.OpenAIResponse{
|
out = append(out, schema.OpenAIResponse{
|
||||||
ID: id, Created: created, Model: model,
|
ID: id, Created: created, Model: model,
|
||||||
Choices: []schema.Choice{{
|
Choices: []schema.Choice{{
|
||||||
@@ -82,7 +148,6 @@ func buildNoActionFinalChunks(
|
|||||||
ID: id, Created: created, Model: model,
|
ID: id, Created: created, Model: model,
|
||||||
Choices: []schema.Choice{{Delta: delta, Index: 0}},
|
Choices: []schema.Choice{{Delta: delta, Index: 0}},
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Usage: usage,
|
|
||||||
})
|
})
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -609,54 +609,52 @@ var _ = Describe("buildNoActionFinalChunks", func() {
|
|||||||
testModel = "test-model"
|
testModel = "test-model"
|
||||||
testCreated = 1700000000
|
testCreated = 1700000000
|
||||||
)
|
)
|
||||||
usage := schema.OpenAIUsage{PromptTokens: 5, CompletionTokens: 7, TotalTokens: 12}
|
|
||||||
|
|
||||||
Describe("Content streamed — trailing usage chunk", func() {
|
Describe("Content streamed — trailing reasoning only", func() {
|
||||||
It("emits just one chunk with usage, no content, no reasoning when reasoning was streamed", func() {
|
It("emits nothing when content and reasoning were already streamed", func() {
|
||||||
|
// Before the streaming-usage-spec fix this branch emitted a
|
||||||
|
// content-less chunk solely to carry `usage`. Per the OpenAI
|
||||||
|
// spec usage no longer rides on delta chunks; the dedicated
|
||||||
|
// trailer (when include_usage=true) carries it instead — so
|
||||||
|
// with nothing to deliver the helper returns no chunks.
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
true, true,
|
true, true,
|
||||||
"", "already-streamed-reasoning", usage,
|
"", "already-streamed-reasoning",
|
||||||
)
|
)
|
||||||
|
Expect(chunks).To(BeEmpty())
|
||||||
Expect(chunks).To(HaveLen(1))
|
|
||||||
Expect(chunks[0].Usage.TotalTokens).To(Equal(12))
|
|
||||||
Expect(contentOf(chunks[0])).To(BeEmpty())
|
|
||||||
Expect(reasoningOf(chunks[0])).To(BeEmpty(),
|
|
||||||
"reasoning must not be re-emitted once it was streamed via the callback")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("emits a trailing reasoning delivery when reasoning came only at end", func() {
|
It("emits a trailing reasoning delivery when reasoning came only at end", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
true, false,
|
true, false,
|
||||||
"", "autoparser final reasoning", usage,
|
"", "autoparser final reasoning",
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(chunks).To(HaveLen(1))
|
Expect(chunks).To(HaveLen(1))
|
||||||
Expect(reasoningOf(chunks[0])).To(Equal("autoparser final reasoning"))
|
Expect(reasoningOf(chunks[0])).To(Equal("autoparser final reasoning"))
|
||||||
Expect(contentOf(chunks[0])).To(BeEmpty())
|
Expect(contentOf(chunks[0])).To(BeEmpty())
|
||||||
Expect(chunks[0].Usage.TotalTokens).To(Equal(12))
|
Expect(chunks[0].Usage).To(BeNil(),
|
||||||
|
"intermediate chunks must not carry usage per OpenAI spec")
|
||||||
})
|
})
|
||||||
|
|
||||||
It("omits reasoning when it's empty regardless of streamed flag", func() {
|
It("returns no chunks when reasoning is empty and content was streamed", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
true, false,
|
true, false,
|
||||||
"", "", usage,
|
"", "",
|
||||||
)
|
)
|
||||||
|
Expect(chunks).To(BeEmpty())
|
||||||
Expect(chunks).To(HaveLen(1))
|
|
||||||
Expect(reasoningOf(chunks[0])).To(BeEmpty())
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("Content not streamed — role, then content+usage", func() {
|
Describe("Content not streamed — role, then content", func() {
|
||||||
It("emits role chunk then content chunk without reasoning when reasoning was streamed", func() {
|
It("emits role chunk then content chunk without reasoning when reasoning was streamed", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
false, true,
|
false, true,
|
||||||
"the answer", "already-streamed-reasoning", usage,
|
"the answer", "already-streamed-reasoning",
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(chunks).To(HaveLen(2))
|
Expect(chunks).To(HaveLen(2))
|
||||||
@@ -666,14 +664,14 @@ var _ = Describe("buildNoActionFinalChunks", func() {
|
|||||||
Expect(contentOf(chunks[1])).To(Equal("the answer"))
|
Expect(contentOf(chunks[1])).To(Equal("the answer"))
|
||||||
Expect(reasoningOf(chunks[1])).To(BeEmpty(),
|
Expect(reasoningOf(chunks[1])).To(BeEmpty(),
|
||||||
"reasoning must not be re-emitted if it was streamed earlier")
|
"reasoning must not be re-emitted if it was streamed earlier")
|
||||||
Expect(chunks[1].Usage.TotalTokens).To(Equal(12))
|
Expect(chunks[1].Usage).To(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("emits role, then content+reasoning when reasoning was not streamed", func() {
|
It("emits role, then content+reasoning when reasoning was not streamed", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
false, false,
|
false, false,
|
||||||
"the answer", "autoparser final reasoning", usage,
|
"the answer", "autoparser final reasoning",
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(chunks).To(HaveLen(2))
|
Expect(chunks).To(HaveLen(2))
|
||||||
@@ -681,14 +679,14 @@ var _ = Describe("buildNoActionFinalChunks", func() {
|
|||||||
|
|
||||||
Expect(contentOf(chunks[1])).To(Equal("the answer"))
|
Expect(contentOf(chunks[1])).To(Equal("the answer"))
|
||||||
Expect(reasoningOf(chunks[1])).To(Equal("autoparser final reasoning"))
|
Expect(reasoningOf(chunks[1])).To(Equal("autoparser final reasoning"))
|
||||||
Expect(chunks[1].Usage.TotalTokens).To(Equal(12))
|
Expect(chunks[1].Usage).To(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("still emits content even when reasoning is empty", func() {
|
It("still emits content even when reasoning is empty", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
false, false,
|
false, false,
|
||||||
"just an answer", "", usage,
|
"just an answer", "",
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(chunks).To(HaveLen(2))
|
Expect(chunks).To(HaveLen(2))
|
||||||
@@ -702,7 +700,7 @@ var _ = Describe("buildNoActionFinalChunks", func() {
|
|||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
false, false,
|
false, false,
|
||||||
"hi", "reasoning", usage,
|
"hi", "reasoning",
|
||||||
)
|
)
|
||||||
for i, ch := range chunks {
|
for i, ch := range chunks {
|
||||||
Expect(ch.ID).To(Equal(testID), "chunk[%d] ID", i)
|
Expect(ch.ID).To(Equal(testID), "chunk[%d] ID", i)
|
||||||
|
|||||||
362
core/http/endpoints/openai/chat_stream_usage_test.go
Normal file
362
core/http/endpoints/openai/chat_stream_usage_test.go
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These tests pin LocalAI's streaming chunks to the OpenAI spec for the
|
||||||
|
// `usage` field. The regression that motivated them (issue #8546) was that
|
||||||
|
// LocalAI emitted `"usage":{...zeros...}` on every chunk, which made the
|
||||||
|
// official OpenAI Node SDK consumers (Continue, Kilo Code, Roo Code, Zed,
|
||||||
|
// IntelliJ Continue) drop every content chunk via the filter at
|
||||||
|
// continuedev/continue packages/openai-adapters/src/apis/OpenAI.ts:275-288.
|
||||||
|
//
|
||||||
|
// Per OpenAI's chat-completion streaming contract:
|
||||||
|
// - intermediate chunks MUST NOT carry a `usage` field
|
||||||
|
// - usage is only delivered when the request opts in via
|
||||||
|
// `stream_options.include_usage: true`, on a final extra chunk whose
|
||||||
|
// `choices` is an empty array.
|
||||||
|
|
||||||
|
var _ = Describe("streaming usage spec compliance", func() {
|
||||||
|
Describe("OpenAIResponse JSON shape", func() {
|
||||||
|
It("does not emit a 'usage' key when Usage is unset", func() {
|
||||||
|
// A typical intermediate token chunk: no Usage populated.
|
||||||
|
content := "hello"
|
||||||
|
resp := schema.OpenAIResponse{
|
||||||
|
ID: "req-1",
|
||||||
|
Created: 1,
|
||||||
|
Model: "m",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Choices: []schema.Choice{{
|
||||||
|
Index: 0,
|
||||||
|
Delta: &schema.Message{Content: &content},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
var raw map[string]any
|
||||||
|
Expect(json.Unmarshal(data, &raw)).To(Succeed())
|
||||||
|
_, present := raw["usage"]
|
||||||
|
Expect(present).To(BeFalse(),
|
||||||
|
"intermediate chunk must not include a 'usage' key; got: %s", string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("emits the usage object when Usage is explicitly set", func() {
|
||||||
|
usage := &schema.OpenAIUsage{PromptTokens: 11, CompletionTokens: 22, TotalTokens: 33}
|
||||||
|
resp := schema.OpenAIResponse{
|
||||||
|
ID: "req-1",
|
||||||
|
Created: 1,
|
||||||
|
Model: "m",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Usage: usage,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
var raw map[string]any
|
||||||
|
Expect(json.Unmarshal(data, &raw)).To(Succeed())
|
||||||
|
u, ok := raw["usage"].(map[string]any)
|
||||||
|
Expect(ok).To(BeTrue(), "expected 'usage' object, got: %s", string(data))
|
||||||
|
Expect(u["prompt_tokens"]).To(BeNumerically("==", 11))
|
||||||
|
Expect(u["completion_tokens"]).To(BeNumerically("==", 22))
|
||||||
|
Expect(u["total_tokens"]).To(BeNumerically("==", 33))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("buildNoActionFinalChunks", func() {
|
||||||
|
It("returns chunks with no Usage embedded", func() {
|
||||||
|
// Whatever the caller is doing, helpers must not bake usage
|
||||||
|
// into intermediate or final delta chunks. The usage trailer
|
||||||
|
// (when requested via include_usage) is emitted separately.
|
||||||
|
chunks := buildNoActionFinalChunks(
|
||||||
|
"req-1", "m", 1,
|
||||||
|
false, false,
|
||||||
|
"hi", "",
|
||||||
|
)
|
||||||
|
Expect(chunks).ToNot(BeEmpty())
|
||||||
|
for i, ch := range chunks {
|
||||||
|
Expect(ch.Usage).To(BeNil(),
|
||||||
|
"chunk[%d] must not carry Usage; got %+v", i, ch.Usage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns chunks with no Usage when only trailing reasoning needs delivery", func() {
|
||||||
|
chunks := buildNoActionFinalChunks(
|
||||||
|
"req-1", "m", 1,
|
||||||
|
true, false,
|
||||||
|
"", "autoparser late reasoning",
|
||||||
|
)
|
||||||
|
Expect(chunks).ToNot(BeEmpty())
|
||||||
|
for i, ch := range chunks {
|
||||||
|
Expect(ch.Usage).To(BeNil(),
|
||||||
|
"chunk[%d] must not carry Usage; got %+v", i, ch.Usage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("buildDeferredToolCallChunks", func() {
|
||||||
|
It("returns chunks with no Usage embedded", func() {
|
||||||
|
calls := []functions.FuncCallResults{{
|
||||||
|
Name: "do_thing", Arguments: `{"x":1}`,
|
||||||
|
}}
|
||||||
|
chunks := buildDeferredToolCallChunks(
|
||||||
|
"req-1", "m", 1, calls, 0,
|
||||||
|
false, "", false, "",
|
||||||
|
)
|
||||||
|
Expect(chunks).ToNot(BeEmpty())
|
||||||
|
for i, ch := range chunks {
|
||||||
|
Expect(ch.Usage).To(BeNil(),
|
||||||
|
"chunk[%d] must not carry Usage; got %+v", i, ch.Usage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("streamUsageTrailerJSON", func() {
|
||||||
|
It("produces JSON matching the OpenAI spec for the trailer chunk", func() {
|
||||||
|
// Trailing usage chunk shape (OpenAI streaming spec):
|
||||||
|
// {"id":"...","object":"chat.completion.chunk","created":...,
|
||||||
|
// "model":"...","choices":[],"usage":{...}}
|
||||||
|
usage := schema.OpenAIUsage{
|
||||||
|
PromptTokens: 18, CompletionTokens: 14, TotalTokens: 32,
|
||||||
|
}
|
||||||
|
data := streamUsageTrailerJSON("req-1", "m", 1, usage)
|
||||||
|
|
||||||
|
var raw map[string]any
|
||||||
|
Expect(json.Unmarshal(data, &raw)).To(Succeed(),
|
||||||
|
"trailer must be valid JSON, got: %s", string(data))
|
||||||
|
|
||||||
|
Expect(raw["id"]).To(Equal("req-1"))
|
||||||
|
Expect(raw["model"]).To(Equal("m"))
|
||||||
|
Expect(raw["object"]).To(Equal("chat.completion.chunk"))
|
||||||
|
Expect(raw["created"]).To(BeNumerically("==", 1))
|
||||||
|
|
||||||
|
// `choices` MUST be present as an empty array (not absent, not null).
|
||||||
|
rawChoices, present := raw["choices"]
|
||||||
|
Expect(present).To(BeTrue(), "choices key must be present, got: %s", string(data))
|
||||||
|
choicesArr, ok := rawChoices.([]any)
|
||||||
|
Expect(ok).To(BeTrue(), "choices must serialize as an array, got: %s", string(data))
|
||||||
|
Expect(choicesArr).To(BeEmpty(), "choices must be empty in usage trailer, got: %s", string(data))
|
||||||
|
|
||||||
|
// `usage` MUST be present and non-null with the populated counts.
|
||||||
|
u, ok := raw["usage"].(map[string]any)
|
||||||
|
Expect(ok).To(BeTrue(), "usage object must be present, got: %s", string(data))
|
||||||
|
Expect(u["prompt_tokens"]).To(BeNumerically("==", 18))
|
||||||
|
Expect(u["completion_tokens"]).To(BeNumerically("==", 14))
|
||||||
|
Expect(u["total_tokens"]).To(BeNumerically("==", 32))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("streamUsageFromTokenUsage", func() {
|
||||||
|
It("converts backend TokenUsage to schema OpenAIUsage", func() {
|
||||||
|
tu := backend.TokenUsage{Prompt: 18, Completion: 213}
|
||||||
|
u := streamUsageFromTokenUsage(tu, false)
|
||||||
|
Expect(u.PromptTokens).To(Equal(18))
|
||||||
|
Expect(u.CompletionTokens).To(Equal(213))
|
||||||
|
Expect(u.TotalTokens).To(Equal(231))
|
||||||
|
Expect(u.TimingTokenGeneration).To(BeZero())
|
||||||
|
Expect(u.TimingPromptProcessing).To(BeZero())
|
||||||
|
})
|
||||||
|
It("includes timings when extraUsage is true", func() {
|
||||||
|
tu := backend.TokenUsage{
|
||||||
|
Prompt: 10, Completion: 20,
|
||||||
|
TimingPromptProcessing: 0.5,
|
||||||
|
TimingTokenGeneration: 1.5,
|
||||||
|
}
|
||||||
|
u := streamUsageFromTokenUsage(tu, true)
|
||||||
|
Expect(u.TimingPromptProcessing).To(Equal(0.5))
|
||||||
|
Expect(u.TimingTokenGeneration).To(Equal(1.5))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("OpenAIRequest.StreamOptions", func() {
|
||||||
|
It("parses stream_options.include_usage=true", func() {
|
||||||
|
body := []byte(`{
|
||||||
|
"model": "m",
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"messages": []
|
||||||
|
}`)
|
||||||
|
var req schema.OpenAIRequest
|
||||||
|
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||||
|
Expect(req.StreamOptions).ToNot(BeNil())
|
||||||
|
Expect(req.StreamOptions.IncludeUsage).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("defaults IncludeUsage to false when stream_options is absent", func() {
|
||||||
|
body := []byte(`{"model":"m","stream":true,"messages":[]}`)
|
||||||
|
var req schema.OpenAIRequest
|
||||||
|
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||||
|
// Either a nil StreamOptions or one with IncludeUsage=false is acceptable.
|
||||||
|
if req.StreamOptions != nil {
|
||||||
|
Expect(req.StreamOptions.IncludeUsage).To(BeFalse())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Functional regression coverage for issue #9927: the streaming workers
|
||||||
|
// must surface the cumulative TokenUsage returned by ComputeChoices to
|
||||||
|
// their caller. The earlier broken implementations discarded that value
|
||||||
|
// (`_, _, chatDeltas, err := ComputeChoices(...)`) and threw away the
|
||||||
|
// counts on the floor, so the include_usage trailer always reported
|
||||||
|
// zeros when tools were enabled.
|
||||||
|
//
|
||||||
|
// These tests stub backend.ModelInferenceFunc so the worker exercises the
|
||||||
|
// real ComputeChoices → predFunc → LLMResponse pipeline. If a future change
|
||||||
|
// drops the TokenUsage somewhere along that path, the assertions on the
|
||||||
|
// returned value fail with a concrete count mismatch (e.g. 0 vs 213),
|
||||||
|
// not with a "function undefined" compile error.
|
||||||
|
var _ = Describe("streaming workers surface final TokenUsage (issue #9927)", func() {
|
||||||
|
var (
|
||||||
|
origInference modelInferenceFunc
|
||||||
|
appCfg *config.ApplicationConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
origInference = backend.ModelInferenceFunc
|
||||||
|
appCfg = config.NewApplicationConfig()
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
backend.ModelInferenceFunc = origInference
|
||||||
|
})
|
||||||
|
|
||||||
|
// mockBackendUsage installs a stub backend that yields one LLMResponse
|
||||||
|
// carrying the supplied TokenUsage. ComputeChoices' single-attempt path
|
||||||
|
// copies these counts into the value it returns to the worker.
|
||||||
|
mockBackendUsage := func(usage backend.TokenUsage, response string) {
|
||||||
|
backend.ModelInferenceFunc = func(
|
||||||
|
ctx context.Context, s string, messages schema.Messages,
|
||||||
|
images, videos, audios []string,
|
||||||
|
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||||
|
o *config.ApplicationConfig,
|
||||||
|
tokenCallback func(string, backend.TokenUsage) bool,
|
||||||
|
tools, toolChoice string,
|
||||||
|
logprobs, topLogprobs *int,
|
||||||
|
logitBias map[string]float64,
|
||||||
|
metadata map[string]string,
|
||||||
|
) (func() (backend.LLMResponse, error), error) {
|
||||||
|
return func() (backend.LLMResponse, error) {
|
||||||
|
return backend.LLMResponse{
|
||||||
|
Response: response,
|
||||||
|
Usage: usage,
|
||||||
|
}, nil
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
makeReq := func() *schema.OpenAIRequest {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
req := &schema.OpenAIRequest{
|
||||||
|
Context: ctx,
|
||||||
|
Cancel: cancel,
|
||||||
|
}
|
||||||
|
req.Model = "test-model" // promoted from BasicModelRequest
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// drainResponses consumes everything the worker pushes onto the channel
|
||||||
|
// so the worker is never blocked on its send. The channel is unbuffered
|
||||||
|
// (matching production), so the drain goroutine must be running before
|
||||||
|
// the worker is called.
|
||||||
|
drainResponses := func(ch <-chan schema.OpenAIResponse) <-chan struct{} {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
for range ch {
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
return done
|
||||||
|
}
|
||||||
|
|
||||||
|
Describe("processStream (no-tools path)", func() {
|
||||||
|
It("returns the cumulative TokenUsage produced by the backend", func() {
|
||||||
|
mockBackendUsage(backend.TokenUsage{Prompt: 18, Completion: 213}, "Hello there")
|
||||||
|
|
||||||
|
req := makeReq()
|
||||||
|
cfg := &config.ModelConfig{}
|
||||||
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
done := drainResponses(responses)
|
||||||
|
|
||||||
|
actual, err := processStream("prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0)
|
||||||
|
<-done
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(actual.Prompt).To(Equal(18),
|
||||||
|
"prompt tokens must round-trip from backend through processStream")
|
||||||
|
Expect(actual.Completion).To(Equal(213),
|
||||||
|
"completion tokens must round-trip from backend through processStream")
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns zero TokenUsage when the backend reports zero (negative control)", func() {
|
||||||
|
mockBackendUsage(backend.TokenUsage{}, "x")
|
||||||
|
|
||||||
|
req := makeReq()
|
||||||
|
cfg := &config.ModelConfig{}
|
||||||
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
done := drainResponses(responses)
|
||||||
|
|
||||||
|
actual, err := processStream("prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0)
|
||||||
|
<-done
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(actual.Prompt).To(BeZero())
|
||||||
|
Expect(actual.Completion).To(BeZero())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("processStreamWithTools (tools path)", func() {
|
||||||
|
It("returns the cumulative TokenUsage produced by the backend", func() {
|
||||||
|
// This is the direct regression check for issue #9927: with tools
|
||||||
|
// enabled, the trailer was reporting {0,0,0} because the worker
|
||||||
|
// discarded ComputeChoices' second return value.
|
||||||
|
mockBackendUsage(backend.TokenUsage{Prompt: 18, Completion: 213}, "answer")
|
||||||
|
|
||||||
|
req := makeReq()
|
||||||
|
cfg := &config.ModelConfig{}
|
||||||
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
done := drainResponses(responses)
|
||||||
|
var textContent string
|
||||||
|
|
||||||
|
actual, err := processStreamWithTools("none", "prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0, &textContent)
|
||||||
|
<-done
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(actual.Prompt).To(Equal(18),
|
||||||
|
"prompt tokens must round-trip from backend through processStreamWithTools (issue #9927)")
|
||||||
|
Expect(actual.Completion).To(Equal(213),
|
||||||
|
"completion tokens must round-trip from backend through processStreamWithTools (issue #9927)")
|
||||||
|
})
|
||||||
|
|
||||||
|
It("forwards timing fields when the backend supplies them", func() {
|
||||||
|
mockBackendUsage(backend.TokenUsage{
|
||||||
|
Prompt: 10, Completion: 20,
|
||||||
|
TimingPromptProcessing: 0.5,
|
||||||
|
TimingTokenGeneration: 1.5,
|
||||||
|
}, "answer")
|
||||||
|
|
||||||
|
req := makeReq()
|
||||||
|
cfg := &config.ModelConfig{}
|
||||||
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
done := drainResponses(responses)
|
||||||
|
var textContent string
|
||||||
|
|
||||||
|
actual, err := processStreamWithTools("none", "prompt", req, cfg, nil, appCfg, nil, responses, "req-1", 0, &textContent)
|
||||||
|
<-done
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(actual.TimingPromptProcessing).To(Equal(0.5))
|
||||||
|
Expect(actual.TimingTokenGeneration).To(Equal(1.5))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
390
core/http/endpoints/openai/chat_stream_workers.go
Normal file
390
core/http/endpoints/openai/chat_stream_workers.go
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
reason "github.com/mudler/LocalAI/pkg/reasoning"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// processStream is the streaming worker for chat completions with no
|
||||||
|
// tool/function calling involved. It pushes SSE-shaped chunks onto
|
||||||
|
// `responses` and returns the authoritative cumulative TokenUsage from
|
||||||
|
// the prediction so the caller can populate the include_usage trailer
|
||||||
|
// without having to peek inside the chunks.
|
||||||
|
//
|
||||||
|
// The caller owns the `responses` channel and is expected to read from
|
||||||
|
// it while this function runs; processStream closes the channel before
|
||||||
|
// returning.
|
||||||
|
func processStream(
|
||||||
|
s string,
|
||||||
|
req *schema.OpenAIRequest,
|
||||||
|
cfg *config.ModelConfig,
|
||||||
|
cl *config.ModelConfigLoader,
|
||||||
|
startupOptions *config.ApplicationConfig,
|
||||||
|
loader *model.ModelLoader,
|
||||||
|
responses chan schema.OpenAIResponse,
|
||||||
|
id string,
|
||||||
|
created int,
|
||||||
|
) (backend.TokenUsage, error) {
|
||||||
|
responses <- schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect if thinking token is already in prompt or template
|
||||||
|
// When UseTokenizerTemplate is enabled, predInput is empty, so we check the template
|
||||||
|
var template string
|
||||||
|
if cfg.TemplateConfig.UseTokenizerTemplate {
|
||||||
|
template = cfg.GetModelTemplate()
|
||||||
|
} else {
|
||||||
|
template = s
|
||||||
|
}
|
||||||
|
thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig)
|
||||||
|
extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig)
|
||||||
|
|
||||||
|
_, finalUsage, _, err := ComputeChoices(req, s, cfg, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||||
|
var reasoningDelta, contentDelta string
|
||||||
|
|
||||||
|
// Always keep the Go-side extractor in sync with raw tokens so it
|
||||||
|
// can serve as fallback for backends without an autoparser (e.g. vLLM).
|
||||||
|
goReasoning, goContent := extractor.ProcessToken(s)
|
||||||
|
|
||||||
|
// When C++ autoparser chat deltas are available, prefer them: they
|
||||||
|
// handle model-specific formats (Gemma 4, etc.) without Go-side tags.
|
||||||
|
// Otherwise fall back to Go-side extraction.
|
||||||
|
if tokenUsage.HasChatDeltaContent() {
|
||||||
|
rawReasoning, cd := tokenUsage.ChatDeltaReasoningAndContent()
|
||||||
|
contentDelta = cd
|
||||||
|
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||||
|
} else {
|
||||||
|
reasoningDelta = goReasoning
|
||||||
|
contentDelta = goContent
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := &schema.Message{}
|
||||||
|
if contentDelta != "" {
|
||||||
|
delta.Content = &contentDelta
|
||||||
|
}
|
||||||
|
if reasoningDelta != "" {
|
||||||
|
delta.Reasoning = &reasoningDelta
|
||||||
|
}
|
||||||
|
|
||||||
|
responses <- schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
close(responses)
|
||||||
|
return finalUsage, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// processStreamWithTools is the streaming worker for chat completions
|
||||||
|
// with tools / function calling. Same contract as processStream: pushes
|
||||||
|
// chunks onto `responses`, closes the channel, returns the cumulative
|
||||||
|
// TokenUsage.
|
||||||
|
//
|
||||||
|
// Returning the TokenUsage as a normal Go value (rather than smuggling
|
||||||
|
// it on a sentinel chunk) is the fix for issue #9927 — the previous
|
||||||
|
// implementation discarded the value from ComputeChoices, so the
|
||||||
|
// include_usage trailer reported zeros whenever `tools` was in play.
|
||||||
|
func processStreamWithTools(
|
||||||
|
noAction string,
|
||||||
|
prompt string,
|
||||||
|
req *schema.OpenAIRequest,
|
||||||
|
cfg *config.ModelConfig,
|
||||||
|
cl *config.ModelConfigLoader,
|
||||||
|
startupOptions *config.ApplicationConfig,
|
||||||
|
loader *model.ModelLoader,
|
||||||
|
responses chan schema.OpenAIResponse,
|
||||||
|
id string,
|
||||||
|
created int,
|
||||||
|
textContentToReturn *string,
|
||||||
|
) (backend.TokenUsage, error) {
|
||||||
|
// Detect if thinking token is already in prompt or template
|
||||||
|
var template string
|
||||||
|
if cfg.TemplateConfig.UseTokenizerTemplate {
|
||||||
|
template = cfg.GetModelTemplate()
|
||||||
|
} else {
|
||||||
|
template = prompt
|
||||||
|
}
|
||||||
|
thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig)
|
||||||
|
extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig)
|
||||||
|
|
||||||
|
result := ""
|
||||||
|
lastEmittedCount := 0
|
||||||
|
sentInitialRole := false
|
||||||
|
sentReasoning := false
|
||||||
|
hasChatDeltaToolCalls := false
|
||||||
|
hasChatDeltaContent := false
|
||||||
|
|
||||||
|
_, finalUsage, chatDeltas, err := ComputeChoices(req, prompt, cfg, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||||
|
result += s
|
||||||
|
|
||||||
|
// Track whether ChatDeltas from the C++ autoparser contain
|
||||||
|
// tool calls or content, so the retry decision can account for them.
|
||||||
|
for _, d := range usage.ChatDeltas {
|
||||||
|
if len(d.ToolCalls) > 0 {
|
||||||
|
hasChatDeltaToolCalls = true
|
||||||
|
}
|
||||||
|
if d.Content != "" {
|
||||||
|
hasChatDeltaContent = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var reasoningDelta, contentDelta string
|
||||||
|
|
||||||
|
goReasoning, goContent := extractor.ProcessToken(s)
|
||||||
|
|
||||||
|
if usage.HasChatDeltaContent() {
|
||||||
|
rawReasoning, cd := usage.ChatDeltaReasoningAndContent()
|
||||||
|
contentDelta = cd
|
||||||
|
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||||
|
} else {
|
||||||
|
reasoningDelta = goReasoning
|
||||||
|
contentDelta = goContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
|
||||||
|
// (OpenAI spec: reasoning and tool_calls never share a delta)
|
||||||
|
if reasoningDelta != "" {
|
||||||
|
responses <- schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model,
|
||||||
|
Choices: []schema.Choice{{
|
||||||
|
Delta: &schema.Message{Reasoning: &reasoningDelta},
|
||||||
|
Index: 0,
|
||||||
|
}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
sentReasoning = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream content deltas (cleaned of reasoning tags) while no tool calls
|
||||||
|
// have been detected. Once the incremental parser finds tool calls,
|
||||||
|
// content stops: per OpenAI spec, content and tool_calls don't mix.
|
||||||
|
if lastEmittedCount == 0 && contentDelta != "" {
|
||||||
|
if !sentInitialRole {
|
||||||
|
responses <- schema.OpenAIResponse{
|
||||||
|
ID: id, Created: created, Model: req.Model,
|
||||||
|
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
sentInitialRole = true
|
||||||
|
}
|
||||||
|
responses <- schema.OpenAIResponse{
|
||||||
|
ID: id, Created: created, Model: req.Model,
|
||||||
|
Choices: []schema.Choice{{
|
||||||
|
Delta: &schema.Message{Content: &contentDelta},
|
||||||
|
Index: 0,
|
||||||
|
}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try incremental XML parsing for streaming support using iterative parser
|
||||||
|
// This allows emitting partial tool calls as they're being generated
|
||||||
|
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
|
||||||
|
|
||||||
|
// Determine XML format from config
|
||||||
|
var xmlFormat *functions.XMLToolCallFormat
|
||||||
|
if cfg.FunctionsConfig.XMLFormat != nil {
|
||||||
|
xmlFormat = cfg.FunctionsConfig.XMLFormat
|
||||||
|
} else if cfg.FunctionsConfig.XMLFormatPreset != "" {
|
||||||
|
xmlFormat = functions.GetXMLFormatPreset(cfg.FunctionsConfig.XMLFormatPreset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use iterative parser for streaming (partial parsing enabled)
|
||||||
|
// Try XML parsing first
|
||||||
|
partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true)
|
||||||
|
if parseErr == nil && len(partialResults) > 0 {
|
||||||
|
// Emit new XML tool calls that weren't emitted before
|
||||||
|
if len(partialResults) > lastEmittedCount {
|
||||||
|
for i := lastEmittedCount; i < len(partialResults); i++ {
|
||||||
|
toolCall := partialResults[i]
|
||||||
|
initialMessage := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model,
|
||||||
|
Choices: []schema.Choice{{
|
||||||
|
Delta: &schema.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []schema.ToolCall{
|
||||||
|
{
|
||||||
|
Index: i,
|
||||||
|
ID: id,
|
||||||
|
Type: "function",
|
||||||
|
FunctionCall: schema.FunctionCall{
|
||||||
|
Name: toolCall.Name,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Index: 0,
|
||||||
|
FinishReason: nil,
|
||||||
|
}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case responses <- initialMessage:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastEmittedCount = len(partialResults)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try JSON tool call parsing for streaming.
|
||||||
|
// Only emit NEW tool calls (same guard as XML parser above).
|
||||||
|
jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true)
|
||||||
|
if jsonErr == nil && len(jsonResults) > lastEmittedCount {
|
||||||
|
for i := lastEmittedCount; i < len(jsonResults); i++ {
|
||||||
|
jsonObj := jsonResults[i]
|
||||||
|
name, ok := jsonObj["name"].(string)
|
||||||
|
if !ok || name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
args := "{}"
|
||||||
|
if argsVal, ok := jsonObj["arguments"]; ok {
|
||||||
|
if argsStr, ok := argsVal.(string); ok {
|
||||||
|
args = argsStr
|
||||||
|
} else {
|
||||||
|
argsBytes, _ := json.Marshal(argsVal)
|
||||||
|
args = string(argsBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
initialMessage := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model,
|
||||||
|
Choices: []schema.Choice{{
|
||||||
|
Delta: &schema.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []schema.ToolCall{
|
||||||
|
{
|
||||||
|
Index: i,
|
||||||
|
ID: id,
|
||||||
|
Type: "function",
|
||||||
|
FunctionCall: schema.FunctionCall{
|
||||||
|
Name: name,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Index: 0,
|
||||||
|
FinishReason: nil,
|
||||||
|
}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
responses <- initialMessage
|
||||||
|
}
|
||||||
|
lastEmittedCount = len(jsonResults)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
func(attempt int) bool {
|
||||||
|
// After streaming completes: check if we got actionable content
|
||||||
|
cleaned := extractor.CleanedContent()
|
||||||
|
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
||||||
|
// but we need to know here whether to retry).
|
||||||
|
// Also check ChatDelta flags: when the C++ autoparser is active,
|
||||||
|
// tool calls and content are delivered via ChatDeltas while the
|
||||||
|
// raw message is cleared. Without this check, we'd retry
|
||||||
|
// unnecessarily, losing valid results and concatenating output.
|
||||||
|
hasToolCalls := lastEmittedCount > 0 || hasChatDeltaToolCalls
|
||||||
|
hasContent := cleaned != "" || hasChatDeltaContent
|
||||||
|
if !hasContent && !hasToolCalls {
|
||||||
|
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
||||||
|
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
||||||
|
extractor.ResetAndSuppressReasoning()
|
||||||
|
result = ""
|
||||||
|
lastEmittedCount = 0
|
||||||
|
sentInitialRole = false
|
||||||
|
hasChatDeltaToolCalls = false
|
||||||
|
hasChatDeltaContent = false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return finalUsage, err
|
||||||
|
}
|
||||||
|
// Try using pre-parsed tool calls from C++ autoparser (chat deltas)
|
||||||
|
var functionResults []functions.FuncCallResults
|
||||||
|
var reasoning string
|
||||||
|
|
||||||
|
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||||
|
xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls))
|
||||||
|
functionResults = deltaToolCalls
|
||||||
|
// Use content/reasoning from deltas too
|
||||||
|
*textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||||
|
reasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||||
|
} else {
|
||||||
|
// Fallback: parse tool calls from raw text (no chat deltas from backend)
|
||||||
|
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||||
|
reasoning = extractor.Reasoning()
|
||||||
|
cleanedResult := extractor.CleanedContent()
|
||||||
|
*textContentToReturn = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
|
||||||
|
cleanedResult = functions.CleanupLLMResult(cleanedResult, cfg.FunctionsConfig)
|
||||||
|
functionResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||||
|
}
|
||||||
|
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", *textContentToReturn)
|
||||||
|
// noAction is a sentinel "just answer" pseudo-function: not a real
|
||||||
|
// tool call. Scan the whole slice rather than only index 0 so we
|
||||||
|
// don't drop a real tool call that happens to follow a noAction
|
||||||
|
// entry, and so the default branch isn't entered with only noAction
|
||||||
|
// entries to emit as tool_calls.
|
||||||
|
noActionToRun := !hasRealCall(functionResults, noAction)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case noActionToRun:
|
||||||
|
// The final usage trailer (when the caller opted in with
|
||||||
|
// stream_options.include_usage) is built by the outer streaming
|
||||||
|
// loop from the TokenUsage this function returns, not from any
|
||||||
|
// chunk on the responses channel.
|
||||||
|
var result string
|
||||||
|
if !sentInitialRole {
|
||||||
|
var hqErr error
|
||||||
|
result, hqErr = handleQuestion(cfg, functionResults, extractor.CleanedContent(), prompt)
|
||||||
|
if hqErr != nil {
|
||||||
|
xlog.Error("error handling question", "error", hqErr)
|
||||||
|
return finalUsage, hqErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, chunk := range buildNoActionFinalChunks(
|
||||||
|
id, req.Model, created,
|
||||||
|
sentInitialRole, sentReasoning,
|
||||||
|
result, reasoning,
|
||||||
|
) {
|
||||||
|
responses <- chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
for _, chunk := range buildDeferredToolCallChunks(
|
||||||
|
id, req.Model, created,
|
||||||
|
functionResults, lastEmittedCount,
|
||||||
|
sentInitialRole, *textContentToReturn,
|
||||||
|
sentReasoning, reasoning,
|
||||||
|
) {
|
||||||
|
responses <- chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
close(responses)
|
||||||
|
return finalUsage, err
|
||||||
|
}
|
||||||
@@ -39,6 +39,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
||||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||||
}
|
}
|
||||||
|
// Usage rides on the struct for the consumer to track the
|
||||||
|
// running cumulative; the consumer strips it before marshalling
|
||||||
|
// so intermediate chunks stay OpenAI-spec compliant.
|
||||||
|
usageForChunk := usage
|
||||||
resp := schema.OpenAIResponse{
|
resp := schema.OpenAIResponse{
|
||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
@@ -51,7 +55,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
Usage: usage,
|
Usage: &usageForChunk,
|
||||||
}
|
}
|
||||||
xlog.Debug("Sending goroutine", "text", s)
|
xlog.Debug("Sending goroutine", "text", s)
|
||||||
|
|
||||||
@@ -127,6 +131,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
|
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var latestUsage *schema.OpenAIUsage
|
||||||
|
|
||||||
LOOP:
|
LOOP:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -135,6 +141,14 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
xlog.Debug("No choices in the response, skipping")
|
xlog.Debug("No choices in the response, skipping")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Capture running cumulative usage for the optional trailer
|
||||||
|
// emitted after the final stop chunk when include_usage=true.
|
||||||
|
if ev.Usage != nil {
|
||||||
|
latestUsage = ev.Usage
|
||||||
|
}
|
||||||
|
// OpenAI streaming spec: intermediate chunks must NOT
|
||||||
|
// carry a `usage` field. Strip the tracking copy now.
|
||||||
|
ev.Usage = nil
|
||||||
respData, err := json.Marshal(ev)
|
respData, err := json.Marshal(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Debug("Failed to marshal response", "error", err)
|
xlog.Debug("Failed to marshal response", "error", err)
|
||||||
@@ -194,8 +208,15 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
}
|
}
|
||||||
respData, _ := json.Marshal(resp)
|
respData, _ := json.Marshal(resp)
|
||||||
|
|
||||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||||
|
|
||||||
|
// Trailing usage chunk per OpenAI spec: emit only when the caller
|
||||||
|
// opted in via stream_options.include_usage.
|
||||||
|
if input.StreamOptions != nil && input.StreamOptions.IncludeUsage && latestUsage != nil {
|
||||||
|
trailer := streamUsageTrailerJSON(id, input.Model, created, *latestUsage)
|
||||||
|
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", trailer)
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||||
c.Response().Flush()
|
c.Response().Flush()
|
||||||
return nil
|
return nil
|
||||||
@@ -247,7 +268,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
Choices: result,
|
Choices: result,
|
||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
Usage: usage,
|
Usage: &usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
Choices: result,
|
Choices: result,
|
||||||
Object: "edit",
|
Object: "edit",
|
||||||
Usage: usage,
|
Usage: &usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
|||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
Data: result,
|
Data: result,
|
||||||
Usage: schema.OpenAIUsage{
|
Usage: &schema.OpenAIUsage{
|
||||||
PromptTokens: 0,
|
PromptTokens: 0,
|
||||||
CompletionTokens: 0,
|
CompletionTokens: 0,
|
||||||
TotalTokens: 0,
|
TotalTokens: 0,
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
|||||||
Data: []schema.Item{{
|
Data: []schema.Item{{
|
||||||
URL: imgPath,
|
URL: imgPath,
|
||||||
}},
|
}},
|
||||||
Usage: schema.OpenAIUsage{
|
Usage: &schema.OpenAIUsage{
|
||||||
PromptTokens: 0,
|
PromptTokens: 0,
|
||||||
CompletionTokens: 0,
|
CompletionTokens: 0,
|
||||||
TotalTokens: 0,
|
TotalTokens: 0,
|
||||||
|
|||||||
@@ -54,6 +54,30 @@ const (
|
|||||||
"Avoid parenthetical asides, URLs, and anything that cannot be clearly vocalized."
|
"Avoid parenthetical asides, URLs, and anything that cannot be clearly vocalized."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// resolveOutputModalities returns the effective output modalities for a
|
||||||
|
// response: response-level overrides session-level, and the OpenAI Realtime
|
||||||
|
// spec default is ["audio"] when neither is set.
|
||||||
|
func resolveOutputModalities(session, response []types.Modality) []types.Modality {
|
||||||
|
if len(response) > 0 {
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
if len(session) > 0 {
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
return []types.Modality{types.ModalityAudio}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modalitiesContainAudio reports whether the resolved modalities include audio
|
||||||
|
// output.
|
||||||
|
func modalitiesContainAudio(m []types.Modality) bool {
|
||||||
|
for _, x := range m {
|
||||||
|
if x == types.ModalityAudio {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
|
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
|
||||||
// If the model support instead audio-to-audio, we will use the specific gRPC calls instead
|
// If the model support instead audio-to-audio, we will use the specific gRPC calls instead
|
||||||
|
|
||||||
@@ -82,6 +106,10 @@ type Session struct {
|
|||||||
InputSampleRate int
|
InputSampleRate int
|
||||||
OutputSampleRate int
|
OutputSampleRate int
|
||||||
MaxOutputTokens types.IntOrInf
|
MaxOutputTokens types.IntOrInf
|
||||||
|
// OutputModalities mirrors the OpenAI Realtime spec field of the same
|
||||||
|
// name. Empty means "use the spec default" (audio). ["text"] suppresses
|
||||||
|
// TTS so the client receives only response.output_text.* events.
|
||||||
|
OutputModalities []types.Modality
|
||||||
// MaxHistoryItems caps the number of MessageItems passed to the LLM each
|
// MaxHistoryItems caps the number of MessageItems passed to the LLM each
|
||||||
// turn (0 = unlimited). Small models — especially the LFM2.5-Audio 1.5B
|
// turn (0 = unlimited). Small models — especially the LFM2.5-Audio 1.5B
|
||||||
// served via the liquid-audio backend — degrade quickly past a handful
|
// served via the liquid-audio backend — degrade quickly past a handful
|
||||||
@@ -162,13 +190,14 @@ func (s *Session) ToServer() types.SessionUnion {
|
|||||||
} else {
|
} else {
|
||||||
return types.SessionUnion{
|
return types.SessionUnion{
|
||||||
Realtime: &types.RealtimeSession{
|
Realtime: &types.RealtimeSession{
|
||||||
ID: s.ID,
|
ID: s.ID,
|
||||||
Object: "realtime.session",
|
Object: "realtime.session",
|
||||||
Model: s.Model,
|
Model: s.Model,
|
||||||
Instructions: s.Instructions,
|
Instructions: s.Instructions,
|
||||||
Tools: s.Tools,
|
Tools: s.Tools,
|
||||||
ToolChoice: s.ToolChoice,
|
ToolChoice: s.ToolChoice,
|
||||||
MaxOutputTokens: s.MaxOutputTokens,
|
MaxOutputTokens: s.MaxOutputTokens,
|
||||||
|
OutputModalities: s.OutputModalities,
|
||||||
Audio: &types.RealtimeSessionAudio{
|
Audio: &types.RealtimeSessionAudio{
|
||||||
Input: &types.SessionAudioInput{
|
Input: &types.SessionAudioInput{
|
||||||
TurnDetection: s.TurnDetection,
|
TurnDetection: s.TurnDetection,
|
||||||
@@ -1015,6 +1044,10 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
|||||||
session.MaxOutputTokens = rt.MaxOutputTokens
|
session.MaxOutputTokens = rt.MaxOutputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(rt.OutputModalities) > 0 {
|
||||||
|
session.OutputModalities = rt.OutputModalities
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1654,106 +1687,130 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for cancellation before TTS
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
xlog.Debug("Response cancelled before TTS (barge-in)")
|
|
||||||
sendCancelledResponse()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language)
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
xlog.Debug("TTS cancelled (barge-in)")
|
|
||||||
sendCancelledResponse()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
xlog.Error("TTS failed", "error", err)
|
|
||||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !res.Success {
|
|
||||||
xlog.Error("TTS failed", "message", res.Message)
|
|
||||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer os.Remove(audioFilePath)
|
|
||||||
|
|
||||||
audioBytes, err := os.ReadFile(audioFilePath)
|
|
||||||
if err != nil {
|
|
||||||
xlog.Error("failed to read TTS file", "error", err)
|
|
||||||
sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse WAV header to get raw PCM and the actual sample rate from the TTS backend.
|
|
||||||
pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes)
|
|
||||||
if ttsSampleRate == 0 {
|
|
||||||
ttsSampleRate = localSampleRate
|
|
||||||
}
|
|
||||||
xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate)
|
|
||||||
|
|
||||||
// SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the
|
|
||||||
// Opus encoder, which resamples to 48kHz internally. This avoids a
|
|
||||||
// lossy intermediate resample through 16kHz.
|
|
||||||
// XXX: This is a noop in websocket mode; it's included in the JSON instead
|
|
||||||
if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
xlog.Debug("Audio playback cancelled (barge-in)")
|
|
||||||
sendCancelledResponse()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
xlog.Error("failed to send audio via transport", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, isWebRTC := t.(*WebRTCTransport)
|
|
||||||
|
|
||||||
// For WebSocket clients, resample to the session's output rate and
|
|
||||||
// deliver audio as base64 in JSON events. WebRTC clients already
|
|
||||||
// received audio over the RTP track, so skip the base64 payload.
|
|
||||||
var audioString string
|
var audioString string
|
||||||
if !isWebRTC {
|
_, isWebRTC := t.(*WebRTCTransport)
|
||||||
wsPCM := pcmData
|
var respMods []types.Modality
|
||||||
if ttsSampleRate != session.OutputSampleRate {
|
if overrides != nil {
|
||||||
samples := sound.BytesToInt16sLE(pcmData)
|
respMods = overrides.OutputModalities
|
||||||
resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate)
|
|
||||||
wsPCM = sound.Int16toBytesLE(resampled)
|
|
||||||
}
|
|
||||||
audioString = base64.StdEncoding.EncodeToString(wsPCM)
|
|
||||||
}
|
}
|
||||||
|
modalities := resolveOutputModalities(session.OutputModalities, respMods)
|
||||||
|
if modalitiesContainAudio(modalities) {
|
||||||
|
// Check for cancellation before TTS
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
xlog.Debug("Response cancelled before TTS (barge-in)")
|
||||||
|
sendCancelledResponse()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{
|
audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language)
|
||||||
ServerEventBase: types.ServerEventBase{},
|
if err != nil {
|
||||||
ResponseID: responseID,
|
if ctx.Err() != nil {
|
||||||
ItemID: item.Assistant.ID,
|
xlog.Debug("TTS cancelled (barge-in)")
|
||||||
OutputIndex: 0,
|
sendCancelledResponse()
|
||||||
ContentIndex: 0,
|
return
|
||||||
Delta: finalSpeech,
|
}
|
||||||
})
|
xlog.Error("TTS failed", "error", err)
|
||||||
sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{
|
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
|
||||||
ServerEventBase: types.ServerEventBase{},
|
return
|
||||||
ResponseID: responseID,
|
}
|
||||||
ItemID: item.Assistant.ID,
|
if !res.Success {
|
||||||
OutputIndex: 0,
|
xlog.Error("TTS failed", "message", res.Message)
|
||||||
ContentIndex: 0,
|
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID)
|
||||||
Transcript: finalSpeech,
|
return
|
||||||
})
|
}
|
||||||
|
defer func() { _ = os.Remove(audioFilePath) }()
|
||||||
|
|
||||||
if !isWebRTC {
|
audioBytes, err := os.ReadFile(audioFilePath)
|
||||||
sendEvent(t, types.ResponseOutputAudioDeltaEvent{
|
if err != nil {
|
||||||
|
xlog.Error("failed to read TTS file", "error", err)
|
||||||
|
sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse WAV header to get raw PCM and the actual sample rate from the TTS backend.
|
||||||
|
pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes)
|
||||||
|
if ttsSampleRate == 0 {
|
||||||
|
ttsSampleRate = localSampleRate
|
||||||
|
}
|
||||||
|
xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate)
|
||||||
|
|
||||||
|
// SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the
|
||||||
|
// Opus encoder, which resamples to 48kHz internally. This avoids a
|
||||||
|
// lossy intermediate resample through 16kHz.
|
||||||
|
// XXX: This is a noop in websocket mode; it's included in the JSON instead
|
||||||
|
if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
xlog.Debug("Audio playback cancelled (barge-in)")
|
||||||
|
sendCancelledResponse()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
xlog.Error("failed to send audio via transport", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For WebSocket clients, resample to the session's output rate and
|
||||||
|
// deliver audio as base64 in JSON events. WebRTC clients already
|
||||||
|
// received audio over the RTP track, so skip the base64 payload.
|
||||||
|
if !isWebRTC {
|
||||||
|
wsPCM := pcmData
|
||||||
|
if ttsSampleRate != session.OutputSampleRate {
|
||||||
|
samples := sound.BytesToInt16sLE(pcmData)
|
||||||
|
resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate)
|
||||||
|
wsPCM = sound.Int16toBytesLE(resampled)
|
||||||
|
}
|
||||||
|
audioString = base64.StdEncoding.EncodeToString(wsPCM)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{
|
||||||
ServerEventBase: types.ServerEventBase{},
|
ServerEventBase: types.ServerEventBase{},
|
||||||
ResponseID: responseID,
|
ResponseID: responseID,
|
||||||
ItemID: item.Assistant.ID,
|
ItemID: item.Assistant.ID,
|
||||||
OutputIndex: 0,
|
OutputIndex: 0,
|
||||||
ContentIndex: 0,
|
ContentIndex: 0,
|
||||||
Delta: audioString,
|
Delta: finalSpeech,
|
||||||
})
|
})
|
||||||
sendEvent(t, types.ResponseOutputAudioDoneEvent{
|
sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{
|
||||||
ServerEventBase: types.ServerEventBase{},
|
ServerEventBase: types.ServerEventBase{},
|
||||||
ResponseID: responseID,
|
ResponseID: responseID,
|
||||||
ItemID: item.Assistant.ID,
|
ItemID: item.Assistant.ID,
|
||||||
OutputIndex: 0,
|
OutputIndex: 0,
|
||||||
ContentIndex: 0,
|
ContentIndex: 0,
|
||||||
|
Transcript: finalSpeech,
|
||||||
|
})
|
||||||
|
|
||||||
|
if !isWebRTC {
|
||||||
|
sendEvent(t, types.ResponseOutputAudioDeltaEvent{
|
||||||
|
ServerEventBase: types.ServerEventBase{},
|
||||||
|
ResponseID: responseID,
|
||||||
|
ItemID: item.Assistant.ID,
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
Delta: audioString,
|
||||||
|
})
|
||||||
|
sendEvent(t, types.ResponseOutputAudioDoneEvent{
|
||||||
|
ServerEventBase: types.ServerEventBase{},
|
||||||
|
ResponseID: responseID,
|
||||||
|
ItemID: item.Assistant.ID,
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Text-only mode: skip TTS, emit only the text events.
|
||||||
|
sendEvent(t, types.ResponseOutputTextDeltaEvent{
|
||||||
|
ServerEventBase: types.ServerEventBase{},
|
||||||
|
ResponseID: responseID,
|
||||||
|
ItemID: item.Assistant.ID,
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
Delta: finalSpeech,
|
||||||
|
})
|
||||||
|
sendEvent(t, types.ResponseOutputTextDoneEvent{
|
||||||
|
ServerEventBase: types.ServerEventBase{},
|
||||||
|
ResponseID: responseID,
|
||||||
|
ItemID: item.Assistant.ID,
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
Text: finalSpeech,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
39
core/http/endpoints/openai/realtime_modality_test.go
Normal file
39
core/http/endpoints/openai/realtime_modality_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("resolveOutputModalities", func() {
|
||||||
|
It("defaults to audio when neither session nor response specify", func() {
|
||||||
|
got := resolveOutputModalities(nil, nil)
|
||||||
|
Expect(got).To(ConsistOf(types.ModalityAudio))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("uses session modalities when response omits them", func() {
|
||||||
|
sess := []types.Modality{types.ModalityText}
|
||||||
|
got := resolveOutputModalities(sess, nil)
|
||||||
|
Expect(got).To(ConsistOf(types.ModalityText))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("response modalities override session", func() {
|
||||||
|
sess := []types.Modality{types.ModalityAudio}
|
||||||
|
resp := []types.Modality{types.ModalityText}
|
||||||
|
got := resolveOutputModalities(sess, resp)
|
||||||
|
Expect(got).To(ConsistOf(types.ModalityText))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns false from modalitiesContainAudio for text-only", func() {
|
||||||
|
Expect(modalitiesContainAudio([]types.Modality{types.ModalityText})).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns true from modalitiesContainAudio for audio (default)", func() {
|
||||||
|
Expect(modalitiesContainAudio([]types.Modality{types.ModalityAudio})).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns true when both audio and text are present", func() {
|
||||||
|
Expect(modalitiesContainAudio([]types.Modality{types.ModalityText, types.ModalityAudio})).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -17,16 +17,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type APIExchangeRequest struct {
|
type APIExchangeRequest struct {
|
||||||
Method string `json:"method"`
|
Method string `json:"method"`
|
||||||
Path string `json:"path"`
|
Path string `json:"path"`
|
||||||
Headers *http.Header `json:"headers"`
|
Headers *http.Header `json:"headers"`
|
||||||
Body *[]byte `json:"body"`
|
Body *[]byte `json:"body"`
|
||||||
|
BodyTruncated bool `json:"body_truncated,omitempty"`
|
||||||
|
BodyBytes int `json:"body_bytes,omitempty"` // original size before truncation
|
||||||
}
|
}
|
||||||
|
|
||||||
type APIExchangeResponse struct {
|
type APIExchangeResponse struct {
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
Headers *http.Header `json:"headers"`
|
Headers *http.Header `json:"headers"`
|
||||||
Body *[]byte `json:"body"`
|
Body *[]byte `json:"body"`
|
||||||
|
BodyTruncated bool `json:"body_truncated,omitempty"`
|
||||||
|
BodyBytes int `json:"body_bytes,omitempty"` // original size before truncation
|
||||||
}
|
}
|
||||||
|
|
||||||
type APIExchange struct {
|
type APIExchange struct {
|
||||||
@@ -66,11 +70,29 @@ var doInitializeTracing = sync.OnceFunc(func() {
|
|||||||
|
|
||||||
type bodyWriter struct {
|
type bodyWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
body *bytes.Buffer
|
body *bytes.Buffer
|
||||||
|
maxBytes int // 0 = unlimited capture
|
||||||
|
truncated bool
|
||||||
|
totalBytes int // bytes the upstream handler wrote, even past the cap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *bodyWriter) Write(b []byte) (int, error) {
|
func (w *bodyWriter) Write(b []byte) (int, error) {
|
||||||
w.body.Write(b)
|
// Capture into the trace buffer up to maxBytes, then drop the overflow
|
||||||
|
// so a chatty endpoint can't grow the buffer without bound. The full
|
||||||
|
// payload still flows through to the real client below.
|
||||||
|
w.totalBytes += len(b)
|
||||||
|
if w.maxBytes <= 0 {
|
||||||
|
w.body.Write(b)
|
||||||
|
} else if remain := w.maxBytes - w.body.Len(); remain > 0 {
|
||||||
|
if remain >= len(b) {
|
||||||
|
w.body.Write(b)
|
||||||
|
} else {
|
||||||
|
w.body.Write(b[:remain])
|
||||||
|
w.truncated = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
w.truncated = true
|
||||||
|
}
|
||||||
return w.ResponseWriter.Write(b)
|
return w.ResponseWriter.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,6 +102,20 @@ func (w *bodyWriter) Flush() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncateForTrace returns a defensive copy of body capped at maxBytes,
|
||||||
|
// and a flag indicating whether the cap forced truncation. maxBytes <= 0
|
||||||
|
// disables the cap.
|
||||||
|
func truncateForTrace(body []byte, maxBytes int) ([]byte, bool) {
|
||||||
|
if maxBytes <= 0 || len(body) <= maxBytes {
|
||||||
|
out := make([]byte, len(body))
|
||||||
|
copy(out, body)
|
||||||
|
return out, false
|
||||||
|
}
|
||||||
|
out := make([]byte, maxBytes)
|
||||||
|
copy(out, body[:maxBytes])
|
||||||
|
return out, true
|
||||||
|
}
|
||||||
|
|
||||||
func initializeTracing(maxItems int) {
|
func initializeTracing(maxItems int) {
|
||||||
tracingMaxItems = maxItems
|
tracingMaxItems = maxItems
|
||||||
doInitializeTracing()
|
doInitializeTracing()
|
||||||
@@ -134,11 +170,18 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
|||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Cap captured payload size. Without this, /embeddings and
|
||||||
|
// streaming /chat/completions blow the in-memory buffer into the
|
||||||
|
// tens of MB, which then locks the admin Traces UI fetching the
|
||||||
|
// JSON dump faster than the 5s auto-refresh.
|
||||||
|
maxBodyBytes := app.ApplicationConfig().TracingMaxBodyBytes
|
||||||
|
|
||||||
// Wrap response writer to capture body
|
// Wrap response writer to capture body
|
||||||
resBody := new(bytes.Buffer)
|
resBody := new(bytes.Buffer)
|
||||||
mw := &bodyWriter{
|
mw := &bodyWriter{
|
||||||
ResponseWriter: c.Response().Writer,
|
ResponseWriter: c.Response().Writer,
|
||||||
body: resBody,
|
body: resBody,
|
||||||
|
maxBytes: maxBodyBytes,
|
||||||
}
|
}
|
||||||
c.Response().Writer = mw
|
c.Response().Writer = mw
|
||||||
|
|
||||||
@@ -159,8 +202,7 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
|||||||
// via any heap-dump-style introspection, and tokens shouldn't
|
// via any heap-dump-style introspection, and tokens shouldn't
|
||||||
// outlive the request that carried them.
|
// outlive the request that carried them.
|
||||||
requestHeaders := redactSensitiveHeaders(c.Request().Header)
|
requestHeaders := redactSensitiveHeaders(c.Request().Header)
|
||||||
requestBody := make([]byte, len(body))
|
requestBody, requestTruncated := truncateForTrace(body, maxBodyBytes)
|
||||||
copy(requestBody, body)
|
|
||||||
responseHeaders := redactSensitiveHeaders(c.Response().Header())
|
responseHeaders := redactSensitiveHeaders(c.Response().Header())
|
||||||
responseBody := make([]byte, resBody.Len())
|
responseBody := make([]byte, resBody.Len())
|
||||||
copy(responseBody, resBody.Bytes())
|
copy(responseBody, resBody.Bytes())
|
||||||
@@ -168,15 +210,19 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
|||||||
Timestamp: startTime,
|
Timestamp: startTime,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
Request: APIExchangeRequest{
|
Request: APIExchangeRequest{
|
||||||
Method: c.Request().Method,
|
Method: c.Request().Method,
|
||||||
Path: c.Path(),
|
Path: c.Path(),
|
||||||
Headers: &requestHeaders,
|
Headers: &requestHeaders,
|
||||||
Body: &requestBody,
|
Body: &requestBody,
|
||||||
|
BodyTruncated: requestTruncated,
|
||||||
|
BodyBytes: len(body),
|
||||||
},
|
},
|
||||||
Response: APIExchangeResponse{
|
Response: APIExchangeResponse{
|
||||||
Status: status,
|
Status: status,
|
||||||
Headers: &responseHeaders,
|
Headers: &responseHeaders,
|
||||||
Body: &responseBody,
|
Body: &responseBody,
|
||||||
|
BodyTruncated: mw.truncated,
|
||||||
|
BodyBytes: mw.totalBytes,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if handlerErr != nil {
|
if handlerErr != nil {
|
||||||
|
|||||||
116
core/http/middleware/trace_body_cap_test.go
Normal file
116
core/http/middleware/trace_body_cap_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The trace middleware copies request and response bodies into an in-memory
|
||||||
|
// buffer that backs the admin /api/traces endpoint. With no upper bound a
|
||||||
|
// chatty workload (embeddings, large completions) trivially produces a
|
||||||
|
// multi-MB response that locks the Traces UI in a loading state — fetching
|
||||||
|
// and parsing the payload outruns the 5-second auto-refresh. These specs
|
||||||
|
// pin the capping contract so future refactors keep both the cap and the
|
||||||
|
// passthrough to the real client intact.
|
||||||
|
|
||||||
|
var _ = Describe("bodyWriter capping", func() {
|
||||||
|
It("captures the full body when maxBytes is 0 (unlimited)", func() {
|
||||||
|
downstream := httptest.NewRecorder()
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 0}
|
||||||
|
|
||||||
|
payload := []byte(strings.Repeat("x", 4096))
|
||||||
|
n, err := bw.Write(payload)
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(len(payload)))
|
||||||
|
Expect(buf.Len()).To(Equal(len(payload)))
|
||||||
|
Expect(downstream.Body.Len()).To(Equal(len(payload)))
|
||||||
|
Expect(bw.truncated).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("stops appending to the trace buffer once maxBytes is reached but still forwards to the client", func() {
|
||||||
|
downstream := httptest.NewRecorder()
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 100}
|
||||||
|
|
||||||
|
payload := []byte(strings.Repeat("a", 250))
|
||||||
|
n, err := bw.Write(payload)
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(len(payload)), "Write must return the full byte count so callers see no short write")
|
||||||
|
Expect(buf.Len()).To(Equal(100), "trace buffer should hold exactly maxBytes")
|
||||||
|
Expect(downstream.Body.Len()).To(Equal(len(payload)), "client must still receive every byte")
|
||||||
|
Expect(bw.truncated).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("handles a write that straddles the cap by keeping only the leading slice", func() {
|
||||||
|
downstream := httptest.NewRecorder()
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 10}
|
||||||
|
|
||||||
|
_, err := bw.Write([]byte("12345"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(bw.truncated).To(BeFalse())
|
||||||
|
|
||||||
|
_, err = bw.Write([]byte("67890ABCDE"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
Expect(buf.String()).To(Equal("1234567890"))
|
||||||
|
Expect(downstream.Body.String()).To(Equal("1234567890ABCDE"))
|
||||||
|
Expect(bw.truncated).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("ignores further writes after the cap was already hit", func() {
|
||||||
|
downstream := httptest.NewRecorder()
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 4}
|
||||||
|
|
||||||
|
_, _ = bw.Write([]byte("AAAA"))
|
||||||
|
_, _ = bw.Write([]byte("BBBB"))
|
||||||
|
_, _ = bw.Write([]byte("CCCC"))
|
||||||
|
|
||||||
|
Expect(buf.String()).To(Equal("AAAA"))
|
||||||
|
Expect(downstream.Body.String()).To(Equal("AAAABBBBCCCC"))
|
||||||
|
Expect(bw.truncated).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = Describe("truncateForTrace", func() {
|
||||||
|
It("returns the input unchanged when below the cap", func() {
|
||||||
|
in := []byte("hello")
|
||||||
|
out, truncated := truncateForTrace(in, 1024)
|
||||||
|
Expect(truncated).To(BeFalse())
|
||||||
|
Expect(out).To(Equal(in))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("truncates when the input exceeds the cap and signals truncation", func() {
|
||||||
|
in := []byte(strings.Repeat("z", 200))
|
||||||
|
out, truncated := truncateForTrace(in, 64)
|
||||||
|
Expect(truncated).To(BeTrue())
|
||||||
|
Expect(out).To(HaveLen(64))
|
||||||
|
Expect(string(out)).To(Equal(strings.Repeat("z", 64)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("treats maxBytes <= 0 as unlimited (back-compat with current default)", func() {
|
||||||
|
in := []byte(strings.Repeat("q", 10_000))
|
||||||
|
out, truncated := truncateForTrace(in, 0)
|
||||||
|
Expect(truncated).To(BeFalse())
|
||||||
|
Expect(out).To(HaveLen(len(in)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not retain the caller's backing array (defensive copy)", func() {
|
||||||
|
in := []byte("abcdefghij")
|
||||||
|
out, truncated := truncateForTrace(in, 4)
|
||||||
|
Expect(truncated).To(BeTrue())
|
||||||
|
Expect(string(out)).To(Equal("abcd"))
|
||||||
|
|
||||||
|
// Mutating the source must not corrupt the trace copy.
|
||||||
|
in[0] = 'Z'
|
||||||
|
Expect(string(out)).To(Equal("abcd"))
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
@@ -14,18 +15,37 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
usageFlushInterval = 5 * time.Second
|
usageFlushInterval = 5 * time.Second
|
||||||
usageMaxPending = 5000
|
// usageMaxPending bounds the in-memory queue. Sized for bursty inference
|
||||||
|
// traffic on a self-hosted instance with a slow or unavailable DB.
|
||||||
|
usageMaxPending = 50000
|
||||||
)
|
)
|
||||||
|
|
||||||
// usageBatcher accumulates usage records and flushes them to the DB periodically.
|
// usageBatcher accumulates usage records and flushes them to the DB periodically.
|
||||||
type usageBatcher struct {
|
type usageBatcher struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
pending []*auth.UsageRecord
|
pending []*auth.UsageRecord
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
stop chan struct{}
|
||||||
|
done chan struct{}
|
||||||
|
stopOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// droppedRecords counts records discarded because the in-memory queue was full.
|
||||||
|
// Used to rate-limit the warn log so a sustained outage doesn't flood it.
|
||||||
|
var droppedRecords atomic.Uint64
|
||||||
|
|
||||||
func (b *usageBatcher) add(r *auth.UsageRecord) {
|
func (b *usageBatcher) add(r *auth.UsageRecord) {
|
||||||
b.mu.Lock()
|
b.mu.Lock()
|
||||||
|
if len(b.pending) >= usageMaxPending {
|
||||||
|
b.mu.Unlock()
|
||||||
|
// Rate-limit: one warn per 1024 drops keeps the log readable.
|
||||||
|
n := droppedRecords.Add(1)
|
||||||
|
if n&1023 == 1 {
|
||||||
|
xlog.Warn("usage batcher full, dropping record",
|
||||||
|
"cap", usageMaxPending, "total_dropped", n)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
b.pending = append(b.pending, r)
|
b.pending = append(b.pending, r)
|
||||||
b.mu.Unlock()
|
b.mu.Unlock()
|
||||||
}
|
}
|
||||||
@@ -42,31 +62,102 @@ func (b *usageBatcher) flush() {
|
|||||||
|
|
||||||
if err := b.db.Create(&batch).Error; err != nil {
|
if err := b.db.Create(&batch).Error; err != nil {
|
||||||
xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err)
|
xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err)
|
||||||
// Re-queue failed records with a cap to avoid unbounded growth
|
// Cap-aware re-queue: prepend as much of the failed batch as fits
|
||||||
|
// alongside any records added concurrently with the failed write.
|
||||||
b.mu.Lock()
|
b.mu.Lock()
|
||||||
if len(b.pending) < usageMaxPending {
|
room := usageMaxPending - len(b.pending)
|
||||||
b.pending = append(batch, b.pending...)
|
if room > 0 {
|
||||||
|
if room > len(batch) {
|
||||||
|
room = len(batch)
|
||||||
|
}
|
||||||
|
b.pending = append(batch[:room], b.pending...)
|
||||||
}
|
}
|
||||||
b.mu.Unlock()
|
b.mu.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var batcher *usageBatcher
|
func (b *usageBatcher) run() {
|
||||||
|
defer close(b.done)
|
||||||
|
ticker := time.NewTicker(usageFlushInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
b.flush()
|
||||||
|
case <-b.stop:
|
||||||
|
b.flush() // final drain
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *usageBatcher) shutdown() {
|
||||||
|
b.stopOnce.Do(func() {
|
||||||
|
close(b.stop)
|
||||||
|
<-b.done
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// The package-level batcher is guarded by batcherMu so Init / Shutdown cycles
|
||||||
|
// (the test pattern) don't race against UsageMiddleware reads.
|
||||||
|
var (
|
||||||
|
batcherMu sync.RWMutex
|
||||||
|
batcher *usageBatcher
|
||||||
|
)
|
||||||
|
|
||||||
|
func currentBatcher() *usageBatcher {
|
||||||
|
batcherMu.RLock()
|
||||||
|
defer batcherMu.RUnlock()
|
||||||
|
return batcher
|
||||||
|
}
|
||||||
|
|
||||||
// InitUsageRecorder starts a background goroutine that periodically flushes
|
// InitUsageRecorder starts a background goroutine that periodically flushes
|
||||||
// accumulated usage records to the database.
|
// accumulated usage records to the database. Calling it more than once
|
||||||
|
// shuts down the previous batcher first so its goroutine doesn't leak.
|
||||||
func InitUsageRecorder(db *gorm.DB) {
|
func InitUsageRecorder(db *gorm.DB) {
|
||||||
if db == nil {
|
if db == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
batcher = &usageBatcher{db: db}
|
|
||||||
go func() {
|
batcherMu.Lock()
|
||||||
ticker := time.NewTicker(usageFlushInterval)
|
old := batcher
|
||||||
defer ticker.Stop()
|
batcher = nil
|
||||||
for range ticker.C {
|
batcherMu.Unlock()
|
||||||
batcher.flush()
|
if old != nil {
|
||||||
}
|
old.shutdown()
|
||||||
}()
|
}
|
||||||
|
|
||||||
|
b := &usageBatcher{
|
||||||
|
db: db,
|
||||||
|
stop: make(chan struct{}),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
batcherMu.Lock()
|
||||||
|
batcher = b
|
||||||
|
batcherMu.Unlock()
|
||||||
|
|
||||||
|
go b.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShutdownUsageRecorder stops the background flusher and synchronously drains
|
||||||
|
// pending records once. Safe to call multiple times. Not yet wired into the
|
||||||
|
// application lifecycle; intended for graceful process exit and tests.
|
||||||
|
func ShutdownUsageRecorder() {
|
||||||
|
batcherMu.Lock()
|
||||||
|
b := batcher
|
||||||
|
batcher = nil
|
||||||
|
batcherMu.Unlock()
|
||||||
|
if b != nil {
|
||||||
|
b.shutdown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlushNow synchronously flushes any pending usage records. Intended for tests
|
||||||
|
// that need deterministic behaviour without waiting for the ticker.
|
||||||
|
func FlushNow() {
|
||||||
|
if b := currentBatcher(); b != nil {
|
||||||
|
b.flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// usageResponseBody is the minimal structure we need from the response JSON.
|
// usageResponseBody is the minimal structure we need from the response JSON.
|
||||||
@@ -84,7 +175,8 @@ type usageResponseBody struct {
|
|||||||
func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
if db == nil || batcher == nil {
|
b := currentBatcher()
|
||||||
|
if db == nil || b == nil {
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,9 +241,17 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
|||||||
return handlerErr
|
return handlerErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
source := auth.GetSource(c)
|
||||||
|
if source == "" {
|
||||||
|
// Auth disabled or unrecognised path: classify as web so the row is still
|
||||||
|
// bucketable rather than silently dropped from per-source aggregates.
|
||||||
|
source = auth.UsageSourceWeb
|
||||||
|
}
|
||||||
|
|
||||||
record := &auth.UsageRecord{
|
record := &auth.UsageRecord{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
UserName: user.Name,
|
UserName: user.Name,
|
||||||
|
Source: source,
|
||||||
Model: resp.Model,
|
Model: resp.Model,
|
||||||
Endpoint: c.Request().URL.Path,
|
Endpoint: c.Request().URL.Path,
|
||||||
PromptTokens: resp.Usage.PromptTokens,
|
PromptTokens: resp.Usage.PromptTokens,
|
||||||
@@ -161,7 +261,13 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
|||||||
CreatedAt: startTime,
|
CreatedAt: startTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
batcher.add(record)
|
if key := auth.GetAPIKey(c); key != nil {
|
||||||
|
id := key.ID
|
||||||
|
record.APIKeyID = &id
|
||||||
|
record.APIKeyName = key.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
b.add(record)
|
||||||
|
|
||||||
return handlerErr
|
return handlerErr
|
||||||
}
|
}
|
||||||
|
|||||||
140
core/http/middleware/usage_test.go
Normal file
140
core/http/middleware/usage_test.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
//go:build auth
|
||||||
|
|
||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/core/http/auth"
|
||||||
|
"github.com/mudler/LocalAI/core/http/middleware"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testAuthDB returns a fresh in-memory SQLite auth DB.
|
||||||
|
func testAuthDB() *gorm.DB {
|
||||||
|
db, err := auth.InitDB(":memory:")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("UsageMiddleware", func() {
|
||||||
|
var (
|
||||||
|
e *echo.Echo
|
||||||
|
db *gorm.DB
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
db = testAuthDB()
|
||||||
|
e = echo.New()
|
||||||
|
middleware.InitUsageRecorder(db)
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
middleware.ShutdownUsageRecorder()
|
||||||
|
})
|
||||||
|
|
||||||
|
okHandler := func(c echo.Context) error {
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"usage": map[string]int{
|
||||||
|
"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Response().Header().Set("Content-Type", "application/json")
|
||||||
|
c.Response().WriteHeader(http.StatusOK)
|
||||||
|
_, _ = c.Response().Write(body)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlushNow drains pending records synchronously, replacing the 6s sleep
|
||||||
|
// that was previously needed to wait for the batcher's ticker.
|
||||||
|
flush := middleware.FlushNow
|
||||||
|
|
||||||
|
It("records source=web when auth_source is web", func() {
|
||||||
|
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"})
|
||||||
|
c.Set("auth_source", auth.UsageSourceWeb)
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}, middleware.UsageMiddleware(db))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||||
|
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
var rec auth.UsageRecord
|
||||||
|
Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed())
|
||||||
|
Expect(rec.Source).To(Equal(auth.UsageSourceWeb))
|
||||||
|
Expect(rec.APIKeyID).To(BeNil())
|
||||||
|
Expect(rec.APIKeyName).To(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("records source=apikey with snapshotted name when auth_apikey is set", func() {
|
||||||
|
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"})
|
||||||
|
c.Set("auth_source", auth.UsageSourceAPIKey)
|
||||||
|
c.Set("auth_apikey", &auth.UserAPIKey{ID: "key-1", Name: "ci-runner"})
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}, middleware.UsageMiddleware(db))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||||
|
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
var rec auth.UsageRecord
|
||||||
|
Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed())
|
||||||
|
Expect(rec.Source).To(Equal(auth.UsageSourceAPIKey))
|
||||||
|
Expect(rec.APIKeyID).ToNot(BeNil())
|
||||||
|
Expect(*rec.APIKeyID).To(Equal("key-1"))
|
||||||
|
Expect(rec.APIKeyName).To(Equal("ci-runner"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("FlushNow drains pending records synchronously", func() {
|
||||||
|
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
c.Set("auth_user", &auth.User{ID: "carol", Name: "Carol"})
|
||||||
|
c.Set("auth_source", auth.UsageSourceWeb)
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}, middleware.UsageMiddleware(db))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||||
|
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
// No sleep: FlushNow should drain immediately.
|
||||||
|
middleware.FlushNow()
|
||||||
|
|
||||||
|
var rec auth.UsageRecord
|
||||||
|
Expect(db.Where("user_id = ?", "carol").First(&rec).Error).To(Succeed())
|
||||||
|
Expect(rec.Source).To(Equal(auth.UsageSourceWeb))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("falls back to source=web when auth_source is empty", func() {
|
||||||
|
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"})
|
||||||
|
// no auth_source set
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}, middleware.UsageMiddleware(db))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||||
|
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
var rec auth.UsageRecord
|
||||||
|
Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed())
|
||||||
|
Expect(rec.Source).To(Equal(auth.UsageSourceWeb))
|
||||||
|
})
|
||||||
|
})
|
||||||
143
core/http/react-ui/e2e/chat-polling-selection.spec.js
Normal file
143
core/http/react-ui/e2e/chat-polling-selection.spec.js
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import { test, expect } from '@playwright/test'
|
||||||
|
|
||||||
|
// Regression coverage for issue #9904:
|
||||||
|
// - /api/operations was polled every 1s and *always* re-rendered the Chat
|
||||||
|
// page, even when the response was unchanged. The reconciliation would
|
||||||
|
// collapse any text selection inside an assistant message.
|
||||||
|
// - The copy button next to each assistant message used navigator.clipboard
|
||||||
|
// without any fallback, which is undefined when the page is served over
|
||||||
|
// plain http (non-secure context) from a remote host.
|
||||||
|
|
||||||
|
async function setupChatPage(page) {
|
||||||
|
await page.route('**/api/models/capabilities', (route) => {
|
||||||
|
route.fulfill({
|
||||||
|
contentType: 'application/json',
|
||||||
|
body: JSON.stringify({
|
||||||
|
data: [{ id: 'test-model', capabilities: ['FLAG_CHAT'] }],
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Poll-tracking mock: assert the hook is hammering /api/operations every
|
||||||
|
// ~1s, and always return an empty list so its contents never change.
|
||||||
|
let operationsHits = 0
|
||||||
|
await page.route('**/api/operations', (route) => {
|
||||||
|
operationsHits++
|
||||||
|
route.fulfill({
|
||||||
|
contentType: 'application/json',
|
||||||
|
body: JSON.stringify({ operations: [] }),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
await page.route('**/v1/chat/completions', (route) => {
|
||||||
|
// One short SSE stream so the chat finishes streaming quickly and we
|
||||||
|
// can interact with a stable assistant message.
|
||||||
|
const body = [
|
||||||
|
'data: {"choices":[{"delta":{"content":"Hello world this is a long assistant reply that we can try to select."},"index":0}]}\n\n',
|
||||||
|
'data: {"choices":[{"delta":{},"index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}\n\n',
|
||||||
|
'data: [DONE]\n\n',
|
||||||
|
].join('')
|
||||||
|
route.fulfill({
|
||||||
|
status: 200,
|
||||||
|
headers: { 'Content-Type': 'text/event-stream' },
|
||||||
|
body,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return { getOperationsHits: () => operationsHits }
|
||||||
|
}
|
||||||
|
|
||||||
|
test.describe('Chat - /api/operations polling (#9904)', () => {
|
||||||
|
test('text selection inside an assistant message survives polling', async ({ page }) => {
|
||||||
|
const { getOperationsHits } = await setupChatPage(page)
|
||||||
|
|
||||||
|
await page.goto('/app/chat')
|
||||||
|
await expect(page.getByRole('button', { name: 'test-model' })).toBeVisible({ timeout: 10_000 })
|
||||||
|
|
||||||
|
await page.locator('.chat-input').fill('Hi')
|
||||||
|
await page.locator('.chat-send-btn').click()
|
||||||
|
|
||||||
|
const assistantContent = page.locator('.chat-message-assistant .chat-message-content').first()
|
||||||
|
await expect(assistantContent).toContainText('Hello world', { timeout: 10_000 })
|
||||||
|
|
||||||
|
// Sanity check: the polling we're regressing against is actually firing.
|
||||||
|
await page.waitForTimeout(2_500)
|
||||||
|
expect(getOperationsHits()).toBeGreaterThan(1)
|
||||||
|
|
||||||
|
// Sanity check that the bug we're guarding against is structurally
|
||||||
|
// possible: count how many times the assistant content node gets
|
||||||
|
// *touched* by React (childList / characterData mutations) over a
|
||||||
|
// 3-second window. Before the fix, every poll re-rendered Chat and
|
||||||
|
// re-set dangerouslySetInnerHTML, triggering a mutation cascade that
|
||||||
|
// collapsed the user's text selection. After the fix, polling with
|
||||||
|
// identical contents must not mutate the DOM at all.
|
||||||
|
const mutationCount = await assistantContent.evaluate((el) => new Promise((resolve) => {
|
||||||
|
let count = 0
|
||||||
|
const obs = new MutationObserver((records) => { count += records.length })
|
||||||
|
obs.observe(el, { childList: true, subtree: true, characterData: true })
|
||||||
|
setTimeout(() => { obs.disconnect(); resolve(count) }, 3_000)
|
||||||
|
}))
|
||||||
|
expect(mutationCount).toBe(0)
|
||||||
|
|
||||||
|
// Same sanity check translated to a user-observable property: a
|
||||||
|
// programmatically created selection survives the polling window.
|
||||||
|
await assistantContent.evaluate((el) => {
|
||||||
|
const range = document.createRange()
|
||||||
|
range.selectNodeContents(el)
|
||||||
|
const sel = window.getSelection()
|
||||||
|
sel.removeAllRanges()
|
||||||
|
sel.addRange(range)
|
||||||
|
})
|
||||||
|
|
||||||
|
const initialSelection = await page.evaluate(() => window.getSelection().toString())
|
||||||
|
expect(initialSelection).toContain('Hello world')
|
||||||
|
|
||||||
|
await page.waitForTimeout(2_500)
|
||||||
|
|
||||||
|
const selectionAfterPolling = await page.evaluate(() => window.getSelection().toString())
|
||||||
|
expect(selectionAfterPolling).toBe(initialSelection)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test.describe('Chat - copy button (#9904)', () => {
|
||||||
|
test('copy button works when navigator.clipboard is unavailable (plain http)', async ({ page }) => {
|
||||||
|
await setupChatPage(page)
|
||||||
|
|
||||||
|
// Simulate a non-secure context: hide navigator.clipboard before any of
|
||||||
|
// our app code touches it. This mirrors what browsers do over plain
|
||||||
|
// http from a remote host.
|
||||||
|
await page.addInitScript(() => {
|
||||||
|
Object.defineProperty(window, 'isSecureContext', { value: false, configurable: true })
|
||||||
|
try {
|
||||||
|
Object.defineProperty(navigator, 'clipboard', { value: undefined, configurable: true })
|
||||||
|
} catch { /* some browsers refuse — the secure-context flag is enough */ }
|
||||||
|
})
|
||||||
|
|
||||||
|
await page.goto('/app/chat')
|
||||||
|
await expect(page.getByRole('button', { name: 'test-model' })).toBeVisible({ timeout: 10_000 })
|
||||||
|
|
||||||
|
await page.locator('.chat-input').fill('Hi')
|
||||||
|
await page.locator('.chat-send-btn').click()
|
||||||
|
|
||||||
|
const assistantBubble = page.locator('.chat-message-assistant .chat-message-bubble').first()
|
||||||
|
await expect(assistantBubble).toContainText('Hello world', { timeout: 10_000 })
|
||||||
|
|
||||||
|
// Spy on document.execCommand so we can confirm the fallback path ran.
|
||||||
|
await page.evaluate(() => {
|
||||||
|
window.__execCommandCalls = []
|
||||||
|
const original = document.execCommand?.bind(document)
|
||||||
|
document.execCommand = (cmd, ...rest) => {
|
||||||
|
window.__execCommandCalls.push(cmd)
|
||||||
|
// execCommand('copy') in a headless browser may return false because
|
||||||
|
// there is no real clipboard, but the fact that we tried is what we
|
||||||
|
// care about for this regression.
|
||||||
|
return original ? original(cmd, ...rest) : false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
await assistantBubble.locator('.chat-message-actions button').first().click()
|
||||||
|
|
||||||
|
const execCommandCalls = await page.evaluate(() => window.__execCommandCalls)
|
||||||
|
expect(execCommandCalls).toContain('copy')
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -97,7 +97,8 @@
|
|||||||
},
|
},
|
||||||
"toasts": {
|
"toasts": {
|
||||||
"selectModel": "Bitte wählen Sie ein Modell",
|
"selectModel": "Bitte wählen Sie ein Modell",
|
||||||
"copied": "In die Zwischenablage kopiert"
|
"copied": "In die Zwischenablage kopiert",
|
||||||
|
"copyFailed": "Kopieren in die Zwischenablage fehlgeschlagen"
|
||||||
},
|
},
|
||||||
"menu": {
|
"menu": {
|
||||||
"trigger": "Chats",
|
"trigger": "Chats",
|
||||||
|
|||||||
@@ -53,7 +53,30 @@
|
|||||||
},
|
},
|
||||||
"usage": {
|
"usage": {
|
||||||
"title": "Usage",
|
"title": "Usage",
|
||||||
"subtitle": "API token usage statistics"
|
"subtitle": "API token usage statistics",
|
||||||
|
"sources": {
|
||||||
|
"tab": "Sources",
|
||||||
|
"mixTitle": "Source mix",
|
||||||
|
"ribbonAria": "{{apikey}}% API keys, {{web}}% Web UI, {{legacy}}% Legacy",
|
||||||
|
"topSources": "Top sources over time",
|
||||||
|
"searchPlaceholder": "Search by name or prefix",
|
||||||
|
"sortBy": "Sort",
|
||||||
|
"sortTokens": "Tokens",
|
||||||
|
"sortRequests": "Requests",
|
||||||
|
"sortLastUsed": "Last used",
|
||||||
|
"sortName": "Name",
|
||||||
|
"sortUser": "User",
|
||||||
|
"webUI": "Web UI",
|
||||||
|
"legacy": "Legacy",
|
||||||
|
"revoked": "revoked",
|
||||||
|
"filteredTo": "Filtered to: {{name}}",
|
||||||
|
"clearFilter": "Clear filter",
|
||||||
|
"other": "Other ({{count}})",
|
||||||
|
"noTrafficShort": "No requests in this period.",
|
||||||
|
"noKeysYet": "Once requests come in, you'll see them broken down here.",
|
||||||
|
"createKey": "Create your first API key",
|
||||||
|
"truncatedWarning": "Showing top 200 keys. Apply a filter to narrow further."
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"explorer": {
|
"explorer": {
|
||||||
"title": "Explorer",
|
"title": "Explorer",
|
||||||
|
|||||||
@@ -97,7 +97,8 @@
|
|||||||
},
|
},
|
||||||
"toasts": {
|
"toasts": {
|
||||||
"selectModel": "Please select a model",
|
"selectModel": "Please select a model",
|
||||||
"copied": "Copied to clipboard"
|
"copied": "Copied to clipboard",
|
||||||
|
"copyFailed": "Could not copy to clipboard"
|
||||||
},
|
},
|
||||||
"menu": {
|
"menu": {
|
||||||
"trigger": "Chats",
|
"trigger": "Chats",
|
||||||
|
|||||||
@@ -97,7 +97,8 @@
|
|||||||
},
|
},
|
||||||
"toasts": {
|
"toasts": {
|
||||||
"selectModel": "Por favor selecciona un modelo",
|
"selectModel": "Por favor selecciona un modelo",
|
||||||
"copied": "Copiado al portapapeles"
|
"copied": "Copiado al portapapeles",
|
||||||
|
"copyFailed": "No se pudo copiar al portapapeles"
|
||||||
},
|
},
|
||||||
"menu": {
|
"menu": {
|
||||||
"trigger": "Chats",
|
"trigger": "Chats",
|
||||||
|
|||||||
@@ -97,7 +97,8 @@
|
|||||||
},
|
},
|
||||||
"toasts": {
|
"toasts": {
|
||||||
"selectModel": "Seleziona un modello",
|
"selectModel": "Seleziona un modello",
|
||||||
"copied": "Copiato negli appunti"
|
"copied": "Copiato negli appunti",
|
||||||
|
"copyFailed": "Impossibile copiare negli appunti"
|
||||||
},
|
},
|
||||||
"menu": {
|
"menu": {
|
||||||
"trigger": "Chat",
|
"trigger": "Chat",
|
||||||
|
|||||||
@@ -97,7 +97,8 @@
|
|||||||
},
|
},
|
||||||
"toasts": {
|
"toasts": {
|
||||||
"selectModel": "请选择一个模型",
|
"selectModel": "请选择一个模型",
|
||||||
"copied": "已复制到剪贴板"
|
"copied": "已复制到剪贴板",
|
||||||
|
"copyFailed": "无法复制到剪贴板"
|
||||||
},
|
},
|
||||||
"menu": {
|
"menu": {
|
||||||
"trigger": "聊天",
|
"trigger": "聊天",
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user