mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-26 01:31:27 -04:00
Compare commits
2 Commits
master
...
fix/turboq
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6fed26271 | ||
|
|
66af748332 |
@@ -16,8 +16,7 @@ side (`pkg/oci/cosignverify` plus the gallery YAML).
|
||||
per-arch manifest before checking signatures.
|
||||
- **Storage:** Signatures are written as OCI 1.1 referrers
|
||||
(`--registry-referrers-mode=oci-1-1`) in the new Sigstore bundle format
|
||||
(current cosign releases do this by default; no `--new-bundle-format`
|
||||
flag). No `:sha256-<hex>.sig` tag clutter.
|
||||
(`--new-bundle-format`). No `:sha256-<hex>.sig` tag clutter.
|
||||
- **Consumer:** `pkg/oci/cosignverify` discovers the bundle via the
|
||||
referrers API, hands it to `sigstore-go`, and verifies it against the
|
||||
policy declared in the gallery YAML (`Gallery.Verification`).
|
||||
@@ -34,14 +33,15 @@ to sign. The job needs:
|
||||
|
||||
- `permissions: { id-token: write, contents: read }` at the job level so
|
||||
the runner can exchange its GitHub OIDC token for a Fulcio cert.
|
||||
- `sigstore/cosign-installer@v3` step (current cosign releases already
|
||||
default to the new bundle format).
|
||||
- `sigstore/cosign-installer@v3` step (cosign ≥ 2.2 for
|
||||
`--new-bundle-format`).
|
||||
- After each `docker buildx imagetools create`, resolve the resulting
|
||||
list digest with `docker buildx imagetools inspect <tag> --format
|
||||
'{{.Manifest.Digest}}'` and sign:
|
||||
|
||||
```sh
|
||||
cosign sign --yes --recursive \
|
||||
--new-bundle-format \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"${REGISTRY_REPO}@${DIGEST}"
|
||||
```
|
||||
@@ -49,12 +49,6 @@ cosign sign --yes --recursive \
|
||||
Sign by digest, never by tag — signing by tag binds the signature to
|
||||
whatever the tag points at *now*, and a subsequent tag push orphans it.
|
||||
|
||||
`--registry-referrers-mode=oci-1-1` is still gated behind
|
||||
`COSIGN_EXPERIMENTAL=1` in cosign v2.4.x (set at the job env level in
|
||||
`backend_merge.yml`). Re-evaluate when bumping the pinned cosign release
|
||||
— newer versions are expected to graduate this flag and the env var can
|
||||
then be dropped.
|
||||
|
||||
`backend_build_darwin.yml` builds and pushes single-arch darwin images
|
||||
that bypass the manifest-list merge. If/when those entries get a gallery
|
||||
`verification:` policy, the equivalent cosign step has to land there
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
.devcontainer
|
||||
models
|
||||
backends
|
||||
volumes
|
||||
examples/chatbot-ui/models
|
||||
backend/go/image/stablediffusion-ggml/build/
|
||||
backend/go/*/build
|
||||
@@ -22,11 +21,3 @@ __pycache__
|
||||
# backend virtual environments
|
||||
**/venv
|
||||
backend/python/**/source
|
||||
|
||||
# In-place llama.cpp clone + per-variant build copies. The Makefile
|
||||
# clones llama.cpp itself at the pinned LLAMA_VERSION; if a stale
|
||||
# local checkout is COPY'd into the image, the `llama.cpp:` target
|
||||
# sees the directory and skips re-cloning, so grpc-server.cpp ends
|
||||
# up compiled against whatever (likely older) commit the host had.
|
||||
backend/cpp/llama-cpp/llama.cpp
|
||||
backend/cpp/llama-cpp-*-build
|
||||
|
||||
10
.github/workflows/backend_merge.yml
vendored
10
.github/workflows/backend_merge.yml
vendored
@@ -40,11 +40,6 @@ jobs:
|
||||
id-token: write
|
||||
env:
|
||||
quay_username: ${{ secrets.quayUsername }}
|
||||
# cosign v2.4.x still gates --registry-referrers-mode=oci-1-1 behind
|
||||
# this flag. Without it, signing fails with:
|
||||
# invalid argument "oci-1-1" for "--registry-referrers-mode" flag:
|
||||
# in order to use mode "oci-1-1", you must set COSIGN_EXPERIMENTAL=1
|
||||
COSIGN_EXPERIMENTAL: '1'
|
||||
steps:
|
||||
# Sparse checkout: the merge job needs `.github/scripts/` (for the
|
||||
# keepalive cleanup script) but none of the source tree.
|
||||
@@ -71,8 +66,7 @@ jobs:
|
||||
|
||||
# cosign signs each pushed manifest list with --recursive so the
|
||||
# index and every per-arch entry get an attached Sigstore bundle.
|
||||
# Recent cosign releases always emit the new bundle format, so
|
||||
# there's no extra CLI flag to opt into it.
|
||||
# 2.2+ is required for --new-bundle-format.
|
||||
- name: Install cosign
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: sigstore/cosign-installer@v3
|
||||
@@ -159,6 +153,7 @@ jobs:
|
||||
# manifest before checking signatures need the per-arch
|
||||
# signatures, not just the list-level one.
|
||||
cosign sign --yes --recursive \
|
||||
--new-bundle-format \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"quay.io/go-skynet/local-ai-backends@${digest}"
|
||||
|
||||
@@ -185,6 +180,7 @@ jobs:
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||
cosign sign --yes --recursive \
|
||||
--new-bundle-format \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"localai/localai-backends@${digest}"
|
||||
|
||||
|
||||
1
.github/workflows/image_build.yml
vendored
1
.github/workflows/image_build.yml
vendored
@@ -106,7 +106,6 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
1
.github/workflows/image_merge.yml
vendored
1
.github/workflows/image_merge.yml
vendored
@@ -80,7 +80,6 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -26,10 +26,6 @@ go-bert
|
||||
LocalAI
|
||||
/local-ai
|
||||
/local-ai-launcher
|
||||
# Root-level build artifacts when running `go build ./...` against
|
||||
# Go backend packages whose main lives under backend/go/.
|
||||
/cloud-proxy
|
||||
/local-store
|
||||
# prevent above rules from omitting the helm chart
|
||||
!charts/*
|
||||
# prevent above rules from omitting the api/localai folder
|
||||
@@ -81,6 +77,3 @@ local-backends/
|
||||
tests/e2e-ui/ui-test-server
|
||||
core/http/react-ui/playwright-report/
|
||||
core/http/react-ui/test-results/
|
||||
|
||||
# Local worktrees
|
||||
.worktrees/
|
||||
|
||||
15
Makefile
15
Makefile
@@ -69,7 +69,7 @@ else
|
||||
GORELEASER=$(shell which goreleaser)
|
||||
endif
|
||||
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/... ./backend/go/cloud-proxy/... ./backend/go/local-store/...
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/...
|
||||
|
||||
|
||||
.PHONY: all test build vendor lint lint-all
|
||||
@@ -268,13 +268,12 @@ prepare-e2e:
|
||||
run-e2e-image:
|
||||
docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests
|
||||
|
||||
test-e2e: build-mock-backend build-cloud-proxy-backend prepare-e2e run-e2e-image
|
||||
test-e2e: build-mock-backend prepare-e2e run-e2e-image
|
||||
@echo 'Running e2e tests'
|
||||
BUILD_TYPE=$(BUILD_TYPE) \
|
||||
LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
|
||||
$(MAKE) clean-mock-backend
|
||||
$(MAKE) clean-cloud-proxy-backend
|
||||
$(MAKE) teardown-e2e
|
||||
docker rmi localai-tests
|
||||
|
||||
@@ -1065,7 +1064,6 @@ BACKEND_DS4 = ds4|ds4|.|false|false
|
||||
# Golang backends
|
||||
BACKEND_PIPER = piper|golang|.|false|true
|
||||
BACKEND_LOCAL_STORE = local-store|golang|.|false|true
|
||||
BACKEND_CLOUD_PROXY = cloud-proxy|golang|.|false|true
|
||||
BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
@@ -1151,7 +1149,6 @@ $(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DS4)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CLOUD_PROXY)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
@@ -1204,7 +1201,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
@@ -1216,12 +1213,6 @@ build-mock-backend: protogen-go
|
||||
clean-mock-backend:
|
||||
rm -f tests/e2e/mock-backend/mock-backend
|
||||
|
||||
build-cloud-proxy-backend: protogen-go
|
||||
$(GOCMD) build -o tests/e2e/mock-backend/cloud-proxy ./backend/go/cloud-proxy
|
||||
|
||||
clean-cloud-proxy-backend:
|
||||
rm -f tests/e2e/mock-backend/cloud-proxy
|
||||
|
||||
########################################################
|
||||
### UI E2E Test Server
|
||||
########################################################
|
||||
|
||||
@@ -37,22 +37,6 @@ service Backend {
|
||||
|
||||
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
||||
|
||||
// TokenClassify runs a token-classification (NER) model on the
|
||||
// supplied text and returns each detected entity span. Used by the
|
||||
// PII redactor's optional NER tier — the regex tier still handles
|
||||
// formatted hits cheaply, while this catches names, locations, and
|
||||
// other unformatted PII that regex misses.
|
||||
rpc TokenClassify(TokenClassifyRequest) returns (TokenClassifyResponse) {}
|
||||
|
||||
// Score evaluates the model's joint log-probability of each
|
||||
// supplied candidate continuation given a shared prompt. The
|
||||
// prompt's KV cache is computed once and reused across candidates.
|
||||
// Used for routing-policy multi-label classification, reranking,
|
||||
// calibrated confidence, and reward-model scoring — any task where
|
||||
// the consumer wants the model's confidence in a pre-specified
|
||||
// continuation rather than a generated one.
|
||||
rpc Score(ScoreRequest) returns (ScoreResponse) {}
|
||||
|
||||
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
|
||||
|
||||
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||
@@ -84,23 +68,6 @@ service Backend {
|
||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||
|
||||
// Forward proxies a raw HTTP request to an upstream provider. The
|
||||
// cloud-proxy backend implements this for passthrough-mode model
|
||||
// configs: the client wire format is preserved end-to-end (no
|
||||
// translation through internal proto), which means new provider
|
||||
// fields work the day they ship. Translation-mode proxies use the
|
||||
// standard Predict/PredictStream RPCs instead. Backends that don't
|
||||
// support this return UNIMPLEMENTED.
|
||||
//
|
||||
// The request is bidirectionally streamed so large bodies can flow
|
||||
// without buffering. In practice the first ForwardRequest carries
|
||||
// path, method, headers, and the initial body chunk; subsequent
|
||||
// messages append body chunks. The first ForwardReply carries the
|
||||
// upstream status and response headers; subsequent messages stream
|
||||
// body chunks (SSE frames or chunked transfer). Cancellation of the
|
||||
// gRPC context closes the upstream connection.
|
||||
rpc Forward(stream ForwardRequest) returns (stream ForwardReply) {}
|
||||
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -114,76 +81,6 @@ message MetricsResponse {
|
||||
int32 prompt_tokens_processed = 5;
|
||||
}
|
||||
|
||||
// TokenClassifyRequest carries the text to classify plus an optional
|
||||
// score threshold. The transformers backend interprets threshold as
|
||||
// the minimum confidence to include in the response; 0 = include all.
|
||||
message TokenClassifyRequest {
|
||||
string text = 1;
|
||||
float threshold = 2;
|
||||
}
|
||||
|
||||
// TokenClassifyEntity is one detected entity span. Byte offsets are
|
||||
// into the original UTF-8 text — start..end is a half-open range that
|
||||
// addresses the substring corresponding to entity_group.
|
||||
//
|
||||
// entity_group follows HuggingFace's aggregated-tag convention (e.g.
|
||||
// "PER", "LOC", "ORG", or a PII-specific label like "EMAIL" /
|
||||
// "SSN" depending on the model). The redactor's per-pattern action
|
||||
// map keys off this string.
|
||||
message TokenClassifyEntity {
|
||||
string entity_group = 1;
|
||||
int32 start = 2;
|
||||
int32 end = 3;
|
||||
float score = 4;
|
||||
string text = 5;
|
||||
}
|
||||
|
||||
message TokenClassifyResponse {
|
||||
repeated TokenClassifyEntity entities = 1;
|
||||
}
|
||||
|
||||
// ScoreRequest carries one shared prompt and one or more continuations
|
||||
// to score against it. The backend tokenises the prompt once and reuses
|
||||
// the resulting KV cache across all candidates in this request.
|
||||
message ScoreRequest {
|
||||
string prompt = 1;
|
||||
repeated string candidates = 2;
|
||||
// Return per-token logprobs for each candidate when true. Default
|
||||
// false to keep the wire response small; the joint log_prob field
|
||||
// covers the common ranking case.
|
||||
bool include_token_logprobs = 3;
|
||||
// When true, the response also populates length_normalized_log_prob
|
||||
// (joint log-prob divided by candidate token count). Useful when
|
||||
// candidates differ in length and the consumer wants a per-token
|
||||
// measure comparable across them (PMI-style scoring).
|
||||
bool length_normalize = 4;
|
||||
}
|
||||
|
||||
// CandidateScore is one row in the ScoreResponse, matching by index
|
||||
// the candidate in ScoreRequest.candidates.
|
||||
message CandidateScore {
|
||||
// Sum of log P(token_i | prompt, candidate_token_<i) across the
|
||||
// candidate's tokens. The primary ranking signal.
|
||||
double log_prob = 1;
|
||||
// log_prob / num_tokens — populated when length_normalize=true on
|
||||
// the request.
|
||||
double length_normalized_log_prob = 2;
|
||||
// Per-token detail — populated when include_token_logprobs=true.
|
||||
repeated TokenLogProb tokens = 3;
|
||||
// Number of tokens the backend tokenised this candidate into, after
|
||||
// any backend-specific normalisation (e.g. leading-space handling).
|
||||
int32 num_tokens = 4;
|
||||
}
|
||||
|
||||
message TokenLogProb {
|
||||
string token = 1;
|
||||
double log_prob = 2;
|
||||
}
|
||||
|
||||
message ScoreResponse {
|
||||
repeated CandidateScore candidates = 1;
|
||||
}
|
||||
|
||||
message RerankRequest {
|
||||
string query = 1;
|
||||
repeated string documents = 2;
|
||||
@@ -428,25 +325,6 @@ message ModelOptions {
|
||||
// applied verbatim to the backend's engine constructor (e.g. vLLM AsyncEngineArgs).
|
||||
// Unknown keys produce an error at LoadModel time.
|
||||
string EngineArgs = 73;
|
||||
|
||||
// Proxy carries the cloud-proxy backend's per-model configuration.
|
||||
// Empty for non-proxy backends.
|
||||
ProxyOptions Proxy = 74;
|
||||
}
|
||||
|
||||
// ProxyOptions configures the cloud-proxy backend. UpstreamURL and
|
||||
// Mode are always meaningful; Provider only matters in translate mode.
|
||||
// The two api_key_* fields are mutually exclusive and resolved by the
|
||||
// backend at LoadModel — core forwards the references rather than the
|
||||
// plaintext key.
|
||||
message ProxyOptions {
|
||||
string upstream_url = 1;
|
||||
string mode = 2;
|
||||
string provider = 3;
|
||||
string api_key_env = 4;
|
||||
string api_key_file = 5;
|
||||
string upstream_model = 6;
|
||||
int32 request_timeout_seconds = 7;
|
||||
}
|
||||
|
||||
message Result {
|
||||
@@ -1124,32 +1002,3 @@ message QuantizationStopRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
// ForwardHeader is one HTTP header on the request or response. Headers
|
||||
// like Authorization are typically injected by the backend (from the
|
||||
// resolved API key) rather than passed through from the client.
|
||||
message ForwardHeader {
|
||||
string name = 1;
|
||||
string value = 2;
|
||||
}
|
||||
|
||||
// ForwardRequest is a streamed HTTP request to the upstream. First
|
||||
// message carries path/method/headers; subsequent messages carry
|
||||
// body_chunk only. All fields except body_chunk are honoured on the
|
||||
// first message and ignored thereafter.
|
||||
message ForwardRequest {
|
||||
string path = 1; // e.g. "/v1/chat/completions" — appended to the model's upstream_url
|
||||
string method = 2; // usually "POST"
|
||||
repeated ForwardHeader headers = 3;
|
||||
bytes body_chunk = 4;
|
||||
}
|
||||
|
||||
// ForwardReply is a streamed HTTP response from the upstream. First
|
||||
// message carries status/headers; subsequent messages carry body_chunk
|
||||
// only. SSE responses arrive as a sequence of body_chunk frames; the
|
||||
// caller is responsible for any parsing.
|
||||
message ForwardReply {
|
||||
int32 status = 1;
|
||||
repeated ForwardHeader headers = 2;
|
||||
bytes body_chunk = 3;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=ad0209f6a4b067574d2b4afe896c08c177156b31
|
||||
# Upstream pin lives below as DS4_VERSION?=2606543be7a8c125a32cee37f5d1d85dc78f2fcf
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=ad0209f6a4b067574d2b4afe896c08c177156b31
|
||||
DS4_VERSION?=2606543be7a8c125a32cee37f5d1d85dc78f2fcf
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=b4e1d916c5ec7e75ea3c124dd090425a99fc613f
|
||||
IK_LLAMA_VERSION?=11a1fea9e291f12ce2c803a9d7812c30ca806bcf
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=549b9d84330c327e6791fa812a7d60c0cf63572e
|
||||
LLAMA_VERSION?=ad277572619fcfb6ddd38f4c6437283a4b2b8636
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -34,7 +34,6 @@
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
@@ -122,40 +121,6 @@ static std::string base64_encode_bytes(const unsigned char* data, size_t len) {
|
||||
|
||||
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
|
||||
|
||||
// Score bypasses the slot loop (see the comment on Score below) so it
|
||||
// must not run concurrently with any slot-loop RPC. These counters
|
||||
// are a defence-in-depth tripwire — ModelConfig.Validate already
|
||||
// rejects llama-cpp configs that mix score with chat/completion/
|
||||
// embeddings, so a healthy deployment never trips them. seq_cst is
|
||||
// load-bearing for the increment-then-check pattern below.
|
||||
static std::atomic<int> slot_loop_inflight{0};
|
||||
static std::atomic<int> score_inflight{0};
|
||||
|
||||
// Increment-then-check, not check-then-increment: two simultaneous
|
||||
// racers both observe the other's increment and both abort cleanly.
|
||||
// Reversed, both could see zero and proceed.
|
||||
struct conflict_guard {
|
||||
std::atomic<int>& self;
|
||||
conflict_guard(const char* rpc, std::atomic<int>& self_, std::atomic<int>& other, const char* other_name)
|
||||
: self(self_) {
|
||||
self.fetch_add(1, std::memory_order_seq_cst);
|
||||
int o = other.load(std::memory_order_seq_cst);
|
||||
if (o > 0) {
|
||||
fprintf(stderr,
|
||||
"FATAL: %s called with %s=%d. The llama-cpp backend cannot "
|
||||
"service Score and slot-loop RPCs concurrently — Score "
|
||||
"bypasses the slot loop and races the llama_context. Bind "
|
||||
"Score-using features to a model dedicated to scoring "
|
||||
"(known_usecases: [score] with no chat/completion/embeddings).\n",
|
||||
rpc, other_name, o);
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
~conflict_guard() {
|
||||
self.fetch_sub(1, std::memory_order_seq_cst);
|
||||
}
|
||||
};
|
||||
|
||||
static std::function<void(int)> shutdown_handler;
|
||||
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
||||
|
||||
@@ -552,27 +517,10 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.warmup = true;
|
||||
// no_op_offload: disable host tensor op offload (default: false)
|
||||
params.no_op_offload = false;
|
||||
// kv_unified: enable unified KV cache. Upstream's server auto-enables this
|
||||
// when the slot count is auto (-np <0), bumping n_parallel to 4 alongside.
|
||||
// LocalAI keeps n_parallel=1 by default, which would skip that auto path
|
||||
// and leave kv_unified=false. We flip the default to true here so the
|
||||
// server-side prompt cache (cache_idle_slots) is actually usable on the
|
||||
// single-slot path that LocalAI ships with: without it, idle slots are
|
||||
// never persisted across requests and the prompt cache is dead weight.
|
||||
// Users can opt out with `options: [ "kv_unified:false" ]`.
|
||||
params.kv_unified = true;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot. Match upstream's
|
||||
// default (32); the previous LocalAI-specific 8 was unnecessarily tight
|
||||
// and limits partial-prefix recovery without a clear memory rationale.
|
||||
params.n_ctx_checkpoints = 32;
|
||||
// cache_idle_slots: save and clear idle slot KV to the prompt cache on
|
||||
// task switch. Upstream default is true; the server auto-disables it if
|
||||
// kv_unified=false or cache_ram_mib=0, so flipping kv_unified above is
|
||||
// what actually unlocks it.
|
||||
params.cache_idle_slots = true;
|
||||
// checkpoint_every_nt: create a context checkpoint every N tokens during
|
||||
// prefill (-1 disables). Match upstream's default (8192).
|
||||
params.checkpoint_every_nt = 8192;
|
||||
// kv_unified: enable unified KV cache (default: false)
|
||||
params.kv_unified = false;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot (default: 8)
|
||||
params.n_ctx_checkpoints = 8;
|
||||
|
||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||
for (int i = 0; i < request->options_size(); i++) {
|
||||
@@ -731,29 +679,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
try {
|
||||
params.n_ctx_checkpoints = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (32)
|
||||
}
|
||||
}
|
||||
|
||||
// --- server-side idle-slot prompt cache toggle (upstream --cache-idle-slots) ---
|
||||
// Saves the slot's KV state into the host-side prompt cache on task
|
||||
// switch so a later request with the same prefix can warm-load it.
|
||||
// Auto-disabled by the server if kv_unified=false or cache_ram=0.
|
||||
} else if (!strcmp(optname, "cache_idle_slots") || !strcmp(optname, "idle_slots_cache")) {
|
||||
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||
params.cache_idle_slots = true;
|
||||
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
// --- prefill checkpoint cadence (upstream -cpent / --checkpoint-every-n-tokens) ---
|
||||
// -1 disables checkpointing during prefill.
|
||||
} else if (!strcmp(optname, "checkpoint_every_nt") || !strcmp(optname, "checkpoint_every_n_tokens")) {
|
||||
if (optval != NULL) {
|
||||
try {
|
||||
params.checkpoint_every_nt = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (8192)
|
||||
// If conversion fails, keep default value (8)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1481,7 +1407,6 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("PredictStream", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
|
||||
@@ -2241,7 +2166,6 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Predict", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
data["stream"] = false;
|
||||
@@ -3000,7 +2924,6 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Embedding", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
body["stream"] = false;
|
||||
@@ -3108,8 +3031,6 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||
}
|
||||
|
||||
conflict_guard guard("Rerank", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// Create and queue the task
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
{
|
||||
@@ -3182,218 +3103,12 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
// Score returns the model's joint log-probability of each candidate
|
||||
// continuation given a shared prompt.
|
||||
//
|
||||
// WHY bypass the slot/task queue: upstream server_context exposes
|
||||
// get_llama_context as "main thread only" and the slot loop's
|
||||
// update_slots() owns the context whenever a task is in flight.
|
||||
// No public synchronization primitive is available — so Score is
|
||||
// unsafe to call concurrently with active generation through this
|
||||
// backend. In practice routing-classifier calls happen before the
|
||||
// request is routed to a generation backend, so the model used
|
||||
// for Score is typically idle. Concurrent Score calls are
|
||||
// serialised by a local mutex; KV-cache state is isolated behind
|
||||
// a dedicated sequence ID cleared between candidates.
|
||||
//
|
||||
// A patch to server-context.cpp that adds SERVER_TASK_TYPE_SCORE
|
||||
// and routes scoring through the slot loop would be the correct
|
||||
// long-term fix; tracked as a follow-up.
|
||||
//
|
||||
// Perf TODO (measured: ~450 ms warm for 3 candidates on Arch-
|
||||
// Router-1.5B Q4_K_M + Intel SYCL): the current loop re-decodes
|
||||
// `prompt + candidate` from scratch for every candidate, throwing
|
||||
// away the prompt's KV cache between iterations. A smarter
|
||||
// version would:
|
||||
// 1. Decode just the prompt once into score_seq_id.
|
||||
// 2. Snapshot/cp that sequence (llama_memory_seq_cp) into a
|
||||
// per-candidate sequence id.
|
||||
// 3. For each candidate, decode only its tokens onto the copy
|
||||
// (continuing from the saved prompt state), read logits.
|
||||
// 4. llama_memory_seq_rm the copy.
|
||||
// Estimated speedup: 3-candidate calls 450 ms -> ~150-200 ms,
|
||||
// 6-candidate calls 630 ms -> ~220 ms. Single source-file change,
|
||||
// no proto / Go-side changes needed. Worth doing once routing is
|
||||
// wired into the middleware and Score is on the hot path of every
|
||||
// chat request.
|
||||
grpc::Status Score(ServerContext* context, const backend::ScoreRequest* request, backend::ScoreResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
if (request->candidates_size() == 0) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "candidates must be non-empty");
|
||||
}
|
||||
|
||||
// Tripwire against the slot loop. Acquired before score_mutex
|
||||
// so it fires even when this Score is queued behind another.
|
||||
conflict_guard guard("Score", score_inflight, slot_loop_inflight, "slot_loop_inflight");
|
||||
|
||||
// Serialise concurrent Score calls. The slot loop is still
|
||||
// free to race with us — see the class comment above.
|
||||
static std::mutex score_mutex;
|
||||
std::lock_guard<std::mutex> score_lock(score_mutex);
|
||||
|
||||
llama_context * lctx = ctx_server.get_llama_context();
|
||||
if (lctx == nullptr) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "llama context unavailable (sleeping?)");
|
||||
}
|
||||
const llama_vocab * vocab = ctx_server.impl->vocab;
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
const int32_t n_ctx = llama_n_ctx(lctx);
|
||||
llama_memory_t mem = llama_get_memory(lctx);
|
||||
|
||||
// The KV-cache is sized to seq_to_stream.size() at load
|
||||
// (typically equal to n_slots, often 1). Sequence IDs must
|
||||
// be in [0, n_seq_max), so we can't pick a high-value
|
||||
// "private" ID — we have to share with the slot. We clear
|
||||
// the cache before AND after each candidate to keep
|
||||
// scoring isolated from whatever state the slot held, and
|
||||
// the static mutex above guarantees no other Score call is
|
||||
// racing in the meantime. The slot loop is still free to
|
||||
// race (see comment on this method) — Score must not run
|
||||
// concurrently with generation through this backend.
|
||||
const llama_seq_id score_seq_id = 0;
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
|
||||
// Tokenize the shared prompt once with add_special=true so
|
||||
// BOS is prepended when the model requires it. parse_special
|
||||
// keeps chat-template markers in the prompt intact.
|
||||
const std::string prompt = request->prompt();
|
||||
std::vector<llama_token> prompt_tokens = common_tokenize(vocab, prompt, /*add_special=*/true, /*parse_special=*/true);
|
||||
const int32_t prompt_len = (int32_t) prompt_tokens.size();
|
||||
|
||||
for (int ci = 0; ci < request->candidates_size(); ci++) {
|
||||
const std::string & candidate_text = request->candidates(ci);
|
||||
|
||||
// Re-tokenize prompt + candidate as a single string. BPE
|
||||
// merges across the boundary can shift the tokenization
|
||||
// versus tokenize(prompt) ++ tokenize(candidate), so we
|
||||
// find the divergence point against prompt_tokens.
|
||||
std::vector<llama_token> full_tokens = common_tokenize(vocab, prompt + candidate_text, /*add_special=*/true, /*parse_special=*/true);
|
||||
int32_t divergence = prompt_len;
|
||||
const int32_t min_len = std::min<int32_t>(prompt_len, (int32_t) full_tokens.size());
|
||||
for (int32_t i = 0; i < min_len; i++) {
|
||||
if (prompt_tokens[i] != full_tokens[i]) {
|
||||
divergence = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const int32_t cand_len = (int32_t) full_tokens.size() - divergence;
|
||||
backend::CandidateScore * cs = response->add_candidates();
|
||||
cs->set_num_tokens(cand_len);
|
||||
if (cand_len <= 0) {
|
||||
cs->set_log_prob(0.0);
|
||||
if (request->length_normalize()) {
|
||||
cs->set_length_normalized_log_prob(0.0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (divergence < 1) {
|
||||
// Need at least one prior token (typically BOS) to
|
||||
// predict the first candidate token's logit. Tokeniser
|
||||
// models without BOS + an empty prompt fall in here.
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
|
||||
"Score: prompt produced no leading tokens; need at least one (e.g. BOS) to predict candidate");
|
||||
}
|
||||
if ((int32_t) full_tokens.size() > n_ctx) {
|
||||
return grpc::Status(grpc::StatusCode::OUT_OF_RANGE,
|
||||
"Score: prompt+candidate exceeds context size (got " +
|
||||
std::to_string(full_tokens.size()) + ", n_ctx=" + std::to_string(n_ctx) + ")");
|
||||
}
|
||||
|
||||
// Build a batch covering the entire prompt+candidate. We
|
||||
// need logits at (divergence-1) onward — those are the
|
||||
// predictions for each candidate token.
|
||||
llama_batch batch = llama_batch_init((int32_t) full_tokens.size(), 0, 1);
|
||||
for (int32_t i = 0; i < (int32_t) full_tokens.size(); i++) {
|
||||
batch.token[i] = full_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = score_seq_id;
|
||||
// logits[i] is "do we want the prediction *for the
|
||||
// next token*, computed from this position?"
|
||||
// We want predictions for candidate tokens at
|
||||
// positions divergence .. full_tokens.size()-1, which
|
||||
// come from logits at positions (divergence-1) ..
|
||||
// (full_tokens.size()-2).
|
||||
bool need_logit = (i >= divergence - 1) && (i < (int32_t) full_tokens.size() - 1);
|
||||
batch.logits[i] = need_logit ? 1 : 0;
|
||||
}
|
||||
batch.n_tokens = (int32_t) full_tokens.size();
|
||||
|
||||
// Decode the batch. If decode fails (e.g. KV slot
|
||||
// exhaustion), surface as INTERNAL — the caller will
|
||||
// typically fall back to a sampling-based classifier.
|
||||
int decode_err = llama_decode(lctx, batch);
|
||||
if (decode_err != 0) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_decode failed during Score: " + std::to_string(decode_err));
|
||||
}
|
||||
|
||||
// Sum log-probabilities of the actual candidate tokens.
|
||||
double total_log_prob = 0.0;
|
||||
for (int32_t k = 0; k < cand_len; k++) {
|
||||
// The k-th candidate token sits at full_tokens index
|
||||
// (divergence + k). Its predicting logit is at batch
|
||||
// position (divergence + k - 1).
|
||||
int32_t logit_pos = divergence + k - 1;
|
||||
const float * logits = llama_get_logits_ith(lctx, logit_pos);
|
||||
if (logits == nullptr) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_get_logits_ith returned null at position " + std::to_string(logit_pos));
|
||||
}
|
||||
llama_token target_token = full_tokens[divergence + k];
|
||||
|
||||
// Compute log_softmax(logits)[target_token] with the
|
||||
// max-subtraction stability trick.
|
||||
float max_logit = logits[0];
|
||||
for (int32_t v = 1; v < n_vocab; v++) {
|
||||
if (logits[v] > max_logit) max_logit = logits[v];
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (int32_t v = 0; v < n_vocab; v++) {
|
||||
sum_exp += std::exp((double)(logits[v] - max_logit));
|
||||
}
|
||||
double token_log_prob = (double)(logits[target_token] - max_logit) - std::log(sum_exp);
|
||||
total_log_prob += token_log_prob;
|
||||
|
||||
if (request->include_token_logprobs()) {
|
||||
backend::TokenLogProb * tlp = cs->add_tokens();
|
||||
std::string piece = common_token_to_piece(lctx, target_token);
|
||||
tlp->set_token(piece);
|
||||
tlp->set_log_prob(token_log_prob);
|
||||
}
|
||||
}
|
||||
|
||||
cs->set_log_prob(total_log_prob);
|
||||
if (request->length_normalize() && cand_len > 0) {
|
||||
cs->set_length_normalized_log_prob(total_log_prob / (double) cand_len);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
// Drop this candidate's KV-cache contribution so the next
|
||||
// candidate starts from a clean state. Without this, the
|
||||
// next decode would conflict at positions 0..N-1 for our
|
||||
// sequence ID.
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
}
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("TokenizeString", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
body["stream"] = false;
|
||||
|
||||
@@ -3415,8 +3130,6 @@ public:
|
||||
|
||||
grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override {
|
||||
|
||||
conflict_guard guard("GetMetrics", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// request slots data using task queue
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
int task_id = rd.queue_tasks.get_new_id();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=5aeb2fdbe26cd4c534c6fa15de73cb5749bd0403
|
||||
TURBOQUANT_VERSION?=4c1c3ac09d2dba0aa9a55b94f6c50c41a92f9c8c
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,30 +1,23 @@
|
||||
#!/bin/bash
|
||||
# Patch the shared backend/cpp/llama-cpp/grpc-server.cpp *copy* used by the
|
||||
# turboquant build to account for the gaps between upstream and the fork:
|
||||
# turboquant build:
|
||||
#
|
||||
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
|
||||
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
|
||||
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
|
||||
# server-side random per-instance marker) with the legacy "<__media__>"
|
||||
# literal. The fork branched before that PR, so server-common.cpp has no
|
||||
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
|
||||
# "<__media__>", and Go-side tooling falls back to that sentinel when the
|
||||
# backend does not expose media_marker, so substituting the literal keeps
|
||||
# behavior identical on the turboquant path.
|
||||
# 3. Revert the `common_params_speculative` field references to the
|
||||
# pre-refactor flat layout. Upstream ggml-org/llama.cpp#22397 split the
|
||||
# struct into nested `draft` / `ngram_simple` / `ngram_mod` / etc. members;
|
||||
# the turboquant fork branched before that PR and still exposes the flat
|
||||
# `n_max`, `mparams_dft`, `ngram_size_n`, ... fields. The substitutions
|
||||
# below map the new nested paths back to the legacy flat names so the
|
||||
# shared grpc-server.cpp keeps compiling against the fork's common.h.
|
||||
# Drop this block once the fork rebases past #22397.
|
||||
#
|
||||
# Historical context: this script used to also paper over API gaps between the
|
||||
# fork and upstream (flat vs nested `common_params_speculative`, missing
|
||||
# `get_media_marker()`, `ctx_server.impl->model` vs `model_tgt`, and a
|
||||
# LOCALAI_LEGACY_LLAMA_CPP_SPEC compile gate). As of TURBOQUANT_VERSION
|
||||
# 4c1c3ac0 the fork has rebased past ggml-org/llama.cpp#21962, #22397 and
|
||||
# #22838, so the shared grpc-server.cpp compiles unmodified against the fork.
|
||||
# Only the fork-specific KV-cache enum entries remain.
|
||||
#
|
||||
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build stays compiling
|
||||
# against vanilla upstream.
|
||||
#
|
||||
# Idempotent: skips each insertion if its marker is already present (so re-runs
|
||||
# Idempotent: skips the insertion if its marker is already present (so re-runs
|
||||
# of the same build dir don't double-insert).
|
||||
|
||||
set -euo pipefail
|
||||
@@ -52,7 +45,7 @@ else
|
||||
awk '
|
||||
/^ GGML_TYPE_Q5_1,$/ && !done {
|
||||
print
|
||||
print " // turboquant fork extras — added by patch-grpc-server.sh"
|
||||
print " // turboquant fork extras - added by patch-grpc-server.sh"
|
||||
print " GGML_TYPE_TURBO2_0,"
|
||||
print " GGML_TYPE_TURBO3_0,"
|
||||
print " GGML_TYPE_TURBO4_0,"
|
||||
@@ -72,83 +65,4 @@ else
|
||||
echo "==> KV allow-list patch OK"
|
||||
fi
|
||||
|
||||
if grep -q 'get_media_marker()' "$SRC"; then
|
||||
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
|
||||
# Only one call site today (ModelMetadata), but replace all occurrences to
|
||||
# stay robust if upstream adds more. Use a temp file to avoid relying on
|
||||
# sed -i portability (the builder image uses GNU sed, but keeping this
|
||||
# consistent with the awk block above).
|
||||
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> get_media_marker() substitution OK"
|
||||
else
|
||||
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
|
||||
fi
|
||||
|
||||
if grep -q 'params\.speculative\.draft\.\|params\.speculative\.ngram_simple\.' "$SRC"; then
|
||||
echo "==> patching $SRC to revert common_params_speculative refs to pre-#22397 flat layout"
|
||||
# Each substitution is the exact post-refactor path → legacy flat field.
|
||||
# Order doesn't matter because the source paths are disjoint, but we keep
|
||||
# the most-specific (mparams.path) first for readability.
|
||||
sed -E \
|
||||
-e 's/params\.speculative\.draft\.mparams\.path/params.speculative.mparams_dft.path/g' \
|
||||
-e 's/params\.speculative\.draft\.n_max/params.speculative.n_max/g' \
|
||||
-e 's/params\.speculative\.draft\.n_min/params.speculative.n_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_min/params.speculative.p_min/g' \
|
||||
-e 's/params\.speculative\.draft\.p_split/params.speculative.p_split/g' \
|
||||
-e 's/params\.speculative\.draft\.n_gpu_layers/params.speculative.n_gpu_layers/g' \
|
||||
-e 's/params\.speculative\.draft\.n_ctx/params.speculative.n_ctx/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_n/params.speculative.ngram_size_n/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.size_m/params.speculative.ngram_size_m/g' \
|
||||
-e 's/params\.speculative\.ngram_simple\.min_hits/params.speculative.ngram_min_hits/g' \
|
||||
"$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> speculative field rename OK"
|
||||
else
|
||||
echo "==> $SRC has no post-#22397 speculative field refs, skipping spec rename patch"
|
||||
fi
|
||||
|
||||
# 4. Revert the `ctx_server.impl->model_tgt` rename introduced by upstream
|
||||
# ggml-org/llama.cpp#22838 (parallel drafting). The turboquant fork still
|
||||
# exposes the field as `model` on `server_context_impl`. The two call sites
|
||||
# are in the Rerank and ModelMetadata RPC handlers.
|
||||
if grep -q 'ctx_server\.impl->model_tgt' "$SRC"; then
|
||||
echo "==> patching $SRC to revert ctx_server.impl->model_tgt -> ctx_server.impl->model"
|
||||
sed -E 's/ctx_server\.impl->model_tgt/ctx_server.impl->model/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> model_tgt rename OK"
|
||||
else
|
||||
echo "==> $SRC has no ctx_server.impl->model_tgt refs, skipping model_tgt rename patch"
|
||||
fi
|
||||
|
||||
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
|
||||
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
|
||||
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
|
||||
# draft.tensor_buft_overrides) introduced for the post-#22838 layout. Those
|
||||
# blocks reference struct fields that simply do not exist in the fork.
|
||||
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
|
||||
else
|
||||
echo "==> patching $SRC to define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top"
|
||||
# Insert the define before the very first `#include` so it precedes all the
|
||||
# speculative-decoding code paths.
|
||||
awk '
|
||||
!done && /^#include/ {
|
||||
print "#define LOCALAI_LEGACY_LLAMA_CPP_SPEC 1"
|
||||
print "// ^ injected by backend/cpp/turboquant/patch-grpc-server.sh"
|
||||
print ""
|
||||
done = 1
|
||||
}
|
||||
{ print }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: no #include anchor found to insert LOCALAI_LEGACY_LLAMA_CPP_SPEC" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> LOCALAI_LEGACY_LLAMA_CPP_SPEC define OK"
|
||||
fi
|
||||
|
||||
echo "==> all patches applied"
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
GOCMD=go
|
||||
|
||||
cloud-proxy:
|
||||
CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o cloud-proxy ./
|
||||
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
build: cloud-proxy package
|
||||
|
||||
clean:
|
||||
rm -f cloud-proxy
|
||||
@@ -1,16 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Ginkgo bootstrap. The other Test* functions in this package use
|
||||
// raw testing.T and run independently; they coexist with Ginkgo
|
||||
// specs registered via Describe / Context.
|
||||
func TestCloudProxySpecs(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "cloud-proxy specs")
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package main
|
||||
|
||||
// cloud-proxy is a LocalAI backend that forwards request traffic to an
|
||||
// external HTTP provider (OpenAI, Anthropic, etc.). Two modes:
|
||||
//
|
||||
// - passthrough: serves the Forward RPC; the client wire format is
|
||||
// preserved end-to-end, no translation.
|
||||
// - translate: serves Predict/PredictStream; the backend converts
|
||||
// internal proto to the provider's wire format. (Phases 5–6.)
|
||||
//
|
||||
// LoadModel reads UpstreamURL/Mode/Provider/key references from
|
||||
// ProxyOptions and resolves the API key once at load time.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/xlog"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:50051", "the address to listen on")
|
||||
|
||||
func main() {
|
||||
// xlog's default handler emits ANSI color codes; that's fine for an
|
||||
// interactive shell but unreadable when the backend's stdout is
|
||||
// captured by LocalAI and tee'd to a log file. Force plain text when
|
||||
// LOCALAI_LOG_FORMAT is unset and stdout isn't a terminal.
|
||||
format := os.Getenv("LOCALAI_LOG_FORMAT")
|
||||
if format == "" && !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
format = xlog.TextFormat
|
||||
}
|
||||
xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), format))
|
||||
flag.Parse()
|
||||
if err := grpc.StartServer(*addr, NewCloudProxy()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the cloud-proxy binary into the package dir for the
|
||||
# final Dockerfile stage. Mirrors backend/go/local-store/package.sh —
|
||||
# no extra runtime libs needed since the backend is pure Go.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
mkdir -p $CURDIR/package
|
||||
cp -avf $CURDIR/cloud-proxy $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
@@ -1,270 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("composeURL", func() {
|
||||
// Upstream URL convention: gallery configs put the canonical path
|
||||
// in upstream_url, so per-request Path is ignored. A bare-host
|
||||
// upstream_url accepts the per-request path.
|
||||
DescribeTable("path resolution",
|
||||
func(upstream, reqPath, want string) {
|
||||
got, err := composeURL(upstream, reqPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(got).To(Equal(want))
|
||||
},
|
||||
Entry("full path wins", "https://api.openai.com/v1/chat/completions", "/v1/something-else", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host accepts path", "https://api.openai.com", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("root slash treated as bare", "https://api.openai.com/", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host + empty path", "https://api.openai.com", "", "https://api.openai.com"),
|
||||
)
|
||||
|
||||
It("returns an error on invalid upstream URL", func() {
|
||||
_, err := composeURL("://garbage", "")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("applyAuthHeader", func() {
|
||||
It("sets x-api-key and anthropic-version for Anthropic, no Authorization", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerAnthropic, "ant-key")
|
||||
Expect(req.Header.Get("x-api-key")).To(Equal("ant-key"))
|
||||
Expect(req.Header.Get("anthropic-version")).NotTo(BeEmpty())
|
||||
Expect(req.Header.Get("Authorization")).To(BeEmpty(), "Authorization must not leak on Anthropic backend")
|
||||
})
|
||||
|
||||
It("sets Bearer Authorization for OpenAI, no x-api-key", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerOpenAI, "sk-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer sk-key"))
|
||||
Expect(req.Header.Get("x-api-key")).To(BeEmpty(), "x-api-key must not leak on OpenAI backend")
|
||||
})
|
||||
|
||||
It("defaults to Bearer when provider is empty", func() {
|
||||
// Passthrough mode often has provider == "" because the operator
|
||||
// doesn't claim a specific upstream wire format. Most providers
|
||||
// (including OpenAI-compatible ones) accept Bearer, so default to it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, "", "some-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer some-key"))
|
||||
})
|
||||
|
||||
It("preserves an existing anthropic-version header", func() {
|
||||
// If the client supplied anthropic-version (rare but legitimate
|
||||
// for an upstream pinned to a specific date), the proxy must not
|
||||
// clobber it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
req.Header.Set("anthropic-version", "2024-10-01")
|
||||
applyAuthHeader(req, providerAnthropic, "k")
|
||||
Expect(req.Header.Get("anthropic-version")).To(Equal("2024-10-01"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("isHopByHopHeader", func() {
|
||||
DescribeTable("hop-by-hop classification",
|
||||
func(header string, want bool) {
|
||||
Expect(isHopByHopHeader(header)).To(Equal(want))
|
||||
},
|
||||
Entry("Connection is hop-by-hop", "Connection", true),
|
||||
Entry("Keep-Alive is hop-by-hop", "Keep-Alive", true),
|
||||
Entry("Proxy-Connection is hop-by-hop", "Proxy-Connection", true),
|
||||
Entry("Transfer-Encoding is hop-by-hop", "Transfer-Encoding", true),
|
||||
Entry("TE is hop-by-hop", "TE", true),
|
||||
Entry("Trailer is hop-by-hop", "Trailer", true),
|
||||
Entry("Upgrade is hop-by-hop", "Upgrade", true),
|
||||
Entry("Host is hop-by-hop", "Host", true),
|
||||
Entry("Content-Length is hop-by-hop", "Content-Length", true),
|
||||
// Case-insensitive — RFC 7230 doesn't constrain header case.
|
||||
Entry("lowercase connection is hop-by-hop", "connection", true),
|
||||
Entry("uppercase HOST is hop-by-hop", "HOST", true),
|
||||
// Non hop-by-hop — must NOT be stripped.
|
||||
Entry("Authorization is end-to-end", "Authorization", false),
|
||||
Entry("Content-Type is end-to-end", "Content-Type", false),
|
||||
Entry("Accept is end-to-end", "Accept", false),
|
||||
Entry("X-Custom is end-to-end", "X-Custom", false),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("Forward", func() {
|
||||
It("strips hop-by-hop and Connection headers before upstream, preserves custom headers", func() {
|
||||
gotConnection := make(chan string, 1)
|
||||
gotXCustom := make(chan string, 1)
|
||||
gotHost := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotConnection <- r.Header.Get("Connection")
|
||||
gotXCustom <- r.Header.Get("X-Custom")
|
||||
gotHost <- r.Header.Get("Host")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-hopbyhop"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
{Name: "Connection", Value: "keep-alive"},
|
||||
{Name: "Host", Value: "spoofed.example.com"},
|
||||
{Name: "X-Custom", Value: "preserved"},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotConnection).To(BeEmpty(), "Connection must not leak to upstream")
|
||||
Expect(<-gotHost).NotTo(Equal("spoofed.example.com"), "Host header must not be spoofed through")
|
||||
Expect(<-gotXCustom).To(Equal("preserved"), "X-Custom header must survive")
|
||||
})
|
||||
|
||||
It("replaces caller-supplied Authorization with the configured key", func() {
|
||||
// The proxy must overwrite a client-supplied Authorization header
|
||||
// so a downstream caller can't smuggle stale or wrong credentials.
|
||||
gotAuth := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_AUTH_REPLACE_KEY", "sk-real")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_AUTH_REPLACE_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-replaces-auth"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
// Client-supplied Authorization with the wrong scheme / key.
|
||||
{Name: "Authorization", Value: "Basic Zm9vOmJhcg=="},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
|
||||
})
|
||||
|
||||
It("handles concurrent calls without interference", func() {
|
||||
// CloudProxy explicitly omits base.SingleThread — independent
|
||||
// Forward streams must not block each other or leak state.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
addr := "test://forward-concurrent"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
|
||||
const N = 8
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
stream, err := c.Forward(context.Background())
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
payload := "request-" + string(rune('A'+idx))
|
||||
if err := stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
BodyChunk: []byte(payload),
|
||||
}); err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv()
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
body = append(body, r.GetBodyChunk()...)
|
||||
}
|
||||
if string(body) != payload {
|
||||
errs <- &echoMismatch{want: payload, got: string(body)}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
var collected []error
|
||||
for err := range errs {
|
||||
collected = append(collected, err)
|
||||
}
|
||||
Expect(collected).To(BeEmpty(), "no concurrent Forward call should fail")
|
||||
})
|
||||
})
|
||||
|
||||
type echoMismatch struct{ want, got string }
|
||||
|
||||
func (e *echoMismatch) Error() string {
|
||||
return "echo mismatch: want " + strconv.Quote(e.want) + " got " + strconv.Quote(e.got)
|
||||
}
|
||||
@@ -1,508 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Anthropic Messages API wire-format types. Narrowed to what translate
|
||||
// mode preserves through the Reply proto: text + tool_use blocks +
|
||||
// usage tokens. Image blocks, prompt caching, metadata, and stop
|
||||
// sequence metadata are not modelled — passthrough mode covers those.
|
||||
//
|
||||
// Notable differences from OpenAI:
|
||||
// - max_tokens is REQUIRED. Anthropic 400s without it.
|
||||
// - Roles are user/assistant only — system messages move to a
|
||||
// top-level `system` string field.
|
||||
// - Streaming SSE uses event: lines alongside data: lines. The
|
||||
// events we care about: content_block_start (carries tool_use
|
||||
// init: id + name), content_block_delta (text_delta with text;
|
||||
// input_json_delta with partial_json for tool arguments), and
|
||||
// message_stop (terminates the stream). Others are ignored.
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int32 `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// Content is `any` because Anthropic accepts a bare string OR a
|
||||
// list of content blocks. Use the string form for plain user/
|
||||
// assistant turns; switch to []anthropicContentBlock when the
|
||||
// turn needs tool_use (assistant) or tool_result (user) blocks.
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// anthropicToolChoice mirrors the four shapes Anthropic accepts:
|
||||
// {"type":"auto"} | {"type":"any"} | {"type":"tool","name":"X"} |
|
||||
// {"type":"none"} (newer models). OpenAI's "auto"/"none"/
|
||||
// "required"/{"function":{"name":"X"}} all map here.
|
||||
type anthropicToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// anthropicContentBlock is the union shape used both for response
|
||||
// blocks (text/tool_use we read off the wire) and outbound request
|
||||
// blocks (tool_use/tool_result we emit in the conversation history).
|
||||
// Anthropic encodes tool calls inline rather than as a separate field,
|
||||
// so we walk Content[] looking for type=="tool_use" on responses and
|
||||
// produce equivalent blocks when serialising prior-turn tool calls.
|
||||
type anthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
// Tool-result block fields. tool_result uses `content` (not
|
||||
// `text`) and pairs with `tool_use_id`; modelling them as
|
||||
// distinct fields avoids ambiguity at marshal time.
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
ResultContent string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []anthropicContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// anthropicStreamEvent is the union shape used for every event type we
|
||||
// process. Type discriminates; only the matching fields are populated.
|
||||
// content_block_start carries ContentBlock (with id/name for tool_use);
|
||||
// content_block_delta carries Delta (text or partial_json).
|
||||
type anthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index,omitempty"`
|
||||
ContentBlock *anthropicContentBlock `json:"content_block,omitempty"`
|
||||
Delta *anthropicStreamDelta `json:"delta,omitempty"`
|
||||
Message *anthropicResponse `json:"message,omitempty"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicStreamDelta struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
}
|
||||
|
||||
// Anthropic requires max_tokens. If the caller didn't set it, use a
|
||||
// generous-but-bounded default so the request doesn't 400.
|
||||
const anthropicDefaultMaxTokens int32 = 4096
|
||||
|
||||
const anthropicToolChoiceNone = "none"
|
||||
|
||||
// Reused JSON-Schema defaults for malformed inputs. Anthropic requires
|
||||
// input_schema to be a JSON object and tool_use.input to be a JSON
|
||||
// object; clients that omit them must not 400 the entire request.
|
||||
var (
|
||||
emptyJSONObject = json.RawMessage(`{}`)
|
||||
emptyObjectSchema = json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
)
|
||||
|
||||
func buildAnthropicRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := anthropicRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
MaxTokens: opts.GetTokens(),
|
||||
Stream: stream,
|
||||
StopSequences: opts.GetStopPrompts(),
|
||||
}
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = anthropicDefaultMaxTokens
|
||||
}
|
||||
// Newer Anthropic models 400 when both temperature and top_p are
|
||||
// set ("`temperature` and `top_p` cannot both be specified for
|
||||
// this model. Please use only one.") even though their docs only
|
||||
// "recommend" picking one. The OpenAI-compatible chat UI almost
|
||||
// always sends both with default values, so prefer temperature
|
||||
// and drop top_p when both are present.
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
} else if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
|
||||
req.Tools = convertOpenAITools(opts.GetTools())
|
||||
req.ToolChoice = convertOpenAIToolChoice(opts.GetToolChoice())
|
||||
// Anthropic rejects tool_choice without tools and older models
|
||||
// don't accept {"type":"none"} — collapse to a no-tools request.
|
||||
if req.ToolChoice != nil && req.ToolChoice.Type == anthropicToolChoiceNone {
|
||||
req.Tools, req.ToolChoice = nil, nil
|
||||
}
|
||||
|
||||
var systemParts []string
|
||||
for _, m := range opts.GetMessages() {
|
||||
role := m.GetRole()
|
||||
if role == "system" {
|
||||
if c := m.GetContent(); c != "" {
|
||||
systemParts = append(systemParts, c)
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch role {
|
||||
case "user":
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "assistant":
|
||||
if blocks := assistantBlocks(m); blocks != nil {
|
||||
req.Messages = append(req.Messages, anthropicMessage{Role: "assistant", Content: blocks})
|
||||
continue
|
||||
}
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "assistant",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "tool", "function":
|
||||
req.Messages = appendToolResult(req.Messages, anthropicContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: m.GetToolCallId(),
|
||||
ResultContent: m.GetContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
req.System = strings.Join(systemParts, "\n\n")
|
||||
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []anthropicMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// appendToolResult appends a tool_result block as a user message,
|
||||
// merging into a preceding user message that already carries blocks.
|
||||
// Anthropic concatenates consecutive same-role messages on its end,
|
||||
// but explicit merging keeps the body smaller and the conversation
|
||||
// strictly alternating — which some upstream filters require.
|
||||
func appendToolResult(msgs []anthropicMessage, block anthropicContentBlock) []anthropicMessage {
|
||||
if n := len(msgs); n > 0 && msgs[n-1].Role == "user" {
|
||||
if existing, ok := msgs[n-1].Content.([]anthropicContentBlock); ok {
|
||||
msgs[n-1].Content = append(existing, block)
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
return append(msgs, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: []anthropicContentBlock{block},
|
||||
})
|
||||
}
|
||||
|
||||
func convertOpenAITools(toolsJSON string) []anthropicTool {
|
||||
if toolsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var raw []openAITool
|
||||
if err := json.Unmarshal([]byte(toolsJSON), &raw); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tools JSON, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
tools := make([]anthropicTool, 0, len(raw))
|
||||
for _, t := range raw {
|
||||
if t.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
schema := t.Function.Parameters
|
||||
if len(schema) == 0 {
|
||||
schema = emptyObjectSchema
|
||||
}
|
||||
tools = append(tools, anthropicTool{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// convertOpenAIToolChoice accepts the spec form
|
||||
// ({type:function, function:{name:X}}) and the flat legacy form
|
||||
// ({type:function, name:X}) some clients send. Unknown object shapes
|
||||
// are warned and dropped rather than silently treated as auto.
|
||||
func convertOpenAIToolChoice(toolChoiceJSON string) *anthropicToolChoice {
|
||||
if toolChoiceJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var asString string
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asString); err == nil {
|
||||
switch asString {
|
||||
case "auto":
|
||||
return &anthropicToolChoice{Type: "auto"}
|
||||
case "none":
|
||||
return &anthropicToolChoice{Type: anthropicToolChoiceNone}
|
||||
case "required":
|
||||
return &anthropicToolChoice{Type: "any"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var asObj struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asObj); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tool_choice, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
if name := asObj.Function.Name; name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: name}
|
||||
}
|
||||
if asObj.Name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: asObj.Name}
|
||||
}
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unrecognised tool_choice shape, dropping", "shape", toolChoiceJSON)
|
||||
return nil
|
||||
}
|
||||
|
||||
// openAITool mirrors pkg/functions.Tool but keeps Parameters as
|
||||
// json.RawMessage so the input_schema passes through verbatim — no
|
||||
// re-marshal cost, no fidelity loss on exotic schemas.
|
||||
type openAITool struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters json.RawMessage `json:"parameters"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
func assistantBlocks(m *pb.Message) []anthropicContentBlock {
|
||||
toolCallsJSON := m.GetToolCalls()
|
||||
if toolCallsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var toolCalls []openAIToolCall
|
||||
if err := json.Unmarshal([]byte(toolCallsJSON), &toolCalls); err != nil || len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
blocks := make([]anthropicContentBlock, 0, len(toolCalls)+1)
|
||||
if text := m.GetContent(); text != "" {
|
||||
blocks = append(blocks, anthropicContentBlock{Type: "text", Text: text})
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
// OpenAI's arguments are a JSON-encoded string; pass through
|
||||
// as RawMessage so a non-JSON string from a poorly-formed
|
||||
// local model doesn't crash the marshaller downstream.
|
||||
args := json.RawMessage(tc.Function.Arguments)
|
||||
if len(args) == 0 {
|
||||
args = emptyJSONObject
|
||||
}
|
||||
blocks = append(blocks, anthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: args,
|
||||
})
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
// doAnthropicRequest is the Anthropic counterpart of doOpenAIRequest.
|
||||
// applyAuthHeader sets x-api-key and anthropic-version when provider
|
||||
// is anthropic, so this method doesn't need to duplicate that.
|
||||
func (c *CloudProxy) doAnthropicRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictAnthropicRich returns the full Reply: joined text from all
|
||||
// text blocks, tool_use blocks mapped to ToolCallDelta, and usage
|
||||
// tokens.
|
||||
func (c *CloudProxy) predictAnthropicRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildAnthropicRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed anthropicResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
|
||||
reply := &pb.Reply{}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.InputTokens)
|
||||
reply.Tokens = int32(parsed.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
var content strings.Builder
|
||||
var toolCalls []*pb.ToolCallDelta
|
||||
toolIdx := 0
|
||||
for _, b := range parsed.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
content.WriteString(b.Text)
|
||||
case "tool_use":
|
||||
// Input is a structured JSON object; we serialise to a
|
||||
// string so it fits the OpenAI-shaped arguments field
|
||||
// downstream consumers expect.
|
||||
args := ""
|
||||
if len(b.Input) > 0 {
|
||||
args = string(b.Input)
|
||||
}
|
||||
toolCalls = append(toolCalls, newToolCallDelta(toolIdx, b.ID, b.Name, args))
|
||||
toolIdx++
|
||||
}
|
||||
}
|
||||
reply.Message = []byte(content.String())
|
||||
if len(toolCalls) > 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{ToolCalls: toolCalls}}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictAnthropicStreamRich streams Reply chunks from Anthropic's SSE.
|
||||
// Three event types matter: content_block_start (initialises tool_use
|
||||
// id+name), content_block_delta (carries text or input_json_delta),
|
||||
// message_stop (terminates). The block index from the wire feeds
|
||||
// straight into ToolCallDelta.Index so downstream consumers can
|
||||
// reassemble multiple parallel tool calls.
|
||||
func (c *CloudProxy) predictAnthropicStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildAnthropicRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
var ev anthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
switch ev.Type {
|
||||
case "content_block_start":
|
||||
// tool_use blocks announce id + name here; arguments arrive
|
||||
// in subsequent input_json_delta events. Emit a Reply with
|
||||
// just the tool_call init fields so consumers can allocate
|
||||
// a slot at this index.
|
||||
if ev.ContentBlock != nil && ev.ContentBlock.Type == "tool_use" {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, ev.ContentBlock.ID, ev.ContentBlock.Name, ""),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
if ev.Delta == nil {
|
||||
continue
|
||||
}
|
||||
switch ev.Delta.Type {
|
||||
case "text_delta":
|
||||
if ev.Delta.Text == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Message: []byte(ev.Delta.Text),
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: ev.Delta.Text}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
case "input_json_delta":
|
||||
if ev.Delta.PartialJSON == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, "", "", ev.Delta.PartialJSON),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_delta":
|
||||
// Anthropic sends final usage in message_delta.usage. Emit
|
||||
// a usage-only Reply so the consumer can record totals.
|
||||
if ev.Usage != nil {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Tokens: int32(ev.Usage.OutputTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
@@ -1,334 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// fakeAnthropicUpstream mirrors fakeOpenAIUpstream but decodes the
|
||||
// request body as an anthropicRequest so tests can assert on the
|
||||
// translated wire shape (system field, max_tokens, etc.).
|
||||
func fakeAnthropicUpstream(t *testing.T, handler func(req anthropicRequest) (status int, body string, contentType string)) (*httptest.Server, *anthropicRequest) {
|
||||
t.Helper()
|
||||
var captured anthropicRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newAnthropicTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_FAKE", "sk-ant-fake")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "claude-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_FAKE",
|
||||
UpstreamModel: "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_BasicMessages(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hi there"}],"model":"claude-3-5-sonnet-20241022","usage":{"input_tokens":5,"output_tokens":2}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
g.Expect(captured.Model).To(Equal("claude-3-5-sonnet-20241022"))
|
||||
// System message must be hoisted out of Messages into top-level field.
|
||||
g.Expect(captured.System).To(Equal("be brief"))
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
// Anthropic 400s when both temperature and top_p are set; the
|
||||
// translator must prefer temperature and drop top_p.
|
||||
g.Expect(captured.TopP).To(BeNil())
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
// When only top_p is set, it should be forwarded.
|
||||
func TestPredict_Anthropic_TopPOnly(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
|
||||
TopP: 0.9,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
// PredictOptions.TopP is float32 on the wire; the translator widens
|
||||
// to float64 so 0.9 round-trips as 0.8999999761581421… — compare
|
||||
// with a small tolerance rather than exact equality.
|
||||
g.Expect(captured.TopP).NotTo(BeNil())
|
||||
g.Expect(math.Abs(*captured.TopP - 0.9)).To(BeNumerically("<=", 1e-6))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_DefaultsMaxTokens(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic 400s without max_tokens. The translator must default
|
||||
// it when the caller doesn't supply Tokens.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.MaxTokens).To(Equal(anthropicDefaultMaxTokens))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?", Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_ConcatenatesContentBlocks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic may return multiple text blocks; the translator joins
|
||||
// them so the Predict() string return is the full assistant message.
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"hello "},{"type":"text","text":"world"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello world"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 401, `{"error":{"type":"authentication_error","message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_Anthropic_StreamsTextDeltas(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Real Anthropic SSE has event: lines + data: lines. The translator
|
||||
// only needs the data: payload; only content_block_delta with
|
||||
// delta.type=text_delta carries content. message_stop ends.
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" \"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"world\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
Tokens: 16,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_TranslatesOpenAITools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "weather in Paris?"}},
|
||||
Tools: tools,
|
||||
ToolChoice: `"auto"`,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(HaveLen(1))
|
||||
g.Expect(captured.Tools[0].Name).To(Equal("get_weather"))
|
||||
g.Expect(captured.Tools[0].Description).To(Equal("Get weather"))
|
||||
// input_schema must be the parameters object verbatim.
|
||||
g.Expect(string(captured.Tools[0].InputSchema)).To(ContainSubstring(`"city"`))
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("auto"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_RequiredMapsToAny(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"required"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("any"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NoneDropsTools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"none"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(BeNil())
|
||||
g.Expect(captured.ToolChoice).To(BeNil())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NamedFunction(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"weather","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `{"type":"function","function":{"name":"weather"}}`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("tool"))
|
||||
g.Expect(captured.ToolChoice.Name).To(Equal("weather"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_RoundTripsAssistantToolCalls(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// LocalAI Assistant's second turn: the LLM previously emitted a
|
||||
// tool_use, the server executed it, and the conversation now
|
||||
// includes the assistant turn (with tool_calls) plus a tool-role
|
||||
// result message. Both must convert to Anthropic block form.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"list_models","parameters":{"type":"object"}}}]`
|
||||
toolCallsJSON := `[{"id":"call_abc","type":"function","function":{"name":"list_models","arguments":"{}"}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Tools: tools,
|
||||
Messages: []*pb.Message{
|
||||
{Role: "user", Content: "what models are installed?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: toolCallsJSON},
|
||||
{Role: "tool", Content: `{"models":["a","b"]}`, ToolCallId: "call_abc"},
|
||||
},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
g.Expect(captured.Messages).To(HaveLen(3))
|
||||
// 1. user text — bare string
|
||||
s, ok := captured.Messages[0].Content.(string)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(s).To(Equal("what models are installed?"))
|
||||
// 2. assistant — must be a content-block list with one tool_use
|
||||
// json.Unmarshal of `any` produces []any not []anthropicContentBlock.
|
||||
blocks, ok := captured.Messages[1].Content.([]any)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(blocks).To(HaveLen(1))
|
||||
b0, _ := blocks[0].(map[string]any)
|
||||
g.Expect(b0["type"]).To(Equal("tool_use"))
|
||||
g.Expect(b0["id"]).To(Equal("call_abc"))
|
||||
g.Expect(b0["name"]).To(Equal("list_models"))
|
||||
// 3. tool → user with tool_result block
|
||||
g.Expect(captured.Messages[2].Role).To(Equal("user"))
|
||||
resBlocks, _ := captured.Messages[2].Content.([]any)
|
||||
r0, _ := resBlocks[0].(map[string]any)
|
||||
g.Expect(r0["type"]).To(Equal("tool_result"))
|
||||
g.Expect(r0["tool_use_id"]).To(Equal("call_abc"))
|
||||
g.Expect(r0["content"]).To(Equal(`{"models":["a","b"]}`))
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Verify buildOpenAIRequest preserves caller-supplied tools and
|
||||
// tool_choice as opaque JSON. PredictOptions carries them as strings;
|
||||
// they must land in the outbound request body unchanged so the
|
||||
// upstream sees the caller's intent verbatim. A regression here would
|
||||
// silently disable function calling for translate-mode clients.
|
||||
func TestBuildOpenAIRequest_ToolsAndToolChoicePassthrough(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
toolsJSON := `[{"type":"function","function":{"name":"search","parameters":{"type":"object"}}}]`
|
||||
choiceJSON := `{"type":"function","function":{"name":"search"}}`
|
||||
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find x"}},
|
||||
Tools: toolsJSON,
|
||||
ToolChoice: choiceJSON,
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var decoded openAIRequest
|
||||
err = json.Unmarshal(body, &decoded)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// Compare the JSON-canonical form so whitespace differences are ignored.
|
||||
gotTools, _ := json.Marshal(json.RawMessage(decoded.Tools))
|
||||
wantTools, _ := json.Marshal(json.RawMessage(toolsJSON))
|
||||
g.Expect(string(gotTools)).To(Equal(string(wantTools)))
|
||||
gotChoice, _ := json.Marshal(json.RawMessage(decoded.ToolChoice))
|
||||
wantChoice, _ := json.Marshal(json.RawMessage(choiceJSON))
|
||||
g.Expect(string(gotChoice)).To(Equal(string(wantChoice)))
|
||||
}
|
||||
|
||||
// Garbage JSON in tools / tool_choice is silently dropped (omitted)
|
||||
// rather than blowing up the request. Documents the parseRawJSON
|
||||
// behaviour — operators shouldn't see hard failures from an upstream
|
||||
// caller's mis-formatted tools field.
|
||||
func TestBuildOpenAIRequest_InvalidToolsJSONDropped(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: "this is not json",
|
||||
ToolChoice: "{also bad",
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("this is not json"))
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("{also bad"))
|
||||
}
|
||||
|
||||
// Anthropic empty content array yields an empty Reply (not an error).
|
||||
// Mirrors how an upstream tool_use-only response might arrive — the
|
||||
// content array can legitimately be empty in some edge cases.
|
||||
func TestPredictRich_Anthropic_EmptyContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"m1","type":"message","role":"assistant","content":[],"usage":{"input_tokens":3,"output_tokens":0}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(0))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(3)))
|
||||
}
|
||||
|
||||
// A truncated / malformed SSE payload mid-stream should be tolerated:
|
||||
// the malformed chunk gets skipped (xlog.Debug logged), valid chunks
|
||||
// before AND after it still reach the channel.
|
||||
func TestPredictStreamRich_OpenAI_TolerantOfBadChunks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
body := strings.Join([]string{
|
||||
`data: {"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
``,
|
||||
`data: this-is-not-json{{`,
|
||||
``,
|
||||
`data: {"choices":[{"index":0,"delta":{"content":" world"}}]}`,
|
||||
``,
|
||||
`data: [DONE]`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var assembled strings.Builder
|
||||
for reply := range results {
|
||||
assembled.Write(reply.GetMessage())
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// The good chunks before and after the malformed one both made it through.
|
||||
g.Expect(assembled.String()).To(Equal("hello world"))
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// OpenAI Chat Completions wire-format types. Narrowed to the fields
|
||||
// translate mode needs to preserve through the Reply proto: content,
|
||||
// role, tool_calls (typed so we can map them to pb.ToolCallDelta),
|
||||
// and sampling params copied verbatim from PredictOptions.
|
||||
//
|
||||
// Provider-specific extensions (logit_bias, function calling beyond
|
||||
// tool_calls, etc.) are not modelled — passthrough mode covers callers
|
||||
// that need full upstream fidelity.
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int32 `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// openAIToolCall covers both the non-streaming response shape (full
|
||||
// id+function+arguments) and the streaming-delta shape (sparse fields,
|
||||
// index assignment). The proto's ToolCallDelta absorbs both — name is
|
||||
// set on first appearance, arguments arrive incrementally in streaming.
|
||||
type openAIToolCall struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function openAIFunctionCall `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
type openAIFunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Choices []openAIChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChunk struct {
|
||||
Choices []openAIStreamChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// buildOpenAIRequest converts pb.PredictOptions into the OpenAI Chat
|
||||
// Completions request body. Prefers Messages when non-empty; falls
|
||||
// back to wrapping Prompt as a single user message so plain
|
||||
// /completions-style calls still work in translate mode.
|
||||
func buildOpenAIRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := openAIRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
Stream: stream,
|
||||
Stop: opts.GetStopPrompts(),
|
||||
Tools: parseRawJSON(opts.GetTools()),
|
||||
ToolChoice: parseRawJSON(opts.GetToolChoice()),
|
||||
}
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
}
|
||||
if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
req.MaxTokens = &n
|
||||
}
|
||||
if p := opts.GetFrequencyPenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.FrequencyPenalty = &v
|
||||
}
|
||||
if p := opts.GetPresencePenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.PresencePenalty = &v
|
||||
}
|
||||
|
||||
for _, m := range opts.GetMessages() {
|
||||
msg := openAIMessage{
|
||||
Role: m.GetRole(),
|
||||
Content: m.GetContent(),
|
||||
Name: m.GetName(),
|
||||
ToolCallID: m.GetToolCallId(),
|
||||
}
|
||||
// Pre-existing tool_calls arrive as a JSON string from the
|
||||
// upstream caller's previous assistant turn; pass-through as-is.
|
||||
if tc := m.GetToolCalls(); tc != "" {
|
||||
_ = json.Unmarshal([]byte(tc), &msg.ToolCalls)
|
||||
}
|
||||
req.Messages = append(req.Messages, msg)
|
||||
}
|
||||
// Fallback for plain Prompt requests (no Messages array). LocalAI
|
||||
// templating may have produced a flat prompt; rewrap as a single
|
||||
// user message so the upstream chat endpoint accepts it.
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []openAIMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// modelName picks the upstream model: upstream_model from the proxy
|
||||
// config wins (operator override), else the local model name captured
|
||||
// at LoadModel time. Operator sets upstream_model to map LocalAI's
|
||||
// alias (e.g. "claude-strict") to the upstream's canonical name
|
||||
// (e.g. "claude-3-5-sonnet-20241022").
|
||||
func modelName(cfg *proxyConfig, _ *pb.PredictOptions) string {
|
||||
if cfg.upstreamModel != "" {
|
||||
return cfg.upstreamModel
|
||||
}
|
||||
return cfg.localModel
|
||||
}
|
||||
|
||||
// parseRawJSON parses a JSON string into a RawMessage so it round-trips
|
||||
// into the upstream body. Returns nil for empty/invalid input so the
|
||||
// field is omitted (omitempty).
|
||||
func parseRawJSON(s string) json.RawMessage {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
var probe json.RawMessage
|
||||
if err := json.Unmarshal([]byte(s), &probe); err != nil {
|
||||
return nil
|
||||
}
|
||||
return probe
|
||||
}
|
||||
|
||||
// doOpenAIRequest builds + sends the upstream request. Returns the
|
||||
// raw response on success; caller handles status / body.
|
||||
func (c *CloudProxy) doOpenAIRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictOpenAIRich is the non-streaming translate path. Returns a
|
||||
// fully-populated *pb.Reply with assistant content, tool calls, and
|
||||
// token usage. The gRPC server forwards the Reply verbatim.
|
||||
func (c *CloudProxy) predictOpenAIRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildOpenAIRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
if len(parsed.Choices) == 0 {
|
||||
return nil, errors.New("cloud-proxy: upstream returned no choices")
|
||||
}
|
||||
|
||||
choice := parsed.Choices[0]
|
||||
reply := &pb.Reply{
|
||||
Message: []byte(choice.Message.Content),
|
||||
}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.PromptTokens)
|
||||
reply.Tokens = int32(parsed.Usage.CompletionTokens)
|
||||
}
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// Non-streaming: a single ChatDelta carries the full tool-call
|
||||
// set. Index/Name/Arguments are populated together; downstream
|
||||
// consumers don't need to assemble streaming deltas.
|
||||
delta := &pb.ChatDelta{}
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
delta.ToolCalls = append(delta.ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
reply.ChatDeltas = []*pb.ChatDelta{delta}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictOpenAIStreamRich streams *pb.Reply chunks. Each chunk carries
|
||||
// either a content delta (Message + ChatDeltas[].Content) or tool-call
|
||||
// deltas (ChatDeltas[].ToolCalls). The final Reply carries usage tokens
|
||||
// when the upstream sends them (stream_options.include_usage).
|
||||
func (c *CloudProxy) predictOpenAIStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildOpenAIRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
var chunk openAIStreamChunk
|
||||
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
// Usage frames may arrive separately from content frames when
|
||||
// stream_options.include_usage is set; emit a usage-only Reply
|
||||
// in that case so the consumer sees the totals.
|
||||
if chunk.Usage != nil && len(chunk.Choices) == 0 {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
PromptTokens: int32(chunk.Usage.PromptTokens),
|
||||
Tokens: int32(chunk.Usage.CompletionTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, ch := range chunk.Choices {
|
||||
reply := &pb.Reply{}
|
||||
if ch.Delta.Content != "" {
|
||||
reply.Message = []byte(ch.Delta.Content)
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{Content: ch.Delta.Content}}
|
||||
}
|
||||
if len(ch.Delta.ToolCalls) > 0 {
|
||||
if len(reply.ChatDeltas) == 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{}}
|
||||
}
|
||||
for _, tc := range ch.Delta.ToolCalls {
|
||||
reply.ChatDeltas[0].ToolCalls = append(reply.ChatDeltas[0].ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
}
|
||||
if reply.Message == nil && len(reply.ChatDeltas) == 0 {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, reply) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
@@ -1,170 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// fakeOpenAIUpstream returns an httptest.Server that decodes the
|
||||
// inbound request as an openAIRequest, calls handler with it, and
|
||||
// writes the handler's reply as the response.
|
||||
func fakeOpenAIUpstream(t *testing.T, handler func(req openAIRequest) (status int, body string, contentType string)) (*httptest.Server, *openAIRequest) {
|
||||
t.Helper()
|
||||
var captured openAIRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_OPENAI_FAKE", "sk-fake-openai")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "gpt-4o-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
ApiKeyEnv: "CLOUD_PROXY_OPENAI_FAKE",
|
||||
UpstreamModel: "gpt-4o",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_BasicChat(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"id":"resp-1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
// Verify the upstream saw a properly-translated request.
|
||||
g.Expect(captured.Model).To(Equal("gpt-4o"))
|
||||
g.Expect(captured.Messages).To(HaveLen(2))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("system"))
|
||||
g.Expect(captured.Messages[1].Role).To(Equal("user"))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
g.Expect(captured.MaxTokens).NotTo(BeNil())
|
||||
g.Expect(*captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// No Messages array — backend should synth a single user message
|
||||
// from Prompt so non-chat clients still route through translate.
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 401, `{"error":{"message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_OpenAI_StreamsContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Stream three content deltas then [DONE]. Verify the channel
|
||||
// receives them in order with no missing pieces.
|
||||
chunks := []string{
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":" "}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"world"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestPredict_RejectedInPassthroughMode(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_FAKE", "k")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_, err = cp.Predict(&pb.PredictOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("only valid in translate"))
|
||||
}
|
||||
@@ -1,429 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
|
||||
// import core to keep the boundary clean.
|
||||
const (
|
||||
modePassthrough = "passthrough"
|
||||
modeTranslate = "translate"
|
||||
|
||||
providerOpenAI = "openai"
|
||||
providerAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// CloudProxy is the LocalAI backend that proxies model traffic to a
|
||||
// configured upstream HTTP provider. Concurrency: base.SingleThread is
|
||||
// NOT embedded — forward calls are independent and HTTP transport is
|
||||
// goroutine-safe, so multiple Forward streams can run in parallel.
|
||||
// Locking would serialise requests to a chat provider for no benefit.
|
||||
type CloudProxy struct {
|
||||
base.Base
|
||||
|
||||
cfg atomic.Pointer[proxyConfig]
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type proxyConfig struct {
|
||||
upstreamURL string
|
||||
mode string
|
||||
provider string
|
||||
upstreamModel string
|
||||
localModel string // ModelOptions.Model — fallback when upstream_model is unset
|
||||
apiKey string // resolved at Load time
|
||||
}
|
||||
|
||||
func NewCloudProxy() *CloudProxy {
|
||||
// No Client-level Timeout — that would bound streaming SSE
|
||||
// responses too, which can legitimately last minutes. Per-request
|
||||
// deadlines come from the gRPC stream context.
|
||||
return &CloudProxy{client: &http.Client{}}
|
||||
}
|
||||
|
||||
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
|
||||
po := opts.GetProxy()
|
||||
if po == nil {
|
||||
return errors.New("cloud-proxy: Load requires ProxyOptions to be set")
|
||||
}
|
||||
if po.GetUpstreamUrl() == "" {
|
||||
return errors.New("cloud-proxy: upstream_url is required")
|
||||
}
|
||||
if _, err := url.ParseRequestURI(po.GetUpstreamUrl()); err != nil {
|
||||
return fmt.Errorf("cloud-proxy: upstream_url %q invalid: %w", po.GetUpstreamUrl(), err)
|
||||
}
|
||||
|
||||
mode := po.GetMode()
|
||||
if mode == "" {
|
||||
mode = modePassthrough
|
||||
}
|
||||
switch mode {
|
||||
case modePassthrough:
|
||||
case modeTranslate:
|
||||
switch po.GetProvider() {
|
||||
case providerOpenAI:
|
||||
// implemented in provider_openai.go
|
||||
case providerAnthropic:
|
||||
// implemented in provider_anthropic.go
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: translate mode requires provider in {%s, %s}, got %q",
|
||||
providerOpenAI, providerAnthropic, po.GetProvider())
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: unknown mode %q", mode)
|
||||
}
|
||||
|
||||
key, err := resolveAPIKey(po.GetApiKeyEnv(), po.GetApiKeyFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.cfg.Store(&proxyConfig{
|
||||
upstreamURL: po.GetUpstreamUrl(),
|
||||
mode: mode,
|
||||
provider: po.GetProvider(),
|
||||
upstreamModel: po.GetUpstreamModel(),
|
||||
localModel: opts.GetModel(),
|
||||
apiKey: key,
|
||||
})
|
||||
xlog.Info("cloud-proxy: ready",
|
||||
"upstream", po.GetUpstreamUrl(),
|
||||
"mode", mode,
|
||||
"provider", po.GetProvider(),
|
||||
"has_key", key != "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveAPIKey mirrors config.ProxyConfig.ResolveAPIKey. Duplicated
|
||||
// (a few lines) rather than importing core/config from a backend
|
||||
// binary — keeps backends independent of core's package layout.
|
||||
// Mutual-exclusion is enforced upstream in core/config.Validate.
|
||||
func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
if envName != "" {
|
||||
v := os.Getenv(envName)
|
||||
if v == "" {
|
||||
return "", fmt.Errorf("cloud-proxy: api_key_env %q is unset", envName)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
if filePath != "" {
|
||||
b, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: read api_key_file %q: %w", filePath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(b)), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// PredictRich is the non-streaming translate path. Returns a fully-
|
||||
// populated *pb.Reply: content, tool-call deltas (ChatDeltas), and
|
||||
// usage tokens. Implements the optional grpc.AIModelRich interface;
|
||||
// the gRPC server prefers this path over Predict when present so
|
||||
// tool calls survive the round-trip. Passthrough mode rejects
|
||||
// PredictRich — callers must use Forward.
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIRich(ctx, cfg, opts)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicRich(ctx, cfg, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("cloud-proxy: predict not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// PredictStreamRich is the rich streaming counterpart of PredictRich.
|
||||
// Each emitted Reply carries either a content delta, tool-call deltas,
|
||||
// or usage tokens (the final upstream frame). base.Base.PredictStream
|
||||
// is bypassed when AIModelRich is implemented, so the channel is
|
||||
// closed by the gRPC server pump.
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict-stream", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict-stream failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIStreamRich(ctx, cfg, opts, results)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicStreamRich(ctx, cfg, opts, results)
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: predictStream not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Predict is the legacy (string, error) AIModel signature. Used only
|
||||
// if a caller goes through the non-rich path (it shouldn't, since
|
||||
// server.go prefers PredictRich). Provided so the AIModel interface
|
||||
// is satisfied for backends that haven't opted into the rich variant.
|
||||
func (c *CloudProxy) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
reply, err := c.PredictRich(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.GetMessage()), nil
|
||||
}
|
||||
|
||||
// PredictStream is the legacy chan-string streaming path. Adapts the
|
||||
// rich stream by extracting only content text — tool-call-only chunks
|
||||
// (no Message bytes) and usage-only chunks are silently dropped, since
|
||||
// the legacy chan-string contract cannot represent them. Consumers
|
||||
// that need tool calls must call PredictStreamRich directly.
|
||||
func (c *CloudProxy) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
defer close(results)
|
||||
richCh := make(chan *pb.Reply)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- c.PredictStreamRich(opts, richCh)
|
||||
close(richCh)
|
||||
}()
|
||||
for reply := range richCh {
|
||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
||||
results <- string(msg)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// sendReply pushes one Reply onto a stream channel honouring ctx
|
||||
// cancellation. Returns false on cancel so the caller can exit with
|
||||
// ctx.Err(). Used by both translate-mode providers.
|
||||
func sendReply(ctx context.Context, results chan<- *pb.Reply, reply *pb.Reply) bool {
|
||||
select {
|
||||
case results <- reply:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newToolCallDelta is a small constructor for the cross-provider
|
||||
// tool-call delta shape. Centralised so the int32 cast and the four
|
||||
// fields stay consistent across the OpenAI / Anthropic translators.
|
||||
// Empty name/args are valid — Anthropic streaming announces the call
|
||||
// with id+name then sends arguments incrementally; OpenAI's reverse
|
||||
// pattern (args without name) also lands here.
|
||||
func newToolCallDelta(index int, id, name, args string) *pb.ToolCallDelta {
|
||||
return &pb.ToolCallDelta{
|
||||
Index: int32(index),
|
||||
Id: id,
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward shovels bytes between a Forward gRPC stream and an upstream
|
||||
// HTTP request. First request message carries path/method/headers and
|
||||
// the initial body chunk; subsequent messages append body chunks. The
|
||||
// first reply carries upstream status + response headers; subsequent
|
||||
// replies stream body chunks until the upstream connection closes.
|
||||
// Cancellation of ctx (the gRPC stream context) closes the upstream
|
||||
// connection.
|
||||
func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest, out chan<- *pb.ForwardReply) error {
|
||||
defer close(out)
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
}
|
||||
|
||||
first, ok := <-in
|
||||
if !ok {
|
||||
return errors.New("cloud-proxy: Forward stream closed before first request")
|
||||
}
|
||||
|
||||
// Honour the per-request path only when the configured upstream_url
|
||||
// has no path of its own — gallery convention is to put the
|
||||
// canonical path in upstream_url.
|
||||
fullURL, err := composeURL(cfg.upstreamURL, first.GetPath())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method := first.GetMethod()
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
// Pipe the body in from the gRPC stream so the HTTP request can
|
||||
// start before the client finishes sending. The pipe-reader is
|
||||
// closed via CloseWithError on the error paths so the writer
|
||||
// goroutine doesn't block forever.
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go func() {
|
||||
var writeErr error
|
||||
defer func() { _ = pw.CloseWithError(writeErr) }()
|
||||
if len(first.GetBodyChunk()) > 0 {
|
||||
if _, writeErr = pw.Write(first.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
for req := range in {
|
||||
if len(req.GetBodyChunk()) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, writeErr = pw.Write(req.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, fullURL, pr)
|
||||
if err != nil {
|
||||
_ = pr.CloseWithError(err) // unblocks the body-pump's pw.Write
|
||||
return fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
|
||||
// Apply caller-supplied headers, then override with the
|
||||
// authorization header derived from the resolved key. Caller-
|
||||
// supplied Authorization is always replaced — operators may not
|
||||
// know the backend's auth scheme, and silently leaking through a
|
||||
// client Authorization header to a different upstream would
|
||||
// confuse the upstream and could leak credentials.
|
||||
for _, h := range first.GetHeaders() {
|
||||
if h == nil || h.GetName() == "" {
|
||||
continue
|
||||
}
|
||||
// Strip hop-by-hop headers that aren't meaningful to the
|
||||
// upstream (Host is set by the http client from the URL;
|
||||
// Content-Length is computed from the body).
|
||||
if isHopByHopHeader(h.GetName()) {
|
||||
continue
|
||||
}
|
||||
req.Header.Add(h.GetName(), h.GetValue())
|
||||
}
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
|
||||
xlog.Info("cloud-proxy: forward", "method", method, "url", fullURL, "provider", cfg.provider)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: forward upstream failed", "url", fullURL, "error", err)
|
||||
return fmt.Errorf("cloud-proxy: upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
logFn := xlog.Info
|
||||
if resp.StatusCode >= 400 {
|
||||
logFn = xlog.Warn
|
||||
}
|
||||
logFn("cloud-proxy: forward response", "url", fullURL, "status", resp.StatusCode)
|
||||
|
||||
// First reply: status + response headers, no body.
|
||||
headers := make([]*pb.ForwardHeader, 0, len(resp.Header))
|
||||
for k, vs := range resp.Header {
|
||||
for _, v := range vs {
|
||||
headers = append(headers, &pb.ForwardHeader{Name: k, Value: v})
|
||||
}
|
||||
}
|
||||
out <- &pb.ForwardReply{Status: int32(resp.StatusCode), Headers: headers}
|
||||
|
||||
// Subsequent replies: body chunks. Use a fixed 8KB buffer — small
|
||||
// enough that SSE token frames flush promptly, large enough that
|
||||
// long chunked-transfer bodies aren't death by a thousand reads.
|
||||
buf := make([]byte, 8*1024)
|
||||
for {
|
||||
n, rerr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
out <- &pb.ForwardReply{BodyChunk: chunk}
|
||||
}
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("cloud-proxy: upstream body read: %w", rerr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// composeURL combines the configured upstream URL with the per-request
|
||||
// path. The upstream URL typically already includes the canonical path
|
||||
// (e.g. https://api.openai.com/v1/chat/completions) so the per-request
|
||||
// path is ignored in that case. When upstream_url is a bare host
|
||||
// (https://api.openai.com), the request path is appended.
|
||||
func composeURL(upstream, reqPath string) (string, error) {
|
||||
u, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: parse upstream_url %q: %w", upstream, err)
|
||||
}
|
||||
if u.Path == "" || u.Path == "/" {
|
||||
u.Path = reqPath
|
||||
}
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// applyAuthHeader writes the appropriate authorization header for the
|
||||
// provider. OpenAI/Anthropic/most providers use Bearer; Anthropic
|
||||
// historically uses x-api-key + anthropic-version, but accepts Bearer
|
||||
// too via the OpenAI-compatible path. Default to Bearer when provider
|
||||
// is empty (passthrough mode where the operator doesn't claim a
|
||||
// provider).
|
||||
func applyAuthHeader(req *http.Request, provider, key string) {
|
||||
switch provider {
|
||||
case providerAnthropic:
|
||||
req.Header.Set("x-api-key", key)
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
}
|
||||
|
||||
// isHopByHopHeader returns true for headers that should not be
|
||||
// forwarded from the client request to the upstream (RFC 7230 §6.1
|
||||
// hop-by-hop list, plus a few that the http.Client sets itself).
|
||||
func isHopByHopHeader(name string) bool {
|
||||
switch strings.ToLower(name) {
|
||||
case "connection", "proxy-connection", "keep-alive", "transfer-encoding",
|
||||
"te", "trailer", "upgrade", "host", "content-length":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// helper: run a CloudProxy in-process via grpc.Provide so tests can
|
||||
// call Forward through the public Backend interface without listening
|
||||
// on a real socket.
|
||||
func newInProcClient(t *testing.T, proxy *CloudProxy) grpc.Backend {
|
||||
t.Helper()
|
||||
addr := "test://" + t.Name()
|
||||
grpc.Provide(addr, proxy)
|
||||
return grpc.NewClient(addr, true, nil, false)
|
||||
}
|
||||
|
||||
func TestForward_PassthroughEcho(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Fake upstream: echoes the request body back, prefixed with a
|
||||
// canary so the test can assert both that the body reached the
|
||||
// upstream and the response made it back to the client.
|
||||
gotBody := make(chan string, 1)
|
||||
gotAuth := make(chan string, 1)
|
||||
gotPath := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody <- string(body)
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
gotPath <- r.URL.Path
|
||||
w.Header().Set("X-Echo", "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("echo: " + string(body)))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_FAKE_KEY", "sk-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{{Name: "Content-Type", Value: "application/json"}},
|
||||
BodyChunk: []byte(`{"prompt":`),
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{BodyChunk: []byte(`"hi"}`)})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.CloseSend()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// First reply: status + headers.
|
||||
first, err := stream.Recv()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(first.Status).To(Equal(int32(http.StatusOK)))
|
||||
g.Expect(hasHeader(first.Headers, "X-Echo", "true")).To(BeTrue())
|
||||
|
||||
// Subsequent replies: body.
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
body = append(body, r.BodyChunk...)
|
||||
}
|
||||
g.Expect(string(body)).To(Equal(`echo: {"prompt":"hi"}`))
|
||||
|
||||
// Upstream observations.
|
||||
var gotBodyVal, gotAuthVal, gotPathVal string
|
||||
g.Eventually(gotBody).Should(Receive(&gotBodyVal), "upstream never saw body")
|
||||
g.Expect(gotBodyVal).To(Equal(`{"prompt":"hi"}`))
|
||||
g.Eventually(gotAuth).Should(Receive(&gotAuthVal), "upstream never saw auth header")
|
||||
g.Expect(gotAuthVal).To(Equal("Bearer sk-fake"))
|
||||
g.Eventually(gotPath).Should(Receive(&gotPathVal), "upstream never saw path")
|
||||
g.Expect(gotPathVal).To(Equal("/v1/chat/completions"))
|
||||
}
|
||||
|
||||
func TestForward_AnthropicAuthHeader(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
gotXAPIKey := make(chan string, 1)
|
||||
gotVersion := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotXAPIKey <- r.Header.Get("x-api-key")
|
||||
gotVersion <- r.Header.Get("anthropic-version")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_KEY", "sk-ant-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{Path: "/v1/messages", Method: "POST"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv() // drain status
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
g.Expect(<-gotXAPIKey).To(Equal("sk-ant-fake"))
|
||||
g.Expect(<-gotVersion).NotTo(BeEmpty())
|
||||
}
|
||||
|
||||
func TestLoad_ValidatesConfig(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
|
||||
err := cp.Load(&pb.ModelOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("ProxyOptions"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("upstream_url"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: "rewrite",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unknown mode"))
|
||||
|
||||
// translate + openai should load successfully (Phase 5).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/chat/completions",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// translate + anthropic should load successfully (Phase 6).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/messages",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
ApiKeyEnv: "DEFINITELY_UNSET_ENV_VAR_XYZ",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unset"))
|
||||
}
|
||||
|
||||
func TestForward_RejectsWithoutLoad(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, err = stream.Recv()
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("not loaded"))
|
||||
}
|
||||
|
||||
func hasHeader(hs []*pb.ForwardHeader, name, value string) bool {
|
||||
for _, h := range hs {
|
||||
if strings.EqualFold(h.GetName(), name) && h.GetValue() == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
exec $CURDIR/cloud-proxy "$@"
|
||||
@@ -1,232 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// OpenAI: non-streaming tool call response. Verify the response is
|
||||
// mapped to Reply.ChatDeltas[].ToolCalls with id/name/arguments intact,
|
||||
// and usage tokens land on Reply.PromptTokens / Reply.Tokens.
|
||||
func TestPredictRich_OpenAI_ToolCalls(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"resp-1",
|
||||
"choices":[{
|
||||
"index":0,
|
||||
"message":{
|
||||
"role":"assistant",
|
||||
"content":"",
|
||||
"tool_calls":[
|
||||
{"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"SF\"}"}},
|
||||
{"id":"call_def","type":"function","function":{"name":"get_time","arguments":"{\"tz\":\"PT\"}"}}
|
||||
]
|
||||
},
|
||||
"finish_reason":"tool_calls"
|
||||
}],
|
||||
"usage":{"prompt_tokens":42,"completion_tokens":18,"total_tokens":60}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(42)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(18)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
tcs := reply.GetChatDeltas()[0].GetToolCalls()
|
||||
g.Expect(tcs).To(HaveLen(2))
|
||||
g.Expect(tcs[0].GetId()).To(Equal("call_abc"))
|
||||
g.Expect(tcs[0].GetName()).To(Equal("get_weather"))
|
||||
g.Expect(tcs[0].GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
g.Expect(tcs[1].GetId()).To(Equal("call_def"))
|
||||
g.Expect(tcs[1].GetName()).To(Equal("get_time"))
|
||||
}
|
||||
|
||||
// OpenAI: streaming tool call. Arguments arrive as a sequence of
|
||||
// delta chunks; the consumer is expected to concatenate by tool index.
|
||||
// Verify each chunk reaches the channel and the assembled arguments
|
||||
// match the input.
|
||||
func TestPredictStreamRich_OpenAI_ToolCallDeltas(t *testing.T) {
|
||||
chunks := []string{
|
||||
// Frame 0: announce the tool call (id + name, no args yet).
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_xyz","type":"function","function":{"name":"search"}}]}}]}`,
|
||||
// Frames 1-3: arguments arrive in fragments.
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"q\":"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"clo"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"uds\"}"}}]}}]}`,
|
||||
// Stop frame.
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find something"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolName string
|
||||
toolID string
|
||||
toolIndex int32 = -1
|
||||
argsBuf strings.Builder
|
||||
)
|
||||
for reply := range results {
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if toolIndex == -1 {
|
||||
toolIndex = tc.GetIndex()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("call_xyz"))
|
||||
g.Expect(toolName).To(Equal("search"))
|
||||
g.Expect(toolIndex).To(Equal(int32(0)))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"clouds"}`))
|
||||
}
|
||||
|
||||
// Anthropic: non-streaming tool_use block. The block appears in
|
||||
// Content[] alongside text blocks; the input field is a structured
|
||||
// JSON object. Map to ToolCallDelta with arguments as serialised JSON
|
||||
// so downstream OpenAI-shaped consumers see a familiar format.
|
||||
func TestPredictRich_Anthropic_ToolUse(t *testing.T) {
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"msg_1","type":"message","role":"assistant",
|
||||
"content":[
|
||||
{"type":"text","text":"Let me check that."},
|
||||
{"type":"tool_use","id":"toolu_01","name":"weather","input":{"location":"SF"}}
|
||||
],
|
||||
"model":"claude","usage":{"input_tokens":12,"output_tokens":34}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal("Let me check that."))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(12)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(34)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
g.Expect(reply.GetChatDeltas()[0].GetToolCalls()).To(HaveLen(1))
|
||||
tc := reply.GetChatDeltas()[0].GetToolCalls()[0]
|
||||
g.Expect(tc.GetId()).To(Equal("toolu_01"))
|
||||
g.Expect(tc.GetName()).To(Equal("weather"))
|
||||
g.Expect(tc.GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
}
|
||||
|
||||
// Anthropic: streaming tool_use. content_block_start announces the
|
||||
// tool's id + name; input_json_delta events carry argument fragments
|
||||
// which the consumer accumulates. message_delta carries final usage.
|
||||
func TestPredictStreamRich_Anthropic_InputJSONDelta(t *testing.T) {
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
// Block 0 is a tool_use; consumer should allocate a slot.
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_42\",\"name\":\"lookup\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"q\\\":\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"\\\"rain\\\"}\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "rain?"}},
|
||||
Tokens: 64,
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolID, toolName string
|
||||
argsBuf strings.Builder
|
||||
finalTokens int32
|
||||
)
|
||||
for reply := range results {
|
||||
if reply.GetTokens() > 0 && len(reply.GetChatDeltas()) == 0 {
|
||||
finalTokens = reply.GetTokens()
|
||||
continue
|
||||
}
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("toolu_42"))
|
||||
g.Expect(toolName).To(Equal("lookup"))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"rain"}`))
|
||||
g.Expect(finalTokens).To(Equal(int32(7)))
|
||||
}
|
||||
|
||||
// Sanity: the legacy Predict() (string, error) signature still works
|
||||
// — it delegates to PredictRich and extracts Message.
|
||||
func TestPredict_LegacyWrapper_OpenAI(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "hi"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello"))
|
||||
}
|
||||
@@ -8,6 +8,6 @@ import (
|
||||
|
||||
func assert(cond bool, msg string) {
|
||||
if !cond {
|
||||
xlog.Fatal(msg)
|
||||
xlog.Fatal().Stack().Msg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,7 @@
|
||||
package main
|
||||
|
||||
// LocalAI's in-process vector store, exposed as a gRPC backend. Keep
|
||||
// the implementation here — NOT in a pkg/ library imported by the main
|
||||
// LocalAI process. The whole point of the gRPC surface is that vector
|
||||
// storage is a backend like any other (local-store, qdrant, pinecone,
|
||||
// ...) and can be swapped without changing the routing/recognition
|
||||
// code that consumes it.
|
||||
//
|
||||
// Storage is a sorted parallel-slice (keys [][]float32, values
|
||||
// [][]byte). Set/Delete preserve the sort so Get can binary-search.
|
||||
// Find scans linearly and uses a heap to keep the top-K — fine for
|
||||
// the tens-to-thousands range. The "normalized fast path" (Find when
|
||||
// every stored key has unit magnitude AND the query is normalized)
|
||||
// skips the per-item magnitude calculation.
|
||||
//
|
||||
// Concurrency: base.SingleThread serialises gRPC calls so the
|
||||
// non-thread-safe slice/heap manipulation here is sound.
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
@@ -25,29 +10,32 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
base.SingleThread
|
||||
|
||||
keys [][]float32
|
||||
// The sorted keys
|
||||
keys [][]float32
|
||||
// The sorted values
|
||||
values [][]byte
|
||||
|
||||
// keysAreNormalized stays true until any non-unit-magnitude key
|
||||
// is added; once false, the magnitude-aware fallback path is
|
||||
// used by Find. Re-evaluated only at Set time, never again on
|
||||
// its own — a deletion of the offending key does NOT flip it
|
||||
// back to true (the bookkeeping cost would dominate the gain).
|
||||
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
|
||||
// TODO: Should we normalize incoming keys if they are not instead?
|
||||
keysAreNormalized bool
|
||||
|
||||
// keyLen is the dimension of every stored key. -1 means "no
|
||||
// keys yet, dimension is open". Dimension mismatch on Set is
|
||||
// rejected so cosine similarity (which requires equal-length
|
||||
// vectors) doesn't silently mis-match.
|
||||
// The first key decides the length of the keys
|
||||
keyLen int
|
||||
}
|
||||
|
||||
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
|
||||
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
|
||||
type Pair struct {
|
||||
Key []float32
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
keys: make([][]float32, 0),
|
||||
@@ -57,278 +45,334 @@ func NewStore() *Store {
|
||||
}
|
||||
}
|
||||
|
||||
// Load is a no-op — local-store has no on-disk artefact. opts.Model is
|
||||
// just a namespace identifier; isolation is already handled upstream
|
||||
// (ModelLoader spawns a fresh local-store process per (backend,
|
||||
// model) tuple, so each namespace is its own Store{} instance).
|
||||
func compareSlices(k1, k2 []float32) int {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
return slices.Compare(k1, k2)
|
||||
}
|
||||
|
||||
func hasKey(unsortedSlice [][]float32, target []float32) bool {
|
||||
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
|
||||
return compareSlices(k, target) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
|
||||
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
|
||||
return compareSlices(k, t)
|
||||
})
|
||||
}
|
||||
|
||||
func isSortedPairs(kvs []Pair) bool {
|
||||
for i := 1; i < len(kvs); i++ {
|
||||
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isSortedKeys(keys [][]float32) bool {
|
||||
for i := 1; i < len(keys); i++ {
|
||||
if compareSlices(keys[i-1], keys[i]) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
|
||||
ks := make([][]float32, len(keys))
|
||||
|
||||
for i, k := range keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
|
||||
slices.SortFunc(ks, compareSlices)
|
||||
|
||||
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
|
||||
assert(isSortedKeys(ks), "keys are not sorted")
|
||||
|
||||
return ks
|
||||
}
|
||||
|
||||
func (s *Store) Load(opts *pb.ModelOptions) error {
|
||||
// local-store is an in-memory vector store with no on-disk artefact to
|
||||
// load — opts.Model is just a namespace identifier. The old `!= ""` guard
|
||||
// rejected any non-empty model name with "not implemented", which broke
|
||||
// callers that pass a namespace to isolate embedding spaces (face vs.
|
||||
// voice biometrics both go through local-store but need distinct stores
|
||||
// so ArcFace 512-D and ECAPA-TDNN 192-D don't collide). Namespace
|
||||
// isolation is already handled upstream: ModelLoader spawns a fresh
|
||||
// local-store process per (backend, model) tuple, so each namespace is
|
||||
// its own Store{} instance. Nothing to do here beyond accepting the load.
|
||||
_ = opts
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort the incoming kvs and merge them with the existing sorted kvs
|
||||
func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
values := store.UnwrapValues(opts.Values)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Set: no keys to add")
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
}
|
||||
if len(keys) != len(values) {
|
||||
return fmt.Errorf("local-store: Set: len(keys) = %d, len(values) = %d", len(keys), len(values))
|
||||
|
||||
if len(opts.Keys) != len(opts.Values) {
|
||||
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(keys[0])
|
||||
} else if len(keys[0]) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key length %d does not match existing %d", len(keys[0]), s.keyLen)
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
kvs := make([]incomingPair, len(keys))
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
if s.keysAreNormalized && !isNormalized(k) {
|
||||
kvs := make([]Pair, len(opts.Keys))
|
||||
|
||||
for i, k := range opts.Keys {
|
||||
if s.keysAreNormalized && !isNormalized(k.Floats) {
|
||||
s.keysAreNormalized = false
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = k.Floats[:5]
|
||||
} else {
|
||||
sample = k.Floats
|
||||
}
|
||||
xlog.Debug("Key is not normalized", "sample", sample)
|
||||
}
|
||||
|
||||
kvs[i] = Pair{
|
||||
Key: k.Floats,
|
||||
Value: opts.Values[i].Bytes,
|
||||
}
|
||||
kvs[i] = incomingPair{key: k, value: values[i]}
|
||||
}
|
||||
|
||||
slices.SortFunc(kvs, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) })
|
||||
slices.SortFunc(kvs, func(a, b Pair) int {
|
||||
return compareSlices(a.Key, b.Key)
|
||||
})
|
||||
|
||||
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
|
||||
assert(isSortedPairs(kvs), "keys are not sorted")
|
||||
|
||||
l := len(kvs) + len(s.keys)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
i, j := 0, 0
|
||||
for {
|
||||
if i+j >= l {
|
||||
break
|
||||
}
|
||||
|
||||
if i >= len(kvs) {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
continue
|
||||
}
|
||||
|
||||
if j >= len(s.keys) {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
c := compareSlices(kvs[i].Key, s.keys[j])
|
||||
if c < 0 {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
} else if c > 0 {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
} else {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
|
||||
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
merged := mergeSortedPairs(s.keys, s.values, kvs)
|
||||
s.keys = merged.keys
|
||||
s.values = merged.values
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Set: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Set: keys/values length skew")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Delete: no keys to delete")
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to delete")
|
||||
}
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Delete: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
l := len(s.keys) - len(ks)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
tail_ks := s.keys
|
||||
tail_vs := s.values
|
||||
for _, k := range ks {
|
||||
j, found := findInSortedSlice(tail_ks, k)
|
||||
|
||||
if found {
|
||||
merge_ks = append(merge_ks, tail_ks[:j]...)
|
||||
merge_vs = append(merge_vs, tail_vs[:j]...)
|
||||
tail_ks = tail_ks[j+1:]
|
||||
tail_vs = tail_vs[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
|
||||
}
|
||||
|
||||
xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs))
|
||||
}
|
||||
|
||||
merge_ks = append(merge_ks, tail_ks...)
|
||||
merge_vs = append(merge_vs, tail_vs...)
|
||||
|
||||
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
|
||||
assert(isSortedKeys(s.keys), "keys are not sorted")
|
||||
assert(func() bool {
|
||||
for _, k := range ks {
|
||||
if _, found := findInSortedSlice(s.keys, k); found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
return true
|
||||
}(), "Keys to delete still present")
|
||||
|
||||
mergedK := make([][]float32, 0, len(s.keys))
|
||||
mergedV := make([][]byte, 0, len(s.keys))
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if ok {
|
||||
mergedK = append(mergedK, tailK[:j]...)
|
||||
mergedV = append(mergedV, tailV[:j]...)
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
if len(s.keys) != l {
|
||||
xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l)
|
||||
}
|
||||
mergedK = append(mergedK, tailK...)
|
||||
mergedV = append(mergedV, tailV...)
|
||||
s.keys = mergedK
|
||||
s.values = mergedV
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Delete: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Delete: keys/values length skew")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StoresGet fetches values for the given keys. Missing keys are
|
||||
// omitted from the result rather than reported as an error — callers
|
||||
// compare returned-key length against requested-key length to detect
|
||||
// them. Returned slices are aligned.
|
||||
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys))
|
||||
pbValues := make([]*pb.StoresValue, 0, len(opts.Keys))
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
if len(s.keys) == 0 {
|
||||
return pb.StoresGetResult{}, nil
|
||||
}
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("local-store: Get: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
|
||||
var foundKeys [][]float32
|
||||
var foundValues [][]byte
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
foundKeys = append(foundKeys, tailK[j])
|
||||
foundValues = append(foundValues, tailV[j])
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
return pb.StoresGetResult{
|
||||
Keys: store.WrapKeys(foundKeys),
|
||||
Values: store.WrapValues(foundValues),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StoresFind returns the topK nearest stored entries by cosine
|
||||
// similarity, ordered most-similar first. An empty store returns
|
||||
// empty slices and no error.
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
query := opts.Key.Floats
|
||||
topK := int(opts.TopK)
|
||||
if topK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: topK = %d, must be >= 1", topK)
|
||||
}
|
||||
if len(s.keys) == 0 {
|
||||
return pb.StoresFindResult{}, nil
|
||||
}
|
||||
if len(query) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: query length %d does not match existing %d", len(query), s.keyLen)
|
||||
xlog.Debug("Get: No keys in store")
|
||||
}
|
||||
|
||||
var keys [][]float32
|
||||
var values [][]byte
|
||||
var sims []float32
|
||||
if s.keysAreNormalized && isNormalized(query) {
|
||||
keys, values, sims = s.findNormalized(query, topK)
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
keys, values, sims = s.findFallback(query, topK)
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
return pb.StoresFindResult{
|
||||
Keys: store.WrapKeys(keys),
|
||||
Values: store.WrapValues(values),
|
||||
Similarities: sims,
|
||||
|
||||
tail_k := s.keys
|
||||
tail_v := s.values
|
||||
for i, k := range ks {
|
||||
j, found := findInSortedSlice(tail_k, k)
|
||||
|
||||
if found {
|
||||
pbKeys = append(pbKeys, &pb.StoresKey{
|
||||
Floats: k,
|
||||
})
|
||||
pbValues = append(pbValues, &pb.StoresValue{
|
||||
Bytes: tail_v[j],
|
||||
})
|
||||
|
||||
tail_k = tail_k[j+1:]
|
||||
tail_v = tail_v[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k))
|
||||
}
|
||||
}
|
||||
|
||||
if len(pbKeys) != len(opts.Keys) {
|
||||
xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys))
|
||||
}
|
||||
|
||||
return pb.StoresGetResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) findNormalized(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
assert(s.keysAreNormalized, "findNormalized: s.keysAreNormalized is false")
|
||||
assert(isNormalized(query), "findNormalized: query is not unit-length")
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot float32
|
||||
for j := range k {
|
||||
dot += query[j] * k[j]
|
||||
}
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("findNormalized: dot %f out of [-1, 1] — keysAreNormalized invariant violated", dot))
|
||||
heap.Push(&pq, &priorityItem{similarity: dot, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func (s *Store) findFallback(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
var qmag float64
|
||||
for _, v := range query {
|
||||
qmag += float64(v) * float64(v)
|
||||
}
|
||||
qmag = math.Sqrt(qmag)
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot, kmag float64
|
||||
for j := range k {
|
||||
dot += float64(query[j]) * float64(k[j])
|
||||
kmag += float64(k[j]) * float64(k[j])
|
||||
}
|
||||
denom := qmag * math.Sqrt(kmag)
|
||||
var sim float32
|
||||
if denom > 0 {
|
||||
sim = float32(dot / denom)
|
||||
}
|
||||
heap.Push(&pq, &priorityItem{similarity: sim, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func isNormalized(k []float32) bool {
|
||||
var sum float64
|
||||
|
||||
for _, v := range k {
|
||||
sum += float64(v) * float64(v)
|
||||
v64 := float64(v)
|
||||
sum += v64 * v64
|
||||
}
|
||||
mag := math.Sqrt(sum)
|
||||
return mag >= 0.99 && mag <= 1.01
|
||||
|
||||
s := math.Sqrt(sum)
|
||||
|
||||
return s >= 0.99 && s <= 1.01
|
||||
}
|
||||
|
||||
type incomingPair struct {
|
||||
key []float32
|
||||
value []byte
|
||||
}
|
||||
// TODO: This we could replace with handwritten SIMD code
|
||||
func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
type pairs struct {
|
||||
keys [][]float32
|
||||
values [][]byte
|
||||
}
|
||||
|
||||
// mergeSortedPairs merges (existing, incoming) into a fresh sorted
|
||||
// slice. Equal keys take the incoming value — Set is upsert.
|
||||
func mergeSortedPairs(existingK [][]float32, existingV [][]byte, incoming []incomingPair) pairs {
|
||||
assert(slices.IsSortedFunc(existingK, slices.Compare[[]float32]), "mergeSortedPairs: existing not sorted")
|
||||
assert(slices.IsSortedFunc(incoming, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) }), "mergeSortedPairs: incoming not sorted")
|
||||
l := len(existingK) + len(incoming)
|
||||
mk := make([][]float32, 0, l)
|
||||
mv := make([][]byte, 0, l)
|
||||
i, j := 0, 0
|
||||
for i < len(incoming) || j < len(existingK) {
|
||||
switch {
|
||||
case j >= len(existingK):
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case i >= len(incoming):
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
c := slices.Compare(incoming[i].key, existingK[j])
|
||||
switch {
|
||||
case c < 0:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case c > 0:
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
var dot float32
|
||||
for i := range len(k1) {
|
||||
dot += k1[i] * k2[i]
|
||||
}
|
||||
return pairs{keys: mk, values: mv}
|
||||
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot))
|
||||
|
||||
// 2.0 * (1.0 - dot) would be the Euclidean distance
|
||||
return dot
|
||||
}
|
||||
|
||||
type priorityItem struct {
|
||||
similarity float32
|
||||
key []float32
|
||||
value []byte
|
||||
type PriorityItem struct {
|
||||
Similarity float32
|
||||
Key []float32
|
||||
Value []byte
|
||||
}
|
||||
|
||||
type priorityQueue []*priorityItem
|
||||
type PriorityQueue []*PriorityItem
|
||||
|
||||
func (pq priorityQueue) Len() int { return len(pq) }
|
||||
func (pq priorityQueue) Less(i, j int) bool { return pq[i].similarity < pq[j].similarity }
|
||||
func (pq priorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
|
||||
func (pq *priorityQueue) Push(x any) { *pq = append(*pq, x.(*priorityItem)) }
|
||||
func (pq *priorityQueue) Pop() any {
|
||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq PriorityQueue) Less(i, j int) bool {
|
||||
// Inverted because the most similar should be at the top
|
||||
return pq[i].Similarity < pq[j].Similarity
|
||||
}
|
||||
|
||||
func (pq PriorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Push(x any) {
|
||||
item := x.(*PriorityItem)
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Pop() any {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
@@ -336,16 +380,142 @@ func (pq *priorityQueue) Pop() any {
|
||||
return item
|
||||
}
|
||||
|
||||
func drainPQ(pq *priorityQueue) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
n := pq.Len()
|
||||
keys = make([][]float32, n)
|
||||
values = make([][]byte, n)
|
||||
similarities = make([]float32, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
item := heap.Pop(pq).(*priorityItem)
|
||||
keys[i] = item.key
|
||||
values[i] = item.value
|
||||
similarities[i] = item.similarity
|
||||
func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
for i, k := range s.keys {
|
||||
sim := normalizedCosineSimilarity(tk, k)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: sim,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot, mag2 float64
|
||||
for i := range len(k1) {
|
||||
dot += float64(k1[i] * k2[i])
|
||||
mag2 += float64(k2[i] * k2[i])
|
||||
}
|
||||
|
||||
sim := float32(dot / (mag1 * math.Sqrt(mag2)))
|
||||
|
||||
assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim))
|
||||
|
||||
return sim
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
var mag1 float64
|
||||
for _, v := range tk {
|
||||
mag1 += float64(v * v)
|
||||
}
|
||||
mag1 = math.Sqrt(mag1)
|
||||
|
||||
for i, k := range s.keys {
|
||||
dist := cosineSimilarity(tk, k, mag1)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: dist,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
|
||||
if len(tk) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen)
|
||||
}
|
||||
|
||||
if opts.TopK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK)
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Key.Floats)
|
||||
} else {
|
||||
if len(opts.Key.Floats) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
if s.keysAreNormalized && isNormalized(tk) {
|
||||
return s.StoresFindNormalized(opts)
|
||||
} else {
|
||||
if s.keysAreNormalized {
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = tk[:5]
|
||||
} else {
|
||||
sample = tk
|
||||
}
|
||||
xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample)
|
||||
}
|
||||
|
||||
return s.StoresFindFallback(opts)
|
||||
}
|
||||
return keys, values, similarities
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLocalStore(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "local-store test suite")
|
||||
}
|
||||
@@ -1,284 +0,0 @@
|
||||
package main
|
||||
|
||||
// Regression suite for the local-store gRPC backend. Exercises the
|
||||
// Stores{Set,Get,Find,Delete} surface — the only public contract.
|
||||
// Callers (face/voice recognition, the routing KNN classifier) reach
|
||||
// this code via grpc.Backend, so testing at the wire-shaped boundary
|
||||
// matches the production import shape.
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("StoresSet", func() {
|
||||
It("rejects empty input", func() {
|
||||
Expect(NewStore().StoresSet(&pb.StoresSetOptions{})).NotTo(Succeed(), "Set with no keys should fail")
|
||||
})
|
||||
|
||||
It("rejects key/value length mismatch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("a"), []byte("b")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "len(keys) != len(values) should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch on later add", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("3d")})
|
||||
err := s.StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "dimension mismatch on later Set should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch within batch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}, {1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("3d"), []byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "mixed-dimension within one batch should fail")
|
||||
})
|
||||
|
||||
It("merges sorted and updates existing key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.3, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("c"), []byte("a")})
|
||||
mustSet(s, [][]float32{{0.2, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("b"), []byte("a-updated")})
|
||||
Expect(s.keys).To(HaveLen(3))
|
||||
got := singleGet(s, []float32{0.1, 0, 0})
|
||||
Expect(string(got)).To(Equal("a-updated"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresGet", func() {
|
||||
It("round-trips multi-key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c")},
|
||||
)
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("omits missing keys rather than erroring", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.1, 0, 0}, {0.9, 0, 0}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresDelete", func() {
|
||||
It("removes and preserves sort", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0, 0}, {0.2, 0, 0}, {0.3, 0, 0}, {0.4, 0, 0}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c"), []byte("d")},
|
||||
)
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.2, 0, 0}, {0.4, 0, 0}}),
|
||||
})).To(Succeed())
|
||||
Expect(s.keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("tolerates missing keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.9, 0, 0}}),
|
||||
})).To(Succeed(), "delete of missing key should succeed")
|
||||
Expect(s.keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresFind", func() {
|
||||
It("returns normalized top-K", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{
|
||||
normalizeVec([]float32{1, 0, 0}),
|
||||
normalizeVec([]float32{0, 1, 0}),
|
||||
normalizeVec([]float32{0, 0, 1}),
|
||||
},
|
||||
[][]byte{[]byte("x"), []byte("y"), []byte("z")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{0.9, 0.1, 0})},
|
||||
TopK: 2,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", res.Similarities[1]), "results not sorted desc by similarity")
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
})
|
||||
|
||||
It("falls back for non-normalized keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{2, 0, 0}, {0, 3, 0}}, [][]byte{[]byte("x"), []byte("y")})
|
||||
Expect(s.keysAreNormalized).To(BeFalse(), "store should report non-normalized after Set with magnitude > 1")
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{4, 0, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", float32(0.99)))
|
||||
Expect(res.Similarities[0]).To(BeNumerically("<=", float32(1.01)))
|
||||
})
|
||||
|
||||
It("rejects zero topK", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 0,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with topK=0 should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with mismatched dimension should fail")
|
||||
})
|
||||
|
||||
It("returns empty result on empty store", func() {
|
||||
res, err := NewStore().StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 5,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred(), "Find on empty store should succeed")
|
||||
Expect(res.Keys).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("handles topK larger than store", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{normalizeVec([]float32{1, 0, 0}), normalizeVec([]float32{0, 1, 0})},
|
||||
[][]byte{[]byte("x"), []byte("y")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{1, 0, 0})},
|
||||
TopK: 10,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresLoad", func() {
|
||||
It("is a no-op", func() {
|
||||
Expect(NewStore().Load(&pb.ModelOptions{Model: "any-namespace"})).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
func BenchmarkStoresFindNormalized(b *testing.B) {
|
||||
const dim = 768
|
||||
for _, n := range []int{8, 32, 128, 512} {
|
||||
b.Run(fmtN(n), func(b *testing.B) {
|
||||
s := buildStore(b, n, dim)
|
||||
query := normalizeVec(randVec(dim, 42))
|
||||
req := &pb.StoresFindOptions{Key: &pb.StoresKey{Floats: query}, TopK: 1}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := s.StoresFind(req); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
func mustSet(s *Store, keys [][]float32, values [][]byte) {
|
||||
ExpectWithOffset(1, s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)})).To(Succeed())
|
||||
}
|
||||
|
||||
func singleGet(s *Store, key []float32) []byte {
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{Keys: wrapKeys([][]float32{key})})
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
if len(res.Values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return res.Values[0].Bytes
|
||||
}
|
||||
|
||||
func wrapKeys(in [][]float32) []*pb.StoresKey {
|
||||
out := make([]*pb.StoresKey, len(in))
|
||||
for i, k := range in {
|
||||
out[i] = &pb.StoresKey{Floats: k}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func wrapValues(in [][]byte) []*pb.StoresValue {
|
||||
out := make([]*pb.StoresValue, len(in))
|
||||
for i, v := range in {
|
||||
out[i] = &pb.StoresValue{Bytes: v}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildStore(tb testing.TB, n, dim int) *Store {
|
||||
tb.Helper()
|
||||
s := NewStore()
|
||||
keys := make([][]float32, n)
|
||||
values := make([][]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
keys[i] = normalizeVec(randVec(dim, int64(i)+1))
|
||||
values[i] = []byte{byte(i)}
|
||||
}
|
||||
if err := s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)}); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func randVec(dim int, seed int64) []float32 {
|
||||
r := rand.New(rand.NewPCG(uint64(seed), 0xabcdef))
|
||||
v := make([]float32, dim)
|
||||
for i := range v {
|
||||
v[i] = float32(r.NormFloat64())
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func normalizeVec(v []float32) []float32 {
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
mag := math.Sqrt(sum)
|
||||
if mag == 0 {
|
||||
return v
|
||||
}
|
||||
out := make([]float32, len(v))
|
||||
for i, x := range v {
|
||||
out[i] = float32(float64(x) / mag)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fmtN(n int) string {
|
||||
return map[int]string{8: "n=8", 32: "n=32", 128: "n=128", 512: "n=512"}[n]
|
||||
}
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=a397e03488cc27e1a42da646b82dfce9f50741c0
|
||||
STABLEDIFFUSION_GGML_VERSION?=5b0267e941cade15bd80089d89838795d9f4baa6
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@
|
||||
#include <stdlib.h>
|
||||
#include <regex>
|
||||
#include <errno.h>
|
||||
#include <inttypes.h>
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
@@ -377,8 +376,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *clip_g_path = "";
|
||||
const char *t5xxl_path = "";
|
||||
const char *vae_path = "";
|
||||
const char *audio_vae_path = "";
|
||||
const char *embeddings_connectors_path = "";
|
||||
const char *scheduler_str = "";
|
||||
const char *sampler = "";
|
||||
const char *clip_vision_path = "";
|
||||
@@ -434,12 +431,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "vae_path")) {
|
||||
vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "audio_vae_path")) {
|
||||
audio_vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "embeddings_connectors_path")) {
|
||||
embeddings_connectors_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "scheduler")) {
|
||||
scheduler_str = optval;
|
||||
}
|
||||
@@ -572,8 +563,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
ctx_params.taesd_path = taesd_path;
|
||||
ctx_params.control_net_path = control_net_path;
|
||||
if (lora_dir && strlen(lora_dir) > 0) {
|
||||
@@ -1076,71 +1065,9 @@ static uint8_t* load_and_resize_image(const char* path, int target_width, int ta
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Write sd.cpp's audio buffer to a temp WAV file (IEEE float, interleaved).
|
||||
// sd_audio_t.data is planar (all channel 0 samples, then channel 1, etc.) — we
|
||||
// interleave on the fly so ffmpeg's standard wav demuxer can read it directly.
|
||||
// Returns 0 on success and fills wav_path (must be at least 64 bytes).
|
||||
static int write_planar_float_wav(const sd_audio_t* a, char* wav_path, size_t wav_path_sz) {
|
||||
if (!a || !a->data || a->sample_count == 0 || a->channels == 0 || a->sample_rate == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
snprintf(wav_path, wav_path_sz, "/tmp/gosd-audio-XXXXXX.wav");
|
||||
int fd = mkstemps(wav_path, 4);
|
||||
if (fd < 0) { perror("mkstemps wav"); return -1; }
|
||||
FILE* f = fdopen(fd, "wb");
|
||||
if (!f) { perror("fdopen wav"); close(fd); return -1; }
|
||||
|
||||
uint64_t frames = a->sample_count;
|
||||
uint32_t channels = a->channels;
|
||||
uint32_t sample_rate = a->sample_rate;
|
||||
uint64_t total_samples64 = frames * (uint64_t)channels;
|
||||
uint64_t data_bytes64 = total_samples64 * sizeof(float);
|
||||
if (data_bytes64 > 0xFFFFFFFFull - 44) {
|
||||
fprintf(stderr, "audio too large for 32-bit WAV (%" PRIu64 " bytes)\n", data_bytes64);
|
||||
fclose(f);
|
||||
unlink(wav_path);
|
||||
return -1;
|
||||
}
|
||||
uint32_t data_bytes = (uint32_t)data_bytes64;
|
||||
uint32_t riff_size = 36 + data_bytes;
|
||||
uint16_t fmt_code = 3; // WAVE_FORMAT_IEEE_FLOAT
|
||||
uint16_t bits_per_sample = 32;
|
||||
uint16_t block_align = (uint16_t)(channels * sizeof(float));
|
||||
uint32_t byte_rate = sample_rate * block_align;
|
||||
uint16_t ch16 = (uint16_t)channels;
|
||||
uint32_t fmt_size = 16;
|
||||
|
||||
fwrite("RIFF", 1, 4, f);
|
||||
fwrite(&riff_size, 4, 1, f);
|
||||
fwrite("WAVEfmt ", 1, 8, f);
|
||||
fwrite(&fmt_size, 4, 1, f);
|
||||
fwrite(&fmt_code, 2, 1, f);
|
||||
fwrite(&ch16, 2, 1, f);
|
||||
fwrite(&sample_rate, 4, 1, f);
|
||||
fwrite(&byte_rate, 4, 1, f);
|
||||
fwrite(&block_align, 2, 1, f);
|
||||
fwrite(&bits_per_sample, 2, 1, f);
|
||||
fwrite("data", 1, 4, f);
|
||||
fwrite(&data_bytes, 4, 1, f);
|
||||
|
||||
// Interleave planar [ch0_samples..., ch1_samples...] → [ch0_s0, ch1_s0, ...]
|
||||
for (uint64_t s = 0; s < frames; s++) {
|
||||
for (uint32_t c = 0; c < channels; c++) {
|
||||
float v = a->data[(size_t)c * frames + s];
|
||||
fwrite(&v, sizeof(float), 1, f);
|
||||
}
|
||||
}
|
||||
fclose(f);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Pipe raw RGB/RGBA frames to ffmpeg stdin and let it produce an MP4 at dst.
|
||||
// Uses fork+execvp to avoid shell interpretation of dst. When `audio` is
|
||||
// non-null, the audio waveform is staged to a temp WAV and added as a second
|
||||
// ffmpeg input so the final MP4 contains both video and AAC audio.
|
||||
static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps,
|
||||
const sd_audio_t* audio, const char* dst) {
|
||||
// Uses fork+execvp to avoid shell interpretation of dst.
|
||||
static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps, const char* dst) {
|
||||
if (num_frames <= 0 || !frames || !frames[0].data) {
|
||||
fprintf(stderr, "ffmpeg_mux: empty frames\n");
|
||||
return 1;
|
||||
@@ -1155,87 +1082,38 @@ static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps,
|
||||
snprintf(size_str, sizeof(size_str), "%dx%d", width, height);
|
||||
snprintf(fps_str, sizeof(fps_str), "%d", fps);
|
||||
|
||||
// Optional audio: write a temp WAV file if the model produced audio.
|
||||
char wav_path[64] = {0};
|
||||
bool have_audio = false;
|
||||
if (audio && audio->data && audio->sample_count > 0 && audio->channels > 0 && audio->sample_rate > 0) {
|
||||
if (write_planar_float_wav(audio, wav_path, sizeof(wav_path)) == 0) {
|
||||
have_audio = true;
|
||||
fprintf(stderr, "ffmpeg_mux: audio %u Hz × %u ch × %" PRIu64 " frames → %s\n",
|
||||
audio->sample_rate, audio->channels, audio->sample_count, wav_path);
|
||||
} else {
|
||||
fprintf(stderr, "ffmpeg_mux: failed to stage audio; producing silent video\n");
|
||||
}
|
||||
}
|
||||
|
||||
int pipefd[2];
|
||||
if (pipe(pipefd) != 0) {
|
||||
perror("pipe");
|
||||
if (have_audio) unlink(wav_path);
|
||||
return 1;
|
||||
}
|
||||
if (pipe(pipefd) != 0) { perror("pipe"); return 1; }
|
||||
|
||||
pid_t pid = fork();
|
||||
if (pid < 0) {
|
||||
perror("fork");
|
||||
close(pipefd[0]); close(pipefd[1]);
|
||||
if (have_audio) unlink(wav_path);
|
||||
return 1;
|
||||
}
|
||||
if (pid < 0) { perror("fork"); close(pipefd[0]); close(pipefd[1]); return 1; }
|
||||
|
||||
if (pid == 0) {
|
||||
// child
|
||||
close(pipefd[1]);
|
||||
if (dup2(pipefd[0], STDIN_FILENO) < 0) { perror("dup2"); _exit(127); }
|
||||
close(pipefd[0]);
|
||||
std::vector<char*> argv;
|
||||
argv.push_back(const_cast<char*>("ffmpeg"));
|
||||
argv.push_back(const_cast<char*>("-y"));
|
||||
argv.push_back(const_cast<char*>("-hide_banner"));
|
||||
argv.push_back(const_cast<char*>("-loglevel"));
|
||||
argv.push_back(const_cast<char*>("warning"));
|
||||
// Input 0: raw video from stdin
|
||||
argv.push_back(const_cast<char*>("-f"));
|
||||
argv.push_back(const_cast<char*>("rawvideo"));
|
||||
argv.push_back(const_cast<char*>("-pix_fmt"));
|
||||
argv.push_back(const_cast<char*>(pix_fmt_in));
|
||||
argv.push_back(const_cast<char*>("-s"));
|
||||
argv.push_back(size_str);
|
||||
argv.push_back(const_cast<char*>("-framerate"));
|
||||
argv.push_back(fps_str);
|
||||
argv.push_back(const_cast<char*>("-i"));
|
||||
argv.push_back(const_cast<char*>("-"));
|
||||
// Input 1: optional audio WAV
|
||||
if (have_audio) {
|
||||
argv.push_back(const_cast<char*>("-i"));
|
||||
argv.push_back(wav_path);
|
||||
argv.push_back(const_cast<char*>("-map"));
|
||||
argv.push_back(const_cast<char*>("0:v:0"));
|
||||
argv.push_back(const_cast<char*>("-map"));
|
||||
argv.push_back(const_cast<char*>("1:a:0"));
|
||||
argv.push_back(const_cast<char*>("-c:a"));
|
||||
argv.push_back(const_cast<char*>("aac"));
|
||||
argv.push_back(const_cast<char*>("-b:a"));
|
||||
argv.push_back(const_cast<char*>("192k"));
|
||||
// -shortest so the final clip ends with the shorter of the two
|
||||
// streams — guards against an audio buffer that overshoots the
|
||||
// video duration (or vice versa) on certain LTX variants.
|
||||
argv.push_back(const_cast<char*>("-shortest"));
|
||||
}
|
||||
argv.push_back(const_cast<char*>("-c:v"));
|
||||
argv.push_back(const_cast<char*>("libx264"));
|
||||
argv.push_back(const_cast<char*>("-pix_fmt"));
|
||||
argv.push_back(const_cast<char*>("yuv420p"));
|
||||
argv.push_back(const_cast<char*>("-movflags"));
|
||||
argv.push_back(const_cast<char*>("+faststart"));
|
||||
// Force MP4 container. Distributed LocalAI hands us a staging
|
||||
// path (e.g. /staging/localai-output-NNN.tmp) with a non-standard
|
||||
// extension; relying on filename suffix makes ffmpeg bail with
|
||||
// "Unable to choose an output format".
|
||||
argv.push_back(const_cast<char*>("-f"));
|
||||
argv.push_back(const_cast<char*>("mp4"));
|
||||
argv.push_back(const_cast<char*>(dst));
|
||||
argv.push_back(nullptr);
|
||||
std::vector<char*> argv = {
|
||||
const_cast<char*>("ffmpeg"),
|
||||
const_cast<char*>("-y"),
|
||||
const_cast<char*>("-hide_banner"),
|
||||
const_cast<char*>("-loglevel"), const_cast<char*>("warning"),
|
||||
const_cast<char*>("-f"), const_cast<char*>("rawvideo"),
|
||||
const_cast<char*>("-pix_fmt"), const_cast<char*>(pix_fmt_in),
|
||||
const_cast<char*>("-s"), size_str,
|
||||
const_cast<char*>("-framerate"), fps_str,
|
||||
const_cast<char*>("-i"), const_cast<char*>("-"),
|
||||
const_cast<char*>("-c:v"), const_cast<char*>("libx264"),
|
||||
const_cast<char*>("-pix_fmt"), const_cast<char*>("yuv420p"),
|
||||
const_cast<char*>("-movflags"), const_cast<char*>("+faststart"),
|
||||
// Force MP4 container. Distributed LocalAI hands us a staging
|
||||
// path (e.g. /staging/localai-output-NNN.tmp) with a non-standard
|
||||
// extension; relying on filename suffix makes ffmpeg bail with
|
||||
// "Unable to choose an output format".
|
||||
const_cast<char*>("-f"), const_cast<char*>("mp4"),
|
||||
const_cast<char*>(dst),
|
||||
nullptr
|
||||
};
|
||||
execvp(argv[0], argv.data());
|
||||
perror("execvp ffmpeg");
|
||||
_exit(127);
|
||||
@@ -1260,7 +1138,6 @@ static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps,
|
||||
close(pipefd[1]);
|
||||
int status;
|
||||
waitpid(pid, &status, 0);
|
||||
if (have_audio) unlink(wav_path);
|
||||
return 1;
|
||||
}
|
||||
p += n;
|
||||
@@ -1271,13 +1148,8 @@ static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps,
|
||||
|
||||
int status = 0;
|
||||
while (waitpid(pid, &status, 0) < 0) {
|
||||
if (errno != EINTR) {
|
||||
perror("waitpid");
|
||||
if (have_audio) unlink(wav_path);
|
||||
return 1;
|
||||
}
|
||||
if (errno != EINTR) { perror("waitpid"); return 1; }
|
||||
}
|
||||
if (have_audio) unlink(wav_path);
|
||||
if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) {
|
||||
fprintf(stderr, "ffmpeg exited with status %d\n", status);
|
||||
return 1;
|
||||
@@ -1352,7 +1224,7 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
|
||||
fprintf(stderr, "Generated %d frames, muxing to %s via ffmpeg\n", num_frames_out, dst);
|
||||
|
||||
int rc = ffmpeg_mux_raw_to_mp4(frames, num_frames_out, fps, audio, dst);
|
||||
int rc = ffmpeg_mux_raw_to_mp4(frames, num_frames_out, fps, dst);
|
||||
|
||||
for (int i = 0; i < num_frames_out; i++) {
|
||||
if (frames[i].data) free(frames[i].data);
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=0ccd896f5b882628e1c077f9769735ef4ce52860
|
||||
WHISPER_CPP_VERSION?=afa2ea544fb4b0448916b4a31ecd33c8685bd482
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -36,11 +36,15 @@ fi
|
||||
# flash-attn-4 4.0 stable lands.
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --prerelease=allow"
|
||||
|
||||
# JetPack 7 / L4T arm64 sglang + torch wheels come straight from PyPI now
|
||||
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and sglang 0.5.11+
|
||||
# ships a cp312 aarch64 wheel pinned to that torch). They're cp312-only,
|
||||
# so bump the venv Python accordingly.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
# JetPack 7 / L4T arm64 wheels are built for cp312 and shipped via
|
||||
# pypi.jetson-ai-lab.io. Bump the venv Python so the prebuilt sglang
|
||||
# wheel resolves cleanly. The actual install on l4t13 goes through
|
||||
# pyproject.toml (see the elif branch below) so [tool.uv.sources] can
|
||||
# pin only torch/torchvision/torchaudio/sglang to the jetson-ai-lab
|
||||
# index — leaving PyPI as the path for transitive deps like
|
||||
# markdown-it-py / anthropic / propcache that the L4T mirror's proxy
|
||||
# 503s on. No --index-strategy flag here: the explicit index keeps the
|
||||
# scoping clean.
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
@@ -106,6 +110,27 @@ if [ "x${BUILD_TYPE}" == "x" ] || [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
|
||||
fi
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
|
||||
popd
|
||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
||||
# [tool.uv.sources] can pin torch/torchvision/torchaudio/sglang to the
|
||||
# jetson-ai-lab index, while everything else (transitive deps and
|
||||
# PyPI-resolvable packages like transformers / accelerate) comes from
|
||||
# PyPI. Bypasses installRequirements because uv pip install -r
|
||||
# requirements.txt does not honor sources — see
|
||||
# backend/python/sglang/pyproject.toml for the rationale. Mirrors the
|
||||
# equivalent path in backend/python/vllm/install.sh.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
pushd "${backend_dir}"
|
||||
# Build deps first (matches installRequirements' requirements-install.txt
|
||||
# pass — sglang/sgl-kernel sdists need packaging/setuptools-scm in the
|
||||
# venv before they can build under --no-build-isolation).
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
||||
popd
|
||||
runProtogen
|
||||
else
|
||||
installRequirements
|
||||
fi
|
||||
|
||||
68
backend/python/sglang/pyproject.toml
Normal file
68
backend/python/sglang/pyproject.toml
Normal file
@@ -0,0 +1,68 @@
|
||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the sglang backend.
|
||||
#
|
||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
||||
#
|
||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / sglang / sgl-kernel
|
||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently.
|
||||
# With `--extra-index-url` + `--index-strategy=unsafe-best-match` (the
|
||||
# historical fix in install.sh) uv would pick those proxy URLs for ordinary
|
||||
# PyPI packages — markdown-it-py, anthropic, propcache, etc. — and trip on
|
||||
# the 503s. See e.g. CI run 25439791228 (markdown-it-py-4.0.0).
|
||||
#
|
||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
||||
# This breaks the historical 503 path without losing access to the L4T
|
||||
# wheels we actually need from there. Mirrors the equivalent fix already
|
||||
# in backend/python/vllm/pyproject.toml.
|
||||
#
|
||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
||||
# pipeline through libbackend.sh's installRequirements and never read
|
||||
# this file.
|
||||
[project]
|
||||
name = "localai-sglang-l4t13"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
||||
# l4t13 path bypasses installRequirements (see install.sh).
|
||||
"grpcio==1.80.0",
|
||||
"protobuf",
|
||||
"certifi",
|
||||
"setuptools",
|
||||
"pillow",
|
||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
# sglang on jetson — the [all] extra is deliberately omitted because it
|
||||
# pulls outlines/decord, and decord has no aarch64 cp312 wheel anywhere
|
||||
# (PyPI nor the jetson-ai-lab index ships only legacy cp35-cp37). With
|
||||
# [all] uv backtracks through versions trying to satisfy decord and
|
||||
# lands on sglang==0.1.16. The 0.5.0 floor matches the only major
|
||||
# series the jetson-ai-lab sbsa/cu130 mirror currently publishes
|
||||
# (sglang==0.5.1.post2 as of 2026-05-06). Bumping to >=0.5.11 here
|
||||
# would make the build unsatisfiable until the mirror catches up.
|
||||
# Gemma 4 / MTP recipes are therefore not supported on l4t13 — those
|
||||
# features land on cublas12/cublas13 hosts that pull the newer wheel
|
||||
# from PyPI. backend.py keeps backward compat with the 0.5.x SamplingParams
|
||||
# field rename via runtime detection.
|
||||
"sglang>=0.5.0",
|
||||
# PyPI-resolvable packages that complete the runtime.
|
||||
"accelerate",
|
||||
"transformers",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "jetson-ai-lab"
|
||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "jetson-ai-lab" }
|
||||
torchvision = { index = "jetson-ai-lab" }
|
||||
torchaudio = { index = "jetson-ai-lab" }
|
||||
sglang = { index = "jetson-ai-lab" }
|
||||
@@ -1,15 +0,0 @@
|
||||
# sglang 0.5.11+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist
|
||||
# pins torch==2.11.0 / torchaudio==2.11.0, locking an ABI-consistent set with
|
||||
# the cu130 torch wheel installed above. 0.5.11 is the floor for Gemma 4
|
||||
# support (sgl-project/sglang#21952).
|
||||
#
|
||||
# The [all] extra is deliberately NOT used on aarch64: it pulls the
|
||||
# [diffusion] sub-extra which requires `xatlas`, and xatlas ships no
|
||||
# aarch64 wheel and its sdist depends on scikit_build_core without
|
||||
# declaring it in build-system.requires — so under --no-build-isolation
|
||||
# uv can't build it. Upstream sglang gates st_attn and vsa on
|
||||
# platform_machine != aarch64 in the diffusion extra but forgot xatlas.
|
||||
# Plain `sglang` carries everything backend.py uses (Engine, ServerArgs,
|
||||
# FunctionCallParser, ReasoningParser); the [all] extras are optional
|
||||
# accelerators not required at import time.
|
||||
sglang>=0.5.11
|
||||
@@ -1,9 +0,0 @@
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
@@ -26,7 +26,7 @@ import torch.cuda
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
import transformers as transformers_module
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, pipeline
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
from scipy.io import wavfile
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
@@ -200,21 +200,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
autoTokenizer = False
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.SentenceTransformer = True
|
||||
elif request.Type == "TokenClassification":
|
||||
# NER / PII tagging via HuggingFace's token-classification
|
||||
# pipeline. aggregation_strategy="simple" merges B-/I- tags
|
||||
# into single spans and gives byte offsets back. The
|
||||
# tokenizer is bundled inside the pipeline, so we skip the
|
||||
# AutoTokenizer load below.
|
||||
autoTokenizer = False
|
||||
self.tokenClassifier = pipeline(
|
||||
"token-classification",
|
||||
model=model_name,
|
||||
aggregation_strategy="simple",
|
||||
device=0 if self.CUDA else -1,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
)
|
||||
self.TokenClassification = True
|
||||
else:
|
||||
# Generic: dynamically resolve model class from transformers
|
||||
model_type = TYPE_ALIASES.get(request.Type, request.Type)
|
||||
@@ -268,39 +253,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TokenClassify(self, request, context):
|
||||
# Runs HuggingFace's token-classification pipeline and returns
|
||||
# the aggregated entity spans. The pipeline gives us byte
|
||||
# offsets via aggregation_strategy="simple" (set at load
|
||||
# time), so the caller can slice the original text without
|
||||
# re-tokenising on the Go side.
|
||||
if not getattr(self, "TokenClassification", False):
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("model was not loaded as Type=TokenClassification")
|
||||
return backend_pb2.TokenClassifyResponse()
|
||||
try:
|
||||
results = self.tokenClassifier(request.text)
|
||||
except Exception as err:
|
||||
print("TokenClassify error:", err, file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"token-classification failed: {err}")
|
||||
return backend_pb2.TokenClassifyResponse()
|
||||
|
||||
threshold = request.threshold if request.threshold > 0 else 0.0
|
||||
entities = []
|
||||
for r in results:
|
||||
score = float(r.get("score", 0.0))
|
||||
if score < threshold:
|
||||
continue
|
||||
entities.append(backend_pb2.TokenClassifyEntity(
|
||||
entity_group=str(r.get("entity_group") or r.get("entity") or ""),
|
||||
start=int(r.get("start", 0)),
|
||||
end=int(r.get("end", 0)),
|
||||
score=score,
|
||||
text=str(r.get("word", "")),
|
||||
))
|
||||
return backend_pb2.TokenClassifyResponse(entities=entities)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
set_seed(request.Seed)
|
||||
# Tokenize input
|
||||
|
||||
@@ -13,14 +13,14 @@ else
|
||||
fi
|
||||
|
||||
# Handle l4t build profiles (Python 3.12, pip fallback) if needed.
|
||||
# Since PyTorch 2.11 (April 2026) PyPI ships aarch64 + cu130 manylinux wheels
|
||||
# directly for torch/torchvision/torchaudio and an aarch64 vllm wheel pinned
|
||||
# to that torch, so the jetson-ai-lab mirror is no longer needed.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
# unsafe-best-match is required on l4t13 because the jetson-ai-lab index
|
||||
# lists transitive deps at limited versions — without it uv pins to the
|
||||
# first matching index and fails to resolve a compatible wheel from PyPI.
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
EXTRA_PIP_INSTALL_FLAGS="${EXTRA_PIP_INSTALL_FLAGS:-} --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
@@ -42,11 +42,18 @@ if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
else
|
||||
uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700
|
||||
fi
|
||||
elif [ "x${BUILD_PROFILE}" == "xcublas13" ] || [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
# cublas13 (x86_64) and l4t13 (aarch64) both pull vllm from PyPI now:
|
||||
# vllm 0.19+ defaults to cu130 wheels on x86_64 and vllm 0.20+ ships an
|
||||
# aarch64 manylinux wheel pinned to torch==2.11.0. No extra index needed
|
||||
# in either case.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
# JetPack 7 / L4T arm64 cu130 — vllm comes from the prebuilt SBSA wheel
|
||||
# at jetson-ai-lab. Version is unpinned: the index ships whatever build
|
||||
# matches the cu130/cp312 ABI. unsafe-best-match lets uv fall through
|
||||
# to PyPI for transitive deps not present on the jetson-ai-lab index.
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
else
|
||||
uv pip install --index-strategy=unsafe-best-match vllm --extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
fi
|
||||
elif [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
||||
# vllm 0.19+ defaults to cu130 wheels on PyPI, no extra index needed.
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm --torch-backend=auto
|
||||
else
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. PyPI ships aarch64 + cu130 manylinux wheels
|
||||
# for torch/torchvision/torchaudio directly since PyTorch 2.11 (April 2026),
|
||||
# so no custom index is needed. flash-attn is dropped here: PyPI has no
|
||||
# aarch64 wheel for it, but vLLM 0.20+ bundles its own vllm_flash_attn
|
||||
# (fa2 + fa3) inside the main wheel, so it is not required at runtime.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
bitsandbytes
|
||||
flash-attn
|
||||
diffusers
|
||||
librosa
|
||||
soundfile
|
||||
|
||||
@@ -356,133 +356,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
async def Score(self, request, context):
|
||||
"""
|
||||
Joint log-probability of each candidate continuation given the
|
||||
shared prompt. Used by routing-policy multi-label classification
|
||||
(read the distribution rather than asking the model to emit a
|
||||
single argmax label), reranking, and reward-model scoring.
|
||||
|
||||
Implementation uses vLLM's `prompt_logprobs` to recover the
|
||||
per-token log P(token_i | tokens_<i) for the full concatenated
|
||||
sequence; the candidate's tokens are the suffix whose logprobs
|
||||
get summed. max_tokens=1 because vLLM requires at least one
|
||||
generated token; the generated token is discarded.
|
||||
"""
|
||||
if not hasattr(self, 'llm') or self.llm is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Model not loaded")
|
||||
return backend_pb2.ScoreResponse()
|
||||
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Tokenizer not available")
|
||||
return backend_pb2.ScoreResponse()
|
||||
if len(request.candidates) == 0:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("candidates must be non-empty")
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
try:
|
||||
prompt = request.prompt or ""
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
results = []
|
||||
|
||||
for candidate in request.candidates:
|
||||
# Tokenise the concatenated sequence. We can't naively
|
||||
# use len(prompt_tokens) + len(tokenizer.encode(candidate))
|
||||
# because BPE merges at the boundary may produce a
|
||||
# different tokenisation. Encoding the joined text and
|
||||
# walking the divergence point is the correct primitive.
|
||||
full_text = prompt + candidate
|
||||
full_token_ids = self.tokenizer.encode(full_text)
|
||||
|
||||
divergence = prompt_len
|
||||
min_len = min(prompt_len, len(full_token_ids))
|
||||
for i in range(min_len):
|
||||
if prompt_token_ids[i] != full_token_ids[i]:
|
||||
divergence = i
|
||||
break
|
||||
|
||||
candidate_token_ids = full_token_ids[divergence:]
|
||||
num_candidate_tokens = len(candidate_token_ids)
|
||||
if num_candidate_tokens == 0:
|
||||
results.append(backend_pb2.CandidateScore(
|
||||
log_prob=0.0,
|
||||
length_normalized_log_prob=0.0,
|
||||
num_tokens=0,
|
||||
))
|
||||
continue
|
||||
|
||||
sampling = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
prompt_logprobs=1,
|
||||
detokenize=False,
|
||||
)
|
||||
|
||||
request_id = random_uuid()
|
||||
last_output = None
|
||||
outputs_iter = self.llm.generate(
|
||||
{"prompt": full_text},
|
||||
sampling_params=sampling,
|
||||
request_id=request_id,
|
||||
)
|
||||
try:
|
||||
async for out in outputs_iter:
|
||||
last_output = out
|
||||
finally:
|
||||
try:
|
||||
await outputs_iter.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if last_output is None or not getattr(last_output, "prompt_logprobs", None):
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details("vLLM did not return prompt_logprobs")
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
prompt_logprobs = last_output.prompt_logprobs
|
||||
total = 0.0
|
||||
tokens_proto = []
|
||||
for offset, tok_id in enumerate(candidate_token_ids):
|
||||
position = divergence + offset
|
||||
if position >= len(prompt_logprobs) or prompt_logprobs[position] is None:
|
||||
continue
|
||||
entry = prompt_logprobs[position]
|
||||
lp_obj = entry.get(tok_id)
|
||||
if lp_obj is not None:
|
||||
lp = lp_obj.logprob
|
||||
else:
|
||||
# Token not in top-K; vLLM's top-1 may miss it.
|
||||
# Fall back to the lowest available logprob in the
|
||||
# entry — a conservative lower-bound on the true
|
||||
# log P, biased against this candidate.
|
||||
lp = min(v.logprob for v in entry.values())
|
||||
total += lp
|
||||
if request.include_token_logprobs:
|
||||
tokens_proto.append(backend_pb2.TokenLogProb(
|
||||
token=self.tokenizer.decode([tok_id]),
|
||||
log_prob=lp,
|
||||
))
|
||||
|
||||
cs = backend_pb2.CandidateScore(
|
||||
log_prob=total,
|
||||
num_tokens=num_candidate_tokens,
|
||||
)
|
||||
if request.length_normalize and num_candidate_tokens > 0:
|
||||
cs.length_normalized_log_prob = total / num_candidate_tokens
|
||||
if tokens_proto:
|
||||
cs.tokens.extend(tokens_proto)
|
||||
results.append(cs)
|
||||
|
||||
return backend_pb2.ScoreResponse(candidates=results)
|
||||
except Exception as e:
|
||||
print(f"Score error: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
# Build the sampling parameters
|
||||
# NOTE: this must stay in sync with the vllm backend
|
||||
|
||||
@@ -43,11 +43,14 @@ if [ "x${BUILD_PROFILE}" == "xcublas13" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# JetPack 7 / L4T arm64 vllm + torch wheels come straight from PyPI now
|
||||
# (torch 2.11+ ships aarch64 + cu130 manylinux wheels and vllm 0.20+ ships
|
||||
# an aarch64 wheel pinned to that torch). They're cp312-only, so bump the
|
||||
# venv Python accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
# JetPack 7 / L4T arm64 wheels (torch, vllm, flash-attn) live on
|
||||
# pypi.jetson-ai-lab.io and are built for cp312, so bump the venv Python
|
||||
# accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||
#
|
||||
# l4t13 uses pyproject.toml (see the elif branch below) to pin only the
|
||||
# L4T-specific wheels to the jetson-ai-lab index via [tool.uv.sources].
|
||||
# That keeps PyPI as the resolution path for transitive deps like
|
||||
# anthropic/openai/propcache, which the L4T mirror's proxy 503s on.
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
@@ -100,6 +103,25 @@ if [ "x${BUILD_TYPE}" == "xintel" ]; then
|
||||
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
|
||||
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
|
||||
popd
|
||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
||||
# [tool.uv.sources] can pin torch/vllm/flash-attn/torchvision/torchaudio
|
||||
# to the jetson-ai-lab index, while everything else (transitive deps and
|
||||
# PyPI-resolvable packages like transformers) comes from PyPI. Bypasses
|
||||
# installRequirements because uv pip install -r requirements.txt does not
|
||||
# honor sources — see backend/python/vllm/pyproject.toml for the rationale.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
pushd "${backend_dir}"
|
||||
# Build deps first (matches installRequirements' requirements-install.txt
|
||||
# pass — fastsafetensors and friends need pybind11 in the venv before
|
||||
# their sdists can build under --no-build-isolation).
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
||||
popd
|
||||
runProtogen
|
||||
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
|
||||
# requirements-cpu-after.txt and compiles vllm locally against the host's
|
||||
# actual CPU. Not used by default because it takes ~30-40 minutes, but
|
||||
|
||||
61
backend/python/vllm/pyproject.toml
Normal file
61
backend/python/vllm/pyproject.toml
Normal file
@@ -0,0 +1,61 @@
|
||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the vllm backend.
|
||||
#
|
||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
||||
#
|
||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / vllm / flash-attn
|
||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently. With
|
||||
# `--extra-index-url` + `--index-strategy=unsafe-best-match` (the historical
|
||||
# fix in install.sh) uv would pick those proxy URLs for ordinary PyPI
|
||||
# packages — `anthropic`, `openai`, `propcache`, `annotated-types` — and
|
||||
# trip on the 503s. See e.g. CI run 25212201349 (anthropic-0.97.0).
|
||||
#
|
||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
||||
# This breaks the historical 503 path without losing access to the L4T
|
||||
# wheels we actually need from there.
|
||||
#
|
||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
||||
# pipeline through libbackend.sh's installRequirements and never read
|
||||
# this file.
|
||||
[project]
|
||||
name = "localai-vllm-l4t13"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
||||
# l4t13 path bypasses installRequirements (see install.sh).
|
||||
"grpcio==1.80.0",
|
||||
"protobuf",
|
||||
"certifi",
|
||||
"setuptools",
|
||||
"pillow",
|
||||
"charset-normalizer>=3.4.7",
|
||||
"chardet",
|
||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"flash-attn",
|
||||
"vllm",
|
||||
# PyPI-resolvable packages that complete the runtime — accelerate,
|
||||
# transformers, bitsandbytes carry their own wheels for aarch64.
|
||||
"accelerate",
|
||||
"transformers",
|
||||
"bitsandbytes",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "jetson-ai-lab"
|
||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "jetson-ai-lab" }
|
||||
torchvision = { index = "jetson-ai-lab" }
|
||||
torchaudio = { index = "jetson-ai-lab" }
|
||||
flash-attn = { index = "jetson-ai-lab" }
|
||||
vllm = { index = "jetson-ai-lab" }
|
||||
@@ -1,4 +0,0 @@
|
||||
# vLLM 0.20+ ships an aarch64 manylinux wheel on PyPI whose Requires-Dist pins
|
||||
# torch==2.11.0 / torchvision==0.26.0 / torchaudio==2.11.0, locking an ABI-
|
||||
# consistent set with the cu130 torch wheel installed above.
|
||||
vllm
|
||||
@@ -1,8 +0,0 @@
|
||||
# JetPack 7 / L4T arm64 + CUDA 13. Since PyTorch 2.11 (April 2026), PyPI ships
|
||||
# aarch64 + cu130 manylinux wheels for torch/torchvision/torchaudio directly,
|
||||
# so we no longer need a custom --extra-index-url for the L4T mirror.
|
||||
# https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/
|
||||
accelerate
|
||||
torch
|
||||
transformers
|
||||
bitsandbytes
|
||||
@@ -375,15 +375,6 @@ impl Backend for KokorosService {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type AudioToAudioStreamStream = ReceiverStream<Result<backend::AudioToAudioResponse, Status>>;
|
||||
|
||||
async fn audio_to_audio_stream(
|
||||
&self,
|
||||
_: Request<tonic::Streaming<backend::AudioToAudioRequest>>,
|
||||
) -> Result<Response<Self::AudioToAudioStreamStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn sound_generation(
|
||||
&self,
|
||||
_: Request<backend::SoundGenerationRequest>,
|
||||
|
||||
@@ -9,18 +9,11 @@ import (
|
||||
|
||||
corebackend "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services/facerecognition"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/voicerecognition"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
@@ -58,22 +51,6 @@ type Application struct {
|
||||
faceRegistry facerecognition.Registry
|
||||
voiceRegistry voicerecognition.Registry
|
||||
authDB *gorm.DB
|
||||
metricsService *monitoring.LocalAIMetricsService
|
||||
statsRecorder *billing.Recorder
|
||||
fallbackUser *auth.User
|
||||
piiRedactor *pii.Redactor
|
||||
piiEvents pii.EventStore
|
||||
mitmCA atomic.Pointer[mitm.CA]
|
||||
mitmServer atomic.Pointer[mitm.Server]
|
||||
mitmMutex sync.Mutex // serializes Stop+Start; readers use atomic loads
|
||||
// mitmHostConflicts records duplicate-host claims across model configs.
|
||||
// Non-empty disables the MITM listener until resolved — the strict
|
||||
// 1-to-1 host↔model invariant the dispatcher relies on. Read by
|
||||
// /api/middleware/status so the admin UI can surface the cause.
|
||||
mitmHostConflicts atomic.Pointer[map[string][]string]
|
||||
routerDecisions router.DecisionStore
|
||||
routerRegistry *router.Registry
|
||||
admissionLimiter *admission.Limiter
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
p2pMutex sync.Mutex
|
||||
@@ -208,103 +185,6 @@ func (a *Application) AuthDB() *gorm.DB {
|
||||
return a.authDB
|
||||
}
|
||||
|
||||
// MetricsService returns the OTel + Prometheus metric service. nil when
|
||||
// --disable-metrics is set or initialisation failed at startup.
|
||||
//
|
||||
// The service is created in startup.go before any counter is registered
|
||||
// so that otel.SetMeterProvider runs early enough for the billing
|
||||
// recorder's counters to bind to the Prom-backed provider rather than
|
||||
// the no-op global. core/http/app.go reuses this instance instead of
|
||||
// constructing its own — two providers would orphan one set of counters
|
||||
// behind whichever provider lost the SetMeterProvider race.
|
||||
func (a *Application) MetricsService() *monitoring.LocalAIMetricsService {
|
||||
return a.metricsService
|
||||
}
|
||||
|
||||
// StatsRecorder returns the billing recorder used by the usage
|
||||
// middleware. It is non-nil whenever stats are not explicitly disabled
|
||||
// — i.e., the no-auth single-user path still gets a working recorder
|
||||
// (in-memory by default). Routes register UsageMiddleware against this
|
||||
// recorder regardless of auth state.
|
||||
func (a *Application) StatsRecorder() *billing.Recorder {
|
||||
return a.statsRecorder
|
||||
}
|
||||
|
||||
// FallbackUser is the synthetic "local" user that UsageMiddleware uses
|
||||
// to attribute requests when no authenticated user is on the context
|
||||
// (i.e., --auth is off). nil when auth is on, since real users are
|
||||
// always available there.
|
||||
func (a *Application) FallbackUser() *auth.User {
|
||||
return a.fallbackUser
|
||||
}
|
||||
|
||||
// PIIRedactor returns the regex-tier PII redactor or nil if PII
|
||||
// filtering is disabled. The chat-route middleware uses this to apply
|
||||
// redaction before dispatch.
|
||||
func (a *Application) PIIRedactor() *pii.Redactor {
|
||||
return a.piiRedactor
|
||||
}
|
||||
|
||||
// PIIEvents returns the PII event store. Same nil-when-disabled
|
||||
// semantics as PIIRedactor; admin REST and MCP read tools call List
|
||||
// against it.
|
||||
func (a *Application) PIIEvents() pii.EventStore {
|
||||
return a.piiEvents
|
||||
}
|
||||
|
||||
// MITMCA returns the cloudproxy MITM proxy's CA, or nil when the
|
||||
// MITM listener is disabled.
|
||||
func (a *Application) MITMCA() *mitm.CA { return a.mitmCA.Load() }
|
||||
|
||||
// MITMServer returns the running MITM proxy or nil.
|
||||
func (a *Application) MITMServer() *mitm.Server { return a.mitmServer.Load() }
|
||||
|
||||
// MITMHostConflicts returns a snapshot of host→[]model-name pairs that
|
||||
// are claimed by 2+ model configs. Empty when the 1-to-1 invariant
|
||||
// holds. Non-empty disables the MITM listener — read by the admin
|
||||
// status endpoint to explain why.
|
||||
func (a *Application) MITMHostConflicts() map[string][]string {
|
||||
p := a.mitmHostConflicts.Load()
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return *p
|
||||
}
|
||||
|
||||
// MITMHostOwners returns the host→model-name map, useful for the
|
||||
// admin status endpoint. The lookup is recomputed on each call to
|
||||
// stay current with model-config edits without needing a
|
||||
// MITMRestart.
|
||||
func (a *Application) MITMHostOwners() map[string]string {
|
||||
if a.backendLoader == nil {
|
||||
return nil
|
||||
}
|
||||
return a.backendLoader.MITMHostOwners().Owners
|
||||
}
|
||||
|
||||
// RouterDecisions returns the routing decision store. nil when stats
|
||||
// are disabled (--disable-stats); the RouteModel middleware skips the
|
||||
// log write in that case but still rewrites requests.
|
||||
func (a *Application) RouterDecisions() router.DecisionStore {
|
||||
return a.routerDecisions
|
||||
}
|
||||
|
||||
// RouterClassifierRegistry returns the process-wide classifier cache.
|
||||
// Shared between the OpenAI and Anthropic route middlewares so the
|
||||
// admin stats endpoint sees every live classifier — and so a
|
||||
// classifier built on the OpenAI route is reused on Anthropic.
|
||||
func (a *Application) RouterClassifierRegistry() *router.Registry {
|
||||
return a.routerRegistry
|
||||
}
|
||||
|
||||
// AdmissionLimiter returns the per-model admission limiter. The
|
||||
// admission middleware uses it to gate concurrent requests; the
|
||||
// admin status surface reads InFlight/Capacity from it for live
|
||||
// load visibility.
|
||||
func (a *Application) AdmissionLimiter() *admission.Limiter {
|
||||
return a.admissionLimiter
|
||||
}
|
||||
|
||||
// StartupConfig returns the original startup configuration (from env vars, before file loading)
|
||||
func (a *Application) StartupConfig() *config.ApplicationConfig {
|
||||
return a.startupConfig
|
||||
@@ -375,15 +255,6 @@ func (a *Application) start() error {
|
||||
a.modelLoader,
|
||||
a.galleryService,
|
||||
)
|
||||
// Wire usage tracking so the assistant's get_usage_stats tool
|
||||
// returns real data; nil values keep the tool returning a clear
|
||||
// "unavailable" error if startup ran with --disable-stats.
|
||||
assistantClient.StatsRecorder = a.statsRecorder
|
||||
assistantClient.FallbackUser = a.fallbackUser
|
||||
// PII filter — same nil-or-real wiring.
|
||||
assistantClient.PIIRedactor = a.piiRedactor
|
||||
assistantClient.PIIEvents = a.piiEvents
|
||||
assistantClient.RouterDecisions = a.routerDecisions
|
||||
if err := holder.Initialize(a.applicationConfig.Context, assistantClient, localaitools.Options{}); err != nil {
|
||||
// Why log+continue instead of fail: the assistant is an optional
|
||||
// feature; a failure here must not take down the whole server.
|
||||
|
||||
@@ -233,12 +233,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("File stager initialized (HTTP direct transfer)")
|
||||
}
|
||||
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(
|
||||
registry,
|
||||
natsClient,
|
||||
cfg.Distributed.BackendInstallTimeoutOrDefault(),
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func startMITMProxy(app *Application, options *config.ApplicationConfig) error {
|
||||
app.mitmMutex.Lock()
|
||||
defer app.mitmMutex.Unlock()
|
||||
return startMITMLocked(app, options)
|
||||
}
|
||||
|
||||
func startMITMLocked(app *Application, options *config.ApplicationConfig) error {
|
||||
// Validate the host↔model-config 1-to-1 invariant before binding
|
||||
// the listener. Two configs claiming the same host means the
|
||||
// dispatcher would have ambiguous PII settings; refuse to start
|
||||
// rather than silently picking one. The conflict map is published
|
||||
// for /api/middleware/status to surface in the UI.
|
||||
ownership := app.backendLoader.MITMHostOwners()
|
||||
if len(ownership.Conflicts) > 0 {
|
||||
conflicts := ownership.Conflicts
|
||||
app.mitmHostConflicts.Store(&conflicts)
|
||||
hosts := make([]string, 0, len(conflicts))
|
||||
for h := range conflicts {
|
||||
hosts = append(hosts, h)
|
||||
}
|
||||
sort.Strings(hosts)
|
||||
xlog.Error("mitm: refusing to start — duplicate host claims across model configs",
|
||||
"hosts", hosts,
|
||||
"conflicts", conflicts,
|
||||
)
|
||||
return errors.New("mitm: configuration error: duplicate host claims (see /api/middleware/status)")
|
||||
}
|
||||
app.mitmHostConflicts.Store(nil)
|
||||
|
||||
caDir := options.MITMCADir
|
||||
if caDir == "" {
|
||||
base := options.DataPath
|
||||
if base == "" {
|
||||
base = "."
|
||||
}
|
||||
caDir = filepath.Join(base, "mitm-ca")
|
||||
}
|
||||
|
||||
if app.mitmCA.Load() == nil {
|
||||
ca, err := mitm.LoadOrCreateCA(caDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ca: %w", err)
|
||||
}
|
||||
app.mitmCA.Store(ca)
|
||||
}
|
||||
|
||||
// Allowlist is exactly the set of hosts claimed by model configs.
|
||||
// No global list — admins add hosts by creating an MITM model
|
||||
// config (template available in the Add Model UI). When no config
|
||||
// claims any host, the listener still starts but every CONNECT
|
||||
// tunnels through unmodified.
|
||||
effectiveHosts := make([]string, 0, len(ownership.Owners))
|
||||
for h := range ownership.Owners {
|
||||
effectiveHosts = append(effectiveHosts, h)
|
||||
}
|
||||
sort.Strings(effectiveHosts)
|
||||
|
||||
// Per-host PII gate inherits from the owning model's pii.enabled.
|
||||
// A non-cloud-proxy backend with no explicit pii.enabled resolves
|
||||
// to false → host is intercepted but the regex pass is skipped
|
||||
// (audit events still record).
|
||||
var piiDisabled []string
|
||||
for host, modelName := range ownership.Owners {
|
||||
cfg, exists := app.backendLoader.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if !cfg.PIIIsEnabled() {
|
||||
piiDisabled = append(piiDisabled, host)
|
||||
}
|
||||
}
|
||||
|
||||
handler := mitm.NewPIIHandler(mitm.PIIHandlerOptions{
|
||||
Redactor: app.piiRedactor,
|
||||
EventStore: app.piiEvents,
|
||||
HostsWithPIIDisabled: piiDisabled,
|
||||
})
|
||||
|
||||
srv, err := mitm.NewServer(mitm.Config{
|
||||
Addr: options.MITMListen,
|
||||
CA: app.mitmCA.Load(),
|
||||
InterceptHosts: effectiveHosts,
|
||||
Handler: handler,
|
||||
EventStore: app.piiEvents,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("server: %w", err)
|
||||
}
|
||||
if err := srv.Start(); err != nil {
|
||||
return fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
app.mitmServer.Store(srv)
|
||||
|
||||
xlog.Info("mitm: cloudproxy listener started",
|
||||
"addr", srv.Addr(),
|
||||
"ca_dir", caDir,
|
||||
"intercept_hosts", effectiveHosts,
|
||||
"model_owned_hosts", len(ownership.Owners),
|
||||
"pii_disabled_hosts", len(piiDisabled),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopMITM is idempotent.
|
||||
func (a *Application) StopMITM() error {
|
||||
a.mitmMutex.Lock()
|
||||
defer a.mitmMutex.Unlock()
|
||||
stopMITMLocked(a)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartMITM reuses the existing CA so trusted clients keep
|
||||
// working across listener flips.
|
||||
func (a *Application) RestartMITM() error {
|
||||
a.mitmMutex.Lock()
|
||||
defer a.mitmMutex.Unlock()
|
||||
stopMITMLocked(a)
|
||||
if a.applicationConfig.MITMListen == "" {
|
||||
xlog.Info("mitm: cloudproxy listener stays disabled (no listen address)")
|
||||
return nil
|
||||
}
|
||||
return startMITMLocked(a, a.applicationConfig)
|
||||
}
|
||||
|
||||
func stopMITMLocked(a *Application) {
|
||||
srv := a.mitmServer.Load()
|
||||
if srv == nil {
|
||||
return
|
||||
}
|
||||
srv.Stop()
|
||||
a.mitmServer.Store(nil)
|
||||
xlog.Info("mitm: cloudproxy listener stopped")
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// adapterConfig resolves a model name to its runtime ModelConfig, or
|
||||
// nil when the name is unknown. Shared by the router-facing factories
|
||||
// below and by ModelConfigLookup.
|
||||
func (a *Application) adapterConfig(modelName string) *config.ModelConfig {
|
||||
cfg, err := a.backendLoader.LoadModelConfigFileByNameDefaultOptions(modelName, a.applicationConfig)
|
||||
if err != nil || cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ModelConfigLookup is the lookup function the router middleware's
|
||||
// classifier validator uses to confirm classifier_model declares
|
||||
// FLAG_SCORE before binding it.
|
||||
func (a *Application) ModelConfigLookup() func(modelName string) *config.ModelConfig {
|
||||
return a.adapterConfig
|
||||
}
|
||||
|
||||
// Scorer returns a backend.Scorer bound to the named model, or nil
|
||||
// when the model is unknown. Used as a method value (app.Scorer) by
|
||||
// router.ClassifierDeps — no factory-of-factory wrapper needed.
|
||||
func (a *Application) Scorer(modelName string) backend.Scorer {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewScorer(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// Reranker returns a backend.Reranker bound to the named model, or
|
||||
// nil when unknown. The reranker model's `type:` (e.g. "colbert")
|
||||
// selects the scoring head inside the rerankers backend.
|
||||
func (a *Application) Reranker(modelName string) backend.Reranker {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewReranker(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// Embedder returns a backend.Embedder bound to the named model, or
|
||||
// nil when unknown. Used by the router's L2 embedding cache.
|
||||
func (a *Application) Embedder(modelName string) backend.Embedder {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewEmbedder(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// VectorStore returns a backend.VectorStore for the named collection,
|
||||
// or nil when the name is empty. Each router model gets its own
|
||||
// backend process via the model loader's cache keyed by storeName.
|
||||
func (a *Application) VectorStore(storeName string) backend.VectorStore {
|
||||
return backend.NewVectorStore(a.modelLoader, a.applicationConfig, storeName)
|
||||
}
|
||||
@@ -87,28 +87,6 @@ var _ = Describe("loadRuntimeSettingsFromFile", func() {
|
||||
})
|
||||
})
|
||||
|
||||
// MITM listener address. The file is the only source — no env var
|
||||
// exists — so a regression here means an admin who configured the
|
||||
// listener via /api/settings loses it after a reboot, even though
|
||||
// the value is still on disk in the volume. (Intercept hosts now
|
||||
// live in model YAML mitm.hosts: blocks, not runtime_settings.json.)
|
||||
Describe("MITM fields", func() {
|
||||
It("loads mitm_listen", func() {
|
||||
cfg := &config.ApplicationConfig{DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`)}
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.MITMListen).To(Equal(":8443"))
|
||||
})
|
||||
|
||||
It("does not override an explicit CLI flag", func() {
|
||||
cfg := &config.ApplicationConfig{
|
||||
DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`),
|
||||
MITMListen: ":9999", // simulate WithMITMListen(":9999")
|
||||
}
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.MITMListen).To(Equal(":9999"), "CLI flag must win over the persisted file value")
|
||||
})
|
||||
})
|
||||
|
||||
// The Agent Pool block has a mix of zero and non-zero defaults
|
||||
// (Enabled=true, EmbeddingModel="granite-...", MaxChunkingSize=400,
|
||||
// VectorEngine="chromem", AgentHubURL="https://agenthub.localai.io").
|
||||
|
||||
@@ -15,18 +15,11 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
@@ -135,117 +128,6 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}()
|
||||
}
|
||||
|
||||
// Initialize the OTel + Prometheus metric pipeline before any
|
||||
// counter is created. monitoring.NewLocalAIMetricsService calls
|
||||
// otel.SetMeterProvider, so any subsequent otel.Meter() call —
|
||||
// including billing.NewRecorder below — sees the real provider
|
||||
// rather than the no-op global. Initialising metrics later (in
|
||||
// core/http/app.go) leaves billing's counters bound to a no-op
|
||||
// meter and never reaches /metrics. We deliberately ignore
|
||||
// DisableMetrics here for ordering purposes; the HTTP middleware
|
||||
// that records api_call histograms is still gated.
|
||||
if !options.DisableMetrics {
|
||||
ms, err := monitoring.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
xlog.Error("failed to initialize metrics provider", "error", err)
|
||||
} else {
|
||||
application.metricsService = ms
|
||||
// Bind the billing package's counters to the same meter the
|
||||
// metrics service exports. Without this, billing's counters
|
||||
// resolve via the OTel global and never reach /metrics.
|
||||
billing.SetMeter(ms.Meter)
|
||||
}
|
||||
}
|
||||
|
||||
// Wire the routing-module billing recorder. The recorder runs in
|
||||
// every mode (auth on/off, distributed/single-node) so that token
|
||||
// tracking is not gated on auth — a no-auth single-user box still
|
||||
// gets dashboards and `/api/usage` populated.
|
||||
//
|
||||
// fallbackUser is wired *unconditionally* when stats are enabled.
|
||||
// UsageMiddleware uses it as the attribution source whenever
|
||||
// auth.GetUser(c) is nil — that covers (a) no-auth deployments and
|
||||
// (b) internal callers under auth-on (cron flushers, distributed
|
||||
// worker callbacks) that hit a recordable endpoint without a user
|
||||
// in context. The billing.user_id_present invariant still rejects
|
||||
// empty IDs; LocalUser() returns a stable UUID per data path.
|
||||
if !options.DisableStats {
|
||||
var statsBackend billing.StatsBackend
|
||||
switch {
|
||||
case application.authDB != nil:
|
||||
statsBackend = billing.NewGormBackend(application.authDB, 0, 0)
|
||||
xlog.Info("stats: using auth DB for usage records")
|
||||
default:
|
||||
statsBackend = billing.NewMemoryBackend(0)
|
||||
xlog.Info("stats: using in-memory ring buffer (no-auth single-user mode)")
|
||||
}
|
||||
application.fallbackUser = billing.LocalUser(options.DataPath)
|
||||
application.statsRecorder = billing.NewRecorder(statsBackend)
|
||||
// Drain pending records on SIGTERM. The GORM backend buffers up
|
||||
// to maxPending (5k) records across a 5s flush tick, so without
|
||||
// this the last few seconds of usage disappear on graceful exit.
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
_ = application.statsRecorder.Close()
|
||||
})
|
||||
xlog.Info("stats: fallback user wired", "local_user_id", application.fallbackUser.ID)
|
||||
} else {
|
||||
xlog.Info("stats: disabled by --disable-stats")
|
||||
}
|
||||
|
||||
// Wire the regex PII filter. Default-on: a single-user box gets
|
||||
// the built-in pattern set the first time it starts, with email/
|
||||
// phone/SSN/credit-card on mask and api_key_prefix on block. If
|
||||
// the operator wants different actions, --pii-config points at a
|
||||
// YAML file that overrides per-id; --disable-pii turns it off
|
||||
// entirely.
|
||||
if !options.DisablePII {
|
||||
patterns, err := pii.LoadConfig(options.PIIConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pii config: %w", err)
|
||||
}
|
||||
application.piiRedactor = pii.NewRedactor(patterns)
|
||||
application.piiEvents = pii.NewMemoryEventStore(0)
|
||||
// Apply persisted per-pattern overrides — admins toggling
|
||||
// action/disabled via the UI and clicking "Save to disk" land
|
||||
// here on the next start. Bad ids are warned and ignored so a
|
||||
// stale entry doesn't block startup.
|
||||
for id, ov := range options.PIIPatternOverrides {
|
||||
if ov.Action != nil {
|
||||
if err := application.piiRedactor.SetAction(id, pii.Action(*ov.Action)); err != nil {
|
||||
xlog.Warn("pii: persisted override skipped", "pattern", id, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if ov.Disabled != nil {
|
||||
if err := application.piiRedactor.SetDisabled(id, *ov.Disabled); err != nil {
|
||||
xlog.Warn("pii: persisted disable skipped", "pattern", id, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
xlog.Info("pii: filter enabled",
|
||||
"patterns", len(patterns),
|
||||
"config_path", options.PIIConfigPath,
|
||||
"persisted_overrides", len(options.PIIPatternOverrides),
|
||||
)
|
||||
} else {
|
||||
xlog.Info("pii: disabled by --disable-pii")
|
||||
}
|
||||
|
||||
// Wire the routing decision log. Always-on when stats are enabled —
|
||||
// the per-router admin page reads this as the live activity feed
|
||||
// and as input to drift checks for subsystem 5.
|
||||
if !options.DisableStats {
|
||||
application.routerDecisions = router.NewMemoryDecisionStore(0)
|
||||
}
|
||||
// Process-wide classifier cache shared across all route middlewares so
|
||||
// the embedding-cache stats endpoint sees a single source of truth.
|
||||
application.routerRegistry = router.NewRegistry()
|
||||
|
||||
// Subsystem 5: admission control. Limiter is always wired so a
|
||||
// model that gains a limits: block via gallery install or YAML
|
||||
// edit takes effect on the next restart without conditional plumbing.
|
||||
application.admissionLimiter = admission.New()
|
||||
|
||||
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
|
||||
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
|
||||
if application.authDB != nil && application.agentJobService != nil {
|
||||
@@ -313,36 +195,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
}
|
||||
// Hydrate from the store first so the wildcard subscriber finds an
|
||||
// already-populated statuses map for any operations still in flight
|
||||
// on a peer replica.
|
||||
if err := application.galleryService.Hydrate(); err != nil {
|
||||
xlog.Warn("Gallery service hydrate failed", "error", err)
|
||||
}
|
||||
// Bind cache-invalidation handler before SubscribeBroadcasts so the
|
||||
// first inbound event is already routed. Peer replicas install a
|
||||
// model and broadcast on SubjectCacheInvalidateModels; this
|
||||
// callback re-runs LoadModelConfigsFromPath so a subsequent chat
|
||||
// completion that load-balances onto this replica finds the new
|
||||
// config. The originating replica reloads inline in modelHandler
|
||||
// and never enters this path.
|
||||
gs := application.galleryService
|
||||
sys := options.SystemState
|
||||
cfgLoaderOpts := options.ToConfigLoaderOptions()
|
||||
gs.OnModelsChanged = func(_ messaging.CacheInvalidateEvent) {
|
||||
if err := application.ModelConfigLoader().LoadModelConfigsFromPath(sys.Model.ModelsPath, cfgLoaderOpts...); err != nil {
|
||||
xlog.Warn("Failed to reload model configs after peer invalidation", "error", err)
|
||||
}
|
||||
}
|
||||
if err := application.galleryService.SubscribeBroadcasts(); err != nil {
|
||||
xlog.Warn("Gallery service subscribe failed", "error", err)
|
||||
}
|
||||
// Wire distributed model/backend managers so delete propagates to workers
|
||||
application.galleryService.SetModelManager(
|
||||
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||
)
|
||||
application.galleryService.SetBackendManager(
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry, application.galleryService),
|
||||
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -433,20 +291,6 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
loadRuntimeSettingsFromFile(options)
|
||||
}
|
||||
|
||||
// Wire the cloudproxy MITM listener. Opt-in: empty MITMListen
|
||||
// means "no MITM" — operators must explicitly choose to start
|
||||
// it because clients have to install the generated CA cert.
|
||||
// The handler reuses the global redactor + event store so an
|
||||
// admin who's already configured PII filtering for direct API
|
||||
// traffic doesn't need a parallel config for MITM traffic.
|
||||
// Runs after loadRuntimeSettingsFromFile so a listener configured
|
||||
// via /api/settings is brought back up across restarts.
|
||||
if options.MITMListen != "" {
|
||||
if err := startMITMProxy(application, options); err != nil {
|
||||
return nil, fmt.Errorf("mitm: startup: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging)
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
@@ -708,13 +552,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
}
|
||||
if settings.TracingMaxBodyBytes != nil {
|
||||
// Allow the on-disk setting to override the CLI/env default. The
|
||||
// startup default is non-zero (see NewApplicationConfig), so a plain
|
||||
// `== 0` guard like the others would never trigger; we instead respect
|
||||
// any value the file specifies. 0 in the file means "uncapped".
|
||||
options.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||
}
|
||||
|
||||
// Branding / whitelabeling. There are no env vars for these — the file is
|
||||
// the only source — so apply unconditionally. Without this block a server
|
||||
@@ -736,25 +573,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.Branding.FaviconFile = *settings.FaviconFile
|
||||
}
|
||||
|
||||
// MITM listener address. The CLI flag WithMITMListen populates
|
||||
// options at startup; if the user configured MITM via /api/settings
|
||||
// after the fact, only the file holds the value. Apply when the
|
||||
// CLI flag did not already set it. (Intercept hosts now live in
|
||||
// model YAML mitm.hosts: rather than runtime_settings.json.)
|
||||
if settings.MITMListen != nil && options.MITMListen == "" {
|
||||
options.MITMListen = *settings.MITMListen
|
||||
}
|
||||
|
||||
// PII pattern overrides — file is the only source; CLI flags don't
|
||||
// reach into this map. Apply unconditionally when present; the
|
||||
// redactor wiring below sees the result on first construction.
|
||||
if settings.PIIPatternOverrides != nil {
|
||||
options.PIIPatternOverrides = make(map[string]config.PIIPatternRuntimeOverride, len(*settings.PIIPatternOverrides))
|
||||
for id, ov := range *settings.PIIPatternOverrides {
|
||||
options.PIIPatternOverrides[id] = ov
|
||||
}
|
||||
}
|
||||
|
||||
// Backend upgrade flags
|
||||
if settings.AutoUpgradeBackends != nil {
|
||||
if !options.AutoUpgradeBackends {
|
||||
|
||||
@@ -78,7 +78,7 @@ func ModelAudioTransform(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func ModelAudioTransform(
|
||||
data["sample_rate"] = res.SampleRate
|
||||
data["samples"] = res.Samples
|
||||
data["reference_provided"] = res.ReferenceProvided
|
||||
if snippet := trace.AudioSnippet(dst, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
if snippet := trace.AudioSnippet(dst); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
package backend_test
|
||||
|
||||
// Regression spec for X-LocalAI-Node coverage on audio/image/TTS/rerank/VAD.
|
||||
//
|
||||
// The X-LocalAI-Node middleware (core/http/middleware.ExposeNodeHeader)
|
||||
// works end-to-end only if the per-request holder attached to the HTTP
|
||||
// request context reaches the SmartRouter via ml.Load(opts...). The chain
|
||||
// is:
|
||||
//
|
||||
// handler -> backend.Foo(ctx, ...) -> ModelOptions(cfg, app, WithContext(ctx))
|
||||
// -> ml.Load(opts...) -> grpcModel(..., o.context) -> modelRouter(ctx, ...)
|
||||
// -> SmartRouter -> distributedhdr.Stamp(ctx, nodeID)
|
||||
//
|
||||
// If any backend helper drops `ctx` and lets ModelOptions fall back to the
|
||||
// app context, the router never sees the per-request holder and the
|
||||
// header silently stays empty for that endpoint. These specs pin the
|
||||
// request-context-reaches-router contract for the five backend helpers
|
||||
// that were previously dropping ctx between the handler and Load.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pbproto "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// newCapturingLoader returns a ModelLoader wired with a stub model router
|
||||
// that captures the context it receives and then short-circuits with a
|
||||
// sentinel error. The router callback is the exact seam where the
|
||||
// SmartRouter would call distributedhdr.Stamp in production, so observing
|
||||
// the holder here is equivalent to observing it at the real router.
|
||||
func newCapturingLoader() (*model.ModelLoader, *atomic.Value, func() context.Context) {
|
||||
loader := model.NewModelLoader(&system.SystemState{})
|
||||
var captured atomic.Value
|
||||
loader.SetModelRouter(func(ctx context.Context, _ string, _, _, _ string, _ *pbproto.ModelOptions, _ bool) (*model.Model, error) {
|
||||
captured.Store(ctx)
|
||||
// Return an error so the backend short-circuits before trying to
|
||||
// dial gRPC. We only care about the context-arrival contract.
|
||||
return nil, errRouterShortCircuit
|
||||
})
|
||||
get := func() context.Context {
|
||||
v, _ := captured.Load().(context.Context)
|
||||
return v
|
||||
}
|
||||
return loader, &captured, get
|
||||
}
|
||||
|
||||
var errRouterShortCircuit = sentinelErr("router short-circuit (test)")
|
||||
|
||||
type sentinelErr string
|
||||
|
||||
func (s sentinelErr) Error() string { return string(s) }
|
||||
|
||||
func newAppCfg() *config.ApplicationConfig {
|
||||
return config.NewApplicationConfig(config.WithSystemState(&system.SystemState{}))
|
||||
}
|
||||
|
||||
func newModelCfg() config.ModelConfig {
|
||||
threads := 1
|
||||
cfg := config.ModelConfig{
|
||||
Name: "test-model",
|
||||
Backend: "stub-backend",
|
||||
Threads: &threads,
|
||||
}
|
||||
cfg.Model = "test.bin"
|
||||
return cfg
|
||||
}
|
||||
|
||||
var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
|
||||
const fakeNodeID = "node-ctx-propagation-7"
|
||||
|
||||
var (
|
||||
appCfg *config.ApplicationConfig
|
||||
modelCfg config.ModelConfig
|
||||
loader *model.ModelLoader
|
||||
routerCtxOf func() context.Context
|
||||
holder *atomic.Value
|
||||
reqCtx context.Context
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
appCfg = newAppCfg()
|
||||
modelCfg = newModelCfg()
|
||||
loader, _, routerCtxOf = newCapturingLoader()
|
||||
holder = distributedhdr.NewHolder()
|
||||
reqCtx = distributedhdr.WithHolder(context.Background(), holder)
|
||||
})
|
||||
|
||||
// stampViaRouterCtx asserts the captured router context carries the
|
||||
// SAME holder that was attached to the request. We verify by stamping
|
||||
// through the router-side ctx and observing the value via the
|
||||
// request-side holder; if the holders were different objects the load
|
||||
// would return "".
|
||||
stampViaRouterCtx := func() {
|
||||
routerCtx := routerCtxOf()
|
||||
Expect(routerCtx).ToNot(BeNil(), "router callback must have been invoked")
|
||||
distributedhdr.Stamp(routerCtx, fakeNodeID)
|
||||
Expect(distributedhdr.Load(holder)).To(Equal(fakeNodeID),
|
||||
"stamp via router-side ctx must be observable via the request-side holder")
|
||||
}
|
||||
|
||||
It("Rerank forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.Rerank(reqCtx, &pbproto.RerankRequest{Query: "q"}, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("VAD forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.VAD(&schema.VADRequest{}, reqCtx, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTS forwards the request context to the SmartRouter", func() {
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTranscriptionWithOptions forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.ModelTranscriptionWithOptions(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTranscriptionStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTranscriptionStream(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg, func(backend.TranscriptionStreamChunk) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ImageGeneration forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.ImageGeneration(reqCtx, 64, 64, 1, 0, "p", "", "", "/tmp/out.png", loader, modelCfg, appCfg, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("does NOT leak the holder when the app context is used instead", func() {
|
||||
// Sanity: the bug being fixed manifests as the router getting
|
||||
// appCfg.Context (no holder) instead of reqCtx (holder). A direct
|
||||
// call with context.Background() must not see the holder via the
|
||||
// app context surface.
|
||||
appCtxOnly := appCfg.Context
|
||||
Expect(distributedhdr.Holder(appCtxOnly)).To(BeNil(),
|
||||
"the app context must not be the carrier of per-request holders")
|
||||
})
|
||||
})
|
||||
@@ -35,7 +35,7 @@ func Detection(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -12,38 +11,9 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// Embedder produces a fixed-dimension vector from a prompt. The
|
||||
// router's L2 embedding cache uses it to look up semantically-similar
|
||||
// past decisions.
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, text string) ([]float32, error)
|
||||
}
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
|
||||
// NewEmbedder binds (loader, modelConfig, appConfig) into an Embedder.
|
||||
func NewEmbedder(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Embedder {
|
||||
return &modelEmbedder{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelEmbedder struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (e *modelEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
fn, err := ModelEmbedding(ctx, text, nil, e.loader, e.modelConfig, e.appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fn()
|
||||
}
|
||||
|
||||
func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
@@ -97,7 +67,7 @@ func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.M
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"input_text": trace.TruncateString(s, 1000),
|
||||
|
||||
@@ -32,7 +32,7 @@ func FaceAnalyze(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func FaceVerify(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -11,12 +10,9 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(ctx context.Context, height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
func ImageGeneration(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(
|
||||
opts...,
|
||||
)
|
||||
@@ -27,7 +23,7 @@ func ImageGeneration(ctx context.Context, height, width, step, seed int, positiv
|
||||
|
||||
fn := func() error {
|
||||
_, err := inferenceModel.GenerateImage(
|
||||
ctx,
|
||||
appConfig.Context,
|
||||
&proto.GenerateImageRequest{
|
||||
Height: int32(height),
|
||||
Width: int32(width),
|
||||
@@ -45,7 +41,7 @@ func ImageGeneration(ctx context.Context, height, width, step, seed int, positiv
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"positive_prompt": positive_prompt,
|
||||
|
||||
@@ -94,7 +94,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
opts := ModelOptions(*c, o, model.WithContext(ctx))
|
||||
opts := ModelOptions(*c, o)
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(o, c.Name, c.Backend, err, map[string]any{"model_file": modelFile})
|
||||
@@ -305,7 +305,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
|
||||
if o.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems, o.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
@@ -316,13 +316,9 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
"audios_count": len(audios),
|
||||
}
|
||||
|
||||
// Cap the captured fields up front: agent-pool LLM calls embed the
|
||||
// full augmented chat history in messages and the full reply in
|
||||
// response, so without a per-field cap a single trace can dwarf the
|
||||
// rest of the buffer. The cap matches the API-trace body cap.
|
||||
if len(messages) > 0 {
|
||||
if msgJSON, err := json.Marshal(messages); err == nil {
|
||||
traceData["messages"] = trace.TruncateToBytes(string(msgJSON), o.TracingMaxBodyBytes)
|
||||
traceData["messages"] = string(msgJSON)
|
||||
}
|
||||
}
|
||||
if reasoningJSON, err := json.Marshal(c.ReasoningConfig); err == nil {
|
||||
@@ -341,7 +337,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
resp, err := originalFn()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
traceData["response"] = trace.TruncateToBytes(resp.Response, o.TracingMaxBodyBytes)
|
||||
traceData["response"] = resp.Response
|
||||
traceData["token_usage"] = map[string]any{
|
||||
"prompt": resp.Usage.Prompt,
|
||||
"completion": resp.Usage.Completion,
|
||||
@@ -363,10 +359,10 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
toolCallCount += len(d.ToolCalls)
|
||||
}
|
||||
if len(contentParts) > 0 {
|
||||
chatDeltasInfo["content"] = trace.TruncateToBytes(strings.Join(contentParts, ""), o.TracingMaxBodyBytes)
|
||||
chatDeltasInfo["content"] = strings.Join(contentParts, "")
|
||||
}
|
||||
if len(reasoningParts) > 0 {
|
||||
chatDeltasInfo["reasoning_content"] = trace.TruncateToBytes(strings.Join(reasoningParts, ""), o.TracingMaxBodyBytes)
|
||||
chatDeltasInfo["reasoning_content"] = strings.Join(reasoningParts, "")
|
||||
}
|
||||
if toolCallCount > 0 {
|
||||
chatDeltasInfo["tool_call_count"] = toolCallCount
|
||||
|
||||
@@ -21,7 +21,7 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
|
||||
if !appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceModelLoad,
|
||||
@@ -242,18 +242,6 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
Tokenizer: c.Tokenizer,
|
||||
}
|
||||
|
||||
if c.Backend == "cloud-proxy" {
|
||||
opts.Proxy = &pb.ProxyOptions{
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
}
|
||||
}
|
||||
|
||||
if c.MMProj != "" {
|
||||
opts.MMProj = filepath.Join(modelPath, c.MMProj)
|
||||
}
|
||||
@@ -289,7 +277,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
MinP: float32(*c.MinP),
|
||||
Tokens: int32(*c.Maxtokens),
|
||||
Threads: int32(*c.Threads),
|
||||
PromptCacheAll: *c.PromptCacheAll,
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
PromptCacheRO: c.PromptCacheRO,
|
||||
PromptCachePath: promptCachePath,
|
||||
F16KV: *c.F16,
|
||||
|
||||
@@ -11,56 +11,8 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// RerankResult is the per-document score returned to consumers,
|
||||
// narrowed from proto.RerankResult so callers don't need to depend on
|
||||
// the proto package.
|
||||
type RerankResult struct {
|
||||
Index int
|
||||
RelevanceScore float32
|
||||
}
|
||||
|
||||
// Reranker scores a list of candidate documents against a query.
|
||||
// Returns one RerankResult per input document (no top-N truncation -
|
||||
// callers that need it can sort and slice).
|
||||
type Reranker interface {
|
||||
Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error)
|
||||
}
|
||||
|
||||
// NewReranker binds (loader, modelConfig, appConfig) into a Reranker.
|
||||
func NewReranker(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Reranker {
|
||||
return &modelReranker{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelReranker struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (r *modelReranker) Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error) {
|
||||
req := &proto.RerankRequest{
|
||||
Query: query,
|
||||
Documents: documents,
|
||||
// TopN=0: backend returns scores for every document. Truncating
|
||||
// here would silently zero out labels the reranker considered
|
||||
// unlikely, which the router classifier needs.
|
||||
}
|
||||
res, err := Rerank(ctx, req, r.loader, r.appConfig, r.modelConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]RerankResult, 0, len(res.GetResults()))
|
||||
for _, dr := range res.GetResults() {
|
||||
out = append(out, RerankResult{Index: int(dr.GetIndex()), RelevanceScore: dr.GetRelevanceScore()})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
rerankModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -73,7 +25,7 @@ func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.Mod
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// ScoreOptions controls a single Score request.
|
||||
type ScoreOptions struct {
|
||||
// IncludeTokenLogprobs returns per-token log-probability detail for
|
||||
// each candidate. Off by default — the joint LogProb is enough for
|
||||
// ranking; callers that need calibration / entropy over the token
|
||||
// stream opt in.
|
||||
IncludeTokenLogprobs bool
|
||||
// LengthNormalize divides the joint log-prob by the candidate's
|
||||
// token count. Useful when comparing candidates of different
|
||||
// lengths — without it, longer candidates score lower by default.
|
||||
LengthNormalize bool
|
||||
}
|
||||
|
||||
// CandidateScore is the per-candidate result. Mirrors pb.CandidateScore
|
||||
// but avoids leaking the proto type to consumers.
|
||||
type CandidateScore struct {
|
||||
LogProb float64
|
||||
LengthNormalizedLogProb float64
|
||||
NumTokens int
|
||||
Tokens []TokenLogProb
|
||||
}
|
||||
|
||||
type TokenLogProb struct {
|
||||
Token string
|
||||
LogProb float64
|
||||
}
|
||||
|
||||
// Scorer evaluates a model's joint log-probability of each candidate
|
||||
// continuation given a shared prompt. Implemented by NewScorer over a
|
||||
// model-loaded backend; the router's score classifier consumes this
|
||||
// for multi-label policy selection.
|
||||
type Scorer interface {
|
||||
Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error)
|
||||
}
|
||||
|
||||
// NewScorer binds (loader, modelConfig, appConfig) into a Scorer. The
|
||||
// underlying backend is resolved lazily on the first Score call.
|
||||
// Returns nil only as a contract violation — callers that need to
|
||||
// detect "model not loadable" should look up the config first.
|
||||
func NewScorer(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Scorer {
|
||||
return &modelScorer{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelScorer struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (m *modelScorer) Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error) {
|
||||
fn, err := ModelScore(prompt, candidates, ScoreOptions{LengthNormalize: true}, m.loader, m.modelConfig, m.appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
// ModelScore loads the backend for modelConfig and returns a closure
|
||||
// that scores `candidates` against `prompt`. The closure is bound to
|
||||
// the loaded model so callers can keep it around for repeat scoring
|
||||
// within the same request without re-resolving the backend.
|
||||
func ModelScore(prompt string, candidates []string, opts ScoreOptions, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func(ctx context.Context) ([]CandidateScore, error), error) {
|
||||
modelOpts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(modelOpts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
b, ok := inferenceModel.(grpc.Backend)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("scoring not supported by backend %q", modelConfig.Backend)
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("Score: candidates must be non-empty")
|
||||
}
|
||||
return func(ctx context.Context) ([]CandidateScore, error) {
|
||||
// Surface score calls in the Traces UI alongside the LLM calls
|
||||
// they typically gate (router classifier, eval scoring). Without
|
||||
// this, a router-classified request shows only the downstream LLM
|
||||
// trace with no record of the classification that picked it.
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
resp, err := b.Score(ctx, &pb.ScoreRequest{
|
||||
Prompt: prompt,
|
||||
Candidates: candidates,
|
||||
IncludeTokenLogprobs: opts.IncludeTokenLogprobs,
|
||||
LengthNormalize: opts.LengthNormalize,
|
||||
})
|
||||
results := scoreResponseToCandidates(resp, opts.IncludeTokenLogprobs)
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceScore,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(prompt, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
// Copy candidates so the trace buffer doesn't pin a
|
||||
// caller-owned slice for the lifetime of the ring.
|
||||
"candidates": append([]string(nil), candidates...),
|
||||
"results": results,
|
||||
},
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// scoreResponseToCandidates converts the wire-format pb response into
|
||||
// the value type consumed by callers. Extracted to keep ModelScore's
|
||||
// closure trivial and so the conversion can be unit-tested without a
|
||||
// real backend.
|
||||
func scoreResponseToCandidates(resp *pb.ScoreResponse, includeTokens bool) []CandidateScore {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]CandidateScore, len(resp.Candidates))
|
||||
for i, c := range resp.Candidates {
|
||||
cs := CandidateScore{
|
||||
LogProb: c.LogProb,
|
||||
LengthNormalizedLogProb: c.LengthNormalizedLogProb,
|
||||
NumTokens: int(c.NumTokens),
|
||||
}
|
||||
if includeTokens && len(c.Tokens) > 0 {
|
||||
cs.Tokens = make([]TokenLogProb, len(c.Tokens))
|
||||
for j, t := range c.Tokens {
|
||||
cs.Tokens[j] = TokenLogProb{Token: t.Token, LogProb: t.LogProb}
|
||||
}
|
||||
}
|
||||
out[i] = cs
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("scoreResponseToCandidates", func() {
|
||||
It("returns nil for a nil response", func() {
|
||||
Expect(scoreResponseToCandidates(nil, false)).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns an empty slice when the response has no candidates", func() {
|
||||
Expect(scoreResponseToCandidates(&pb.ScoreResponse{}, false)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("copies LogProb / LengthNormalizedLogProb / NumTokens for every candidate", func() {
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{
|
||||
{LogProb: -2.0, LengthNormalizedLogProb: -1.0, NumTokens: 2},
|
||||
{LogProb: -7.5, LengthNormalizedLogProb: -1.5, NumTokens: 5},
|
||||
}}
|
||||
got := scoreResponseToCandidates(resp, false)
|
||||
Expect(got).To(HaveLen(2))
|
||||
Expect(got[0].LogProb).To(Equal(-2.0))
|
||||
Expect(got[0].LengthNormalizedLogProb).To(Equal(-1.0))
|
||||
Expect(got[0].NumTokens).To(Equal(2))
|
||||
Expect(got[1].LogProb).To(Equal(-7.5))
|
||||
Expect(got[1].NumTokens).To(Equal(5))
|
||||
})
|
||||
|
||||
It("omits per-token detail when includeTokens=false even if the wire response carries it", func() {
|
||||
// Defensive: if the backend over-reports we still respect the
|
||||
// caller's opt-in so consumers don't pay marshaling for data
|
||||
// they didn't ask for.
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{
|
||||
LogProb: -1.0,
|
||||
Tokens: []*pb.TokenLogProb{{Token: "hi", LogProb: -1.0}},
|
||||
}}}
|
||||
got := scoreResponseToCandidates(resp, false)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Tokens).To(BeNil())
|
||||
})
|
||||
|
||||
It("populates per-token detail when includeTokens=true", func() {
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{
|
||||
LogProb: -3.0,
|
||||
NumTokens: 2,
|
||||
Tokens: []*pb.TokenLogProb{
|
||||
{Token: "Hello", LogProb: -1.0},
|
||||
{Token: " world", LogProb: -2.0},
|
||||
},
|
||||
}}}
|
||||
got := scoreResponseToCandidates(resp, true)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Tokens).To(HaveLen(2))
|
||||
Expect(got[0].Tokens[0].Token).To(Equal("Hello"))
|
||||
Expect(got[0].Tokens[0].LogProb).To(Equal(-1.0))
|
||||
Expect(got[0].Tokens[1].Token).To(Equal(" world"))
|
||||
Expect(got[0].Tokens[1].LogProb).To(Equal(-2.0))
|
||||
})
|
||||
})
|
||||
@@ -98,7 +98,7 @@ func SoundGeneration(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,74 +1,12 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
// VectorStore is the narrowed KNN store used by the router's embedding
|
||||
// cache. Search returns the top-1 match (cosine similarity in [-1, 1])
|
||||
// and the serialised payload, or ok=false on a clean miss.
|
||||
type VectorStore interface {
|
||||
Search(ctx context.Context, vec []float32) (similarity float64, payload []byte, ok bool, err error)
|
||||
Insert(ctx context.Context, vec []float32, payload []byte) error
|
||||
}
|
||||
|
||||
// NewVectorStore returns a VectorStore backed by the local-store
|
||||
// gRPC backend, namespaced by storeName so two routers don't collide.
|
||||
func NewVectorStore(loader *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) VectorStore {
|
||||
if storeName == "" {
|
||||
return nil
|
||||
}
|
||||
return &localVectorStore{loader: loader, appConfig: appConfig, storeName: storeName}
|
||||
}
|
||||
|
||||
type localVectorStore struct {
|
||||
loader *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
storeName string
|
||||
}
|
||||
|
||||
func (s *localVectorStore) backend(_ context.Context) (grpc.Backend, error) {
|
||||
return StoreBackend(s.loader, s.appConfig, s.storeName, "")
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Search(ctx context.Context, vec []float32) (float64, []byte, bool, error) {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return 0, nil, false, fmt.Errorf("vector store load: %w", err)
|
||||
}
|
||||
_, values, similarities, err := store.Find(ctx, be, vec, 1)
|
||||
if err != nil {
|
||||
// local-store's Find returns "existing length is -1" before
|
||||
// any keys are inserted. Surface that as a clean miss so the
|
||||
// cache layer treats it as an empty store and proceeds to
|
||||
// Insert rather than skipping.
|
||||
if strings.Contains(err.Error(), "existing length is -1") {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return 0, nil, false, fmt.Errorf("vector store find: %w", err)
|
||||
}
|
||||
if len(values) == 0 || len(similarities) == 0 {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return float64(similarities[0]), values[0], true, nil
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload []byte) error {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vector store load: %w", err)
|
||||
}
|
||||
return store.SetSingle(ctx, be, vec, payload)
|
||||
}
|
||||
|
||||
func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string, backend string) (grpc.Backend, error) {
|
||||
if backend == "" {
|
||||
backend = model.LocalStoreBackend
|
||||
|
||||
@@ -27,7 +27,7 @@ func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.Model
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -41,14 +41,11 @@ func (r *TranscriptionRequest) toProto(threads uint32) *proto.TranscriptRequest
|
||||
}
|
||||
}
|
||||
|
||||
func loadTranscriptionModel(ctx context.Context, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
||||
func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
||||
if modelConfig.Backend == "" {
|
||||
modelConfig.Backend = model.WhisperBackend
|
||||
}
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
transcriptionModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -71,7 +68,7 @@ func ModelTranscription(ctx context.Context, audio, language string, translate,
|
||||
}
|
||||
|
||||
func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -79,10 +76,10 @@ func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest
|
||||
var startTime time.Time
|
||||
var audioSnippet map[string]any
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
// Capture audio before the backend call — the backend may delete the file.
|
||||
audioSnippet = trace.AudioSnippet(req.Audio, appConfig.TracingMaxBodyBytes)
|
||||
audioSnippet = trace.AudioSnippet(req.Audio)
|
||||
}
|
||||
|
||||
r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
|
||||
@@ -153,7 +150,7 @@ type TranscriptionStreamChunk struct {
|
||||
// support real streaming should still emit one terminal event with Final set,
|
||||
// which the HTTP layer turns into a single delta + done SSE pair.
|
||||
func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
||||
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -29,10 +29,7 @@ func ModelTTS(
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (string, *proto.Result, error) {
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -70,7 +67,7 @@ func ModelTTS(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -96,7 +93,7 @@ func ModelTTS(
|
||||
"language": language,
|
||||
}
|
||||
if err == nil && res.Success {
|
||||
if snippet := trace.AudioSnippet(filePath, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
if snippet := trace.AudioSnippet(filePath); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
@@ -134,7 +131,7 @@ func ModelTTSStream(
|
||||
modelConfig config.ModelConfig,
|
||||
audioCallback func([]byte) error,
|
||||
) error {
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -164,7 +161,7 @@ func ModelTTSStream(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
@@ -263,7 +260,7 @@ func ModelTTSStream(
|
||||
"streaming": true,
|
||||
}
|
||||
if resultErr == nil && len(snippetPCM) > 0 {
|
||||
if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes, appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,10 +14,7 @@ func VAD(request *schema.VADRequest,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
|
||||
@@ -42,7 +42,7 @@ func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, en
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"prompt": prompt,
|
||||
|
||||
@@ -31,7 +31,7 @@ func VoiceAnalyze(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func VoiceEmbed(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func VoiceVerify(
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
|
||||
@@ -39,19 +39,19 @@ type RunCMD struct {
|
||||
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
||||
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
|
||||
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
|
||||
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||
AutoUpgradeBackends bool `env:"LOCALAI_AUTO_UPGRADE_BACKENDS,AUTO_UPGRADE_BACKENDS" help:"Automatically upgrade backends when new versions are detected" group:"backends" default:"false"`
|
||||
PreferDevelopmentBackends bool `env:"LOCALAI_PREFER_DEV_BACKENDS,PREFER_DEV_BACKENDS" help:"Prefer development backend versions (shows development backends by default in UI)" group:"backends" default:"false"`
|
||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||
|
||||
F16 bool `name:"f16" env:"LOCALAI_F16,F16" help:"Enable GPU acceleration" group:"performance"`
|
||||
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
|
||||
@@ -100,7 +100,6 @@ type RunCMD struct {
|
||||
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
|
||||
EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
|
||||
TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
|
||||
TracingMaxBodyBytes int `env:"LOCALAI_TRACING_MAX_BODY_BYTES" default:"65536" help:"Maximum bytes captured per request/response body in the trace buffer (0 = uncapped). Caps memory growth from chatty endpoints like /embeddings." group:"api"`
|
||||
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`
|
||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||
|
||||
@@ -145,25 +144,18 @@ type RunCMD struct {
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
|
||||
Version bool
|
||||
|
||||
// Cloud-proxy MITM listener (off by default).
|
||||
MITMListen string `env:"LOCALAI_MITM_LISTEN" help:"Address (host:port) for the cloudproxy MITM listener. Empty = disabled. Clients set HTTPS_PROXY=http://<this>:<port>. Intercept hosts are declared per-model via the model YAML mitm.hosts: block; create one from the Add Model UI." group:"middleware"`
|
||||
MITMCADir string `env:"LOCALAI_MITM_CA_DIR" type:"path" help:"Directory holding the MITM proxy CA cert + key. Defaults to <data-path>/mitm-ca." group:"middleware"`
|
||||
}
|
||||
|
||||
func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
@@ -222,8 +214,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithLoadToMemory(r.LoadToMemory),
|
||||
config.WithMachineTag(r.MachineTag),
|
||||
config.WithAPIAddress(r.Address),
|
||||
config.WithMITMListen(r.MITMListen),
|
||||
config.WithMITMCADir(r.MITMCADir),
|
||||
config.WithAgentJobRetentionDays(r.AgentJobRetentionDays),
|
||||
config.WithLlamaCPPTunnelCallback(func(tunnels []string) {
|
||||
tunnelEnvVar := strings.Join(tunnels, ",")
|
||||
@@ -264,29 +254,12 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.StorageSecretKey != "" {
|
||||
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
|
||||
}
|
||||
if r.BackendInstallTimeout != "" {
|
||||
d, err := time.ParseDuration(r.BackendInstallTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT %q: %w", r.BackendInstallTimeout, err)
|
||||
}
|
||||
opts = append(opts, config.WithBackendInstallTimeout(d))
|
||||
}
|
||||
if r.BackendUpgradeTimeout != "" {
|
||||
d, err := time.ParseDuration(r.BackendUpgradeTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT %q: %w", r.BackendUpgradeTimeout, err)
|
||||
}
|
||||
opts = append(opts, config.WithBackendUpgradeTimeout(d))
|
||||
}
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
opts = append(opts, config.DisableMetricsEndpoint)
|
||||
@@ -300,7 +273,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.EnableTracing)
|
||||
}
|
||||
opts = append(opts, config.WithTracingMaxItems(r.TracingMaxItems))
|
||||
opts = append(opts, config.WithTracingMaxBodyBytes(r.TracingMaxBodyBytes))
|
||||
|
||||
token := ""
|
||||
if r.Peer2Peer || r.Peer2PeerToken != "" {
|
||||
|
||||
@@ -21,7 +21,6 @@ type ApplicationConfig struct {
|
||||
Debug bool
|
||||
EnableTracing bool
|
||||
TracingMaxItems int
|
||||
TracingMaxBodyBytes int // Per-body cap for captured request/response bodies; 0 disables the cap
|
||||
EnableBackendLogging bool
|
||||
GeneratedContentDir string
|
||||
|
||||
@@ -40,54 +39,6 @@ type ApplicationConfig struct {
|
||||
P2PNetworkID string
|
||||
Federated bool
|
||||
|
||||
// DisableStats turns off per-request token tracking. By default the
|
||||
// routing module's billing recorder runs in every mode (including
|
||||
// no-auth single-user) so dashboards and `/api/usage` are immediately
|
||||
// useful; set this to opt out of that, e.g., for ephemeral CI runs
|
||||
// or privacy-strict deployments where no token-count history should
|
||||
// touch disk or memory.
|
||||
DisableStats bool
|
||||
|
||||
// PIIConfigPath points to an optional YAML file describing the PII
|
||||
// pattern set. When empty, the routing/pii module's DefaultPatterns()
|
||||
// (email, phone, SSN, credit card, IPv4, API key prefixes) are
|
||||
// loaded with their default actions. Each entry overrides the
|
||||
// matching default by ID:
|
||||
//
|
||||
// patterns:
|
||||
// - id: email
|
||||
// action: route_local # downgrade default mask -> route_local
|
||||
// - id: ssn
|
||||
// action: block # upgrade default mask -> block
|
||||
//
|
||||
// Unknown ids are rejected with a clear error at startup.
|
||||
PIIConfigPath string
|
||||
|
||||
// DisablePII turns the regex PII filter off entirely. Default
|
||||
// (false) enables it on the OpenAI chat completions route.
|
||||
DisablePII bool
|
||||
|
||||
// MITMListen is the address (host:port) the cloudproxy MITM
|
||||
// listener binds on. Empty disables the MITM proxy entirely.
|
||||
// Use case: redacting PII from Claude Code / Codex CLI traffic
|
||||
// without LocalAI holding the upstream API key. Clients set
|
||||
// HTTPS_PROXY=http://localai:port and trust the CA cert
|
||||
// LocalAI exposes at /api/middleware/proxy-ca.crt.
|
||||
MITMListen string
|
||||
|
||||
// MITMCADir holds the persisted MITM proxy CA cert and private
|
||||
// key. The CA is generated on first start; subsequent starts
|
||||
// reload it so clients keep trusting the same root. The key
|
||||
// file is mode 0600.
|
||||
MITMCADir string
|
||||
|
||||
|
||||
// PIIPatternOverrides applies persisted per-id deltas (action,
|
||||
// disabled) to the live redactor at startup. Loaded from
|
||||
// runtime_settings.json and applied right after pii.NewRedactor.
|
||||
// nil/empty leaves the YAML defaults in place.
|
||||
PIIPatternOverrides map[string]PIIPatternRuntimeOverride
|
||||
|
||||
DisableWebUI bool
|
||||
OllamaAPIRootEndpoint bool
|
||||
EnforcePredownloadScans bool
|
||||
@@ -160,18 +111,6 @@ type ApplicationConfig struct {
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed DistributedConfig
|
||||
|
||||
// ExposeNodeHeader, when true, activates middleware.ExposeNodeHeader on
|
||||
// the inference routes (OpenAI chat/completions/embeddings, Anthropic
|
||||
// /v1/messages, Ollama /api/chat,/api/generate,/api/embed). The
|
||||
// middleware wraps the response writer and attaches an "X-LocalAI-Node"
|
||||
// response header carrying the ID of the distributed-mode worker node
|
||||
// that served the request. Off by default because the node ID is
|
||||
// internal topology that can aid attacker reconnaissance if surfaced on
|
||||
// a public endpoint; operators opt in explicitly via
|
||||
// --expose-node-header / LOCALAI_EXPOSE_NODE_HEADER for debugging,
|
||||
// observability and load-balancer attribution.
|
||||
ExposeNodeHeader bool
|
||||
|
||||
// LocalAI Assistant chat modality. Hard-disable the in-process admin MCP
|
||||
// server with this flag; runtime-toggleable via /api/settings.
|
||||
DisableLocalAIAssistant bool
|
||||
@@ -248,7 +187,6 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||
TracingMaxItems: 1024,
|
||||
TracingMaxBodyBytes: 64 * 1024, // 64 KiB - caps each request/response body in the trace buffer
|
||||
AgentPool: AgentPoolConfig{
|
||||
Enabled: true,
|
||||
Timeout: "5m",
|
||||
@@ -640,12 +578,6 @@ func WithTracingMaxItems(items int) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithTracingMaxBodyBytes(bytes int) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.TracingMaxBodyBytes = bytes
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeneratedContentDir(generatedContentDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.GeneratedContentDir = generatedContentDir
|
||||
@@ -664,45 +596,6 @@ func WithDataPath(dataPath string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisableStats turns off the billing recorder. CLI: --disable-stats.
|
||||
func WithDisableStats(disable bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DisableStats = disable
|
||||
}
|
||||
}
|
||||
|
||||
// WithPIIConfigPath points the routing PII filter at a YAML config
|
||||
// file. CLI: --pii-config.
|
||||
func WithPIIConfigPath(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.PIIConfigPath = path
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisablePII turns the regex PII filter off. CLI: --disable-pii.
|
||||
func WithDisablePII(disable bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DisablePII = disable
|
||||
}
|
||||
}
|
||||
|
||||
// WithMITMListen sets the address the cloudproxy MITM listener
|
||||
// binds on. Empty = disabled. CLI: --mitm-listen.
|
||||
func WithMITMListen(addr string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MITMListen = addr
|
||||
}
|
||||
}
|
||||
|
||||
// WithMITMCADir sets the directory used to persist the MITM proxy
|
||||
// CA cert + key. CLI: --mitm-ca-dir.
|
||||
func WithMITMCADir(dir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MITMCADir = dir
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DynamicConfigsDir = dynamicConfigsDir
|
||||
@@ -992,15 +885,6 @@ func WithDisableLocalAIAssistant(disabled bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithExposeNodeHeader enables the X-LocalAI-Node response header on
|
||||
// inference endpoints. Default off; the node ID reveals internal cluster
|
||||
// topology and is opt-in for that reason.
|
||||
func WithExposeNodeHeader(enabled bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.ExposeNodeHeader = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
|
||||
// Some options defined at the application level are going to be passed as defaults for
|
||||
// all the configuration for the models.
|
||||
@@ -1036,7 +920,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
f16 := o.F16
|
||||
debug := o.Debug
|
||||
tracingMaxItems := o.TracingMaxItems
|
||||
tracingMaxBodyBytes := o.TracingMaxBodyBytes
|
||||
enableTracing := o.EnableTracing
|
||||
enableBackendLogging := o.EnableBackendLogging
|
||||
cors := o.CORS
|
||||
@@ -1106,8 +989,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
logoHorizontalFile := o.Branding.LogoHorizontalFile
|
||||
faviconFile := o.Branding.FaviconFile
|
||||
|
||||
mitmListen := o.MITMListen
|
||||
|
||||
return RuntimeSettings{
|
||||
WatchdogEnabled: &watchdogEnabled,
|
||||
WatchdogIdleEnabled: &watchdogIdle,
|
||||
@@ -1127,7 +1008,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
F16: &f16,
|
||||
Debug: &debug,
|
||||
TracingMaxItems: &tracingMaxItems,
|
||||
TracingMaxBodyBytes: &tracingMaxBodyBytes,
|
||||
EnableTracing: &enableTracing,
|
||||
EnableBackendLogging: &enableBackendLogging,
|
||||
CORS: &cors,
|
||||
@@ -1161,7 +1041,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
LogoFile: &logoFile,
|
||||
LogoHorizontalFile: &logoHorizontalFile,
|
||||
FaviconFile: &faviconFile,
|
||||
MITMListen: &mitmListen,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1267,9 +1146,6 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
if settings.TracingMaxItems != nil {
|
||||
o.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
if settings.TracingMaxBodyBytes != nil {
|
||||
o.TracingMaxBodyBytes = *settings.TracingMaxBodyBytes
|
||||
}
|
||||
if settings.EnableBackendLogging != nil {
|
||||
o.EnableBackendLogging = *settings.EnableBackendLogging
|
||||
}
|
||||
@@ -1387,10 +1263,6 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
o.Branding.FaviconFile = *settings.FaviconFile
|
||||
}
|
||||
|
||||
if settings.MITMListen != nil {
|
||||
o.MITMListen = *settings.MITMListen
|
||||
}
|
||||
|
||||
// Note: ApiKeys requires special handling (merging with startup keys) - handled in caller
|
||||
|
||||
return requireRestart
|
||||
|
||||
@@ -40,10 +40,7 @@ type DistributedConfig struct {
|
||||
// model-row cleanup on MarkUnhealthy / MarkDraining).
|
||||
DisablePerModelHealthCheck bool
|
||||
|
||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||
|
||||
BackendInstallTimeout time.Duration // NATS round-trip timeout for backend.install (default 15m)
|
||||
BackendUpgradeTimeout time.Duration // NATS round-trip timeout for backend.upgrade (default 15m)
|
||||
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||
|
||||
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
|
||||
|
||||
@@ -71,15 +68,13 @@ func (c DistributedConfig) Validate() error {
|
||||
}
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
FlagMCPToolTimeout: c.MCPToolTimeout,
|
||||
FlagMCPDiscoveryTimeout: c.MCPDiscoveryTimeout,
|
||||
FlagWorkerWaitTimeout: c.WorkerWaitTimeout,
|
||||
FlagDrainTimeout: c.DrainTimeout,
|
||||
FlagHealthCheckInterval: c.HealthCheckInterval,
|
||||
FlagStaleNodeThreshold: c.StaleNodeThreshold,
|
||||
FlagMCPCIJobTimeout: c.MCPCIJobTimeout,
|
||||
FlagBackendInstallTimeout: c.BackendInstallTimeout,
|
||||
FlagBackendUpgradeTimeout: c.BackendUpgradeTimeout,
|
||||
"mcp-tool-timeout": c.MCPToolTimeout,
|
||||
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
|
||||
"worker-wait-timeout": c.WorkerWaitTimeout,
|
||||
"drain-timeout": c.DrainTimeout,
|
||||
"health-check-interval": c.HealthCheckInterval,
|
||||
"stale-node-threshold": c.StaleNodeThreshold,
|
||||
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
|
||||
} {
|
||||
if d < 0 {
|
||||
return fmt.Errorf("%s must not be negative", name)
|
||||
@@ -142,66 +137,24 @@ func WithStorageSecretKey(key string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendInstallTimeout(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.BackendInstallTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendUpgradeTimeout(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.BackendUpgradeTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// Flag names for distributed timeout / interval configuration. These are
|
||||
// the kebab-case identifiers kong derives from the matching RunCMD struct
|
||||
// fields; they appear in Validate error messages and any other operator-
|
||||
// facing surface that needs to reference a specific knob by name. Keeping
|
||||
// them as constants prevents the string from drifting from the actual
|
||||
// flag a future rename would produce.
|
||||
const (
|
||||
FlagMCPToolTimeout = "mcp-tool-timeout"
|
||||
FlagMCPDiscoveryTimeout = "mcp-discovery-timeout"
|
||||
FlagWorkerWaitTimeout = "worker-wait-timeout"
|
||||
FlagDrainTimeout = "drain-timeout"
|
||||
FlagHealthCheckInterval = "health-check-interval"
|
||||
FlagStaleNodeThreshold = "stale-node-threshold"
|
||||
FlagMCPCIJobTimeout = "mcp-ci-job-timeout"
|
||||
FlagBackendInstallTimeout = "backend-install-timeout"
|
||||
FlagBackendUpgradeTimeout = "backend-upgrade-timeout"
|
||||
)
|
||||
|
||||
// Defaults for distributed timeouts.
|
||||
const (
|
||||
DefaultMCPToolTimeout = 360 * time.Second
|
||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
DefaultHealthCheckInterval = 15 * time.Second
|
||||
DefaultStaleNodeThreshold = 60 * time.Second
|
||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||
DefaultBackendInstallTimeout = 15 * time.Minute
|
||||
DefaultBackendUpgradeTimeout = 15 * time.Minute
|
||||
DefaultMCPToolTimeout = 360 * time.Second
|
||||
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
DefaultHealthCheckInterval = 15 * time.Second
|
||||
DefaultStaleNodeThreshold = 60 * time.Second
|
||||
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||
)
|
||||
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||
}
|
||||
|
||||
// BackendUpgradeTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendUpgradeTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendUpgradeTimeout, DefaultBackendUpgradeTimeout)
|
||||
}
|
||||
|
||||
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
var _ = Describe("DistributedConfig backend NATS timeouts", func() {
|
||||
Context("BackendInstallTimeoutOrDefault", func() {
|
||||
It("returns 15 minutes when unset", func() {
|
||||
c := config.DistributedConfig{}
|
||||
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||
})
|
||||
|
||||
It("returns the configured value when set", func() {
|
||||
c := config.DistributedConfig{BackendInstallTimeout: 42 * time.Minute}
|
||||
Expect(c.BackendInstallTimeoutOrDefault()).To(Equal(42 * time.Minute))
|
||||
})
|
||||
})
|
||||
|
||||
Context("BackendUpgradeTimeoutOrDefault", func() {
|
||||
It("returns 15 minutes when unset", func() {
|
||||
c := config.DistributedConfig{}
|
||||
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(15 * time.Minute))
|
||||
})
|
||||
|
||||
It("returns the configured value when set", func() {
|
||||
c := config.DistributedConfig{BackendUpgradeTimeout: 30 * time.Minute}
|
||||
Expect(c.BackendUpgradeTimeoutOrDefault()).To(Equal(30 * time.Minute))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("DistributedConfig flag-name constants", func() {
|
||||
// Pin the kebab-case strings so a rename of the Go field name (or a
|
||||
// CLI flag naming convention change) forces the constant to update,
|
||||
// keeping the Validate error messages and any future operator-facing
|
||||
// surface in sync with the actual CLI flag.
|
||||
DescribeTable("flag name constants",
|
||||
func(actual, expected string) {
|
||||
Expect(actual).To(Equal(expected))
|
||||
},
|
||||
Entry("MCP tool timeout", config.FlagMCPToolTimeout, "mcp-tool-timeout"),
|
||||
Entry("MCP discovery timeout", config.FlagMCPDiscoveryTimeout, "mcp-discovery-timeout"),
|
||||
Entry("worker wait timeout", config.FlagWorkerWaitTimeout, "worker-wait-timeout"),
|
||||
Entry("drain timeout", config.FlagDrainTimeout, "drain-timeout"),
|
||||
Entry("health check interval", config.FlagHealthCheckInterval, "health-check-interval"),
|
||||
Entry("stale node threshold", config.FlagStaleNodeThreshold, "stale-node-threshold"),
|
||||
Entry("MCP CI job timeout", config.FlagMCPCIJobTimeout, "mcp-ci-job-timeout"),
|
||||
Entry("backend install timeout", config.FlagBackendInstallTimeout, "backend-install-timeout"),
|
||||
Entry("backend upgrade timeout", config.FlagBackendUpgradeTimeout, "backend-upgrade-timeout"),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("DistributedConfig.Validate negative-duration errors", func() {
|
||||
It("rejects a negative BackendInstallTimeout with the flag name in the error", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
BackendInstallTimeout: -1 * time.Second,
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(config.FlagBackendInstallTimeout))
|
||||
Expect(err.Error()).To(ContainSubstring("must not be negative"))
|
||||
})
|
||||
|
||||
It("rejects a negative BackendUpgradeTimeout with the flag name in the error", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
BackendUpgradeTimeout: -1 * time.Second,
|
||||
}
|
||||
err := c.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(config.FlagBackendUpgradeTimeout))
|
||||
})
|
||||
|
||||
It("accepts all-zero durations as valid (defaults apply)", func() {
|
||||
c := config.DistributedConfig{
|
||||
Enabled: true,
|
||||
NatsURL: "nats://localhost:4222",
|
||||
}
|
||||
Expect(c.Validate()).To(Succeed())
|
||||
})
|
||||
})
|
||||
@@ -136,36 +136,4 @@ var _ = Describe("Backend hooks and parser defaults", func() {
|
||||
Expect(cfg.EngineArgs["enable_chunked_prefill"]).To(Equal(true))
|
||||
})
|
||||
})
|
||||
|
||||
Context("PromptCacheAll default", func() {
|
||||
It("defaults to true when omitted from YAML", func() {
|
||||
cfg := &ModelConfig{}
|
||||
cfg.SetDefaults()
|
||||
|
||||
Expect(cfg.PromptCacheAll).NotTo(BeNil())
|
||||
Expect(*cfg.PromptCacheAll).To(BeTrue())
|
||||
})
|
||||
|
||||
It("preserves an explicit false from YAML", func() {
|
||||
falseV := false
|
||||
cfg := &ModelConfig{
|
||||
LLMConfig: LLMConfig{PromptCacheAll: &falseV},
|
||||
}
|
||||
cfg.SetDefaults()
|
||||
|
||||
Expect(cfg.PromptCacheAll).NotTo(BeNil())
|
||||
Expect(*cfg.PromptCacheAll).To(BeFalse())
|
||||
})
|
||||
|
||||
It("preserves an explicit true from YAML", func() {
|
||||
trueV := true
|
||||
cfg := &ModelConfig{
|
||||
LLMConfig: LLMConfig{PromptCacheAll: &trueV},
|
||||
}
|
||||
cfg.SetDefaults()
|
||||
|
||||
Expect(cfg.PromptCacheAll).NotTo(BeNil())
|
||||
Expect(*cfg.PromptCacheAll).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -49,31 +49,20 @@ var DiffusersPipelineOptions = []FieldOption{
|
||||
{Value: "StableVideoDiffusionPipeline", Label: "StableVideoDiffusionPipeline"},
|
||||
}
|
||||
|
||||
// UsecaseOptions must stay in sync with GetAllModelConfigUsecases in
|
||||
// core/config/model_config.go — a value missing here is silently
|
||||
// inaccessible from the model editor, which is how `score` (the router
|
||||
// classifier usecase) hid for an entire release.
|
||||
var UsecaseOptions = []FieldOption{
|
||||
{Value: "chat", Label: "Chat"},
|
||||
{Value: "completion", Label: "Completion"},
|
||||
{Value: "edit", Label: "Edit"},
|
||||
{Value: "embeddings", Label: "Embeddings"},
|
||||
{Value: "rerank", Label: "Rerank"},
|
||||
{Value: "score", Label: "Score (Router Classifier)"},
|
||||
{Value: "image", Label: "Image"},
|
||||
{Value: "vision", Label: "Vision"},
|
||||
{Value: "detection", Label: "Detection"},
|
||||
{Value: "face_recognition", Label: "Face Recognition"},
|
||||
{Value: "transcript", Label: "Transcript"},
|
||||
{Value: "diarization", Label: "Diarization"},
|
||||
{Value: "speaker_recognition", Label: "Speaker Recognition"},
|
||||
{Value: "tts", Label: "TTS"},
|
||||
{Value: "sound_generation", Label: "Sound Generation"},
|
||||
{Value: "audio_transform", Label: "Audio Transform"},
|
||||
{Value: "realtime_audio", Label: "Realtime Audio"},
|
||||
{Value: "tokenize", Label: "Tokenize"},
|
||||
{Value: "vad", Label: "VAD"},
|
||||
{Value: "video", Label: "Video"},
|
||||
{Value: "detection", Label: "Detection"},
|
||||
}
|
||||
|
||||
var DiffusersSchedulerOptions = []FieldOption{
|
||||
|
||||
@@ -232,17 +232,6 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Description: "Use the chat template from the model's tokenizer config",
|
||||
Order: 43,
|
||||
},
|
||||
// Router section template — kept in the templates UI section
|
||||
// (rather than the router section under "other") so operators
|
||||
// editing prompt shapes find all template-typed fields in one
|
||||
// place, mirroring how chat / chat_message are grouped.
|
||||
"router.classifier_system_template": {
|
||||
Section: "templates",
|
||||
Label: "Router Classifier System Prompt",
|
||||
Description: "Go text/template (with sprig functions) for the routing system prompt the score classifier feeds to its classifier_model. Executed with `.Policies` ([]{Label, Description}). Empty falls back to the built-in Arch-Router-shaped prompt (route-listing block + JSON output schema). Override when the classifier model was trained on a different schema or you need the routing instructions in a different language. The candidate format scored against the model is fixed at `{\"route\": \"<label>\"}` — keep your override's output schema instruction matching that.",
|
||||
Component: "code-editor",
|
||||
Order: 44,
|
||||
},
|
||||
|
||||
// --- Pipeline ---
|
||||
"pipeline.llm": {
|
||||
@@ -331,207 +320,5 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Description: "Enable CUDA for diffusers",
|
||||
Order: 82,
|
||||
},
|
||||
|
||||
// --- PII filtering (per-model) ---
|
||||
"pii.enabled": {
|
||||
Section: "other",
|
||||
Label: "PII Filtering Enabled",
|
||||
Description: "Enable PII redaction middleware for this model. Unset means use the default (off for local backends, on for proxy-* / cloud-hosted backends).",
|
||||
Component: "toggle",
|
||||
Order: 200,
|
||||
},
|
||||
"pii.patterns": {
|
||||
Section: "other",
|
||||
Label: "PII Pattern Overrides",
|
||||
Description: "Override the global default action for specific patterns on this model. Patterns not listed here inherit the global action (Settings → Middleware → Filtering).",
|
||||
Component: "pii-pattern-list",
|
||||
Order: 201,
|
||||
},
|
||||
|
||||
// --- Cloud passthrough proxy ---
|
||||
// These only have an effect when Backend is set to
|
||||
// "cloud-proxy". When the upstream URL is empty, the model
|
||||
// fails closed — the chat handler does NOT silently fall back
|
||||
// to the local gRPC pipeline.
|
||||
"proxy.mode": {
|
||||
Section: "other",
|
||||
Label: "Proxy Mode",
|
||||
Description: "passthrough forwards the client's OpenAI body verbatim — point upstream_url at an OpenAI-compatible endpoint (incl. Anthropic's /v1/chat/completions compat layer). translate converts OpenAI ↔ Anthropic Messages so you can target a native API (/v1/messages); tool_calls and usage tokens survive the round-trip.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "passthrough", Label: "passthrough (raw forward)"},
|
||||
{Value: "translate", Label: "translate (OpenAI ↔ native)"},
|
||||
},
|
||||
Default: "passthrough",
|
||||
Order: 208,
|
||||
},
|
||||
"proxy.provider": {
|
||||
Section: "other",
|
||||
Label: "Proxy Provider",
|
||||
Description: "Upstream API family. Drives auth header shape (Bearer vs x-api-key + anthropic-version) and, in translate mode, which request/response codec is used.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "openai", Label: "OpenAI"},
|
||||
{Value: "anthropic", Label: "Anthropic"},
|
||||
},
|
||||
Default: "openai",
|
||||
Order: 209,
|
||||
},
|
||||
"proxy.upstream_url": {
|
||||
Section: "other",
|
||||
Label: "Proxy Upstream URL",
|
||||
Description: "Full POST endpoint of the upstream provider (e.g. https://api.openai.com/v1/chat/completions). Only used when Backend is cloud-proxy.",
|
||||
Component: "input",
|
||||
Order: 210,
|
||||
},
|
||||
"proxy.api_key_env": {
|
||||
Section: "other",
|
||||
Label: "Proxy API Key Env Var",
|
||||
Description: "Name of the environment variable holding the upstream API key. Reading from env keeps the secret out of the YAML and the admin UI.",
|
||||
Component: "input",
|
||||
Order: 211,
|
||||
},
|
||||
"proxy.upstream_model": {
|
||||
Section: "other",
|
||||
Label: "Proxy Upstream Model",
|
||||
Description: "Model name sent to the upstream. Leave empty to forward the client's model field unchanged. Useful when the LocalAI alias differs from the upstream's canonical name.",
|
||||
Component: "input",
|
||||
Order: 212,
|
||||
},
|
||||
"proxy.request_timeout_seconds": {
|
||||
Section: "other",
|
||||
Label: "Proxy Request Timeout (seconds)",
|
||||
Description: "Caps the upstream HTTP request duration. 0 disables the deadline; the request still ends when the client disconnects.",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Order: 213,
|
||||
},
|
||||
|
||||
// --- MITM intercept hosts ---
|
||||
// Each host listed here is claimed by this model config; the
|
||||
// cloudproxy MITM listener (see Middleware → MITM Proxy) uses
|
||||
// THIS config's pii: settings to filter the intercepted traffic.
|
||||
// A host claimed by two configs is a critical error — the
|
||||
// listener refuses to start until resolved.
|
||||
"mitm.hosts": {
|
||||
Section: "other",
|
||||
Label: "MITM Intercept Hosts",
|
||||
Description: "Hostnames the cloudproxy MITM proxy terminates TLS for on behalf of this model config. PII filtering and pattern overrides flow from this model when the host is intercepted. Each host must be unique across all configs.",
|
||||
Component: "string-list",
|
||||
Order: 220,
|
||||
},
|
||||
|
||||
// --- Router ---
|
||||
// Routing turns this model config into a dispatcher: the
|
||||
// classifier scores every policy label as a continuation of
|
||||
// the routing prompt and picks the first candidate whose
|
||||
// labels are a superset of the active set. The Routing tab of
|
||||
// the middleware admin page surfaces every model with a router
|
||||
// block.
|
||||
"router.classifier": {
|
||||
Section: "other",
|
||||
Label: "Classifier",
|
||||
Description: "Picks a candidate by scoring every policy label against the prompt. Only \"score\" is shipped today; it asks the classifier_model to rank each label and reads off the softmax. Empty defaults to \"score\".",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "score", Label: "Score (Arch-Router-style)"},
|
||||
},
|
||||
Order: 230,
|
||||
},
|
||||
"router.classifier_model": {
|
||||
Section: "other",
|
||||
Label: "Classifier Model",
|
||||
Description: "Loaded LocalAI model the score classifier asks to rank each policy label as a continuation. Must support the Score gRPC primitive (today: llama-cpp, vLLM) and use the ChatML template. Arch-Router-1.5B Q4_K_M is the canonical choice; any small ChatML instruct model also works at a higher activation_threshold.",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsChat,
|
||||
Order: 231,
|
||||
},
|
||||
"router.fallback": {
|
||||
Section: "other",
|
||||
Label: "Fallback Model",
|
||||
Description: "Model used when no candidate's labels cover the classifier's active label set, or when the classifier errors. Empty means router failures bubble up as HTTP 500 — fail-fast, not silent-bypass.",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsChat,
|
||||
Order: 232,
|
||||
},
|
||||
"router.activation_threshold": {
|
||||
Section: "other",
|
||||
Label: "Activation Threshold",
|
||||
Description: "Softmax-probability floor a policy must clear to join the active label set for a request. Higher → single-label dominant routes; lower → more multi-label activations. 0 picks the package default (0.15). On Arch-Router-1.5B a value around 0.40 keeps the dominant label clean without losing genuine compound activations.",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(1),
|
||||
Step: f64(0.05),
|
||||
Order: 233,
|
||||
},
|
||||
"router.classifier_cache_size": {
|
||||
Section: "other",
|
||||
Label: "Classifier L1 Cache Size",
|
||||
Description: "Bounded LRU keyed on (case-folded, whitespace-trimmed) prompt — amortises the classifier round-trip across verbatim repeats common in agent loops. 0 here means \"use the default\" (1024); the cache cannot be disabled from YAML.",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Order: 234,
|
||||
},
|
||||
"router.policies": {
|
||||
Section: "other",
|
||||
Label: "Policies",
|
||||
Description: "Label vocabulary the classifier scores over. Each policy has a label and a short natural-language description fed verbatim to the classifier model. Short action-oriented sentences work best (\"writing or debugging code\"; \"small talk\").",
|
||||
Component: "router-policies",
|
||||
Order: 235,
|
||||
},
|
||||
"router.candidates": {
|
||||
Section: "other",
|
||||
Label: "Candidates",
|
||||
Description: "Routing table: each entry binds a downstream model to a set of policy labels it can serve. Order matters — the middleware picks the FIRST candidate whose labels are a superset of the active set, so list candidates smallest → largest.",
|
||||
Component: "router-candidates",
|
||||
Order: 236,
|
||||
},
|
||||
"router.score_normalization": {
|
||||
Section: "other",
|
||||
Label: "Score Normalization",
|
||||
Description: "How the score classifier collapses per-candidate joint log-probs into the softmax input. \"raw\" (default) feeds joint log-prob as-is — on-distribution for Arch-Router (the route the model would actually emit if decoded freely). \"mean\" divides by candidate token count — fairer to long labels but off-distribution for models trained to emit fixed-format outputs.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Raw (default)"},
|
||||
{Value: "raw", Label: "Raw"},
|
||||
{Value: "mean", Label: "Mean (length-normalised)"},
|
||||
},
|
||||
Order: 240,
|
||||
},
|
||||
"router.embedding_cache.embedding_model": {
|
||||
Section: "other",
|
||||
Label: "L2 Cache: Embedding Model",
|
||||
Description: "Embedding model used by the L2 decision cache. Embeds incoming probes and looks them up in the per-router local-store collection. Empty disables the cache entirely. nomic-embed-text-v1.5 is the recommended default.",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModels,
|
||||
Order: 237,
|
||||
},
|
||||
"router.embedding_cache.similarity_threshold": {
|
||||
Section: "other",
|
||||
Label: "L2 Cache: Similarity Threshold",
|
||||
Description: "Cosine-similarity floor a cache candidate must clear to count as a hit. 0 picks the package default (0.80). Re-tune per embedding model — the histogram on the Routing tab shows where the cosine distribution actually sits.",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(1),
|
||||
Step: f64(0.01),
|
||||
Order: 238,
|
||||
},
|
||||
"router.embedding_cache.confidence_threshold": {
|
||||
Section: "other",
|
||||
Label: "L2 Cache: Confidence Threshold",
|
||||
Description: "Minimum top-label probability a classifier decision must have to be inserted into the cache. 0 picks the package default (0.60). Uncertain decisions are skipped so they can't poison future paraphrases.",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(1),
|
||||
Step: f64(0.05),
|
||||
Order: 239,
|
||||
},
|
||||
"router.embedding_cache.store_name": {
|
||||
Section: "other",
|
||||
Label: "L2 Cache: Store Name",
|
||||
Description: "Optional override for the local-store collection used by this router's cache. Empty defaults to \"router-cache-<router-model-name>\". Two routers sharing a store_name share their cache (rare).",
|
||||
Component: "input",
|
||||
Order: 240,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,258 +0,0 @@
|
||||
package meta_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
)
|
||||
|
||||
// TestAllFieldsHaveRegistryEntries fails when a NEW ModelConfig field
|
||||
// is added without either a registry entry in DefaultRegistry() or an
|
||||
// entry in the grandfatheredUnregistered baseline below.
|
||||
//
|
||||
// Why this matters: fields without a registry entry render in the UI
|
||||
// with no description, the default `input` (single-line) component,
|
||||
// and land in the catch-all "other" section — which is what we just
|
||||
// hit for router.classifier_system_template. The reflection-based
|
||||
// fallback produces something that works mechanically but is hostile
|
||||
// to operators.
|
||||
//
|
||||
// How to fix when this test fails:
|
||||
//
|
||||
// 1. Preferred — add a registry entry to DefaultRegistry() in
|
||||
// registry.go with Section, Label, Description, and Component.
|
||||
// See e.g. "router.classifier" or "template.chat" for the pattern.
|
||||
//
|
||||
// 2. Escape hatch — append the field path to
|
||||
// grandfatheredUnregistered with a one-line comment justifying
|
||||
// why it has no UI surface (internal, deprecated, legacy
|
||||
// compatibility shim, etc.). The expectation is that this list
|
||||
// shrinks over time as fields get proper registry entries; it
|
||||
// should never grow without good reason.
|
||||
//
|
||||
// The grandfathered list was seeded from a one-time audit. Migrating
|
||||
// the existing 150+ entries to proper registry metadata is out of
|
||||
// scope for any single PR; the lock just stops the list from growing.
|
||||
func TestAllFieldsHaveRegistryEntries(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
md := meta.BuildConfigMetadata(reflect.TypeOf(config.ModelConfig{}))
|
||||
reg := meta.DefaultRegistry()
|
||||
|
||||
grand := make(map[string]struct{}, len(grandfatheredUnregistered))
|
||||
for _, p := range grandfatheredUnregistered {
|
||||
grand[p] = struct{}{}
|
||||
}
|
||||
|
||||
var missing []string
|
||||
for _, f := range md.Fields {
|
||||
if _, ok := reg[f.Path]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := grand[f.Path]; ok {
|
||||
continue
|
||||
}
|
||||
missing = append(missing, f.Path)
|
||||
}
|
||||
|
||||
sort.Strings(missing)
|
||||
g.Expect(missing).To(BeEmpty(),
|
||||
"%d config field(s) have no registry entry and are not on the grandfathered list.\n"+
|
||||
"Add a registry entry to core/config/meta/registry.go OR append to grandfatheredUnregistered in this file with a justification:\n %s",
|
||||
len(missing), strings.Join(missing, "\n "),
|
||||
)
|
||||
|
||||
// Inverse drift check: catch dead entries on the grandfathered
|
||||
// list (field was renamed/removed, or someone wrote a registry
|
||||
// entry without trimming the grandfathered duplicate).
|
||||
known := make(map[string]struct{}, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
known[f.Path] = struct{}{}
|
||||
}
|
||||
var stale, duplicated []string
|
||||
for _, p := range grandfatheredUnregistered {
|
||||
if _, ok := known[p]; !ok {
|
||||
stale = append(stale, p)
|
||||
}
|
||||
if _, ok := reg[p]; ok {
|
||||
duplicated = append(duplicated, p)
|
||||
}
|
||||
}
|
||||
sort.Strings(stale)
|
||||
g.Expect(stale).To(BeEmpty(),
|
||||
"grandfatheredUnregistered references fields that no longer exist in ModelConfig — remove them:\n %s",
|
||||
strings.Join(stale, "\n "))
|
||||
sort.Strings(duplicated)
|
||||
g.Expect(duplicated).To(BeEmpty(),
|
||||
"grandfatheredUnregistered references fields that now HAVE a registry entry — remove them so the test stays meaningful:\n %s",
|
||||
strings.Join(duplicated, "\n "))
|
||||
}
|
||||
|
||||
// grandfatheredUnregistered is the baseline of config fields that
|
||||
// pre-date the registry-coverage test and have no UI metadata yet.
|
||||
// Adding new entries here should be a deliberate, justified decision
|
||||
// — prefer adding a registry entry in registry.go instead.
|
||||
//
|
||||
// Keep the list sorted (one-line-per-entry) so the diff is minimal
|
||||
// when an entry is removed or (rarely) added.
|
||||
var grandfatheredUnregistered = []string{
|
||||
"agent.disable_sink_state",
|
||||
"agent.enable_mcp_prompts",
|
||||
"agent.enable_plan_re_evaluator",
|
||||
"agent.enable_planning",
|
||||
"agent.enable_reasoning",
|
||||
"agent.force_reasoning_tool",
|
||||
"agent.loop_detection",
|
||||
"agent.max_adjustment_attempts",
|
||||
"agent.max_attempts",
|
||||
"agent.max_iterations",
|
||||
"cfg_scale",
|
||||
"concurrency_groups",
|
||||
"cutstrings",
|
||||
"debug",
|
||||
"diffusers.clip_model",
|
||||
"diffusers.clip_skip",
|
||||
"diffusers.clip_subfolder",
|
||||
"diffusers.control_net",
|
||||
"diffusers.enable_parameters",
|
||||
"diffusers.img2img",
|
||||
"disable_log_stats",
|
||||
"disabled",
|
||||
"download_files",
|
||||
"draft_model",
|
||||
"dtype",
|
||||
"enforce_eager",
|
||||
"engine_args",
|
||||
"extract_regex",
|
||||
"feature_flags",
|
||||
"function.argument_regex",
|
||||
"function.argument_regex_key_name",
|
||||
"function.argument_regex_value_name",
|
||||
"function.automatic_tool_parsing_fallback",
|
||||
"function.capture_llm_results",
|
||||
"function.disable_no_action",
|
||||
"function.disable_peg_parser",
|
||||
"function.function_arguments_key",
|
||||
"function.function_name_key",
|
||||
"function.grammar.disable_parallel_new_lines",
|
||||
"function.grammar.expect_strings_after_json",
|
||||
"function.grammar.no_mixed_free_string",
|
||||
"function.grammar.prefix",
|
||||
"function.grammar.properties_order",
|
||||
"function.grammar.schema_type",
|
||||
"function.grammar.triggers",
|
||||
"function.json_regex_match",
|
||||
"function.no_action_description_name",
|
||||
"function.no_action_function_name",
|
||||
"function.replace_function_results",
|
||||
"function.replace_llm_results",
|
||||
"function.response_regex",
|
||||
"function.xml_format.allow_toolcall_in_think",
|
||||
"function.xml_format.key_start",
|
||||
"function.xml_format.key_val_sep",
|
||||
"function.xml_format.key_val_sep2",
|
||||
"function.xml_format.last_tool_end",
|
||||
"function.xml_format.last_val_end",
|
||||
"function.xml_format.raw_argval",
|
||||
"function.xml_format.scope_end",
|
||||
"function.xml_format.scope_start",
|
||||
"function.xml_format.tool_end",
|
||||
"function.xml_format.tool_sep",
|
||||
"function.xml_format.tool_start",
|
||||
"function.xml_format.trim_raw_argval",
|
||||
"function.xml_format.val_end",
|
||||
"function.xml_format_preset",
|
||||
"gpu_memory_utilization",
|
||||
"grammar",
|
||||
"grpc.attempts",
|
||||
"grpc.attempts_sleep_time",
|
||||
"limit_mm_per_prompt.audio",
|
||||
"limit_mm_per_prompt.image",
|
||||
"limit_mm_per_prompt.video",
|
||||
"limits.max_concurrent",
|
||||
"limits.retry_after_seconds",
|
||||
"load_format",
|
||||
"lora_adapter",
|
||||
"lora_adapters",
|
||||
"lora_base",
|
||||
"lora_scale",
|
||||
"lora_scales",
|
||||
"main_gpu",
|
||||
"max_model_len",
|
||||
"mcp.remote",
|
||||
"mcp.stdio",
|
||||
"mirostat",
|
||||
"mirostat_eta",
|
||||
"mirostat_tau",
|
||||
"mmproj",
|
||||
"n_draft",
|
||||
"ngqa",
|
||||
"no_kv_offloading",
|
||||
"no_mulmatq",
|
||||
"numa",
|
||||
"options",
|
||||
"overrides",
|
||||
"parameters.batch",
|
||||
"parameters.clip_skip",
|
||||
"parameters.echo",
|
||||
"parameters.encoding_format",
|
||||
"parameters.frequency_penalty",
|
||||
"parameters.ignore_eos",
|
||||
"parameters.language",
|
||||
"parameters.logit_bias",
|
||||
"parameters.logprobs",
|
||||
"parameters.min_p",
|
||||
"parameters.model",
|
||||
"parameters.n",
|
||||
"parameters.n_keep",
|
||||
"parameters.negative_prompt",
|
||||
"parameters.negative_prompt_scale",
|
||||
"parameters.presence_penalty",
|
||||
"parameters.repeat_last_n",
|
||||
"parameters.rope_freq_base",
|
||||
"parameters.rope_freq_scale",
|
||||
"parameters.tfz",
|
||||
"parameters.tokenizer",
|
||||
"parameters.top_logprobs",
|
||||
"parameters.translate",
|
||||
"parameters.typical_p",
|
||||
"pinned",
|
||||
"prompt_cache_all",
|
||||
"prompt_cache_path",
|
||||
"prompt_cache_ro",
|
||||
"proxy.api_key_file",
|
||||
"reasoning.disable",
|
||||
"reasoning.disable_reasoning_tag_prefill",
|
||||
"reasoning.strip_reasoning_only",
|
||||
"reasoning.tag_pairs",
|
||||
"reasoning.thinking_start_tokens",
|
||||
"reranking",
|
||||
"rms_norm_eps",
|
||||
"roles",
|
||||
"rope_scaling",
|
||||
"step",
|
||||
"stopwords",
|
||||
"swap_space",
|
||||
"system_prompt",
|
||||
"template.edit",
|
||||
"template.function",
|
||||
"template.join_chat_messages_by_character",
|
||||
"template.multimodal",
|
||||
"template.reply_prefix",
|
||||
"tensor_parallel_size",
|
||||
"tensor_split",
|
||||
"trimspace",
|
||||
"trimsuffix",
|
||||
"trust_remote_code",
|
||||
"tts.audio_path",
|
||||
"type",
|
||||
"yarn_attn_factor",
|
||||
"yarn_beta_fast",
|
||||
"yarn_beta_slow",
|
||||
"yarn_ext_factor",
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// MITMHostOwners is the load-bearing piece of D2 — a duplicate host
|
||||
// across model configs is a critical error that disables the listener.
|
||||
// The test exercises both happy paths (no duplicates → clean Owners
|
||||
// map) and conflict detection (two configs on one host → entry in
|
||||
// Conflicts naming both).
|
||||
|
||||
var _ = Describe("ModelConfigLoader.MITMHostOwners", func() {
|
||||
var (
|
||||
dir string
|
||||
loader *config.ModelConfigLoader
|
||||
)
|
||||
|
||||
writeYAML := func(name, body string) {
|
||||
path := filepath.Join(dir, name+".yaml")
|
||||
Expect(os.WriteFile(path, []byte(body), 0o644)).To(Succeed())
|
||||
Expect(loader.ReadModelConfig(path)).To(Succeed())
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
dir, err = os.MkdirTemp("", "mitm-host-owners-test-*")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
loader = config.NewModelConfigLoader(dir)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.RemoveAll(dir)
|
||||
})
|
||||
|
||||
It("returns empty maps when no model declares mitm.hosts", func() {
|
||||
writeYAML("plain", `name: plain
|
||||
backend: llama-cpp
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(BeEmpty())
|
||||
Expect(got.Conflicts).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("indexes hosts to the owning model name", func() {
|
||||
writeYAML("claude", `name: claude
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- api.anthropic.com
|
||||
`)
|
||||
writeYAML("openai", `name: openai
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- api.openai.com
|
||||
- api.openai.azure.com
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(Equal(map[string]string{
|
||||
"api.anthropic.com": "claude",
|
||||
"api.openai.com": "openai",
|
||||
"api.openai.azure.com": "openai",
|
||||
}))
|
||||
Expect(got.Conflicts).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("normalises case and trims whitespace before indexing", func() {
|
||||
writeYAML("claude", `name: claude
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- " API.ANTHROPIC.com "
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(HaveKey("api.anthropic.com"))
|
||||
})
|
||||
|
||||
It("detects two configs claiming the same host as a conflict", func() {
|
||||
// The 1-to-1 invariant the D2 dispatcher relies on: a host
|
||||
// claimed twice means the owner lookup is ambiguous, so the
|
||||
// caller must NOT start the MITM listener until resolved.
|
||||
writeYAML("alpha", `name: alpha
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- api.anthropic.com
|
||||
`)
|
||||
writeYAML("beta", `name: beta
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- api.anthropic.com
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Conflicts).To(HaveKey("api.anthropic.com"))
|
||||
Expect(got.Conflicts["api.anthropic.com"]).To(ConsistOf("alpha", "beta"))
|
||||
})
|
||||
|
||||
It("treats the same host listed twice within ONE config as a no-op (not a conflict)", func() {
|
||||
// A single config repeating a host is benign — same owner
|
||||
// either way. The conflict signal must be cross-config only.
|
||||
writeYAML("dup", `name: dup
|
||||
backend: llama-cpp
|
||||
mitm:
|
||||
hosts:
|
||||
- api.example.com
|
||||
- api.example.com
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(Equal(map[string]string{"api.example.com": "dup"}))
|
||||
Expect(got.Conflicts).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("ignores empty/whitespace-only host entries", func() {
|
||||
writeYAML("sloppy", `name: sloppy
|
||||
backend: llama-cpp
|
||||
mitm:
|
||||
hosts:
|
||||
- ""
|
||||
- " "
|
||||
- api.real.com
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(Equal(map[string]string{"api.real.com": "sloppy"}))
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
@@ -96,330 +95,8 @@ type ModelConfig struct {
|
||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||
|
||||
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
|
||||
PII PIIConfig `yaml:"pii,omitempty" json:"pii,omitempty"`
|
||||
Router RouterConfig `yaml:"router,omitempty" json:"router,omitempty"`
|
||||
Proxy ProxyConfig `yaml:"proxy,omitempty" json:"proxy,omitempty"`
|
||||
MITM MITMModelConfig `yaml:"mitm,omitempty" json:"mitm,omitempty"`
|
||||
Limits LimitsConfig `yaml:"limits,omitempty" json:"limits,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Admission-control limits applied per request. The
|
||||
// admission middleware enforces these before invoking the handler;
|
||||
// requests that exceed a limit get 503 with a Retry-After hint so
|
||||
// clients back off rather than pile on. Per-model so cloud passthroughs
|
||||
// can have a stricter ceiling than local models.
|
||||
type LimitsConfig struct {
|
||||
// MaxConcurrent caps simultaneous in-flight requests for this
|
||||
// model. 0 = unlimited (default). Useful for cloud-passthrough
|
||||
// configs where the upstream rate-limits aggressively, or for
|
||||
// local backends whose memory budget tops out before LocalAI's
|
||||
// queue depth would.
|
||||
MaxConcurrent int `yaml:"max_concurrent,omitempty" json:"max_concurrent,omitempty"`
|
||||
|
||||
// RetryAfterSeconds advises clients how long to wait before
|
||||
// retrying when admission rejects. 0 defaults to 1s — enough to
|
||||
// let an in-flight request finish on a busy local model. The
|
||||
// value is sent verbatim in the Retry-After response header.
|
||||
RetryAfterSeconds int `yaml:"retry_after_seconds,omitempty" json:"retry_after_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// @Description MITM intercept binding for the model. When the cloudproxy
|
||||
// MITM listener is enabled and any host listed here appears in a CONNECT,
|
||||
// the proxy uses THIS model config's pii: settings to filter the
|
||||
// intercepted body. Strict 1-to-1: a host claimed by two configs is a
|
||||
// configuration error and disables the MITM listener until resolved.
|
||||
//
|
||||
// Lets an admin pair a host (api.anthropic.com) with the model's
|
||||
// PII overrides without maintaining a parallel per-host map.
|
||||
type MITMModelConfig struct {
|
||||
// Hosts is the list of hostnames this model claims for MITM
|
||||
// interception. Each entry must be unique across all model configs.
|
||||
Hosts []string `yaml:"hosts,omitempty" json:"hosts,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Cloud proxy configuration. The cloud-proxy backend
|
||||
// forwards a model's traffic to an external provider. Two modes:
|
||||
//
|
||||
// - mode: passthrough — client and upstream must speak the same wire
|
||||
// format; the backend ships the raw request body to the upstream
|
||||
// URL and streams the response back untouched. The streaming PII
|
||||
// filter still runs because it operates on extracted token text.
|
||||
//
|
||||
// - mode: translate — the backend converts LocalAI's internal proto
|
||||
// to the provider's wire format and back. Unlocks cross-provider
|
||||
// routing (OpenAI client → Anthropic upstream, etc.) at the cost
|
||||
// of dropping provider-specific extensions that the internal proto
|
||||
// doesn't model.
|
||||
type ProxyConfig struct {
|
||||
// UpstreamURL is the full POST endpoint, e.g.
|
||||
// https://api.openai.com/v1/chat/completions or
|
||||
// https://api.anthropic.com/v1/messages. Required.
|
||||
UpstreamURL string `yaml:"upstream_url,omitempty" json:"upstream_url,omitempty"`
|
||||
|
||||
// Mode selects passthrough (wire-perfect) or translate (full
|
||||
// control via internal proto). Empty defaults to passthrough.
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||
|
||||
// Provider identifies the upstream's wire format for translate
|
||||
// mode (openai, anthropic). Ignored in passthrough mode — the
|
||||
// wire format there is whatever the client sent.
|
||||
Provider string `yaml:"provider,omitempty" json:"provider,omitempty"`
|
||||
|
||||
// APIKeyEnv names the environment variable holding the upstream
|
||||
// API key. Mutually exclusive with APIKeyFile. Both empty is
|
||||
// allowed (no-auth upstreams).
|
||||
APIKeyEnv string `yaml:"api_key_env,omitempty" json:"api_key_env,omitempty"`
|
||||
|
||||
// APIKeyFile is a path to a file whose contents are the upstream
|
||||
// API key. Trailing whitespace is trimmed. Mutually exclusive
|
||||
// with APIKeyEnv. The integration point for K8s secret mounts,
|
||||
// Vault agent files, and similar external-secret workflows.
|
||||
APIKeyFile string `yaml:"api_key_file,omitempty" json:"api_key_file,omitempty"`
|
||||
|
||||
// UpstreamModel overrides the model name sent to the upstream.
|
||||
// Useful when the LocalAI-facing model alias differs from the
|
||||
// upstream's canonical name (e.g. local "claude-strict" maps to
|
||||
// upstream "claude-3-5-sonnet-20241022"). Empty means forward
|
||||
// the client's model field unchanged.
|
||||
UpstreamModel string `yaml:"upstream_model,omitempty" json:"upstream_model,omitempty"`
|
||||
|
||||
// RequestTimeoutSeconds caps the upstream request duration. 0
|
||||
// means no per-request timeout (only the request context, which
|
||||
// is bound to the client connection, applies).
|
||||
RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// Proxy mode names. Validate() normalises an empty Mode to
|
||||
// ProxyModePassthrough so downstream code only sees concrete values.
|
||||
const (
|
||||
ProxyModePassthrough = "passthrough"
|
||||
ProxyModeTranslate = "translate"
|
||||
)
|
||||
|
||||
// Proxy provider names. Only meaningful in translate mode, where the
|
||||
// cloud-proxy backend picks the wire format to use against the
|
||||
// upstream URL.
|
||||
const (
|
||||
ProxyProviderOpenAI = "openai"
|
||||
ProxyProviderAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// IsCloudProxyBackendPassthrough reports whether this model uses the
|
||||
// cloud-proxy gRPC backend in passthrough mode. Empty Mode counts as
|
||||
// passthrough (SetDefaults normalises it, but Validate accepts empty
|
||||
// too — handlers should not rely on a particular call order).
|
||||
func (c *ModelConfig) IsCloudProxyBackendPassthrough() bool {
|
||||
if c.Backend != "cloud-proxy" {
|
||||
return false
|
||||
}
|
||||
return c.Proxy.Mode == "" || c.Proxy.Mode == ProxyModePassthrough
|
||||
}
|
||||
|
||||
// @Description Intelligent routing configuration. When a model declares
|
||||
// a Router block, requests addressed to it are reclassified at runtime
|
||||
// and dispatched to one of the named candidates. The router rewrites
|
||||
// input.Model in-place, then the standard model-resolution path picks
|
||||
// up the resolved config — meaning ACL checks, disabled-state, and
|
||||
// per-model PII still run against the chosen target.
|
||||
//
|
||||
// Depth-1 invariant: candidates must NOT themselves carry a Router
|
||||
// block. The router's "smart-router → claude-strict → cloud-proxy"
|
||||
// chain is fine, but "router-A → router-B → claude" is rejected at
|
||||
// config load to keep the dispatch graph acyclic and predictable. The
|
||||
// middleware also asserts depth ≤ 1 at runtime as a defensive check.
|
||||
type RouterConfig struct {
|
||||
// Classifier picks the implementation. Only "score" ships today:
|
||||
// it asks the classifier model to score every Policy label as a
|
||||
// continuation of the routing prompt and reads off the
|
||||
// distribution. Empty defaults to "score".
|
||||
Classifier string `yaml:"classifier,omitempty" json:"classifier,omitempty"`
|
||||
|
||||
// Policies is the label vocabulary the classifier scores over.
|
||||
// Each policy carries a natural-language description that ends up
|
||||
// in the system prompt the classifier model sees — short, action-
|
||||
// oriented sentences work best ("writing or debugging code",
|
||||
// "small talk", ...). The Score classifier picks the subset of
|
||||
// labels whose softmax probability passes ActivationThreshold.
|
||||
Policies []RouterPolicy `yaml:"policies,omitempty" json:"policies,omitempty"`
|
||||
|
||||
// Candidates is the routing table — each entry binds a downstream
|
||||
// model to a set of labels it can serve. The middleware picks the
|
||||
// FIRST candidate whose Labels are a superset of the active label
|
||||
// set from the classifier. Admins order this list smallest →
|
||||
// largest so a query that needs one label routes to the smallest
|
||||
// capable model, while a query that needs multiple falls to a
|
||||
// bigger candidate that covers them all.
|
||||
Candidates []RouterCandidate `yaml:"candidates,omitempty" json:"candidates,omitempty"`
|
||||
|
||||
// Fallback is the model used when no candidate matches the active
|
||||
// label set, or when the classifier returns nothing above
|
||||
// threshold. Empty fallback means router failures bubble up as
|
||||
// 500 — fail-fast, not silent-bypass.
|
||||
Fallback string `yaml:"fallback,omitempty" json:"fallback,omitempty"`
|
||||
|
||||
// ClassifierModel names the model the Score classifier scores
|
||||
// against (Arch-Router-1.5B is the canonical choice).
|
||||
ClassifierModel string `yaml:"classifier_model,omitempty" json:"classifier_model,omitempty"`
|
||||
|
||||
// ClassifierCacheSize bounds the per-prompt memo cache that
|
||||
// amortises the classifier round-trip across repeat probes.
|
||||
// 0 disables the cache. Default 1024.
|
||||
ClassifierCacheSize int `yaml:"classifier_cache_size,omitempty" json:"classifier_cache_size,omitempty"`
|
||||
|
||||
// ActivationThreshold is the softmax-probability floor a policy
|
||||
// must clear to be considered "active" for the request. 0
|
||||
// defaults to a sensible value (~0.15) inside the classifier.
|
||||
// Higher → narrower routes (single-label dominant); lower →
|
||||
// more multi-label activations.
|
||||
ActivationThreshold float64 `yaml:"activation_threshold,omitempty" json:"activation_threshold,omitempty"`
|
||||
|
||||
// ClassifierSystemTemplate overrides the routing system prompt
|
||||
// the score classifier feeds to its classifier_model. Go
|
||||
// text/template + Sprig, executed with `.Policies []ScorePolicy`
|
||||
// (Label + Description fields). Empty falls back to the built-in
|
||||
// Arch-Router-shaped template (route-listing block + JSON output
|
||||
// schema). Override when the classifier model was trained on a
|
||||
// different schema (e.g. bare label output, XML route block) or
|
||||
// when the routing instructions need to be in a different
|
||||
// language. The candidate format scored against the model is
|
||||
// fixed at `{"route": "<label>"}` and IS NOT templated — keep
|
||||
// your override's output schema instruction matching that, or
|
||||
// the per-candidate scores degenerate.
|
||||
ClassifierSystemTemplate string `yaml:"classifier_system_template,omitempty" json:"classifier_system_template,omitempty"`
|
||||
|
||||
// ScoreNormalization picks how the score classifier collapses
|
||||
// per-candidate joint log-probs into the softmax input.
|
||||
// - ""/"raw": use joint log-prob as-is (default). Matches the
|
||||
// distribution the classifier model was trained against — the
|
||||
// route the model would actually emit if decoded freely.
|
||||
// - "mean": divide by candidate token count. Fairer to long
|
||||
// labels (their joint log-prob is mechanically smaller because
|
||||
// it sums more negatives), but off-distribution for models
|
||||
// trained to emit fixed-format outputs like Arch-Router's
|
||||
// {"route": "name"}.
|
||||
// Future modes (e.g. "weighted_mean") will land here too.
|
||||
ScoreNormalization string `yaml:"score_normalization,omitempty" json:"score_normalization,omitempty"`
|
||||
|
||||
// EmbeddingCache configures the L2 cache that maps prompt
|
||||
// embeddings to past decisions, so semantically-similar prompts
|
||||
// reuse a classification instead of re-running the classifier
|
||||
// model. Omit the block to disable. See router/embedding_cache.go.
|
||||
EmbeddingCache *EmbeddingCacheConfig `yaml:"embedding_cache,omitempty" json:"embedding_cache,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingCacheConfig configures the L2 embedding-similarity decision
|
||||
// cache. Pairs naturally with a larger / slower classifier model: the
|
||||
// classifier round-trip is amortised across paraphrases of the same
|
||||
// intent. The cache uses the standard /v1/embeddings backend for
|
||||
// vector generation and the local-store gRPC surface for KNN search.
|
||||
type EmbeddingCacheConfig struct {
|
||||
// EmbeddingModel names the loaded LocalAI model used to embed
|
||||
// router prompts. Required when the cache is enabled. Any model
|
||||
// that supports the Embeddings gRPC primitive works;
|
||||
// nomic-embed-text-v1.5 is the recommended default.
|
||||
EmbeddingModel string `yaml:"embedding_model" json:"embedding_model"`
|
||||
|
||||
// SimilarityThreshold is the cosine-similarity floor a cache
|
||||
// candidate must clear to be treated as a hit. 0 picks the
|
||||
// package default (0.80). Higher → fewer false hits, higher miss
|
||||
// rate; lower → more aggressive sharing across paraphrases.
|
||||
SimilarityThreshold float64 `yaml:"similarity_threshold,omitempty" json:"similarity_threshold,omitempty"`
|
||||
|
||||
// ConfidenceThreshold is the minimum classifier top-label
|
||||
// probability for a decision to be inserted into the cache. 0
|
||||
// picks the package default (0.60). Uncertain decisions are not
|
||||
// cached so they can't poison future paraphrases.
|
||||
ConfidenceThreshold float64 `yaml:"confidence_threshold,omitempty" json:"confidence_threshold,omitempty"`
|
||||
|
||||
// StoreName overrides the local-store collection name used for
|
||||
// this router's cache. Empty defaults to "router-cache-<router>"
|
||||
// where <router> is the parent model name. Useful when two
|
||||
// router models should share a cache (rare).
|
||||
StoreName string `yaml:"store_name,omitempty" json:"store_name,omitempty"`
|
||||
}
|
||||
|
||||
// RouterPolicy is one entry in the label vocabulary. The label string
|
||||
// is what the classifier model emits and what candidates reference in
|
||||
// their Labels field; the description is the natural-language hint
|
||||
// fed to the classifier so it can match user intent against the label
|
||||
// space.
|
||||
type RouterPolicy struct {
|
||||
Label string `yaml:"label" json:"label"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
}
|
||||
|
||||
// RouterCandidate names a downstream model and the policy labels it
|
||||
// is willing to serve. Labels are matched as a set: the middleware
|
||||
// picks the first candidate whose Labels is a superset of the
|
||||
// classifier's active set.
|
||||
type RouterCandidate struct {
|
||||
Model string `yaml:"model" json:"model"`
|
||||
Labels []string `yaml:"labels" json:"labels"`
|
||||
}
|
||||
|
||||
// HasRouter returns true when the model declares a router config with
|
||||
// at least one candidate. Used by the RouteModel middleware to decide
|
||||
// whether to engage the classifier.
|
||||
func (c *ModelConfig) HasRouter() bool {
|
||||
return len(c.Router.Candidates) > 0
|
||||
}
|
||||
|
||||
// @Description PII filtering configuration. PII redaction is per-model so
|
||||
// that local models don't pay the latency or behaviour change of regex
|
||||
// scanning, while cloud-bound traffic (cloud-proxy backend) can default to
|
||||
// on. Setting Enabled explicitly always wins over the backend default.
|
||||
type PIIConfig struct {
|
||||
// Enabled toggles redaction for this model. When unset (zero value),
|
||||
// the resolved default depends on Backend: cloud-proxy defaults to
|
||||
// true, everything else to false. A pointer is used so the absence of
|
||||
// the YAML key is distinguishable from explicit false.
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
|
||||
// Patterns lets a model upgrade or downgrade individual pattern
|
||||
// actions (mask | block | route_local) relative to the global
|
||||
// defaults loaded from --pii-config / DefaultPatterns. Pattern IDs
|
||||
// not listed inherit the global action. The regex itself stays
|
||||
// global — only the action is settable per-model.
|
||||
Patterns []PIIPatternOverride `yaml:"patterns,omitempty" json:"patterns,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Per-model action override for a single PII pattern.
|
||||
type PIIPatternOverride struct {
|
||||
ID string `yaml:"id" json:"id"`
|
||||
Action string `yaml:"action" json:"action"`
|
||||
}
|
||||
|
||||
// PIIIsEnabled returns the resolved PII state for this model. Single
|
||||
// source of truth for the gating decision so the middleware and the
|
||||
// /api/middleware/status admin view agree.
|
||||
func (c *ModelConfig) PIIIsEnabled() bool {
|
||||
if c.PII.Enabled != nil {
|
||||
return *c.PII.Enabled
|
||||
}
|
||||
return c.Backend == "cloud-proxy"
|
||||
}
|
||||
|
||||
// PIIPatternOverrides returns the per-pattern action overrides as a map
|
||||
// keyed by pattern ID. The values are the raw action strings — the pii
|
||||
// package validates and converts them.
|
||||
//
|
||||
// Returned via the documented modelPIIConfig interface in
|
||||
// core/services/routing/pii/middleware.go without taking a config
|
||||
// dependency on this package.
|
||||
func (c *ModelConfig) PIIPatternOverrides() map[string]string {
|
||||
if len(c.PII.Patterns) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(c.PII.Patterns))
|
||||
for _, p := range c.PII.Patterns {
|
||||
if p.ID == "" {
|
||||
continue
|
||||
}
|
||||
out[p.ID] = p.Action
|
||||
}
|
||||
return out
|
||||
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
|
||||
}
|
||||
|
||||
// @Description MCP configuration
|
||||
@@ -532,7 +209,7 @@ type LLMConfig struct {
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
|
||||
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
|
||||
PromptCacheAll *bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
|
||||
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
|
||||
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
|
||||
@@ -724,14 +401,6 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
f16 := lo.f16
|
||||
debug := lo.debug
|
||||
|
||||
// Cloud-proxy: normalise empty Mode so downstream consumers
|
||||
// switch on two concrete values only. Validate accepts empty too,
|
||||
// but SetDefaults is the chokepoint that runs before any
|
||||
// inference path reads cfg.Proxy.Mode.
|
||||
if cfg.Proxy.Mode == "" {
|
||||
cfg.Proxy.Mode = ProxyModePassthrough
|
||||
}
|
||||
|
||||
// Apply model-family-specific inference defaults before generic fallbacks.
|
||||
// This ensures gallery-installed and runtime-loaded models get optimal parameters.
|
||||
ApplyInferenceDefaults(cfg, cfg.Name, cfg.Model)
|
||||
@@ -825,13 +494,6 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
cfg.Reranking = &falseV
|
||||
}
|
||||
|
||||
if cfg.PromptCacheAll == nil {
|
||||
// Match upstream llama.cpp's default (common/common.h: cache_prompt = true)
|
||||
// and let cache_idle_slots / kv_unified actually do useful work; users can
|
||||
// opt out with an explicit `prompt_cache_all: false` in the model YAML.
|
||||
cfg.PromptCacheAll = &trueV
|
||||
}
|
||||
|
||||
if threads == 0 {
|
||||
// Threads can't be 0
|
||||
threads = 4
|
||||
@@ -904,74 +566,9 @@ func (c *ModelConfig) Validate() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Cloud-proxy: at most one of api_key_env / api_key_file may be
|
||||
// set. Both empty means no Authorization header (no-auth upstream
|
||||
// or a development passthrough). The mode field accepts the empty
|
||||
// string (defaults to passthrough), "passthrough", or "translate".
|
||||
if c.Proxy.APIKeyEnv != "" && c.Proxy.APIKeyFile != "" {
|
||||
return false, fmt.Errorf("proxy: api_key_env and api_key_file are mutually exclusive")
|
||||
}
|
||||
switch c.Proxy.Mode {
|
||||
case "", ProxyModePassthrough, ProxyModeTranslate:
|
||||
// Empty is accepted at validate-time and normalised to
|
||||
// passthrough by SetDefaults so it never reaches runtime.
|
||||
default:
|
||||
return false, fmt.Errorf("proxy: unknown mode %q (expected %s or %s)",
|
||||
c.Proxy.Mode, ProxyModePassthrough, ProxyModeTranslate)
|
||||
}
|
||||
if c.Proxy.Mode == ProxyModeTranslate && c.Proxy.Provider == "" {
|
||||
return false, fmt.Errorf("proxy: translate mode requires provider (%s, %s)",
|
||||
ProxyProviderOpenAI, ProxyProviderAnthropic)
|
||||
}
|
||||
|
||||
// Score on llama-cpp bypasses the slot loop and races the
|
||||
// llama_context against concurrent generation/embedding traffic
|
||||
// (see backend/cpp/llama-cpp/grpc-server.cpp on Score). Reject the
|
||||
// combination here so operators are forced to split the model.
|
||||
const scoreConflicts = FLAG_CHAT | FLAG_COMPLETION | FLAG_EMBEDDINGS
|
||||
if (c.Backend == "llama-cpp" || c.Backend == "llama") &&
|
||||
c.HasUsecases(FLAG_SCORE) && c.KnownUsecases != nil &&
|
||||
*c.KnownUsecases&scoreConflicts != 0 {
|
||||
return false, fmt.Errorf(
|
||||
"known_usecases conflict on llama-cpp: score is incompatible " +
|
||||
"with chat/completion/embeddings — split into separate model configs")
|
||||
}
|
||||
|
||||
// router.score_normalization is consumed lazily by the score
|
||||
// classifier at first-request time; without load-time validation
|
||||
// a typo wouldn't surface until the first router request panicked
|
||||
// inside NewScoreClassifier. Reject unknown values here so the
|
||||
// operator sees the offending key at startup.
|
||||
switch c.Router.ScoreNormalization {
|
||||
case "", ScoreNormalizationRaw, ScoreNormalizationMean:
|
||||
// ok
|
||||
default:
|
||||
return false, fmt.Errorf("router: unknown score_normalization %q (expected %q or %q)",
|
||||
c.Router.ScoreNormalization, ScoreNormalizationRaw, ScoreNormalizationMean)
|
||||
}
|
||||
|
||||
// router.classifier_system_template parses as Go text/template
|
||||
// (Sprig funcs available at execution time). Reject malformed
|
||||
// templates at load time so the operator sees the parse error
|
||||
// at startup rather than as a 500 on the first router request.
|
||||
if c.Router.ClassifierSystemTemplate != "" {
|
||||
if _, err := template.New("classifier_system").Parse(c.Router.ClassifierSystemTemplate); err != nil {
|
||||
return false, fmt.Errorf("router: classifier_system_template parse error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Score normalisation modes mirror router.ScoreNormalization* —
|
||||
// duplicated as constants on the config package so ModelConfig.Validate
|
||||
// can reject unknown values without taking a dependency on the router
|
||||
// package (which already depends on config).
|
||||
const (
|
||||
ScoreNormalizationRaw = "raw"
|
||||
ScoreNormalizationMean = "mean"
|
||||
)
|
||||
|
||||
func (c *ModelConfig) HasTemplate() bool {
|
||||
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate
|
||||
}
|
||||
@@ -1020,19 +617,19 @@ func (c *ModelConfig) GetConcurrencyGroups() []string {
|
||||
type ModelConfigUsecase int
|
||||
|
||||
const (
|
||||
FLAG_ANY ModelConfigUsecase = 0b000000000000
|
||||
FLAG_CHAT ModelConfigUsecase = 0b000000000001
|
||||
FLAG_COMPLETION ModelConfigUsecase = 0b000000000010
|
||||
FLAG_EDIT ModelConfigUsecase = 0b000000000100
|
||||
FLAG_EMBEDDINGS ModelConfigUsecase = 0b000000001000
|
||||
FLAG_RERANK ModelConfigUsecase = 0b000000010000
|
||||
FLAG_IMAGE ModelConfigUsecase = 0b000000100000
|
||||
FLAG_TRANSCRIPT ModelConfigUsecase = 0b000001000000
|
||||
FLAG_TTS ModelConfigUsecase = 0b000010000000
|
||||
FLAG_SOUND_GENERATION ModelConfigUsecase = 0b000100000000
|
||||
FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000
|
||||
FLAG_VAD ModelConfigUsecase = 0b010000000000
|
||||
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
|
||||
FLAG_ANY ModelConfigUsecase = 0b000000000000
|
||||
FLAG_CHAT ModelConfigUsecase = 0b000000000001
|
||||
FLAG_COMPLETION ModelConfigUsecase = 0b000000000010
|
||||
FLAG_EDIT ModelConfigUsecase = 0b000000000100
|
||||
FLAG_EMBEDDINGS ModelConfigUsecase = 0b000000001000
|
||||
FLAG_RERANK ModelConfigUsecase = 0b000000010000
|
||||
FLAG_IMAGE ModelConfigUsecase = 0b000000100000
|
||||
FLAG_TRANSCRIPT ModelConfigUsecase = 0b000001000000
|
||||
FLAG_TTS ModelConfigUsecase = 0b000010000000
|
||||
FLAG_SOUND_GENERATION ModelConfigUsecase = 0b000100000000
|
||||
FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000
|
||||
FLAG_VAD ModelConfigUsecase = 0b010000000000
|
||||
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
|
||||
FLAG_DETECTION ModelConfigUsecase = 0b1000000000000
|
||||
FLAG_VISION ModelConfigUsecase = 0b10000000000000
|
||||
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b100000000000000
|
||||
@@ -1040,14 +637,6 @@ const (
|
||||
FLAG_AUDIO_TRANSFORM ModelConfigUsecase = 0b10000000000000000
|
||||
FLAG_DIARIZATION ModelConfigUsecase = 0b100000000000000000
|
||||
FLAG_REALTIME_AUDIO ModelConfigUsecase = 0b1000000000000000000
|
||||
// Marks a model as wired for the Score gRPC primitive (joint
|
||||
// log-prob of candidate continuations under a shared prompt). Must
|
||||
// be declared explicitly via `known_usecases: [score]` — there's
|
||||
// no heuristic for it. On the llama-cpp backend, Score bypasses
|
||||
// the slot loop and races the llama_context, so Validate() refuses
|
||||
// to load a llama-cpp config that combines FLAG_SCORE with
|
||||
// chat/completion/embeddings.
|
||||
FLAG_SCORE ModelConfigUsecase = 0b10000000000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
@@ -1057,12 +646,12 @@ const (
|
||||
// Flags within the same group are NOT orthogonal (e.g., chat and completion are
|
||||
// both text/language). A model is multimodal when its usecases span 2+ groups.
|
||||
var ModalityGroups = []ModelConfigUsecase{
|
||||
FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT, // text/language
|
||||
FLAG_VISION | FLAG_DETECTION, // visual understanding
|
||||
FLAG_TRANSCRIPT | FLAG_REALTIME_AUDIO, // speech input — realtime_audio is any-to-any, so it counts here too
|
||||
FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT, // text/language
|
||||
FLAG_VISION | FLAG_DETECTION, // visual understanding
|
||||
FLAG_TRANSCRIPT | FLAG_REALTIME_AUDIO, // speech input — realtime_audio is any-to-any, so it counts here too
|
||||
FLAG_TTS | FLAG_SOUND_GENERATION | FLAG_REALTIME_AUDIO, // audio output — and here, so a lone realtime_audio flag still reads as multimodal
|
||||
FLAG_AUDIO_TRANSFORM, // audio in/out transforms
|
||||
FLAG_IMAGE | FLAG_VIDEO, // visual generation
|
||||
FLAG_AUDIO_TRANSFORM, // audio in/out transforms
|
||||
FLAG_IMAGE | FLAG_VIDEO, // visual generation
|
||||
}
|
||||
|
||||
// IsMultimodal returns true if the given usecases span two or more orthogonal
|
||||
@@ -1085,19 +674,19 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
return map[string]ModelConfigUsecase{
|
||||
// Note: FLAG_ANY is intentionally excluded from this map
|
||||
// because it's 0 and would always match in HasUsecases checks
|
||||
"FLAG_CHAT": FLAG_CHAT,
|
||||
"FLAG_COMPLETION": FLAG_COMPLETION,
|
||||
"FLAG_EDIT": FLAG_EDIT,
|
||||
"FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
|
||||
"FLAG_RERANK": FLAG_RERANK,
|
||||
"FLAG_IMAGE": FLAG_IMAGE,
|
||||
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
|
||||
"FLAG_TTS": FLAG_TTS,
|
||||
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
|
||||
"FLAG_TOKENIZE": FLAG_TOKENIZE,
|
||||
"FLAG_VAD": FLAG_VAD,
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
"FLAG_VIDEO": FLAG_VIDEO,
|
||||
"FLAG_CHAT": FLAG_CHAT,
|
||||
"FLAG_COMPLETION": FLAG_COMPLETION,
|
||||
"FLAG_EDIT": FLAG_EDIT,
|
||||
"FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
|
||||
"FLAG_RERANK": FLAG_RERANK,
|
||||
"FLAG_IMAGE": FLAG_IMAGE,
|
||||
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
|
||||
"FLAG_TTS": FLAG_TTS,
|
||||
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
|
||||
"FLAG_TOKENIZE": FLAG_TOKENIZE,
|
||||
"FLAG_VAD": FLAG_VAD,
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
"FLAG_VIDEO": FLAG_VIDEO,
|
||||
"FLAG_DETECTION": FLAG_DETECTION,
|
||||
"FLAG_VISION": FLAG_VISION,
|
||||
"FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION,
|
||||
@@ -1105,7 +694,6 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
"FLAG_AUDIO_TRANSFORM": FLAG_AUDIO_TRANSFORM,
|
||||
"FLAG_DIARIZATION": FLAG_DIARIZATION,
|
||||
"FLAG_REALTIME_AUDIO": FLAG_REALTIME_AUDIO,
|
||||
"FLAG_SCORE": FLAG_SCORE,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1131,23 +719,9 @@ func GetUsecasesFromYAML(input []string) *ModelConfigUsecase {
|
||||
}
|
||||
|
||||
// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success.
|
||||
//
|
||||
// Declared known_usecases are normally additive — the guessing heuristic
|
||||
// still adds whatever it can infer from backend/templates. The one
|
||||
// exception is FLAG_SCORE: when the operator declared score, they
|
||||
// reserved the model for the router classifier. Letting GuessUsecases
|
||||
// paint chat/completion on top would surface it in chat pickers it was
|
||||
// deliberately kept out of, and (on llama-cpp) reintroduce the slot
|
||||
// contention the score/chat conflict check exists to prevent. So a
|
||||
// declared score list is authoritative.
|
||||
func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
|
||||
if c.KnownUsecases != nil {
|
||||
if (u & *c.KnownUsecases) == u {
|
||||
return true
|
||||
}
|
||||
if (*c.KnownUsecases & FLAG_SCORE) == FLAG_SCORE {
|
||||
return false
|
||||
}
|
||||
if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
|
||||
return true
|
||||
}
|
||||
return c.GuessUsecases(u)
|
||||
}
|
||||
@@ -1304,14 +878,6 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SCORE) == FLAG_SCORE {
|
||||
// No heuristic: Score-intent is a deliberate operator choice
|
||||
// (it reserves the model from generation traffic on llama-cpp),
|
||||
// so HasUsecases(FLAG_SCORE) is true only when KnownUsecases
|
||||
// declares it explicitly.
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -388,49 +388,6 @@ func (bcl *ModelConfigLoader) Preload(modelPath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MITMHostOwnership is the result of mapping intercept hosts to the
|
||||
// model configs that claim them. The invariant the dispatcher relies
|
||||
// on: every host belongs to AT MOST one model config. Any duplicate
|
||||
// is surfaced via Conflicts and disables the MITM listener until
|
||||
// resolved — a half-applied "first wins" rule would silently mask
|
||||
// configuration drift, so we fail loud.
|
||||
type MITMHostOwnership struct {
|
||||
// Owners maps lowercase hostname → owning model name. Empty when
|
||||
// no model declares mitm.hosts.
|
||||
Owners map[string]string
|
||||
// Conflicts lists hosts claimed by 2+ configs, with the names of
|
||||
// the configs that claim them. Non-empty Conflicts means callers
|
||||
// must NOT start the MITM listener.
|
||||
Conflicts map[string][]string
|
||||
}
|
||||
|
||||
// MITMHostOwners walks every loaded ModelConfig's mitm.hosts, builds
|
||||
// the host→owner index, and reports any duplicates. The lookup table
|
||||
// is hostname-lowercased to match the Server's allowlist semantics.
|
||||
func (bcl *ModelConfigLoader) MITMHostOwners() MITMHostOwnership {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
owners := map[string]string{}
|
||||
collisions := map[string][]string{}
|
||||
for name, cfg := range bcl.configs {
|
||||
for _, h := range cfg.MITM.Hosts {
|
||||
h = strings.ToLower(strings.TrimSpace(h))
|
||||
if h == "" {
|
||||
continue
|
||||
}
|
||||
if existing, ok := owners[h]; ok && existing != name {
|
||||
if _, seen := collisions[h]; !seen {
|
||||
collisions[h] = []string{existing}
|
||||
}
|
||||
collisions[h] = append(collisions[h], name)
|
||||
continue
|
||||
}
|
||||
owners[h] = name
|
||||
}
|
||||
}
|
||||
return MITMHostOwnership{Owners: owners, Conflicts: collisions}
|
||||
}
|
||||
|
||||
// LoadModelConfigsFromPath reads all the configurations of the models from a path
|
||||
// (non-recursive)
|
||||
func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
|
||||
|
||||
@@ -54,141 +54,6 @@ parameters:
|
||||
Expect(err).To(BeNil())
|
||||
Expect(valid).To(BeTrue())
|
||||
|
||||
// llama-cpp configs can't mix the score usecase with
|
||||
// chat/completion/embeddings — Score bypasses the slot
|
||||
// loop and would race the llama_context. The check fires
|
||||
// at load and save time; here we exercise it directly.
|
||||
scoreFlag := FLAG_SCORE | FLAG_CHAT
|
||||
conflicting := ModelConfig{
|
||||
Name: "router-but-also-chat",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &scoreFlag,
|
||||
}
|
||||
valid, err = conflicting.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("score is incompatible"))
|
||||
|
||||
scoreOnly := FLAG_SCORE
|
||||
dedicated := ModelConfig{
|
||||
Name: "router-only",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &scoreOnly,
|
||||
}
|
||||
valid, err = dedicated.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// The constraint is llama-cpp-specific; other backends
|
||||
// may safely combine.
|
||||
scoreAndChat := FLAG_SCORE | FLAG_CHAT
|
||||
otherBackend := ModelConfig{
|
||||
Name: "vllm-router-and-chat",
|
||||
Backend: "vllm",
|
||||
KnownUsecases: &scoreAndChat,
|
||||
}
|
||||
valid, err = otherBackend.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Cloud-proxy: api_key_env and api_key_file are mutually
|
||||
// exclusive — picking both is a config bug we catch at
|
||||
// load/save rather than at backend-load time.
|
||||
bothKeys := ModelConfig{
|
||||
Name: "both-keys",
|
||||
Backend: "cloud-proxy",
|
||||
Proxy: ProxyConfig{
|
||||
UpstreamURL: "https://example.com/v1",
|
||||
APIKeyEnv: "OPENAI_KEY",
|
||||
APIKeyFile: "/run/secrets/openai",
|
||||
},
|
||||
}
|
||||
valid, err = bothKeys.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("mutually exclusive")))
|
||||
|
||||
// Translate mode requires a provider — without one, the
|
||||
// backend has no way to pick a wire format.
|
||||
translateNoProvider := ModelConfig{
|
||||
Name: "translate-no-provider",
|
||||
Backend: "cloud-proxy",
|
||||
Proxy: ProxyConfig{UpstreamURL: "https://example.com/v1", Mode: ProxyModeTranslate},
|
||||
}
|
||||
valid, err = translateNoProvider.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("translate mode requires provider")))
|
||||
|
||||
// Unknown mode is rejected.
|
||||
badMode := ModelConfig{
|
||||
Name: "bad-mode",
|
||||
Backend: "cloud-proxy",
|
||||
Proxy: ProxyConfig{UpstreamURL: "https://example.com/v1", Mode: "rewrite"},
|
||||
}
|
||||
valid, err = badMode.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("unknown mode")))
|
||||
|
||||
// Passthrough (default) with one key source is happy.
|
||||
passthroughOK := ModelConfig{
|
||||
Name: "passthrough-ok",
|
||||
Backend: "cloud-proxy",
|
||||
Proxy: ProxyConfig{UpstreamURL: "https://example.com/v1", APIKeyEnv: "OPENAI_KEY"},
|
||||
}
|
||||
valid, err = passthroughOK.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// router.score_normalization: load-time rejection of an
|
||||
// unknown value. The classifier consumes it lazily, so
|
||||
// without this validation a YAML typo wouldn't surface
|
||||
// until the first router request panicked deep in
|
||||
// NewScoreClassifier.
|
||||
badNorm := ModelConfig{
|
||||
Name: "bad-norm",
|
||||
Router: RouterConfig{
|
||||
ScoreNormalization: "men", // typo of "mean"
|
||||
},
|
||||
}
|
||||
valid, err = badNorm.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("unknown score_normalization")))
|
||||
|
||||
// Accepted values pass.
|
||||
for _, mode := range []string{"", ScoreNormalizationRaw, ScoreNormalizationMean} {
|
||||
goodNorm := ModelConfig{
|
||||
Name: "good-norm-" + mode,
|
||||
Router: RouterConfig{ScoreNormalization: mode},
|
||||
}
|
||||
valid, err = goodNorm.Validate()
|
||||
Expect(valid).To(BeTrue(), "score_normalization=%q should be accepted", mode)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
|
||||
// router.classifier_system_template: parse-time rejection
|
||||
// of malformed Go templates. Same reasoning as above —
|
||||
// without this the parse error wouldn't surface until
|
||||
// the first router request panicked in NewScoreClassifier.
|
||||
badTmpl := ModelConfig{
|
||||
Name: "bad-tmpl",
|
||||
Router: RouterConfig{
|
||||
ClassifierSystemTemplate: "Routes: {{range .Policies",
|
||||
},
|
||||
}
|
||||
valid, err = badTmpl.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("classifier_system_template parse error")))
|
||||
|
||||
// Well-formed template passes.
|
||||
goodTmpl := ModelConfig{
|
||||
Name: "good-tmpl",
|
||||
Router: RouterConfig{
|
||||
ClassifierSystemTemplate: `Routes: {{range .Policies}}{{.Label}} {{end}}`,
|
||||
},
|
||||
}
|
||||
valid, err = goodTmpl.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml
|
||||
httpClient := http.Client{}
|
||||
resp, err := httpClient.Get("https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml")
|
||||
@@ -303,29 +168,6 @@ parameters:
|
||||
Expect(i.HasUsecases(FLAG_TTS)).To(BeFalse())
|
||||
Expect(i.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
|
||||
Expect(i.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
||||
|
||||
// Declared `known_usecases: [score]` is authoritative — the
|
||||
// guessing heuristic must NOT add chat on top, even though the
|
||||
// inherited chatml template would otherwise satisfy the chat
|
||||
// heuristic. Score means "this model is reserved for the
|
||||
// router classifier"; surfacing it as a chat model defeats the
|
||||
// reservation and reintroduces the slot contention the load-time
|
||||
// score/chat conflict check exists to prevent.
|
||||
scoreReserved := FLAG_SCORE
|
||||
j := ModelConfig{
|
||||
Name: "arch-router",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &scoreReserved,
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "inherited from chatml",
|
||||
ChatMessage: "inherited from chatml",
|
||||
Completion: "inherited from chatml",
|
||||
},
|
||||
}
|
||||
Expect(j.HasUsecases(FLAG_SCORE)).To(BeTrue())
|
||||
Expect(j.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
Expect(j.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
|
||||
Expect(j.HasUsecases(FLAG_EMBEDDINGS)).To(BeFalse())
|
||||
})
|
||||
It("Test Validate with invalid MCP config", func() {
|
||||
tmp, err := os.CreateTemp("", "config.yaml")
|
||||
|
||||
@@ -38,7 +38,6 @@ type RuntimeSettings struct {
|
||||
Debug *bool `json:"debug,omitempty"`
|
||||
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
||||
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
||||
TracingMaxBodyBytes *int `json:"tracing_max_body_bytes,omitempty"` // Per-body cap in bytes; 0 disables the cap
|
||||
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
|
||||
|
||||
// Security/CORS settings
|
||||
@@ -90,26 +89,4 @@ type RuntimeSettings struct {
|
||||
LogoFile *string `json:"logo_file,omitempty"`
|
||||
LogoHorizontalFile *string `json:"logo_horizontal_file,omitempty"`
|
||||
FaviconFile *string `json:"favicon_file,omitempty"`
|
||||
|
||||
// Cloud-proxy MITM listener. MITMCADir is intentionally NOT
|
||||
// exposed at runtime — the CA dir is a startup-only path and
|
||||
// changing it after the CA has been generated would orphan
|
||||
// trusted clients.
|
||||
MITMListen *string `json:"mitm_listen,omitempty"`
|
||||
|
||||
// PII pattern overrides — keyed by pattern id, applied to the live
|
||||
// redactor at startup and persisted by POST /api/pii/patterns/persist.
|
||||
// Distinguishes from --pii-config (which replaces the entire
|
||||
// pattern set) by only carrying the per-id action/enabled deltas
|
||||
// against the global default catalog.
|
||||
PIIPatternOverrides *map[string]PIIPatternRuntimeOverride `json:"pii_pattern_overrides,omitempty"`
|
||||
}
|
||||
|
||||
// PIIPatternRuntimeOverride captures the persistable deltas an admin
|
||||
// has applied to a single global PII pattern. Both fields are pointers
|
||||
// so an override that only flips Disabled doesn't have to also restate
|
||||
// Action (and vice versa).
|
||||
type PIIPatternRuntimeOverride struct {
|
||||
Action *string `json:"action,omitempty"`
|
||||
Disabled *bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
@@ -51,25 +51,6 @@ var _ = Describe("RuntimeSettings persistence helpers", func() {
|
||||
})
|
||||
})
|
||||
|
||||
// MITM round trip pins the contract that loadRuntimeSettingsFromFile
|
||||
// MITM listener address must survive a write/read round trip so the
|
||||
// next process restart can bring the listener back up. (Intercept
|
||||
// hosts now live in model YAML rather than runtime_settings.json.)
|
||||
Describe("MITM round trip", func() {
|
||||
It("preserves mitm_listen across read/write", func() {
|
||||
listen := ":8443"
|
||||
Expect(cfg.WritePersistedSettings(config.RuntimeSettings{
|
||||
MITMListen: &listen,
|
||||
})).To(Succeed())
|
||||
|
||||
got, err := cfg.ReadPersistedSettings()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(got.MITMListen).ToNot(BeNil())
|
||||
Expect(*got.MITMListen).To(Equal(":8443"))
|
||||
})
|
||||
})
|
||||
|
||||
// PreserveOnSaveDoesNotClobberAssets reproduces the user-reported
|
||||
// regression: an admin uploads a logo, then clicks Save on the
|
||||
// Settings page. The Save body still has the stale pre-upload
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/finetune"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/quantization"
|
||||
|
||||
@@ -211,18 +212,19 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
e.Use(middleware.Recover())
|
||||
}
|
||||
|
||||
// Metrics middleware. The metric service was created in
|
||||
// application.start() so the OTel global provider is set before any
|
||||
// counter is registered (the routing-module billing recorder relies
|
||||
// on this). We reuse that instance here rather than calling
|
||||
// monitoring.NewLocalAIMetricsService a second time, which would
|
||||
// create a second provider, second prometheus exporter, and orphan
|
||||
// whichever instance lost the SetMeterProvider race.
|
||||
if metricsService := application.MetricsService(); metricsService != nil {
|
||||
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
_ = metricsService.Shutdown()
|
||||
})
|
||||
// Metrics middleware
|
||||
if !application.ApplicationConfig().DisableMetrics {
|
||||
metricsService, err := monitoring.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if metricsService != nil {
|
||||
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
@@ -265,9 +267,10 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
e.Static("/generated-videos", videoPath)
|
||||
}
|
||||
|
||||
// Usage recording is initialised in application/startup.go and
|
||||
// surfaced via application.StatsRecorder(); routes wire UsageMiddleware
|
||||
// against that recorder regardless of auth state.
|
||||
// Initialize usage recording when auth DB is available
|
||||
if application.AuthDB() != nil {
|
||||
httpMiddleware.InitUsageRecorder(application.AuthDB())
|
||||
}
|
||||
|
||||
// Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is
|
||||
// the role of the exempt-path logic inside the middleware.
|
||||
@@ -354,33 +357,12 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
// Register auth routes (login, callback, API keys, user management)
|
||||
routes.RegisterAuthRoutes(e, application)
|
||||
|
||||
// Register routing-module usage endpoints. Unlike /api/auth/usage
|
||||
// these go through the StatsRecorder and work in no-auth single-user
|
||||
// mode by attributing requests to the synthetic "local" user.
|
||||
routes.RegisterUsageRoutes(e, application)
|
||||
routes.RegisterPIIRoutes(e, application)
|
||||
routes.RegisterMiddlewareRoutes(e, application)
|
||||
|
||||
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
||||
var opcache *galleryop.OpCache
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
opcache = galleryop.NewOpCache(application.GalleryService())
|
||||
// In distributed mode, wire the NATS client + gallery store so this
|
||||
// replica's OpCache stays in sync with peers — without this the
|
||||
// /api/operations endpoint returns whatever this single replica
|
||||
// happened to admit, and a load-balanced UI poll alternates between
|
||||
// "operation visible" and "operation gone" between replicas.
|
||||
if d := application.Distributed(); d != nil {
|
||||
opcache.SetMessagingClient(d.Nats)
|
||||
if d.DistStores != nil && d.DistStores.Gallery != nil {
|
||||
opcache.SetGalleryStore(d.DistStores.Gallery)
|
||||
}
|
||||
if err := opcache.Start(application.ApplicationConfig().Context); err != nil {
|
||||
xlog.Warn("OpCache distributed subscribe failed; running standalone", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
|
||||
@@ -421,7 +403,7 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
}
|
||||
}
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
|
||||
// Distributed SSE routes (job progress + agent events via NATS)
|
||||
if d := application.Distributed(); d != nil {
|
||||
|
||||
@@ -38,15 +38,9 @@ func InitDB(databaseURL string) (*gorm.DB, error) {
|
||||
}
|
||||
|
||||
// Backfill: users created before the provider column existed have an empty
|
||||
// provider - treat them as local accounts so the UI can identify them.
|
||||
// provider — treat them as local accounts so the UI can identify them.
|
||||
db.Exec("UPDATE users SET provider = ? WHERE provider = '' OR provider IS NULL", ProviderLocal)
|
||||
|
||||
// Backfill: pre-feature usage_records have no source column. Classify them so the
|
||||
// new per-source aggregators include them.
|
||||
if err := BackfillUsageSource(db); err != nil {
|
||||
return nil, fmt.Errorf("failed to backfill usage source: %w", err)
|
||||
}
|
||||
|
||||
// Create composite index on users(provider, subject) for fast OAuth lookups
|
||||
if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil {
|
||||
// Ignore error on postgres if index already exists
|
||||
|
||||
@@ -16,10 +16,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
contextKeyUser = "auth_user"
|
||||
contextKeyRole = "auth_role"
|
||||
contextKeyAPIKey = "auth_apikey"
|
||||
contextKeySource = "auth_source"
|
||||
contextKeyUser = "auth_user"
|
||||
contextKeyRole = "auth_role"
|
||||
)
|
||||
|
||||
// Middleware returns an Echo middleware that handles authentication.
|
||||
@@ -77,7 +75,6 @@ func Middleware(db *gorm.DB, appConfig *config.ApplicationConfig) echo.Middlewar
|
||||
}
|
||||
c.Set(contextKeyUser, syntheticUser)
|
||||
c.Set(contextKeyRole, RoleAdmin)
|
||||
c.Set(contextKeySource, UsageSourceLegacy)
|
||||
authenticated = true
|
||||
}
|
||||
}
|
||||
@@ -216,20 +213,6 @@ func GetUserRole(c echo.Context) string {
|
||||
return role
|
||||
}
|
||||
|
||||
// GetAPIKey returns the resolved API key from the echo context, or nil.
|
||||
// Nil for session-cookie and legacy-env-key authentication.
|
||||
func GetAPIKey(c echo.Context) *UserAPIKey {
|
||||
k, _ := c.Get(contextKeyAPIKey).(*UserAPIKey)
|
||||
return k
|
||||
}
|
||||
|
||||
// GetSource returns the request's authentication source: UsageSourceAPIKey,
|
||||
// UsageSourceWeb, UsageSourceLegacy, or empty if no authentication was performed.
|
||||
func GetSource(c echo.Context) string {
|
||||
s, _ := c.Get(contextKeySource).(string)
|
||||
return s
|
||||
}
|
||||
|
||||
// RequireRouteFeature returns a global middleware that checks the user has access
|
||||
// to the feature required by the matched route. It uses the RouteFeatureRegistry
|
||||
// to look up the required feature for each route pattern + HTTP method.
|
||||
@@ -438,67 +421,47 @@ func RequireQuota(db *gorm.DB) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
// tryAuthenticate attempts to authenticate the request using the database.
|
||||
//
|
||||
// On success it returns the user and, as a side effect, sets the following
|
||||
// values on the Echo context:
|
||||
// - contextKeySource ("auth_source"): always set, one of UsageSourceWeb /
|
||||
// UsageSourceAPIKey. UsageSourceLegacy is set elsewhere by the parent
|
||||
// Middleware when a legacy env key matches.
|
||||
// - contextKeyAPIKey ("auth_apikey"): set to the resolved *UserAPIKey for
|
||||
// named-key branches (Bearer, x-api-key, xi-api-key, token cookie).
|
||||
// - "_auth_session": session record, used by Middleware to drive cookie
|
||||
// rotation. Only set on the session-cookie branch.
|
||||
//
|
||||
// contextKeyUser and contextKeyRole are populated by the parent Middleware
|
||||
// after this function returns.
|
||||
func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User {
|
||||
hmacSecret := appConfig.Auth.APIKeyHMACSecret
|
||||
|
||||
// a. Session cookie -> web UI
|
||||
// a. Session cookie
|
||||
if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" {
|
||||
if user, session := ValidateSession(db, cookie.Value, hmacSecret); user != nil {
|
||||
// Store session for rotation check in middleware
|
||||
c.Set("_auth_session", session)
|
||||
c.Set(contextKeySource, UsageSourceWeb)
|
||||
return user
|
||||
}
|
||||
}
|
||||
|
||||
// b. Authorization: Bearer
|
||||
// b. Authorization: Bearer token
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
// b1. Session token via Bearer -> still web UI
|
||||
// Try as session ID first
|
||||
if user, _ := ValidateSession(db, token, hmacSecret); user != nil {
|
||||
c.Set(contextKeySource, UsageSourceWeb)
|
||||
return user
|
||||
}
|
||||
|
||||
// b2. Named API key
|
||||
// Try as user API key
|
||||
if key, err := ValidateAPIKey(db, token, hmacSecret); err == nil {
|
||||
c.Set(contextKeySource, UsageSourceAPIKey)
|
||||
c.Set(contextKeyAPIKey, key)
|
||||
return &key.User
|
||||
}
|
||||
}
|
||||
|
||||
// c. x-api-key / xi-api-key -> named API key
|
||||
// c. x-api-key / xi-api-key headers
|
||||
for _, header := range []string{"x-api-key", "xi-api-key"} {
|
||||
if k := c.Request().Header.Get(header); k != "" {
|
||||
if apiKey, err := ValidateAPIKey(db, k, hmacSecret); err == nil {
|
||||
c.Set(contextKeySource, UsageSourceAPIKey)
|
||||
c.Set(contextKeyAPIKey, apiKey)
|
||||
if key := c.Request().Header.Get(header); key != "" {
|
||||
if apiKey, err := ValidateAPIKey(db, key, hmacSecret); err == nil {
|
||||
return &apiKey.User
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// d. token cookie -> named API key
|
||||
// d. token cookie (legacy)
|
||||
if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" {
|
||||
// Try as user API key
|
||||
if key, err := ValidateAPIKey(db, cookie.Value, hmacSecret); err == nil {
|
||||
c.Set(contextKeySource, UsageSourceAPIKey)
|
||||
c.Set(contextKeyAPIKey, key)
|
||||
return &key.User
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,122 +303,4 @@ var _ = Describe("Auth Middleware", func() {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("auth context plumbing for usage source", func() {
|
||||
// probeApp builds a minimal echo app with the auth middleware and a single
|
||||
// "/probe" route that captures the user, source, and apikey from context.
|
||||
type probe struct {
|
||||
user *auth.User
|
||||
source string
|
||||
key *auth.UserAPIKey
|
||||
}
|
||||
probeApp := func(db *gorm.DB, appConfig *config.ApplicationConfig, p *probe) *echo.Echo {
|
||||
e := echo.New()
|
||||
e.Use(auth.Middleware(db, appConfig))
|
||||
e.GET("/probe", func(c echo.Context) error {
|
||||
p.user = auth.GetUser(c)
|
||||
p.source = auth.GetSource(c)
|
||||
p.key = auth.GetAPIKey(c)
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
return e
|
||||
}
|
||||
|
||||
It("session cookie sets source=web, apikey=nil", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
token := createTestSession(db, user.ID)
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withSessionCookie(token))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.user).ToNot(BeNil())
|
||||
Expect(p.user.ID).To(Equal(user.ID))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(p.key).To(BeNil())
|
||||
})
|
||||
|
||||
It("Bearer session token sets source=web, apikey=nil", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
token := createTestSession(db, user.ID)
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withBearerToken(token))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.user).ToNot(BeNil())
|
||||
Expect(p.user.ID).To(Equal(user.ID))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(p.key).To(BeNil())
|
||||
})
|
||||
|
||||
It("Bearer API key sets source=apikey and exposes the resolved *UserAPIKey", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
plaintext, key, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withBearerToken(plaintext))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(p.key).ToNot(BeNil())
|
||||
Expect(p.key.ID).To(Equal(key.ID))
|
||||
})
|
||||
|
||||
It("x-api-key header sets source=apikey", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withXApiKey(plaintext))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(p.key).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("token cookie sets source=apikey", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal)
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withTokenCookie(plaintext))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(p.key).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("legacy env key sets source=legacy, apikey=nil", func() {
|
||||
db := testDB()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
appConfig.ApiKeys = []string{"legacy-secret"}
|
||||
|
||||
var p probe
|
||||
app := probeApp(db, appConfig, &p)
|
||||
rec := doRequest(app, http.MethodGet, "/probe", withBearerToken("legacy-secret"))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(p.source).To(Equal(auth.UsageSourceLegacy))
|
||||
Expect(p.key).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,43 +5,14 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Source classification for a UsageRecord.
|
||||
const (
|
||||
UsageSourceAPIKey = "apikey" // request authenticated with a named UserAPIKey
|
||||
UsageSourceWeb = "web" // request authenticated with a session cookie (web UI)
|
||||
UsageSourceLegacy = "legacy" // request authenticated with an env-configured legacy key
|
||||
)
|
||||
|
||||
// UsageRecord represents a single API request's token usage.
|
||||
//
|
||||
// Model semantics: Model is the legacy column kept for backward-compatible
|
||||
// aggregation; new code should write RequestedModel (what the client asked
|
||||
// for) and ServedModel (what actually ran after routing). When no router
|
||||
// is in play, all three are equal.
|
||||
//
|
||||
// PreFilterPromptTokens vs PromptTokens: PromptTokens is the count after
|
||||
// PII redaction (i.e., what the backend processed and was billed for).
|
||||
// PreFilterPromptTokens is the count of the original prompt before any
|
||||
// PII filtering; PostFilterPromptTokens duplicates PromptTokens for
|
||||
// queryability symmetry. For non-PII paths PreFilterPromptTokens ==
|
||||
// PostFilterPromptTokens == PromptTokens.
|
||||
type UsageRecord struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||
UserName string `gorm:"size:255"`
|
||||
|
||||
// Source classifies how the request authenticated. One of UsageSource* constants.
|
||||
// Empty for pre-feature rows until the InitDB backfill runs.
|
||||
Source string `gorm:"size:16;index:idx_usage_source"`
|
||||
// APIKeyID is the UserAPIKey.ID when Source == UsageSourceAPIKey. Nil otherwise.
|
||||
APIKeyID *string `gorm:"size:36;index:idx_usage_apikey"`
|
||||
// APIKeyName is a snapshot of UserAPIKey.Name at write time. Survives key deletion.
|
||||
APIKeyName string `gorm:"size:255"`
|
||||
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||
UserName string `gorm:"size:255"`
|
||||
Model string `gorm:"size:255;index"`
|
||||
Endpoint string `gorm:"size:255"`
|
||||
PromptTokens int64
|
||||
@@ -49,22 +20,6 @@ type UsageRecord struct {
|
||||
TotalTokens int64
|
||||
Duration int64 // milliseconds
|
||||
CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
|
||||
|
||||
// Routing extension fields. Nullable / zero-valued for legacy rows.
|
||||
RequestedModel string `gorm:"size:255;index"`
|
||||
ServedModel string `gorm:"size:255;index"`
|
||||
PreFilterPromptTokens int64 // tokens the client sent before PII redaction
|
||||
PostFilterPromptTokens int64 // tokens after redaction (== PromptTokens unless filter shrunk it)
|
||||
CachedTokens int64 // backend-reported KV-cache hit tokens
|
||||
PrefillTokens int64 // backend-reported prefill tokens (subset of prompt)
|
||||
DraftTokens int64 // speculative-decoding draft tokens
|
||||
PricingVersionID string `gorm:"size:64;index"` // FK to pricing_version; "" when no pricing was applied
|
||||
CostUSD float64 // computed at insert when pricing is available; 0 with empty PricingVersionID = unknown
|
||||
|
||||
// Cross-subsystem correlation. Empty when the subsystem didn't run.
|
||||
CorrelationID string `gorm:"size:64;index"`
|
||||
RouterDecisionID string `gorm:"size:64;index"`
|
||||
PIIEventID string `gorm:"size:64"`
|
||||
}
|
||||
|
||||
// RecordUsage inserts a usage record.
|
||||
@@ -75,12 +30,9 @@ func RecordUsage(db *gorm.DB, record *UsageRecord) error {
|
||||
// UsageBucket is an aggregated time bucket for the dashboard.
|
||||
type UsageBucket struct {
|
||||
Bucket string `json:"bucket"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Model string `json:"model"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
UserName string `json:"user_name,omitempty"`
|
||||
Source string `json:"source,omitempty"`
|
||||
APIKeyID string `json:"api_key_id,omitempty"`
|
||||
APIKeyName string `json:"api_key_name,omitempty"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
@@ -167,28 +119,6 @@ func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) {
|
||||
return buckets, nil
|
||||
}
|
||||
|
||||
// BackfillUsageSource sets the Source column on pre-feature usage rows.
|
||||
// Idempotent: only touches rows where source is NULL or empty.
|
||||
// - rows whose user_id == "legacy-api-key" -> UsageSourceLegacy
|
||||
// - everything else -> UsageSourceWeb
|
||||
func BackfillUsageSource(db *gorm.DB) error {
|
||||
// Legacy first (more specific predicate)
|
||||
if err := db.Exec(
|
||||
`UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '') AND user_id = ?`,
|
||||
UsageSourceLegacy, "legacy-api-key",
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("backfill legacy usage source: %w", err)
|
||||
}
|
||||
// Everything else -> web
|
||||
if err := db.Exec(
|
||||
`UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '')`,
|
||||
UsageSourceWeb,
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("backfill web usage source: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllUsage returns aggregated usage for all users (admin). Optional userID filter.
|
||||
func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
||||
sqlite := isSQLiteDB(db)
|
||||
@@ -219,257 +149,3 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
||||
}
|
||||
return buckets, nil
|
||||
}
|
||||
|
||||
// TotalsEntry is a token+request roll-up.
|
||||
type TotalsEntry struct {
|
||||
Tokens int64 `json:"tokens"`
|
||||
Requests int64 `json:"requests"`
|
||||
}
|
||||
|
||||
// KeyTotal is the per-key roll-up returned by sources endpoints. UserID and
|
||||
// UserName are snapshotted from the UsageRecord so revoked-and-deleted keys
|
||||
// still carry their owner attribution in admin views.
|
||||
type KeyTotal struct {
|
||||
APIKeyID string `json:"api_key_id"`
|
||||
APIKeyName string `json:"api_key_name"`
|
||||
UserID string `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Requests int64 `json:"requests"`
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
}
|
||||
|
||||
// UserSourceTotal is a per-(user, source) roll-up for sources that don't carry
|
||||
// a named API key identity (web, legacy). It exists so admin views can show
|
||||
// which user generated each block of Web UI / legacy traffic; the per-apikey
|
||||
// breakdown for source=apikey already lives in KeyTotal.
|
||||
type UserSourceTotal struct {
|
||||
Source string `json:"source"`
|
||||
UserID string `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Requests int64 `json:"requests"`
|
||||
}
|
||||
|
||||
// SourceTotals summarises a per-source breakdown.
|
||||
type SourceTotals struct {
|
||||
BySource map[string]TotalsEntry `json:"by_source"`
|
||||
ByKey []KeyTotal `json:"by_key"` // server-sorted desc by tokens, capped
|
||||
ByUserSource []UserSourceTotal `json:"by_user_source,omitempty"` // populated only when includeLegacy=true
|
||||
GrandTotal TotalsEntry `json:"grand_total"`
|
||||
}
|
||||
|
||||
const maxKeyTotals = 200
|
||||
|
||||
// GetUserUsageBySource returns per-source aggregated usage for one user. Legacy
|
||||
// is excluded by design (visible to admins only via the admin variant).
|
||||
func GetUserUsageBySource(db *gorm.DB, userID, period string) ([]UsageBucket, SourceTotals, error) {
|
||||
sqlite := isSQLiteDB(db)
|
||||
since, dateFmt := periodToWindow(period, sqlite)
|
||||
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||
|
||||
query := db.Model(&UsageRecord{}).
|
||||
Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
|
||||
"SUM(prompt_tokens) as prompt_tokens, "+
|
||||
"SUM(completion_tokens) as completion_tokens, "+
|
||||
"SUM(total_tokens) as total_tokens, "+
|
||||
"COUNT(*) as request_count").
|
||||
Where("user_id = ?", userID).
|
||||
Where("source <> ?", UsageSourceLegacy).
|
||||
Group("bucket, source, api_key_id, api_key_name").
|
||||
Order("bucket ASC")
|
||||
|
||||
if !since.IsZero() {
|
||||
query = query.Where("created_at >= ?", since)
|
||||
}
|
||||
|
||||
var buckets []UsageBucket
|
||||
if err := query.Find(&buckets).Error; err != nil {
|
||||
return nil, SourceTotals{}, err
|
||||
}
|
||||
|
||||
totals := computeSourceTotals(db, userID, "", since, false)
|
||||
return buckets, totals, nil
|
||||
}
|
||||
|
||||
// computeSourceTotals rolls up by_source / by_key / grand_total.
|
||||
// userID/apiKeyID are optional filters. includeLegacy controls whether the
|
||||
// legacy bucket is exposed (admin-only).
|
||||
func computeSourceTotals(db *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) SourceTotals {
|
||||
totals := SourceTotals{BySource: map[string]TotalsEntry{}}
|
||||
|
||||
bySourceQ := db.Model(&UsageRecord{}).
|
||||
Select("source, SUM(total_tokens) as tokens, COUNT(*) as requests").
|
||||
Group("source")
|
||||
bySourceQ = applyFilters(bySourceQ, userID, apiKeyID, since, includeLegacy)
|
||||
|
||||
var bySourceRows []struct {
|
||||
Source string
|
||||
Tokens int64
|
||||
Requests int64
|
||||
}
|
||||
if err := bySourceQ.Scan(&bySourceRows).Error; err != nil {
|
||||
xlog.Warn("computeSourceTotals: by-source Scan failed", "error", err)
|
||||
return totals
|
||||
}
|
||||
for _, r := range bySourceRows {
|
||||
totals.BySource[r.Source] = TotalsEntry{Tokens: r.Tokens, Requests: r.Requests}
|
||||
totals.GrandTotal.Tokens += r.Tokens
|
||||
totals.GrandTotal.Requests += r.Requests
|
||||
}
|
||||
|
||||
byKeyQ := db.Model(&UsageRecord{}).
|
||||
Select("COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
|
||||
"user_id, user_name, "+
|
||||
"SUM(total_tokens) as tokens, COUNT(*) as requests, MAX(created_at) as last_used").
|
||||
Where("api_key_id IS NOT NULL AND api_key_id <> ''").
|
||||
Group("api_key_id, api_key_name, user_id, user_name").
|
||||
Order("tokens DESC").
|
||||
Limit(maxKeyTotals)
|
||||
byKeyQ = applyFilters(byKeyQ, userID, apiKeyID, since, includeLegacy)
|
||||
|
||||
// Iterate Rows() manually because MAX(created_at) is returned as a string by
|
||||
// the SQLite driver, and Go's database/sql refuses to scan that into
|
||||
// *time.Time. Postgres returns a proper timestamp. We accept both shapes
|
||||
// via a Rows.Scan into a string column, then parse uniformly.
|
||||
rows, err := byKeyQ.Rows()
|
||||
if err != nil {
|
||||
xlog.Warn("computeSourceTotals: by-key Rows() failed", "error", err)
|
||||
} else {
|
||||
defer func() { _ = rows.Close() }()
|
||||
out := make([]KeyTotal, 0)
|
||||
for rows.Next() {
|
||||
var (
|
||||
apiKeyID, apiKeyName, userIDCol, userName, lastUsedRaw string
|
||||
tokens, requests int64
|
||||
)
|
||||
if scanErr := rows.Scan(&apiKeyID, &apiKeyName, &userIDCol, &userName, &tokens, &requests, &lastUsedRaw); scanErr != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, KeyTotal{
|
||||
APIKeyID: apiKeyID,
|
||||
APIKeyName: apiKeyName,
|
||||
UserID: userIDCol,
|
||||
UserName: userName,
|
||||
Tokens: tokens,
|
||||
Requests: requests,
|
||||
LastUsed: parseLastUsedString(lastUsedRaw),
|
||||
})
|
||||
}
|
||||
if rerr := rows.Err(); rerr != nil {
|
||||
xlog.Warn("computeSourceTotals: by-key rows iteration failed", "error", rerr)
|
||||
}
|
||||
totals.ByKey = out
|
||||
}
|
||||
|
||||
// by_user_source: only populated for admin callers (includeLegacy=true) so
|
||||
// they can attribute Web UI / legacy traffic to specific users. Per-apikey
|
||||
// rows already carry user info via KeyTotal above, so this query only
|
||||
// covers source != apikey.
|
||||
if includeLegacy {
|
||||
byUserSourceQ := db.Model(&UsageRecord{}).
|
||||
Select("source, user_id, user_name, "+
|
||||
"SUM(total_tokens) as tokens, COUNT(*) as requests").
|
||||
Where("source <> ?", UsageSourceAPIKey).
|
||||
Group("source, user_id, user_name").
|
||||
Order("tokens DESC")
|
||||
byUserSourceQ = applyFilters(byUserSourceQ, userID, apiKeyID, since, includeLegacy)
|
||||
|
||||
var byUserSourceRows []UserSourceTotal
|
||||
if scanErr := byUserSourceQ.Scan(&byUserSourceRows).Error; scanErr != nil {
|
||||
xlog.Warn("computeSourceTotals: by-user-source Scan failed", "error", scanErr)
|
||||
} else {
|
||||
totals.ByUserSource = byUserSourceRows
|
||||
}
|
||||
}
|
||||
|
||||
return totals
|
||||
}
|
||||
|
||||
// parseLastUsedString converts the textual MAX(created_at) value returned by
|
||||
// SQLite (or any driver that surfaces the timestamp as a string) into a
|
||||
// time.Time. Returns the zero time on parse failure.
|
||||
func parseLastUsedString(s string) time.Time {
|
||||
if s == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
// GORM's SQLite driver emits Go's default time formatting. Try the formats
|
||||
// it commonly produces, falling back to RFC3339Nano.
|
||||
layouts := []string{
|
||||
"2006-01-02 15:04:05.999999999 -0700 MST",
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339Nano,
|
||||
time.RFC3339,
|
||||
}
|
||||
for _, layout := range layouts {
|
||||
if t, err := time.Parse(layout, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
xlog.Warn("parseLastUsedString: unrecognised format", "value", s)
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// GetAllUsageBySource is the admin variant of GetUserUsageBySource.
|
||||
// Optional filters: userID and apiKeyID. Legacy is included.
|
||||
// truncated == true iff the per-key roll-up was capped at maxKeyTotals.
|
||||
func GetAllUsageBySource(db *gorm.DB, period, userID, apiKeyID string) ([]UsageBucket, SourceTotals, bool, error) {
|
||||
sqlite := isSQLiteDB(db)
|
||||
since, dateFmt := periodToWindow(period, sqlite)
|
||||
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||
|
||||
query := db.Model(&UsageRecord{}).
|
||||
Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
|
||||
"user_id, user_name, "+
|
||||
"SUM(prompt_tokens) as prompt_tokens, "+
|
||||
"SUM(completion_tokens) as completion_tokens, "+
|
||||
"SUM(total_tokens) as total_tokens, "+
|
||||
"COUNT(*) as request_count").
|
||||
Group("bucket, source, api_key_id, api_key_name, user_id, user_name").
|
||||
Order("bucket ASC")
|
||||
|
||||
query = applyFilters(query, userID, apiKeyID, since, true)
|
||||
|
||||
var buckets []UsageBucket
|
||||
if err := query.Find(&buckets).Error; err != nil {
|
||||
return nil, SourceTotals{}, false, err
|
||||
}
|
||||
|
||||
totals := computeSourceTotals(db, userID, apiKeyID, since, true)
|
||||
|
||||
// Count distinct api_key_ids matching the filters. If > maxKeyTotals,
|
||||
// the by_key slice was capped and we signal truncation to the caller.
|
||||
truncated := false
|
||||
var distinct int64
|
||||
countQ := applyFilters(
|
||||
db.Model(&UsageRecord{}).
|
||||
Distinct("api_key_id").
|
||||
Where("api_key_id IS NOT NULL AND api_key_id <> ''"),
|
||||
userID, apiKeyID, since, true,
|
||||
)
|
||||
if err := countQ.Count(&distinct).Error; err != nil {
|
||||
xlog.Warn("GetAllUsageBySource: distinct api_key_id count failed", "error", err)
|
||||
} else {
|
||||
truncated = distinct > maxKeyTotals
|
||||
}
|
||||
|
||||
return buckets, totals, truncated, nil
|
||||
}
|
||||
|
||||
func applyFilters(q *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) *gorm.DB {
|
||||
if userID != "" {
|
||||
q = q.Where("user_id = ?", userID)
|
||||
}
|
||||
if apiKeyID != "" {
|
||||
q = q.Where("api_key_id = ?", apiKeyID)
|
||||
}
|
||||
if !since.IsZero() {
|
||||
q = q.Where("created_at >= ?", since)
|
||||
}
|
||||
if !includeLegacy {
|
||||
q = q.Where("source <> ?", UsageSourceLegacy)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
@@ -3,13 +3,11 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ = Describe("Usage", func() {
|
||||
@@ -160,275 +158,4 @@ var _ = Describe("Usage", func() {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Usage source backfill", func() {
|
||||
It("backfills 'web' for pre-feature rows", func() {
|
||||
db := testDB()
|
||||
|
||||
rawDB, err := db.DB()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rawDB.Exec(
|
||||
`INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`,
|
||||
"user-x", "gpt-4", time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||
|
||||
var loaded auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "user-x").First(&loaded).Error).To(Succeed())
|
||||
Expect(loaded.Source).To(Equal(auth.UsageSourceWeb))
|
||||
})
|
||||
|
||||
It("backfills 'legacy' for pre-feature rows with legacy-api-key user_id", func() {
|
||||
db := testDB()
|
||||
|
||||
rawDB, err := db.DB()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rawDB.Exec(
|
||||
`INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`,
|
||||
"legacy-api-key", "gpt-4", time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||
|
||||
var loaded auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "legacy-api-key").First(&loaded).Error).To(Succeed())
|
||||
Expect(loaded.Source).To(Equal(auth.UsageSourceLegacy))
|
||||
})
|
||||
|
||||
It("is idempotent on re-run", func() {
|
||||
db := testDB()
|
||||
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||
Expect(auth.BackfillUsageSource(db)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UsageRecord with source fields", func() {
|
||||
It("persists Source, APIKeyID, APIKeyName", func() {
|
||||
db := testDB()
|
||||
keyID := "key-uuid-1"
|
||||
record := &auth.UsageRecord{
|
||||
UserID: "user-1",
|
||||
UserName: "Test User",
|
||||
Source: auth.UsageSourceAPIKey,
|
||||
APIKeyID: &keyID,
|
||||
APIKeyName: "ci-runner",
|
||||
Model: "gpt-4",
|
||||
Endpoint: "/v1/chat/completions",
|
||||
TotalTokens: 150,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
Expect(auth.RecordUsage(db, record)).To(Succeed())
|
||||
|
||||
var loaded auth.UsageRecord
|
||||
Expect(db.First(&loaded, record.ID).Error).To(Succeed())
|
||||
Expect(loaded.Source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(loaded.APIKeyID).ToNot(BeNil())
|
||||
Expect(*loaded.APIKeyID).To(Equal("key-uuid-1"))
|
||||
Expect(loaded.APIKeyName).To(Equal("ci-runner"))
|
||||
})
|
||||
|
||||
It("allows nil APIKeyID for web/legacy sources", func() {
|
||||
db := testDB()
|
||||
record := &auth.UsageRecord{
|
||||
UserID: "user-1",
|
||||
Source: auth.UsageSourceWeb,
|
||||
Model: "gpt-4",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
Expect(auth.RecordUsage(db, record)).To(Succeed())
|
||||
|
||||
var loaded auth.UsageRecord
|
||||
Expect(db.First(&loaded, record.ID).Error).To(Succeed())
|
||||
Expect(loaded.Source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(loaded.APIKeyID).To(BeNil())
|
||||
Expect(loaded.APIKeyName).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetUserUsageBySource", func() {
|
||||
insert := func(db *gorm.DB, userID, source, keyID, keyName string, tokens int64, when time.Time) {
|
||||
rec := &auth.UsageRecord{
|
||||
UserID: userID,
|
||||
Source: source,
|
||||
Model: "gpt-4",
|
||||
TotalTokens: tokens,
|
||||
CreatedAt: when,
|
||||
}
|
||||
if keyID != "" {
|
||||
rec.APIKeyID = &keyID
|
||||
rec.APIKeyName = keyName
|
||||
}
|
||||
Expect(auth.RecordUsage(db, rec)).To(Succeed())
|
||||
}
|
||||
|
||||
It("returns only the caller's rows, never legacy", func() {
|
||||
db := testDB()
|
||||
now := time.Now()
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "k1", "ci", 100, now)
|
||||
insert(db, "alice", auth.UsageSourceWeb, "", "", 50, now)
|
||||
insert(db, "alice", auth.UsageSourceLegacy, "", "", 30, now)
|
||||
insert(db, "bob", auth.UsageSourceAPIKey, "k2", "bobk", 90, now)
|
||||
|
||||
buckets, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, b := range buckets {
|
||||
Expect(b.UserID).To(Or(BeEmpty(), Equal("alice")))
|
||||
Expect(b.Source).ToNot(Equal(auth.UsageSourceLegacy))
|
||||
}
|
||||
|
||||
Expect(totals.GrandTotal.Tokens).To(Equal(int64(150)))
|
||||
Expect(totals.BySource[auth.UsageSourceAPIKey].Tokens).To(Equal(int64(100)))
|
||||
Expect(totals.BySource[auth.UsageSourceWeb].Tokens).To(Equal(int64(50)))
|
||||
_, hasLegacy := totals.BySource[auth.UsageSourceLegacy]
|
||||
Expect(hasLegacy).To(BeFalse())
|
||||
})
|
||||
|
||||
It("snapshots survive key deletion", func() {
|
||||
db := testDB()
|
||||
now := time.Now()
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "deleted-key", "old-name", 42, now)
|
||||
_, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.ByKey).To(HaveLen(1))
|
||||
Expect(totals.ByKey[0].APIKeyName).To(Equal("old-name"))
|
||||
Expect(totals.ByKey[0].APIKeyID).To(Equal("deleted-key"))
|
||||
Expect(totals.ByKey[0].LastUsed).ToNot(BeZero())
|
||||
Expect(totals.ByKey[0].LastUsed).To(BeTemporally("~", now, 2*time.Second))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetAllUsageBySource", func() {
|
||||
insert := func(db *gorm.DB, userID, source, keyID string, tokens int64) {
|
||||
rec := &auth.UsageRecord{
|
||||
UserID: userID,
|
||||
Source: source,
|
||||
Model: "gpt-4",
|
||||
TotalTokens: tokens,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if keyID != "" {
|
||||
rec.APIKeyID = &keyID
|
||||
rec.APIKeyName = "name-" + keyID
|
||||
}
|
||||
Expect(auth.RecordUsage(db, rec)).To(Succeed())
|
||||
}
|
||||
|
||||
It("includes legacy for admins", func() {
|
||||
db := testDB()
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10)
|
||||
insert(db, "legacy-api-key", auth.UsageSourceLegacy, "", 5)
|
||||
|
||||
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.BySource).To(HaveKey(auth.UsageSourceLegacy))
|
||||
Expect(totals.BySource[auth.UsageSourceLegacy].Tokens).To(Equal(int64(5)))
|
||||
})
|
||||
|
||||
It("filters by user_id AND api_key_id", func() {
|
||||
db := testDB()
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10)
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, "k2", 20)
|
||||
insert(db, "bob", auth.UsageSourceAPIKey, "k3", 30)
|
||||
|
||||
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "alice", "k2")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.GrandTotal.Tokens).To(Equal(int64(20)))
|
||||
})
|
||||
|
||||
It("sets truncated=true when by_key exceeds the cap", func() {
|
||||
db := testDB()
|
||||
for i := 0; i < 210; i++ {
|
||||
insert(db, "alice", auth.UsageSourceAPIKey, fmt.Sprintf("key-%03d", i), int64(210-i))
|
||||
}
|
||||
|
||||
_, totals, truncated, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(truncated).To(BeTrue())
|
||||
Expect(totals.ByKey).To(HaveLen(200))
|
||||
Expect(totals.ByKey[0].Tokens > totals.ByKey[199].Tokens).To(BeTrue())
|
||||
})
|
||||
|
||||
// insertNamed records a row with explicit user_id, user_name, source,
|
||||
// and optional api key snapshot. Used by the user-attribution tests
|
||||
// below which the older insert helper can't express.
|
||||
insertNamed := func(db *gorm.DB, userID, userName, source, keyID, keyName string, tokens int64) {
|
||||
rec := &auth.UsageRecord{
|
||||
UserID: userID,
|
||||
UserName: userName,
|
||||
Source: source,
|
||||
Model: "gpt-4",
|
||||
TotalTokens: tokens,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if keyID != "" {
|
||||
rec.APIKeyID = &keyID
|
||||
rec.APIKeyName = keyName
|
||||
}
|
||||
Expect(auth.RecordUsage(db, rec)).To(Succeed())
|
||||
}
|
||||
|
||||
It("attributes each KeyTotal to its owner user", func() {
|
||||
db := testDB()
|
||||
insertNamed(db, "alice", "Alice", auth.UsageSourceAPIKey, "k1", "ci-runner", 100)
|
||||
insertNamed(db, "bob", "Bob", auth.UsageSourceAPIKey, "k2", "lap", 50)
|
||||
|
||||
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.ByKey).To(HaveLen(2))
|
||||
|
||||
byID := map[string]auth.KeyTotal{}
|
||||
for _, k := range totals.ByKey {
|
||||
byID[k.APIKeyID] = k
|
||||
}
|
||||
Expect(byID["k1"].UserID).To(Equal("alice"))
|
||||
Expect(byID["k1"].UserName).To(Equal("Alice"))
|
||||
Expect(byID["k2"].UserID).To(Equal("bob"))
|
||||
Expect(byID["k2"].UserName).To(Equal("Bob"))
|
||||
})
|
||||
|
||||
It("breaks Web UI and legacy traffic out per user in by_user_source for admin", func() {
|
||||
db := testDB()
|
||||
// Alice and Bob both have Web UI traffic; a synthetic legacy user
|
||||
// also contributes. ByUserSource should expose one row per
|
||||
// (source, user) pair, never for source=apikey.
|
||||
insertNamed(db, "alice", "Alice", auth.UsageSourceWeb, "", "", 30)
|
||||
insertNamed(db, "bob", "Bob", auth.UsageSourceWeb, "", "", 70)
|
||||
insertNamed(db, "legacy-api-key", "API Key User", auth.UsageSourceLegacy, "", "", 10)
|
||||
insertNamed(db, "alice", "Alice", auth.UsageSourceAPIKey, "k1", "ci-runner", 5)
|
||||
|
||||
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(totals.ByUserSource).ToNot(BeEmpty())
|
||||
|
||||
for _, r := range totals.ByUserSource {
|
||||
Expect(r.Source).ToNot(Equal(auth.UsageSourceAPIKey))
|
||||
}
|
||||
|
||||
webByUser := map[string]int64{}
|
||||
legacyByUser := map[string]int64{}
|
||||
for _, r := range totals.ByUserSource {
|
||||
switch r.Source {
|
||||
case auth.UsageSourceWeb:
|
||||
webByUser[r.UserID] = r.Tokens
|
||||
case auth.UsageSourceLegacy:
|
||||
legacyByUser[r.UserID] = r.Tokens
|
||||
}
|
||||
}
|
||||
Expect(webByUser["alice"]).To(Equal(int64(30)))
|
||||
Expect(webByUser["bob"]).To(Equal(int64(70)))
|
||||
Expect(legacyByUser["legacy-api-key"]).To(Equal(int64(10)))
|
||||
})
|
||||
|
||||
It("does NOT populate by_user_source in the non-admin path", func() {
|
||||
db := testDB()
|
||||
insertNamed(db, "alice", "Alice", auth.UsageSourceWeb, "", "", 30)
|
||||
|
||||
_, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Non-admin path uses includeLegacy=false, so by_user_source stays nil.
|
||||
Expect(totals.ByUserSource).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestAnthropic(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Anthropic test suite")
|
||||
}
|
||||
@@ -10,13 +10,10 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -30,7 +27,7 @@ import (
|
||||
// @Param request body schema.AnthropicRequest true "query params"
|
||||
// @Success 200 {object} schema.AnthropicResponse "Response"
|
||||
// @Router /v1/messages [post]
|
||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc {
|
||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := uuid.New().String()
|
||||
|
||||
@@ -50,12 +47,6 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||
|
||||
xlog.Debug("Anthropic Messages endpoint configuration read", "config", cfg)
|
||||
|
||||
// Cloud-proxy bail. Same shape as the OpenAI chat endpoint —
|
||||
// forwards via the cloud-proxy gRPC backend.
|
||||
if cfg.IsCloudProxyBackendPassthrough() {
|
||||
return forwardCloudProxyAnthropicViaBackend(c, cfg, input, piiRedactor, piiEvents, ml, appConfig)
|
||||
}
|
||||
|
||||
// Convert Anthropic messages to OpenAI format for internal processing
|
||||
openAIMessages := convertAnthropicToOpenAIMessages(input)
|
||||
|
||||
@@ -141,7 +132,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||
xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
|
||||
|
||||
if input.Stream {
|
||||
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator, piiRedactor, piiEvents)
|
||||
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||
}
|
||||
|
||||
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||
@@ -322,45 +313,17 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
xlog.Debug("Anthropic Response", "response", string(respData))
|
||||
}
|
||||
|
||||
middleware.StampUsage(c, input.Model, tokenUsage.Prompt, tokenUsage.Completion)
|
||||
|
||||
return c.JSON(200, resp)
|
||||
} // end MCP iteration loop
|
||||
|
||||
return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached")
|
||||
}
|
||||
|
||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator, piiRedactor *pii.Redactor, piiEvents pii.EventStore) error {
|
||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Per-stream PII filter — same gating as the OpenAI chat path. The
|
||||
// filter is wire-format-agnostic; we feed it the text portion of
|
||||
// each text_delta and emit only what's safe to send. The filter
|
||||
// holds back a tail of size MaxPatternLength-1 so a pattern split
|
||||
// across chunk boundaries still gets masked. When PII is disabled
|
||||
// for this model the filter is nil and emits flow unchanged.
|
||||
var streamPIIFilter *pii.StreamFilter
|
||||
if piiRedactor != nil && cfg.PIIIsEnabled() {
|
||||
correlationID := c.Request().Header.Get("x-request-id")
|
||||
userID := ""
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
userID = u.ID
|
||||
}
|
||||
var overrides map[string]pii.Action
|
||||
if raw := cfg.PIIPatternOverrides(); len(raw) > 0 {
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
}
|
||||
streamPIIFilter = pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID)
|
||||
}
|
||||
|
||||
// Send message_start event
|
||||
messageStart := schema.AnthropicStreamEvent{
|
||||
Type: "message_start",
|
||||
@@ -440,7 +403,6 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
|
||||
if len(toolCalls) > toolCallsEmitted {
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
@@ -481,20 +443,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
|
||||
if !inToolCall && token != "" {
|
||||
out := token
|
||||
if streamPIIFilter != nil {
|
||||
out = streamPIIFilter.Push(token)
|
||||
}
|
||||
if out != "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: out,
|
||||
},
|
||||
})
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: token,
|
||||
},
|
||||
})
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -532,20 +488,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
// didn't already stream it (autoparser clears raw text, so
|
||||
// accumulatedContent will be empty in that case).
|
||||
if deltaContent != "" && !inToolCall && accumulatedContent == "" {
|
||||
out := deltaContent
|
||||
if streamPIIFilter != nil {
|
||||
out = streamPIIFilter.Push(deltaContent)
|
||||
}
|
||||
if out != "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: out,
|
||||
},
|
||||
})
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: deltaContent,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Emit tool_use blocks from ChatDeltas
|
||||
@@ -553,7 +503,6 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
collectedToolCalls = deltaToolCalls
|
||||
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
@@ -657,9 +606,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && accumulatedContent != "" && toolCallsEmitted == 0 {
|
||||
parsed := functions.ParseFunctionCall(accumulatedContent, cfg.FunctionsConfig)
|
||||
if len(parsed) > 0 {
|
||||
// Close the text content block (after flushing any
|
||||
// residual the streaming PII filter held back).
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex))
|
||||
// Close the text content block
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
@@ -699,12 +646,8 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
}
|
||||
|
||||
// No MCP tools to execute, close stream. drainStreamPIIToText
|
||||
// flushes any residual the streaming PII filter held back as
|
||||
// part of its trailing pattern-window before we close the
|
||||
// text content block.
|
||||
// No MCP tools to execute, close stream
|
||||
if !inToolCall {
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(0))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(0),
|
||||
@@ -730,8 +673,6 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
Type: "message_stop",
|
||||
})
|
||||
|
||||
middleware.StampUsage(c, input.Model, tokenUsage.Prompt, tokenUsage.Completion)
|
||||
|
||||
return nil
|
||||
} // end MCP iteration loop
|
||||
|
||||
@@ -752,30 +693,6 @@ func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool {
|
||||
|
||||
func intPtr(i int) *int { return &i }
|
||||
|
||||
// drainStreamPIIToText flushes any residual the streaming PII filter
|
||||
// has been holding back as part of its trailing pattern-window, and
|
||||
// emits it as one final text_delta into the named block before the
|
||||
// caller closes that block. Drain is idempotent: calling it twice on
|
||||
// the same filter returns "" the second time. Safe to call with a nil
|
||||
// filter (no-op).
|
||||
func drainStreamPIIToText(c echo.Context, sf *pii.StreamFilter, index *int) {
|
||||
if sf == nil {
|
||||
return
|
||||
}
|
||||
residual := sf.Drain()
|
||||
if residual == "" {
|
||||
return
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: index,
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: residual,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
@@ -971,19 +888,3 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
|
||||
|
||||
return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions()
|
||||
}
|
||||
|
||||
// forwardCloudProxyAnthropicViaBackend marshals the Anthropic request,
|
||||
// constructs the streaming PII filter (when applicable), and hands the
|
||||
// body off to the cloud-proxy gRPC backend. Model swap + upstream auth
|
||||
// headers are applied inside the backend; the filter is built here
|
||||
// because the auth/correlation context only exists in the echo handler.
|
||||
func forwardCloudProxyAnthropicViaBackend(c echo.Context, cfg *config.ModelConfig, input *schema.AnthropicRequest, piiRedactor *pii.Redactor, piiEvents pii.EventStore, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
body, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return sendAnthropicError(c, 400, "invalid_request_error", "cloudproxy: marshal request: "+err.Error())
|
||||
}
|
||||
|
||||
correlationID := c.Request().Header.Get("x-request-id")
|
||||
streamFilter := cloudproxy.BuildStreamFilter(c, cfg, input.Stream, piiRedactor, piiEvents, correlationID)
|
||||
return cloudproxy.ForwardViaBackend(c, cfg, body, streamFilter, ml, appConfig)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user