mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-25 17:18:18 -04:00
Compare commits
8 Commits
v4.3.1
...
fix/galler
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11c5fd677d | ||
|
|
b3300ef207 | ||
|
|
8d6548c0b9 | ||
|
|
b02e3ffe61 | ||
|
|
a891eedd08 | ||
|
|
06e777b75e | ||
|
|
90ea327178 | ||
|
|
6a80e23733 |
@@ -4,6 +4,7 @@
|
||||
.devcontainer
|
||||
models
|
||||
backends
|
||||
volumes
|
||||
examples/chatbot-ui/models
|
||||
backend/go/image/stablediffusion-ggml/build/
|
||||
backend/go/*/build
|
||||
@@ -21,3 +22,11 @@ __pycache__
|
||||
# backend virtual environments
|
||||
**/venv
|
||||
backend/python/**/source
|
||||
|
||||
# In-place llama.cpp clone + per-variant build copies. The Makefile
|
||||
# clones llama.cpp itself at the pinned LLAMA_VERSION; if a stale
|
||||
# local checkout is COPY'd into the image, the `llama.cpp:` target
|
||||
# sees the directory and skips re-cloning, so grpc-server.cpp ends
|
||||
# up compiled against whatever (likely older) commit the host had.
|
||||
backend/cpp/llama-cpp/llama.cpp
|
||||
backend/cpp/llama-cpp-*-build
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -26,6 +26,10 @@ go-bert
|
||||
LocalAI
|
||||
/local-ai
|
||||
/local-ai-launcher
|
||||
# Root-level build artifacts when running `go build ./...` against
|
||||
# Go backend packages whose main lives under backend/go/.
|
||||
/cloud-proxy
|
||||
/local-store
|
||||
# prevent above rules from omitting the helm chart
|
||||
!charts/*
|
||||
# prevent above rules from omitting the api/localai folder
|
||||
|
||||
15
Makefile
15
Makefile
@@ -69,7 +69,7 @@ else
|
||||
GORELEASER=$(shell which goreleaser)
|
||||
endif
|
||||
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/...
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/... ./backend/go/cloud-proxy/... ./backend/go/local-store/...
|
||||
|
||||
|
||||
.PHONY: all test build vendor lint lint-all
|
||||
@@ -268,12 +268,13 @@ prepare-e2e:
|
||||
run-e2e-image:
|
||||
docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests
|
||||
|
||||
test-e2e: build-mock-backend prepare-e2e run-e2e-image
|
||||
test-e2e: build-mock-backend build-cloud-proxy-backend prepare-e2e run-e2e-image
|
||||
@echo 'Running e2e tests'
|
||||
BUILD_TYPE=$(BUILD_TYPE) \
|
||||
LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
|
||||
$(MAKE) clean-mock-backend
|
||||
$(MAKE) clean-cloud-proxy-backend
|
||||
$(MAKE) teardown-e2e
|
||||
docker rmi localai-tests
|
||||
|
||||
@@ -1064,6 +1065,7 @@ BACKEND_DS4 = ds4|ds4|.|false|false
|
||||
# Golang backends
|
||||
BACKEND_PIPER = piper|golang|.|false|true
|
||||
BACKEND_LOCAL_STORE = local-store|golang|.|false|true
|
||||
BACKEND_CLOUD_PROXY = cloud-proxy|golang|.|false|true
|
||||
BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
@@ -1149,6 +1151,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DS4)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CLOUD_PROXY)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
@@ -1201,7 +1204,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
@@ -1213,6 +1216,12 @@ build-mock-backend: protogen-go
|
||||
clean-mock-backend:
|
||||
rm -f tests/e2e/mock-backend/mock-backend
|
||||
|
||||
build-cloud-proxy-backend: protogen-go
|
||||
$(GOCMD) build -o tests/e2e/mock-backend/cloud-proxy ./backend/go/cloud-proxy
|
||||
|
||||
clean-cloud-proxy-backend:
|
||||
rm -f tests/e2e/mock-backend/cloud-proxy
|
||||
|
||||
########################################################
|
||||
### UI E2E Test Server
|
||||
########################################################
|
||||
|
||||
@@ -37,6 +37,22 @@ service Backend {
|
||||
|
||||
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
||||
|
||||
// TokenClassify runs a token-classification (NER) model on the
|
||||
// supplied text and returns each detected entity span. Used by the
|
||||
// PII redactor's optional NER tier — the regex tier still handles
|
||||
// formatted hits cheaply, while this catches names, locations, and
|
||||
// other unformatted PII that regex misses.
|
||||
rpc TokenClassify(TokenClassifyRequest) returns (TokenClassifyResponse) {}
|
||||
|
||||
// Score evaluates the model's joint log-probability of each
|
||||
// supplied candidate continuation given a shared prompt. The
|
||||
// prompt's KV cache is computed once and reused across candidates.
|
||||
// Used for routing-policy multi-label classification, reranking,
|
||||
// calibrated confidence, and reward-model scoring — any task where
|
||||
// the consumer wants the model's confidence in a pre-specified
|
||||
// continuation rather than a generated one.
|
||||
rpc Score(ScoreRequest) returns (ScoreResponse) {}
|
||||
|
||||
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
|
||||
|
||||
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||
@@ -68,6 +84,23 @@ service Backend {
|
||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||
|
||||
// Forward proxies a raw HTTP request to an upstream provider. The
|
||||
// cloud-proxy backend implements this for passthrough-mode model
|
||||
// configs: the client wire format is preserved end-to-end (no
|
||||
// translation through internal proto), which means new provider
|
||||
// fields work the day they ship. Translation-mode proxies use the
|
||||
// standard Predict/PredictStream RPCs instead. Backends that don't
|
||||
// support this return UNIMPLEMENTED.
|
||||
//
|
||||
// The request is bidirectionally streamed so large bodies can flow
|
||||
// without buffering. In practice the first ForwardRequest carries
|
||||
// path, method, headers, and the initial body chunk; subsequent
|
||||
// messages append body chunks. The first ForwardReply carries the
|
||||
// upstream status and response headers; subsequent messages stream
|
||||
// body chunks (SSE frames or chunked transfer). Cancellation of the
|
||||
// gRPC context closes the upstream connection.
|
||||
rpc Forward(stream ForwardRequest) returns (stream ForwardReply) {}
|
||||
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -81,6 +114,76 @@ message MetricsResponse {
|
||||
int32 prompt_tokens_processed = 5;
|
||||
}
|
||||
|
||||
// TokenClassifyRequest carries the text to classify plus an optional
|
||||
// score threshold. The transformers backend interprets threshold as
|
||||
// the minimum confidence to include in the response; 0 = include all.
|
||||
message TokenClassifyRequest {
|
||||
string text = 1;
|
||||
float threshold = 2;
|
||||
}
|
||||
|
||||
// TokenClassifyEntity is one detected entity span. Byte offsets are
|
||||
// into the original UTF-8 text — start..end is a half-open range that
|
||||
// addresses the substring corresponding to entity_group.
|
||||
//
|
||||
// entity_group follows HuggingFace's aggregated-tag convention (e.g.
|
||||
// "PER", "LOC", "ORG", or a PII-specific label like "EMAIL" /
|
||||
// "SSN" depending on the model). The redactor's per-pattern action
|
||||
// map keys off this string.
|
||||
message TokenClassifyEntity {
|
||||
string entity_group = 1;
|
||||
int32 start = 2;
|
||||
int32 end = 3;
|
||||
float score = 4;
|
||||
string text = 5;
|
||||
}
|
||||
|
||||
message TokenClassifyResponse {
|
||||
repeated TokenClassifyEntity entities = 1;
|
||||
}
|
||||
|
||||
// ScoreRequest carries one shared prompt and one or more continuations
|
||||
// to score against it. The backend tokenises the prompt once and reuses
|
||||
// the resulting KV cache across all candidates in this request.
|
||||
message ScoreRequest {
|
||||
string prompt = 1;
|
||||
repeated string candidates = 2;
|
||||
// Return per-token logprobs for each candidate when true. Default
|
||||
// false to keep the wire response small; the joint log_prob field
|
||||
// covers the common ranking case.
|
||||
bool include_token_logprobs = 3;
|
||||
// When true, the response also populates length_normalized_log_prob
|
||||
// (joint log-prob divided by candidate token count). Useful when
|
||||
// candidates differ in length and the consumer wants a per-token
|
||||
// measure comparable across them (PMI-style scoring).
|
||||
bool length_normalize = 4;
|
||||
}
|
||||
|
||||
// CandidateScore is one row in the ScoreResponse, matching by index
|
||||
// the candidate in ScoreRequest.candidates.
|
||||
message CandidateScore {
|
||||
// Sum of log P(token_i | prompt, candidate_token_<i) across the
|
||||
// candidate's tokens. The primary ranking signal.
|
||||
double log_prob = 1;
|
||||
// log_prob / num_tokens — populated when length_normalize=true on
|
||||
// the request.
|
||||
double length_normalized_log_prob = 2;
|
||||
// Per-token detail — populated when include_token_logprobs=true.
|
||||
repeated TokenLogProb tokens = 3;
|
||||
// Number of tokens the backend tokenised this candidate into, after
|
||||
// any backend-specific normalisation (e.g. leading-space handling).
|
||||
int32 num_tokens = 4;
|
||||
}
|
||||
|
||||
message TokenLogProb {
|
||||
string token = 1;
|
||||
double log_prob = 2;
|
||||
}
|
||||
|
||||
message ScoreResponse {
|
||||
repeated CandidateScore candidates = 1;
|
||||
}
|
||||
|
||||
message RerankRequest {
|
||||
string query = 1;
|
||||
repeated string documents = 2;
|
||||
@@ -325,6 +428,25 @@ message ModelOptions {
|
||||
// applied verbatim to the backend's engine constructor (e.g. vLLM AsyncEngineArgs).
|
||||
// Unknown keys produce an error at LoadModel time.
|
||||
string EngineArgs = 73;
|
||||
|
||||
// Proxy carries the cloud-proxy backend's per-model configuration.
|
||||
// Empty for non-proxy backends.
|
||||
ProxyOptions Proxy = 74;
|
||||
}
|
||||
|
||||
// ProxyOptions configures the cloud-proxy backend. UpstreamURL and
|
||||
// Mode are always meaningful; Provider only matters in translate mode.
|
||||
// The two api_key_* fields are mutually exclusive and resolved by the
|
||||
// backend at LoadModel — core forwards the references rather than the
|
||||
// plaintext key.
|
||||
message ProxyOptions {
|
||||
string upstream_url = 1;
|
||||
string mode = 2;
|
||||
string provider = 3;
|
||||
string api_key_env = 4;
|
||||
string api_key_file = 5;
|
||||
string upstream_model = 6;
|
||||
int32 request_timeout_seconds = 7;
|
||||
}
|
||||
|
||||
message Result {
|
||||
@@ -1002,3 +1124,32 @@ message QuantizationStopRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
// ForwardHeader is one HTTP header on the request or response. Headers
|
||||
// like Authorization are typically injected by the backend (from the
|
||||
// resolved API key) rather than passed through from the client.
|
||||
message ForwardHeader {
|
||||
string name = 1;
|
||||
string value = 2;
|
||||
}
|
||||
|
||||
// ForwardRequest is a streamed HTTP request to the upstream. First
|
||||
// message carries path/method/headers; subsequent messages carry
|
||||
// body_chunk only. All fields except body_chunk are honoured on the
|
||||
// first message and ignored thereafter.
|
||||
message ForwardRequest {
|
||||
string path = 1; // e.g. "/v1/chat/completions" — appended to the model's upstream_url
|
||||
string method = 2; // usually "POST"
|
||||
repeated ForwardHeader headers = 3;
|
||||
bytes body_chunk = 4;
|
||||
}
|
||||
|
||||
// ForwardReply is a streamed HTTP response from the upstream. First
|
||||
// message carries status/headers; subsequent messages carry body_chunk
|
||||
// only. SSE responses arrive as a sequence of body_chunk frames; the
|
||||
// caller is responsible for any parsing.
|
||||
message ForwardReply {
|
||||
int32 status = 1;
|
||||
repeated ForwardHeader headers = 2;
|
||||
bytes body_chunk = 3;
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
@@ -121,6 +122,40 @@ static std::string base64_encode_bytes(const unsigned char* data, size_t len) {
|
||||
|
||||
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
|
||||
|
||||
// Score bypasses the slot loop (see the comment on Score below) so it
|
||||
// must not run concurrently with any slot-loop RPC. These counters
|
||||
// are a defence-in-depth tripwire — ModelConfig.Validate already
|
||||
// rejects llama-cpp configs that mix score with chat/completion/
|
||||
// embeddings, so a healthy deployment never trips them. seq_cst is
|
||||
// load-bearing for the increment-then-check pattern below.
|
||||
static std::atomic<int> slot_loop_inflight{0};
|
||||
static std::atomic<int> score_inflight{0};
|
||||
|
||||
// Increment-then-check, not check-then-increment: two simultaneous
|
||||
// racers both observe the other's increment and both abort cleanly.
|
||||
// Reversed, both could see zero and proceed.
|
||||
struct conflict_guard {
|
||||
std::atomic<int>& self;
|
||||
conflict_guard(const char* rpc, std::atomic<int>& self_, std::atomic<int>& other, const char* other_name)
|
||||
: self(self_) {
|
||||
self.fetch_add(1, std::memory_order_seq_cst);
|
||||
int o = other.load(std::memory_order_seq_cst);
|
||||
if (o > 0) {
|
||||
fprintf(stderr,
|
||||
"FATAL: %s called with %s=%d. The llama-cpp backend cannot "
|
||||
"service Score and slot-loop RPCs concurrently — Score "
|
||||
"bypasses the slot loop and races the llama_context. Bind "
|
||||
"Score-using features to a model dedicated to scoring "
|
||||
"(known_usecases: [score] with no chat/completion/embeddings).\n",
|
||||
rpc, other_name, o);
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
~conflict_guard() {
|
||||
self.fetch_sub(1, std::memory_order_seq_cst);
|
||||
}
|
||||
};
|
||||
|
||||
static std::function<void(int)> shutdown_handler;
|
||||
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
||||
|
||||
@@ -1446,6 +1481,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("PredictStream", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
|
||||
@@ -2205,6 +2241,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Predict", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
data["stream"] = false;
|
||||
@@ -2963,6 +3000,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Embedding", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
body["stream"] = false;
|
||||
@@ -3070,6 +3108,8 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||
}
|
||||
|
||||
conflict_guard guard("Rerank", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// Create and queue the task
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
{
|
||||
@@ -3142,12 +3182,218 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
// Score returns the model's joint log-probability of each candidate
|
||||
// continuation given a shared prompt.
|
||||
//
|
||||
// WHY bypass the slot/task queue: upstream server_context exposes
|
||||
// get_llama_context as "main thread only" and the slot loop's
|
||||
// update_slots() owns the context whenever a task is in flight.
|
||||
// No public synchronization primitive is available — so Score is
|
||||
// unsafe to call concurrently with active generation through this
|
||||
// backend. In practice routing-classifier calls happen before the
|
||||
// request is routed to a generation backend, so the model used
|
||||
// for Score is typically idle. Concurrent Score calls are
|
||||
// serialised by a local mutex; KV-cache state is isolated behind
|
||||
// a dedicated sequence ID cleared between candidates.
|
||||
//
|
||||
// A patch to server-context.cpp that adds SERVER_TASK_TYPE_SCORE
|
||||
// and routes scoring through the slot loop would be the correct
|
||||
// long-term fix; tracked as a follow-up.
|
||||
//
|
||||
// Perf TODO (measured: ~450 ms warm for 3 candidates on Arch-
|
||||
// Router-1.5B Q4_K_M + Intel SYCL): the current loop re-decodes
|
||||
// `prompt + candidate` from scratch for every candidate, throwing
|
||||
// away the prompt's KV cache between iterations. A smarter
|
||||
// version would:
|
||||
// 1. Decode just the prompt once into score_seq_id.
|
||||
// 2. Snapshot/cp that sequence (llama_memory_seq_cp) into a
|
||||
// per-candidate sequence id.
|
||||
// 3. For each candidate, decode only its tokens onto the copy
|
||||
// (continuing from the saved prompt state), read logits.
|
||||
// 4. llama_memory_seq_rm the copy.
|
||||
// Estimated speedup: 3-candidate calls 450 ms -> ~150-200 ms,
|
||||
// 6-candidate calls 630 ms -> ~220 ms. Single source-file change,
|
||||
// no proto / Go-side changes needed. Worth doing once routing is
|
||||
// wired into the middleware and Score is on the hot path of every
|
||||
// chat request.
|
||||
grpc::Status Score(ServerContext* context, const backend::ScoreRequest* request, backend::ScoreResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
if (request->candidates_size() == 0) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "candidates must be non-empty");
|
||||
}
|
||||
|
||||
// Tripwire against the slot loop. Acquired before score_mutex
|
||||
// so it fires even when this Score is queued behind another.
|
||||
conflict_guard guard("Score", score_inflight, slot_loop_inflight, "slot_loop_inflight");
|
||||
|
||||
// Serialise concurrent Score calls. The slot loop is still
|
||||
// free to race with us — see the class comment above.
|
||||
static std::mutex score_mutex;
|
||||
std::lock_guard<std::mutex> score_lock(score_mutex);
|
||||
|
||||
llama_context * lctx = ctx_server.get_llama_context();
|
||||
if (lctx == nullptr) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "llama context unavailable (sleeping?)");
|
||||
}
|
||||
const llama_vocab * vocab = ctx_server.impl->vocab;
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
const int32_t n_ctx = llama_n_ctx(lctx);
|
||||
llama_memory_t mem = llama_get_memory(lctx);
|
||||
|
||||
// The KV-cache is sized to seq_to_stream.size() at load
|
||||
// (typically equal to n_slots, often 1). Sequence IDs must
|
||||
// be in [0, n_seq_max), so we can't pick a high-value
|
||||
// "private" ID — we have to share with the slot. We clear
|
||||
// the cache before AND after each candidate to keep
|
||||
// scoring isolated from whatever state the slot held, and
|
||||
// the static mutex above guarantees no other Score call is
|
||||
// racing in the meantime. The slot loop is still free to
|
||||
// race (see comment on this method) — Score must not run
|
||||
// concurrently with generation through this backend.
|
||||
const llama_seq_id score_seq_id = 0;
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
|
||||
// Tokenize the shared prompt once with add_special=true so
|
||||
// BOS is prepended when the model requires it. parse_special
|
||||
// keeps chat-template markers in the prompt intact.
|
||||
const std::string prompt = request->prompt();
|
||||
std::vector<llama_token> prompt_tokens = common_tokenize(vocab, prompt, /*add_special=*/true, /*parse_special=*/true);
|
||||
const int32_t prompt_len = (int32_t) prompt_tokens.size();
|
||||
|
||||
for (int ci = 0; ci < request->candidates_size(); ci++) {
|
||||
const std::string & candidate_text = request->candidates(ci);
|
||||
|
||||
// Re-tokenize prompt + candidate as a single string. BPE
|
||||
// merges across the boundary can shift the tokenization
|
||||
// versus tokenize(prompt) ++ tokenize(candidate), so we
|
||||
// find the divergence point against prompt_tokens.
|
||||
std::vector<llama_token> full_tokens = common_tokenize(vocab, prompt + candidate_text, /*add_special=*/true, /*parse_special=*/true);
|
||||
int32_t divergence = prompt_len;
|
||||
const int32_t min_len = std::min<int32_t>(prompt_len, (int32_t) full_tokens.size());
|
||||
for (int32_t i = 0; i < min_len; i++) {
|
||||
if (prompt_tokens[i] != full_tokens[i]) {
|
||||
divergence = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const int32_t cand_len = (int32_t) full_tokens.size() - divergence;
|
||||
backend::CandidateScore * cs = response->add_candidates();
|
||||
cs->set_num_tokens(cand_len);
|
||||
if (cand_len <= 0) {
|
||||
cs->set_log_prob(0.0);
|
||||
if (request->length_normalize()) {
|
||||
cs->set_length_normalized_log_prob(0.0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (divergence < 1) {
|
||||
// Need at least one prior token (typically BOS) to
|
||||
// predict the first candidate token's logit. Tokeniser
|
||||
// models without BOS + an empty prompt fall in here.
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
|
||||
"Score: prompt produced no leading tokens; need at least one (e.g. BOS) to predict candidate");
|
||||
}
|
||||
if ((int32_t) full_tokens.size() > n_ctx) {
|
||||
return grpc::Status(grpc::StatusCode::OUT_OF_RANGE,
|
||||
"Score: prompt+candidate exceeds context size (got " +
|
||||
std::to_string(full_tokens.size()) + ", n_ctx=" + std::to_string(n_ctx) + ")");
|
||||
}
|
||||
|
||||
// Build a batch covering the entire prompt+candidate. We
|
||||
// need logits at (divergence-1) onward — those are the
|
||||
// predictions for each candidate token.
|
||||
llama_batch batch = llama_batch_init((int32_t) full_tokens.size(), 0, 1);
|
||||
for (int32_t i = 0; i < (int32_t) full_tokens.size(); i++) {
|
||||
batch.token[i] = full_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = score_seq_id;
|
||||
// logits[i] is "do we want the prediction *for the
|
||||
// next token*, computed from this position?"
|
||||
// We want predictions for candidate tokens at
|
||||
// positions divergence .. full_tokens.size()-1, which
|
||||
// come from logits at positions (divergence-1) ..
|
||||
// (full_tokens.size()-2).
|
||||
bool need_logit = (i >= divergence - 1) && (i < (int32_t) full_tokens.size() - 1);
|
||||
batch.logits[i] = need_logit ? 1 : 0;
|
||||
}
|
||||
batch.n_tokens = (int32_t) full_tokens.size();
|
||||
|
||||
// Decode the batch. If decode fails (e.g. KV slot
|
||||
// exhaustion), surface as INTERNAL — the caller will
|
||||
// typically fall back to a sampling-based classifier.
|
||||
int decode_err = llama_decode(lctx, batch);
|
||||
if (decode_err != 0) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_decode failed during Score: " + std::to_string(decode_err));
|
||||
}
|
||||
|
||||
// Sum log-probabilities of the actual candidate tokens.
|
||||
double total_log_prob = 0.0;
|
||||
for (int32_t k = 0; k < cand_len; k++) {
|
||||
// The k-th candidate token sits at full_tokens index
|
||||
// (divergence + k). Its predicting logit is at batch
|
||||
// position (divergence + k - 1).
|
||||
int32_t logit_pos = divergence + k - 1;
|
||||
const float * logits = llama_get_logits_ith(lctx, logit_pos);
|
||||
if (logits == nullptr) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_get_logits_ith returned null at position " + std::to_string(logit_pos));
|
||||
}
|
||||
llama_token target_token = full_tokens[divergence + k];
|
||||
|
||||
// Compute log_softmax(logits)[target_token] with the
|
||||
// max-subtraction stability trick.
|
||||
float max_logit = logits[0];
|
||||
for (int32_t v = 1; v < n_vocab; v++) {
|
||||
if (logits[v] > max_logit) max_logit = logits[v];
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (int32_t v = 0; v < n_vocab; v++) {
|
||||
sum_exp += std::exp((double)(logits[v] - max_logit));
|
||||
}
|
||||
double token_log_prob = (double)(logits[target_token] - max_logit) - std::log(sum_exp);
|
||||
total_log_prob += token_log_prob;
|
||||
|
||||
if (request->include_token_logprobs()) {
|
||||
backend::TokenLogProb * tlp = cs->add_tokens();
|
||||
std::string piece = common_token_to_piece(lctx, target_token);
|
||||
tlp->set_token(piece);
|
||||
tlp->set_log_prob(token_log_prob);
|
||||
}
|
||||
}
|
||||
|
||||
cs->set_log_prob(total_log_prob);
|
||||
if (request->length_normalize() && cand_len > 0) {
|
||||
cs->set_length_normalized_log_prob(total_log_prob / (double) cand_len);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
// Drop this candidate's KV-cache contribution so the next
|
||||
// candidate starts from a clean state. Without this, the
|
||||
// next decode would conflict at positions 0..N-1 for our
|
||||
// sequence ID.
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
}
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("TokenizeString", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
body["stream"] = false;
|
||||
|
||||
@@ -3169,6 +3415,8 @@ public:
|
||||
|
||||
grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override {
|
||||
|
||||
conflict_guard guard("GetMetrics", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// request slots data using task queue
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
int task_id = rd.queue_tasks.get_new_id();
|
||||
|
||||
12
backend/go/cloud-proxy/Makefile
Normal file
12
backend/go/cloud-proxy/Makefile
Normal file
@@ -0,0 +1,12 @@
|
||||
GOCMD=go
|
||||
|
||||
cloud-proxy:
|
||||
CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o cloud-proxy ./
|
||||
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
build: cloud-proxy package
|
||||
|
||||
clean:
|
||||
rm -f cloud-proxy
|
||||
16
backend/go/cloud-proxy/cloud_proxy_suite_test.go
Normal file
16
backend/go/cloud-proxy/cloud_proxy_suite_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Ginkgo bootstrap. The other Test* functions in this package use
|
||||
// raw testing.T and run independently; they coexist with Ginkgo
|
||||
// specs registered via Describe / Context.
|
||||
func TestCloudProxySpecs(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "cloud-proxy specs")
|
||||
}
|
||||
39
backend/go/cloud-proxy/main.go
Normal file
39
backend/go/cloud-proxy/main.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package main
|
||||
|
||||
// cloud-proxy is a LocalAI backend that forwards request traffic to an
|
||||
// external HTTP provider (OpenAI, Anthropic, etc.). Two modes:
|
||||
//
|
||||
// - passthrough: serves the Forward RPC; the client wire format is
|
||||
// preserved end-to-end, no translation.
|
||||
// - translate: serves Predict/PredictStream; the backend converts
|
||||
// internal proto to the provider's wire format. (Phases 5–6.)
|
||||
//
|
||||
// LoadModel reads UpstreamURL/Mode/Provider/key references from
|
||||
// ProxyOptions and resolves the API key once at load time.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/xlog"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:50051", "the address to listen on")
|
||||
|
||||
func main() {
|
||||
// xlog's default handler emits ANSI color codes; that's fine for an
|
||||
// interactive shell but unreadable when the backend's stdout is
|
||||
// captured by LocalAI and tee'd to a log file. Force plain text when
|
||||
// LOCALAI_LOG_FORMAT is unset and stdout isn't a terminal.
|
||||
format := os.Getenv("LOCALAI_LOG_FORMAT")
|
||||
if format == "" && !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
format = xlog.TextFormat
|
||||
}
|
||||
xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), format))
|
||||
flag.Parse()
|
||||
if err := grpc.StartServer(*addr, NewCloudProxy()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
13
backend/go/cloud-proxy/package.sh
Executable file
13
backend/go/cloud-proxy/package.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the cloud-proxy binary into the package dir for the
|
||||
# final Dockerfile stage. Mirrors backend/go/local-store/package.sh —
|
||||
# no extra runtime libs needed since the backend is pure Go.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
mkdir -p $CURDIR/package
|
||||
cp -avf $CURDIR/cloud-proxy $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
270
backend/go/cloud-proxy/passthrough_edge_test.go
Normal file
270
backend/go/cloud-proxy/passthrough_edge_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("composeURL", func() {
|
||||
// Upstream URL convention: gallery configs put the canonical path
|
||||
// in upstream_url, so per-request Path is ignored. A bare-host
|
||||
// upstream_url accepts the per-request path.
|
||||
DescribeTable("path resolution",
|
||||
func(upstream, reqPath, want string) {
|
||||
got, err := composeURL(upstream, reqPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(got).To(Equal(want))
|
||||
},
|
||||
Entry("full path wins", "https://api.openai.com/v1/chat/completions", "/v1/something-else", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host accepts path", "https://api.openai.com", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("root slash treated as bare", "https://api.openai.com/", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host + empty path", "https://api.openai.com", "", "https://api.openai.com"),
|
||||
)
|
||||
|
||||
It("returns an error on invalid upstream URL", func() {
|
||||
_, err := composeURL("://garbage", "")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("applyAuthHeader", func() {
|
||||
It("sets x-api-key and anthropic-version for Anthropic, no Authorization", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerAnthropic, "ant-key")
|
||||
Expect(req.Header.Get("x-api-key")).To(Equal("ant-key"))
|
||||
Expect(req.Header.Get("anthropic-version")).NotTo(BeEmpty())
|
||||
Expect(req.Header.Get("Authorization")).To(BeEmpty(), "Authorization must not leak on Anthropic backend")
|
||||
})
|
||||
|
||||
It("sets Bearer Authorization for OpenAI, no x-api-key", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerOpenAI, "sk-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer sk-key"))
|
||||
Expect(req.Header.Get("x-api-key")).To(BeEmpty(), "x-api-key must not leak on OpenAI backend")
|
||||
})
|
||||
|
||||
It("defaults to Bearer when provider is empty", func() {
|
||||
// Passthrough mode often has provider == "" because the operator
|
||||
// doesn't claim a specific upstream wire format. Most providers
|
||||
// (including OpenAI-compatible ones) accept Bearer, so default to it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, "", "some-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer some-key"))
|
||||
})
|
||||
|
||||
It("preserves an existing anthropic-version header", func() {
|
||||
// If the client supplied anthropic-version (rare but legitimate
|
||||
// for an upstream pinned to a specific date), the proxy must not
|
||||
// clobber it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
req.Header.Set("anthropic-version", "2024-10-01")
|
||||
applyAuthHeader(req, providerAnthropic, "k")
|
||||
Expect(req.Header.Get("anthropic-version")).To(Equal("2024-10-01"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("isHopByHopHeader", func() {
|
||||
DescribeTable("hop-by-hop classification",
|
||||
func(header string, want bool) {
|
||||
Expect(isHopByHopHeader(header)).To(Equal(want))
|
||||
},
|
||||
Entry("Connection is hop-by-hop", "Connection", true),
|
||||
Entry("Keep-Alive is hop-by-hop", "Keep-Alive", true),
|
||||
Entry("Proxy-Connection is hop-by-hop", "Proxy-Connection", true),
|
||||
Entry("Transfer-Encoding is hop-by-hop", "Transfer-Encoding", true),
|
||||
Entry("TE is hop-by-hop", "TE", true),
|
||||
Entry("Trailer is hop-by-hop", "Trailer", true),
|
||||
Entry("Upgrade is hop-by-hop", "Upgrade", true),
|
||||
Entry("Host is hop-by-hop", "Host", true),
|
||||
Entry("Content-Length is hop-by-hop", "Content-Length", true),
|
||||
// Case-insensitive — RFC 7230 doesn't constrain header case.
|
||||
Entry("lowercase connection is hop-by-hop", "connection", true),
|
||||
Entry("uppercase HOST is hop-by-hop", "HOST", true),
|
||||
// Non hop-by-hop — must NOT be stripped.
|
||||
Entry("Authorization is end-to-end", "Authorization", false),
|
||||
Entry("Content-Type is end-to-end", "Content-Type", false),
|
||||
Entry("Accept is end-to-end", "Accept", false),
|
||||
Entry("X-Custom is end-to-end", "X-Custom", false),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("Forward", func() {
|
||||
It("strips hop-by-hop and Connection headers before upstream, preserves custom headers", func() {
|
||||
gotConnection := make(chan string, 1)
|
||||
gotXCustom := make(chan string, 1)
|
||||
gotHost := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotConnection <- r.Header.Get("Connection")
|
||||
gotXCustom <- r.Header.Get("X-Custom")
|
||||
gotHost <- r.Header.Get("Host")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-hopbyhop"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
{Name: "Connection", Value: "keep-alive"},
|
||||
{Name: "Host", Value: "spoofed.example.com"},
|
||||
{Name: "X-Custom", Value: "preserved"},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotConnection).To(BeEmpty(), "Connection must not leak to upstream")
|
||||
Expect(<-gotHost).NotTo(Equal("spoofed.example.com"), "Host header must not be spoofed through")
|
||||
Expect(<-gotXCustom).To(Equal("preserved"), "X-Custom header must survive")
|
||||
})
|
||||
|
||||
It("replaces caller-supplied Authorization with the configured key", func() {
|
||||
// The proxy must overwrite a client-supplied Authorization header
|
||||
// so a downstream caller can't smuggle stale or wrong credentials.
|
||||
gotAuth := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_AUTH_REPLACE_KEY", "sk-real")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_AUTH_REPLACE_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-replaces-auth"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
// Client-supplied Authorization with the wrong scheme / key.
|
||||
{Name: "Authorization", Value: "Basic Zm9vOmJhcg=="},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
|
||||
})
|
||||
|
||||
It("handles concurrent calls without interference", func() {
|
||||
// CloudProxy explicitly omits base.SingleThread — independent
|
||||
// Forward streams must not block each other or leak state.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
addr := "test://forward-concurrent"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
|
||||
const N = 8
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
stream, err := c.Forward(context.Background())
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
payload := "request-" + string(rune('A'+idx))
|
||||
if err := stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
BodyChunk: []byte(payload),
|
||||
}); err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv()
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
body = append(body, r.GetBodyChunk()...)
|
||||
}
|
||||
if string(body) != payload {
|
||||
errs <- &echoMismatch{want: payload, got: string(body)}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
var collected []error
|
||||
for err := range errs {
|
||||
collected = append(collected, err)
|
||||
}
|
||||
Expect(collected).To(BeEmpty(), "no concurrent Forward call should fail")
|
||||
})
|
||||
})
|
||||
|
||||
type echoMismatch struct{ want, got string }
|
||||
|
||||
func (e *echoMismatch) Error() string {
|
||||
return "echo mismatch: want " + strconv.Quote(e.want) + " got " + strconv.Quote(e.got)
|
||||
}
|
||||
508
backend/go/cloud-proxy/provider_anthropic.go
Normal file
508
backend/go/cloud-proxy/provider_anthropic.go
Normal file
@@ -0,0 +1,508 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Anthropic Messages API wire-format types. Narrowed to what translate
|
||||
// mode preserves through the Reply proto: text + tool_use blocks +
|
||||
// usage tokens. Image blocks, prompt caching, metadata, and stop
|
||||
// sequence metadata are not modelled — passthrough mode covers those.
|
||||
//
|
||||
// Notable differences from OpenAI:
|
||||
// - max_tokens is REQUIRED. Anthropic 400s without it.
|
||||
// - Roles are user/assistant only — system messages move to a
|
||||
// top-level `system` string field.
|
||||
// - Streaming SSE uses event: lines alongside data: lines. The
|
||||
// events we care about: content_block_start (carries tool_use
|
||||
// init: id + name), content_block_delta (text_delta with text;
|
||||
// input_json_delta with partial_json for tool arguments), and
|
||||
// message_stop (terminates the stream). Others are ignored.
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int32 `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// Content is `any` because Anthropic accepts a bare string OR a
|
||||
// list of content blocks. Use the string form for plain user/
|
||||
// assistant turns; switch to []anthropicContentBlock when the
|
||||
// turn needs tool_use (assistant) or tool_result (user) blocks.
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// anthropicToolChoice mirrors the four shapes Anthropic accepts:
|
||||
// {"type":"auto"} | {"type":"any"} | {"type":"tool","name":"X"} |
|
||||
// {"type":"none"} (newer models). OpenAI's "auto"/"none"/
|
||||
// "required"/{"function":{"name":"X"}} all map here.
|
||||
type anthropicToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// anthropicContentBlock is the union shape used both for response
|
||||
// blocks (text/tool_use we read off the wire) and outbound request
|
||||
// blocks (tool_use/tool_result we emit in the conversation history).
|
||||
// Anthropic encodes tool calls inline rather than as a separate field,
|
||||
// so we walk Content[] looking for type=="tool_use" on responses and
|
||||
// produce equivalent blocks when serialising prior-turn tool calls.
|
||||
type anthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
// Tool-result block fields. tool_result uses `content` (not
|
||||
// `text`) and pairs with `tool_use_id`; modelling them as
|
||||
// distinct fields avoids ambiguity at marshal time.
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
ResultContent string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []anthropicContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// anthropicStreamEvent is the union shape used for every event type we
|
||||
// process. Type discriminates; only the matching fields are populated.
|
||||
// content_block_start carries ContentBlock (with id/name for tool_use);
|
||||
// content_block_delta carries Delta (text or partial_json).
|
||||
type anthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index,omitempty"`
|
||||
ContentBlock *anthropicContentBlock `json:"content_block,omitempty"`
|
||||
Delta *anthropicStreamDelta `json:"delta,omitempty"`
|
||||
Message *anthropicResponse `json:"message,omitempty"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicStreamDelta struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
}
|
||||
|
||||
// Anthropic requires max_tokens. If the caller didn't set it, use a
|
||||
// generous-but-bounded default so the request doesn't 400.
|
||||
const anthropicDefaultMaxTokens int32 = 4096
|
||||
|
||||
const anthropicToolChoiceNone = "none"
|
||||
|
||||
// Reused JSON-Schema defaults for malformed inputs. Anthropic requires
|
||||
// input_schema to be a JSON object and tool_use.input to be a JSON
|
||||
// object; clients that omit them must not 400 the entire request.
|
||||
var (
|
||||
emptyJSONObject = json.RawMessage(`{}`)
|
||||
emptyObjectSchema = json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
)
|
||||
|
||||
func buildAnthropicRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := anthropicRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
MaxTokens: opts.GetTokens(),
|
||||
Stream: stream,
|
||||
StopSequences: opts.GetStopPrompts(),
|
||||
}
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = anthropicDefaultMaxTokens
|
||||
}
|
||||
// Newer Anthropic models 400 when both temperature and top_p are
|
||||
// set ("`temperature` and `top_p` cannot both be specified for
|
||||
// this model. Please use only one.") even though their docs only
|
||||
// "recommend" picking one. The OpenAI-compatible chat UI almost
|
||||
// always sends both with default values, so prefer temperature
|
||||
// and drop top_p when both are present.
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
} else if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
|
||||
req.Tools = convertOpenAITools(opts.GetTools())
|
||||
req.ToolChoice = convertOpenAIToolChoice(opts.GetToolChoice())
|
||||
// Anthropic rejects tool_choice without tools and older models
|
||||
// don't accept {"type":"none"} — collapse to a no-tools request.
|
||||
if req.ToolChoice != nil && req.ToolChoice.Type == anthropicToolChoiceNone {
|
||||
req.Tools, req.ToolChoice = nil, nil
|
||||
}
|
||||
|
||||
var systemParts []string
|
||||
for _, m := range opts.GetMessages() {
|
||||
role := m.GetRole()
|
||||
if role == "system" {
|
||||
if c := m.GetContent(); c != "" {
|
||||
systemParts = append(systemParts, c)
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch role {
|
||||
case "user":
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "assistant":
|
||||
if blocks := assistantBlocks(m); blocks != nil {
|
||||
req.Messages = append(req.Messages, anthropicMessage{Role: "assistant", Content: blocks})
|
||||
continue
|
||||
}
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "assistant",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "tool", "function":
|
||||
req.Messages = appendToolResult(req.Messages, anthropicContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: m.GetToolCallId(),
|
||||
ResultContent: m.GetContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
req.System = strings.Join(systemParts, "\n\n")
|
||||
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []anthropicMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// appendToolResult appends a tool_result block as a user message,
|
||||
// merging into a preceding user message that already carries blocks.
|
||||
// Anthropic concatenates consecutive same-role messages on its end,
|
||||
// but explicit merging keeps the body smaller and the conversation
|
||||
// strictly alternating — which some upstream filters require.
|
||||
func appendToolResult(msgs []anthropicMessage, block anthropicContentBlock) []anthropicMessage {
|
||||
if n := len(msgs); n > 0 && msgs[n-1].Role == "user" {
|
||||
if existing, ok := msgs[n-1].Content.([]anthropicContentBlock); ok {
|
||||
msgs[n-1].Content = append(existing, block)
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
return append(msgs, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: []anthropicContentBlock{block},
|
||||
})
|
||||
}
|
||||
|
||||
func convertOpenAITools(toolsJSON string) []anthropicTool {
|
||||
if toolsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var raw []openAITool
|
||||
if err := json.Unmarshal([]byte(toolsJSON), &raw); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tools JSON, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
tools := make([]anthropicTool, 0, len(raw))
|
||||
for _, t := range raw {
|
||||
if t.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
schema := t.Function.Parameters
|
||||
if len(schema) == 0 {
|
||||
schema = emptyObjectSchema
|
||||
}
|
||||
tools = append(tools, anthropicTool{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// convertOpenAIToolChoice accepts the spec form
|
||||
// ({type:function, function:{name:X}}) and the flat legacy form
|
||||
// ({type:function, name:X}) some clients send. Unknown object shapes
|
||||
// are warned and dropped rather than silently treated as auto.
|
||||
func convertOpenAIToolChoice(toolChoiceJSON string) *anthropicToolChoice {
|
||||
if toolChoiceJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var asString string
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asString); err == nil {
|
||||
switch asString {
|
||||
case "auto":
|
||||
return &anthropicToolChoice{Type: "auto"}
|
||||
case "none":
|
||||
return &anthropicToolChoice{Type: anthropicToolChoiceNone}
|
||||
case "required":
|
||||
return &anthropicToolChoice{Type: "any"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var asObj struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asObj); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tool_choice, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
if name := asObj.Function.Name; name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: name}
|
||||
}
|
||||
if asObj.Name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: asObj.Name}
|
||||
}
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unrecognised tool_choice shape, dropping", "shape", toolChoiceJSON)
|
||||
return nil
|
||||
}
|
||||
|
||||
// openAITool mirrors pkg/functions.Tool but keeps Parameters as
|
||||
// json.RawMessage so the input_schema passes through verbatim — no
|
||||
// re-marshal cost, no fidelity loss on exotic schemas.
|
||||
type openAITool struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters json.RawMessage `json:"parameters"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
func assistantBlocks(m *pb.Message) []anthropicContentBlock {
|
||||
toolCallsJSON := m.GetToolCalls()
|
||||
if toolCallsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var toolCalls []openAIToolCall
|
||||
if err := json.Unmarshal([]byte(toolCallsJSON), &toolCalls); err != nil || len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
blocks := make([]anthropicContentBlock, 0, len(toolCalls)+1)
|
||||
if text := m.GetContent(); text != "" {
|
||||
blocks = append(blocks, anthropicContentBlock{Type: "text", Text: text})
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
// OpenAI's arguments are a JSON-encoded string; pass through
|
||||
// as RawMessage so a non-JSON string from a poorly-formed
|
||||
// local model doesn't crash the marshaller downstream.
|
||||
args := json.RawMessage(tc.Function.Arguments)
|
||||
if len(args) == 0 {
|
||||
args = emptyJSONObject
|
||||
}
|
||||
blocks = append(blocks, anthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: args,
|
||||
})
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
// doAnthropicRequest is the Anthropic counterpart of doOpenAIRequest.
|
||||
// applyAuthHeader sets x-api-key and anthropic-version when provider
|
||||
// is anthropic, so this method doesn't need to duplicate that.
|
||||
func (c *CloudProxy) doAnthropicRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictAnthropicRich returns the full Reply: joined text from all
|
||||
// text blocks, tool_use blocks mapped to ToolCallDelta, and usage
|
||||
// tokens.
|
||||
func (c *CloudProxy) predictAnthropicRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildAnthropicRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed anthropicResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
|
||||
reply := &pb.Reply{}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.InputTokens)
|
||||
reply.Tokens = int32(parsed.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
var content strings.Builder
|
||||
var toolCalls []*pb.ToolCallDelta
|
||||
toolIdx := 0
|
||||
for _, b := range parsed.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
content.WriteString(b.Text)
|
||||
case "tool_use":
|
||||
// Input is a structured JSON object; we serialise to a
|
||||
// string so it fits the OpenAI-shaped arguments field
|
||||
// downstream consumers expect.
|
||||
args := ""
|
||||
if len(b.Input) > 0 {
|
||||
args = string(b.Input)
|
||||
}
|
||||
toolCalls = append(toolCalls, newToolCallDelta(toolIdx, b.ID, b.Name, args))
|
||||
toolIdx++
|
||||
}
|
||||
}
|
||||
reply.Message = []byte(content.String())
|
||||
if len(toolCalls) > 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{ToolCalls: toolCalls}}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictAnthropicStreamRich streams Reply chunks from Anthropic's SSE.
|
||||
// Three event types matter: content_block_start (initialises tool_use
|
||||
// id+name), content_block_delta (carries text or input_json_delta),
|
||||
// message_stop (terminates). The block index from the wire feeds
|
||||
// straight into ToolCallDelta.Index so downstream consumers can
|
||||
// reassemble multiple parallel tool calls.
|
||||
func (c *CloudProxy) predictAnthropicStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildAnthropicRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
var ev anthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
switch ev.Type {
|
||||
case "content_block_start":
|
||||
// tool_use blocks announce id + name here; arguments arrive
|
||||
// in subsequent input_json_delta events. Emit a Reply with
|
||||
// just the tool_call init fields so consumers can allocate
|
||||
// a slot at this index.
|
||||
if ev.ContentBlock != nil && ev.ContentBlock.Type == "tool_use" {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, ev.ContentBlock.ID, ev.ContentBlock.Name, ""),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
if ev.Delta == nil {
|
||||
continue
|
||||
}
|
||||
switch ev.Delta.Type {
|
||||
case "text_delta":
|
||||
if ev.Delta.Text == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Message: []byte(ev.Delta.Text),
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: ev.Delta.Text}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
case "input_json_delta":
|
||||
if ev.Delta.PartialJSON == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, "", "", ev.Delta.PartialJSON),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_delta":
|
||||
// Anthropic sends final usage in message_delta.usage. Emit
|
||||
// a usage-only Reply so the consumer can record totals.
|
||||
if ev.Usage != nil {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Tokens: int32(ev.Usage.OutputTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
334
backend/go/cloud-proxy/provider_anthropic_test.go
Normal file
334
backend/go/cloud-proxy/provider_anthropic_test.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// fakeAnthropicUpstream mirrors fakeOpenAIUpstream but decodes the
|
||||
// request body as an anthropicRequest so tests can assert on the
|
||||
// translated wire shape (system field, max_tokens, etc.).
|
||||
func fakeAnthropicUpstream(t *testing.T, handler func(req anthropicRequest) (status int, body string, contentType string)) (*httptest.Server, *anthropicRequest) {
|
||||
t.Helper()
|
||||
var captured anthropicRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newAnthropicTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_FAKE", "sk-ant-fake")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "claude-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_FAKE",
|
||||
UpstreamModel: "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_BasicMessages(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hi there"}],"model":"claude-3-5-sonnet-20241022","usage":{"input_tokens":5,"output_tokens":2}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
g.Expect(captured.Model).To(Equal("claude-3-5-sonnet-20241022"))
|
||||
// System message must be hoisted out of Messages into top-level field.
|
||||
g.Expect(captured.System).To(Equal("be brief"))
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
// Anthropic 400s when both temperature and top_p are set; the
|
||||
// translator must prefer temperature and drop top_p.
|
||||
g.Expect(captured.TopP).To(BeNil())
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
// When only top_p is set, it should be forwarded.
|
||||
func TestPredict_Anthropic_TopPOnly(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
|
||||
TopP: 0.9,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
// PredictOptions.TopP is float32 on the wire; the translator widens
|
||||
// to float64 so 0.9 round-trips as 0.8999999761581421… — compare
|
||||
// with a small tolerance rather than exact equality.
|
||||
g.Expect(captured.TopP).NotTo(BeNil())
|
||||
g.Expect(math.Abs(*captured.TopP - 0.9)).To(BeNumerically("<=", 1e-6))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_DefaultsMaxTokens(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic 400s without max_tokens. The translator must default
|
||||
// it when the caller doesn't supply Tokens.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.MaxTokens).To(Equal(anthropicDefaultMaxTokens))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?", Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_ConcatenatesContentBlocks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic may return multiple text blocks; the translator joins
|
||||
// them so the Predict() string return is the full assistant message.
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"hello "},{"type":"text","text":"world"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello world"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 401, `{"error":{"type":"authentication_error","message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_Anthropic_StreamsTextDeltas(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Real Anthropic SSE has event: lines + data: lines. The translator
|
||||
// only needs the data: payload; only content_block_delta with
|
||||
// delta.type=text_delta carries content. message_stop ends.
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" \"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"world\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
Tokens: 16,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_TranslatesOpenAITools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "weather in Paris?"}},
|
||||
Tools: tools,
|
||||
ToolChoice: `"auto"`,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(HaveLen(1))
|
||||
g.Expect(captured.Tools[0].Name).To(Equal("get_weather"))
|
||||
g.Expect(captured.Tools[0].Description).To(Equal("Get weather"))
|
||||
// input_schema must be the parameters object verbatim.
|
||||
g.Expect(string(captured.Tools[0].InputSchema)).To(ContainSubstring(`"city"`))
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("auto"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_RequiredMapsToAny(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"required"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("any"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NoneDropsTools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"none"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(BeNil())
|
||||
g.Expect(captured.ToolChoice).To(BeNil())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NamedFunction(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"weather","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `{"type":"function","function":{"name":"weather"}}`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("tool"))
|
||||
g.Expect(captured.ToolChoice.Name).To(Equal("weather"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_RoundTripsAssistantToolCalls(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// LocalAI Assistant's second turn: the LLM previously emitted a
|
||||
// tool_use, the server executed it, and the conversation now
|
||||
// includes the assistant turn (with tool_calls) plus a tool-role
|
||||
// result message. Both must convert to Anthropic block form.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"list_models","parameters":{"type":"object"}}}]`
|
||||
toolCallsJSON := `[{"id":"call_abc","type":"function","function":{"name":"list_models","arguments":"{}"}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Tools: tools,
|
||||
Messages: []*pb.Message{
|
||||
{Role: "user", Content: "what models are installed?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: toolCallsJSON},
|
||||
{Role: "tool", Content: `{"models":["a","b"]}`, ToolCallId: "call_abc"},
|
||||
},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
g.Expect(captured.Messages).To(HaveLen(3))
|
||||
// 1. user text — bare string
|
||||
s, ok := captured.Messages[0].Content.(string)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(s).To(Equal("what models are installed?"))
|
||||
// 2. assistant — must be a content-block list with one tool_use
|
||||
// json.Unmarshal of `any` produces []any not []anthropicContentBlock.
|
||||
blocks, ok := captured.Messages[1].Content.([]any)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(blocks).To(HaveLen(1))
|
||||
b0, _ := blocks[0].(map[string]any)
|
||||
g.Expect(b0["type"]).To(Equal("tool_use"))
|
||||
g.Expect(b0["id"]).To(Equal("call_abc"))
|
||||
g.Expect(b0["name"]).To(Equal("list_models"))
|
||||
// 3. tool → user with tool_result block
|
||||
g.Expect(captured.Messages[2].Role).To(Equal("user"))
|
||||
resBlocks, _ := captured.Messages[2].Content.([]any)
|
||||
r0, _ := resBlocks[0].(map[string]any)
|
||||
g.Expect(r0["type"]).To(Equal("tool_result"))
|
||||
g.Expect(r0["tool_use_id"]).To(Equal("call_abc"))
|
||||
g.Expect(r0["content"]).To(Equal(`{"models":["a","b"]}`))
|
||||
}
|
||||
119
backend/go/cloud-proxy/provider_edge_test.go
Normal file
119
backend/go/cloud-proxy/provider_edge_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Verify buildOpenAIRequest preserves caller-supplied tools and
|
||||
// tool_choice as opaque JSON. PredictOptions carries them as strings;
|
||||
// they must land in the outbound request body unchanged so the
|
||||
// upstream sees the caller's intent verbatim. A regression here would
|
||||
// silently disable function calling for translate-mode clients.
|
||||
func TestBuildOpenAIRequest_ToolsAndToolChoicePassthrough(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
toolsJSON := `[{"type":"function","function":{"name":"search","parameters":{"type":"object"}}}]`
|
||||
choiceJSON := `{"type":"function","function":{"name":"search"}}`
|
||||
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find x"}},
|
||||
Tools: toolsJSON,
|
||||
ToolChoice: choiceJSON,
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var decoded openAIRequest
|
||||
err = json.Unmarshal(body, &decoded)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// Compare the JSON-canonical form so whitespace differences are ignored.
|
||||
gotTools, _ := json.Marshal(json.RawMessage(decoded.Tools))
|
||||
wantTools, _ := json.Marshal(json.RawMessage(toolsJSON))
|
||||
g.Expect(string(gotTools)).To(Equal(string(wantTools)))
|
||||
gotChoice, _ := json.Marshal(json.RawMessage(decoded.ToolChoice))
|
||||
wantChoice, _ := json.Marshal(json.RawMessage(choiceJSON))
|
||||
g.Expect(string(gotChoice)).To(Equal(string(wantChoice)))
|
||||
}
|
||||
|
||||
// Garbage JSON in tools / tool_choice is silently dropped (omitted)
|
||||
// rather than blowing up the request. Documents the parseRawJSON
|
||||
// behaviour — operators shouldn't see hard failures from an upstream
|
||||
// caller's mis-formatted tools field.
|
||||
func TestBuildOpenAIRequest_InvalidToolsJSONDropped(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: "this is not json",
|
||||
ToolChoice: "{also bad",
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("this is not json"))
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("{also bad"))
|
||||
}
|
||||
|
||||
// Anthropic empty content array yields an empty Reply (not an error).
|
||||
// Mirrors how an upstream tool_use-only response might arrive — the
|
||||
// content array can legitimately be empty in some edge cases.
|
||||
func TestPredictRich_Anthropic_EmptyContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"m1","type":"message","role":"assistant","content":[],"usage":{"input_tokens":3,"output_tokens":0}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(0))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(3)))
|
||||
}
|
||||
|
||||
// A truncated / malformed SSE payload mid-stream should be tolerated:
|
||||
// the malformed chunk gets skipped (xlog.Debug logged), valid chunks
|
||||
// before AND after it still reach the channel.
|
||||
func TestPredictStreamRich_OpenAI_TolerantOfBadChunks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
body := strings.Join([]string{
|
||||
`data: {"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
``,
|
||||
`data: this-is-not-json{{`,
|
||||
``,
|
||||
`data: {"choices":[{"index":0,"delta":{"content":" world"}}]}`,
|
||||
``,
|
||||
`data: [DONE]`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var assembled strings.Builder
|
||||
for reply := range results {
|
||||
assembled.Write(reply.GetMessage())
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// The good chunks before and after the malformed one both made it through.
|
||||
g.Expect(assembled.String()).To(Equal("hello world"))
|
||||
}
|
||||
320
backend/go/cloud-proxy/provider_openai.go
Normal file
320
backend/go/cloud-proxy/provider_openai.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// OpenAI Chat Completions wire-format types. Narrowed to the fields
|
||||
// translate mode needs to preserve through the Reply proto: content,
|
||||
// role, tool_calls (typed so we can map them to pb.ToolCallDelta),
|
||||
// and sampling params copied verbatim from PredictOptions.
|
||||
//
|
||||
// Provider-specific extensions (logit_bias, function calling beyond
|
||||
// tool_calls, etc.) are not modelled — passthrough mode covers callers
|
||||
// that need full upstream fidelity.
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int32 `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// openAIToolCall covers both the non-streaming response shape (full
|
||||
// id+function+arguments) and the streaming-delta shape (sparse fields,
|
||||
// index assignment). The proto's ToolCallDelta absorbs both — name is
|
||||
// set on first appearance, arguments arrive incrementally in streaming.
|
||||
type openAIToolCall struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function openAIFunctionCall `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
type openAIFunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Choices []openAIChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChunk struct {
|
||||
Choices []openAIStreamChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// buildOpenAIRequest converts pb.PredictOptions into the OpenAI Chat
|
||||
// Completions request body. Prefers Messages when non-empty; falls
|
||||
// back to wrapping Prompt as a single user message so plain
|
||||
// /completions-style calls still work in translate mode.
|
||||
func buildOpenAIRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := openAIRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
Stream: stream,
|
||||
Stop: opts.GetStopPrompts(),
|
||||
Tools: parseRawJSON(opts.GetTools()),
|
||||
ToolChoice: parseRawJSON(opts.GetToolChoice()),
|
||||
}
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
}
|
||||
if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
req.MaxTokens = &n
|
||||
}
|
||||
if p := opts.GetFrequencyPenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.FrequencyPenalty = &v
|
||||
}
|
||||
if p := opts.GetPresencePenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.PresencePenalty = &v
|
||||
}
|
||||
|
||||
for _, m := range opts.GetMessages() {
|
||||
msg := openAIMessage{
|
||||
Role: m.GetRole(),
|
||||
Content: m.GetContent(),
|
||||
Name: m.GetName(),
|
||||
ToolCallID: m.GetToolCallId(),
|
||||
}
|
||||
// Pre-existing tool_calls arrive as a JSON string from the
|
||||
// upstream caller's previous assistant turn; pass-through as-is.
|
||||
if tc := m.GetToolCalls(); tc != "" {
|
||||
_ = json.Unmarshal([]byte(tc), &msg.ToolCalls)
|
||||
}
|
||||
req.Messages = append(req.Messages, msg)
|
||||
}
|
||||
// Fallback for plain Prompt requests (no Messages array). LocalAI
|
||||
// templating may have produced a flat prompt; rewrap as a single
|
||||
// user message so the upstream chat endpoint accepts it.
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []openAIMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// modelName picks the upstream model: upstream_model from the proxy
|
||||
// config wins (operator override), else the local model name captured
|
||||
// at LoadModel time. Operator sets upstream_model to map LocalAI's
|
||||
// alias (e.g. "claude-strict") to the upstream's canonical name
|
||||
// (e.g. "claude-3-5-sonnet-20241022").
|
||||
func modelName(cfg *proxyConfig, _ *pb.PredictOptions) string {
|
||||
if cfg.upstreamModel != "" {
|
||||
return cfg.upstreamModel
|
||||
}
|
||||
return cfg.localModel
|
||||
}
|
||||
|
||||
// parseRawJSON parses a JSON string into a RawMessage so it round-trips
|
||||
// into the upstream body. Returns nil for empty/invalid input so the
|
||||
// field is omitted (omitempty).
|
||||
func parseRawJSON(s string) json.RawMessage {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
var probe json.RawMessage
|
||||
if err := json.Unmarshal([]byte(s), &probe); err != nil {
|
||||
return nil
|
||||
}
|
||||
return probe
|
||||
}
|
||||
|
||||
// doOpenAIRequest builds + sends the upstream request. Returns the
|
||||
// raw response on success; caller handles status / body.
|
||||
func (c *CloudProxy) doOpenAIRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictOpenAIRich is the non-streaming translate path. Returns a
|
||||
// fully-populated *pb.Reply with assistant content, tool calls, and
|
||||
// token usage. The gRPC server forwards the Reply verbatim.
|
||||
func (c *CloudProxy) predictOpenAIRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildOpenAIRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
if len(parsed.Choices) == 0 {
|
||||
return nil, errors.New("cloud-proxy: upstream returned no choices")
|
||||
}
|
||||
|
||||
choice := parsed.Choices[0]
|
||||
reply := &pb.Reply{
|
||||
Message: []byte(choice.Message.Content),
|
||||
}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.PromptTokens)
|
||||
reply.Tokens = int32(parsed.Usage.CompletionTokens)
|
||||
}
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// Non-streaming: a single ChatDelta carries the full tool-call
|
||||
// set. Index/Name/Arguments are populated together; downstream
|
||||
// consumers don't need to assemble streaming deltas.
|
||||
delta := &pb.ChatDelta{}
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
delta.ToolCalls = append(delta.ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
reply.ChatDeltas = []*pb.ChatDelta{delta}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictOpenAIStreamRich streams *pb.Reply chunks. Each chunk carries
|
||||
// either a content delta (Message + ChatDeltas[].Content) or tool-call
|
||||
// deltas (ChatDeltas[].ToolCalls). The final Reply carries usage tokens
|
||||
// when the upstream sends them (stream_options.include_usage).
|
||||
func (c *CloudProxy) predictOpenAIStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildOpenAIRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
var chunk openAIStreamChunk
|
||||
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
// Usage frames may arrive separately from content frames when
|
||||
// stream_options.include_usage is set; emit a usage-only Reply
|
||||
// in that case so the consumer sees the totals.
|
||||
if chunk.Usage != nil && len(chunk.Choices) == 0 {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
PromptTokens: int32(chunk.Usage.PromptTokens),
|
||||
Tokens: int32(chunk.Usage.CompletionTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, ch := range chunk.Choices {
|
||||
reply := &pb.Reply{}
|
||||
if ch.Delta.Content != "" {
|
||||
reply.Message = []byte(ch.Delta.Content)
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{Content: ch.Delta.Content}}
|
||||
}
|
||||
if len(ch.Delta.ToolCalls) > 0 {
|
||||
if len(reply.ChatDeltas) == 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{}}
|
||||
}
|
||||
for _, tc := range ch.Delta.ToolCalls {
|
||||
reply.ChatDeltas[0].ToolCalls = append(reply.ChatDeltas[0].ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
}
|
||||
if reply.Message == nil && len(reply.ChatDeltas) == 0 {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, reply) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
170
backend/go/cloud-proxy/provider_openai_test.go
Normal file
170
backend/go/cloud-proxy/provider_openai_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// fakeOpenAIUpstream returns an httptest.Server that decodes the
|
||||
// inbound request as an openAIRequest, calls handler with it, and
|
||||
// writes the handler's reply as the response.
|
||||
func fakeOpenAIUpstream(t *testing.T, handler func(req openAIRequest) (status int, body string, contentType string)) (*httptest.Server, *openAIRequest) {
|
||||
t.Helper()
|
||||
var captured openAIRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_OPENAI_FAKE", "sk-fake-openai")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "gpt-4o-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
ApiKeyEnv: "CLOUD_PROXY_OPENAI_FAKE",
|
||||
UpstreamModel: "gpt-4o",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_BasicChat(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"id":"resp-1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
// Verify the upstream saw a properly-translated request.
|
||||
g.Expect(captured.Model).To(Equal("gpt-4o"))
|
||||
g.Expect(captured.Messages).To(HaveLen(2))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("system"))
|
||||
g.Expect(captured.Messages[1].Role).To(Equal("user"))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
g.Expect(captured.MaxTokens).NotTo(BeNil())
|
||||
g.Expect(*captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// No Messages array — backend should synth a single user message
|
||||
// from Prompt so non-chat clients still route through translate.
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 401, `{"error":{"message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_OpenAI_StreamsContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Stream three content deltas then [DONE]. Verify the channel
|
||||
// receives them in order with no missing pieces.
|
||||
chunks := []string{
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":" "}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"world"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestPredict_RejectedInPassthroughMode(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_FAKE", "k")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_, err = cp.Predict(&pb.PredictOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("only valid in translate"))
|
||||
}
|
||||
429
backend/go/cloud-proxy/proxy.go
Normal file
429
backend/go/cloud-proxy/proxy.go
Normal file
@@ -0,0 +1,429 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
|
||||
// import core to keep the boundary clean.
|
||||
const (
|
||||
modePassthrough = "passthrough"
|
||||
modeTranslate = "translate"
|
||||
|
||||
providerOpenAI = "openai"
|
||||
providerAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// CloudProxy is the LocalAI backend that proxies model traffic to a
|
||||
// configured upstream HTTP provider. Concurrency: base.SingleThread is
|
||||
// NOT embedded — forward calls are independent and HTTP transport is
|
||||
// goroutine-safe, so multiple Forward streams can run in parallel.
|
||||
// Locking would serialise requests to a chat provider for no benefit.
|
||||
type CloudProxy struct {
|
||||
base.Base
|
||||
|
||||
cfg atomic.Pointer[proxyConfig]
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type proxyConfig struct {
|
||||
upstreamURL string
|
||||
mode string
|
||||
provider string
|
||||
upstreamModel string
|
||||
localModel string // ModelOptions.Model — fallback when upstream_model is unset
|
||||
apiKey string // resolved at Load time
|
||||
}
|
||||
|
||||
func NewCloudProxy() *CloudProxy {
|
||||
// No Client-level Timeout — that would bound streaming SSE
|
||||
// responses too, which can legitimately last minutes. Per-request
|
||||
// deadlines come from the gRPC stream context.
|
||||
return &CloudProxy{client: &http.Client{}}
|
||||
}
|
||||
|
||||
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
|
||||
po := opts.GetProxy()
|
||||
if po == nil {
|
||||
return errors.New("cloud-proxy: Load requires ProxyOptions to be set")
|
||||
}
|
||||
if po.GetUpstreamUrl() == "" {
|
||||
return errors.New("cloud-proxy: upstream_url is required")
|
||||
}
|
||||
if _, err := url.ParseRequestURI(po.GetUpstreamUrl()); err != nil {
|
||||
return fmt.Errorf("cloud-proxy: upstream_url %q invalid: %w", po.GetUpstreamUrl(), err)
|
||||
}
|
||||
|
||||
mode := po.GetMode()
|
||||
if mode == "" {
|
||||
mode = modePassthrough
|
||||
}
|
||||
switch mode {
|
||||
case modePassthrough:
|
||||
case modeTranslate:
|
||||
switch po.GetProvider() {
|
||||
case providerOpenAI:
|
||||
// implemented in provider_openai.go
|
||||
case providerAnthropic:
|
||||
// implemented in provider_anthropic.go
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: translate mode requires provider in {%s, %s}, got %q",
|
||||
providerOpenAI, providerAnthropic, po.GetProvider())
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: unknown mode %q", mode)
|
||||
}
|
||||
|
||||
key, err := resolveAPIKey(po.GetApiKeyEnv(), po.GetApiKeyFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.cfg.Store(&proxyConfig{
|
||||
upstreamURL: po.GetUpstreamUrl(),
|
||||
mode: mode,
|
||||
provider: po.GetProvider(),
|
||||
upstreamModel: po.GetUpstreamModel(),
|
||||
localModel: opts.GetModel(),
|
||||
apiKey: key,
|
||||
})
|
||||
xlog.Info("cloud-proxy: ready",
|
||||
"upstream", po.GetUpstreamUrl(),
|
||||
"mode", mode,
|
||||
"provider", po.GetProvider(),
|
||||
"has_key", key != "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveAPIKey mirrors config.ProxyConfig.ResolveAPIKey. Duplicated
|
||||
// (a few lines) rather than importing core/config from a backend
|
||||
// binary — keeps backends independent of core's package layout.
|
||||
// Mutual-exclusion is enforced upstream in core/config.Validate.
|
||||
func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
if envName != "" {
|
||||
v := os.Getenv(envName)
|
||||
if v == "" {
|
||||
return "", fmt.Errorf("cloud-proxy: api_key_env %q is unset", envName)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
if filePath != "" {
|
||||
b, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: read api_key_file %q: %w", filePath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(b)), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// PredictRich is the non-streaming translate path. Returns a fully-
|
||||
// populated *pb.Reply: content, tool-call deltas (ChatDeltas), and
|
||||
// usage tokens. Implements the optional grpc.AIModelRich interface;
|
||||
// the gRPC server prefers this path over Predict when present so
|
||||
// tool calls survive the round-trip. Passthrough mode rejects
|
||||
// PredictRich — callers must use Forward.
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIRich(ctx, cfg, opts)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicRich(ctx, cfg, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("cloud-proxy: predict not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// PredictStreamRich is the rich streaming counterpart of PredictRich.
|
||||
// Each emitted Reply carries either a content delta, tool-call deltas,
|
||||
// or usage tokens (the final upstream frame). base.Base.PredictStream
|
||||
// is bypassed when AIModelRich is implemented, so the channel is
|
||||
// closed by the gRPC server pump.
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict-stream", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict-stream failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIStreamRich(ctx, cfg, opts, results)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicStreamRich(ctx, cfg, opts, results)
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: predictStream not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Predict is the legacy (string, error) AIModel signature. Used only
|
||||
// if a caller goes through the non-rich path (it shouldn't, since
|
||||
// server.go prefers PredictRich). Provided so the AIModel interface
|
||||
// is satisfied for backends that haven't opted into the rich variant.
|
||||
func (c *CloudProxy) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
reply, err := c.PredictRich(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.GetMessage()), nil
|
||||
}
|
||||
|
||||
// PredictStream is the legacy chan-string streaming path. Adapts the
|
||||
// rich stream by extracting only content text — tool-call-only chunks
|
||||
// (no Message bytes) and usage-only chunks are silently dropped, since
|
||||
// the legacy chan-string contract cannot represent them. Consumers
|
||||
// that need tool calls must call PredictStreamRich directly.
|
||||
func (c *CloudProxy) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
defer close(results)
|
||||
richCh := make(chan *pb.Reply)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- c.PredictStreamRich(opts, richCh)
|
||||
close(richCh)
|
||||
}()
|
||||
for reply := range richCh {
|
||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
||||
results <- string(msg)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// sendReply pushes one Reply onto a stream channel honouring ctx
|
||||
// cancellation. Returns false on cancel so the caller can exit with
|
||||
// ctx.Err(). Used by both translate-mode providers.
|
||||
func sendReply(ctx context.Context, results chan<- *pb.Reply, reply *pb.Reply) bool {
|
||||
select {
|
||||
case results <- reply:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newToolCallDelta is a small constructor for the cross-provider
|
||||
// tool-call delta shape. Centralised so the int32 cast and the four
|
||||
// fields stay consistent across the OpenAI / Anthropic translators.
|
||||
// Empty name/args are valid — Anthropic streaming announces the call
|
||||
// with id+name then sends arguments incrementally; OpenAI's reverse
|
||||
// pattern (args without name) also lands here.
|
||||
func newToolCallDelta(index int, id, name, args string) *pb.ToolCallDelta {
|
||||
return &pb.ToolCallDelta{
|
||||
Index: int32(index),
|
||||
Id: id,
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward shovels bytes between a Forward gRPC stream and an upstream
|
||||
// HTTP request. First request message carries path/method/headers and
|
||||
// the initial body chunk; subsequent messages append body chunks. The
|
||||
// first reply carries upstream status + response headers; subsequent
|
||||
// replies stream body chunks until the upstream connection closes.
|
||||
// Cancellation of ctx (the gRPC stream context) closes the upstream
|
||||
// connection.
|
||||
func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest, out chan<- *pb.ForwardReply) error {
|
||||
defer close(out)
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
}
|
||||
|
||||
first, ok := <-in
|
||||
if !ok {
|
||||
return errors.New("cloud-proxy: Forward stream closed before first request")
|
||||
}
|
||||
|
||||
// Honour the per-request path only when the configured upstream_url
|
||||
// has no path of its own — gallery convention is to put the
|
||||
// canonical path in upstream_url.
|
||||
fullURL, err := composeURL(cfg.upstreamURL, first.GetPath())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method := first.GetMethod()
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
// Pipe the body in from the gRPC stream so the HTTP request can
|
||||
// start before the client finishes sending. The pipe-reader is
|
||||
// closed via CloseWithError on the error paths so the writer
|
||||
// goroutine doesn't block forever.
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go func() {
|
||||
var writeErr error
|
||||
defer func() { _ = pw.CloseWithError(writeErr) }()
|
||||
if len(first.GetBodyChunk()) > 0 {
|
||||
if _, writeErr = pw.Write(first.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
for req := range in {
|
||||
if len(req.GetBodyChunk()) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, writeErr = pw.Write(req.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, fullURL, pr)
|
||||
if err != nil {
|
||||
_ = pr.CloseWithError(err) // unblocks the body-pump's pw.Write
|
||||
return fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
|
||||
// Apply caller-supplied headers, then override with the
|
||||
// authorization header derived from the resolved key. Caller-
|
||||
// supplied Authorization is always replaced — operators may not
|
||||
// know the backend's auth scheme, and silently leaking through a
|
||||
// client Authorization header to a different upstream would
|
||||
// confuse the upstream and could leak credentials.
|
||||
for _, h := range first.GetHeaders() {
|
||||
if h == nil || h.GetName() == "" {
|
||||
continue
|
||||
}
|
||||
// Strip hop-by-hop headers that aren't meaningful to the
|
||||
// upstream (Host is set by the http client from the URL;
|
||||
// Content-Length is computed from the body).
|
||||
if isHopByHopHeader(h.GetName()) {
|
||||
continue
|
||||
}
|
||||
req.Header.Add(h.GetName(), h.GetValue())
|
||||
}
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
|
||||
xlog.Info("cloud-proxy: forward", "method", method, "url", fullURL, "provider", cfg.provider)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: forward upstream failed", "url", fullURL, "error", err)
|
||||
return fmt.Errorf("cloud-proxy: upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
logFn := xlog.Info
|
||||
if resp.StatusCode >= 400 {
|
||||
logFn = xlog.Warn
|
||||
}
|
||||
logFn("cloud-proxy: forward response", "url", fullURL, "status", resp.StatusCode)
|
||||
|
||||
// First reply: status + response headers, no body.
|
||||
headers := make([]*pb.ForwardHeader, 0, len(resp.Header))
|
||||
for k, vs := range resp.Header {
|
||||
for _, v := range vs {
|
||||
headers = append(headers, &pb.ForwardHeader{Name: k, Value: v})
|
||||
}
|
||||
}
|
||||
out <- &pb.ForwardReply{Status: int32(resp.StatusCode), Headers: headers}
|
||||
|
||||
// Subsequent replies: body chunks. Use a fixed 8KB buffer — small
|
||||
// enough that SSE token frames flush promptly, large enough that
|
||||
// long chunked-transfer bodies aren't death by a thousand reads.
|
||||
buf := make([]byte, 8*1024)
|
||||
for {
|
||||
n, rerr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
out <- &pb.ForwardReply{BodyChunk: chunk}
|
||||
}
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("cloud-proxy: upstream body read: %w", rerr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// composeURL combines the configured upstream URL with the per-request
|
||||
// path. The upstream URL typically already includes the canonical path
|
||||
// (e.g. https://api.openai.com/v1/chat/completions) so the per-request
|
||||
// path is ignored in that case. When upstream_url is a bare host
|
||||
// (https://api.openai.com), the request path is appended.
|
||||
func composeURL(upstream, reqPath string) (string, error) {
|
||||
u, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: parse upstream_url %q: %w", upstream, err)
|
||||
}
|
||||
if u.Path == "" || u.Path == "/" {
|
||||
u.Path = reqPath
|
||||
}
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// applyAuthHeader writes the appropriate authorization header for the
|
||||
// provider. OpenAI/Anthropic/most providers use Bearer; Anthropic
|
||||
// historically uses x-api-key + anthropic-version, but accepts Bearer
|
||||
// too via the OpenAI-compatible path. Default to Bearer when provider
|
||||
// is empty (passthrough mode where the operator doesn't claim a
|
||||
// provider).
|
||||
func applyAuthHeader(req *http.Request, provider, key string) {
|
||||
switch provider {
|
||||
case providerAnthropic:
|
||||
req.Header.Set("x-api-key", key)
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
}
|
||||
|
||||
// isHopByHopHeader returns true for headers that should not be
|
||||
// forwarded from the client request to the upstream (RFC 7230 §6.1
|
||||
// hop-by-hop list, plus a few that the http.Client sets itself).
|
||||
func isHopByHopHeader(name string) bool {
|
||||
switch strings.ToLower(name) {
|
||||
case "connection", "proxy-connection", "keep-alive", "transfer-encoding",
|
||||
"te", "trailer", "upgrade", "host", "content-length":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
206
backend/go/cloud-proxy/proxy_test.go
Normal file
206
backend/go/cloud-proxy/proxy_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// helper: run a CloudProxy in-process via grpc.Provide so tests can
|
||||
// call Forward through the public Backend interface without listening
|
||||
// on a real socket.
|
||||
func newInProcClient(t *testing.T, proxy *CloudProxy) grpc.Backend {
|
||||
t.Helper()
|
||||
addr := "test://" + t.Name()
|
||||
grpc.Provide(addr, proxy)
|
||||
return grpc.NewClient(addr, true, nil, false)
|
||||
}
|
||||
|
||||
func TestForward_PassthroughEcho(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Fake upstream: echoes the request body back, prefixed with a
|
||||
// canary so the test can assert both that the body reached the
|
||||
// upstream and the response made it back to the client.
|
||||
gotBody := make(chan string, 1)
|
||||
gotAuth := make(chan string, 1)
|
||||
gotPath := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody <- string(body)
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
gotPath <- r.URL.Path
|
||||
w.Header().Set("X-Echo", "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("echo: " + string(body)))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_FAKE_KEY", "sk-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{{Name: "Content-Type", Value: "application/json"}},
|
||||
BodyChunk: []byte(`{"prompt":`),
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{BodyChunk: []byte(`"hi"}`)})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.CloseSend()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// First reply: status + headers.
|
||||
first, err := stream.Recv()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(first.Status).To(Equal(int32(http.StatusOK)))
|
||||
g.Expect(hasHeader(first.Headers, "X-Echo", "true")).To(BeTrue())
|
||||
|
||||
// Subsequent replies: body.
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
body = append(body, r.BodyChunk...)
|
||||
}
|
||||
g.Expect(string(body)).To(Equal(`echo: {"prompt":"hi"}`))
|
||||
|
||||
// Upstream observations.
|
||||
var gotBodyVal, gotAuthVal, gotPathVal string
|
||||
g.Eventually(gotBody).Should(Receive(&gotBodyVal), "upstream never saw body")
|
||||
g.Expect(gotBodyVal).To(Equal(`{"prompt":"hi"}`))
|
||||
g.Eventually(gotAuth).Should(Receive(&gotAuthVal), "upstream never saw auth header")
|
||||
g.Expect(gotAuthVal).To(Equal("Bearer sk-fake"))
|
||||
g.Eventually(gotPath).Should(Receive(&gotPathVal), "upstream never saw path")
|
||||
g.Expect(gotPathVal).To(Equal("/v1/chat/completions"))
|
||||
}
|
||||
|
||||
func TestForward_AnthropicAuthHeader(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
gotXAPIKey := make(chan string, 1)
|
||||
gotVersion := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotXAPIKey <- r.Header.Get("x-api-key")
|
||||
gotVersion <- r.Header.Get("anthropic-version")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_KEY", "sk-ant-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{Path: "/v1/messages", Method: "POST"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv() // drain status
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
g.Expect(<-gotXAPIKey).To(Equal("sk-ant-fake"))
|
||||
g.Expect(<-gotVersion).NotTo(BeEmpty())
|
||||
}
|
||||
|
||||
func TestLoad_ValidatesConfig(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
|
||||
err := cp.Load(&pb.ModelOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("ProxyOptions"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("upstream_url"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: "rewrite",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unknown mode"))
|
||||
|
||||
// translate + openai should load successfully (Phase 5).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/chat/completions",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// translate + anthropic should load successfully (Phase 6).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/messages",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
ApiKeyEnv: "DEFINITELY_UNSET_ENV_VAR_XYZ",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unset"))
|
||||
}
|
||||
|
||||
func TestForward_RejectsWithoutLoad(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, err = stream.Recv()
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("not loaded"))
|
||||
}
|
||||
|
||||
func hasHeader(hs []*pb.ForwardHeader, name, value string) bool {
|
||||
for _, h := range hs {
|
||||
if strings.EqualFold(h.GetName(), name) && h.GetValue() == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
6
backend/go/cloud-proxy/run.sh
Executable file
6
backend/go/cloud-proxy/run.sh
Executable file
@@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
exec $CURDIR/cloud-proxy "$@"
|
||||
232
backend/go/cloud-proxy/toolcalls_test.go
Normal file
232
backend/go/cloud-proxy/toolcalls_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// OpenAI: non-streaming tool call response. Verify the response is
|
||||
// mapped to Reply.ChatDeltas[].ToolCalls with id/name/arguments intact,
|
||||
// and usage tokens land on Reply.PromptTokens / Reply.Tokens.
|
||||
func TestPredictRich_OpenAI_ToolCalls(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"resp-1",
|
||||
"choices":[{
|
||||
"index":0,
|
||||
"message":{
|
||||
"role":"assistant",
|
||||
"content":"",
|
||||
"tool_calls":[
|
||||
{"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"SF\"}"}},
|
||||
{"id":"call_def","type":"function","function":{"name":"get_time","arguments":"{\"tz\":\"PT\"}"}}
|
||||
]
|
||||
},
|
||||
"finish_reason":"tool_calls"
|
||||
}],
|
||||
"usage":{"prompt_tokens":42,"completion_tokens":18,"total_tokens":60}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(42)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(18)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
tcs := reply.GetChatDeltas()[0].GetToolCalls()
|
||||
g.Expect(tcs).To(HaveLen(2))
|
||||
g.Expect(tcs[0].GetId()).To(Equal("call_abc"))
|
||||
g.Expect(tcs[0].GetName()).To(Equal("get_weather"))
|
||||
g.Expect(tcs[0].GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
g.Expect(tcs[1].GetId()).To(Equal("call_def"))
|
||||
g.Expect(tcs[1].GetName()).To(Equal("get_time"))
|
||||
}
|
||||
|
||||
// OpenAI: streaming tool call. Arguments arrive as a sequence of
|
||||
// delta chunks; the consumer is expected to concatenate by tool index.
|
||||
// Verify each chunk reaches the channel and the assembled arguments
|
||||
// match the input.
|
||||
func TestPredictStreamRich_OpenAI_ToolCallDeltas(t *testing.T) {
|
||||
chunks := []string{
|
||||
// Frame 0: announce the tool call (id + name, no args yet).
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_xyz","type":"function","function":{"name":"search"}}]}}]}`,
|
||||
// Frames 1-3: arguments arrive in fragments.
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"q\":"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"clo"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"uds\"}"}}]}}]}`,
|
||||
// Stop frame.
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find something"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolName string
|
||||
toolID string
|
||||
toolIndex int32 = -1
|
||||
argsBuf strings.Builder
|
||||
)
|
||||
for reply := range results {
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if toolIndex == -1 {
|
||||
toolIndex = tc.GetIndex()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("call_xyz"))
|
||||
g.Expect(toolName).To(Equal("search"))
|
||||
g.Expect(toolIndex).To(Equal(int32(0)))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"clouds"}`))
|
||||
}
|
||||
|
||||
// Anthropic: non-streaming tool_use block. The block appears in
|
||||
// Content[] alongside text blocks; the input field is a structured
|
||||
// JSON object. Map to ToolCallDelta with arguments as serialised JSON
|
||||
// so downstream OpenAI-shaped consumers see a familiar format.
|
||||
func TestPredictRich_Anthropic_ToolUse(t *testing.T) {
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"msg_1","type":"message","role":"assistant",
|
||||
"content":[
|
||||
{"type":"text","text":"Let me check that."},
|
||||
{"type":"tool_use","id":"toolu_01","name":"weather","input":{"location":"SF"}}
|
||||
],
|
||||
"model":"claude","usage":{"input_tokens":12,"output_tokens":34}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal("Let me check that."))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(12)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(34)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
g.Expect(reply.GetChatDeltas()[0].GetToolCalls()).To(HaveLen(1))
|
||||
tc := reply.GetChatDeltas()[0].GetToolCalls()[0]
|
||||
g.Expect(tc.GetId()).To(Equal("toolu_01"))
|
||||
g.Expect(tc.GetName()).To(Equal("weather"))
|
||||
g.Expect(tc.GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
}
|
||||
|
||||
// Anthropic: streaming tool_use. content_block_start announces the
|
||||
// tool's id + name; input_json_delta events carry argument fragments
|
||||
// which the consumer accumulates. message_delta carries final usage.
|
||||
func TestPredictStreamRich_Anthropic_InputJSONDelta(t *testing.T) {
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
// Block 0 is a tool_use; consumer should allocate a slot.
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_42\",\"name\":\"lookup\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"q\\\":\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"\\\"rain\\\"}\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "rain?"}},
|
||||
Tokens: 64,
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolID, toolName string
|
||||
argsBuf strings.Builder
|
||||
finalTokens int32
|
||||
)
|
||||
for reply := range results {
|
||||
if reply.GetTokens() > 0 && len(reply.GetChatDeltas()) == 0 {
|
||||
finalTokens = reply.GetTokens()
|
||||
continue
|
||||
}
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("toolu_42"))
|
||||
g.Expect(toolName).To(Equal("lookup"))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"rain"}`))
|
||||
g.Expect(finalTokens).To(Equal(int32(7)))
|
||||
}
|
||||
|
||||
// Sanity: the legacy Predict() (string, error) signature still works
|
||||
// — it delegates to PredictRich and extracts Message.
|
||||
func TestPredict_LegacyWrapper_OpenAI(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "hi"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello"))
|
||||
}
|
||||
@@ -8,6 +8,6 @@ import (
|
||||
|
||||
func assert(cond bool, msg string) {
|
||||
if !cond {
|
||||
xlog.Fatal().Stack().Msg(msg)
|
||||
xlog.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
package main
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
// LocalAI's in-process vector store, exposed as a gRPC backend. Keep
|
||||
// the implementation here — NOT in a pkg/ library imported by the main
|
||||
// LocalAI process. The whole point of the gRPC surface is that vector
|
||||
// storage is a backend like any other (local-store, qdrant, pinecone,
|
||||
// ...) and can be swapped without changing the routing/recognition
|
||||
// code that consumes it.
|
||||
//
|
||||
// Storage is a sorted parallel-slice (keys [][]float32, values
|
||||
// [][]byte). Set/Delete preserve the sort so Get can binary-search.
|
||||
// Find scans linearly and uses a heap to keep the top-K — fine for
|
||||
// the tens-to-thousands range. The "normalized fast path" (Find when
|
||||
// every stored key has unit magnitude AND the query is normalized)
|
||||
// skips the per-item magnitude calculation.
|
||||
//
|
||||
// Concurrency: base.SingleThread serialises gRPC calls so the
|
||||
// non-thread-safe slice/heap manipulation here is sound.
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
@@ -10,30 +25,27 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
base.SingleThread
|
||||
|
||||
// The sorted keys
|
||||
keys [][]float32
|
||||
// The sorted values
|
||||
keys [][]float32
|
||||
values [][]byte
|
||||
|
||||
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
|
||||
// TODO: Should we normalize incoming keys if they are not instead?
|
||||
// keysAreNormalized stays true until any non-unit-magnitude key
|
||||
// is added; once false, the magnitude-aware fallback path is
|
||||
// used by Find. Re-evaluated only at Set time, never again on
|
||||
// its own — a deletion of the offending key does NOT flip it
|
||||
// back to true (the bookkeeping cost would dominate the gain).
|
||||
keysAreNormalized bool
|
||||
// The first key decides the length of the keys
|
||||
keyLen int
|
||||
}
|
||||
|
||||
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
|
||||
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
|
||||
type Pair struct {
|
||||
Key []float32
|
||||
Value []byte
|
||||
// keyLen is the dimension of every stored key. -1 means "no
|
||||
// keys yet, dimension is open". Dimension mismatch on Set is
|
||||
// rejected so cosine similarity (which requires equal-length
|
||||
// vectors) doesn't silently mis-match.
|
||||
keyLen int
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
@@ -45,334 +57,278 @@ func NewStore() *Store {
|
||||
}
|
||||
}
|
||||
|
||||
func compareSlices(k1, k2 []float32) int {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
return slices.Compare(k1, k2)
|
||||
}
|
||||
|
||||
func hasKey(unsortedSlice [][]float32, target []float32) bool {
|
||||
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
|
||||
return compareSlices(k, target) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
|
||||
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
|
||||
return compareSlices(k, t)
|
||||
})
|
||||
}
|
||||
|
||||
func isSortedPairs(kvs []Pair) bool {
|
||||
for i := 1; i < len(kvs); i++ {
|
||||
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isSortedKeys(keys [][]float32) bool {
|
||||
for i := 1; i < len(keys); i++ {
|
||||
if compareSlices(keys[i-1], keys[i]) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
|
||||
ks := make([][]float32, len(keys))
|
||||
|
||||
for i, k := range keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
|
||||
slices.SortFunc(ks, compareSlices)
|
||||
|
||||
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
|
||||
assert(isSortedKeys(ks), "keys are not sorted")
|
||||
|
||||
return ks
|
||||
}
|
||||
|
||||
// Load is a no-op — local-store has no on-disk artefact. opts.Model is
|
||||
// just a namespace identifier; isolation is already handled upstream
|
||||
// (ModelLoader spawns a fresh local-store process per (backend,
|
||||
// model) tuple, so each namespace is its own Store{} instance).
|
||||
func (s *Store) Load(opts *pb.ModelOptions) error {
|
||||
// local-store is an in-memory vector store with no on-disk artefact to
|
||||
// load — opts.Model is just a namespace identifier. The old `!= ""` guard
|
||||
// rejected any non-empty model name with "not implemented", which broke
|
||||
// callers that pass a namespace to isolate embedding spaces (face vs.
|
||||
// voice biometrics both go through local-store but need distinct stores
|
||||
// so ArcFace 512-D and ECAPA-TDNN 192-D don't collide). Namespace
|
||||
// isolation is already handled upstream: ModelLoader spawns a fresh
|
||||
// local-store process per (backend, model) tuple, so each namespace is
|
||||
// its own Store{} instance. Nothing to do here beyond accepting the load.
|
||||
_ = opts
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort the incoming kvs and merge them with the existing sorted kvs
|
||||
func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
values := store.UnwrapValues(opts.Values)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Set: no keys to add")
|
||||
}
|
||||
|
||||
if len(opts.Keys) != len(opts.Values) {
|
||||
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
|
||||
if len(keys) != len(values) {
|
||||
return fmt.Errorf("local-store: Set: len(keys) = %d, len(values) = %d", len(keys), len(values))
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
s.keyLen = len(keys[0])
|
||||
} else if len(keys[0]) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key length %d does not match existing %d", len(keys[0]), s.keyLen)
|
||||
}
|
||||
|
||||
kvs := make([]Pair, len(opts.Keys))
|
||||
|
||||
for i, k := range opts.Keys {
|
||||
if s.keysAreNormalized && !isNormalized(k.Floats) {
|
||||
kvs := make([]incomingPair, len(keys))
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
if s.keysAreNormalized && !isNormalized(k) {
|
||||
s.keysAreNormalized = false
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = k.Floats[:5]
|
||||
} else {
|
||||
sample = k.Floats
|
||||
}
|
||||
xlog.Debug("Key is not normalized", "sample", sample)
|
||||
}
|
||||
|
||||
kvs[i] = Pair{
|
||||
Key: k.Floats,
|
||||
Value: opts.Values[i].Bytes,
|
||||
}
|
||||
kvs[i] = incomingPair{key: k, value: values[i]}
|
||||
}
|
||||
|
||||
slices.SortFunc(kvs, func(a, b Pair) int {
|
||||
return compareSlices(a.Key, b.Key)
|
||||
})
|
||||
|
||||
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
|
||||
assert(isSortedPairs(kvs), "keys are not sorted")
|
||||
|
||||
l := len(kvs) + len(s.keys)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
i, j := 0, 0
|
||||
for {
|
||||
if i+j >= l {
|
||||
break
|
||||
}
|
||||
|
||||
if i >= len(kvs) {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
continue
|
||||
}
|
||||
|
||||
if j >= len(s.keys) {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
c := compareSlices(kvs[i].Key, s.keys[j])
|
||||
if c < 0 {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
} else if c > 0 {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
} else {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
|
||||
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
slices.SortFunc(kvs, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) })
|
||||
|
||||
merged := mergeSortedPairs(s.keys, s.values, kvs)
|
||||
s.keys = merged.keys
|
||||
s.values = merged.values
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Set: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Set: keys/values length skew")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to delete")
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Delete: no keys to delete")
|
||||
}
|
||||
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
l := len(s.keys) - len(ks)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
tail_ks := s.keys
|
||||
tail_vs := s.values
|
||||
for _, k := range ks {
|
||||
j, found := findInSortedSlice(tail_ks, k)
|
||||
|
||||
if found {
|
||||
merge_ks = append(merge_ks, tail_ks[:j]...)
|
||||
merge_vs = append(merge_vs, tail_vs[:j]...)
|
||||
tail_ks = tail_ks[j+1:]
|
||||
tail_vs = tail_vs[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
|
||||
}
|
||||
|
||||
xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs))
|
||||
}
|
||||
|
||||
merge_ks = append(merge_ks, tail_ks...)
|
||||
merge_vs = append(merge_vs, tail_vs...)
|
||||
|
||||
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
|
||||
assert(isSortedKeys(s.keys), "keys are not sorted")
|
||||
assert(func() bool {
|
||||
for _, k := range ks {
|
||||
if _, found := findInSortedSlice(s.keys, k); found {
|
||||
return false
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Delete: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}(), "Keys to delete still present")
|
||||
|
||||
if len(s.keys) != l {
|
||||
xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l)
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
|
||||
mergedK := make([][]float32, 0, len(s.keys))
|
||||
mergedV := make([][]byte, 0, len(s.keys))
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if ok {
|
||||
mergedK = append(mergedK, tailK[:j]...)
|
||||
mergedV = append(mergedV, tailV[:j]...)
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
}
|
||||
mergedK = append(mergedK, tailK...)
|
||||
mergedV = append(mergedV, tailV...)
|
||||
s.keys = mergedK
|
||||
s.values = mergedV
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Delete: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Delete: keys/values length skew")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StoresGet fetches values for the given keys. Missing keys are
|
||||
// omitted from the result rather than reported as an error — callers
|
||||
// compare returned-key length against requested-key length to detect
|
||||
// them. Returned slices are aligned.
|
||||
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys))
|
||||
pbValues := make([]*pb.StoresValue, 0, len(opts.Keys))
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
if len(s.keys) == 0 {
|
||||
xlog.Debug("Get: No keys in store")
|
||||
return pb.StoresGetResult{}, nil
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("local-store: Get: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
|
||||
tail_k := s.keys
|
||||
tail_v := s.values
|
||||
for i, k := range ks {
|
||||
j, found := findInSortedSlice(tail_k, k)
|
||||
|
||||
if found {
|
||||
pbKeys = append(pbKeys, &pb.StoresKey{
|
||||
Floats: k,
|
||||
})
|
||||
pbValues = append(pbValues, &pb.StoresValue{
|
||||
Bytes: tail_v[j],
|
||||
})
|
||||
|
||||
tail_k = tail_k[j+1:]
|
||||
tail_v = tail_v[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k))
|
||||
var foundKeys [][]float32
|
||||
var foundValues [][]byte
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
foundKeys = append(foundKeys, tailK[j])
|
||||
foundValues = append(foundValues, tailV[j])
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
|
||||
if len(pbKeys) != len(opts.Keys) {
|
||||
xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys))
|
||||
}
|
||||
|
||||
return pb.StoresGetResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Keys: store.WrapKeys(foundKeys),
|
||||
Values: store.WrapValues(foundValues),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StoresFind returns the topK nearest stored entries by cosine
|
||||
// similarity, ordered most-similar first. An empty store returns
|
||||
// empty slices and no error.
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
query := opts.Key.Floats
|
||||
topK := int(opts.TopK)
|
||||
if topK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: topK = %d, must be >= 1", topK)
|
||||
}
|
||||
if len(s.keys) == 0 {
|
||||
return pb.StoresFindResult{}, nil
|
||||
}
|
||||
if len(query) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: query length %d does not match existing %d", len(query), s.keyLen)
|
||||
}
|
||||
|
||||
var keys [][]float32
|
||||
var values [][]byte
|
||||
var sims []float32
|
||||
if s.keysAreNormalized && isNormalized(query) {
|
||||
keys, values, sims = s.findNormalized(query, topK)
|
||||
} else {
|
||||
keys, values, sims = s.findFallback(query, topK)
|
||||
}
|
||||
return pb.StoresFindResult{
|
||||
Keys: store.WrapKeys(keys),
|
||||
Values: store.WrapValues(values),
|
||||
Similarities: sims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) findNormalized(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
assert(s.keysAreNormalized, "findNormalized: s.keysAreNormalized is false")
|
||||
assert(isNormalized(query), "findNormalized: query is not unit-length")
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot float32
|
||||
for j := range k {
|
||||
dot += query[j] * k[j]
|
||||
}
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("findNormalized: dot %f out of [-1, 1] — keysAreNormalized invariant violated", dot))
|
||||
heap.Push(&pq, &priorityItem{similarity: dot, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func (s *Store) findFallback(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
var qmag float64
|
||||
for _, v := range query {
|
||||
qmag += float64(v) * float64(v)
|
||||
}
|
||||
qmag = math.Sqrt(qmag)
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot, kmag float64
|
||||
for j := range k {
|
||||
dot += float64(query[j]) * float64(k[j])
|
||||
kmag += float64(k[j]) * float64(k[j])
|
||||
}
|
||||
denom := qmag * math.Sqrt(kmag)
|
||||
var sim float32
|
||||
if denom > 0 {
|
||||
sim = float32(dot / denom)
|
||||
}
|
||||
heap.Push(&pq, &priorityItem{similarity: sim, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func isNormalized(k []float32) bool {
|
||||
var sum float64
|
||||
|
||||
for _, v := range k {
|
||||
v64 := float64(v)
|
||||
sum += v64 * v64
|
||||
sum += float64(v) * float64(v)
|
||||
}
|
||||
|
||||
s := math.Sqrt(sum)
|
||||
|
||||
return s >= 0.99 && s <= 1.01
|
||||
mag := math.Sqrt(sum)
|
||||
return mag >= 0.99 && mag <= 1.01
|
||||
}
|
||||
|
||||
// TODO: This we could replace with handwritten SIMD code
|
||||
func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
type incomingPair struct {
|
||||
key []float32
|
||||
value []byte
|
||||
}
|
||||
|
||||
var dot float32
|
||||
for i := range len(k1) {
|
||||
dot += k1[i] * k2[i]
|
||||
type pairs struct {
|
||||
keys [][]float32
|
||||
values [][]byte
|
||||
}
|
||||
|
||||
// mergeSortedPairs merges (existing, incoming) into a fresh sorted
|
||||
// slice. Equal keys take the incoming value — Set is upsert.
|
||||
func mergeSortedPairs(existingK [][]float32, existingV [][]byte, incoming []incomingPair) pairs {
|
||||
assert(slices.IsSortedFunc(existingK, slices.Compare[[]float32]), "mergeSortedPairs: existing not sorted")
|
||||
assert(slices.IsSortedFunc(incoming, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) }), "mergeSortedPairs: incoming not sorted")
|
||||
l := len(existingK) + len(incoming)
|
||||
mk := make([][]float32, 0, l)
|
||||
mv := make([][]byte, 0, l)
|
||||
i, j := 0, 0
|
||||
for i < len(incoming) || j < len(existingK) {
|
||||
switch {
|
||||
case j >= len(existingK):
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case i >= len(incoming):
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
c := slices.Compare(incoming[i].key, existingK[j])
|
||||
switch {
|
||||
case c < 0:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case c > 0:
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot))
|
||||
|
||||
// 2.0 * (1.0 - dot) would be the Euclidean distance
|
||||
return dot
|
||||
return pairs{keys: mk, values: mv}
|
||||
}
|
||||
|
||||
type PriorityItem struct {
|
||||
Similarity float32
|
||||
Key []float32
|
||||
Value []byte
|
||||
type priorityItem struct {
|
||||
similarity float32
|
||||
key []float32
|
||||
value []byte
|
||||
}
|
||||
|
||||
type PriorityQueue []*PriorityItem
|
||||
type priorityQueue []*priorityItem
|
||||
|
||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq PriorityQueue) Less(i, j int) bool {
|
||||
// Inverted because the most similar should be at the top
|
||||
return pq[i].Similarity < pq[j].Similarity
|
||||
}
|
||||
|
||||
func (pq PriorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Push(x any) {
|
||||
item := x.(*PriorityItem)
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Pop() any {
|
||||
func (pq priorityQueue) Len() int { return len(pq) }
|
||||
func (pq priorityQueue) Less(i, j int) bool { return pq[i].similarity < pq[j].similarity }
|
||||
func (pq priorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
|
||||
func (pq *priorityQueue) Push(x any) { *pq = append(*pq, x.(*priorityItem)) }
|
||||
func (pq *priorityQueue) Pop() any {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
@@ -380,142 +336,16 @@ func (pq *PriorityQueue) Pop() any {
|
||||
return item
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
for i, k := range s.keys {
|
||||
sim := normalizedCosineSimilarity(tk, k)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: sim,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot, mag2 float64
|
||||
for i := range len(k1) {
|
||||
dot += float64(k1[i] * k2[i])
|
||||
mag2 += float64(k2[i] * k2[i])
|
||||
}
|
||||
|
||||
sim := float32(dot / (mag1 * math.Sqrt(mag2)))
|
||||
|
||||
assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim))
|
||||
|
||||
return sim
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
var mag1 float64
|
||||
for _, v := range tk {
|
||||
mag1 += float64(v * v)
|
||||
}
|
||||
mag1 = math.Sqrt(mag1)
|
||||
|
||||
for i, k := range s.keys {
|
||||
dist := cosineSimilarity(tk, k, mag1)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: dist,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
|
||||
if len(tk) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen)
|
||||
}
|
||||
|
||||
if opts.TopK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK)
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Key.Floats)
|
||||
} else {
|
||||
if len(opts.Key.Floats) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
if s.keysAreNormalized && isNormalized(tk) {
|
||||
return s.StoresFindNormalized(opts)
|
||||
} else {
|
||||
if s.keysAreNormalized {
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = tk[:5]
|
||||
} else {
|
||||
sample = tk
|
||||
}
|
||||
xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample)
|
||||
}
|
||||
|
||||
return s.StoresFindFallback(opts)
|
||||
func drainPQ(pq *priorityQueue) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
n := pq.Len()
|
||||
keys = make([][]float32, n)
|
||||
values = make([][]byte, n)
|
||||
similarities = make([]float32, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
item := heap.Pop(pq).(*priorityItem)
|
||||
keys[i] = item.key
|
||||
values[i] = item.value
|
||||
similarities[i] = item.similarity
|
||||
}
|
||||
return keys, values, similarities
|
||||
}
|
||||
|
||||
13
backend/go/local-store/store_suite_test.go
Normal file
13
backend/go/local-store/store_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLocalStore(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "local-store test suite")
|
||||
}
|
||||
284
backend/go/local-store/store_test.go
Normal file
284
backend/go/local-store/store_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package main
|
||||
|
||||
// Regression suite for the local-store gRPC backend. Exercises the
|
||||
// Stores{Set,Get,Find,Delete} surface — the only public contract.
|
||||
// Callers (face/voice recognition, the routing KNN classifier) reach
|
||||
// this code via grpc.Backend, so testing at the wire-shaped boundary
|
||||
// matches the production import shape.
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("StoresSet", func() {
|
||||
It("rejects empty input", func() {
|
||||
Expect(NewStore().StoresSet(&pb.StoresSetOptions{})).NotTo(Succeed(), "Set with no keys should fail")
|
||||
})
|
||||
|
||||
It("rejects key/value length mismatch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("a"), []byte("b")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "len(keys) != len(values) should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch on later add", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("3d")})
|
||||
err := s.StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "dimension mismatch on later Set should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch within batch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}, {1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("3d"), []byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "mixed-dimension within one batch should fail")
|
||||
})
|
||||
|
||||
It("merges sorted and updates existing key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.3, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("c"), []byte("a")})
|
||||
mustSet(s, [][]float32{{0.2, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("b"), []byte("a-updated")})
|
||||
Expect(s.keys).To(HaveLen(3))
|
||||
got := singleGet(s, []float32{0.1, 0, 0})
|
||||
Expect(string(got)).To(Equal("a-updated"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresGet", func() {
|
||||
It("round-trips multi-key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c")},
|
||||
)
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("omits missing keys rather than erroring", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.1, 0, 0}, {0.9, 0, 0}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresDelete", func() {
|
||||
It("removes and preserves sort", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0, 0}, {0.2, 0, 0}, {0.3, 0, 0}, {0.4, 0, 0}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c"), []byte("d")},
|
||||
)
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.2, 0, 0}, {0.4, 0, 0}}),
|
||||
})).To(Succeed())
|
||||
Expect(s.keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("tolerates missing keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.9, 0, 0}}),
|
||||
})).To(Succeed(), "delete of missing key should succeed")
|
||||
Expect(s.keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresFind", func() {
|
||||
It("returns normalized top-K", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{
|
||||
normalizeVec([]float32{1, 0, 0}),
|
||||
normalizeVec([]float32{0, 1, 0}),
|
||||
normalizeVec([]float32{0, 0, 1}),
|
||||
},
|
||||
[][]byte{[]byte("x"), []byte("y"), []byte("z")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{0.9, 0.1, 0})},
|
||||
TopK: 2,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", res.Similarities[1]), "results not sorted desc by similarity")
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
})
|
||||
|
||||
It("falls back for non-normalized keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{2, 0, 0}, {0, 3, 0}}, [][]byte{[]byte("x"), []byte("y")})
|
||||
Expect(s.keysAreNormalized).To(BeFalse(), "store should report non-normalized after Set with magnitude > 1")
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{4, 0, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", float32(0.99)))
|
||||
Expect(res.Similarities[0]).To(BeNumerically("<=", float32(1.01)))
|
||||
})
|
||||
|
||||
It("rejects zero topK", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 0,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with topK=0 should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with mismatched dimension should fail")
|
||||
})
|
||||
|
||||
It("returns empty result on empty store", func() {
|
||||
res, err := NewStore().StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 5,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred(), "Find on empty store should succeed")
|
||||
Expect(res.Keys).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("handles topK larger than store", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{normalizeVec([]float32{1, 0, 0}), normalizeVec([]float32{0, 1, 0})},
|
||||
[][]byte{[]byte("x"), []byte("y")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{1, 0, 0})},
|
||||
TopK: 10,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresLoad", func() {
|
||||
It("is a no-op", func() {
|
||||
Expect(NewStore().Load(&pb.ModelOptions{Model: "any-namespace"})).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
func BenchmarkStoresFindNormalized(b *testing.B) {
|
||||
const dim = 768
|
||||
for _, n := range []int{8, 32, 128, 512} {
|
||||
b.Run(fmtN(n), func(b *testing.B) {
|
||||
s := buildStore(b, n, dim)
|
||||
query := normalizeVec(randVec(dim, 42))
|
||||
req := &pb.StoresFindOptions{Key: &pb.StoresKey{Floats: query}, TopK: 1}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := s.StoresFind(req); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
func mustSet(s *Store, keys [][]float32, values [][]byte) {
|
||||
ExpectWithOffset(1, s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)})).To(Succeed())
|
||||
}
|
||||
|
||||
func singleGet(s *Store, key []float32) []byte {
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{Keys: wrapKeys([][]float32{key})})
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
if len(res.Values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return res.Values[0].Bytes
|
||||
}
|
||||
|
||||
func wrapKeys(in [][]float32) []*pb.StoresKey {
|
||||
out := make([]*pb.StoresKey, len(in))
|
||||
for i, k := range in {
|
||||
out[i] = &pb.StoresKey{Floats: k}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func wrapValues(in [][]byte) []*pb.StoresValue {
|
||||
out := make([]*pb.StoresValue, len(in))
|
||||
for i, v := range in {
|
||||
out[i] = &pb.StoresValue{Bytes: v}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildStore(tb testing.TB, n, dim int) *Store {
|
||||
tb.Helper()
|
||||
s := NewStore()
|
||||
keys := make([][]float32, n)
|
||||
values := make([][]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
keys[i] = normalizeVec(randVec(dim, int64(i)+1))
|
||||
values[i] = []byte{byte(i)}
|
||||
}
|
||||
if err := s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)}); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func randVec(dim int, seed int64) []float32 {
|
||||
r := rand.New(rand.NewPCG(uint64(seed), 0xabcdef))
|
||||
v := make([]float32, dim)
|
||||
for i := range v {
|
||||
v[i] = float32(r.NormFloat64())
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func normalizeVec(v []float32) []float32 {
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
mag := math.Sqrt(sum)
|
||||
if mag == 0 {
|
||||
return v
|
||||
}
|
||||
out := make([]float32, len(v))
|
||||
for i, x := range v {
|
||||
out[i] = float32(float64(x) / mag)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fmtN(n int) string {
|
||||
return map[int]string{8: "n=8", 32: "n=32", 128: "n=128", 512: "n=512"}[n]
|
||||
}
|
||||
@@ -376,6 +376,8 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *clip_g_path = "";
|
||||
const char *t5xxl_path = "";
|
||||
const char *vae_path = "";
|
||||
const char *audio_vae_path = "";
|
||||
const char *embeddings_connectors_path = "";
|
||||
const char *scheduler_str = "";
|
||||
const char *sampler = "";
|
||||
const char *clip_vision_path = "";
|
||||
@@ -431,6 +433,12 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "vae_path")) {
|
||||
vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "audio_vae_path")) {
|
||||
audio_vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "embeddings_connectors_path")) {
|
||||
embeddings_connectors_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "scheduler")) {
|
||||
scheduler_str = optval;
|
||||
}
|
||||
@@ -563,6 +571,8 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
ctx_params.taesd_path = taesd_path;
|
||||
ctx_params.control_net_path = control_net_path;
|
||||
if (lora_dir && strlen(lora_dir) > 0) {
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch.cuda
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
import transformers as transformers_module
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, pipeline
|
||||
from scipy.io import wavfile
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
@@ -200,6 +200,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
autoTokenizer = False
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.SentenceTransformer = True
|
||||
elif request.Type == "TokenClassification":
|
||||
# NER / PII tagging via HuggingFace's token-classification
|
||||
# pipeline. aggregation_strategy="simple" merges B-/I- tags
|
||||
# into single spans and gives byte offsets back. The
|
||||
# tokenizer is bundled inside the pipeline, so we skip the
|
||||
# AutoTokenizer load below.
|
||||
autoTokenizer = False
|
||||
self.tokenClassifier = pipeline(
|
||||
"token-classification",
|
||||
model=model_name,
|
||||
aggregation_strategy="simple",
|
||||
device=0 if self.CUDA else -1,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
)
|
||||
self.TokenClassification = True
|
||||
else:
|
||||
# Generic: dynamically resolve model class from transformers
|
||||
model_type = TYPE_ALIASES.get(request.Type, request.Type)
|
||||
@@ -253,6 +268,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TokenClassify(self, request, context):
|
||||
# Runs HuggingFace's token-classification pipeline and returns
|
||||
# the aggregated entity spans. The pipeline gives us byte
|
||||
# offsets via aggregation_strategy="simple" (set at load
|
||||
# time), so the caller can slice the original text without
|
||||
# re-tokenising on the Go side.
|
||||
if not getattr(self, "TokenClassification", False):
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("model was not loaded as Type=TokenClassification")
|
||||
return backend_pb2.TokenClassifyResponse()
|
||||
try:
|
||||
results = self.tokenClassifier(request.text)
|
||||
except Exception as err:
|
||||
print("TokenClassify error:", err, file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"token-classification failed: {err}")
|
||||
return backend_pb2.TokenClassifyResponse()
|
||||
|
||||
threshold = request.threshold if request.threshold > 0 else 0.0
|
||||
entities = []
|
||||
for r in results:
|
||||
score = float(r.get("score", 0.0))
|
||||
if score < threshold:
|
||||
continue
|
||||
entities.append(backend_pb2.TokenClassifyEntity(
|
||||
entity_group=str(r.get("entity_group") or r.get("entity") or ""),
|
||||
start=int(r.get("start", 0)),
|
||||
end=int(r.get("end", 0)),
|
||||
score=score,
|
||||
text=str(r.get("word", "")),
|
||||
))
|
||||
return backend_pb2.TokenClassifyResponse(entities=entities)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
set_seed(request.Seed)
|
||||
# Tokenize input
|
||||
|
||||
@@ -356,6 +356,133 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
async def Score(self, request, context):
|
||||
"""
|
||||
Joint log-probability of each candidate continuation given the
|
||||
shared prompt. Used by routing-policy multi-label classification
|
||||
(read the distribution rather than asking the model to emit a
|
||||
single argmax label), reranking, and reward-model scoring.
|
||||
|
||||
Implementation uses vLLM's `prompt_logprobs` to recover the
|
||||
per-token log P(token_i | tokens_<i) for the full concatenated
|
||||
sequence; the candidate's tokens are the suffix whose logprobs
|
||||
get summed. max_tokens=1 because vLLM requires at least one
|
||||
generated token; the generated token is discarded.
|
||||
"""
|
||||
if not hasattr(self, 'llm') or self.llm is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Model not loaded")
|
||||
return backend_pb2.ScoreResponse()
|
||||
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Tokenizer not available")
|
||||
return backend_pb2.ScoreResponse()
|
||||
if len(request.candidates) == 0:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("candidates must be non-empty")
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
try:
|
||||
prompt = request.prompt or ""
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
results = []
|
||||
|
||||
for candidate in request.candidates:
|
||||
# Tokenise the concatenated sequence. We can't naively
|
||||
# use len(prompt_tokens) + len(tokenizer.encode(candidate))
|
||||
# because BPE merges at the boundary may produce a
|
||||
# different tokenisation. Encoding the joined text and
|
||||
# walking the divergence point is the correct primitive.
|
||||
full_text = prompt + candidate
|
||||
full_token_ids = self.tokenizer.encode(full_text)
|
||||
|
||||
divergence = prompt_len
|
||||
min_len = min(prompt_len, len(full_token_ids))
|
||||
for i in range(min_len):
|
||||
if prompt_token_ids[i] != full_token_ids[i]:
|
||||
divergence = i
|
||||
break
|
||||
|
||||
candidate_token_ids = full_token_ids[divergence:]
|
||||
num_candidate_tokens = len(candidate_token_ids)
|
||||
if num_candidate_tokens == 0:
|
||||
results.append(backend_pb2.CandidateScore(
|
||||
log_prob=0.0,
|
||||
length_normalized_log_prob=0.0,
|
||||
num_tokens=0,
|
||||
))
|
||||
continue
|
||||
|
||||
sampling = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
prompt_logprobs=1,
|
||||
detokenize=False,
|
||||
)
|
||||
|
||||
request_id = random_uuid()
|
||||
last_output = None
|
||||
outputs_iter = self.llm.generate(
|
||||
{"prompt": full_text},
|
||||
sampling_params=sampling,
|
||||
request_id=request_id,
|
||||
)
|
||||
try:
|
||||
async for out in outputs_iter:
|
||||
last_output = out
|
||||
finally:
|
||||
try:
|
||||
await outputs_iter.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if last_output is None or not getattr(last_output, "prompt_logprobs", None):
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details("vLLM did not return prompt_logprobs")
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
prompt_logprobs = last_output.prompt_logprobs
|
||||
total = 0.0
|
||||
tokens_proto = []
|
||||
for offset, tok_id in enumerate(candidate_token_ids):
|
||||
position = divergence + offset
|
||||
if position >= len(prompt_logprobs) or prompt_logprobs[position] is None:
|
||||
continue
|
||||
entry = prompt_logprobs[position]
|
||||
lp_obj = entry.get(tok_id)
|
||||
if lp_obj is not None:
|
||||
lp = lp_obj.logprob
|
||||
else:
|
||||
# Token not in top-K; vLLM's top-1 may miss it.
|
||||
# Fall back to the lowest available logprob in the
|
||||
# entry — a conservative lower-bound on the true
|
||||
# log P, biased against this candidate.
|
||||
lp = min(v.logprob for v in entry.values())
|
||||
total += lp
|
||||
if request.include_token_logprobs:
|
||||
tokens_proto.append(backend_pb2.TokenLogProb(
|
||||
token=self.tokenizer.decode([tok_id]),
|
||||
log_prob=lp,
|
||||
))
|
||||
|
||||
cs = backend_pb2.CandidateScore(
|
||||
log_prob=total,
|
||||
num_tokens=num_candidate_tokens,
|
||||
)
|
||||
if request.length_normalize and num_candidate_tokens > 0:
|
||||
cs.length_normalized_log_prob = total / num_candidate_tokens
|
||||
if tokens_proto:
|
||||
cs.tokens.extend(tokens_proto)
|
||||
results.append(cs)
|
||||
|
||||
return backend_pb2.ScoreResponse(candidates=results)
|
||||
except Exception as e:
|
||||
print(f"Score error: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.ScoreResponse()
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
# Build the sampling parameters
|
||||
# NOTE: this must stay in sync with the vllm backend
|
||||
|
||||
@@ -9,11 +9,18 @@ import (
|
||||
|
||||
corebackend "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services/facerecognition"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/voicerecognition"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
@@ -51,6 +58,22 @@ type Application struct {
|
||||
faceRegistry facerecognition.Registry
|
||||
voiceRegistry voicerecognition.Registry
|
||||
authDB *gorm.DB
|
||||
metricsService *monitoring.LocalAIMetricsService
|
||||
statsRecorder *billing.Recorder
|
||||
fallbackUser *auth.User
|
||||
piiRedactor *pii.Redactor
|
||||
piiEvents pii.EventStore
|
||||
mitmCA atomic.Pointer[mitm.CA]
|
||||
mitmServer atomic.Pointer[mitm.Server]
|
||||
mitmMutex sync.Mutex // serializes Stop+Start; readers use atomic loads
|
||||
// mitmHostConflicts records duplicate-host claims across model configs.
|
||||
// Non-empty disables the MITM listener until resolved — the strict
|
||||
// 1-to-1 host↔model invariant the dispatcher relies on. Read by
|
||||
// /api/middleware/status so the admin UI can surface the cause.
|
||||
mitmHostConflicts atomic.Pointer[map[string][]string]
|
||||
routerDecisions router.DecisionStore
|
||||
routerRegistry *router.Registry
|
||||
admissionLimiter *admission.Limiter
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
p2pMutex sync.Mutex
|
||||
@@ -185,6 +208,103 @@ func (a *Application) AuthDB() *gorm.DB {
|
||||
return a.authDB
|
||||
}
|
||||
|
||||
// MetricsService returns the OTel + Prometheus metric service. nil when
|
||||
// --disable-metrics is set or initialisation failed at startup.
|
||||
//
|
||||
// The service is created in startup.go before any counter is registered
|
||||
// so that otel.SetMeterProvider runs early enough for the billing
|
||||
// recorder's counters to bind to the Prom-backed provider rather than
|
||||
// the no-op global. core/http/app.go reuses this instance instead of
|
||||
// constructing its own — two providers would orphan one set of counters
|
||||
// behind whichever provider lost the SetMeterProvider race.
|
||||
func (a *Application) MetricsService() *monitoring.LocalAIMetricsService {
|
||||
return a.metricsService
|
||||
}
|
||||
|
||||
// StatsRecorder returns the billing recorder used by the usage
|
||||
// middleware. It is non-nil whenever stats are not explicitly disabled
|
||||
// — i.e., the no-auth single-user path still gets a working recorder
|
||||
// (in-memory by default). Routes register UsageMiddleware against this
|
||||
// recorder regardless of auth state.
|
||||
func (a *Application) StatsRecorder() *billing.Recorder {
|
||||
return a.statsRecorder
|
||||
}
|
||||
|
||||
// FallbackUser is the synthetic "local" user that UsageMiddleware uses
|
||||
// to attribute requests when no authenticated user is on the context
|
||||
// (i.e., --auth is off). nil when auth is on, since real users are
|
||||
// always available there.
|
||||
func (a *Application) FallbackUser() *auth.User {
|
||||
return a.fallbackUser
|
||||
}
|
||||
|
||||
// PIIRedactor returns the regex-tier PII redactor or nil if PII
|
||||
// filtering is disabled. The chat-route middleware uses this to apply
|
||||
// redaction before dispatch.
|
||||
func (a *Application) PIIRedactor() *pii.Redactor {
|
||||
return a.piiRedactor
|
||||
}
|
||||
|
||||
// PIIEvents returns the PII event store. Same nil-when-disabled
|
||||
// semantics as PIIRedactor; admin REST and MCP read tools call List
|
||||
// against it.
|
||||
func (a *Application) PIIEvents() pii.EventStore {
|
||||
return a.piiEvents
|
||||
}
|
||||
|
||||
// MITMCA returns the cloudproxy MITM proxy's CA, or nil when the
|
||||
// MITM listener is disabled.
|
||||
func (a *Application) MITMCA() *mitm.CA { return a.mitmCA.Load() }
|
||||
|
||||
// MITMServer returns the running MITM proxy or nil.
|
||||
func (a *Application) MITMServer() *mitm.Server { return a.mitmServer.Load() }
|
||||
|
||||
// MITMHostConflicts returns a snapshot of host→[]model-name pairs that
|
||||
// are claimed by 2+ model configs. Empty when the 1-to-1 invariant
|
||||
// holds. Non-empty disables the MITM listener — read by the admin
|
||||
// status endpoint to explain why.
|
||||
func (a *Application) MITMHostConflicts() map[string][]string {
|
||||
p := a.mitmHostConflicts.Load()
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return *p
|
||||
}
|
||||
|
||||
// MITMHostOwners returns the host→model-name map, useful for the
|
||||
// admin status endpoint. The lookup is recomputed on each call to
|
||||
// stay current with model-config edits without needing a
|
||||
// MITMRestart.
|
||||
func (a *Application) MITMHostOwners() map[string]string {
|
||||
if a.backendLoader == nil {
|
||||
return nil
|
||||
}
|
||||
return a.backendLoader.MITMHostOwners().Owners
|
||||
}
|
||||
|
||||
// RouterDecisions returns the routing decision store. nil when stats
|
||||
// are disabled (--disable-stats); the RouteModel middleware skips the
|
||||
// log write in that case but still rewrites requests.
|
||||
func (a *Application) RouterDecisions() router.DecisionStore {
|
||||
return a.routerDecisions
|
||||
}
|
||||
|
||||
// RouterClassifierRegistry returns the process-wide classifier cache.
|
||||
// Shared between the OpenAI and Anthropic route middlewares so the
|
||||
// admin stats endpoint sees every live classifier — and so a
|
||||
// classifier built on the OpenAI route is reused on Anthropic.
|
||||
func (a *Application) RouterClassifierRegistry() *router.Registry {
|
||||
return a.routerRegistry
|
||||
}
|
||||
|
||||
// AdmissionLimiter returns the per-model admission limiter. The
|
||||
// admission middleware uses it to gate concurrent requests; the
|
||||
// admin status surface reads InFlight/Capacity from it for live
|
||||
// load visibility.
|
||||
func (a *Application) AdmissionLimiter() *admission.Limiter {
|
||||
return a.admissionLimiter
|
||||
}
|
||||
|
||||
// StartupConfig returns the original startup configuration (from env vars, before file loading)
|
||||
func (a *Application) StartupConfig() *config.ApplicationConfig {
|
||||
return a.startupConfig
|
||||
@@ -255,6 +375,15 @@ func (a *Application) start() error {
|
||||
a.modelLoader,
|
||||
a.galleryService,
|
||||
)
|
||||
// Wire usage tracking so the assistant's get_usage_stats tool
|
||||
// returns real data; nil values keep the tool returning a clear
|
||||
// "unavailable" error if startup ran with --disable-stats.
|
||||
assistantClient.StatsRecorder = a.statsRecorder
|
||||
assistantClient.FallbackUser = a.fallbackUser
|
||||
// PII filter — same nil-or-real wiring.
|
||||
assistantClient.PIIRedactor = a.piiRedactor
|
||||
assistantClient.PIIEvents = a.piiEvents
|
||||
assistantClient.RouterDecisions = a.routerDecisions
|
||||
if err := holder.Initialize(a.applicationConfig.Context, assistantClient, localaitools.Options{}); err != nil {
|
||||
// Why log+continue instead of fail: the assistant is an optional
|
||||
// feature; a failure here must not take down the whole server.
|
||||
|
||||
146
core/application/mitm.go
Normal file
146
core/application/mitm.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func startMITMProxy(app *Application, options *config.ApplicationConfig) error {
|
||||
app.mitmMutex.Lock()
|
||||
defer app.mitmMutex.Unlock()
|
||||
return startMITMLocked(app, options)
|
||||
}
|
||||
|
||||
func startMITMLocked(app *Application, options *config.ApplicationConfig) error {
|
||||
// Validate the host↔model-config 1-to-1 invariant before binding
|
||||
// the listener. Two configs claiming the same host means the
|
||||
// dispatcher would have ambiguous PII settings; refuse to start
|
||||
// rather than silently picking one. The conflict map is published
|
||||
// for /api/middleware/status to surface in the UI.
|
||||
ownership := app.backendLoader.MITMHostOwners()
|
||||
if len(ownership.Conflicts) > 0 {
|
||||
conflicts := ownership.Conflicts
|
||||
app.mitmHostConflicts.Store(&conflicts)
|
||||
hosts := make([]string, 0, len(conflicts))
|
||||
for h := range conflicts {
|
||||
hosts = append(hosts, h)
|
||||
}
|
||||
sort.Strings(hosts)
|
||||
xlog.Error("mitm: refusing to start — duplicate host claims across model configs",
|
||||
"hosts", hosts,
|
||||
"conflicts", conflicts,
|
||||
)
|
||||
return errors.New("mitm: configuration error: duplicate host claims (see /api/middleware/status)")
|
||||
}
|
||||
app.mitmHostConflicts.Store(nil)
|
||||
|
||||
caDir := options.MITMCADir
|
||||
if caDir == "" {
|
||||
base := options.DataPath
|
||||
if base == "" {
|
||||
base = "."
|
||||
}
|
||||
caDir = filepath.Join(base, "mitm-ca")
|
||||
}
|
||||
|
||||
if app.mitmCA.Load() == nil {
|
||||
ca, err := mitm.LoadOrCreateCA(caDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ca: %w", err)
|
||||
}
|
||||
app.mitmCA.Store(ca)
|
||||
}
|
||||
|
||||
// Allowlist is exactly the set of hosts claimed by model configs.
|
||||
// No global list — admins add hosts by creating an MITM model
|
||||
// config (template available in the Add Model UI). When no config
|
||||
// claims any host, the listener still starts but every CONNECT
|
||||
// tunnels through unmodified.
|
||||
effectiveHosts := make([]string, 0, len(ownership.Owners))
|
||||
for h := range ownership.Owners {
|
||||
effectiveHosts = append(effectiveHosts, h)
|
||||
}
|
||||
sort.Strings(effectiveHosts)
|
||||
|
||||
// Per-host PII gate inherits from the owning model's pii.enabled.
|
||||
// A non-cloud-proxy backend with no explicit pii.enabled resolves
|
||||
// to false → host is intercepted but the regex pass is skipped
|
||||
// (audit events still record).
|
||||
var piiDisabled []string
|
||||
for host, modelName := range ownership.Owners {
|
||||
cfg, exists := app.backendLoader.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if !cfg.PIIIsEnabled() {
|
||||
piiDisabled = append(piiDisabled, host)
|
||||
}
|
||||
}
|
||||
|
||||
handler := mitm.NewPIIHandler(mitm.PIIHandlerOptions{
|
||||
Redactor: app.piiRedactor,
|
||||
EventStore: app.piiEvents,
|
||||
HostsWithPIIDisabled: piiDisabled,
|
||||
})
|
||||
|
||||
srv, err := mitm.NewServer(mitm.Config{
|
||||
Addr: options.MITMListen,
|
||||
CA: app.mitmCA.Load(),
|
||||
InterceptHosts: effectiveHosts,
|
||||
Handler: handler,
|
||||
EventStore: app.piiEvents,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("server: %w", err)
|
||||
}
|
||||
if err := srv.Start(); err != nil {
|
||||
return fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
app.mitmServer.Store(srv)
|
||||
|
||||
xlog.Info("mitm: cloudproxy listener started",
|
||||
"addr", srv.Addr(),
|
||||
"ca_dir", caDir,
|
||||
"intercept_hosts", effectiveHosts,
|
||||
"model_owned_hosts", len(ownership.Owners),
|
||||
"pii_disabled_hosts", len(piiDisabled),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopMITM is idempotent.
|
||||
func (a *Application) StopMITM() error {
|
||||
a.mitmMutex.Lock()
|
||||
defer a.mitmMutex.Unlock()
|
||||
stopMITMLocked(a)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartMITM reuses the existing CA so trusted clients keep
|
||||
// working across listener flips.
|
||||
func (a *Application) RestartMITM() error {
|
||||
a.mitmMutex.Lock()
|
||||
defer a.mitmMutex.Unlock()
|
||||
stopMITMLocked(a)
|
||||
if a.applicationConfig.MITMListen == "" {
|
||||
xlog.Info("mitm: cloudproxy listener stays disabled (no listen address)")
|
||||
return nil
|
||||
}
|
||||
return startMITMLocked(a, a.applicationConfig)
|
||||
}
|
||||
|
||||
func stopMITMLocked(a *Application) {
|
||||
srv := a.mitmServer.Load()
|
||||
if srv == nil {
|
||||
return
|
||||
}
|
||||
srv.Stop()
|
||||
a.mitmServer.Store(nil)
|
||||
xlog.Info("mitm: cloudproxy listener stopped")
|
||||
}
|
||||
63
core/application/router_factories.go
Normal file
63
core/application/router_factories.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// adapterConfig resolves a model name to its runtime ModelConfig, or
|
||||
// nil when the name is unknown. Shared by the router-facing factories
|
||||
// below and by ModelConfigLookup.
|
||||
func (a *Application) adapterConfig(modelName string) *config.ModelConfig {
|
||||
cfg, err := a.backendLoader.LoadModelConfigFileByNameDefaultOptions(modelName, a.applicationConfig)
|
||||
if err != nil || cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ModelConfigLookup is the lookup function the router middleware's
|
||||
// classifier validator uses to confirm classifier_model declares
|
||||
// FLAG_SCORE before binding it.
|
||||
func (a *Application) ModelConfigLookup() func(modelName string) *config.ModelConfig {
|
||||
return a.adapterConfig
|
||||
}
|
||||
|
||||
// Scorer returns a backend.Scorer bound to the named model, or nil
|
||||
// when the model is unknown. Used as a method value (app.Scorer) by
|
||||
// router.ClassifierDeps — no factory-of-factory wrapper needed.
|
||||
func (a *Application) Scorer(modelName string) backend.Scorer {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewScorer(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// Reranker returns a backend.Reranker bound to the named model, or
|
||||
// nil when unknown. The reranker model's `type:` (e.g. "colbert")
|
||||
// selects the scoring head inside the rerankers backend.
|
||||
func (a *Application) Reranker(modelName string) backend.Reranker {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewReranker(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// Embedder returns a backend.Embedder bound to the named model, or
|
||||
// nil when unknown. Used by the router's L2 embedding cache.
|
||||
func (a *Application) Embedder(modelName string) backend.Embedder {
|
||||
cfg := a.adapterConfig(modelName)
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return backend.NewEmbedder(a.modelLoader, *cfg, a.applicationConfig)
|
||||
}
|
||||
|
||||
// VectorStore returns a backend.VectorStore for the named collection,
|
||||
// or nil when the name is empty. Each router model gets its own
|
||||
// backend process via the model loader's cache keyed by storeName.
|
||||
func (a *Application) VectorStore(storeName string) backend.VectorStore {
|
||||
return backend.NewVectorStore(a.modelLoader, a.applicationConfig, storeName)
|
||||
}
|
||||
@@ -87,6 +87,28 @@ var _ = Describe("loadRuntimeSettingsFromFile", func() {
|
||||
})
|
||||
})
|
||||
|
||||
// MITM listener address. The file is the only source — no env var
|
||||
// exists — so a regression here means an admin who configured the
|
||||
// listener via /api/settings loses it after a reboot, even though
|
||||
// the value is still on disk in the volume. (Intercept hosts now
|
||||
// live in model YAML mitm.hosts: blocks, not runtime_settings.json.)
|
||||
Describe("MITM fields", func() {
|
||||
It("loads mitm_listen", func() {
|
||||
cfg := &config.ApplicationConfig{DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`)}
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.MITMListen).To(Equal(":8443"))
|
||||
})
|
||||
|
||||
It("does not override an explicit CLI flag", func() {
|
||||
cfg := &config.ApplicationConfig{
|
||||
DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`),
|
||||
MITMListen: ":9999", // simulate WithMITMListen(":9999")
|
||||
}
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.MITMListen).To(Equal(":9999"), "CLI flag must win over the persisted file value")
|
||||
})
|
||||
})
|
||||
|
||||
// The Agent Pool block has a mix of zero and non-zero defaults
|
||||
// (Enabled=true, EmbeddingModel="granite-...", MaxChunkingSize=400,
|
||||
// VectorEngine="chromem", AgentHubURL="https://agenthub.localai.io").
|
||||
|
||||
@@ -15,8 +15,15 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
@@ -128,6 +135,117 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}()
|
||||
}
|
||||
|
||||
// Initialize the OTel + Prometheus metric pipeline before any
|
||||
// counter is created. monitoring.NewLocalAIMetricsService calls
|
||||
// otel.SetMeterProvider, so any subsequent otel.Meter() call —
|
||||
// including billing.NewRecorder below — sees the real provider
|
||||
// rather than the no-op global. Initialising metrics later (in
|
||||
// core/http/app.go) leaves billing's counters bound to a no-op
|
||||
// meter and never reaches /metrics. We deliberately ignore
|
||||
// DisableMetrics here for ordering purposes; the HTTP middleware
|
||||
// that records api_call histograms is still gated.
|
||||
if !options.DisableMetrics {
|
||||
ms, err := monitoring.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
xlog.Error("failed to initialize metrics provider", "error", err)
|
||||
} else {
|
||||
application.metricsService = ms
|
||||
// Bind the billing package's counters to the same meter the
|
||||
// metrics service exports. Without this, billing's counters
|
||||
// resolve via the OTel global and never reach /metrics.
|
||||
billing.SetMeter(ms.Meter)
|
||||
}
|
||||
}
|
||||
|
||||
// Wire the routing-module billing recorder. The recorder runs in
|
||||
// every mode (auth on/off, distributed/single-node) so that token
|
||||
// tracking is not gated on auth — a no-auth single-user box still
|
||||
// gets dashboards and `/api/usage` populated.
|
||||
//
|
||||
// fallbackUser is wired *unconditionally* when stats are enabled.
|
||||
// UsageMiddleware uses it as the attribution source whenever
|
||||
// auth.GetUser(c) is nil — that covers (a) no-auth deployments and
|
||||
// (b) internal callers under auth-on (cron flushers, distributed
|
||||
// worker callbacks) that hit a recordable endpoint without a user
|
||||
// in context. The billing.user_id_present invariant still rejects
|
||||
// empty IDs; LocalUser() returns a stable UUID per data path.
|
||||
if !options.DisableStats {
|
||||
var statsBackend billing.StatsBackend
|
||||
switch {
|
||||
case application.authDB != nil:
|
||||
statsBackend = billing.NewGormBackend(application.authDB, 0, 0)
|
||||
xlog.Info("stats: using auth DB for usage records")
|
||||
default:
|
||||
statsBackend = billing.NewMemoryBackend(0)
|
||||
xlog.Info("stats: using in-memory ring buffer (no-auth single-user mode)")
|
||||
}
|
||||
application.fallbackUser = billing.LocalUser(options.DataPath)
|
||||
application.statsRecorder = billing.NewRecorder(statsBackend)
|
||||
// Drain pending records on SIGTERM. The GORM backend buffers up
|
||||
// to maxPending (5k) records across a 5s flush tick, so without
|
||||
// this the last few seconds of usage disappear on graceful exit.
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
_ = application.statsRecorder.Close()
|
||||
})
|
||||
xlog.Info("stats: fallback user wired", "local_user_id", application.fallbackUser.ID)
|
||||
} else {
|
||||
xlog.Info("stats: disabled by --disable-stats")
|
||||
}
|
||||
|
||||
// Wire the regex PII filter. Default-on: a single-user box gets
|
||||
// the built-in pattern set the first time it starts, with email/
|
||||
// phone/SSN/credit-card on mask and api_key_prefix on block. If
|
||||
// the operator wants different actions, --pii-config points at a
|
||||
// YAML file that overrides per-id; --disable-pii turns it off
|
||||
// entirely.
|
||||
if !options.DisablePII {
|
||||
patterns, err := pii.LoadConfig(options.PIIConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pii config: %w", err)
|
||||
}
|
||||
application.piiRedactor = pii.NewRedactor(patterns)
|
||||
application.piiEvents = pii.NewMemoryEventStore(0)
|
||||
// Apply persisted per-pattern overrides — admins toggling
|
||||
// action/disabled via the UI and clicking "Save to disk" land
|
||||
// here on the next start. Bad ids are warned and ignored so a
|
||||
// stale entry doesn't block startup.
|
||||
for id, ov := range options.PIIPatternOverrides {
|
||||
if ov.Action != nil {
|
||||
if err := application.piiRedactor.SetAction(id, pii.Action(*ov.Action)); err != nil {
|
||||
xlog.Warn("pii: persisted override skipped", "pattern", id, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if ov.Disabled != nil {
|
||||
if err := application.piiRedactor.SetDisabled(id, *ov.Disabled); err != nil {
|
||||
xlog.Warn("pii: persisted disable skipped", "pattern", id, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
xlog.Info("pii: filter enabled",
|
||||
"patterns", len(patterns),
|
||||
"config_path", options.PIIConfigPath,
|
||||
"persisted_overrides", len(options.PIIPatternOverrides),
|
||||
)
|
||||
} else {
|
||||
xlog.Info("pii: disabled by --disable-pii")
|
||||
}
|
||||
|
||||
// Wire the routing decision log. Always-on when stats are enabled —
|
||||
// the per-router admin page reads this as the live activity feed
|
||||
// and as input to drift checks for subsystem 5.
|
||||
if !options.DisableStats {
|
||||
application.routerDecisions = router.NewMemoryDecisionStore(0)
|
||||
}
|
||||
// Process-wide classifier cache shared across all route middlewares so
|
||||
// the embedding-cache stats endpoint sees a single source of truth.
|
||||
application.routerRegistry = router.NewRegistry()
|
||||
|
||||
// Subsystem 5: admission control. Limiter is always wired so a
|
||||
// model that gains a limits: block via gallery install or YAML
|
||||
// edit takes effect on the next restart without conditional plumbing.
|
||||
application.admissionLimiter = admission.New()
|
||||
|
||||
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
|
||||
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
|
||||
if application.authDB != nil && application.agentJobService != nil {
|
||||
@@ -195,6 +313,30 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||
}
|
||||
// Hydrate from the store first so the wildcard subscriber finds an
|
||||
// already-populated statuses map for any operations still in flight
|
||||
// on a peer replica.
|
||||
if err := application.galleryService.Hydrate(); err != nil {
|
||||
xlog.Warn("Gallery service hydrate failed", "error", err)
|
||||
}
|
||||
// Bind cache-invalidation handler before SubscribeBroadcasts so the
|
||||
// first inbound event is already routed. Peer replicas install a
|
||||
// model and broadcast on SubjectCacheInvalidateModels; this
|
||||
// callback re-runs LoadModelConfigsFromPath so a subsequent chat
|
||||
// completion that load-balances onto this replica finds the new
|
||||
// config. The originating replica reloads inline in modelHandler
|
||||
// and never enters this path.
|
||||
gs := application.galleryService
|
||||
sys := options.SystemState
|
||||
cfgLoaderOpts := options.ToConfigLoaderOptions()
|
||||
gs.OnModelsChanged = func(_ messaging.CacheInvalidateEvent) {
|
||||
if err := application.ModelConfigLoader().LoadModelConfigsFromPath(sys.Model.ModelsPath, cfgLoaderOpts...); err != nil {
|
||||
xlog.Warn("Failed to reload model configs after peer invalidation", "error", err)
|
||||
}
|
||||
}
|
||||
if err := application.galleryService.SubscribeBroadcasts(); err != nil {
|
||||
xlog.Warn("Gallery service subscribe failed", "error", err)
|
||||
}
|
||||
// Wire distributed model/backend managers so delete propagates to workers
|
||||
application.galleryService.SetModelManager(
|
||||
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||
@@ -291,6 +433,20 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
loadRuntimeSettingsFromFile(options)
|
||||
}
|
||||
|
||||
// Wire the cloudproxy MITM listener. Opt-in: empty MITMListen
|
||||
// means "no MITM" — operators must explicitly choose to start
|
||||
// it because clients have to install the generated CA cert.
|
||||
// The handler reuses the global redactor + event store so an
|
||||
// admin who's already configured PII filtering for direct API
|
||||
// traffic doesn't need a parallel config for MITM traffic.
|
||||
// Runs after loadRuntimeSettingsFromFile so a listener configured
|
||||
// via /api/settings is brought back up across restarts.
|
||||
if options.MITMListen != "" {
|
||||
if err := startMITMProxy(application, options); err != nil {
|
||||
return nil, fmt.Errorf("mitm: startup: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging)
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
@@ -580,6 +736,25 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.Branding.FaviconFile = *settings.FaviconFile
|
||||
}
|
||||
|
||||
// MITM listener address. The CLI flag WithMITMListen populates
|
||||
// options at startup; if the user configured MITM via /api/settings
|
||||
// after the fact, only the file holds the value. Apply when the
|
||||
// CLI flag did not already set it. (Intercept hosts now live in
|
||||
// model YAML mitm.hosts: rather than runtime_settings.json.)
|
||||
if settings.MITMListen != nil && options.MITMListen == "" {
|
||||
options.MITMListen = *settings.MITMListen
|
||||
}
|
||||
|
||||
// PII pattern overrides — file is the only source; CLI flags don't
|
||||
// reach into this map. Apply unconditionally when present; the
|
||||
// redactor wiring below sees the result on first construction.
|
||||
if settings.PIIPatternOverrides != nil {
|
||||
options.PIIPatternOverrides = make(map[string]config.PIIPatternRuntimeOverride, len(*settings.PIIPatternOverrides))
|
||||
for id, ov := range *settings.PIIPatternOverrides {
|
||||
options.PIIPatternOverrides[id] = ov
|
||||
}
|
||||
}
|
||||
|
||||
// Backend upgrade flags
|
||||
if settings.AutoUpgradeBackends != nil {
|
||||
if !options.AutoUpgradeBackends {
|
||||
|
||||
169
core/backend/ctx_propagation_test.go
Normal file
169
core/backend/ctx_propagation_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package backend_test
|
||||
|
||||
// Regression spec for X-LocalAI-Node coverage on audio/image/TTS/rerank/VAD.
|
||||
//
|
||||
// The X-LocalAI-Node middleware (core/http/middleware.ExposeNodeHeader)
|
||||
// works end-to-end only if the per-request holder attached to the HTTP
|
||||
// request context reaches the SmartRouter via ml.Load(opts...). The chain
|
||||
// is:
|
||||
//
|
||||
// handler -> backend.Foo(ctx, ...) -> ModelOptions(cfg, app, WithContext(ctx))
|
||||
// -> ml.Load(opts...) -> grpcModel(..., o.context) -> modelRouter(ctx, ...)
|
||||
// -> SmartRouter -> distributedhdr.Stamp(ctx, nodeID)
|
||||
//
|
||||
// If any backend helper drops `ctx` and lets ModelOptions fall back to the
|
||||
// app context, the router never sees the per-request holder and the
|
||||
// header silently stays empty for that endpoint. These specs pin the
|
||||
// request-context-reaches-router contract for the five backend helpers
|
||||
// that were previously dropping ctx between the handler and Load.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pbproto "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// newCapturingLoader returns a ModelLoader wired with a stub model router
|
||||
// that captures the context it receives and then short-circuits with a
|
||||
// sentinel error. The router callback is the exact seam where the
|
||||
// SmartRouter would call distributedhdr.Stamp in production, so observing
|
||||
// the holder here is equivalent to observing it at the real router.
|
||||
func newCapturingLoader() (*model.ModelLoader, *atomic.Value, func() context.Context) {
|
||||
loader := model.NewModelLoader(&system.SystemState{})
|
||||
var captured atomic.Value
|
||||
loader.SetModelRouter(func(ctx context.Context, _ string, _, _, _ string, _ *pbproto.ModelOptions, _ bool) (*model.Model, error) {
|
||||
captured.Store(ctx)
|
||||
// Return an error so the backend short-circuits before trying to
|
||||
// dial gRPC. We only care about the context-arrival contract.
|
||||
return nil, errRouterShortCircuit
|
||||
})
|
||||
get := func() context.Context {
|
||||
v, _ := captured.Load().(context.Context)
|
||||
return v
|
||||
}
|
||||
return loader, &captured, get
|
||||
}
|
||||
|
||||
var errRouterShortCircuit = sentinelErr("router short-circuit (test)")
|
||||
|
||||
type sentinelErr string
|
||||
|
||||
func (s sentinelErr) Error() string { return string(s) }
|
||||
|
||||
func newAppCfg() *config.ApplicationConfig {
|
||||
return config.NewApplicationConfig(config.WithSystemState(&system.SystemState{}))
|
||||
}
|
||||
|
||||
func newModelCfg() config.ModelConfig {
|
||||
threads := 1
|
||||
cfg := config.ModelConfig{
|
||||
Name: "test-model",
|
||||
Backend: "stub-backend",
|
||||
Threads: &threads,
|
||||
}
|
||||
cfg.Model = "test.bin"
|
||||
return cfg
|
||||
}
|
||||
|
||||
var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
|
||||
const fakeNodeID = "node-ctx-propagation-7"
|
||||
|
||||
var (
|
||||
appCfg *config.ApplicationConfig
|
||||
modelCfg config.ModelConfig
|
||||
loader *model.ModelLoader
|
||||
routerCtxOf func() context.Context
|
||||
holder *atomic.Value
|
||||
reqCtx context.Context
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
appCfg = newAppCfg()
|
||||
modelCfg = newModelCfg()
|
||||
loader, _, routerCtxOf = newCapturingLoader()
|
||||
holder = distributedhdr.NewHolder()
|
||||
reqCtx = distributedhdr.WithHolder(context.Background(), holder)
|
||||
})
|
||||
|
||||
// stampViaRouterCtx asserts the captured router context carries the
|
||||
// SAME holder that was attached to the request. We verify by stamping
|
||||
// through the router-side ctx and observing the value via the
|
||||
// request-side holder; if the holders were different objects the load
|
||||
// would return "".
|
||||
stampViaRouterCtx := func() {
|
||||
routerCtx := routerCtxOf()
|
||||
Expect(routerCtx).ToNot(BeNil(), "router callback must have been invoked")
|
||||
distributedhdr.Stamp(routerCtx, fakeNodeID)
|
||||
Expect(distributedhdr.Load(holder)).To(Equal(fakeNodeID),
|
||||
"stamp via router-side ctx must be observable via the request-side holder")
|
||||
}
|
||||
|
||||
It("Rerank forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.Rerank(reqCtx, &pbproto.RerankRequest{Query: "q"}, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("VAD forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.VAD(&schema.VADRequest{}, reqCtx, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTS forwards the request context to the SmartRouter", func() {
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTranscriptionWithOptions forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.ModelTranscriptionWithOptions(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTranscriptionStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTranscriptionStream(reqCtx, backend.TranscriptionRequest{Audio: "x.wav"}, loader, modelCfg, appCfg, func(backend.TranscriptionStreamChunk) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ImageGeneration forwards the request context to the SmartRouter", func() {
|
||||
_, err := backend.ImageGeneration(reqCtx, 64, 64, 1, 0, "p", "", "", "/tmp/out.png", loader, modelCfg, appCfg, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("does NOT leak the holder when the app context is used instead", func() {
|
||||
// Sanity: the bug being fixed manifests as the router getting
|
||||
// appCfg.Context (no holder) instead of reqCtx (holder). A direct
|
||||
// call with context.Background() must not see the holder via the
|
||||
// app context surface.
|
||||
appCtxOnly := appCfg.Context
|
||||
Expect(distributedhdr.Holder(appCtxOnly)).To(BeNil(),
|
||||
"the app context must not be the carrier of per-request holders")
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,7 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -11,9 +12,38 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
// Embedder produces a fixed-dimension vector from a prompt. The
|
||||
// router's L2 embedding cache uses it to look up semantically-similar
|
||||
// past decisions.
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, text string) ([]float32, error)
|
||||
}
|
||||
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// NewEmbedder binds (loader, modelConfig, appConfig) into an Embedder.
|
||||
func NewEmbedder(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Embedder {
|
||||
return &modelEmbedder{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelEmbedder struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (e *modelEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
fn, err := ModelEmbedding(ctx, text, nil, e.loader, e.modelConfig, e.appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fn()
|
||||
}
|
||||
|
||||
func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -10,9 +11,12 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
func ImageGeneration(ctx context.Context, height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(
|
||||
opts...,
|
||||
)
|
||||
@@ -23,7 +27,7 @@ func ImageGeneration(height, width, step, seed int, positive_prompt, negative_pr
|
||||
|
||||
fn := func() error {
|
||||
_, err := inferenceModel.GenerateImage(
|
||||
appConfig.Context,
|
||||
ctx,
|
||||
&proto.GenerateImageRequest{
|
||||
Height: int32(height),
|
||||
Width: int32(width),
|
||||
|
||||
@@ -94,7 +94,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
opts := ModelOptions(*c, o)
|
||||
opts := ModelOptions(*c, o, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(o, c.Name, c.Backend, err, map[string]any{"model_file": modelFile})
|
||||
|
||||
@@ -242,6 +242,18 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
Tokenizer: c.Tokenizer,
|
||||
}
|
||||
|
||||
if c.Backend == "cloud-proxy" {
|
||||
opts.Proxy = &pb.ProxyOptions{
|
||||
UpstreamUrl: c.Proxy.UpstreamURL,
|
||||
Mode: c.Proxy.Mode,
|
||||
Provider: c.Proxy.Provider,
|
||||
ApiKeyEnv: c.Proxy.APIKeyEnv,
|
||||
ApiKeyFile: c.Proxy.APIKeyFile,
|
||||
UpstreamModel: c.Proxy.UpstreamModel,
|
||||
RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds),
|
||||
}
|
||||
}
|
||||
|
||||
if c.MMProj != "" {
|
||||
opts.MMProj = filepath.Join(modelPath, c.MMProj)
|
||||
}
|
||||
|
||||
@@ -11,8 +11,56 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// RerankResult is the per-document score returned to consumers,
|
||||
// narrowed from proto.RerankResult so callers don't need to depend on
|
||||
// the proto package.
|
||||
type RerankResult struct {
|
||||
Index int
|
||||
RelevanceScore float32
|
||||
}
|
||||
|
||||
// Reranker scores a list of candidate documents against a query.
|
||||
// Returns one RerankResult per input document (no top-N truncation -
|
||||
// callers that need it can sort and slice).
|
||||
type Reranker interface {
|
||||
Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error)
|
||||
}
|
||||
|
||||
// NewReranker binds (loader, modelConfig, appConfig) into a Reranker.
|
||||
func NewReranker(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Reranker {
|
||||
return &modelReranker{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelReranker struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (r *modelReranker) Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error) {
|
||||
req := &proto.RerankRequest{
|
||||
Query: query,
|
||||
Documents: documents,
|
||||
// TopN=0: backend returns scores for every document. Truncating
|
||||
// here would silently zero out labels the reranker considered
|
||||
// unlikely, which the router classifier needs.
|
||||
}
|
||||
res, err := Rerank(ctx, req, r.loader, r.appConfig, r.modelConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]RerankResult, 0, len(res.GetResults()))
|
||||
for _, dr := range res.GetResults() {
|
||||
out = append(out, RerankResult{Index: int(dr.GetIndex()), RelevanceScore: dr.GetRelevanceScore()})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
rerankModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
|
||||
159
core/backend/score.go
Normal file
159
core/backend/score.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// ScoreOptions controls a single Score request.
|
||||
type ScoreOptions struct {
|
||||
// IncludeTokenLogprobs returns per-token log-probability detail for
|
||||
// each candidate. Off by default — the joint LogProb is enough for
|
||||
// ranking; callers that need calibration / entropy over the token
|
||||
// stream opt in.
|
||||
IncludeTokenLogprobs bool
|
||||
// LengthNormalize divides the joint log-prob by the candidate's
|
||||
// token count. Useful when comparing candidates of different
|
||||
// lengths — without it, longer candidates score lower by default.
|
||||
LengthNormalize bool
|
||||
}
|
||||
|
||||
// CandidateScore is the per-candidate result. Mirrors pb.CandidateScore
|
||||
// but avoids leaking the proto type to consumers.
|
||||
type CandidateScore struct {
|
||||
LogProb float64
|
||||
LengthNormalizedLogProb float64
|
||||
NumTokens int
|
||||
Tokens []TokenLogProb
|
||||
}
|
||||
|
||||
type TokenLogProb struct {
|
||||
Token string
|
||||
LogProb float64
|
||||
}
|
||||
|
||||
// Scorer evaluates a model's joint log-probability of each candidate
|
||||
// continuation given a shared prompt. Implemented by NewScorer over a
|
||||
// model-loaded backend; the router's score classifier consumes this
|
||||
// for multi-label policy selection.
|
||||
type Scorer interface {
|
||||
Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error)
|
||||
}
|
||||
|
||||
// NewScorer binds (loader, modelConfig, appConfig) into a Scorer. The
|
||||
// underlying backend is resolved lazily on the first Score call.
|
||||
// Returns nil only as a contract violation — callers that need to
|
||||
// detect "model not loadable" should look up the config first.
|
||||
func NewScorer(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Scorer {
|
||||
return &modelScorer{loader: loader, modelConfig: modelConfig, appConfig: appConfig}
|
||||
}
|
||||
|
||||
type modelScorer struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func (m *modelScorer) Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error) {
|
||||
fn, err := ModelScore(prompt, candidates, ScoreOptions{LengthNormalize: true}, m.loader, m.modelConfig, m.appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
// ModelScore loads the backend for modelConfig and returns a closure
|
||||
// that scores `candidates` against `prompt`. The closure is bound to
|
||||
// the loaded model so callers can keep it around for repeat scoring
|
||||
// within the same request without re-resolving the backend.
|
||||
func ModelScore(prompt string, candidates []string, opts ScoreOptions, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func(ctx context.Context) ([]CandidateScore, error), error) {
|
||||
modelOpts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(modelOpts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
b, ok := inferenceModel.(grpc.Backend)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("scoring not supported by backend %q", modelConfig.Backend)
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("Score: candidates must be non-empty")
|
||||
}
|
||||
return func(ctx context.Context) ([]CandidateScore, error) {
|
||||
// Surface score calls in the Traces UI alongside the LLM calls
|
||||
// they typically gate (router classifier, eval scoring). Without
|
||||
// this, a router-classified request shows only the downstream LLM
|
||||
// trace with no record of the classification that picked it.
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
resp, err := b.Score(ctx, &pb.ScoreRequest{
|
||||
Prompt: prompt,
|
||||
Candidates: candidates,
|
||||
IncludeTokenLogprobs: opts.IncludeTokenLogprobs,
|
||||
LengthNormalize: opts.LengthNormalize,
|
||||
})
|
||||
results := scoreResponseToCandidates(resp, opts.IncludeTokenLogprobs)
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceScore,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(prompt, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
// Copy candidates so the trace buffer doesn't pin a
|
||||
// caller-owned slice for the lifetime of the ring.
|
||||
"candidates": append([]string(nil), candidates...),
|
||||
"results": results,
|
||||
},
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// scoreResponseToCandidates converts the wire-format pb response into
|
||||
// the value type consumed by callers. Extracted to keep ModelScore's
|
||||
// closure trivial and so the conversion can be unit-tested without a
|
||||
// real backend.
|
||||
func scoreResponseToCandidates(resp *pb.ScoreResponse, includeTokens bool) []CandidateScore {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]CandidateScore, len(resp.Candidates))
|
||||
for i, c := range resp.Candidates {
|
||||
cs := CandidateScore{
|
||||
LogProb: c.LogProb,
|
||||
LengthNormalizedLogProb: c.LengthNormalizedLogProb,
|
||||
NumTokens: int(c.NumTokens),
|
||||
}
|
||||
if includeTokens && len(c.Tokens) > 0 {
|
||||
cs.Tokens = make([]TokenLogProb, len(c.Tokens))
|
||||
for j, t := range c.Tokens {
|
||||
cs.Tokens[j] = TokenLogProb{Token: t.Token, LogProb: t.LogProb}
|
||||
}
|
||||
}
|
||||
out[i] = cs
|
||||
}
|
||||
return out
|
||||
}
|
||||
63
core/backend/score_test.go
Normal file
63
core/backend/score_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("scoreResponseToCandidates", func() {
|
||||
It("returns nil for a nil response", func() {
|
||||
Expect(scoreResponseToCandidates(nil, false)).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns an empty slice when the response has no candidates", func() {
|
||||
Expect(scoreResponseToCandidates(&pb.ScoreResponse{}, false)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("copies LogProb / LengthNormalizedLogProb / NumTokens for every candidate", func() {
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{
|
||||
{LogProb: -2.0, LengthNormalizedLogProb: -1.0, NumTokens: 2},
|
||||
{LogProb: -7.5, LengthNormalizedLogProb: -1.5, NumTokens: 5},
|
||||
}}
|
||||
got := scoreResponseToCandidates(resp, false)
|
||||
Expect(got).To(HaveLen(2))
|
||||
Expect(got[0].LogProb).To(Equal(-2.0))
|
||||
Expect(got[0].LengthNormalizedLogProb).To(Equal(-1.0))
|
||||
Expect(got[0].NumTokens).To(Equal(2))
|
||||
Expect(got[1].LogProb).To(Equal(-7.5))
|
||||
Expect(got[1].NumTokens).To(Equal(5))
|
||||
})
|
||||
|
||||
It("omits per-token detail when includeTokens=false even if the wire response carries it", func() {
|
||||
// Defensive: if the backend over-reports we still respect the
|
||||
// caller's opt-in so consumers don't pay marshaling for data
|
||||
// they didn't ask for.
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{
|
||||
LogProb: -1.0,
|
||||
Tokens: []*pb.TokenLogProb{{Token: "hi", LogProb: -1.0}},
|
||||
}}}
|
||||
got := scoreResponseToCandidates(resp, false)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Tokens).To(BeNil())
|
||||
})
|
||||
|
||||
It("populates per-token detail when includeTokens=true", func() {
|
||||
resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{
|
||||
LogProb: -3.0,
|
||||
NumTokens: 2,
|
||||
Tokens: []*pb.TokenLogProb{
|
||||
{Token: "Hello", LogProb: -1.0},
|
||||
{Token: " world", LogProb: -2.0},
|
||||
},
|
||||
}}}
|
||||
got := scoreResponseToCandidates(resp, true)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Tokens).To(HaveLen(2))
|
||||
Expect(got[0].Tokens[0].Token).To(Equal("Hello"))
|
||||
Expect(got[0].Tokens[0].LogProb).To(Equal(-1.0))
|
||||
Expect(got[0].Tokens[1].Token).To(Equal(" world"))
|
||||
Expect(got[0].Tokens[1].LogProb).To(Equal(-2.0))
|
||||
})
|
||||
})
|
||||
@@ -1,12 +1,74 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
// VectorStore is the narrowed KNN store used by the router's embedding
|
||||
// cache. Search returns the top-1 match (cosine similarity in [-1, 1])
|
||||
// and the serialised payload, or ok=false on a clean miss.
|
||||
type VectorStore interface {
|
||||
Search(ctx context.Context, vec []float32) (similarity float64, payload []byte, ok bool, err error)
|
||||
Insert(ctx context.Context, vec []float32, payload []byte) error
|
||||
}
|
||||
|
||||
// NewVectorStore returns a VectorStore backed by the local-store
|
||||
// gRPC backend, namespaced by storeName so two routers don't collide.
|
||||
func NewVectorStore(loader *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) VectorStore {
|
||||
if storeName == "" {
|
||||
return nil
|
||||
}
|
||||
return &localVectorStore{loader: loader, appConfig: appConfig, storeName: storeName}
|
||||
}
|
||||
|
||||
type localVectorStore struct {
|
||||
loader *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
storeName string
|
||||
}
|
||||
|
||||
func (s *localVectorStore) backend(_ context.Context) (grpc.Backend, error) {
|
||||
return StoreBackend(s.loader, s.appConfig, s.storeName, "")
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Search(ctx context.Context, vec []float32) (float64, []byte, bool, error) {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return 0, nil, false, fmt.Errorf("vector store load: %w", err)
|
||||
}
|
||||
_, values, similarities, err := store.Find(ctx, be, vec, 1)
|
||||
if err != nil {
|
||||
// local-store's Find returns "existing length is -1" before
|
||||
// any keys are inserted. Surface that as a clean miss so the
|
||||
// cache layer treats it as an empty store and proceeds to
|
||||
// Insert rather than skipping.
|
||||
if strings.Contains(err.Error(), "existing length is -1") {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return 0, nil, false, fmt.Errorf("vector store find: %w", err)
|
||||
}
|
||||
if len(values) == 0 || len(similarities) == 0 {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
return float64(similarities[0]), values[0], true, nil
|
||||
}
|
||||
|
||||
func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload []byte) error {
|
||||
be, err := s.backend(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vector store load: %w", err)
|
||||
}
|
||||
return store.SetSingle(ctx, be, vec, payload)
|
||||
}
|
||||
|
||||
func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string, backend string) (grpc.Backend, error) {
|
||||
if backend == "" {
|
||||
backend = model.LocalStoreBackend
|
||||
|
||||
@@ -41,11 +41,14 @@ func (r *TranscriptionRequest) toProto(threads uint32) *proto.TranscriptRequest
|
||||
}
|
||||
}
|
||||
|
||||
func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
||||
func loadTranscriptionModel(ctx context.Context, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
||||
if modelConfig.Backend == "" {
|
||||
modelConfig.Backend = model.WhisperBackend
|
||||
}
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
transcriptionModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -68,7 +71,7 @@ func ModelTranscription(ctx context.Context, audio, language string, translate,
|
||||
}
|
||||
|
||||
func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -150,7 +153,7 @@ type TranscriptionStreamChunk struct {
|
||||
// support real streaming should still emit one terminal event with Final set,
|
||||
// which the HTTP layer turns into a single delta + done SSE pair.
|
||||
func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -29,7 +29,10 @@ func ModelTTS(
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (string, *proto.Result, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
@@ -131,7 +134,7 @@ func ModelTTSStream(
|
||||
modelConfig config.ModelConfig,
|
||||
audioCallback func([]byte) error,
|
||||
) error {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
|
||||
@@ -14,7 +14,10 @@ func VAD(request *schema.VADRequest,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
// model.WithContext(ctx) overrides the app-context default set in
|
||||
// ModelOptions so distributed routing decisions reach the request's
|
||||
// X-LocalAI-Node holder via distributedhdr.Stamp.
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
|
||||
@@ -157,8 +157,13 @@ type RunCMD struct {
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
|
||||
// Cloud-proxy MITM listener (off by default).
|
||||
MITMListen string `env:"LOCALAI_MITM_LISTEN" help:"Address (host:port) for the cloudproxy MITM listener. Empty = disabled. Clients set HTTPS_PROXY=http://<this>:<port>. Intercept hosts are declared per-model via the model YAML mitm.hosts: block; create one from the Add Model UI." group:"middleware"`
|
||||
MITMCADir string `env:"LOCALAI_MITM_CA_DIR" type:"path" help:"Directory holding the MITM proxy CA cert + key. Defaults to <data-path>/mitm-ca." group:"middleware"`
|
||||
}
|
||||
|
||||
func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
@@ -217,6 +222,8 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithLoadToMemory(r.LoadToMemory),
|
||||
config.WithMachineTag(r.MachineTag),
|
||||
config.WithAPIAddress(r.Address),
|
||||
config.WithMITMListen(r.MITMListen),
|
||||
config.WithMITMCADir(r.MITMCADir),
|
||||
config.WithAgentJobRetentionDays(r.AgentJobRetentionDays),
|
||||
config.WithLlamaCPPTunnelCallback(func(tunnels []string) {
|
||||
tunnelEnvVar := strings.Join(tunnels, ",")
|
||||
@@ -277,6 +284,9 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
opts = append(opts, config.DisableMetricsEndpoint)
|
||||
|
||||
@@ -40,6 +40,54 @@ type ApplicationConfig struct {
|
||||
P2PNetworkID string
|
||||
Federated bool
|
||||
|
||||
// DisableStats turns off per-request token tracking. By default the
|
||||
// routing module's billing recorder runs in every mode (including
|
||||
// no-auth single-user) so dashboards and `/api/usage` are immediately
|
||||
// useful; set this to opt out of that, e.g., for ephemeral CI runs
|
||||
// or privacy-strict deployments where no token-count history should
|
||||
// touch disk or memory.
|
||||
DisableStats bool
|
||||
|
||||
// PIIConfigPath points to an optional YAML file describing the PII
|
||||
// pattern set. When empty, the routing/pii module's DefaultPatterns()
|
||||
// (email, phone, SSN, credit card, IPv4, API key prefixes) are
|
||||
// loaded with their default actions. Each entry overrides the
|
||||
// matching default by ID:
|
||||
//
|
||||
// patterns:
|
||||
// - id: email
|
||||
// action: route_local # downgrade default mask -> route_local
|
||||
// - id: ssn
|
||||
// action: block # upgrade default mask -> block
|
||||
//
|
||||
// Unknown ids are rejected with a clear error at startup.
|
||||
PIIConfigPath string
|
||||
|
||||
// DisablePII turns the regex PII filter off entirely. Default
|
||||
// (false) enables it on the OpenAI chat completions route.
|
||||
DisablePII bool
|
||||
|
||||
// MITMListen is the address (host:port) the cloudproxy MITM
|
||||
// listener binds on. Empty disables the MITM proxy entirely.
|
||||
// Use case: redacting PII from Claude Code / Codex CLI traffic
|
||||
// without LocalAI holding the upstream API key. Clients set
|
||||
// HTTPS_PROXY=http://localai:port and trust the CA cert
|
||||
// LocalAI exposes at /api/middleware/proxy-ca.crt.
|
||||
MITMListen string
|
||||
|
||||
// MITMCADir holds the persisted MITM proxy CA cert and private
|
||||
// key. The CA is generated on first start; subsequent starts
|
||||
// reload it so clients keep trusting the same root. The key
|
||||
// file is mode 0600.
|
||||
MITMCADir string
|
||||
|
||||
|
||||
// PIIPatternOverrides applies persisted per-id deltas (action,
|
||||
// disabled) to the live redactor at startup. Loaded from
|
||||
// runtime_settings.json and applied right after pii.NewRedactor.
|
||||
// nil/empty leaves the YAML defaults in place.
|
||||
PIIPatternOverrides map[string]PIIPatternRuntimeOverride
|
||||
|
||||
DisableWebUI bool
|
||||
OllamaAPIRootEndpoint bool
|
||||
EnforcePredownloadScans bool
|
||||
@@ -112,6 +160,18 @@ type ApplicationConfig struct {
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed DistributedConfig
|
||||
|
||||
// ExposeNodeHeader, when true, activates middleware.ExposeNodeHeader on
|
||||
// the inference routes (OpenAI chat/completions/embeddings, Anthropic
|
||||
// /v1/messages, Ollama /api/chat,/api/generate,/api/embed). The
|
||||
// middleware wraps the response writer and attaches an "X-LocalAI-Node"
|
||||
// response header carrying the ID of the distributed-mode worker node
|
||||
// that served the request. Off by default because the node ID is
|
||||
// internal topology that can aid attacker reconnaissance if surfaced on
|
||||
// a public endpoint; operators opt in explicitly via
|
||||
// --expose-node-header / LOCALAI_EXPOSE_NODE_HEADER for debugging,
|
||||
// observability and load-balancer attribution.
|
||||
ExposeNodeHeader bool
|
||||
|
||||
// LocalAI Assistant chat modality. Hard-disable the in-process admin MCP
|
||||
// server with this flag; runtime-toggleable via /api/settings.
|
||||
DisableLocalAIAssistant bool
|
||||
@@ -604,6 +664,45 @@ func WithDataPath(dataPath string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisableStats turns off the billing recorder. CLI: --disable-stats.
|
||||
func WithDisableStats(disable bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DisableStats = disable
|
||||
}
|
||||
}
|
||||
|
||||
// WithPIIConfigPath points the routing PII filter at a YAML config
|
||||
// file. CLI: --pii-config.
|
||||
func WithPIIConfigPath(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.PIIConfigPath = path
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisablePII turns the regex PII filter off. CLI: --disable-pii.
|
||||
func WithDisablePII(disable bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DisablePII = disable
|
||||
}
|
||||
}
|
||||
|
||||
// WithMITMListen sets the address the cloudproxy MITM listener
|
||||
// binds on. Empty = disabled. CLI: --mitm-listen.
|
||||
func WithMITMListen(addr string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MITMListen = addr
|
||||
}
|
||||
}
|
||||
|
||||
// WithMITMCADir sets the directory used to persist the MITM proxy
|
||||
// CA cert + key. CLI: --mitm-ca-dir.
|
||||
func WithMITMCADir(dir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MITMCADir = dir
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DynamicConfigsDir = dynamicConfigsDir
|
||||
@@ -893,6 +992,15 @@ func WithDisableLocalAIAssistant(disabled bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithExposeNodeHeader enables the X-LocalAI-Node response header on
|
||||
// inference endpoints. Default off; the node ID reveals internal cluster
|
||||
// topology and is opt-in for that reason.
|
||||
func WithExposeNodeHeader(enabled bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.ExposeNodeHeader = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
|
||||
// Some options defined at the application level are going to be passed as defaults for
|
||||
// all the configuration for the models.
|
||||
@@ -998,6 +1106,8 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
logoHorizontalFile := o.Branding.LogoHorizontalFile
|
||||
faviconFile := o.Branding.FaviconFile
|
||||
|
||||
mitmListen := o.MITMListen
|
||||
|
||||
return RuntimeSettings{
|
||||
WatchdogEnabled: &watchdogEnabled,
|
||||
WatchdogIdleEnabled: &watchdogIdle,
|
||||
@@ -1051,6 +1161,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
LogoFile: &logoFile,
|
||||
LogoHorizontalFile: &logoHorizontalFile,
|
||||
FaviconFile: &faviconFile,
|
||||
MITMListen: &mitmListen,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1276,6 +1387,10 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
o.Branding.FaviconFile = *settings.FaviconFile
|
||||
}
|
||||
|
||||
if settings.MITMListen != nil {
|
||||
o.MITMListen = *settings.MITMListen
|
||||
}
|
||||
|
||||
// Note: ApiKeys requires special handling (merging with startup keys) - handled in caller
|
||||
|
||||
return requireRestart
|
||||
|
||||
@@ -49,20 +49,31 @@ var DiffusersPipelineOptions = []FieldOption{
|
||||
{Value: "StableVideoDiffusionPipeline", Label: "StableVideoDiffusionPipeline"},
|
||||
}
|
||||
|
||||
// UsecaseOptions must stay in sync with GetAllModelConfigUsecases in
|
||||
// core/config/model_config.go — a value missing here is silently
|
||||
// inaccessible from the model editor, which is how `score` (the router
|
||||
// classifier usecase) hid for an entire release.
|
||||
var UsecaseOptions = []FieldOption{
|
||||
{Value: "chat", Label: "Chat"},
|
||||
{Value: "completion", Label: "Completion"},
|
||||
{Value: "edit", Label: "Edit"},
|
||||
{Value: "embeddings", Label: "Embeddings"},
|
||||
{Value: "rerank", Label: "Rerank"},
|
||||
{Value: "score", Label: "Score (Router Classifier)"},
|
||||
{Value: "image", Label: "Image"},
|
||||
{Value: "vision", Label: "Vision"},
|
||||
{Value: "detection", Label: "Detection"},
|
||||
{Value: "face_recognition", Label: "Face Recognition"},
|
||||
{Value: "transcript", Label: "Transcript"},
|
||||
{Value: "diarization", Label: "Diarization"},
|
||||
{Value: "speaker_recognition", Label: "Speaker Recognition"},
|
||||
{Value: "tts", Label: "TTS"},
|
||||
{Value: "sound_generation", Label: "Sound Generation"},
|
||||
{Value: "audio_transform", Label: "Audio Transform"},
|
||||
{Value: "realtime_audio", Label: "Realtime Audio"},
|
||||
{Value: "tokenize", Label: "Tokenize"},
|
||||
{Value: "vad", Label: "VAD"},
|
||||
{Value: "video", Label: "Video"},
|
||||
{Value: "detection", Label: "Detection"},
|
||||
}
|
||||
|
||||
var DiffusersSchedulerOptions = []FieldOption{
|
||||
|
||||
@@ -232,6 +232,17 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Description: "Use the chat template from the model's tokenizer config",
|
||||
Order: 43,
|
||||
},
|
||||
// Router section template — kept in the templates UI section
|
||||
// (rather than the router section under "other") so operators
|
||||
// editing prompt shapes find all template-typed fields in one
|
||||
// place, mirroring how chat / chat_message are grouped.
|
||||
"router.classifier_system_template": {
|
||||
Section: "templates",
|
||||
Label: "Router Classifier System Prompt",
|
||||
Description: "Go text/template (with sprig functions) for the routing system prompt the score classifier feeds to its classifier_model. Executed with `.Policies` ([]{Label, Description}). Empty falls back to the built-in Arch-Router-shaped prompt (route-listing block + JSON output schema). Override when the classifier model was trained on a different schema or you need the routing instructions in a different language. The candidate format scored against the model is fixed at `{\"route\": \"<label>\"}` — keep your override's output schema instruction matching that.",
|
||||
Component: "code-editor",
|
||||
Order: 44,
|
||||
},
|
||||
|
||||
// --- Pipeline ---
|
||||
"pipeline.llm": {
|
||||
@@ -320,5 +331,207 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Description: "Enable CUDA for diffusers",
|
||||
Order: 82,
|
||||
},
|
||||
|
||||
// --- PII filtering (per-model) ---
|
||||
"pii.enabled": {
|
||||
Section: "other",
|
||||
Label: "PII Filtering Enabled",
|
||||
Description: "Enable PII redaction middleware for this model. Unset means use the default (off for local backends, on for proxy-* / cloud-hosted backends).",
|
||||
Component: "toggle",
|
||||
Order: 200,
|
||||
},
|
||||
"pii.patterns": {
|
||||
Section: "other",
|
||||
Label: "PII Pattern Overrides",
|
||||
Description: "Override the global default action for specific patterns on this model. Patterns not listed here inherit the global action (Settings → Middleware → Filtering).",
|
||||
Component: "pii-pattern-list",
|
||||
Order: 201,
|
||||
},
|
||||
|
||||
// --- Cloud passthrough proxy ---
|
||||
// These only have an effect when Backend is set to
|
||||
// "cloud-proxy". When the upstream URL is empty, the model
|
||||
// fails closed — the chat handler does NOT silently fall back
|
||||
// to the local gRPC pipeline.
|
||||
"proxy.mode": {
|
||||
Section: "other",
|
||||
Label: "Proxy Mode",
|
||||
Description: "passthrough forwards the client's OpenAI body verbatim — point upstream_url at an OpenAI-compatible endpoint (incl. Anthropic's /v1/chat/completions compat layer). translate converts OpenAI ↔ Anthropic Messages so you can target a native API (/v1/messages); tool_calls and usage tokens survive the round-trip.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "passthrough", Label: "passthrough (raw forward)"},
|
||||
{Value: "translate", Label: "translate (OpenAI ↔ native)"},
|
||||
},
|
||||
Default: "passthrough",
|
||||
Order: 208,
|
||||
},
|
||||
"proxy.provider": {
|
||||
Section: "other",
|
||||
Label: "Proxy Provider",
|
||||
Description: "Upstream API family. Drives auth header shape (Bearer vs x-api-key + anthropic-version) and, in translate mode, which request/response codec is used.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "openai", Label: "OpenAI"},
|
||||
{Value: "anthropic", Label: "Anthropic"},
|
||||
},
|
||||
Default: "openai",
|
||||
Order: 209,
|
||||
},
|
||||
"proxy.upstream_url": {
|
||||
Section: "other",
|
||||
Label: "Proxy Upstream URL",
|
||||
Description: "Full POST endpoint of the upstream provider (e.g. https://api.openai.com/v1/chat/completions). Only used when Backend is cloud-proxy.",
|
||||
Component: "input",
|
||||
Order: 210,
|
||||
},
|
||||
"proxy.api_key_env": {
|
||||
Section: "other",
|
||||
Label: "Proxy API Key Env Var",
|
||||
Description: "Name of the environment variable holding the upstream API key. Reading from env keeps the secret out of the YAML and the admin UI.",
|
||||
Component: "input",
|
||||
Order: 211,
|
||||
},
|
||||
"proxy.upstream_model": {
|
||||
Section: "other",
|
||||
Label: "Proxy Upstream Model",
|
||||
Description: "Model name sent to the upstream. Leave empty to forward the client's model field unchanged. Useful when the LocalAI alias differs from the upstream's canonical name.",
|
||||
Component: "input",
|
||||
Order: 212,
|
||||
},
|
||||
"proxy.request_timeout_seconds": {
|
||||
Section: "other",
|
||||
Label: "Proxy Request Timeout (seconds)",
|
||||
Description: "Caps the upstream HTTP request duration. 0 disables the deadline; the request still ends when the client disconnects.",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Order: 213,
|
||||
},
|
||||
|
||||
// --- MITM intercept hosts ---
|
||||
// Each host listed here is claimed by this model config; the
|
||||
// cloudproxy MITM listener (see Middleware → MITM Proxy) uses
|
||||
// THIS config's pii: settings to filter the intercepted traffic.
|
||||
// A host claimed by two configs is a critical error — the
|
||||
// listener refuses to start until resolved.
|
||||
"mitm.hosts": {
|
||||
Section: "other",
|
||||
Label: "MITM Intercept Hosts",
|
||||
Description: "Hostnames the cloudproxy MITM proxy terminates TLS for on behalf of this model config. PII filtering and pattern overrides flow from this model when the host is intercepted. Each host must be unique across all configs.",
|
||||
Component: "string-list",
|
||||
Order: 220,
|
||||
},
|
||||
|
||||
// --- Router ---
|
||||
// Routing turns this model config into a dispatcher: the
|
||||
// classifier scores every policy label as a continuation of
|
||||
// the routing prompt and picks the first candidate whose
|
||||
// labels are a superset of the active set. The Routing tab of
|
||||
// the middleware admin page surfaces every model with a router
|
||||
// block.
|
||||
"router.classifier": {
|
||||
Section: "other",
|
||||
Label: "Classifier",
|
||||
Description: "Picks a candidate by scoring every policy label against the prompt. Only \"score\" is shipped today; it asks the classifier_model to rank each label and reads off the softmax. Empty defaults to \"score\".",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "score", Label: "Score (Arch-Router-style)"},
|
||||
},
|
||||
Order: 230,
|
||||
},
|
||||
"router.classifier_model": {
|
||||
Section: "other",
|
||||
Label: "Classifier Model",
|
||||
Description: "Loaded LocalAI model the score classifier asks to rank each policy label as a continuation. Must support the Score gRPC primitive (today: llama-cpp, vLLM) and use the ChatML template. Arch-Router-1.5B Q4_K_M is the canonical choice; any small ChatML instruct model also works at a higher activation_threshold.",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsChat,
|
||||
Order: 231,
|
||||
},
|
||||
"router.fallback": {
|
||||
Section: "other",
|
||||
Label: "Fallback Model",
|
||||
Description: "Model used when no candidate's labels cover the classifier's active label set, or when the classifier errors. Empty means router failures bubble up as HTTP 500 — fail-fast, not silent-bypass.",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsChat,
|
||||
Order: 232,
|
||||
},
|
||||
"router.activation_threshold": {
|
||||
Section: "other",
|
||||
Label: "Activation Threshold",
|
||||
Description: "Softmax-probability floor a policy must clear to join the active label set for a request. Higher → single-label dominant routes; lower → more multi-label activations. 0 picks the package default (0.15). On Arch-Router-1.5B a value around 0.40 keeps the dominant label clean without losing genuine compound activations.",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(1),
|
||||
Step: f64(0.05),
|
||||
Order: 233,
|
||||
},
|
||||
"router.classifier_cache_size": {
|
||||
Section: "other",
|
||||
Label: "Classifier L1 Cache Size",
|
||||
Description: "Bounded LRU keyed on (case-folded, whitespace-trimmed) prompt — amortises the classifier round-trip across verbatim repeats common in agent loops. 0 here means \"use the default\" (1024); the cache cannot be disabled from YAML.",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Order: 234,
|
||||
},
|
||||
"router.policies": {
|
||||
Section: "other",
|
||||
Label: "Policies",
|
||||
Description: "Label vocabulary the classifier scores over. Each policy has a label and a short natural-language description fed verbatim to the classifier model. Short action-oriented sentences work best (\"writing or debugging code\"; \"small talk\").",
|
||||
Component: "router-policies",
|
||||
Order: 235,
|
||||
},
|
||||
"router.candidates": {
|
||||
Section: "other",
|
||||
Label: "Candidates",
|
||||
Description: "Routing table: each entry binds a downstream model to a set of policy labels it can serve. Order matters — the middleware picks the FIRST candidate whose labels are a superset of the active set, so list candidates smallest → largest.",
|
||||
Component: "router-candidates",
|
||||
Order: 236,
|
||||
},
|
||||
"router.score_normalization": {
|
||||
Section: "other",
|
||||
Label: "Score Normalization",
|
||||
Description: "How the score classifier collapses per-candidate joint log-probs into the softmax input. \"raw\" (default) feeds joint log-prob as-is — on-distribution for Arch-Router (the route the model would actually emit if decoded freely). \"mean\" divides by candidate token count — fairer to long labels but off-distribution for models trained to emit fixed-format outputs.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Raw (default)"},
|
||||
{Value: "raw", Label: "Raw"},
|
||||
{Value: "mean", Label: "Mean (length-normalised)"},
|
||||
},
|
||||
Order: 240,
|
||||
},
|
||||
"router.embedding_cache.embedding_model": {
|
||||
Section: "other",
|
||||
Label: "L2 Cache: Embedding Model",
|
||||
Description: "Embedding model used by the L2 decision cache. Embeds incoming probes and looks them up in the per-router local-store collection. Empty disables the cache entirely. nomic-embed-text-v1.5 is the recommended default.",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModels,
|
||||
Order: 237,
|
||||
},
|
||||
"router.embedding_cache.similarity_threshold": {
|
||||
Section: "other",
|
||||
Label: "L2 Cache: Similarity Threshold",
|
||||
Description: "Cosine-similarity floor a cache candidate must clear to count as a hit. 0 picks the package default (0.80). Re-tune per embedding model — the histogram on the Routing tab shows where the cosine distribution actually sits.",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(1),
|
||||
Step: f64(0.01),
|
||||
Order: 238,
|
||||
},
|
||||
"router.embedding_cache.confidence_threshold": {
|
||||
Section: "other",
|
||||
Label: "L2 Cache: Confidence Threshold",
|
||||
Description: "Minimum top-label probability a classifier decision must have to be inserted into the cache. 0 picks the package default (0.60). Uncertain decisions are skipped so they can't poison future paraphrases.",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(1),
|
||||
Step: f64(0.05),
|
||||
Order: 239,
|
||||
},
|
||||
"router.embedding_cache.store_name": {
|
||||
Section: "other",
|
||||
Label: "L2 Cache: Store Name",
|
||||
Description: "Optional override for the local-store collection used by this router's cache. Empty defaults to \"router-cache-<router-model-name>\". Two routers sharing a store_name share their cache (rare).",
|
||||
Component: "input",
|
||||
Order: 240,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
258
core/config/meta/registry_coverage_test.go
Normal file
258
core/config/meta/registry_coverage_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package meta_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
)
|
||||
|
||||
// TestAllFieldsHaveRegistryEntries fails when a NEW ModelConfig field
|
||||
// is added without either a registry entry in DefaultRegistry() or an
|
||||
// entry in the grandfatheredUnregistered baseline below.
|
||||
//
|
||||
// Why this matters: fields without a registry entry render in the UI
|
||||
// with no description, the default `input` (single-line) component,
|
||||
// and land in the catch-all "other" section — which is what we just
|
||||
// hit for router.classifier_system_template. The reflection-based
|
||||
// fallback produces something that works mechanically but is hostile
|
||||
// to operators.
|
||||
//
|
||||
// How to fix when this test fails:
|
||||
//
|
||||
// 1. Preferred — add a registry entry to DefaultRegistry() in
|
||||
// registry.go with Section, Label, Description, and Component.
|
||||
// See e.g. "router.classifier" or "template.chat" for the pattern.
|
||||
//
|
||||
// 2. Escape hatch — append the field path to
|
||||
// grandfatheredUnregistered with a one-line comment justifying
|
||||
// why it has no UI surface (internal, deprecated, legacy
|
||||
// compatibility shim, etc.). The expectation is that this list
|
||||
// shrinks over time as fields get proper registry entries; it
|
||||
// should never grow without good reason.
|
||||
//
|
||||
// The grandfathered list was seeded from a one-time audit. Migrating
|
||||
// the existing 150+ entries to proper registry metadata is out of
|
||||
// scope for any single PR; the lock just stops the list from growing.
|
||||
func TestAllFieldsHaveRegistryEntries(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
md := meta.BuildConfigMetadata(reflect.TypeOf(config.ModelConfig{}))
|
||||
reg := meta.DefaultRegistry()
|
||||
|
||||
grand := make(map[string]struct{}, len(grandfatheredUnregistered))
|
||||
for _, p := range grandfatheredUnregistered {
|
||||
grand[p] = struct{}{}
|
||||
}
|
||||
|
||||
var missing []string
|
||||
for _, f := range md.Fields {
|
||||
if _, ok := reg[f.Path]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := grand[f.Path]; ok {
|
||||
continue
|
||||
}
|
||||
missing = append(missing, f.Path)
|
||||
}
|
||||
|
||||
sort.Strings(missing)
|
||||
g.Expect(missing).To(BeEmpty(),
|
||||
"%d config field(s) have no registry entry and are not on the grandfathered list.\n"+
|
||||
"Add a registry entry to core/config/meta/registry.go OR append to grandfatheredUnregistered in this file with a justification:\n %s",
|
||||
len(missing), strings.Join(missing, "\n "),
|
||||
)
|
||||
|
||||
// Inverse drift check: catch dead entries on the grandfathered
|
||||
// list (field was renamed/removed, or someone wrote a registry
|
||||
// entry without trimming the grandfathered duplicate).
|
||||
known := make(map[string]struct{}, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
known[f.Path] = struct{}{}
|
||||
}
|
||||
var stale, duplicated []string
|
||||
for _, p := range grandfatheredUnregistered {
|
||||
if _, ok := known[p]; !ok {
|
||||
stale = append(stale, p)
|
||||
}
|
||||
if _, ok := reg[p]; ok {
|
||||
duplicated = append(duplicated, p)
|
||||
}
|
||||
}
|
||||
sort.Strings(stale)
|
||||
g.Expect(stale).To(BeEmpty(),
|
||||
"grandfatheredUnregistered references fields that no longer exist in ModelConfig — remove them:\n %s",
|
||||
strings.Join(stale, "\n "))
|
||||
sort.Strings(duplicated)
|
||||
g.Expect(duplicated).To(BeEmpty(),
|
||||
"grandfatheredUnregistered references fields that now HAVE a registry entry — remove them so the test stays meaningful:\n %s",
|
||||
strings.Join(duplicated, "\n "))
|
||||
}
|
||||
|
||||
// grandfatheredUnregistered is the baseline of config fields that
|
||||
// pre-date the registry-coverage test and have no UI metadata yet.
|
||||
// Adding new entries here should be a deliberate, justified decision
|
||||
// — prefer adding a registry entry in registry.go instead.
|
||||
//
|
||||
// Keep the list sorted (one-line-per-entry) so the diff is minimal
|
||||
// when an entry is removed or (rarely) added.
|
||||
var grandfatheredUnregistered = []string{
|
||||
"agent.disable_sink_state",
|
||||
"agent.enable_mcp_prompts",
|
||||
"agent.enable_plan_re_evaluator",
|
||||
"agent.enable_planning",
|
||||
"agent.enable_reasoning",
|
||||
"agent.force_reasoning_tool",
|
||||
"agent.loop_detection",
|
||||
"agent.max_adjustment_attempts",
|
||||
"agent.max_attempts",
|
||||
"agent.max_iterations",
|
||||
"cfg_scale",
|
||||
"concurrency_groups",
|
||||
"cutstrings",
|
||||
"debug",
|
||||
"diffusers.clip_model",
|
||||
"diffusers.clip_skip",
|
||||
"diffusers.clip_subfolder",
|
||||
"diffusers.control_net",
|
||||
"diffusers.enable_parameters",
|
||||
"diffusers.img2img",
|
||||
"disable_log_stats",
|
||||
"disabled",
|
||||
"download_files",
|
||||
"draft_model",
|
||||
"dtype",
|
||||
"enforce_eager",
|
||||
"engine_args",
|
||||
"extract_regex",
|
||||
"feature_flags",
|
||||
"function.argument_regex",
|
||||
"function.argument_regex_key_name",
|
||||
"function.argument_regex_value_name",
|
||||
"function.automatic_tool_parsing_fallback",
|
||||
"function.capture_llm_results",
|
||||
"function.disable_no_action",
|
||||
"function.disable_peg_parser",
|
||||
"function.function_arguments_key",
|
||||
"function.function_name_key",
|
||||
"function.grammar.disable_parallel_new_lines",
|
||||
"function.grammar.expect_strings_after_json",
|
||||
"function.grammar.no_mixed_free_string",
|
||||
"function.grammar.prefix",
|
||||
"function.grammar.properties_order",
|
||||
"function.grammar.schema_type",
|
||||
"function.grammar.triggers",
|
||||
"function.json_regex_match",
|
||||
"function.no_action_description_name",
|
||||
"function.no_action_function_name",
|
||||
"function.replace_function_results",
|
||||
"function.replace_llm_results",
|
||||
"function.response_regex",
|
||||
"function.xml_format.allow_toolcall_in_think",
|
||||
"function.xml_format.key_start",
|
||||
"function.xml_format.key_val_sep",
|
||||
"function.xml_format.key_val_sep2",
|
||||
"function.xml_format.last_tool_end",
|
||||
"function.xml_format.last_val_end",
|
||||
"function.xml_format.raw_argval",
|
||||
"function.xml_format.scope_end",
|
||||
"function.xml_format.scope_start",
|
||||
"function.xml_format.tool_end",
|
||||
"function.xml_format.tool_sep",
|
||||
"function.xml_format.tool_start",
|
||||
"function.xml_format.trim_raw_argval",
|
||||
"function.xml_format.val_end",
|
||||
"function.xml_format_preset",
|
||||
"gpu_memory_utilization",
|
||||
"grammar",
|
||||
"grpc.attempts",
|
||||
"grpc.attempts_sleep_time",
|
||||
"limit_mm_per_prompt.audio",
|
||||
"limit_mm_per_prompt.image",
|
||||
"limit_mm_per_prompt.video",
|
||||
"limits.max_concurrent",
|
||||
"limits.retry_after_seconds",
|
||||
"load_format",
|
||||
"lora_adapter",
|
||||
"lora_adapters",
|
||||
"lora_base",
|
||||
"lora_scale",
|
||||
"lora_scales",
|
||||
"main_gpu",
|
||||
"max_model_len",
|
||||
"mcp.remote",
|
||||
"mcp.stdio",
|
||||
"mirostat",
|
||||
"mirostat_eta",
|
||||
"mirostat_tau",
|
||||
"mmproj",
|
||||
"n_draft",
|
||||
"ngqa",
|
||||
"no_kv_offloading",
|
||||
"no_mulmatq",
|
||||
"numa",
|
||||
"options",
|
||||
"overrides",
|
||||
"parameters.batch",
|
||||
"parameters.clip_skip",
|
||||
"parameters.echo",
|
||||
"parameters.encoding_format",
|
||||
"parameters.frequency_penalty",
|
||||
"parameters.ignore_eos",
|
||||
"parameters.language",
|
||||
"parameters.logit_bias",
|
||||
"parameters.logprobs",
|
||||
"parameters.min_p",
|
||||
"parameters.model",
|
||||
"parameters.n",
|
||||
"parameters.n_keep",
|
||||
"parameters.negative_prompt",
|
||||
"parameters.negative_prompt_scale",
|
||||
"parameters.presence_penalty",
|
||||
"parameters.repeat_last_n",
|
||||
"parameters.rope_freq_base",
|
||||
"parameters.rope_freq_scale",
|
||||
"parameters.tfz",
|
||||
"parameters.tokenizer",
|
||||
"parameters.top_logprobs",
|
||||
"parameters.translate",
|
||||
"parameters.typical_p",
|
||||
"pinned",
|
||||
"prompt_cache_all",
|
||||
"prompt_cache_path",
|
||||
"prompt_cache_ro",
|
||||
"proxy.api_key_file",
|
||||
"reasoning.disable",
|
||||
"reasoning.disable_reasoning_tag_prefill",
|
||||
"reasoning.strip_reasoning_only",
|
||||
"reasoning.tag_pairs",
|
||||
"reasoning.thinking_start_tokens",
|
||||
"reranking",
|
||||
"rms_norm_eps",
|
||||
"roles",
|
||||
"rope_scaling",
|
||||
"step",
|
||||
"stopwords",
|
||||
"swap_space",
|
||||
"system_prompt",
|
||||
"template.edit",
|
||||
"template.function",
|
||||
"template.join_chat_messages_by_character",
|
||||
"template.multimodal",
|
||||
"template.reply_prefix",
|
||||
"tensor_parallel_size",
|
||||
"tensor_split",
|
||||
"trimspace",
|
||||
"trimsuffix",
|
||||
"trust_remote_code",
|
||||
"tts.audio_path",
|
||||
"type",
|
||||
"yarn_attn_factor",
|
||||
"yarn_beta_fast",
|
||||
"yarn_beta_slow",
|
||||
"yarn_ext_factor",
|
||||
}
|
||||
133
core/config/mitm_host_owners_test.go
Normal file
133
core/config/mitm_host_owners_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
)
|
||||
|
||||
// MITMHostOwners is the load-bearing piece of D2 — a duplicate host
|
||||
// across model configs is a critical error that disables the listener.
|
||||
// The test exercises both happy paths (no duplicates → clean Owners
|
||||
// map) and conflict detection (two configs on one host → entry in
|
||||
// Conflicts naming both).
|
||||
|
||||
var _ = Describe("ModelConfigLoader.MITMHostOwners", func() {
|
||||
var (
|
||||
dir string
|
||||
loader *config.ModelConfigLoader
|
||||
)
|
||||
|
||||
writeYAML := func(name, body string) {
|
||||
path := filepath.Join(dir, name+".yaml")
|
||||
Expect(os.WriteFile(path, []byte(body), 0o644)).To(Succeed())
|
||||
Expect(loader.ReadModelConfig(path)).To(Succeed())
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
dir, err = os.MkdirTemp("", "mitm-host-owners-test-*")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
loader = config.NewModelConfigLoader(dir)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.RemoveAll(dir)
|
||||
})
|
||||
|
||||
It("returns empty maps when no model declares mitm.hosts", func() {
|
||||
writeYAML("plain", `name: plain
|
||||
backend: llama-cpp
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(BeEmpty())
|
||||
Expect(got.Conflicts).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("indexes hosts to the owning model name", func() {
|
||||
writeYAML("claude", `name: claude
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- api.anthropic.com
|
||||
`)
|
||||
writeYAML("openai", `name: openai
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- api.openai.com
|
||||
- api.openai.azure.com
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(Equal(map[string]string{
|
||||
"api.anthropic.com": "claude",
|
||||
"api.openai.com": "openai",
|
||||
"api.openai.azure.com": "openai",
|
||||
}))
|
||||
Expect(got.Conflicts).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("normalises case and trims whitespace before indexing", func() {
|
||||
writeYAML("claude", `name: claude
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- " API.ANTHROPIC.com "
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(HaveKey("api.anthropic.com"))
|
||||
})
|
||||
|
||||
It("detects two configs claiming the same host as a conflict", func() {
|
||||
// The 1-to-1 invariant the D2 dispatcher relies on: a host
|
||||
// claimed twice means the owner lookup is ambiguous, so the
|
||||
// caller must NOT start the MITM listener until resolved.
|
||||
writeYAML("alpha", `name: alpha
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- api.anthropic.com
|
||||
`)
|
||||
writeYAML("beta", `name: beta
|
||||
backend: cloud-proxy
|
||||
mitm:
|
||||
hosts:
|
||||
- api.anthropic.com
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Conflicts).To(HaveKey("api.anthropic.com"))
|
||||
Expect(got.Conflicts["api.anthropic.com"]).To(ConsistOf("alpha", "beta"))
|
||||
})
|
||||
|
||||
It("treats the same host listed twice within ONE config as a no-op (not a conflict)", func() {
|
||||
// A single config repeating a host is benign — same owner
|
||||
// either way. The conflict signal must be cross-config only.
|
||||
writeYAML("dup", `name: dup
|
||||
backend: llama-cpp
|
||||
mitm:
|
||||
hosts:
|
||||
- api.example.com
|
||||
- api.example.com
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(Equal(map[string]string{"api.example.com": "dup"}))
|
||||
Expect(got.Conflicts).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("ignores empty/whitespace-only host entries", func() {
|
||||
writeYAML("sloppy", `name: sloppy
|
||||
backend: llama-cpp
|
||||
mitm:
|
||||
hosts:
|
||||
- ""
|
||||
- " "
|
||||
- api.real.com
|
||||
`)
|
||||
got := loader.MITMHostOwners()
|
||||
Expect(got.Owners).To(Equal(map[string]string{"api.real.com": "sloppy"}))
|
||||
})
|
||||
})
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
@@ -95,8 +96,330 @@ type ModelConfig struct {
|
||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||
|
||||
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
|
||||
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
|
||||
PII PIIConfig `yaml:"pii,omitempty" json:"pii,omitempty"`
|
||||
Router RouterConfig `yaml:"router,omitempty" json:"router,omitempty"`
|
||||
Proxy ProxyConfig `yaml:"proxy,omitempty" json:"proxy,omitempty"`
|
||||
MITM MITMModelConfig `yaml:"mitm,omitempty" json:"mitm,omitempty"`
|
||||
Limits LimitsConfig `yaml:"limits,omitempty" json:"limits,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Admission-control limits applied per request. The
|
||||
// admission middleware enforces these before invoking the handler;
|
||||
// requests that exceed a limit get 503 with a Retry-After hint so
|
||||
// clients back off rather than pile on. Per-model so cloud passthroughs
|
||||
// can have a stricter ceiling than local models.
|
||||
type LimitsConfig struct {
|
||||
// MaxConcurrent caps simultaneous in-flight requests for this
|
||||
// model. 0 = unlimited (default). Useful for cloud-passthrough
|
||||
// configs where the upstream rate-limits aggressively, or for
|
||||
// local backends whose memory budget tops out before LocalAI's
|
||||
// queue depth would.
|
||||
MaxConcurrent int `yaml:"max_concurrent,omitempty" json:"max_concurrent,omitempty"`
|
||||
|
||||
// RetryAfterSeconds advises clients how long to wait before
|
||||
// retrying when admission rejects. 0 defaults to 1s — enough to
|
||||
// let an in-flight request finish on a busy local model. The
|
||||
// value is sent verbatim in the Retry-After response header.
|
||||
RetryAfterSeconds int `yaml:"retry_after_seconds,omitempty" json:"retry_after_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// @Description MITM intercept binding for the model. When the cloudproxy
|
||||
// MITM listener is enabled and any host listed here appears in a CONNECT,
|
||||
// the proxy uses THIS model config's pii: settings to filter the
|
||||
// intercepted body. Strict 1-to-1: a host claimed by two configs is a
|
||||
// configuration error and disables the MITM listener until resolved.
|
||||
//
|
||||
// Lets an admin pair a host (api.anthropic.com) with the model's
|
||||
// PII overrides without maintaining a parallel per-host map.
|
||||
type MITMModelConfig struct {
|
||||
// Hosts is the list of hostnames this model claims for MITM
|
||||
// interception. Each entry must be unique across all model configs.
|
||||
Hosts []string `yaml:"hosts,omitempty" json:"hosts,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Cloud proxy configuration. The cloud-proxy backend
|
||||
// forwards a model's traffic to an external provider. Two modes:
|
||||
//
|
||||
// - mode: passthrough — client and upstream must speak the same wire
|
||||
// format; the backend ships the raw request body to the upstream
|
||||
// URL and streams the response back untouched. The streaming PII
|
||||
// filter still runs because it operates on extracted token text.
|
||||
//
|
||||
// - mode: translate — the backend converts LocalAI's internal proto
|
||||
// to the provider's wire format and back. Unlocks cross-provider
|
||||
// routing (OpenAI client → Anthropic upstream, etc.) at the cost
|
||||
// of dropping provider-specific extensions that the internal proto
|
||||
// doesn't model.
|
||||
type ProxyConfig struct {
|
||||
// UpstreamURL is the full POST endpoint, e.g.
|
||||
// https://api.openai.com/v1/chat/completions or
|
||||
// https://api.anthropic.com/v1/messages. Required.
|
||||
UpstreamURL string `yaml:"upstream_url,omitempty" json:"upstream_url,omitempty"`
|
||||
|
||||
// Mode selects passthrough (wire-perfect) or translate (full
|
||||
// control via internal proto). Empty defaults to passthrough.
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||
|
||||
// Provider identifies the upstream's wire format for translate
|
||||
// mode (openai, anthropic). Ignored in passthrough mode — the
|
||||
// wire format there is whatever the client sent.
|
||||
Provider string `yaml:"provider,omitempty" json:"provider,omitempty"`
|
||||
|
||||
// APIKeyEnv names the environment variable holding the upstream
|
||||
// API key. Mutually exclusive with APIKeyFile. Both empty is
|
||||
// allowed (no-auth upstreams).
|
||||
APIKeyEnv string `yaml:"api_key_env,omitempty" json:"api_key_env,omitempty"`
|
||||
|
||||
// APIKeyFile is a path to a file whose contents are the upstream
|
||||
// API key. Trailing whitespace is trimmed. Mutually exclusive
|
||||
// with APIKeyEnv. The integration point for K8s secret mounts,
|
||||
// Vault agent files, and similar external-secret workflows.
|
||||
APIKeyFile string `yaml:"api_key_file,omitempty" json:"api_key_file,omitempty"`
|
||||
|
||||
// UpstreamModel overrides the model name sent to the upstream.
|
||||
// Useful when the LocalAI-facing model alias differs from the
|
||||
// upstream's canonical name (e.g. local "claude-strict" maps to
|
||||
// upstream "claude-3-5-sonnet-20241022"). Empty means forward
|
||||
// the client's model field unchanged.
|
||||
UpstreamModel string `yaml:"upstream_model,omitempty" json:"upstream_model,omitempty"`
|
||||
|
||||
// RequestTimeoutSeconds caps the upstream request duration. 0
|
||||
// means no per-request timeout (only the request context, which
|
||||
// is bound to the client connection, applies).
|
||||
RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// Proxy mode names. Validate() normalises an empty Mode to
|
||||
// ProxyModePassthrough so downstream code only sees concrete values.
|
||||
const (
|
||||
ProxyModePassthrough = "passthrough"
|
||||
ProxyModeTranslate = "translate"
|
||||
)
|
||||
|
||||
// Proxy provider names. Only meaningful in translate mode, where the
|
||||
// cloud-proxy backend picks the wire format to use against the
|
||||
// upstream URL.
|
||||
const (
|
||||
ProxyProviderOpenAI = "openai"
|
||||
ProxyProviderAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// IsCloudProxyBackendPassthrough reports whether this model uses the
|
||||
// cloud-proxy gRPC backend in passthrough mode. Empty Mode counts as
|
||||
// passthrough (SetDefaults normalises it, but Validate accepts empty
|
||||
// too — handlers should not rely on a particular call order).
|
||||
func (c *ModelConfig) IsCloudProxyBackendPassthrough() bool {
|
||||
if c.Backend != "cloud-proxy" {
|
||||
return false
|
||||
}
|
||||
return c.Proxy.Mode == "" || c.Proxy.Mode == ProxyModePassthrough
|
||||
}
|
||||
|
||||
// @Description Intelligent routing configuration. When a model declares
|
||||
// a Router block, requests addressed to it are reclassified at runtime
|
||||
// and dispatched to one of the named candidates. The router rewrites
|
||||
// input.Model in-place, then the standard model-resolution path picks
|
||||
// up the resolved config — meaning ACL checks, disabled-state, and
|
||||
// per-model PII still run against the chosen target.
|
||||
//
|
||||
// Depth-1 invariant: candidates must NOT themselves carry a Router
|
||||
// block. The router's "smart-router → claude-strict → cloud-proxy"
|
||||
// chain is fine, but "router-A → router-B → claude" is rejected at
|
||||
// config load to keep the dispatch graph acyclic and predictable. The
|
||||
// middleware also asserts depth ≤ 1 at runtime as a defensive check.
|
||||
type RouterConfig struct {
|
||||
// Classifier picks the implementation. Only "score" ships today:
|
||||
// it asks the classifier model to score every Policy label as a
|
||||
// continuation of the routing prompt and reads off the
|
||||
// distribution. Empty defaults to "score".
|
||||
Classifier string `yaml:"classifier,omitempty" json:"classifier,omitempty"`
|
||||
|
||||
// Policies is the label vocabulary the classifier scores over.
|
||||
// Each policy carries a natural-language description that ends up
|
||||
// in the system prompt the classifier model sees — short, action-
|
||||
// oriented sentences work best ("writing or debugging code",
|
||||
// "small talk", ...). The Score classifier picks the subset of
|
||||
// labels whose softmax probability passes ActivationThreshold.
|
||||
Policies []RouterPolicy `yaml:"policies,omitempty" json:"policies,omitempty"`
|
||||
|
||||
// Candidates is the routing table — each entry binds a downstream
|
||||
// model to a set of labels it can serve. The middleware picks the
|
||||
// FIRST candidate whose Labels are a superset of the active label
|
||||
// set from the classifier. Admins order this list smallest →
|
||||
// largest so a query that needs one label routes to the smallest
|
||||
// capable model, while a query that needs multiple falls to a
|
||||
// bigger candidate that covers them all.
|
||||
Candidates []RouterCandidate `yaml:"candidates,omitempty" json:"candidates,omitempty"`
|
||||
|
||||
// Fallback is the model used when no candidate matches the active
|
||||
// label set, or when the classifier returns nothing above
|
||||
// threshold. Empty fallback means router failures bubble up as
|
||||
// 500 — fail-fast, not silent-bypass.
|
||||
Fallback string `yaml:"fallback,omitempty" json:"fallback,omitempty"`
|
||||
|
||||
// ClassifierModel names the model the Score classifier scores
|
||||
// against (Arch-Router-1.5B is the canonical choice).
|
||||
ClassifierModel string `yaml:"classifier_model,omitempty" json:"classifier_model,omitempty"`
|
||||
|
||||
// ClassifierCacheSize bounds the per-prompt memo cache that
|
||||
// amortises the classifier round-trip across repeat probes.
|
||||
// 0 disables the cache. Default 1024.
|
||||
ClassifierCacheSize int `yaml:"classifier_cache_size,omitempty" json:"classifier_cache_size,omitempty"`
|
||||
|
||||
// ActivationThreshold is the softmax-probability floor a policy
|
||||
// must clear to be considered "active" for the request. 0
|
||||
// defaults to a sensible value (~0.15) inside the classifier.
|
||||
// Higher → narrower routes (single-label dominant); lower →
|
||||
// more multi-label activations.
|
||||
ActivationThreshold float64 `yaml:"activation_threshold,omitempty" json:"activation_threshold,omitempty"`
|
||||
|
||||
// ClassifierSystemTemplate overrides the routing system prompt
|
||||
// the score classifier feeds to its classifier_model. Go
|
||||
// text/template + Sprig, executed with `.Policies []ScorePolicy`
|
||||
// (Label + Description fields). Empty falls back to the built-in
|
||||
// Arch-Router-shaped template (route-listing block + JSON output
|
||||
// schema). Override when the classifier model was trained on a
|
||||
// different schema (e.g. bare label output, XML route block) or
|
||||
// when the routing instructions need to be in a different
|
||||
// language. The candidate format scored against the model is
|
||||
// fixed at `{"route": "<label>"}` and IS NOT templated — keep
|
||||
// your override's output schema instruction matching that, or
|
||||
// the per-candidate scores degenerate.
|
||||
ClassifierSystemTemplate string `yaml:"classifier_system_template,omitempty" json:"classifier_system_template,omitempty"`
|
||||
|
||||
// ScoreNormalization picks how the score classifier collapses
|
||||
// per-candidate joint log-probs into the softmax input.
|
||||
// - ""/"raw": use joint log-prob as-is (default). Matches the
|
||||
// distribution the classifier model was trained against — the
|
||||
// route the model would actually emit if decoded freely.
|
||||
// - "mean": divide by candidate token count. Fairer to long
|
||||
// labels (their joint log-prob is mechanically smaller because
|
||||
// it sums more negatives), but off-distribution for models
|
||||
// trained to emit fixed-format outputs like Arch-Router's
|
||||
// {"route": "name"}.
|
||||
// Future modes (e.g. "weighted_mean") will land here too.
|
||||
ScoreNormalization string `yaml:"score_normalization,omitempty" json:"score_normalization,omitempty"`
|
||||
|
||||
// EmbeddingCache configures the L2 cache that maps prompt
|
||||
// embeddings to past decisions, so semantically-similar prompts
|
||||
// reuse a classification instead of re-running the classifier
|
||||
// model. Omit the block to disable. See router/embedding_cache.go.
|
||||
EmbeddingCache *EmbeddingCacheConfig `yaml:"embedding_cache,omitempty" json:"embedding_cache,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingCacheConfig configures the L2 embedding-similarity decision
|
||||
// cache. Pairs naturally with a larger / slower classifier model: the
|
||||
// classifier round-trip is amortised across paraphrases of the same
|
||||
// intent. The cache uses the standard /v1/embeddings backend for
|
||||
// vector generation and the local-store gRPC surface for KNN search.
|
||||
type EmbeddingCacheConfig struct {
|
||||
// EmbeddingModel names the loaded LocalAI model used to embed
|
||||
// router prompts. Required when the cache is enabled. Any model
|
||||
// that supports the Embeddings gRPC primitive works;
|
||||
// nomic-embed-text-v1.5 is the recommended default.
|
||||
EmbeddingModel string `yaml:"embedding_model" json:"embedding_model"`
|
||||
|
||||
// SimilarityThreshold is the cosine-similarity floor a cache
|
||||
// candidate must clear to be treated as a hit. 0 picks the
|
||||
// package default (0.80). Higher → fewer false hits, higher miss
|
||||
// rate; lower → more aggressive sharing across paraphrases.
|
||||
SimilarityThreshold float64 `yaml:"similarity_threshold,omitempty" json:"similarity_threshold,omitempty"`
|
||||
|
||||
// ConfidenceThreshold is the minimum classifier top-label
|
||||
// probability for a decision to be inserted into the cache. 0
|
||||
// picks the package default (0.60). Uncertain decisions are not
|
||||
// cached so they can't poison future paraphrases.
|
||||
ConfidenceThreshold float64 `yaml:"confidence_threshold,omitempty" json:"confidence_threshold,omitempty"`
|
||||
|
||||
// StoreName overrides the local-store collection name used for
|
||||
// this router's cache. Empty defaults to "router-cache-<router>"
|
||||
// where <router> is the parent model name. Useful when two
|
||||
// router models should share a cache (rare).
|
||||
StoreName string `yaml:"store_name,omitempty" json:"store_name,omitempty"`
|
||||
}
|
||||
|
||||
// RouterPolicy is one entry in the label vocabulary. The label string
|
||||
// is what the classifier model emits and what candidates reference in
|
||||
// their Labels field; the description is the natural-language hint
|
||||
// fed to the classifier so it can match user intent against the label
|
||||
// space.
|
||||
type RouterPolicy struct {
|
||||
Label string `yaml:"label" json:"label"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
}
|
||||
|
||||
// RouterCandidate names a downstream model and the policy labels it
|
||||
// is willing to serve. Labels are matched as a set: the middleware
|
||||
// picks the first candidate whose Labels is a superset of the
|
||||
// classifier's active set.
|
||||
type RouterCandidate struct {
|
||||
Model string `yaml:"model" json:"model"`
|
||||
Labels []string `yaml:"labels" json:"labels"`
|
||||
}
|
||||
|
||||
// HasRouter returns true when the model declares a router config with
|
||||
// at least one candidate. Used by the RouteModel middleware to decide
|
||||
// whether to engage the classifier.
|
||||
func (c *ModelConfig) HasRouter() bool {
|
||||
return len(c.Router.Candidates) > 0
|
||||
}
|
||||
|
||||
// @Description PII filtering configuration. PII redaction is per-model so
|
||||
// that local models don't pay the latency or behaviour change of regex
|
||||
// scanning, while cloud-bound traffic (cloud-proxy backend) can default to
|
||||
// on. Setting Enabled explicitly always wins over the backend default.
|
||||
type PIIConfig struct {
|
||||
// Enabled toggles redaction for this model. When unset (zero value),
|
||||
// the resolved default depends on Backend: cloud-proxy defaults to
|
||||
// true, everything else to false. A pointer is used so the absence of
|
||||
// the YAML key is distinguishable from explicit false.
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
|
||||
// Patterns lets a model upgrade or downgrade individual pattern
|
||||
// actions (mask | block | route_local) relative to the global
|
||||
// defaults loaded from --pii-config / DefaultPatterns. Pattern IDs
|
||||
// not listed inherit the global action. The regex itself stays
|
||||
// global — only the action is settable per-model.
|
||||
Patterns []PIIPatternOverride `yaml:"patterns,omitempty" json:"patterns,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Per-model action override for a single PII pattern.
|
||||
type PIIPatternOverride struct {
|
||||
ID string `yaml:"id" json:"id"`
|
||||
Action string `yaml:"action" json:"action"`
|
||||
}
|
||||
|
||||
// PIIIsEnabled returns the resolved PII state for this model. Single
|
||||
// source of truth for the gating decision so the middleware and the
|
||||
// /api/middleware/status admin view agree.
|
||||
func (c *ModelConfig) PIIIsEnabled() bool {
|
||||
if c.PII.Enabled != nil {
|
||||
return *c.PII.Enabled
|
||||
}
|
||||
return c.Backend == "cloud-proxy"
|
||||
}
|
||||
|
||||
// PIIPatternOverrides returns the per-pattern action overrides as a map
|
||||
// keyed by pattern ID. The values are the raw action strings — the pii
|
||||
// package validates and converts them.
|
||||
//
|
||||
// Returned via the documented modelPIIConfig interface in
|
||||
// core/services/routing/pii/middleware.go without taking a config
|
||||
// dependency on this package.
|
||||
func (c *ModelConfig) PIIPatternOverrides() map[string]string {
|
||||
if len(c.PII.Patterns) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(c.PII.Patterns))
|
||||
for _, p := range c.PII.Patterns {
|
||||
if p.ID == "" {
|
||||
continue
|
||||
}
|
||||
out[p.ID] = p.Action
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// @Description MCP configuration
|
||||
@@ -401,6 +724,14 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
f16 := lo.f16
|
||||
debug := lo.debug
|
||||
|
||||
// Cloud-proxy: normalise empty Mode so downstream consumers
|
||||
// switch on two concrete values only. Validate accepts empty too,
|
||||
// but SetDefaults is the chokepoint that runs before any
|
||||
// inference path reads cfg.Proxy.Mode.
|
||||
if cfg.Proxy.Mode == "" {
|
||||
cfg.Proxy.Mode = ProxyModePassthrough
|
||||
}
|
||||
|
||||
// Apply model-family-specific inference defaults before generic fallbacks.
|
||||
// This ensures gallery-installed and runtime-loaded models get optimal parameters.
|
||||
ApplyInferenceDefaults(cfg, cfg.Name, cfg.Model)
|
||||
@@ -573,9 +904,74 @@ func (c *ModelConfig) Validate() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Cloud-proxy: at most one of api_key_env / api_key_file may be
|
||||
// set. Both empty means no Authorization header (no-auth upstream
|
||||
// or a development passthrough). The mode field accepts the empty
|
||||
// string (defaults to passthrough), "passthrough", or "translate".
|
||||
if c.Proxy.APIKeyEnv != "" && c.Proxy.APIKeyFile != "" {
|
||||
return false, fmt.Errorf("proxy: api_key_env and api_key_file are mutually exclusive")
|
||||
}
|
||||
switch c.Proxy.Mode {
|
||||
case "", ProxyModePassthrough, ProxyModeTranslate:
|
||||
// Empty is accepted at validate-time and normalised to
|
||||
// passthrough by SetDefaults so it never reaches runtime.
|
||||
default:
|
||||
return false, fmt.Errorf("proxy: unknown mode %q (expected %s or %s)",
|
||||
c.Proxy.Mode, ProxyModePassthrough, ProxyModeTranslate)
|
||||
}
|
||||
if c.Proxy.Mode == ProxyModeTranslate && c.Proxy.Provider == "" {
|
||||
return false, fmt.Errorf("proxy: translate mode requires provider (%s, %s)",
|
||||
ProxyProviderOpenAI, ProxyProviderAnthropic)
|
||||
}
|
||||
|
||||
// Score on llama-cpp bypasses the slot loop and races the
|
||||
// llama_context against concurrent generation/embedding traffic
|
||||
// (see backend/cpp/llama-cpp/grpc-server.cpp on Score). Reject the
|
||||
// combination here so operators are forced to split the model.
|
||||
const scoreConflicts = FLAG_CHAT | FLAG_COMPLETION | FLAG_EMBEDDINGS
|
||||
if (c.Backend == "llama-cpp" || c.Backend == "llama") &&
|
||||
c.HasUsecases(FLAG_SCORE) && c.KnownUsecases != nil &&
|
||||
*c.KnownUsecases&scoreConflicts != 0 {
|
||||
return false, fmt.Errorf(
|
||||
"known_usecases conflict on llama-cpp: score is incompatible " +
|
||||
"with chat/completion/embeddings — split into separate model configs")
|
||||
}
|
||||
|
||||
// router.score_normalization is consumed lazily by the score
|
||||
// classifier at first-request time; without load-time validation
|
||||
// a typo wouldn't surface until the first router request panicked
|
||||
// inside NewScoreClassifier. Reject unknown values here so the
|
||||
// operator sees the offending key at startup.
|
||||
switch c.Router.ScoreNormalization {
|
||||
case "", ScoreNormalizationRaw, ScoreNormalizationMean:
|
||||
// ok
|
||||
default:
|
||||
return false, fmt.Errorf("router: unknown score_normalization %q (expected %q or %q)",
|
||||
c.Router.ScoreNormalization, ScoreNormalizationRaw, ScoreNormalizationMean)
|
||||
}
|
||||
|
||||
// router.classifier_system_template parses as Go text/template
|
||||
// (Sprig funcs available at execution time). Reject malformed
|
||||
// templates at load time so the operator sees the parse error
|
||||
// at startup rather than as a 500 on the first router request.
|
||||
if c.Router.ClassifierSystemTemplate != "" {
|
||||
if _, err := template.New("classifier_system").Parse(c.Router.ClassifierSystemTemplate); err != nil {
|
||||
return false, fmt.Errorf("router: classifier_system_template parse error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Score normalisation modes mirror router.ScoreNormalization* —
|
||||
// duplicated as constants on the config package so ModelConfig.Validate
|
||||
// can reject unknown values without taking a dependency on the router
|
||||
// package (which already depends on config).
|
||||
const (
|
||||
ScoreNormalizationRaw = "raw"
|
||||
ScoreNormalizationMean = "mean"
|
||||
)
|
||||
|
||||
func (c *ModelConfig) HasTemplate() bool {
|
||||
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate
|
||||
}
|
||||
@@ -624,19 +1020,19 @@ func (c *ModelConfig) GetConcurrencyGroups() []string {
|
||||
type ModelConfigUsecase int
|
||||
|
||||
const (
|
||||
FLAG_ANY ModelConfigUsecase = 0b000000000000
|
||||
FLAG_CHAT ModelConfigUsecase = 0b000000000001
|
||||
FLAG_COMPLETION ModelConfigUsecase = 0b000000000010
|
||||
FLAG_EDIT ModelConfigUsecase = 0b000000000100
|
||||
FLAG_EMBEDDINGS ModelConfigUsecase = 0b000000001000
|
||||
FLAG_RERANK ModelConfigUsecase = 0b000000010000
|
||||
FLAG_IMAGE ModelConfigUsecase = 0b000000100000
|
||||
FLAG_TRANSCRIPT ModelConfigUsecase = 0b000001000000
|
||||
FLAG_TTS ModelConfigUsecase = 0b000010000000
|
||||
FLAG_SOUND_GENERATION ModelConfigUsecase = 0b000100000000
|
||||
FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000
|
||||
FLAG_VAD ModelConfigUsecase = 0b010000000000
|
||||
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
|
||||
FLAG_ANY ModelConfigUsecase = 0b000000000000
|
||||
FLAG_CHAT ModelConfigUsecase = 0b000000000001
|
||||
FLAG_COMPLETION ModelConfigUsecase = 0b000000000010
|
||||
FLAG_EDIT ModelConfigUsecase = 0b000000000100
|
||||
FLAG_EMBEDDINGS ModelConfigUsecase = 0b000000001000
|
||||
FLAG_RERANK ModelConfigUsecase = 0b000000010000
|
||||
FLAG_IMAGE ModelConfigUsecase = 0b000000100000
|
||||
FLAG_TRANSCRIPT ModelConfigUsecase = 0b000001000000
|
||||
FLAG_TTS ModelConfigUsecase = 0b000010000000
|
||||
FLAG_SOUND_GENERATION ModelConfigUsecase = 0b000100000000
|
||||
FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000
|
||||
FLAG_VAD ModelConfigUsecase = 0b010000000000
|
||||
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
|
||||
FLAG_DETECTION ModelConfigUsecase = 0b1000000000000
|
||||
FLAG_VISION ModelConfigUsecase = 0b10000000000000
|
||||
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b100000000000000
|
||||
@@ -644,6 +1040,14 @@ const (
|
||||
FLAG_AUDIO_TRANSFORM ModelConfigUsecase = 0b10000000000000000
|
||||
FLAG_DIARIZATION ModelConfigUsecase = 0b100000000000000000
|
||||
FLAG_REALTIME_AUDIO ModelConfigUsecase = 0b1000000000000000000
|
||||
// Marks a model as wired for the Score gRPC primitive (joint
|
||||
// log-prob of candidate continuations under a shared prompt). Must
|
||||
// be declared explicitly via `known_usecases: [score]` — there's
|
||||
// no heuristic for it. On the llama-cpp backend, Score bypasses
|
||||
// the slot loop and races the llama_context, so Validate() refuses
|
||||
// to load a llama-cpp config that combines FLAG_SCORE with
|
||||
// chat/completion/embeddings.
|
||||
FLAG_SCORE ModelConfigUsecase = 0b10000000000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
@@ -653,12 +1057,12 @@ const (
|
||||
// Flags within the same group are NOT orthogonal (e.g., chat and completion are
|
||||
// both text/language). A model is multimodal when its usecases span 2+ groups.
|
||||
var ModalityGroups = []ModelConfigUsecase{
|
||||
FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT, // text/language
|
||||
FLAG_VISION | FLAG_DETECTION, // visual understanding
|
||||
FLAG_TRANSCRIPT | FLAG_REALTIME_AUDIO, // speech input — realtime_audio is any-to-any, so it counts here too
|
||||
FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT, // text/language
|
||||
FLAG_VISION | FLAG_DETECTION, // visual understanding
|
||||
FLAG_TRANSCRIPT | FLAG_REALTIME_AUDIO, // speech input — realtime_audio is any-to-any, so it counts here too
|
||||
FLAG_TTS | FLAG_SOUND_GENERATION | FLAG_REALTIME_AUDIO, // audio output — and here, so a lone realtime_audio flag still reads as multimodal
|
||||
FLAG_AUDIO_TRANSFORM, // audio in/out transforms
|
||||
FLAG_IMAGE | FLAG_VIDEO, // visual generation
|
||||
FLAG_AUDIO_TRANSFORM, // audio in/out transforms
|
||||
FLAG_IMAGE | FLAG_VIDEO, // visual generation
|
||||
}
|
||||
|
||||
// IsMultimodal returns true if the given usecases span two or more orthogonal
|
||||
@@ -681,19 +1085,19 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
return map[string]ModelConfigUsecase{
|
||||
// Note: FLAG_ANY is intentionally excluded from this map
|
||||
// because it's 0 and would always match in HasUsecases checks
|
||||
"FLAG_CHAT": FLAG_CHAT,
|
||||
"FLAG_COMPLETION": FLAG_COMPLETION,
|
||||
"FLAG_EDIT": FLAG_EDIT,
|
||||
"FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
|
||||
"FLAG_RERANK": FLAG_RERANK,
|
||||
"FLAG_IMAGE": FLAG_IMAGE,
|
||||
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
|
||||
"FLAG_TTS": FLAG_TTS,
|
||||
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
|
||||
"FLAG_TOKENIZE": FLAG_TOKENIZE,
|
||||
"FLAG_VAD": FLAG_VAD,
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
"FLAG_VIDEO": FLAG_VIDEO,
|
||||
"FLAG_CHAT": FLAG_CHAT,
|
||||
"FLAG_COMPLETION": FLAG_COMPLETION,
|
||||
"FLAG_EDIT": FLAG_EDIT,
|
||||
"FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
|
||||
"FLAG_RERANK": FLAG_RERANK,
|
||||
"FLAG_IMAGE": FLAG_IMAGE,
|
||||
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
|
||||
"FLAG_TTS": FLAG_TTS,
|
||||
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
|
||||
"FLAG_TOKENIZE": FLAG_TOKENIZE,
|
||||
"FLAG_VAD": FLAG_VAD,
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
"FLAG_VIDEO": FLAG_VIDEO,
|
||||
"FLAG_DETECTION": FLAG_DETECTION,
|
||||
"FLAG_VISION": FLAG_VISION,
|
||||
"FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION,
|
||||
@@ -701,6 +1105,7 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
"FLAG_AUDIO_TRANSFORM": FLAG_AUDIO_TRANSFORM,
|
||||
"FLAG_DIARIZATION": FLAG_DIARIZATION,
|
||||
"FLAG_REALTIME_AUDIO": FLAG_REALTIME_AUDIO,
|
||||
"FLAG_SCORE": FLAG_SCORE,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -726,9 +1131,23 @@ func GetUsecasesFromYAML(input []string) *ModelConfigUsecase {
|
||||
}
|
||||
|
||||
// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success.
|
||||
//
|
||||
// Declared known_usecases are normally additive — the guessing heuristic
|
||||
// still adds whatever it can infer from backend/templates. The one
|
||||
// exception is FLAG_SCORE: when the operator declared score, they
|
||||
// reserved the model for the router classifier. Letting GuessUsecases
|
||||
// paint chat/completion on top would surface it in chat pickers it was
|
||||
// deliberately kept out of, and (on llama-cpp) reintroduce the slot
|
||||
// contention the score/chat conflict check exists to prevent. So a
|
||||
// declared score list is authoritative.
|
||||
func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
|
||||
if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
|
||||
return true
|
||||
if c.KnownUsecases != nil {
|
||||
if (u & *c.KnownUsecases) == u {
|
||||
return true
|
||||
}
|
||||
if (*c.KnownUsecases & FLAG_SCORE) == FLAG_SCORE {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return c.GuessUsecases(u)
|
||||
}
|
||||
@@ -885,6 +1304,14 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SCORE) == FLAG_SCORE {
|
||||
// No heuristic: Score-intent is a deliberate operator choice
|
||||
// (it reserves the model from generation traffic on llama-cpp),
|
||||
// so HasUsecases(FLAG_SCORE) is true only when KnownUsecases
|
||||
// declares it explicitly.
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -388,6 +388,49 @@ func (bcl *ModelConfigLoader) Preload(modelPath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MITMHostOwnership is the result of mapping intercept hosts to the
|
||||
// model configs that claim them. The invariant the dispatcher relies
|
||||
// on: every host belongs to AT MOST one model config. Any duplicate
|
||||
// is surfaced via Conflicts and disables the MITM listener until
|
||||
// resolved — a half-applied "first wins" rule would silently mask
|
||||
// configuration drift, so we fail loud.
|
||||
type MITMHostOwnership struct {
|
||||
// Owners maps lowercase hostname → owning model name. Empty when
|
||||
// no model declares mitm.hosts.
|
||||
Owners map[string]string
|
||||
// Conflicts lists hosts claimed by 2+ configs, with the names of
|
||||
// the configs that claim them. Non-empty Conflicts means callers
|
||||
// must NOT start the MITM listener.
|
||||
Conflicts map[string][]string
|
||||
}
|
||||
|
||||
// MITMHostOwners walks every loaded ModelConfig's mitm.hosts, builds
|
||||
// the host→owner index, and reports any duplicates. The lookup table
|
||||
// is hostname-lowercased to match the Server's allowlist semantics.
|
||||
func (bcl *ModelConfigLoader) MITMHostOwners() MITMHostOwnership {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
owners := map[string]string{}
|
||||
collisions := map[string][]string{}
|
||||
for name, cfg := range bcl.configs {
|
||||
for _, h := range cfg.MITM.Hosts {
|
||||
h = strings.ToLower(strings.TrimSpace(h))
|
||||
if h == "" {
|
||||
continue
|
||||
}
|
||||
if existing, ok := owners[h]; ok && existing != name {
|
||||
if _, seen := collisions[h]; !seen {
|
||||
collisions[h] = []string{existing}
|
||||
}
|
||||
collisions[h] = append(collisions[h], name)
|
||||
continue
|
||||
}
|
||||
owners[h] = name
|
||||
}
|
||||
}
|
||||
return MITMHostOwnership{Owners: owners, Conflicts: collisions}
|
||||
}
|
||||
|
||||
// LoadModelConfigsFromPath reads all the configurations of the models from a path
|
||||
// (non-recursive)
|
||||
func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
|
||||
|
||||
@@ -54,6 +54,141 @@ parameters:
|
||||
Expect(err).To(BeNil())
|
||||
Expect(valid).To(BeTrue())
|
||||
|
||||
// llama-cpp configs can't mix the score usecase with
|
||||
// chat/completion/embeddings — Score bypasses the slot
|
||||
// loop and would race the llama_context. The check fires
|
||||
// at load and save time; here we exercise it directly.
|
||||
scoreFlag := FLAG_SCORE | FLAG_CHAT
|
||||
conflicting := ModelConfig{
|
||||
Name: "router-but-also-chat",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &scoreFlag,
|
||||
}
|
||||
valid, err = conflicting.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("score is incompatible"))
|
||||
|
||||
scoreOnly := FLAG_SCORE
|
||||
dedicated := ModelConfig{
|
||||
Name: "router-only",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &scoreOnly,
|
||||
}
|
||||
valid, err = dedicated.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// The constraint is llama-cpp-specific; other backends
|
||||
// may safely combine.
|
||||
scoreAndChat := FLAG_SCORE | FLAG_CHAT
|
||||
otherBackend := ModelConfig{
|
||||
Name: "vllm-router-and-chat",
|
||||
Backend: "vllm",
|
||||
KnownUsecases: &scoreAndChat,
|
||||
}
|
||||
valid, err = otherBackend.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Cloud-proxy: api_key_env and api_key_file are mutually
|
||||
// exclusive — picking both is a config bug we catch at
|
||||
// load/save rather than at backend-load time.
|
||||
bothKeys := ModelConfig{
|
||||
Name: "both-keys",
|
||||
Backend: "cloud-proxy",
|
||||
Proxy: ProxyConfig{
|
||||
UpstreamURL: "https://example.com/v1",
|
||||
APIKeyEnv: "OPENAI_KEY",
|
||||
APIKeyFile: "/run/secrets/openai",
|
||||
},
|
||||
}
|
||||
valid, err = bothKeys.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("mutually exclusive")))
|
||||
|
||||
// Translate mode requires a provider — without one, the
|
||||
// backend has no way to pick a wire format.
|
||||
translateNoProvider := ModelConfig{
|
||||
Name: "translate-no-provider",
|
||||
Backend: "cloud-proxy",
|
||||
Proxy: ProxyConfig{UpstreamURL: "https://example.com/v1", Mode: ProxyModeTranslate},
|
||||
}
|
||||
valid, err = translateNoProvider.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("translate mode requires provider")))
|
||||
|
||||
// Unknown mode is rejected.
|
||||
badMode := ModelConfig{
|
||||
Name: "bad-mode",
|
||||
Backend: "cloud-proxy",
|
||||
Proxy: ProxyConfig{UpstreamURL: "https://example.com/v1", Mode: "rewrite"},
|
||||
}
|
||||
valid, err = badMode.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("unknown mode")))
|
||||
|
||||
// Passthrough (default) with one key source is happy.
|
||||
passthroughOK := ModelConfig{
|
||||
Name: "passthrough-ok",
|
||||
Backend: "cloud-proxy",
|
||||
Proxy: ProxyConfig{UpstreamURL: "https://example.com/v1", APIKeyEnv: "OPENAI_KEY"},
|
||||
}
|
||||
valid, err = passthroughOK.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// router.score_normalization: load-time rejection of an
|
||||
// unknown value. The classifier consumes it lazily, so
|
||||
// without this validation a YAML typo wouldn't surface
|
||||
// until the first router request panicked deep in
|
||||
// NewScoreClassifier.
|
||||
badNorm := ModelConfig{
|
||||
Name: "bad-norm",
|
||||
Router: RouterConfig{
|
||||
ScoreNormalization: "men", // typo of "mean"
|
||||
},
|
||||
}
|
||||
valid, err = badNorm.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("unknown score_normalization")))
|
||||
|
||||
// Accepted values pass.
|
||||
for _, mode := range []string{"", ScoreNormalizationRaw, ScoreNormalizationMean} {
|
||||
goodNorm := ModelConfig{
|
||||
Name: "good-norm-" + mode,
|
||||
Router: RouterConfig{ScoreNormalization: mode},
|
||||
}
|
||||
valid, err = goodNorm.Validate()
|
||||
Expect(valid).To(BeTrue(), "score_normalization=%q should be accepted", mode)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
|
||||
// router.classifier_system_template: parse-time rejection
|
||||
// of malformed Go templates. Same reasoning as above —
|
||||
// without this the parse error wouldn't surface until
|
||||
// the first router request panicked in NewScoreClassifier.
|
||||
badTmpl := ModelConfig{
|
||||
Name: "bad-tmpl",
|
||||
Router: RouterConfig{
|
||||
ClassifierSystemTemplate: "Routes: {{range .Policies",
|
||||
},
|
||||
}
|
||||
valid, err = badTmpl.Validate()
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(err).To(MatchError(ContainSubstring("classifier_system_template parse error")))
|
||||
|
||||
// Well-formed template passes.
|
||||
goodTmpl := ModelConfig{
|
||||
Name: "good-tmpl",
|
||||
Router: RouterConfig{
|
||||
ClassifierSystemTemplate: `Routes: {{range .Policies}}{{.Label}} {{end}}`,
|
||||
},
|
||||
}
|
||||
valid, err = goodTmpl.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml
|
||||
httpClient := http.Client{}
|
||||
resp, err := httpClient.Get("https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml")
|
||||
@@ -168,6 +303,29 @@ parameters:
|
||||
Expect(i.HasUsecases(FLAG_TTS)).To(BeFalse())
|
||||
Expect(i.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
|
||||
Expect(i.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
||||
|
||||
// Declared `known_usecases: [score]` is authoritative — the
|
||||
// guessing heuristic must NOT add chat on top, even though the
|
||||
// inherited chatml template would otherwise satisfy the chat
|
||||
// heuristic. Score means "this model is reserved for the
|
||||
// router classifier"; surfacing it as a chat model defeats the
|
||||
// reservation and reintroduces the slot contention the load-time
|
||||
// score/chat conflict check exists to prevent.
|
||||
scoreReserved := FLAG_SCORE
|
||||
j := ModelConfig{
|
||||
Name: "arch-router",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &scoreReserved,
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "inherited from chatml",
|
||||
ChatMessage: "inherited from chatml",
|
||||
Completion: "inherited from chatml",
|
||||
},
|
||||
}
|
||||
Expect(j.HasUsecases(FLAG_SCORE)).To(BeTrue())
|
||||
Expect(j.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
Expect(j.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
|
||||
Expect(j.HasUsecases(FLAG_EMBEDDINGS)).To(BeFalse())
|
||||
})
|
||||
It("Test Validate with invalid MCP config", func() {
|
||||
tmp, err := os.CreateTemp("", "config.yaml")
|
||||
|
||||
@@ -90,4 +90,26 @@ type RuntimeSettings struct {
|
||||
LogoFile *string `json:"logo_file,omitempty"`
|
||||
LogoHorizontalFile *string `json:"logo_horizontal_file,omitempty"`
|
||||
FaviconFile *string `json:"favicon_file,omitempty"`
|
||||
|
||||
// Cloud-proxy MITM listener. MITMCADir is intentionally NOT
|
||||
// exposed at runtime — the CA dir is a startup-only path and
|
||||
// changing it after the CA has been generated would orphan
|
||||
// trusted clients.
|
||||
MITMListen *string `json:"mitm_listen,omitempty"`
|
||||
|
||||
// PII pattern overrides — keyed by pattern id, applied to the live
|
||||
// redactor at startup and persisted by POST /api/pii/patterns/persist.
|
||||
// Distinguishes from --pii-config (which replaces the entire
|
||||
// pattern set) by only carrying the per-id action/enabled deltas
|
||||
// against the global default catalog.
|
||||
PIIPatternOverrides *map[string]PIIPatternRuntimeOverride `json:"pii_pattern_overrides,omitempty"`
|
||||
}
|
||||
|
||||
// PIIPatternRuntimeOverride captures the persistable deltas an admin
|
||||
// has applied to a single global PII pattern. Both fields are pointers
|
||||
// so an override that only flips Disabled doesn't have to also restate
|
||||
// Action (and vice versa).
|
||||
type PIIPatternRuntimeOverride struct {
|
||||
Action *string `json:"action,omitempty"`
|
||||
Disabled *bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
@@ -51,6 +51,25 @@ var _ = Describe("RuntimeSettings persistence helpers", func() {
|
||||
})
|
||||
})
|
||||
|
||||
// MITM round trip pins the contract that loadRuntimeSettingsFromFile
|
||||
// MITM listener address must survive a write/read round trip so the
|
||||
// next process restart can bring the listener back up. (Intercept
|
||||
// hosts now live in model YAML rather than runtime_settings.json.)
|
||||
Describe("MITM round trip", func() {
|
||||
It("preserves mitm_listen across read/write", func() {
|
||||
listen := ":8443"
|
||||
Expect(cfg.WritePersistedSettings(config.RuntimeSettings{
|
||||
MITMListen: &listen,
|
||||
})).To(Succeed())
|
||||
|
||||
got, err := cfg.ReadPersistedSettings()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(got.MITMListen).ToNot(BeNil())
|
||||
Expect(*got.MITMListen).To(Equal(":8443"))
|
||||
})
|
||||
})
|
||||
|
||||
// PreserveOnSaveDoesNotClobberAssets reproduces the user-reported
|
||||
// regression: an admin uploads a logo, then clicks Save on the
|
||||
// Settings page. The Save body still has the stale pre-upload
|
||||
|
||||
0
core/explorer/empty_db.json.lock
Normal file
0
core/explorer/empty_db.json.lock
Normal file
0
core/explorer/test_db.json.lock
Normal file
0
core/explorer/test_db.json.lock
Normal file
@@ -25,10 +25,8 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/finetune"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/quantization"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -213,19 +211,18 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
e.Use(middleware.Recover())
|
||||
}
|
||||
|
||||
// Metrics middleware
|
||||
if !application.ApplicationConfig().DisableMetrics {
|
||||
metricsService, err := monitoring.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if metricsService != nil {
|
||||
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
// Metrics middleware. The metric service was created in
|
||||
// application.start() so the OTel global provider is set before any
|
||||
// counter is registered (the routing-module billing recorder relies
|
||||
// on this). We reuse that instance here rather than calling
|
||||
// monitoring.NewLocalAIMetricsService a second time, which would
|
||||
// create a second provider, second prometheus exporter, and orphan
|
||||
// whichever instance lost the SetMeterProvider race.
|
||||
if metricsService := application.MetricsService(); metricsService != nil {
|
||||
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
_ = metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
@@ -268,13 +265,9 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
e.Static("/generated-videos", videoPath)
|
||||
}
|
||||
|
||||
// Initialize usage recording when auth DB is available, and ensure the
|
||||
// batcher drains its in-memory queue on graceful shutdown so the last
|
||||
// few seconds of usage don't disappear when the process exits.
|
||||
if application.AuthDB() != nil {
|
||||
httpMiddleware.InitUsageRecorder(application.AuthDB())
|
||||
signals.RegisterGracefulTerminationHandler(httpMiddleware.ShutdownUsageRecorder)
|
||||
}
|
||||
// Usage recording is initialised in application/startup.go and
|
||||
// surfaced via application.StatsRecorder(); routes wire UsageMiddleware
|
||||
// against that recorder regardless of auth state.
|
||||
|
||||
// Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is
|
||||
// the role of the exempt-path logic inside the middleware.
|
||||
@@ -361,12 +354,33 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
// Register auth routes (login, callback, API keys, user management)
|
||||
routes.RegisterAuthRoutes(e, application)
|
||||
|
||||
// Register routing-module usage endpoints. Unlike /api/auth/usage
|
||||
// these go through the StatsRecorder and work in no-auth single-user
|
||||
// mode by attributing requests to the synthetic "local" user.
|
||||
routes.RegisterUsageRoutes(e, application)
|
||||
routes.RegisterPIIRoutes(e, application)
|
||||
routes.RegisterMiddlewareRoutes(e, application)
|
||||
|
||||
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
||||
var opcache *galleryop.OpCache
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
opcache = galleryop.NewOpCache(application.GalleryService())
|
||||
// In distributed mode, wire the NATS client + gallery store so this
|
||||
// replica's OpCache stays in sync with peers — without this the
|
||||
// /api/operations endpoint returns whatever this single replica
|
||||
// happened to admit, and a load-balanced UI poll alternates between
|
||||
// "operation visible" and "operation gone" between replicas.
|
||||
if d := application.Distributed(); d != nil {
|
||||
opcache.SetMessagingClient(d.Nats)
|
||||
if d.DistStores != nil && d.DistStores.Gallery != nil {
|
||||
opcache.SetGalleryStore(d.DistStores.Gallery)
|
||||
}
|
||||
if err := opcache.Start(application.ApplicationConfig().Context); err != nil {
|
||||
xlog.Warn("OpCache distributed subscribe failed; running standalone", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
|
||||
|
||||
@@ -17,6 +17,18 @@ const (
|
||||
)
|
||||
|
||||
// UsageRecord represents a single API request's token usage.
|
||||
//
|
||||
// Model semantics: Model is the legacy column kept for backward-compatible
|
||||
// aggregation; new code should write RequestedModel (what the client asked
|
||||
// for) and ServedModel (what actually ran after routing). When no router
|
||||
// is in play, all three are equal.
|
||||
//
|
||||
// PreFilterPromptTokens vs PromptTokens: PromptTokens is the count after
|
||||
// PII redaction (i.e., what the backend processed and was billed for).
|
||||
// PreFilterPromptTokens is the count of the original prompt before any
|
||||
// PII filtering; PostFilterPromptTokens duplicates PromptTokens for
|
||||
// queryability symmetry. For non-PII paths PreFilterPromptTokens ==
|
||||
// PostFilterPromptTokens == PromptTokens.
|
||||
type UsageRecord struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||
@@ -37,6 +49,22 @@ type UsageRecord struct {
|
||||
TotalTokens int64
|
||||
Duration int64 // milliseconds
|
||||
CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
|
||||
|
||||
// Routing extension fields. Nullable / zero-valued for legacy rows.
|
||||
RequestedModel string `gorm:"size:255;index"`
|
||||
ServedModel string `gorm:"size:255;index"`
|
||||
PreFilterPromptTokens int64 // tokens the client sent before PII redaction
|
||||
PostFilterPromptTokens int64 // tokens after redaction (== PromptTokens unless filter shrunk it)
|
||||
CachedTokens int64 // backend-reported KV-cache hit tokens
|
||||
PrefillTokens int64 // backend-reported prefill tokens (subset of prompt)
|
||||
DraftTokens int64 // speculative-decoding draft tokens
|
||||
PricingVersionID string `gorm:"size:64;index"` // FK to pricing_version; "" when no pricing was applied
|
||||
CostUSD float64 // computed at insert when pricing is available; 0 with empty PricingVersionID = unknown
|
||||
|
||||
// Cross-subsystem correlation. Empty when the subsystem didn't run.
|
||||
CorrelationID string `gorm:"size:64;index"`
|
||||
RouterDecisionID string `gorm:"size:64;index"`
|
||||
PIIEventID string `gorm:"size:64"`
|
||||
}
|
||||
|
||||
// RecordUsage inserts a usage record.
|
||||
|
||||
13
core/http/endpoints/anthropic/anthropic_suite_test.go
Normal file
13
core/http/endpoints/anthropic/anthropic_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestAnthropic(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Anthropic test suite")
|
||||
}
|
||||
@@ -10,10 +10,13 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -27,7 +30,7 @@ import (
|
||||
// @Param request body schema.AnthropicRequest true "query params"
|
||||
// @Success 200 {object} schema.AnthropicResponse "Response"
|
||||
// @Router /v1/messages [post]
|
||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc {
|
||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := uuid.New().String()
|
||||
|
||||
@@ -47,6 +50,12 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||
|
||||
xlog.Debug("Anthropic Messages endpoint configuration read", "config", cfg)
|
||||
|
||||
// Cloud-proxy bail. Same shape as the OpenAI chat endpoint —
|
||||
// forwards via the cloud-proxy gRPC backend.
|
||||
if cfg.IsCloudProxyBackendPassthrough() {
|
||||
return forwardCloudProxyAnthropicViaBackend(c, cfg, input, piiRedactor, piiEvents, ml, appConfig)
|
||||
}
|
||||
|
||||
// Convert Anthropic messages to OpenAI format for internal processing
|
||||
openAIMessages := convertAnthropicToOpenAIMessages(input)
|
||||
|
||||
@@ -132,7 +141,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||
xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
|
||||
|
||||
if input.Stream {
|
||||
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator, piiRedactor, piiEvents)
|
||||
}
|
||||
|
||||
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||
@@ -313,17 +322,45 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
xlog.Debug("Anthropic Response", "response", string(respData))
|
||||
}
|
||||
|
||||
middleware.StampUsage(c, input.Model, tokenUsage.Prompt, tokenUsage.Completion)
|
||||
|
||||
return c.JSON(200, resp)
|
||||
} // end MCP iteration loop
|
||||
|
||||
return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached")
|
||||
}
|
||||
|
||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
|
||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator, piiRedactor *pii.Redactor, piiEvents pii.EventStore) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Per-stream PII filter — same gating as the OpenAI chat path. The
|
||||
// filter is wire-format-agnostic; we feed it the text portion of
|
||||
// each text_delta and emit only what's safe to send. The filter
|
||||
// holds back a tail of size MaxPatternLength-1 so a pattern split
|
||||
// across chunk boundaries still gets masked. When PII is disabled
|
||||
// for this model the filter is nil and emits flow unchanged.
|
||||
var streamPIIFilter *pii.StreamFilter
|
||||
if piiRedactor != nil && cfg.PIIIsEnabled() {
|
||||
correlationID := c.Request().Header.Get("x-request-id")
|
||||
userID := ""
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
userID = u.ID
|
||||
}
|
||||
var overrides map[string]pii.Action
|
||||
if raw := cfg.PIIPatternOverrides(); len(raw) > 0 {
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
}
|
||||
streamPIIFilter = pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID)
|
||||
}
|
||||
|
||||
// Send message_start event
|
||||
messageStart := schema.AnthropicStreamEvent{
|
||||
Type: "message_start",
|
||||
@@ -403,6 +440,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
|
||||
if len(toolCalls) > toolCallsEmitted {
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
@@ -443,14 +481,20 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
|
||||
if !inToolCall && token != "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: token,
|
||||
},
|
||||
})
|
||||
out := token
|
||||
if streamPIIFilter != nil {
|
||||
out = streamPIIFilter.Push(token)
|
||||
}
|
||||
if out != "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: out,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -488,14 +532,20 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
// didn't already stream it (autoparser clears raw text, so
|
||||
// accumulatedContent will be empty in that case).
|
||||
if deltaContent != "" && !inToolCall && accumulatedContent == "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: deltaContent,
|
||||
},
|
||||
})
|
||||
out := deltaContent
|
||||
if streamPIIFilter != nil {
|
||||
out = streamPIIFilter.Push(deltaContent)
|
||||
}
|
||||
if out != "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: out,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Emit tool_use blocks from ChatDeltas
|
||||
@@ -503,6 +553,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
collectedToolCalls = deltaToolCalls
|
||||
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
@@ -606,7 +657,9 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && accumulatedContent != "" && toolCallsEmitted == 0 {
|
||||
parsed := functions.ParseFunctionCall(accumulatedContent, cfg.FunctionsConfig)
|
||||
if len(parsed) > 0 {
|
||||
// Close the text content block
|
||||
// Close the text content block (after flushing any
|
||||
// residual the streaming PII filter held back).
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
@@ -646,8 +699,12 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
}
|
||||
|
||||
// No MCP tools to execute, close stream
|
||||
// No MCP tools to execute, close stream. drainStreamPIIToText
|
||||
// flushes any residual the streaming PII filter held back as
|
||||
// part of its trailing pattern-window before we close the
|
||||
// text content block.
|
||||
if !inToolCall {
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(0))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(0),
|
||||
@@ -673,6 +730,8 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
Type: "message_stop",
|
||||
})
|
||||
|
||||
middleware.StampUsage(c, input.Model, tokenUsage.Prompt, tokenUsage.Completion)
|
||||
|
||||
return nil
|
||||
} // end MCP iteration loop
|
||||
|
||||
@@ -693,6 +752,30 @@ func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool {
|
||||
|
||||
func intPtr(i int) *int { return &i }
|
||||
|
||||
// drainStreamPIIToText flushes any residual the streaming PII filter
|
||||
// has been holding back as part of its trailing pattern-window, and
|
||||
// emits it as one final text_delta into the named block before the
|
||||
// caller closes that block. Drain is idempotent: calling it twice on
|
||||
// the same filter returns "" the second time. Safe to call with a nil
|
||||
// filter (no-op).
|
||||
func drainStreamPIIToText(c echo.Context, sf *pii.StreamFilter, index *int) {
|
||||
if sf == nil {
|
||||
return
|
||||
}
|
||||
residual := sf.Drain()
|
||||
if residual == "" {
|
||||
return
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: index,
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: residual,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
@@ -888,3 +971,19 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
|
||||
|
||||
return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions()
|
||||
}
|
||||
|
||||
// forwardCloudProxyAnthropicViaBackend marshals the Anthropic request,
|
||||
// constructs the streaming PII filter (when applicable), and hands the
|
||||
// body off to the cloud-proxy gRPC backend. Model swap + upstream auth
|
||||
// headers are applied inside the backend; the filter is built here
|
||||
// because the auth/correlation context only exists in the echo handler.
|
||||
func forwardCloudProxyAnthropicViaBackend(c echo.Context, cfg *config.ModelConfig, input *schema.AnthropicRequest, piiRedactor *pii.Redactor, piiEvents pii.EventStore, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
body, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return sendAnthropicError(c, 400, "invalid_request_error", "cloudproxy: marshal request: "+err.Error())
|
||||
}
|
||||
|
||||
correlationID := c.Request().Header.Get("x-request-id")
|
||||
streamFilter := cloudproxy.BuildStreamFilter(c, cfg, input.Stream, piiRedactor, piiEvents, correlationID)
|
||||
return cloudproxy.ForwardViaBackend(c, cfg, body, streamFilter, ml, appConfig)
|
||||
}
|
||||
|
||||
114
core/http/endpoints/anthropic/messages_pii_test.go
Normal file
114
core/http/endpoints/anthropic/messages_pii_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// drainStreamPIIToText is called from four sites in messages.go and is
|
||||
// the load-bearing primitive for "the streaming filter has buffered
|
||||
// some bytes that the request just ended on; flush them as a final
|
||||
// text_delta event before closing the content block". A regression
|
||||
// here would silently truncate the last few bytes of an assistant
|
||||
// response on every PII-enabled stream — invisible without coverage.
|
||||
|
||||
// newTestFilter compiles the default patterns and returns a filter
|
||||
// that holds back its trailing pattern-window; pushing a short string
|
||||
// (shorter than holdLen) keeps the bytes inside Drain.
|
||||
func newTestFilter() *pii.StreamFilter {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
red := pii.NewRedactor(patterns)
|
||||
return pii.NewStreamFilter(red, nil, nil, "", "")
|
||||
}
|
||||
|
||||
// newTestContext builds a recording echo context — the recorder
|
||||
// captures the SSE bytes drainStreamPIIToText writes.
|
||||
func newTestContext() (echo.Context, *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader("{}"))
|
||||
rec := httptest.NewRecorder()
|
||||
return echo.New().NewContext(req, rec), rec
|
||||
}
|
||||
|
||||
var _ = Describe("drainStreamPIIToText", func() {
|
||||
It("is a no-op when the filter is nil", func() {
|
||||
c, rec := newTestContext()
|
||||
drainStreamPIIToText(c, nil, intPtr(0))
|
||||
Expect(rec.Body.Len()).To(Equal(0), "nil filter wrote %d bytes: %q", rec.Body.Len(), rec.Body.String())
|
||||
})
|
||||
|
||||
It("emits nothing when the drain is empty", func() {
|
||||
// A filter with nothing buffered should not emit a phantom event;
|
||||
// otherwise every non-PII response would close with an empty
|
||||
// text_delta that pollutes downstream parsers.
|
||||
sf := newTestFilter()
|
||||
c, rec := newTestContext()
|
||||
drainStreamPIIToText(c, sf, intPtr(0))
|
||||
Expect(rec.Body.Len()).To(Equal(0), "empty drain wrote %d bytes: %q", rec.Body.Len(), rec.Body.String())
|
||||
})
|
||||
|
||||
It("flushes residual buffered bytes as a text_delta event", func() {
|
||||
sf := newTestFilter()
|
||||
// Push less than holdLen so all bytes are retained until Drain.
|
||||
// "tail" is short enough that no pattern is plausible.
|
||||
out := sf.Push("tail")
|
||||
Expect(out).To(Equal(""), "Push of short text emitted %q; want all bytes held", out)
|
||||
|
||||
c, rec := newTestContext()
|
||||
drainStreamPIIToText(c, sf, intPtr(2))
|
||||
|
||||
body := rec.Body.String()
|
||||
// Wire format: "event: content_block_delta\ndata: {…}\n\n"
|
||||
Expect(body).To(ContainSubstring("event: content_block_delta"))
|
||||
Expect(body).To(ContainSubstring(`"type":"content_block_delta"`))
|
||||
Expect(body).To(ContainSubstring(`"index":2`))
|
||||
Expect(body).To(ContainSubstring(`"text":"tail"`))
|
||||
Expect(body).To(ContainSubstring(`"type":"text_delta"`))
|
||||
Expect(strings.HasSuffix(body, "\n\n")).To(BeTrue(), "SSE event missing trailing blank line: %q", body)
|
||||
})
|
||||
|
||||
It("is idempotent across consecutive drains", func() {
|
||||
// Two consecutive Drains: the filter returns "" the second time,
|
||||
// so the second drainStreamPIIToText must emit nothing. The
|
||||
// production path in messages.go has at least four call sites
|
||||
// that may overlap (currentBlockIndex==0 emergency path + the
|
||||
// unconditional drain near the end of the stream); without
|
||||
// idempotence we'd duplicate the residual on the wire.
|
||||
sf := newTestFilter()
|
||||
sf.Push("tail")
|
||||
|
||||
c1, rec1 := newTestContext()
|
||||
drainStreamPIIToText(c1, sf, intPtr(0))
|
||||
first := rec1.Body.Len()
|
||||
Expect(first).NotTo(Equal(0), "first drain emitted nothing")
|
||||
|
||||
c2, rec2 := newTestContext()
|
||||
drainStreamPIIToText(c2, sf, intPtr(0))
|
||||
Expect(rec2.Body.Len()).To(Equal(0), "second drain wrote %d bytes; want idempotent no-op: %q", rec2.Body.Len(), rec2.Body.String())
|
||||
})
|
||||
|
||||
It("masks redacted residual instead of leaking it", func() {
|
||||
// The held tail must travel through the redactor on Drain. If
|
||||
// the bytes happen to form a complete pattern at end-of-stream,
|
||||
// the residual emit must contain the mask placeholder, not the
|
||||
// raw value.
|
||||
sf := newTestFilter()
|
||||
// "alice@example.com" is 17 bytes. holdLen for default patterns
|
||||
// is well above 17, so this stays buffered until Drain, which
|
||||
// then redacts it.
|
||||
out := sf.Push("alice@example.com")
|
||||
Expect(out).To(Equal(""), "Push emitted bytes early: %q", out)
|
||||
|
||||
c, rec := newTestContext()
|
||||
drainStreamPIIToText(c, sf, intPtr(0))
|
||||
body := rec.Body.String()
|
||||
Expect(body).NotTo(ContainSubstring("alice@example.com"), "raw email leaked in residual emit: %q", body)
|
||||
Expect(body).To(ContainSubstring("[REDACTED:email]"), "residual emit missing mask placeholder: %q", body)
|
||||
})
|
||||
})
|
||||
@@ -92,6 +92,30 @@ var instructionDefs = []instructionDef{
|
||||
Tags: []string{"branding"},
|
||||
Intro: "GET /api/branding is public so the login screen can render the configured logo before authentication. Text fields are saved through POST /api/settings; binary assets (logo, horizontal logo, favicon) use multipart upload at /api/branding/asset/{kind} and are served back from /branding/asset/{kind}.",
|
||||
},
|
||||
{
|
||||
Name: "usage-and-billing",
|
||||
Description: "Per-user token usage and request counts, with optional cost tracking",
|
||||
Tags: []string{"usage"},
|
||||
Intro: "GET /api/usage returns the current user's token usage in time-bucketed form (day/week/month/all). In single-user no-auth mode the records are attributed to a synthetic local user with stable UUID, so this endpoint and the dashboard work without --auth. /api/usage/all is the cluster-wide view and requires admin (the local user is admin in single-user mode). UsageRecord fields include RequestedModel/ServedModel and PreFilter/PostFilterPromptTokens for routing- and PII-aware accounting.",
|
||||
},
|
||||
{
|
||||
Name: "pii-filtering",
|
||||
Description: "Inspect and tune the regex PII filter applied to chat requests",
|
||||
Tags: []string{"pii"},
|
||||
Intro: "GET /api/pii/patterns lists the active pattern set with each one's action (mask, block, route_local). GET /api/pii/events returns recent redaction events filtered by correlation_id / user_id / pattern_id (admin or local-user only). POST /api/pii/test dry-runs the redactor against an admin-supplied string. POST /api/pii/decide is the programmatic decision oracle for external routers: send `{text}`, receive `{findings, suggested_action, redacted_preview}` without LocalAI mutating, recording, or acting on the call — caller composes the action with its own policy. Default patterns: email, phone, SSN, credit card (Luhn), IPv4, common API key prefixes (sk-, pk-, ghp_, github_pat_). PII is per-model: by default it is OFF for non-proxy backends and ON for backends starting with proxy-* (cloud passthroughs). Opt in with `pii: { enabled: true }` in a model's YAML; use `pii: { patterns: [{id, action}] }` to upgrade or downgrade individual actions for that model. Override global default actions via --pii-config pii.yaml; --disable-pii turns the filter off entirely.",
|
||||
},
|
||||
{
|
||||
Name: "middleware-admin",
|
||||
Description: "Inspect and configure the routing-module middleware (PII filter and routing)",
|
||||
Tags: []string{"middleware", "pii", "router"},
|
||||
Intro: "GET /api/middleware/status is the single round-trip the /app/middleware admin page reads to render the current state: active PII patterns and their actions, every model's resolved enabled/override state, recent event count, and the active routing models with their classifier configurations. Admin-only (the synthetic local user is admin in no-auth mode). PUT /api/pii/patterns/:id changes a pattern's action in-process — TRANSIENT, lost on restart. To persist, edit --pii-config YAML. GET /api/router/decisions returns the routing decision log filtered by correlation_id / user_id / router_model. The same surface is exposed as MCP tools (`get_middleware_status`, `set_pii_pattern_action`, `get_router_decisions`) for agent-driven configuration.",
|
||||
},
|
||||
{
|
||||
Name: "intelligent-routing",
|
||||
Description: "Per-model `router:` configuration that classifies requests and rewrites the served model",
|
||||
Tags: []string{"router"},
|
||||
Intro: "Add a `router:` block to a ModelConfig to turn it into a routing model. The block declares a classifier (today: `feature` — handcrafted rules over prompt length and code-fence presence), a list of candidates (label + downstream model + optional rule), and a fallback. When a client addresses the routing model, the RouteModel middleware invokes the classifier, picks a candidate, and rewrites input.Model — the standard model-resolution path then runs ACL, disabled-state, and per-model PII against the chosen target. Depth-1 invariant: candidates must NOT themselves carry a `router:` block; runtime check returns 500 on violation. Decisions are logged to GET /api/router/decisions and surfaced in the /app/middleware Routing tab. POST /api/router/decide is the programmatic decision-oracle: external routers (e.g. an organisation-wide router service) send `{router, input}` and receive the classifier's label set + candidate model WITHOUT LocalAI rewriting, forwarding, or recording the call. Shares the classifier cache with the in-band path so warm-up costs are paid once.",
|
||||
},
|
||||
}
|
||||
|
||||
// swaggerState holds parsed swagger spec data, initialised once.
|
||||
|
||||
@@ -39,7 +39,7 @@ var _ = Describe("API Instructions Endpoints", func() {
|
||||
|
||||
instructions, ok := resp["instructions"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(instructions).To(HaveLen(12))
|
||||
Expect(instructions).To(HaveLen(16))
|
||||
|
||||
// Verify each instruction has required fields and correct URL format
|
||||
for _, s := range instructions {
|
||||
@@ -74,6 +74,10 @@ var _ = Describe("API Instructions Endpoints", func() {
|
||||
"monitoring",
|
||||
"agents",
|
||||
"face-recognition",
|
||||
"usage-and-billing",
|
||||
"pii-filtering",
|
||||
"middleware-admin",
|
||||
"intelligent-routing",
|
||||
))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -173,12 +173,12 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
// Validate without calling SetDefaults() — runtime defaults should not
|
||||
// be persisted to disk. SetDefaults() is called when loading configs
|
||||
// for inference via LoadModelConfigsFromPath().
|
||||
if valid, _ := modelConfig.Validate(); !valid {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Invalid configuration",
|
||||
if valid, vErr := modelConfig.Validate(); !valid {
|
||||
msg := "Invalid configuration"
|
||||
if vErr != nil {
|
||||
msg = vErr.Error()
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
return c.JSON(http.StatusBadRequest, ModelResponse{Success: false, Error: msg})
|
||||
}
|
||||
|
||||
// Create the configuration file
|
||||
|
||||
@@ -61,7 +61,11 @@ func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// The legacy /v1/mcp/chat/completions endpoint never opts into the
|
||||
// in-process LocalAI Assistant tool surface — pass nil holder so the
|
||||
// assistant branch in chat.go is unreachable from this code path.
|
||||
chatHandler := openai.ChatEndpoint(cl, ml, evaluator, appConfig, natsClient, nil)
|
||||
// Stream-side PII filter is also nil: this legacy endpoint pre-dates
|
||||
// the per-model PII config and is kept for backward compatibility.
|
||||
// The request-side middleware on the main chat route handles
|
||||
// filtering for the standard /v1/chat/completions path.
|
||||
chatHandler := openai.ChatEndpoint(cl, ml, evaluator, appConfig, natsClient, nil, nil, nil)
|
||||
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
|
||||
85
core/http/endpoints/localai/pii_decide.go
Normal file
85
core/http/endpoints/localai/pii_decide.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// PIIDecideEndpoint exposes the PII redactor as a decision oracle:
|
||||
// scan the supplied text and return findings + the strongest action
|
||||
// the configured pattern set would take, without rewriting the
|
||||
// caller's request or recording an audit event.
|
||||
//
|
||||
// External routers (e.g. the localai-org/platform router) call this
|
||||
// before dispatching to learn whether to mask the prompt in place,
|
||||
// route to a local-only backend, block the request, or pass it
|
||||
// through. LocalAI's in-band PII middleware is the alternative path
|
||||
// for direct-to-LocalAI clients — same Redactor, different framing.
|
||||
//
|
||||
// Takes the *pii.Redactor directly rather than the whole
|
||||
// *application.Application so the handler stays unit-testable with a
|
||||
// freshly-constructed redactor (mirrors the pattern in
|
||||
// router_decide.go). The route-registration site is responsible for
|
||||
// stubbing this endpoint when --disable-pii is set so callers get a
|
||||
// 503 signalling "admin opted out" rather than a misleading allow.
|
||||
//
|
||||
// @Summary Scan text for PII and return findings + suggested action (decision oracle)
|
||||
// @Tags pii
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.PIIDecideRequest true "decide params"
|
||||
// @Success 200 {object} schema.PIIDecideResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Router /api/pii/decide [post]
|
||||
func PIIDecideEndpoint(redactor *pii.Redactor) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.PIIDecideRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid request body: "+err.Error())
|
||||
}
|
||||
if req.Text == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "text is required")
|
||||
}
|
||||
|
||||
res := redactor.Redact(req.Text)
|
||||
findings := make([]schema.PIIFinding, len(res.Spans))
|
||||
for i, s := range res.Spans {
|
||||
findings[i] = schema.PIIFinding{
|
||||
Start: s.Start,
|
||||
End: s.End,
|
||||
Pattern: s.Pattern,
|
||||
HashPrefix: s.HashPrefix,
|
||||
}
|
||||
}
|
||||
return c.JSON(http.StatusOK, schema.PIIDecideResponse{
|
||||
Findings: findings,
|
||||
SuggestedAction: suggestedAction(res),
|
||||
RedactedPreview: res.Redacted,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// actionAllow is the wire-only value for "no findings". The other
|
||||
// three map to existing pii.Action* constants; allow has no in-band
|
||||
// counterpart because the in-band middleware simply passes through.
|
||||
const actionAllow = "allow"
|
||||
|
||||
// suggestedAction collapses the Redactor's Result flags onto a single
|
||||
// wire-format action using the in-band ordering (block > route_local
|
||||
// > mask > allow). Spans-without-Blocked-or-LocalOnly means every
|
||||
// match resolved to ActionMask.
|
||||
func suggestedAction(res pii.Result) string {
|
||||
switch {
|
||||
case res.Blocked:
|
||||
return string(pii.ActionBlock)
|
||||
case res.LocalOnly:
|
||||
return string(pii.ActionRouteLocal)
|
||||
case len(res.Spans) > 0:
|
||||
return string(pii.ActionMask)
|
||||
default:
|
||||
return actionAllow
|
||||
}
|
||||
}
|
||||
107
core/http/endpoints/localai/pii_decide_test.go
Normal file
107
core/http/endpoints/localai/pii_decide_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// PIIDecideEndpoint exposes the redactor as a decision oracle. These
|
||||
// specs pin the validation surface and the suggested_action mapping
|
||||
// across all four actions (allow/mask/route_local/block). The redactor
|
||||
// itself is covered in core/services/routing/pii/redactor_test.go.
|
||||
|
||||
var _ = Describe("PIIDecideEndpoint", func() {
|
||||
var redactor *pii.Redactor
|
||||
|
||||
BeforeEach(func() {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
redactor = pii.NewRedactor(patterns)
|
||||
})
|
||||
|
||||
It("rejects requests with no text field", func() {
|
||||
rec, _ := invokePIIDecide(redactor, `{}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("text is required"))
|
||||
})
|
||||
|
||||
It("rejects malformed JSON", func() {
|
||||
rec, _ := invokePIIDecide(redactor, `not json`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
It("returns allow for clean text", func() {
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"hello world"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.SuggestedAction).To(Equal("allow"))
|
||||
Expect(body.Findings).To(BeEmpty())
|
||||
Expect(body.RedactedPreview).To(Equal("hello world"))
|
||||
})
|
||||
|
||||
It("returns mask for text containing email (default action)", func() {
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"reach me at alice@example.com please"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.SuggestedAction).To(Equal("mask"))
|
||||
Expect(body.Findings).To(HaveLen(1))
|
||||
Expect(body.Findings[0].Pattern).To(Equal("email"))
|
||||
Expect(body.Findings[0].HashPrefix).NotTo(BeEmpty())
|
||||
Expect(body.RedactedPreview).To(ContainSubstring("[REDACTED:email]"))
|
||||
Expect(body.RedactedPreview).NotTo(ContainSubstring("alice@example.com"))
|
||||
})
|
||||
|
||||
It("returns block when an api_key_prefix is present (block beats mask)", func() {
|
||||
// api_key_prefix defaults to ActionBlock per DefaultPatterns.
|
||||
// Mix in an email so we also confirm the block-action wins
|
||||
// over the mask-action via actionRank.
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"my key is sk-1234567890abcdefghij and email alice@example.com"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.SuggestedAction).To(Equal("block"))
|
||||
Expect(len(body.Findings)).To(BeNumerically(">=", 1))
|
||||
})
|
||||
|
||||
It("returns route_local when an override sets that action", func() {
|
||||
// Promote the email pattern to route_local for this test —
|
||||
// exercises the route_local branch of suggestedAction without
|
||||
// needing a custom pattern set.
|
||||
Expect(redactor.SetAction("email", pii.ActionRouteLocal)).To(Succeed())
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"contact alice@example.com"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.SuggestedAction).To(Equal("route_local"))
|
||||
// route_local leaves the original text intact — caller decides
|
||||
// whether to forward it to a local-only backend.
|
||||
Expect(body.RedactedPreview).To(ContainSubstring("alice@example.com"))
|
||||
})
|
||||
|
||||
It("never leaks the matched value via HashPrefix", func() {
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"alice@example.com"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Findings).To(HaveLen(1))
|
||||
// HashPrefix is 8 hex chars of sha256 — definitely not the
|
||||
// matched value, but stable so admins can correlate leaks.
|
||||
Expect(body.Findings[0].HashPrefix).To(HaveLen(8))
|
||||
Expect(body.Findings[0].HashPrefix).NotTo(ContainSubstring("alice"))
|
||||
})
|
||||
})
|
||||
|
||||
func invokePIIDecide(redactor *pii.Redactor, body string) (*httptest.ResponseRecorder, schema.PIIDecideResponse) {
|
||||
e := echo.New()
|
||||
e.POST("/api/pii/decide", localai.PIIDecideEndpoint(redactor))
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/pii/decide", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
var parsed schema.PIIDecideResponse
|
||||
if rec.Code == http.StatusOK {
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &parsed)).To(Succeed())
|
||||
}
|
||||
return rec, parsed
|
||||
}
|
||||
109
core/http/endpoints/localai/router_decide.go
Normal file
109
core/http/endpoints/localai/router_decide.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
)
|
||||
|
||||
// RouterDecideEndpoint exposes the routing classifier as a decision
|
||||
// oracle: given a router model and a prompt, it runs the same
|
||||
// classifier the in-band RouteModel middleware would have run, returns
|
||||
// the active label set, and resolves which candidate model would have
|
||||
// been picked. It does NOT rewrite anything, forward to a backend, or
|
||||
// write to the decision store — Platform-side routers call this to get
|
||||
// LocalAI's opinion without committing LocalAI to handle the request.
|
||||
//
|
||||
// The classifier is shared with the in-band middleware via the
|
||||
// process-wide router.Registry on deps, so this endpoint and the
|
||||
// request path agree on cache state, embedding-cache hits, etc.
|
||||
//
|
||||
// Takes discrete deps rather than the whole *application.Application so
|
||||
// it stays unit-testable with a stub Scorer and a tmpdir-backed model
|
||||
// loader (mirrors the existing route_model_test.go setup).
|
||||
//
|
||||
// @Summary Classify a prompt against a router model's policies (decision oracle)
|
||||
// @Tags router
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.RouterDecideRequest true "decide params"
|
||||
// @Success 200 {object} schema.RouterDecideResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 404 {object} map[string]string
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Failure 503 {object} map[string]string
|
||||
// @Router /api/router/decide [post]
|
||||
func RouterDecideEndpoint(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, deps middleware.ClassifierDeps) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.RouterDecideRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid request body: "+err.Error())
|
||||
}
|
||||
if req.Router == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "router is required")
|
||||
}
|
||||
if req.Input == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "input is required")
|
||||
}
|
||||
|
||||
cfg, err := loader.LoadModelConfigFileByNameDefaultOptions(req.Router, appConfig)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to load model config: "+err.Error())
|
||||
}
|
||||
// LoadModelConfigFileByName returns a synthetic stub
|
||||
// (PredictionOptions.Model only, no Name) when neither an
|
||||
// in-memory config nor a YAML file exists for the requested
|
||||
// name. Use Name to discriminate "model unknown" (404) from
|
||||
// "model known but not a router" (400) — Platform wants both
|
||||
// signals.
|
||||
if cfg == nil || cfg.Name == "" {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "router model not found: "+req.Router)
|
||||
}
|
||||
if !cfg.HasRouter() {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "model "+req.Router+" is not a router (no `router:` block)")
|
||||
}
|
||||
|
||||
// Build (or reuse) the classifier via the same registry the
|
||||
// in-band middleware uses. Errors here are config problems —
|
||||
// classifier_model missing, policy without description, etc. —
|
||||
// so 503 is the right status: the router is configured but its
|
||||
// classifier can't be instantiated right now.
|
||||
classifier, err := middleware.GetOrBuildClassifier(deps.Registry, cfg, deps)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusServiceUnavailable, "classifier unavailable: "+err.Error())
|
||||
}
|
||||
|
||||
decision, err := classifier.Classify(c.Request().Context(), router.Probe{Prompt: req.Input})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "classify failed: "+err.Error())
|
||||
}
|
||||
|
||||
candidate := router.MatchCandidate(cfg.Router.Candidates, decision.Labels)
|
||||
fallback := false
|
||||
if candidate == "" && cfg.Router.Fallback != "" {
|
||||
candidate = cfg.Router.Fallback
|
||||
fallback = true
|
||||
}
|
||||
|
||||
classifierName := cfg.Router.Classifier
|
||||
if classifierName == "" {
|
||||
classifierName = router.ClassifierScore
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, schema.RouterDecideResponse{
|
||||
Router: req.Router,
|
||||
Classifier: classifierName,
|
||||
Labels: decision.Labels,
|
||||
Candidate: candidate,
|
||||
Fallback: fallback,
|
||||
Score: decision.Score,
|
||||
LatencyMs: decision.Latency.Milliseconds(),
|
||||
Cached: decision.Cached,
|
||||
CacheSimilarity: decision.CacheSimilarity,
|
||||
})
|
||||
}
|
||||
}
|
||||
248
core/http/endpoints/localai/router_decide_test.go
Normal file
248
core/http/endpoints/localai/router_decide_test.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// RouterDecideEndpoint is the programmatic decision oracle that
|
||||
// external routers call to get LocalAI's classifier opinion without
|
||||
// committing LocalAI to handle the request. These specs pin the
|
||||
// validation surface and the happy-path / fallback / depth-1
|
||||
// behaviours; the classifier itself is covered in
|
||||
// core/services/routing/router/score_test.go and the in-band
|
||||
// middleware is covered in core/http/middleware/route_model_test.go.
|
||||
|
||||
var _ = Describe("RouterDecideEndpoint", func() {
|
||||
var (
|
||||
modelDir string
|
||||
appConfig *config.ApplicationConfig
|
||||
loader *config.ModelConfigLoader
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
d, err := os.MkdirTemp("", "router-decide-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
modelDir = d
|
||||
appConfig = &config.ApplicationConfig{
|
||||
Context: context.Background(),
|
||||
SystemState: &system.SystemState{Model: system.Model{ModelsPath: modelDir}},
|
||||
}
|
||||
loader = config.NewModelConfigLoader(modelDir)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.RemoveAll(modelDir)
|
||||
})
|
||||
|
||||
It("rejects requests with no router field", func() {
|
||||
rec, _ := invokeDecide(loader, appConfig, deps(nil), `{"input":"hello"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("router is required"))
|
||||
})
|
||||
|
||||
It("rejects requests with no input field", func() {
|
||||
rec, _ := invokeDecide(loader, appConfig, deps(nil), `{"router":"smart-router"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("input is required"))
|
||||
})
|
||||
|
||||
It("returns 404 for an unknown router model", func() {
|
||||
rec, _ := invokeDecide(loader, appConfig, deps(nil), `{"router":"missing","input":"hello"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("router model not found"))
|
||||
})
|
||||
|
||||
It("returns 400 when the named model has no router block", func() {
|
||||
writeBareModel(modelDir, "plain-model")
|
||||
rec, _ := invokeDecide(loader, appConfig, deps(nil), `{"router":"plain-model","input":"hello"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("is not a router"))
|
||||
})
|
||||
|
||||
It("returns 503 when the classifier can't be built (no scorer wired)", func() {
|
||||
writeScoreRouter(modelDir, "smart-router")
|
||||
writeBareModel(modelDir, "small-model")
|
||||
writeBareModel(modelDir, "big-model")
|
||||
// deps(nil) provides no scorer — buildClassifier returns an
|
||||
// error and the handler maps that to 503.
|
||||
rec, _ := invokeDecide(loader, appConfig, deps(nil), `{"router":"smart-router","input":"hello"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusServiceUnavailable))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("classifier unavailable"))
|
||||
})
|
||||
|
||||
It("returns the picked candidate when one covers the active labels", func() {
|
||||
writeScoreRouter(modelDir, "smart-router")
|
||||
writeBareModel(modelDir, "small-model")
|
||||
writeBareModel(modelDir, "big-model")
|
||||
scorer := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -0.05, // dominant
|
||||
"casual-chat": -3.0,
|
||||
"math-reasoning": -4.0,
|
||||
}}
|
||||
rec, body := invokeDecide(loader, appConfig, deps(scorer), `{"router":"smart-router","input":"debug my Go null pointer"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Candidate).To(Equal("big-model"))
|
||||
Expect(body.Fallback).To(BeFalse())
|
||||
Expect(body.Labels).To(ContainElement("code-generation"))
|
||||
Expect(body.Classifier).To(Equal(router.ClassifierScore))
|
||||
Expect(body.Score).To(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
It("returns the fallback when no candidate covers the active labels", func() {
|
||||
// The router declares a label `math-reasoning` but no
|
||||
// candidate carries it — only small=[casual-chat] and
|
||||
// big=[code-generation, casual-chat]. A classifier output of
|
||||
// "math-reasoning" forces the fallback path.
|
||||
writeRouterNoFallbackCover(modelDir, "smart-router")
|
||||
writeBareModel(modelDir, "small-model")
|
||||
writeBareModel(modelDir, "big-model")
|
||||
writeBareModel(modelDir, "fallback-model")
|
||||
scorer := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"math-reasoning": -0.05,
|
||||
"code-generation": -3.0,
|
||||
"casual-chat": -4.0,
|
||||
}}
|
||||
rec, body := invokeDecide(loader, appConfig, deps(scorer), `{"router":"smart-router","input":"3 apples cost $2.40"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Candidate).To(Equal("fallback-model"))
|
||||
Expect(body.Fallback).To(BeTrue())
|
||||
Expect(body.Labels).To(ContainElement("math-reasoning"))
|
||||
})
|
||||
})
|
||||
|
||||
// stubScorer mirrors the one in core/http/middleware/route_model_test.go.
|
||||
// Duplicated rather than exported because Go test helpers don't cross
|
||||
// _test.go package boundaries and exporting test-only types would
|
||||
// pollute the production surface.
|
||||
type stubScorer struct {
|
||||
labelToLogProb map[string]float64
|
||||
}
|
||||
|
||||
func (s *stubScorer) Score(_ context.Context, _ string, candidates []string) ([]backend.CandidateScore, error) {
|
||||
out := make([]backend.CandidateScore, len(candidates))
|
||||
for i, c := range candidates {
|
||||
// Candidate is the Arch-Router JSON envelope
|
||||
// `{"route": "<label>"}<stop>`; match against the full
|
||||
// envelope so overlapping labels (e.g. `code` vs
|
||||
// `code-generation`) can't collide under Go's randomised map
|
||||
// iteration. Without this the lookup misses on every
|
||||
// candidate and softmax flattens, making assertions pass for
|
||||
// accidental reasons.
|
||||
var lp float64
|
||||
for label, v := range s.labelToLogProb {
|
||||
if strings.Contains(c, `{"route": "`+label+`"}`) {
|
||||
lp = v
|
||||
break
|
||||
}
|
||||
}
|
||||
out[i] = backend.CandidateScore{
|
||||
LogProb: lp * 2,
|
||||
LengthNormalizedLogProb: lp,
|
||||
NumTokens: 2,
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// deps wires a ClassifierDeps with a fresh registry and (optionally) a
|
||||
// stub scorer. Nil scorer is used to exercise the unavailable path.
|
||||
func deps(s *stubScorer) middleware.ClassifierDeps {
|
||||
var scorer middleware.ScorerFactory
|
||||
if s != nil {
|
||||
scorer = func(string) backend.Scorer { return s }
|
||||
}
|
||||
return middleware.ClassifierDeps{
|
||||
Scorer: scorer,
|
||||
Registry: router.NewRegistry(),
|
||||
}
|
||||
}
|
||||
|
||||
func invokeDecide(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, d middleware.ClassifierDeps, body string) (*httptest.ResponseRecorder, schema.RouterDecideResponse) {
|
||||
// Route through echo's mux so the default HTTPErrorHandler
|
||||
// serialises echo.HTTPError into the response body. Calling the
|
||||
// handler directly with a fresh Context skips that step and
|
||||
// leaves the recorder empty on errors.
|
||||
e := echo.New()
|
||||
e.POST("/api/router/decide", localai.RouterDecideEndpoint(loader, appConfig, d))
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/router/decide", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
var parsed schema.RouterDecideResponse
|
||||
if rec.Code == http.StatusOK {
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &parsed)).To(Succeed())
|
||||
}
|
||||
return rec, parsed
|
||||
}
|
||||
|
||||
func writeScoreRouter(modelDir, name string) {
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Router: config.RouterConfig{
|
||||
Classifier: "score",
|
||||
ClassifierModel: "arch-router",
|
||||
Fallback: "small-model",
|
||||
Policies: []config.RouterPolicy{
|
||||
{Label: "code-generation", Description: "writing or debugging code"},
|
||||
{Label: "casual-chat", Description: "small talk"},
|
||||
{Label: "math-reasoning", Description: "arithmetic and word problems"},
|
||||
},
|
||||
Candidates: []config.RouterCandidate{
|
||||
{Model: "small-model", Labels: []string{"casual-chat"}},
|
||||
{Model: "big-model", Labels: []string{"code-generation", "casual-chat", "math-reasoning"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, err := yaml.Marshal(cfg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), b, 0o644)).To(Succeed())
|
||||
}
|
||||
|
||||
// writeRouterNoFallbackCover declares math-reasoning as a policy but
|
||||
// has no candidate covering it. Combined with Fallback=fallback-model,
|
||||
// a math-reasoning classification forces the fallback branch.
|
||||
func writeRouterNoFallbackCover(modelDir, name string) {
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Router: config.RouterConfig{
|
||||
Classifier: "score",
|
||||
ClassifierModel: "arch-router",
|
||||
Fallback: "fallback-model",
|
||||
Policies: []config.RouterPolicy{
|
||||
{Label: "code-generation", Description: "writing or debugging code"},
|
||||
{Label: "casual-chat", Description: "small talk"},
|
||||
{Label: "math-reasoning", Description: "arithmetic and word problems"},
|
||||
},
|
||||
Candidates: []config.RouterCandidate{
|
||||
{Model: "small-model", Labels: []string{"casual-chat"}},
|
||||
{Model: "big-model", Labels: []string{"code-generation", "casual-chat"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, err := yaml.Marshal(cfg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), b, 0o644)).To(Succeed())
|
||||
}
|
||||
|
||||
func writeBareModel(modelDir, name string) {
|
||||
body := "name: " + name + "\nbackend: mock-backend\n"
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(body), 0o644)).To(Succeed())
|
||||
}
|
||||
90
core/http/endpoints/localai/score.go
Normal file
90
core/http/endpoints/localai/score.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// ScoreRequest is the wire format for POST /api/score. Mirrors the
|
||||
// gRPC ScoreRequest one-to-one — the endpoint exists primarily to
|
||||
// smoke-test the new Score primitive end-to-end without writing a
|
||||
// custom gRPC client. Production routing will call backend.ModelScore
|
||||
// directly via the router-side adapter.
|
||||
type ScoreRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Candidates []string `json:"candidates"`
|
||||
IncludeTokenLogprobs bool `json:"include_token_logprobs,omitempty"`
|
||||
LengthNormalize bool `json:"length_normalize,omitempty"`
|
||||
}
|
||||
|
||||
type ScoreResponseCandidate struct {
|
||||
LogProb float64 `json:"log_prob"`
|
||||
LengthNormalizedLogProb float64 `json:"length_normalized_log_prob,omitempty"`
|
||||
NumTokens int `json:"num_tokens"`
|
||||
Tokens []ScoreTokenLP `json:"tokens,omitempty"`
|
||||
}
|
||||
|
||||
type ScoreTokenLP struct {
|
||||
Token string `json:"token"`
|
||||
LogProb float64 `json:"log_prob"`
|
||||
}
|
||||
|
||||
type ScoreResponse struct {
|
||||
Model string `json:"model"`
|
||||
Candidates []ScoreResponseCandidate `json:"candidates"`
|
||||
}
|
||||
|
||||
// ScoreEndpoint exposes the Score gRPC primitive over HTTP. Admin-only —
|
||||
// scoring loads a model and runs inference, same risk surface as
|
||||
// /v1/chat/completions.
|
||||
func ScoreEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req ScoreRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(400, "invalid request body: "+err.Error())
|
||||
}
|
||||
if req.Model == "" {
|
||||
return echo.NewHTTPError(400, "model is required")
|
||||
}
|
||||
if len(req.Candidates) == 0 {
|
||||
return echo.NewHTTPError(400, "candidates must be non-empty")
|
||||
}
|
||||
|
||||
modelConfig, err := cl.LoadModelConfigFileByNameDefaultOptions(req.Model, appConfig)
|
||||
if err != nil || modelConfig == nil {
|
||||
return echo.NewHTTPError(404, "model not found: "+req.Model)
|
||||
}
|
||||
|
||||
fn, err := backend.ModelScore(req.Prompt, req.Candidates, backend.ScoreOptions{
|
||||
IncludeTokenLogprobs: req.IncludeTokenLogprobs,
|
||||
LengthNormalize: req.LengthNormalize,
|
||||
}, ml, *modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(500, "failed to bind scorer: "+err.Error())
|
||||
}
|
||||
results, err := fn(c.Request().Context())
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(500, "score call failed: "+err.Error())
|
||||
}
|
||||
|
||||
out := ScoreResponse{Model: req.Model, Candidates: make([]ScoreResponseCandidate, len(results))}
|
||||
for i, r := range results {
|
||||
out.Candidates[i] = ScoreResponseCandidate{
|
||||
LogProb: r.LogProb,
|
||||
LengthNormalizedLogProb: r.LengthNormalizedLogProb,
|
||||
NumTokens: r.NumTokens,
|
||||
}
|
||||
if req.IncludeTokenLogprobs && len(r.Tokens) > 0 {
|
||||
toks := make([]ScoreTokenLP, len(r.Tokens))
|
||||
for j, t := range r.Tokens {
|
||||
toks[j] = ScoreTokenLP{Token: t.Token, LogProb: t.LogProb}
|
||||
}
|
||||
out.Candidates[i].Tokens = toks
|
||||
}
|
||||
}
|
||||
return c.JSON(200, out)
|
||||
}
|
||||
}
|
||||
@@ -253,6 +253,16 @@ func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
if settings.MITMListen != nil {
|
||||
if err := app.RestartMITM(); err != nil {
|
||||
xlog.Error("Failed to restart MITM proxy", "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{
|
||||
Success: false,
|
||||
Error: "Settings saved but failed to restart MITM proxy: " + err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Restart P2P if P2P settings changed
|
||||
p2pChanged := settings.P2PToken != nil || settings.P2PNetworkID != nil || settings.Federated != nil
|
||||
if p2pChanged {
|
||||
|
||||
@@ -74,6 +74,34 @@ func (stubClient) GetBranding(_ context.Context) (*localaitools.Branding, error)
|
||||
func (stubClient) SetBranding(_ context.Context, _ localaitools.SetBrandingRequest) (*localaitools.Branding, error) {
|
||||
return &localaitools.Branding{InstanceName: "LocalAI"}, nil
|
||||
}
|
||||
func (stubClient) GetUsageStats(_ context.Context, _ localaitools.UsageStatsQuery) (*localaitools.UsageStats, error) {
|
||||
return &localaitools.UsageStats{Viewer: localaitools.UsageViewer{ID: "stub", Name: "stub"}, Period: "month"}, nil
|
||||
}
|
||||
func (stubClient) ListPIIPatterns(_ context.Context) ([]localaitools.PIIPattern, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (stubClient) GetPIIEvents(_ context.Context, _ localaitools.PIIEventsQuery) ([]localaitools.PIIEvent, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (stubClient) TestPIIRedaction(_ context.Context, req localaitools.PIIRedactTestRequest) (*localaitools.PIIRedactTestResult, error) {
|
||||
return &localaitools.PIIRedactTestResult{Redacted: req.Text}, nil
|
||||
}
|
||||
func (stubClient) SetPIIPatternAction(_ context.Context, _ localaitools.PIIPatternActionUpdate) error {
|
||||
return nil
|
||||
}
|
||||
func (stubClient) PersistPIIPatterns(_ context.Context) error { return nil }
|
||||
func (stubClient) GetMiddlewareStatus(_ context.Context) (*localaitools.MiddlewareStatus, error) {
|
||||
return &localaitools.MiddlewareStatus{
|
||||
PII: localaitools.MiddlewarePIIStatus{
|
||||
EnabledGlobally: true,
|
||||
Patterns: []localaitools.PIIPattern{},
|
||||
Models: []localaitools.MiddlewarePIIModel{},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
func (stubClient) GetRouterDecisions(_ context.Context, _ localaitools.RouterDecisionsQuery) ([]localaitools.RouterDecision, error) {
|
||||
return []localaitools.RouterDecision{}, nil
|
||||
}
|
||||
|
||||
var _ = Describe("LocalAIAssistantHolder", func() {
|
||||
var ctx context.Context
|
||||
|
||||
@@ -36,7 +36,7 @@ func EmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
promptEvalCount := 0
|
||||
|
||||
for _, s := range inputStrings {
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *cfg, appConfig)
|
||||
embedFn, err := backend.ModelEmbedding(c.Request().Context(), s, []int{}, ml, *cfg, appConfig)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama embed failed", "error", err)
|
||||
return ollamaError(c, 500, fmt.Sprintf("embedding failed: %v", err))
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
reason "github.com/mudler/LocalAI/pkg/reasoning"
|
||||
|
||||
@@ -72,7 +74,7 @@ func mergeToolCallDeltas(existing []schema.ToolCall, deltas []schema.ToolCall) [
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder) echo.HandlerFunc {
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var textContentToReturn string
|
||||
id := uuid.New().String()
|
||||
@@ -92,6 +94,15 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
xlog.Debug("Chat endpoint configuration read", "config", config)
|
||||
|
||||
// Cloud-proxy bail. Bypasses the local pipeline (templating,
|
||||
// MCP injection, gRPC backend) and forwards via the cloud-
|
||||
// proxy backend, which does the outbound HTTP. The streaming
|
||||
// PII filter still runs because its input is per-token text
|
||||
// extracted from the wire envelope, not the envelope itself.
|
||||
if config.IsCloudProxyBackendPassthrough() {
|
||||
return forwardCloudProxyOpenAIViaBackend(c, config, input, piiRedactor, piiEvents, ml, startupOptions)
|
||||
}
|
||||
|
||||
funcs := input.Functions
|
||||
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
|
||||
strictMode := false
|
||||
@@ -326,6 +337,14 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
|
||||
// Per-stream PII filter: when the resolved model has PII
|
||||
// enabled, wrap the response content so values spanning
|
||||
// chunk boundaries still get masked. Shared with the
|
||||
// cloud-proxy bail below via cloudproxy.BuildStreamFilter
|
||||
// so both paths apply the same per-model gate and override
|
||||
// rules.
|
||||
streamPIIFilter := cloudproxy.BuildStreamFilter(c, config, true, piiRedactor, piiEvents, id)
|
||||
|
||||
mcpStreamMaxIterations := 10
|
||||
if config.Agent.MaxIterations > 0 {
|
||||
mcpStreamMaxIterations = config.Agent.MaxIterations
|
||||
@@ -377,12 +396,52 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
collectedToolCalls = mergeToolCallDeltas(collectedToolCalls, ev.Choices[0].Delta.ToolCalls)
|
||||
}
|
||||
}
|
||||
// Collect content for MCP conversation history and automatic tool parsing fallback
|
||||
if (hasMCPToolsStream || config.FunctionsConfig.AutomaticToolParsingFallback) && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil {
|
||||
if s, ok := ev.Choices[0].Delta.Content.(string); ok {
|
||||
collectedContent += s
|
||||
} else if sp, ok := ev.Choices[0].Delta.Content.(*string); ok && sp != nil {
|
||||
collectedContent += *sp
|
||||
// Extract the raw content delta string once per chunk;
|
||||
// both the MCP collector and the PII filter need it
|
||||
// and the type-switch is otherwise duplicated.
|
||||
var rawContent string
|
||||
haveContent := false
|
||||
if ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil {
|
||||
switch v := ev.Choices[0].Delta.Content.(type) {
|
||||
case string:
|
||||
rawContent = v
|
||||
haveContent = true
|
||||
case *string:
|
||||
if v != nil {
|
||||
rawContent = *v
|
||||
haveContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// Collect content for MCP conversation history and automatic tool parsing fallback.
|
||||
// We collect the RAW (unfiltered) content so the model's tool-call
|
||||
// markup keeps parsing correctly even when PII redaction would mask
|
||||
// substrings.
|
||||
if (hasMCPToolsStream || config.FunctionsConfig.AutomaticToolParsingFallback) && haveContent {
|
||||
collectedContent += rawContent
|
||||
}
|
||||
// Stream-side PII filter: feed the content delta
|
||||
// through the buffered-emit filter. The filter
|
||||
// holds back a tail to handle pattern boundaries
|
||||
// across chunks, so a Push may legitimately
|
||||
// return "" — drop the chunk in that case rather
|
||||
// than emitting an empty Delta to the wire.
|
||||
if streamPIIFilter != nil && haveContent {
|
||||
filtered := streamPIIFilter.Push(rawContent)
|
||||
if filtered == "" {
|
||||
// Fully buffered — skip this chunk's
|
||||
// content. Still emit non-content chunks
|
||||
// (role, tool_calls). When this delta is
|
||||
// content-only and we buffer it, drop the
|
||||
// whole event to avoid a vestigial
|
||||
// {"delta":{}} on the wire.
|
||||
if ev.Choices[0].Delta.Role == "" && len(ev.Choices[0].Delta.ToolCalls) == 0 && ev.Choices[0].Delta.Reasoning == nil {
|
||||
continue
|
||||
}
|
||||
// Mixed delta — strip content, keep the rest.
|
||||
ev.Choices[0].Delta.Content = nil
|
||||
} else {
|
||||
ev.Choices[0].Delta.Content = filtered
|
||||
}
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
@@ -529,6 +588,31 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}
|
||||
|
||||
// Drain the per-stream PII filter before the stop chunk
|
||||
// so any text held back by the buffered-emit invariant
|
||||
// reaches the client as a regular content delta. We
|
||||
// emit it as a chunk WITHOUT a finish_reason so the
|
||||
// next "stop" chunk still terminates the stream.
|
||||
if streamPIIFilter != nil {
|
||||
residual := streamPIIFilter.Drain()
|
||||
if residual != "" {
|
||||
drainResp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Content: residual},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
if drainBytes, err := json.Marshal(drainResp); err == nil {
|
||||
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", drainBytes)
|
||||
c.Response().Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No MCP tools to execute, send final stop message
|
||||
finishReason := FinishReasonStop
|
||||
if toolsCalled && len(input.Tools) > 0 {
|
||||
@@ -553,6 +637,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
middleware.StampUsage(c, input.Model, finalUsage.Prompt, finalUsage.Completion)
|
||||
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
|
||||
// Trailing usage chunk per OpenAI spec: emit only when the
|
||||
@@ -935,6 +1022,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
respData, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(respData))
|
||||
|
||||
middleware.StampUsage(c, input.Model, usage.PromptTokens, usage.CompletionTokens)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
} // end MCP iteration loop
|
||||
@@ -981,3 +1070,20 @@ func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCall
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// forwardCloudProxyOpenAIViaBackend marshals the OpenAI request,
|
||||
// constructs the streaming PII filter (when this model has PII
|
||||
// enabled), and hands off to the cloud-proxy gRPC backend which does
|
||||
// the outbound HTTP. The chat endpoint owns the body+filter
|
||||
// construction because it's the only place the request lands as a
|
||||
// parsed *schema.OpenAIRequest.
|
||||
func forwardCloudProxyOpenAIViaBackend(c echo.Context, cfg *config.ModelConfig, input *schema.OpenAIRequest, piiRedactor *pii.Redactor, piiEvents pii.EventStore, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
body, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "cloudproxy: marshal request: "+err.Error())
|
||||
}
|
||||
|
||||
correlationID := c.Response().Header().Get("X-Correlation-ID")
|
||||
streamFilter := cloudproxy.BuildStreamFilter(c, cfg, input.Stream, piiRedactor, piiEvents, correlationID)
|
||||
return cloudproxy.ForwardViaBackend(c, cfg, body, streamFilter, ml, appConfig)
|
||||
}
|
||||
|
||||
@@ -21,6 +21,13 @@ import (
|
||||
// The caller owns the `responses` channel and is expected to read from
|
||||
// it while this function runs; processStream closes the channel before
|
||||
// returning.
|
||||
//
|
||||
// X-LocalAI-Node attribution (when --expose-node-header is on) is
|
||||
// handled by middleware.ExposeNodeHeader at the response writer wrapper
|
||||
// layer; no in-band signal from the worker is needed. The initial
|
||||
// role=assistant chunk is still emitted from the first token callback
|
||||
// rather than eagerly here, so the wrapper's lazy lookup against the
|
||||
// loader runs AFTER ml.Load has stamped the per-modelID node ID.
|
||||
func processStream(
|
||||
s string,
|
||||
req *schema.OpenAIRequest,
|
||||
@@ -32,13 +39,7 @@ func processStream(
|
||||
id string,
|
||||
created int,
|
||||
) (backend.TokenUsage, error) {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentInitialRole := false
|
||||
|
||||
// Detect if thinking token is already in prompt or template
|
||||
// When UseTokenizerTemplate is enabled, predInput is empty, so we check the template
|
||||
@@ -70,6 +71,17 @@ func processStream(
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
if !sentInitialRole {
|
||||
sentInitialRole = true
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
delta := &schema.Message{}
|
||||
if contentDelta != "" {
|
||||
delta.Content = &contentDelta
|
||||
@@ -130,6 +142,9 @@ func processStreamWithTools(
|
||||
hasChatDeltaToolCalls := false
|
||||
hasChatDeltaContent := false
|
||||
|
||||
// X-LocalAI-Node attribution is handled by middleware.ExposeNodeHeader
|
||||
// at the wrapper layer; no in-band signalling from this worker.
|
||||
|
||||
_, finalUsage, chatDeltas, err := ComputeChoices(req, prompt, cfg, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
|
||||
|
||||
@@ -9,10 +9,12 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -25,7 +27,7 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc {
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
created := int(time.Now().Unix())
|
||||
@@ -111,6 +113,31 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
}
|
||||
|
||||
// Per-stream PII filter — same gating as chat. /v1/completions
|
||||
// has no chat-message structure, so request-side PII isn't
|
||||
// wired here, but the response-side filter still catches PII
|
||||
// trained into the model. Filter is nil when this model has
|
||||
// PII disabled.
|
||||
var streamPIIFilter *pii.StreamFilter
|
||||
if piiRedactor != nil && config.PIIIsEnabled() {
|
||||
correlationID := id
|
||||
userID := ""
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
userID = u.ID
|
||||
}
|
||||
var overrides map[string]pii.Action
|
||||
if raw := config.PIIPatternOverrides(); len(raw) > 0 {
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
}
|
||||
streamPIIFilter = pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID)
|
||||
}
|
||||
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||
@@ -143,12 +170,28 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
}
|
||||
// Capture running cumulative usage for the optional trailer
|
||||
// emitted after the final stop chunk when include_usage=true.
|
||||
// Done before the PII filter so a fully-buffered chunk
|
||||
// (which we drop from the wire) still contributes to the
|
||||
// running total.
|
||||
if ev.Usage != nil {
|
||||
latestUsage = ev.Usage
|
||||
}
|
||||
// OpenAI streaming spec: intermediate chunks must NOT
|
||||
// carry a `usage` field. Strip the tracking copy now.
|
||||
ev.Usage = nil
|
||||
// Run the per-chunk text through the streaming PII
|
||||
// filter. The filter holds back a tail to handle
|
||||
// pattern boundaries, so a Push may legitimately
|
||||
// return "" — drop the chunk's text rather than
|
||||
// emitting a 0-token delta. Choice.Text is the only
|
||||
// content surface in /v1/completions chunks.
|
||||
if streamPIIFilter != nil && ev.Choices[0].Text != "" {
|
||||
filtered := streamPIIFilter.Push(ev.Choices[0].Text)
|
||||
if filtered == "" {
|
||||
continue
|
||||
}
|
||||
ev.Choices[0].Text = filtered
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
xlog.Debug("Failed to marshal response", "error", err)
|
||||
@@ -194,6 +237,25 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any residual the streaming PII filter held back as
|
||||
// part of its trailing pattern-window. Emit it as one final
|
||||
// text-bearing chunk before the synthetic stop chunk so the
|
||||
// completion body remains a contiguous text stream.
|
||||
if streamPIIFilter != nil {
|
||||
if residual := streamPIIFilter.Drain(); residual != "" {
|
||||
residualResp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{{Index: 0, Text: residual}},
|
||||
Object: "text_completion",
|
||||
}
|
||||
if data, err := json.Marshal(residualResp); err == nil {
|
||||
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
@@ -208,6 +270,14 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
pt, ct := 0, 0
|
||||
if latestUsage != nil {
|
||||
pt = latestUsage.PromptTokens
|
||||
ct = latestUsage.CompletionTokens
|
||||
}
|
||||
middleware.StampUsage(c, input.Model, pt, ct)
|
||||
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
|
||||
// Trailing usage chunk per OpenAI spec: emit only when the caller
|
||||
@@ -274,6 +344,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(jsonResult))
|
||||
|
||||
middleware.StampUsage(c, input.Model, totalTokenUsage.Prompt, totalTokenUsage.Completion)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
@@ -98,6 +98,8 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(jsonResult))
|
||||
|
||||
middleware.StampUsage(c, input.Model, totalTokenUsage.Prompt, totalTokenUsage.Completion)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
|
||||
for i, s := range config.InputToken {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
|
||||
embedFn, err := backend.ModelEmbedding(input.Context, "", s, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
|
||||
for i, s := range config.InputStrings {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
|
||||
embedFn, err := backend.ModelEmbedding(input.Context, s, []int{}, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -102,6 +102,15 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
xlog.Debug("Response", "response", string(jsonResult))
|
||||
|
||||
// LocalAI's embeddings endpoint does not currently track per-call
|
||||
// token counts (the gRPC Embedding RPC returns a vector, not a
|
||||
// usage block), so we stamp with zeros. The point of stamping is
|
||||
// that the billing pipeline still sees the request and emits the
|
||||
// localai_billed_requests_total counter; without this the call
|
||||
// would be silently dropped by the unrecorded-counter path. When
|
||||
// embeddings learn to report usage, swap the zeros for real counts.
|
||||
middleware.StampUsage(c, input.Model, 0, 0)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
inputSrc = inputImages[0]
|
||||
}
|
||||
|
||||
fn, err := backend.ImageGeneration(height, width, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages)
|
||||
fn, err := backend.ImageGeneration(c.Request().Context(), height, width, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@ func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
// Note: ImageGenerationFunc will call into the loaded model's GenerateImage which expects src JSON
|
||||
// Also pass ref images (orig + mask) so backends that support ref images can use them.
|
||||
refImages := []string{origRef, maskRef}
|
||||
fn, err := backend.ImageGenerationFunc(height, width, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages)
|
||||
fn, err := backend.ImageGenerationFunc(c.Request().Context(), height, width, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -56,7 +57,7 @@ var _ = Describe("Inpainting", func() {
|
||||
appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir))
|
||||
|
||||
orig := backend.ImageGenerationFunc
|
||||
backend.ImageGenerationFunc = func(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
backend.ImageGenerationFunc = func(ctx context.Context, height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
fn := func() error {
|
||||
return os.WriteFile(dst, []byte("PNGDATA"), 0644)
|
||||
}
|
||||
|
||||
@@ -497,6 +497,7 @@ func runRealtimeSession(application *application.Application, t Transport, model
|
||||
application.ModelLoader(),
|
||||
application.ApplicationConfig(),
|
||||
evaluator,
|
||||
buildRealtimeRoutingContext(application, sessionID),
|
||||
)
|
||||
if err != nil {
|
||||
xlog.Error("failed to load model", "error", err)
|
||||
@@ -627,6 +628,7 @@ func runRealtimeSession(application *application.Application, t Transport, model
|
||||
application.ModelLoader(),
|
||||
application.ApplicationConfig(),
|
||||
evaluator,
|
||||
buildRealtimeRoutingContext(application, session.ID),
|
||||
); err != nil {
|
||||
xlog.Error("failed to update session", "error", err)
|
||||
sendError(t, "session_update_error", "Failed to update session", "", "")
|
||||
@@ -946,7 +948,7 @@ func updateTransSession(session *Session, update *types.SessionUnion, cl *config
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSession(session *Session, update *types.SessionUnion, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, evaluator *templates.Evaluator) error {
|
||||
func updateSession(session *Session, update *types.SessionUnion, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, evaluator *templates.Evaluator, routing *RealtimeRoutingContext) error {
|
||||
sessionLock.Lock()
|
||||
defer sessionLock.Unlock()
|
||||
|
||||
@@ -985,7 +987,7 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
||||
}
|
||||
|
||||
if rt.Model != "" || (rt.Audio != nil && rt.Audio.Output != nil && rt.Audio.Output.Voice != "") || (rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.Transcription != nil) {
|
||||
m, err := newModel(&session.ModelConfig.Pipeline, cl, ml, appConfig, evaluator)
|
||||
m, err := newModel(&session.ModelConfig.Pipeline, cl, ml, appConfig, evaluator, routing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,13 +2,18 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -34,6 +39,15 @@ type wrappedModel struct {
|
||||
modelLoader *model.ModelLoader
|
||||
confLoader *config.ModelConfigLoader
|
||||
evaluator *templates.Evaluator
|
||||
|
||||
// Routing — populated by newModel when the application wires routing
|
||||
// deps in. nil-safe: with classifierRegistry == nil the per-turn
|
||||
// routing block in Predict is skipped, preserving today's "one LLM
|
||||
// for the whole session" behaviour.
|
||||
routerDeps *middleware.ClassifierDeps
|
||||
routerStore router.DecisionStore
|
||||
routerSessionID string
|
||||
routerUserID string
|
||||
}
|
||||
|
||||
// anyToAnyModel represent a model which supports Any-to-Any operations
|
||||
@@ -90,9 +104,24 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
// Per-turn routing: when the session's LLMConfig is a router, swap
|
||||
// to the candidate the classifier picks for this turn's prompt.
|
||||
// LLMConfig itself is held by value (we never mutate it) — turnCfg
|
||||
// is the config we dispatch against.
|
||||
turnCfg := m.LLMConfig
|
||||
if m.LLMConfig.HasRouter() && m.routerDeps != nil {
|
||||
chosen, err := m.routeTurn(ctx, &input)
|
||||
if err != nil {
|
||||
xlog.Warn("realtime routing failed; using session default LLM",
|
||||
"router_model", m.LLMConfig.Name, "error", err)
|
||||
} else if chosen != nil {
|
||||
turnCfg = chosen
|
||||
}
|
||||
}
|
||||
|
||||
var predInput string
|
||||
var funcs []functions.Function
|
||||
if !m.LLMConfig.TemplateConfig.UseTokenizerTemplate {
|
||||
if !turnCfg.TemplateConfig.UseTokenizerTemplate {
|
||||
if len(tools) > 0 {
|
||||
for _, t := range tools {
|
||||
if t.Function != nil {
|
||||
@@ -120,11 +149,11 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
noActionName := "answer"
|
||||
noActionDescription := "use this action to answer without performing any action"
|
||||
|
||||
if m.LLMConfig.FunctionsConfig.NoActionFunctionName != "" {
|
||||
noActionName = m.LLMConfig.FunctionsConfig.NoActionFunctionName
|
||||
if turnCfg.FunctionsConfig.NoActionFunctionName != "" {
|
||||
noActionName = turnCfg.FunctionsConfig.NoActionFunctionName
|
||||
}
|
||||
if m.LLMConfig.FunctionsConfig.NoActionDescriptionName != "" {
|
||||
noActionDescription = m.LLMConfig.FunctionsConfig.NoActionDescriptionName
|
||||
if turnCfg.FunctionsConfig.NoActionDescriptionName != "" {
|
||||
noActionDescription = turnCfg.FunctionsConfig.NoActionDescriptionName
|
||||
}
|
||||
|
||||
noActionGrammar := functions.Function{
|
||||
@@ -140,16 +169,16 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
},
|
||||
}
|
||||
|
||||
if !m.LLMConfig.FunctionsConfig.DisableNoAction {
|
||||
if !turnCfg.FunctionsConfig.DisableNoAction {
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
}
|
||||
}
|
||||
|
||||
predInput = m.evaluator.TemplateMessages(input, input.Messages, m.LLMConfig, funcs, len(funcs) > 0)
|
||||
predInput = m.evaluator.TemplateMessages(input, input.Messages, turnCfg, funcs, len(funcs) > 0)
|
||||
|
||||
xlog.Debug("Prompt (after templating)", "prompt", predInput)
|
||||
if m.LLMConfig.Grammar != "" {
|
||||
xlog.Debug("Grammar", "grammar", m.LLMConfig.Grammar)
|
||||
if turnCfg.Grammar != "" {
|
||||
xlog.Debug("Grammar", "grammar", turnCfg.Grammar)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,33 +188,33 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
// String values: "auto", "required", "none"
|
||||
switch toolChoice.Mode {
|
||||
case types.ToolChoiceModeRequired:
|
||||
m.LLMConfig.SetFunctionCallString("required")
|
||||
turnCfg.SetFunctionCallString("required")
|
||||
case types.ToolChoiceModeNone:
|
||||
// Don't use tools
|
||||
m.LLMConfig.SetFunctionCallString("none")
|
||||
turnCfg.SetFunctionCallString("none")
|
||||
case types.ToolChoiceModeAuto:
|
||||
// Default behavior - let model decide
|
||||
}
|
||||
} else if toolChoice.Function != nil {
|
||||
// Specific function specified
|
||||
m.LLMConfig.SetFunctionCallNameString(toolChoice.Function.Name)
|
||||
turnCfg.SetFunctionCallNameString(toolChoice.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate grammar for function calling if tools are provided and grammar generation is enabled
|
||||
shouldUseFn := len(tools) > 0 && m.LLMConfig.ShouldUseFunctions()
|
||||
shouldUseFn := len(tools) > 0 && turnCfg.ShouldUseFunctions()
|
||||
|
||||
if !m.LLMConfig.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn {
|
||||
if !turnCfg.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn {
|
||||
// Force picking one of the functions by the request
|
||||
if m.LLMConfig.FunctionToCall() != "" {
|
||||
funcs = functions.Functions(funcs).Select(m.LLMConfig.FunctionToCall())
|
||||
if turnCfg.FunctionToCall() != "" {
|
||||
funcs = functions.Functions(funcs).Select(turnCfg.FunctionToCall())
|
||||
}
|
||||
|
||||
// Generate grammar from function definitions
|
||||
jsStruct := functions.Functions(funcs).ToJSONStructure(m.LLMConfig.FunctionsConfig.FunctionNameKey, m.LLMConfig.FunctionsConfig.FunctionNameKey)
|
||||
g, err := jsStruct.Grammar(m.LLMConfig.FunctionsConfig.GrammarOptions()...)
|
||||
jsStruct := functions.Functions(funcs).ToJSONStructure(turnCfg.FunctionsConfig.FunctionNameKey, turnCfg.FunctionsConfig.FunctionNameKey)
|
||||
g, err := jsStruct.Grammar(turnCfg.FunctionsConfig.GrammarOptions()...)
|
||||
if err == nil {
|
||||
m.LLMConfig.Grammar = g
|
||||
turnCfg.Grammar = g
|
||||
xlog.Debug("Generated grammar for function calling", "grammar", g)
|
||||
} else {
|
||||
xlog.Error("Failed generating grammar", "error", err)
|
||||
@@ -237,7 +266,50 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
toolChoiceJSON = string(b)
|
||||
}
|
||||
|
||||
return backend.ModelInference(ctx, predInput, messages, images, videos, audios, m.modelLoader, m.LLMConfig, m.confLoader, m.appConfig, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, nil)
|
||||
return backend.ModelInference(ctx, predInput, messages, images, videos, audios, m.modelLoader, turnCfg, m.confLoader, m.appConfig, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, nil)
|
||||
}
|
||||
|
||||
// routeTurn classifies this turn's prompt against the session's router
|
||||
// LLM config and returns the candidate ModelConfig to dispatch against.
|
||||
// Returns nil with no error when routing was attempted but the resolver
|
||||
// signalled "no decision" — the caller falls back to the session
|
||||
// default. Records the decision in the store using the realtime session
|
||||
// id as the correlation id so the admin UI can group turn-by-turn
|
||||
// decisions under one session row.
|
||||
func (m *wrappedModel) routeTurn(ctx context.Context, req *schema.OpenAIRequest) (*config.ModelConfig, error) {
|
||||
if m.routerDeps == nil {
|
||||
return nil, nil
|
||||
}
|
||||
registry := m.routerDeps.Registry
|
||||
if registry == nil {
|
||||
registry = router.NewRegistry()
|
||||
}
|
||||
classifier, classifierErr := middleware.GetOrBuildClassifier(registry, m.LLMConfig, *m.routerDeps)
|
||||
if classifierErr != nil {
|
||||
xlog.Warn("realtime router: classifier unavailable — using fallback",
|
||||
"router_model", m.LLMConfig.Name, "error", classifierErr)
|
||||
classifier = nil
|
||||
}
|
||||
loader := func(name string) (*config.ModelConfig, error) {
|
||||
return m.confLoader.LoadModelConfigFileByNameDefaultOptions(name, m.appConfig)
|
||||
}
|
||||
probe := middleware.OpenAIProbeFromRequest(req)
|
||||
|
||||
result, err := router.Resolve(ctx, m.LLMConfig, classifier, loader, probe)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.routerStore != nil {
|
||||
_ = m.routerStore.Record(context.Background(), result.ToDecisionRecord(newRealtimeDecisionID(), m.routerSessionID, m.routerUserID, router.SourceRealtime))
|
||||
}
|
||||
return result.ChosenConfig, nil
|
||||
}
|
||||
|
||||
func newRealtimeDecisionID() string {
|
||||
var b [12]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return "rd_" + hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
|
||||
@@ -279,8 +351,49 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig
|
||||
}, cfgSST, nil
|
||||
}
|
||||
|
||||
// RealtimeRoutingContext is the bundle of routing dependencies the
|
||||
// realtime pipeline needs to consult router.Resolve per turn. nil-safe:
|
||||
// passing nil skips routing entirely and preserves the historical "one
|
||||
// LLM for the whole session" behaviour.
|
||||
type RealtimeRoutingContext struct {
|
||||
Deps *middleware.ClassifierDeps
|
||||
Store router.DecisionStore
|
||||
SessionID string
|
||||
UserID string
|
||||
}
|
||||
|
||||
// buildRealtimeRoutingContext assembles the routing dependencies the
|
||||
// realtime pipeline needs from the application container. Returns nil
|
||||
// when no Application is wired (tests, stripped builds) — that path
|
||||
// leaves wrappedModel.Predict on the historical "no routing" path
|
||||
// instead of failing at session start.
|
||||
func buildRealtimeRoutingContext(a *application.Application, sessionID string) *RealtimeRoutingContext {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
deps := &middleware.ClassifierDeps{
|
||||
Scorer: a.Scorer,
|
||||
Embedder: a.Embedder,
|
||||
VectorStore: a.VectorStore,
|
||||
Reranker: a.Reranker,
|
||||
ModelLookup: a.ModelConfigLookup(),
|
||||
Registry: a.RouterClassifierRegistry(),
|
||||
Evaluator: a.TemplatesEvaluator(),
|
||||
}
|
||||
userID := ""
|
||||
if u := a.FallbackUser(); u != nil {
|
||||
userID = u.ID
|
||||
}
|
||||
return &RealtimeRoutingContext{
|
||||
Deps: deps,
|
||||
Store: a.RouterDecisions(),
|
||||
SessionID: sessionID,
|
||||
UserID: userID,
|
||||
}
|
||||
}
|
||||
|
||||
// returns and loads either a wrapped model or a model that support audio-to-audio
|
||||
func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, evaluator *templates.Evaluator) (Model, error) {
|
||||
func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, evaluator *templates.Evaluator, routing *RealtimeRoutingContext) (Model, error) {
|
||||
xlog.Debug("Creating new model pipeline model", "pipeline", pipeline)
|
||||
|
||||
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
@@ -346,7 +459,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
return &wrappedModel{
|
||||
wm := &wrappedModel{
|
||||
TTSConfig: cfgTTS,
|
||||
TranscriptionConfig: cfgSST,
|
||||
LLMConfig: cfgLLM,
|
||||
@@ -356,5 +469,12 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
|
||||
modelLoader: ml,
|
||||
appConfig: appConfig,
|
||||
evaluator: evaluator,
|
||||
}, nil
|
||||
}
|
||||
if routing != nil {
|
||||
wm.routerDeps = routing.Deps
|
||||
wm.routerStore = routing.Store
|
||||
wm.routerSessionID = routing.SessionID
|
||||
wm.routerUserID = routing.UserID
|
||||
}
|
||||
return wm, nil
|
||||
}
|
||||
|
||||
81
core/http/middleware/admission.go
Normal file
81
core/http/middleware/admission.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// AdmissionControl runs after RouteModel so the limit applies to the
|
||||
// SERVED model — a router fanout that lands on a saturated downstream
|
||||
// model gets rejected even though the requested router-model has slack.
|
||||
//
|
||||
// On reject: HTTP 503, Retry-After header, error JSON. An audit row
|
||||
// goes into the shared event store under KindAdmission so admins see
|
||||
// rejection rates alongside PII and proxy events.
|
||||
//
|
||||
// Models without limits.max_concurrent (the common case) hit a fast
|
||||
// no-op path — Acquire returns immediately for max <= 0.
|
||||
func AdmissionControl(limiter *admission.Limiter, events pii.EventStore) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return next(c)
|
||||
}
|
||||
max := cfg.Limits.MaxConcurrent
|
||||
release, ok := limiter.Acquire(cfg.Name, max)
|
||||
if !ok {
|
||||
retryAfter := admission.RetryAfter(cfg.Limits.RetryAfterSeconds)
|
||||
recordAdmissionRejection(events, cfg.Name, retryAfter)
|
||||
c.Response().Header().Set("Retry-After", strconv.Itoa(int(retryAfter.Seconds())))
|
||||
return c.JSON(http.StatusServiceUnavailable, map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "admission_rejected",
|
||||
"message": fmt.Sprintf("model %q is at capacity (max_concurrent=%d); retry after %s", cfg.Name, max, retryAfter),
|
||||
},
|
||||
})
|
||||
}
|
||||
defer release()
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// admissionEventSeq scopes IDs across the process so rapid
|
||||
// rejections under load get unique row IDs without coordinating
|
||||
// with the rest of the event-store ID schemes.
|
||||
var admissionEventSeq atomic.Uint64
|
||||
|
||||
func recordAdmissionRejection(events pii.EventStore, modelName string, retryAfter time.Duration) {
|
||||
if events == nil {
|
||||
return
|
||||
}
|
||||
statusCode := http.StatusServiceUnavailable
|
||||
durMS := retryAfter.Milliseconds()
|
||||
id := fmt.Sprintf("adm_%d_%s", admissionEventSeq.Add(1), randHex(4))
|
||||
_ = events.Record(context.Background(), pii.PIIEvent{
|
||||
ID: id,
|
||||
Kind: pii.KindAdmission,
|
||||
Host: modelName,
|
||||
StatusCode: statusCode,
|
||||
DurationMS: durMS,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
func randHex(n int) string {
|
||||
b := make([]byte, n)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
118
core/http/middleware/admission_test.go
Normal file
118
core/http/middleware/admission_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// recordingStore captures admission rows so the test can assert
|
||||
// the audit trail without standing up the full pii event store.
|
||||
type recordingStore struct {
|
||||
mu sync.Mutex
|
||||
events []pii.PIIEvent
|
||||
}
|
||||
|
||||
func (r *recordingStore) Record(_ context.Context, e pii.PIIEvent) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.events = append(r.events, e)
|
||||
return nil
|
||||
}
|
||||
func (r *recordingStore) List(_ context.Context, _ pii.ListQuery) ([]pii.PIIEvent, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *recordingStore) Count(_ context.Context) (int, error) { return 0, nil }
|
||||
func (r *recordingStore) Close() error { return nil }
|
||||
|
||||
func runAdmission(lim *admission.Limiter, store *recordingStore, cfg *config.ModelConfig, handler echo.HandlerFunc) (*httptest.ResponseRecorder, error) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader("{}"))
|
||||
rec := httptest.NewRecorder()
|
||||
c := echo.New().NewContext(req, rec)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
mw := AdmissionControl(lim, store)
|
||||
err := mw(handler)(c)
|
||||
return rec, err
|
||||
}
|
||||
|
||||
var _ = Describe("Admission", func() {
|
||||
It("allows when under limit", func() {
|
||||
lim := admission.New()
|
||||
cfg := &config.ModelConfig{Limits: config.LimitsConfig{MaxConcurrent: 2}}
|
||||
cfg.Name = "m"
|
||||
rec, err := runAdmission(lim, &recordingStore{}, cfg, func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("rejects when full", func() {
|
||||
// Saturate the limiter outside the middleware, then a request
|
||||
// at the same model gets 503 with a Retry-After header.
|
||||
lim := admission.New()
|
||||
release, ok := lim.Acquire("busy", 1)
|
||||
Expect(ok).To(BeTrue(), "setup acquire should succeed")
|
||||
defer release()
|
||||
|
||||
cfg := &config.ModelConfig{Limits: config.LimitsConfig{MaxConcurrent: 1, RetryAfterSeconds: 3}}
|
||||
cfg.Name = "busy"
|
||||
store := &recordingStore{}
|
||||
handlerCalled := false
|
||||
rec, err := runAdmission(lim, store, cfg, func(c echo.Context) error {
|
||||
handlerCalled = true
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(rec.Code).To(Equal(http.StatusServiceUnavailable))
|
||||
Expect(rec.Header().Get("Retry-After")).To(Equal("3"))
|
||||
Expect(handlerCalled).To(BeFalse(), "handler should not run when admission rejects")
|
||||
Expect(rec.Body.String()).To(ContainSubstring("admission_rejected"))
|
||||
Expect(store.events).To(HaveLen(1))
|
||||
Expect(store.events[0].Kind).To(Equal(pii.KindAdmission))
|
||||
Expect(store.events[0].Host).To(Equal("busy"), "audit row carries the model name")
|
||||
})
|
||||
|
||||
It("no limit configured is no-op", func() {
|
||||
// MaxConcurrent=0 means unlimited — handler always runs and no
|
||||
// audit row is written even after many calls.
|
||||
lim := admission.New()
|
||||
cfg := &config.ModelConfig{}
|
||||
cfg.Name = "open"
|
||||
store := &recordingStore{}
|
||||
for i := 0; i < 10; i++ {
|
||||
rec, err := runAdmission(lim, store, cfg, func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
}
|
||||
Expect(store.events).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("releases after handler", func() {
|
||||
// One slot, two SEQUENTIAL requests: the second succeeds because
|
||||
// the first's release runs on handler return.
|
||||
lim := admission.New()
|
||||
cfg := &config.ModelConfig{Limits: config.LimitsConfig{MaxConcurrent: 1}}
|
||||
cfg.Name = "tight"
|
||||
for i := 0; i < 3; i++ {
|
||||
rec, err := runAdmission(lim, &recordingStore{}, cfg, func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
}
|
||||
})
|
||||
})
|
||||
50
core/http/middleware/context_keys.go
Normal file
50
core/http/middleware/context_keys.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
// Context keys used by routing-module middlewares to communicate with
|
||||
// the usage recorder. Unlike the legacy CONTEXT_LOCALS_KEY_* constants
|
||||
// (which exist for backward-compatible callers), these are the
|
||||
// canonical names for new fields.
|
||||
const (
|
||||
// ContextKeyRequestedModel is set by content-router middleware to
|
||||
// the model name the client originally asked for, before any router
|
||||
// remapping. UsageMiddleware writes this into UsageRecord.RequestedModel.
|
||||
ContextKeyRequestedModel = "routing.requested_model"
|
||||
|
||||
// ContextKeyServedModel is set by content-router middleware to the
|
||||
// model that actually handled the request (post-routing). When no
|
||||
// router runs, callers may leave this unset and the response-reported
|
||||
// model name is used as the served value.
|
||||
ContextKeyServedModel = "routing.served_model"
|
||||
|
||||
// ContextKeyPreFilterPromptTokens / ContextKeyPostFilterPromptTokens
|
||||
// are set by the PII middleware to record how many prompt tokens
|
||||
// the user sent vs how many made it past redaction. When both are
|
||||
// zero or unset, UsageMiddleware uses the response-reported prompt
|
||||
// token count for both — i.e., no filter ran.
|
||||
ContextKeyPreFilterPromptTokens = "routing.pre_filter_prompt_tokens"
|
||||
ContextKeyPostFilterPromptTokens = "routing.post_filter_prompt_tokens"
|
||||
|
||||
// ContextKeyCorrelationID is the join key threaded across PII
|
||||
// events, router decisions, admission events, and usage records.
|
||||
// trace.go middleware sets X-Correlation-ID on the response; this
|
||||
// key mirrors the same value into echo.Context for in-process
|
||||
// propagation without re-parsing the header.
|
||||
ContextKeyCorrelationID = "routing.correlation_id"
|
||||
|
||||
// ContextKeyPromptTokens / ContextKeyCompletionTokens / ContextKeyTotalTokens
|
||||
// are the canonical token counts the request handler measured. Stamping
|
||||
// these from the handler is the only reliable path for streaming
|
||||
// responses, where the SSE chunks may not include a usage block (OpenAI
|
||||
// requires stream_options.include_usage; Anthropic uses a separate
|
||||
// message_delta event shape). UsageMiddleware prefers these context
|
||||
// values over body-parsing.
|
||||
ContextKeyPromptTokens = "routing.prompt_tokens"
|
||||
ContextKeyCompletionTokens = "routing.completion_tokens"
|
||||
ContextKeyTotalTokens = "routing.total_tokens"
|
||||
|
||||
// ContextKeyResponseModel is the model name the handler committed to
|
||||
// in its response payload. UsageMiddleware uses it when neither the
|
||||
// router nor the body-parse path has produced one. Distinct from
|
||||
// ContextKeyServedModel, which is the router's resolved choice.
|
||||
ContextKeyResponseModel = "routing.response_model"
|
||||
)
|
||||
127
core/http/middleware/node_header.go
Normal file
127
core/http/middleware/node_header.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
)
|
||||
|
||||
// NodeHeaderName is the HTTP response header that, when --expose-node-header
|
||||
// is enabled, carries the ID of the distributed-mode worker node that served
|
||||
// the inference request. Off by default: node IDs reveal internal topology
|
||||
// and should not be exposed on a public endpoint.
|
||||
const NodeHeaderName = "X-LocalAI-Node"
|
||||
|
||||
// nodeHeaderWriter wraps an http.ResponseWriter and stamps the X-LocalAI-Node
|
||||
// header lazily on the first Write / WriteHeader / Flush call. The lazy
|
||||
// resolve is what makes this work for streaming: the picked node ID is only
|
||||
// known AFTER the router runs (i.e. on the first SSE chunk), so resolving at
|
||||
// request entry would attach the previous request's routing decision (or
|
||||
// nothing on a cold cache).
|
||||
type nodeHeaderWriter struct {
|
||||
http.ResponseWriter
|
||||
resolve func() string
|
||||
set bool
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) maybeSet() {
|
||||
if w.set {
|
||||
return
|
||||
}
|
||||
w.set = true
|
||||
if id := w.resolve(); id != "" {
|
||||
w.Header().Set(NodeHeaderName, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) Write(b []byte) (int, error) {
|
||||
w.maybeSet()
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (w *nodeHeaderWriter) WriteHeader(code int) {
|
||||
w.maybeSet()
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Flush keeps SSE handlers working: Echo's Response.Flush goes through
|
||||
// http.NewResponseController which walks Unwrap() chains and invokes Flush
|
||||
// on the first wrapper that implements http.Flusher. By implementing it
|
||||
// here we both stamp the header before the underlying writer flushes AND
|
||||
// keep the streaming path alive.
|
||||
func (w *nodeHeaderWriter) Flush() {
|
||||
w.maybeSet()
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack preserves WebSocket / raw-conn handlers that need to take over the
|
||||
// underlying TCP connection (e.g. /v1/realtime). Without this the wrapper
|
||||
// would silently break those endpoints.
|
||||
//
|
||||
// When the underlying writer does not implement http.Hijacker we return
|
||||
// http.ErrNotSupported so callers using errors.Is (notably
|
||||
// http.NewResponseController.Hijack) detect the condition through the
|
||||
// standard sentinel rather than a string-matched custom error.
|
||||
func (w *nodeHeaderWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("hijack not supported: %w", http.ErrNotSupported)
|
||||
}
|
||||
|
||||
// Unwrap lets http.NewResponseController reach through us to find optional
|
||||
// interfaces (CloseNotifier, SetReadDeadline, etc.) on the real writer.
|
||||
func (w *nodeHeaderWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
||||
// ExposeNodeHeader installs a per-request response writer wrapper that
|
||||
// stamps the X-LocalAI-Node header from the per-request holder published
|
||||
// by the distributed router on the first write. Off by default; opted in
|
||||
// via --expose-node-header / LOCALAI_EXPOSE_NODE_HEADER.
|
||||
//
|
||||
// Attribution is per-request correct: the middleware creates a fresh
|
||||
// holder per request, plumbs it through context.Context, and the router
|
||||
// writes the picked node ID for THIS request's routing decision. No
|
||||
// shared loader state, no overwriting across concurrent requests for the
|
||||
// same model on multiple replicas.
|
||||
func ExposeNodeHeader(appCfg *config.ApplicationConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if appCfg == nil || !appCfg.ExposeNodeHeader {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// One holder per request. The pointer is captured both in
|
||||
// the wrapper closure (read side) and in the request
|
||||
// context (write side, accessed by the router via
|
||||
// distributedhdr.Stamp). Both sides point at the same
|
||||
// atomic slot.
|
||||
holder := distributedhdr.NewHolder()
|
||||
|
||||
req := c.Request()
|
||||
c.SetRequest(req.WithContext(distributedhdr.WithHolder(req.Context(), holder)))
|
||||
|
||||
orig := c.Response().Writer
|
||||
wrapper := &nodeHeaderWriter{
|
||||
ResponseWriter: orig,
|
||||
resolve: func() string {
|
||||
return distributedhdr.Load(holder)
|
||||
},
|
||||
}
|
||||
c.Response().Writer = wrapper
|
||||
defer func() {
|
||||
c.Response().Writer = orig
|
||||
}()
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
89
core/http/middleware/node_header_concurrency_test.go
Normal file
89
core/http/middleware/node_header_concurrency_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package middleware
|
||||
|
||||
// Regression coverage for the multi-replica X-LocalAI-Node attribution
|
||||
// bug fixed by this PR.
|
||||
//
|
||||
// Pre-refactor failure mode: ExposeNodeHeader resolved the node ID by
|
||||
// calling ml.LookupNodeID(modelName), which read a single per-modelID
|
||||
// slot in the loader's in-memory store. The distributed router
|
||||
// overwrote that slot on every routing decision, so when N concurrent
|
||||
// requests for the same model were routed to N different replicas, the
|
||||
// header value the wrapper picked up at first-byte time depended on
|
||||
// goroutine interleaving and not on which replica THIS request was
|
||||
// actually sent to.
|
||||
//
|
||||
// The fix routes attribution through a per-request atomic holder
|
||||
// installed by the middleware and stamped by the router via
|
||||
// distributedhdr.Stamp. Each request carries its own slot, so peer
|
||||
// stamps cannot bleed in.
|
||||
//
|
||||
// This spec exercises the exact concurrency pattern the bug required to
|
||||
// reproduce: many goroutines, all running through the same middleware
|
||||
// instance (mirroring e.Use() in production), each stamping a distinct
|
||||
// node ID, each asserting its own response carries the matching ID.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
)
|
||||
|
||||
var _ = Describe("ExposeNodeHeader multi-replica attribution", func() {
|
||||
It("each concurrent request sees the node ID stamped on ITS OWN request, not a peer's", func() {
|
||||
appCfg := &config.ApplicationConfig{ExposeNodeHeader: true}
|
||||
e := echo.New()
|
||||
mw := ExposeNodeHeader(appCfg)
|
||||
|
||||
const N = 64
|
||||
results := make([]string, N)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(N)
|
||||
|
||||
// Drive N concurrent requests through the same middleware
|
||||
// instance. Each handler stamps a distinct, per-request node ID
|
||||
// derived from the request index, then yields before writing
|
||||
// the body so the goroutines have ample opportunity to
|
||||
// interleave. Under the old shared-loader design this is the
|
||||
// configuration that surfaced the bug; under the new
|
||||
// per-request-holder design every request must round-trip its
|
||||
// own stamp.
|
||||
for i := 0; i < N; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
expected := fmt.Sprintf("node-%d", i)
|
||||
handler := func(c echo.Context) error {
|
||||
distributedhdr.Stamp(c.Request().Context(), expected)
|
||||
// Yield to amplify interleaving: if any stamp were
|
||||
// shared across requests, the late writes would
|
||||
// observe peer-stamped values instead of their
|
||||
// own.
|
||||
for j := 0; j < 16; j++ {
|
||||
_ = j
|
||||
}
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
results[i] = rec.Header().Get(NodeHeaderName)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < N; i++ {
|
||||
Expect(results[i]).To(Equal(fmt.Sprintf("node-%d", i)),
|
||||
"request %d must see ITS OWN routing decision in the header, not a peer's", i)
|
||||
}
|
||||
})
|
||||
})
|
||||
201
core/http/middleware/node_header_integration_test.go
Normal file
201
core/http/middleware/node_header_integration_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package middleware_test
|
||||
|
||||
// Route-level integration coverage for the X-LocalAI-Node middleware.
|
||||
//
|
||||
// What this file pins (and why a separate spec on top of the unit tests
|
||||
// in node_header_test.go):
|
||||
//
|
||||
// - The unit tests in node_header_test.go exercise the wrapper by
|
||||
// invoking `mw(handler)(c)` directly against a hand-built
|
||||
// echo.Context. That misses regressions where the contract between
|
||||
// the real Echo router and the wrapper breaks: e.g. middleware
|
||||
// installation via e.Use() loses the wrapper because the framework
|
||||
// re-decorates c.Response().Writer after middleware setup, or a
|
||||
// handler that bypasses c.Response().Writer (writing to some other
|
||||
// captured surface).
|
||||
//
|
||||
// - This spec dispatches a real HTTP request through e.ServeHTTP into
|
||||
// a streaming handler shaped like chat.go's streaming branch: set
|
||||
// SSE headers, write chunks via c.Response().Write, Flush. It
|
||||
// proves that:
|
||||
// 1. Middleware installed via e.Use() is on the writer chain
|
||||
// when the handler runs.
|
||||
// 2. The per-request holder attached by ExposeNodeHeader is
|
||||
// visible to the handler (and, transitively, to anything that
|
||||
// shares the request context, including the SmartRouter).
|
||||
// 3. The wrapper's lazy maybeSet fires on the first underlying
|
||||
// Write/Flush, so X-LocalAI-Node lands on the response map
|
||||
// BEFORE the first body byte is committed.
|
||||
// 4. The header is present in the recorded response (i.e. it
|
||||
// isn't dropped because we tried to set it post-WriteHeader).
|
||||
//
|
||||
// Out of scope (and why):
|
||||
//
|
||||
// - We do NOT wire core/http/endpoints/openai.ChatEndpoint
|
||||
// end-to-end. ChatEndpoint depends on templates.Evaluator, the
|
||||
// MCP NATS client, and the LocalAI Assistant holder; standing
|
||||
// those up just to assert header ordering is out of proportion to
|
||||
// the property under test. The handler used here mirrors
|
||||
// chat.go's streaming branch and exercises the SAME middleware ->
|
||||
// c.Response().Writer -> SSE write path as production. If
|
||||
// chat.go's streaming branch ever stops going through
|
||||
// c.Response().Writer (e.g. it starts using a captured raw
|
||||
// http.ResponseWriter from a different seam), this test will not
|
||||
// notice; guard that with a code review checklist on chat.go.
|
||||
//
|
||||
// - We do NOT spin up the real SmartRouter. The contract between
|
||||
// router and middleware is "router calls distributedhdr.Stamp on
|
||||
// the request context; middleware reads the resulting holder on
|
||||
// first write". A synthetic stamp inside the handler exercises
|
||||
// the same code path; the router's own unit/integration tests
|
||||
// cover the routing decision itself.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
)
|
||||
|
||||
// orderRecorder snapshots the X-LocalAI-Node header value AT THE MOMENT
|
||||
// the underlying writer is asked to commit each event. Any header set on
|
||||
// the response map AFTER the first write/flush is dropped on the wire,
|
||||
// so this is the ground-truth observation a real SSE client would see.
|
||||
type orderRecorder struct {
|
||||
http.ResponseWriter
|
||||
mu sync.Mutex
|
||||
events []string
|
||||
}
|
||||
|
||||
func (o *orderRecorder) record(ev string) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
o.events = append(o.events, ev)
|
||||
}
|
||||
|
||||
func (o *orderRecorder) snapshot() []string {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
out := make([]string, len(o.events))
|
||||
copy(out, o.events)
|
||||
return out
|
||||
}
|
||||
|
||||
func (o *orderRecorder) WriteHeader(code int) {
|
||||
o.record(fmt.Sprintf("header:%d:node=%s", code, o.Header().Get(middleware.NodeHeaderName)))
|
||||
o.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (o *orderRecorder) Write(b []byte) (int, error) {
|
||||
o.record(fmt.Sprintf("write:node=%s", o.Header().Get(middleware.NodeHeaderName)))
|
||||
return o.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (o *orderRecorder) Flush() {
|
||||
o.record(fmt.Sprintf("flush:node=%s", o.Header().Get(middleware.NodeHeaderName)))
|
||||
if f, ok := o.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("ExposeNodeHeader middleware (route-level integration)", func() {
|
||||
const fakeNodeID = "node-route-7"
|
||||
|
||||
var appCfg *config.ApplicationConfig
|
||||
|
||||
BeforeEach(func() {
|
||||
appCfg = config.NewApplicationConfig()
|
||||
appCfg.ExposeNodeHeader = true
|
||||
})
|
||||
|
||||
It("stamps X-LocalAI-Node before the first SSE byte via the real router + middleware chain", func() {
|
||||
// Build a real Echo router. We need the tracker to sit BELOW
|
||||
// the ExposeNodeHeader wrapper in the writer chain (so its
|
||||
// recorded snapshot reflects what bytes-on-the-wire see AFTER
|
||||
// the wrapper has had a chance to stamp the header). Install
|
||||
// the tracker via a middleware that runs BEFORE
|
||||
// ExposeNodeHeader; Echo's middleware execution order matches
|
||||
// e.Use() call order, so the first Use() wraps the OUTER
|
||||
// layer of the writer chain (i.e. the wrapper installed by
|
||||
// the second Use() wraps the tracker installed by the first).
|
||||
var (
|
||||
recorderMu sync.Mutex
|
||||
tracker *orderRecorder
|
||||
)
|
||||
e := echo.New()
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
recorderMu.Lock()
|
||||
tracker = &orderRecorder{ResponseWriter: c.Response().Writer}
|
||||
c.Response().Writer = tracker
|
||||
recorderMu.Unlock()
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
e.Use(middleware.ExposeNodeHeader(appCfg))
|
||||
|
||||
e.POST("/v1/chat/completions", func(c echo.Context) error {
|
||||
// Simulate the SmartRouter publishing the picked node ID
|
||||
// into the per-request holder installed by the middleware.
|
||||
// In production this happens inside ModelLoader.Load via
|
||||
// distributedhdr.Stamp(ctx, result.Node.ID).
|
||||
distributedhdr.Stamp(c.Request().Context(), fakeNodeID)
|
||||
|
||||
// SSE response prelude (same shape as chat.go).
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Emit a handful of SSE chunks. The very first
|
||||
// Write/Flush is what triggers the middleware
|
||||
// wrapper's maybeSet, so the X-LocalAI-Node header
|
||||
// MUST already be on the response map by the time the
|
||||
// byte is committed.
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := c.Response().Write([]byte(fmt.Sprintf("data: chunk %d\n\n", i)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(""))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
recorderMu.Lock()
|
||||
Expect(tracker).ToNot(BeNil(), "handler must run and install the order recorder")
|
||||
events := tracker.snapshot()
|
||||
recorderMu.Unlock()
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Header().Get(middleware.NodeHeaderName)).To(Equal(fakeNodeID),
|
||||
"production contract: header must reach the wire on a streamed response")
|
||||
|
||||
Expect(events).ToNot(BeEmpty(),
|
||||
"expected at least one underlying-writer event from the streaming handler")
|
||||
|
||||
// The very first observed event is the moment the wrapper
|
||||
// commits to the wire. Its recorded node= value is what a
|
||||
// real HTTP client would actually see; anything that lands
|
||||
// AFTER this byte is invisible.
|
||||
first := events[0]
|
||||
Expect(first).To(ContainSubstring("node="+fakeNodeID),
|
||||
"first writer event must carry the X-LocalAI-Node header (chain: middleware.Use -> e.POST -> handler.Write/Flush); got events: %v", events)
|
||||
|
||||
// Body sanity: SSE chunks made it to the recorder.
|
||||
Expect(rec.Body.String()).To(ContainSubstring("data: chunk 0"))
|
||||
})
|
||||
})
|
||||
260
core/http/middleware/node_header_test.go
Normal file
260
core/http/middleware/node_header_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
)
|
||||
|
||||
// orderedWriter records the order in which header-snapshot vs body-byte
|
||||
// events happen. Used by the streaming spec to assert that the X-LocalAI-Node
|
||||
// header lands on the response BEFORE the first body byte is committed to
|
||||
// the underlying writer.
|
||||
type orderedWriter struct {
|
||||
http.ResponseWriter
|
||||
events []string
|
||||
}
|
||||
|
||||
func (o *orderedWriter) WriteHeader(code int) {
|
||||
o.events = append(o.events, "header:"+http.StatusText(code))
|
||||
o.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (o *orderedWriter) Write(b []byte) (int, error) {
|
||||
// Snapshot the X-LocalAI-Node header value AT THE INSTANT the underlying
|
||||
// writer is asked to commit bytes. This is what real HTTP clients
|
||||
// effectively observe: anything set on the header map AFTER this point
|
||||
// would be silently dropped.
|
||||
o.events = append(o.events, "write:node="+o.Header().Get(NodeHeaderName))
|
||||
return o.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (o *orderedWriter) Flush() {
|
||||
o.events = append(o.events, "flush:node="+o.Header().Get(NodeHeaderName))
|
||||
if f, ok := o.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("ExposeNodeHeader middleware", func() {
|
||||
const (
|
||||
fakeNodeID = "node-abcdef"
|
||||
)
|
||||
|
||||
var (
|
||||
e *echo.Echo
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
appCfg = &config.ApplicationConfig{}
|
||||
})
|
||||
|
||||
// run executes the middleware against a fake handler. The handler may
|
||||
// reach into the per-request context to stamp the holder (simulating
|
||||
// what the distributed router does in production); the wrapper reads
|
||||
// the holder lazily on the first underlying write.
|
||||
run := func(handler echo.HandlerFunc) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
mw := ExposeNodeHeader(appCfg)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
When("ExposeNodeHeader is false", func() {
|
||||
It("does not set the X-LocalAI-Node header", func() {
|
||||
appCfg.ExposeNodeHeader = false
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
// Even if a router were to stamp, with the flag off
|
||||
// there is no holder on the context so Stamp is a no-op.
|
||||
distributedhdr.Stamp(c.Request().Context(), fakeNodeID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("does not even install the wrapper (writer is unchanged)", func() {
|
||||
appCfg.ExposeNodeHeader = false
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
origWriter := c.Response().Writer
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
// Pass-through must leave the writer identity intact so
|
||||
// no overhead is added on the hot path when the feature
|
||||
// is off.
|
||||
Expect(c.Response().Writer).To(BeIdenticalTo(origWriter))
|
||||
// And no holder is attached to the request context.
|
||||
Expect(distributedhdr.Holder(c.Request().Context())).To(BeNil())
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
mw := ExposeNodeHeader(appCfg)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true and the router stamps a node ID", func() {
|
||||
It("sets the X-LocalAI-Node header on a buffered response", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
distributedhdr.Stamp(c.Request().Context(), fakeNodeID)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
|
||||
It("sets the header even on a 500 error response (Write still triggers maybeSet)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
distributedhdr.Stamp(c.Request().Context(), fakeNodeID)
|
||||
return c.String(http.StatusInternalServerError, "boom")
|
||||
})
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusInternalServerError))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
|
||||
It("installs a holder on the request context that the router can find", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
|
||||
var observed *string
|
||||
rec := run(func(c echo.Context) error {
|
||||
h := distributedhdr.Holder(c.Request().Context())
|
||||
Expect(h).ToNot(BeNil(), "middleware must attach a per-request holder when the flag is on")
|
||||
|
||||
distributedhdr.Stamp(c.Request().Context(), fakeNodeID)
|
||||
got := distributedhdr.Load(h)
|
||||
observed = &got
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(observed).ToNot(BeNil())
|
||||
Expect(*observed).To(Equal(fakeNodeID))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("ExposeNodeHeader is true but the router never stamps", func() {
|
||||
It("does not set the header (in-process model, not distributed)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
|
||||
rec := run(func(c echo.Context) error {
|
||||
// Holder is present but nothing ever stamps it - this is
|
||||
// the in-process / non-distributed path.
|
||||
Expect(distributedhdr.Holder(c.Request().Context())).ToNot(BeNil())
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
When("the handler streams via Flush before any Write", func() {
|
||||
It("sets the header BEFORE the first byte hits the underlying writer", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
|
||||
// Wrap the recorder with an order-tracking writer so we can
|
||||
// assert that the header is on the response map by the time
|
||||
// the first body byte is committed. This is the property
|
||||
// that protected the pre-refactor streaming bug: if the
|
||||
// wrapper stamped lazily but AFTER the byte commit, real
|
||||
// SSE clients would see the body without the header.
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
tracker := &orderedWriter{ResponseWriter: rec}
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Writer = tracker
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
// Simulate the router publishing the picked node ID
|
||||
// mid-request, then an SSE stream emitting chunks.
|
||||
distributedhdr.Stamp(c.Request().Context(), fakeNodeID)
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Flush()
|
||||
_, err := c.Response().Write([]byte("data: chunk\n\n"))
|
||||
return err
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
// First recorded event on the underlying writer must show
|
||||
// the header already populated. The first event is either
|
||||
// flush or write; either way the node ID must be on it.
|
||||
Expect(tracker.events).ToNot(BeEmpty())
|
||||
Expect(tracker.events[0]).To(HavePrefix("flush:node=" + fakeNodeID))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("the handler writes a body without an explicit WriteHeader", func() {
|
||||
It("still stamps the header before the implicit 200 commit", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
tracker := &orderedWriter{ResponseWriter: rec}
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Writer = tracker
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
distributedhdr.Stamp(c.Request().Context(), fakeNodeID)
|
||||
_, err := c.Response().Write([]byte("body"))
|
||||
return err
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
// Echo's Response.Write calls WriteHeader on the underlying
|
||||
// writer first, then Write. Both must see the header
|
||||
// already populated (the wrapper's maybeSet ran inside both
|
||||
// WriteHeader and Write before they hit `tracker`).
|
||||
Expect(len(tracker.events)).To(BeNumerically(">=", 2))
|
||||
Expect(tracker.events[0]).To(HavePrefix("header:"))
|
||||
Expect(tracker.events[1]).To(Equal("write:node=" + fakeNodeID))
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal(fakeNodeID))
|
||||
})
|
||||
})
|
||||
|
||||
When("the router stamps after request entry but before first write", func() {
|
||||
It("uses the value present AT the first write (late binding)", func() {
|
||||
appCfg.ExposeNodeHeader = true
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
// Simulate the router making a routing decision after
|
||||
// the handler has already started running but before the
|
||||
// first byte hits the wire. The wrapper must read the
|
||||
// holder lazily, not eagerly at request entry.
|
||||
distributedhdr.Stamp(c.Request().Context(), "fresh-node-B")
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
|
||||
mw := ExposeNodeHeader(appCfg)
|
||||
Expect(mw(handler)(c)).To(Succeed())
|
||||
|
||||
Expect(rec.Header().Get(NodeHeaderName)).To(Equal("fresh-node-B"),
|
||||
"the wrapper must read the node ID lazily at first write, not eagerly at entry")
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -220,6 +221,7 @@ func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
|
||||
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
ctxWithCorrelationID = distributedhdr.Inherit(ctxWithCorrelationID, reqCtx)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
@@ -308,6 +310,17 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
// Collapse the modern max_completion_tokens alias into the
|
||||
// legacy Maxtokens field so downstream code reads exactly one.
|
||||
// MaxCompletionTokens wins on conflict — it's the canonical
|
||||
// name per OpenAI's deprecation guidance, and a client that
|
||||
// took the trouble to send it intends that value. Clearing
|
||||
// the sibling prevents both names from being emitted if input
|
||||
// is re-marshaled (cloud-proxy passthrough).
|
||||
if input.MaxCompletionTokens != nil {
|
||||
input.Maxtokens = input.MaxCompletionTokens
|
||||
input.MaxCompletionTokens = nil
|
||||
}
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
@@ -621,6 +634,7 @@ func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error {
|
||||
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
ctxWithCorrelationID = distributedhdr.Inherit(ctxWithCorrelationID, reqCtx)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
@@ -156,9 +156,13 @@ var _ = Describe("SetModelAndConfig middleware", func() {
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// The OpenAI chat/completions spec nests the function name under "function":
|
||||
// {"type":"function", "function":{"name":"my_function"}}
|
||||
//
|
||||
// {"type":"function", "function":{"name":"my_function"}}
|
||||
//
|
||||
// The legacy Anthropic-compat shape puts it at the top level:
|
||||
// {"type":"function", "name":"my_function"}
|
||||
//
|
||||
// {"type":"function", "name":"my_function"}
|
||||
//
|
||||
// Both need to reach SetFunctionCallNameString (not SetFunctionCallString,
|
||||
// which is the mode field "none"/"auto"/"required").
|
||||
//
|
||||
@@ -550,4 +554,46 @@ var _ = Describe("SetModelAndConfig tool_choice parsing (chat completions)", fun
|
||||
Expect(capturedConfig.FunctionToCall()).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
// OpenAI deprecated max_tokens in favour of max_completion_tokens
|
||||
// (gpt-5 / o-series reject the legacy name). The middleware accepts
|
||||
// both and collapses to the legacy internal Maxtokens field so
|
||||
// downstream code reads exactly one.
|
||||
Context("max_completion_tokens alias", func() {
|
||||
chatReqMaxTokens := func(fields string) string {
|
||||
return `{"model":"test-model",` +
|
||||
`"messages":[{"role":"user","content":"hi"}],` +
|
||||
fields + `}`
|
||||
}
|
||||
|
||||
It("accepts the modern max_completion_tokens name", func() {
|
||||
rec := postJSON(app, "/v1/chat/completions",
|
||||
chatReqMaxTokens(`"max_completion_tokens":64`))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(capturedConfig).ToNot(BeNil())
|
||||
Expect(capturedConfig.Maxtokens).ToNot(BeNil())
|
||||
Expect(*capturedConfig.Maxtokens).To(Equal(64))
|
||||
})
|
||||
|
||||
It("still accepts the legacy max_tokens name", func() {
|
||||
rec := postJSON(app, "/v1/chat/completions",
|
||||
chatReqMaxTokens(`"max_tokens":48`))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(capturedConfig).ToNot(BeNil())
|
||||
Expect(capturedConfig.Maxtokens).ToNot(BeNil())
|
||||
Expect(*capturedConfig.Maxtokens).To(Equal(48))
|
||||
})
|
||||
|
||||
It("prefers max_completion_tokens when both are set", func() {
|
||||
rec := postJSON(app, "/v1/chat/completions",
|
||||
chatReqMaxTokens(`"max_tokens":48,"max_completion_tokens":64`))
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(capturedConfig).ToNot(BeNil())
|
||||
Expect(capturedConfig.Maxtokens).ToNot(BeNil())
|
||||
Expect(*capturedConfig.Maxtokens).To(Equal(64))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
603
core/http/middleware/route_model.go
Normal file
603
core/http/middleware/route_model.go
Normal file
@@ -0,0 +1,603 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/xlog"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ScorerFactory returns a backend.Scorer bound to a named classifier
|
||||
// model. The score classifier uses it to compute joint log-prob of
|
||||
// every policy label against the routing prompt.
|
||||
type ScorerFactory func(modelName string) backend.Scorer
|
||||
|
||||
// EmbedderFactory returns a backend.Embedder bound to a named model.
|
||||
// Used by the L2 embedding cache. Returning nil signals "model not
|
||||
// loadable" — the middleware then falls back to the uncached
|
||||
// classifier so routing still happens.
|
||||
type EmbedderFactory func(modelName string) backend.Embedder
|
||||
|
||||
// VectorStoreFactory returns a backend.VectorStore bound to a named
|
||||
// collection. Each router model's cache lives in its own collection
|
||||
// so two routers can't poison each other's hits.
|
||||
type VectorStoreFactory func(storeName string) backend.VectorStore
|
||||
|
||||
// RerankerFactory returns a backend.Reranker bound to a named model.
|
||||
// Used by the colbert classifier to score policy descriptions against
|
||||
// the prompt via LocalAI's rerankers backend. Returning nil signals
|
||||
// "model not loadable" — buildClassifier reports a config error.
|
||||
type RerankerFactory func(modelName string) backend.Reranker
|
||||
|
||||
// ModelConfigLookup resolves a model name to its config, or nil when
|
||||
// unknown. Used by buildClassifier to confirm the classifier_model
|
||||
// declared the score usecase — the actual usecase-conflict check
|
||||
// lives in ModelConfig.Validate() and runs at config load/save time.
|
||||
type ModelConfigLookup func(modelName string) *config.ModelConfig
|
||||
|
||||
// ClassifierDeps bundles the backend factories the router middleware
|
||||
// needs to build a classifier and its optional L2 cache. Bundled into
|
||||
// one struct because RouteModel already takes many positional
|
||||
// arguments — additions to the dependency surface go here instead of
|
||||
// growing the signature.
|
||||
//
|
||||
// Embedder and VectorStore are optional: when both are non-nil and the
|
||||
// router config declares an embedding_cache block, the score
|
||||
// classifier is wrapped in EmbeddingCacheClassifier. Otherwise the
|
||||
// score classifier runs unwrapped and the embedding-cache YAML is
|
||||
// ignored with a warning.
|
||||
type ClassifierDeps struct {
|
||||
Scorer ScorerFactory
|
||||
Embedder EmbedderFactory
|
||||
VectorStore VectorStoreFactory
|
||||
Reranker RerankerFactory
|
||||
|
||||
// ModelLookup resolves the classifier_model name to its config so
|
||||
// buildClassifier can reject misconfigurations that would
|
||||
// otherwise crash the llama-cpp backend at request time. Optional
|
||||
// — when nil, the check is skipped (tests, embedded callers that
|
||||
// haven't wired the loader).
|
||||
ModelLookup ModelConfigLookup
|
||||
|
||||
// Registry is the shared classifier cache. Both the OpenAI and
|
||||
// Anthropic routes pass the same registry so the admin stats
|
||||
// endpoint sees every live classifier. Nil falls back to a local
|
||||
// registry — tests that don't need cross-route stats use this.
|
||||
Registry *router.Registry
|
||||
|
||||
// Evaluator renders the classifier model's chat template around
|
||||
// the routing system + user prompt. Optional — when nil, the
|
||||
// score classifier falls back to a built-in ChatML envelope,
|
||||
// which is correct for Arch-Router/Qwen but wrong for non-ChatML
|
||||
// routing models. Production wiring passes the app-wide
|
||||
// templates.Evaluator so any model the operator points at gets
|
||||
// its own chat template applied.
|
||||
Evaluator *templates.Evaluator
|
||||
}
|
||||
|
||||
// ProbeExtractor pulls the prompt content out of a parsed request so
|
||||
// the classifier can inspect it without taking a dependency on the
|
||||
// schema package. One extractor per request shape — wired by the
|
||||
// route registration site (mirrors the piiadapter pattern).
|
||||
//
|
||||
// Returns ok=false when the parsed value isn't the expected type — the
|
||||
// middleware then passes through without engaging the router.
|
||||
type ProbeExtractor func(parsed any) (router.Probe, bool)
|
||||
|
||||
// RouteModel runs after SetModelAndConfig and the schema-specific
|
||||
// SetXRequest, looks at the resolved model's Router config, and (when
|
||||
// present) reclassifies the request to one of the candidates.
|
||||
//
|
||||
// The middleware:
|
||||
//
|
||||
// 1. Loads MODEL_CONFIG from the echo context. If nil or HasRouter()
|
||||
// is false, passes through.
|
||||
// 2. Extracts the probe via the supplied ProbeExtractor.
|
||||
// 3. Invokes the classifier matching cfg.Router.Classifier
|
||||
// ("score" or "colbert"). If the classifier can't be built —
|
||||
// missing classifier_model, misconfigured policies, etc. — the
|
||||
// request fails with 503. cfg.Router.Fallback only catches
|
||||
// Classify-time errors and label-coverage misses, not config
|
||||
// bugs that would otherwise be silent.
|
||||
// 4. Resolves the chosen candidate to its model name. Reloads the
|
||||
// ModelConfig for that model and asserts depth-1 (the candidate
|
||||
// must NOT itself have a Router). Violation returns 500 — config
|
||||
// bug, not a request bug.
|
||||
// 5. Updates input.Model in place, replaces MODEL_CONFIG with the
|
||||
// candidate's config, and stamps RequestedModel/ServedModel on the
|
||||
// context so UsageMiddleware records the routing.
|
||||
// 6. Writes a DecisionRecord to the store for the admin page.
|
||||
//
|
||||
// store may be nil when --disable-stats turns off the routing log;
|
||||
// classification still runs.
|
||||
//
|
||||
// Composition with SmartRouter (distributed mode): this middleware
|
||||
// only does *model* selection. Node selection still happens in
|
||||
// SmartRouter.Route() downstream of this middleware.
|
||||
// RouteModel wires the router middleware. source is the value written to
|
||||
// DecisionRecord.Source (router.SourceChat / SourceAnthropic / ...) so
|
||||
// the admin page can split decisions by entry point. Pass
|
||||
// router.SourceChat for the OpenAI chat endpoint, router.SourceAnthropic
|
||||
// for the Anthropic messages endpoint.
|
||||
func RouteModel(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, store router.DecisionStore, fallbackUser *auth.User, extractor ProbeExtractor, source string, deps ClassifierDeps) echo.MiddlewareFunc {
|
||||
registry := deps.Registry
|
||||
if registry == nil {
|
||||
registry = router.NewRegistry()
|
||||
}
|
||||
candidateLoader := func(name string) (*config.ModelConfig, error) {
|
||||
return loader.LoadModelConfigFileByNameDefaultOptions(name, appConfig)
|
||||
}
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil || !cfg.HasRouter() {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
parsed := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST)
|
||||
if parsed == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
probe, probeOK := extractor(parsed)
|
||||
if !probeOK {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
classifier, err := GetOrBuildClassifier(registry, cfg, deps)
|
||||
if err != nil {
|
||||
// Build-time failures are config bugs (missing
|
||||
// classifier_model, undeclared usecase, policy
|
||||
// validation, ...). Silently falling back would hide
|
||||
// them and make the router look "working" while the
|
||||
// classifier model is never invoked — surface as 503
|
||||
// with the underlying reason so operators see it.
|
||||
xlog.Warn("router: classifier build failed",
|
||||
"router_model", cfg.Name, "classifier", cfg.Router.Classifier, "error", err)
|
||||
return echo.NewHTTPError(503, "router classifier unavailable: "+err.Error())
|
||||
}
|
||||
|
||||
result, err := router.Resolve(c.Request().Context(), cfg, classifier, candidateLoader, probe)
|
||||
if err != nil {
|
||||
xlog.Warn("router: resolve failed", "router_model", cfg.Name, "error", err)
|
||||
return echo.NewHTTPError(500, err.Error())
|
||||
}
|
||||
|
||||
if req, ok := parsed.(schema.LocalAIRequest); ok {
|
||||
chosen := result.ChosenModel
|
||||
req.ModelName(&chosen)
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, result.ChosenConfig)
|
||||
c.Set(ContextKeyRequestedModel, result.RouterModel)
|
||||
c.Set(ContextKeyServedModel, result.ChosenModel)
|
||||
|
||||
if store != nil {
|
||||
recordHTTPDecision(c, store, result, fallbackUser, source)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordHTTPDecision writes the resolved decision to the store with
|
||||
// HTTP-shaped audit metadata (correlation id from header, user from
|
||||
// auth middleware, fallback to the synthetic local user). Realtime
|
||||
// has its own recorder that supplies session-derived metadata
|
||||
// instead.
|
||||
func recordHTTPDecision(c echo.Context, store router.DecisionStore, result *router.ResolveResult, fallbackUser *auth.User, source string) {
|
||||
correlationID, _ := c.Get(ContextKeyCorrelationID).(string)
|
||||
if correlationID == "" {
|
||||
correlationID = c.Response().Header().Get("X-Correlation-ID")
|
||||
}
|
||||
userID := ""
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
userID = u.ID
|
||||
} else if fallbackUser != nil {
|
||||
userID = fallbackUser.ID
|
||||
}
|
||||
_ = store.Record(context.Background(), result.ToDecisionRecord(newDecisionID(), correlationID, userID, source))
|
||||
}
|
||||
|
||||
|
||||
// GetOrBuildClassifier looks up a built Classifier for the named router
|
||||
// model in the registry and builds it on miss. Exported so the
|
||||
// /api/router/decide decision-oracle endpoint can share the same
|
||||
// build-once cache that the in-band RouteModel middleware uses.
|
||||
func GetOrBuildClassifier(registry *router.Registry, cfg *config.ModelConfig, deps ClassifierDeps) (router.Classifier, error) {
|
||||
// Fingerprint folds the classifier model's renderer-affecting
|
||||
// fields (chat templates + stopwords) in alongside the router
|
||||
// config. Without this, hot-reloading the classifier model's
|
||||
// YAML (via ReloadModelsEndpoint, /import-model, or the MCP
|
||||
// reload_models tool) wouldn't rebuild the cached classifier —
|
||||
// the candidates slice and renderer closure are baked at build
|
||||
// time from those fields and would silently keep the stale
|
||||
// stop token / template until process restart.
|
||||
var classifierCfg *config.ModelConfig
|
||||
if deps.ModelLookup != nil {
|
||||
classifierCfg = deps.ModelLookup(cfg.Router.ClassifierModel)
|
||||
}
|
||||
fp := routerConfigFingerprint(cfg.Router, classifierCfg)
|
||||
if cached, ok := registry.Get(cfg.Name, fp); ok {
|
||||
return cached, nil
|
||||
}
|
||||
c, err := buildClassifier(cfg, deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
registry.Put(cfg.Name, fp, c)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// routerConfigFingerprint is a stable cache key for the (router cfg,
|
||||
// classifier model cfg) tuple. FNV-64 over the YAML form of the
|
||||
// router block plus the renderer-affecting fields of the classifier
|
||||
// model — equality-only, not cryptographic. YAML-marshal picks up
|
||||
// any future RouterConfig field without this function needing to be
|
||||
// touched; for the classifier model we hash a narrow projection so
|
||||
// unrelated changes (parameters, files, ...) don't burst the cache.
|
||||
// Pass classifierCfg=nil when no lookup is wired — the fingerprint
|
||||
// degenerates to the router-only form, matching pre-refactor behaviour.
|
||||
func routerConfigFingerprint(rc config.RouterConfig, classifierCfg *config.ModelConfig) uint64 {
|
||||
bytes, err := yaml.Marshal(rc)
|
||||
if err != nil {
|
||||
// Marshalling a value type can't fail in practice; fall
|
||||
// back to a hash that varies per call so we don't quietly
|
||||
// share a cache entry across distinct configs.
|
||||
return uint64(time.Now().UnixNano())
|
||||
}
|
||||
h := fnv.New64a()
|
||||
h.Write(bytes)
|
||||
if classifierCfg != nil {
|
||||
// Narrow projection: only the fields newTemplateRenderer and
|
||||
// firstStopWord actually read. Hashing the whole ModelConfig
|
||||
// would invalidate the cache on irrelevant parameter changes.
|
||||
h.Write([]byte{0}) // separator so empty fields don't collide
|
||||
h.Write([]byte(classifierCfg.TemplateConfig.Chat))
|
||||
h.Write([]byte{0})
|
||||
h.Write([]byte(classifierCfg.TemplateConfig.ChatMessage))
|
||||
h.Write([]byte{0})
|
||||
for _, sw := range classifierCfg.StopWords {
|
||||
h.Write([]byte(sw))
|
||||
h.Write([]byte{0})
|
||||
}
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func buildClassifier(cfg *config.ModelConfig, deps ClassifierDeps) (router.Classifier, error) {
|
||||
rc := cfg.Router
|
||||
name := rc.Classifier
|
||||
if name == "" {
|
||||
name = router.ClassifierScore
|
||||
}
|
||||
policies, err := validateRouterPolicies(name, rc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cacheCap := rc.ClassifierCacheSize
|
||||
if cacheCap == 0 {
|
||||
cacheCap = 1024
|
||||
}
|
||||
|
||||
var inner router.Classifier
|
||||
switch name {
|
||||
case router.ClassifierScore:
|
||||
if deps.Scorer == nil {
|
||||
return nil, fmt.Errorf("router classifier score unavailable: no scorer factory wired")
|
||||
}
|
||||
if err := assertClassifierDeclaresScore(rc.ClassifierModel, deps.ModelLookup); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scorer := deps.Scorer(rc.ClassifierModel)
|
||||
if scorer == nil {
|
||||
return nil, fmt.Errorf("router classifier score: classifier_model %q not loadable", rc.ClassifierModel)
|
||||
}
|
||||
opts := router.ScoreClassifierOptions{
|
||||
CacheCap: cacheCap,
|
||||
ActivationThreshold: rc.ActivationThreshold,
|
||||
Normalization: rc.ScoreNormalization,
|
||||
SystemPromptTemplate: rc.ClassifierSystemTemplate,
|
||||
}
|
||||
// Build the prompt renderer + stop token from the classifier
|
||||
// model's own config when available. Without ModelLookup
|
||||
// (tests, embedded callers) the score classifier's built-in
|
||||
// ChatML defaults kick in, which is correct for Arch-Router.
|
||||
if deps.ModelLookup != nil {
|
||||
if classifierCfg := deps.ModelLookup(rc.ClassifierModel); classifierCfg != nil {
|
||||
if deps.Evaluator != nil {
|
||||
opts.PromptRenderer = newTemplateRenderer(deps.Evaluator, classifierCfg)
|
||||
}
|
||||
if st := pickAssistantTurnEnd(classifierCfg.StopWords, classifierCfg.TemplateConfig.ChatMessage); st != "" {
|
||||
opts.StopToken = st
|
||||
}
|
||||
}
|
||||
}
|
||||
inner = router.NewScoreClassifier(policies, scorer, opts)
|
||||
case router.ClassifierColbert:
|
||||
if deps.Reranker == nil {
|
||||
return nil, fmt.Errorf("router classifier colbert unavailable: no reranker factory wired")
|
||||
}
|
||||
reranker := deps.Reranker(rc.ClassifierModel)
|
||||
if reranker == nil {
|
||||
return nil, fmt.Errorf("router classifier colbert: classifier_model %q not loadable", rc.ClassifierModel)
|
||||
}
|
||||
inner = router.NewRerankClassifier(policies, reranker, cacheCap, rc.ActivationThreshold)
|
||||
default:
|
||||
return nil, fmt.Errorf("router: unknown classifier %q (supported: %s)", name, strings.Join([]string{router.ClassifierScore, router.ClassifierColbert}, ", "))
|
||||
}
|
||||
|
||||
if rc.EmbeddingCache == nil {
|
||||
return inner, nil
|
||||
}
|
||||
wrapped, err := wrapWithEmbeddingCache(cfg, inner, deps)
|
||||
if err != nil {
|
||||
// Caching plumbing problems must not break routing — log,
|
||||
// drop the cache layer, and return the uncached classifier.
|
||||
// The admin UI surfaces the warning via the classifier-build
|
||||
// error path used elsewhere.
|
||||
xlog.Warn("router: embedding cache disabled",
|
||||
"router_model", cfg.Name, "error", err)
|
||||
return inner, nil
|
||||
}
|
||||
return wrapped, nil
|
||||
}
|
||||
|
||||
// assertClassifierDeclaresScore refuses to build the score classifier
|
||||
// unless classifier_model's config declares FLAG_SCORE. The actual
|
||||
// usecase-conflict check (score + chat/completion/embeddings on
|
||||
// llama-cpp) lives in ModelConfig.Validate() and fires at config load
|
||||
// and save time — by the time we get here, any model that reached the
|
||||
// loader is already conflict-free. This check just refuses to bind a
|
||||
// model that never declared itself for Score in the first place; that
|
||||
// model could be a misconfigured chat model the operator pointed at
|
||||
// by accident, and without FLAG_SCORE the validator never saw it.
|
||||
//
|
||||
// When lookup is nil (test wiring) the check is skipped and we fall
|
||||
// back to the C++ backend's runtime tripwire as the last line of
|
||||
// defence.
|
||||
func assertClassifierDeclaresScore(classifierModel string, lookup ModelConfigLookup) error {
|
||||
if lookup == nil {
|
||||
return nil
|
||||
}
|
||||
cfg := lookup(classifierModel)
|
||||
if cfg == nil {
|
||||
// Unknown model — Scorer() will produce a clearer "not
|
||||
// loadable" error a few lines down.
|
||||
return nil
|
||||
}
|
||||
if !cfg.HasUsecases(config.FLAG_SCORE) {
|
||||
return fmt.Errorf(
|
||||
"router classifier score: classifier_model %q does not declare the "+
|
||||
"score usecase. Add `known_usecases: [score]` to its config so "+
|
||||
"the loader can reject conflicting usecase combinations",
|
||||
classifierModel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRouterPolicies checks the shared invariants both classifiers
|
||||
// rely on (non-empty policies, every candidate label declared as a
|
||||
// policy, every candidate has a model + at least one label) and
|
||||
// returns the parsed []ScorePolicy. Both Score and Rerank classifiers
|
||||
// take the same policy shape.
|
||||
func validateRouterPolicies(classifierName string, rc config.RouterConfig) ([]router.ScorePolicy, error) {
|
||||
if rc.ClassifierModel == "" {
|
||||
return nil, fmt.Errorf("router classifier %s requires classifier_model", classifierName)
|
||||
}
|
||||
if len(rc.Policies) == 0 {
|
||||
return nil, fmt.Errorf("router classifier %s requires at least one policy", classifierName)
|
||||
}
|
||||
policies := make([]router.ScorePolicy, 0, len(rc.Policies))
|
||||
for _, p := range rc.Policies {
|
||||
if p.Label == "" {
|
||||
return nil, fmt.Errorf("router classifier %s: policy with empty label", classifierName)
|
||||
}
|
||||
if p.Description == "" {
|
||||
return nil, fmt.Errorf("router classifier %s: policy %q has no description", classifierName, p.Label)
|
||||
}
|
||||
policies = append(policies, router.ScorePolicy{Label: p.Label, Description: p.Description})
|
||||
}
|
||||
policyLabels := make(map[string]struct{}, len(policies))
|
||||
for _, p := range policies {
|
||||
policyLabels[p.Label] = struct{}{}
|
||||
}
|
||||
for _, c := range rc.Candidates {
|
||||
if c.Model == "" {
|
||||
return nil, fmt.Errorf("router classifier %s: candidate has empty model field", classifierName)
|
||||
}
|
||||
if len(c.Labels) == 0 {
|
||||
return nil, fmt.Errorf("router classifier %s: candidate %q has no labels", classifierName, c.Model)
|
||||
}
|
||||
for _, l := range c.Labels {
|
||||
if _, ok := policyLabels[l]; !ok {
|
||||
return nil, fmt.Errorf("router classifier %s: candidate %q references unknown label %q (not in policies)", classifierName, c.Model, l)
|
||||
}
|
||||
}
|
||||
}
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
// newTemplateRenderer adapts the templates.Evaluator + the classifier
|
||||
// model's config into the router.PromptRenderer callback. The
|
||||
// resulting renderer pushes the routing system + user prompt through
|
||||
// the classifier model's full chat-template pipeline — per-role
|
||||
// formatting via TemplateConfig.ChatMessage, then the outer
|
||||
// TemplateConfig.Chat — so non-ChatML routing models render
|
||||
// correctly without router-package awareness of the template format.
|
||||
//
|
||||
// We must go through TemplateMessages, not EvaluateTemplateForPrompt
|
||||
// directly: the gallery's outer Chat templates are uniformly
|
||||
// `{{.Input -}}<|im_start|>assistant` (or the Llama-3 equivalent)
|
||||
// and reference {{.Input}} only — never {{.SystemPrompt}}. Passing
|
||||
// our routing system prompt through .SystemPrompt would silently
|
||||
// drop it because Go text/template ignores unreferenced fields.
|
||||
// TemplateMessages instead renders each role through ChatMessage and
|
||||
// joins them into the .Input the outer template DOES read.
|
||||
//
|
||||
// Returns nil (forcing the score classifier's chatMLRenderer
|
||||
// fallback) when either template piece is missing — partial
|
||||
// templating would still drop content.
|
||||
func newTemplateRenderer(eval *templates.Evaluator, classifierCfg *config.ModelConfig) router.PromptRenderer {
|
||||
if classifierCfg.TemplateConfig.Chat == "" || classifierCfg.TemplateConfig.ChatMessage == "" {
|
||||
return nil
|
||||
}
|
||||
cfgCopy := *classifierCfg
|
||||
return func(system, user string) (string, error) {
|
||||
messages := []schema.Message{
|
||||
{Role: "system", StringContent: system},
|
||||
{Role: "user", StringContent: user},
|
||||
}
|
||||
rendered := eval.TemplateMessages(schema.OpenAIRequest{}, messages, &cfgCopy, nil, false)
|
||||
if rendered == "" {
|
||||
return "", fmt.Errorf("router: classifier %q chat template produced empty output", cfgCopy.Name)
|
||||
}
|
||||
return rendered, nil
|
||||
}
|
||||
}
|
||||
|
||||
// pickAssistantTurnEnd returns the classifier model's assistant
|
||||
// turn-end token — the one to suffix candidates with so the model's
|
||||
// "I'm done" signal folds into the per-candidate joint log-prob.
|
||||
//
|
||||
// Strategy: prefer the stopword that *literally appears* in the
|
||||
// chat_message template, because that token is the assistant
|
||||
// turn-end by construction. ChatML's chat_message ends with
|
||||
// "<|im_end|>", Llama-3's ends with "<|eot_id|>", etc. — the
|
||||
// template is the source of truth.
|
||||
//
|
||||
// Fallback: the first non-empty stopword. That's right for
|
||||
// well-ordered configs (ChatML conventionally lists <|im_end|>
|
||||
// first) but wrong for some gallery Llama-3 templates that defensively
|
||||
// list <|im_end|> first even though the actual turn-end is <|eot_id|>.
|
||||
// The template-scan above catches those.
|
||||
//
|
||||
// When no stopwords are configured at all, return "" — caller falls
|
||||
// back to defaultStopToken (<|im_end|>) inside the score classifier.
|
||||
func pickAssistantTurnEnd(words []string, chatMessageTemplate string) string {
|
||||
if chatMessageTemplate != "" {
|
||||
for _, w := range words {
|
||||
if w != "" && strings.Contains(chatMessageTemplate, w) {
|
||||
return w
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, w := range words {
|
||||
if w != "" {
|
||||
return w
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func wrapWithEmbeddingCache(cfg *config.ModelConfig, inner router.Classifier, deps ClassifierDeps) (router.Classifier, error) {
|
||||
ec := cfg.Router.EmbeddingCache
|
||||
if ec.EmbeddingModel == "" {
|
||||
return nil, fmt.Errorf("embedding_cache requires embedding_model")
|
||||
}
|
||||
if deps.Embedder == nil || deps.VectorStore == nil {
|
||||
return nil, fmt.Errorf("embedding cache factories not wired")
|
||||
}
|
||||
embedder := deps.Embedder(ec.EmbeddingModel)
|
||||
if embedder == nil {
|
||||
return nil, fmt.Errorf("embedding_model %q not loadable", ec.EmbeddingModel)
|
||||
}
|
||||
storeName := ec.StoreName
|
||||
if storeName == "" {
|
||||
storeName = "router-cache-" + cfg.Name
|
||||
}
|
||||
vstore := deps.VectorStore(storeName)
|
||||
if vstore == nil {
|
||||
return nil, fmt.Errorf("vector store %q not loadable", storeName)
|
||||
}
|
||||
return router.NewEmbeddingCacheClassifier(inner, embedder, vstore, ec.SimilarityThreshold, ec.ConfidenceThreshold), nil
|
||||
}
|
||||
|
||||
func newDecisionID() string {
|
||||
var b [12]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return "rd_" + hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
// OpenAIProbe extracts a router.Probe from a parsed *schema.OpenAIRequest.
|
||||
// Concatenates message contents (string-form or text blocks of the
|
||||
// structured `[]any` content) so the classifier sees a single corpus
|
||||
// for length and content-shape rules. Image blocks are skipped — a
|
||||
// future multimodal classifier can take a different route.
|
||||
func OpenAIProbe(parsed any) (router.Probe, bool) {
|
||||
req, ok := parsed.(*schema.OpenAIRequest)
|
||||
if !ok || req == nil {
|
||||
return router.Probe{}, false
|
||||
}
|
||||
return OpenAIProbeFromRequest(req), true
|
||||
}
|
||||
|
||||
// OpenAIProbeFromRequest is the typed counterpart of OpenAIProbe — same
|
||||
// extraction logic, but takes the request struct directly. Realtime and
|
||||
// other non-HTTP callers use it to feed a probe to router.Resolve
|
||||
// without going through an echo.Context first.
|
||||
func OpenAIProbeFromRequest(req *schema.OpenAIRequest) router.Probe {
|
||||
if req == nil {
|
||||
return router.Probe{}
|
||||
}
|
||||
var b strings.Builder
|
||||
for i := range req.Messages {
|
||||
switch ct := req.Messages[i].Content.(type) {
|
||||
case string:
|
||||
b.WriteString(ct)
|
||||
b.WriteByte('\n')
|
||||
case []any:
|
||||
for _, block := range ct {
|
||||
if bm, ok := block.(map[string]any); ok && bm["type"] == "text" {
|
||||
if t, ok := bm["text"].(string); ok {
|
||||
b.WriteString(t)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return router.Probe{Prompt: b.String()}
|
||||
}
|
||||
|
||||
// AnthropicProbe is the AnthropicRequest analogue of OpenAIProbe.
|
||||
func AnthropicProbe(parsed any) (router.Probe, bool) {
|
||||
req, ok := parsed.(*schema.AnthropicRequest)
|
||||
if !ok || req == nil {
|
||||
return router.Probe{}, false
|
||||
}
|
||||
var b strings.Builder
|
||||
for i := range req.Messages {
|
||||
switch ct := req.Messages[i].Content.(type) {
|
||||
case string:
|
||||
b.WriteString(ct)
|
||||
b.WriteByte('\n')
|
||||
case []any:
|
||||
for _, block := range ct {
|
||||
if bm, ok := block.(map[string]any); ok && bm["type"] == "text" {
|
||||
if t, ok := bm["text"].(string); ok {
|
||||
b.WriteString(t)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return router.Probe{
|
||||
Prompt: b.String(),
|
||||
}, true
|
||||
}
|
||||
|
||||
551
core/http/middleware/route_model_test.go
Normal file
551
core/http/middleware/route_model_test.go
Normal file
@@ -0,0 +1,551 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// The RouteModel middleware wires the score classifier into request
|
||||
// rewriting. The classifier itself is covered in
|
||||
// router/score_test.go — these specs pin the middleware-level
|
||||
// behaviour: candidate matching against the active label set, the
|
||||
// fallback path, and the depth-1 invariant.
|
||||
|
||||
var _ = Describe("RouteModel middleware (score classifier)", func() {
|
||||
var (
|
||||
modelDir string
|
||||
appConfig *config.ApplicationConfig
|
||||
loader *config.ModelConfigLoader
|
||||
store *fakeDecisionStore
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
d, err := os.MkdirTemp("", "router-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
modelDir = d
|
||||
appConfig = &config.ApplicationConfig{
|
||||
Context: context.Background(),
|
||||
SystemState: &system.SystemState{Model: system.Model{ModelsPath: modelDir}},
|
||||
}
|
||||
loader = config.NewModelConfigLoader(modelDir)
|
||||
store = &fakeDecisionStore{}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.RemoveAll(modelDir)
|
||||
})
|
||||
|
||||
It("routes to a candidate whose labels cover the active set", func() {
|
||||
// 3 policies, 2 candidates. Small model has [casual-chat],
|
||||
// bigger has [code-generation, math-reasoning, casual-chat].
|
||||
// A query that activates code-generation should fall to the
|
||||
// bigger candidate because it's the only one that covers it.
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
writeCandidate(modelDir, "small-model")
|
||||
writeCandidate(modelDir, "big-model")
|
||||
|
||||
s := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -0.05, // dominant
|
||||
"casual-chat": -3.0,
|
||||
"math-reasoning": -4.0,
|
||||
}}
|
||||
rec, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("debug my Go null pointer"), stubScorerFactory(s))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Body.String()).To(Equal("served:big-model"))
|
||||
Expect(store.records).To(HaveLen(1))
|
||||
Expect(store.records[0].ServedModel).To(Equal("big-model"))
|
||||
Expect(store.records[0].Label).To(ContainSubstring("code-generation"))
|
||||
})
|
||||
|
||||
It("prefers the smaller candidate when both cover the active set", func() {
|
||||
// Both candidates list casual-chat. Admins order small →
|
||||
// big, so a casual-chat-only request must route to small.
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
writeCandidate(modelDir, "small-model")
|
||||
writeCandidate(modelDir, "big-model")
|
||||
|
||||
s := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -5.0,
|
||||
"casual-chat": -0.05, // dominant
|
||||
"math-reasoning": -5.0,
|
||||
}}
|
||||
rec, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("hi"), stubScorerFactory(s))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(rec.Body.String()).To(Equal("served:small-model"))
|
||||
})
|
||||
|
||||
It("falls back when no candidate covers the active label set", func() {
|
||||
// Only the bigger candidate covers math-reasoning. We
|
||||
// deliberately drop it from the candidates list so neither
|
||||
// matches; expect Fallback to fire.
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
// Remove the second candidate so coverage gap appears.
|
||||
routerCfg.Router.Candidates = routerCfg.Router.Candidates[:1]
|
||||
writeCandidate(modelDir, "small-model")
|
||||
writeCandidate(modelDir, "qwen3-0.6b")
|
||||
|
||||
s := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -5.0,
|
||||
"casual-chat": -5.0,
|
||||
"math-reasoning": -0.05, // dominant — but no candidate has it
|
||||
}}
|
||||
rec, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("3 apples cost $2.40"), stubScorerFactory(s))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(rec.Body.String()).To(Equal("served:qwen3-0.6b"))
|
||||
})
|
||||
|
||||
It("rejects candidates that reference unknown labels at build time", func() {
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
routerCfg.Router.Candidates = append(routerCfg.Router.Candidates, config.RouterCandidate{
|
||||
Model: "broken",
|
||||
Labels: []string{"nonexistent-label"},
|
||||
})
|
||||
writeCandidate(modelDir, "small-model")
|
||||
writeCandidate(modelDir, "big-model")
|
||||
writeCandidate(modelDir, "broken")
|
||||
writeCandidate(modelDir, "qwen3-0.6b")
|
||||
|
||||
s := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -0.05,
|
||||
"casual-chat": -3.0,
|
||||
"math-reasoning": -4.0,
|
||||
}}
|
||||
_, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("debug something"), stubScorerFactory(s))
|
||||
// Build-time config bugs (here: a candidate referencing a
|
||||
// label not declared in policies) must surface to the client
|
||||
// — the previous silent-fallback behaviour hid the broken
|
||||
// config and left operators wondering why traces never showed
|
||||
// the classifier model running.
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("unknown label"))
|
||||
})
|
||||
|
||||
It("returns 500 when the candidate is itself a router (depth-1 invariant)", func() {
|
||||
// The candidate model is itself a router. We must reject
|
||||
// the dispatch — chained routers are deliberately
|
||||
// disallowed.
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
// Bend the test setup: replace one of the candidate-model
|
||||
// configs with a nested-router config.
|
||||
nestedRouter := newScoreRouterModel(modelDir, "small-model")
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "small-model.yaml"), []byte(toYAML(nestedRouter)), 0o644)).To(Succeed())
|
||||
writeCandidate(modelDir, "big-model")
|
||||
writeCandidate(modelDir, "qwen3-0.6b")
|
||||
|
||||
s := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -5.0,
|
||||
"casual-chat": -0.05,
|
||||
"math-reasoning": -5.0,
|
||||
}}
|
||||
_, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("hi"), stubScorerFactory(s))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("depth-1 invariant"))
|
||||
})
|
||||
})
|
||||
|
||||
// Regression coverage for the rendered routing prompt — pins the
|
||||
// guarantee that the routing system prompt (route listing, JSON
|
||||
// output schema) actually reaches the classifier model. The first
|
||||
// implementation of the template-aware renderer routed through
|
||||
// EvaluateTemplateForPrompt, which only invokes the outer Chat
|
||||
// template — and the gallery's outer Chat templates are
|
||||
// `{{.Input -}}<|im_start|>assistant` shape, so .SystemPrompt was
|
||||
// silently dropped. The fix routes through TemplateMessages, which
|
||||
// renders each role through ChatMessage and joins the result into
|
||||
// .Input. These specs would fail loudly if the renderer ever
|
||||
// regresses back to bypassing per-role formatting.
|
||||
var _ = Describe("RouteModel rendered classifier prompt", func() {
|
||||
var (
|
||||
modelDir string
|
||||
appConfig *config.ApplicationConfig
|
||||
loader *config.ModelConfigLoader
|
||||
store *fakeDecisionStore
|
||||
eval *templates.Evaluator
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
d, err := os.MkdirTemp("", "router-render-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
modelDir = d
|
||||
appConfig = &config.ApplicationConfig{
|
||||
Context: context.Background(),
|
||||
SystemState: &system.SystemState{Model: system.Model{ModelsPath: modelDir}},
|
||||
}
|
||||
loader = config.NewModelConfigLoader(modelDir)
|
||||
store = &fakeDecisionStore{}
|
||||
eval = templates.NewEvaluator(modelDir)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_ = os.RemoveAll(modelDir)
|
||||
})
|
||||
|
||||
It("includes the routing system prompt in the rendered ChatML envelope", func() {
|
||||
// Mirrors the live arch-router-1.5b.yaml: chatml-style chat +
|
||||
// chat_message templates. This is the production-wired path.
|
||||
writeChatMLClassifierModel(modelDir, "arch-router")
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
|
||||
s := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -0.05,
|
||||
"casual-chat": -3.0,
|
||||
"math-reasoning": -4.0,
|
||||
}}
|
||||
_, err := runRouterWithDeps(loader, appConfig, store, routerCfg,
|
||||
openAIChat("debug this null pointer"),
|
||||
ClassifierDeps{
|
||||
Scorer: stubScorerFactory(s),
|
||||
ModelLookup: loaderLookup(loader, appConfig),
|
||||
Evaluator: eval,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// The routing system prompt must reach the scorer. Three
|
||||
// anchors: the route-listing block, one of the JSON-shaped
|
||||
// route entries (escapeJSONString preserves the description),
|
||||
// and the JSON output schema instruction.
|
||||
Expect(s.lastPrompt).To(ContainSubstring("<routes>"),
|
||||
"system prompt dropped: rendered prompt missing route-listing block. got: %q", s.lastPrompt)
|
||||
Expect(s.lastPrompt).To(ContainSubstring(`{"name": "code-generation"`),
|
||||
"system prompt dropped: rendered prompt missing route entries. got: %q", s.lastPrompt)
|
||||
Expect(s.lastPrompt).To(ContainSubstring(`{"route": "<name>"}`),
|
||||
"system prompt dropped: rendered prompt missing JSON output schema. got: %q", s.lastPrompt)
|
||||
|
||||
// And the per-role envelope must be present (proves we went
|
||||
// through ChatMessage, not the SystemPrompt-only path).
|
||||
Expect(s.lastPrompt).To(ContainSubstring("<|im_start|>system"),
|
||||
"system role marker missing — ChatMessage template wasn't invoked")
|
||||
Expect(s.lastPrompt).To(ContainSubstring("<|im_start|>user"),
|
||||
"user role marker missing")
|
||||
// User probe makes it through the per-role template. The trailing
|
||||
// \n on the probe content is added by OpenAIProbeFromRequest;
|
||||
// preserved through ChatMessage rendering.
|
||||
Expect(s.lastPrompt).To(ContainSubstring("debug this null pointer"),
|
||||
"user probe missing from rendered prompt")
|
||||
// Outer Chat template must add the assistant-open marker so
|
||||
// the scorer's first predicted token is the start of the
|
||||
// candidate.
|
||||
Expect(s.lastPrompt).To(MatchRegexp(`<\|im_start\|>assistant\s*$`),
|
||||
"rendered prompt must end at assistant-open marker. got: %q", s.lastPrompt)
|
||||
})
|
||||
|
||||
It("falls back to chatMLRenderer when the classifier model has no chat_message template", func() {
|
||||
// Partial template config: only outer Chat, no per-role
|
||||
// piece. The renderer must refuse rather than emit a prompt
|
||||
// that drops the system turn, so the score classifier's
|
||||
// built-in ChatML default takes over.
|
||||
writePartialClassifierModel(modelDir, "arch-router")
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
|
||||
s := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -0.05,
|
||||
"casual-chat": -3.0,
|
||||
"math-reasoning": -4.0,
|
||||
}}
|
||||
_, err := runRouterWithDeps(loader, appConfig, store, routerCfg,
|
||||
openAIChat("hello world"),
|
||||
ClassifierDeps{
|
||||
Scorer: stubScorerFactory(s),
|
||||
ModelLookup: loaderLookup(loader, appConfig),
|
||||
Evaluator: eval,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// chatMLRenderer fallback emits its own envelope and still
|
||||
// embeds the routing system prompt. OpenAIProbeFromRequest
|
||||
// appends "\n" after each message body, so the user content
|
||||
// reaches the renderer as "hello world\n" — the substring
|
||||
// match accounts for that.
|
||||
Expect(s.lastPrompt).To(ContainSubstring("<routes>"),
|
||||
"fallback renderer also dropped the system prompt")
|
||||
Expect(s.lastPrompt).To(ContainSubstring("<|im_start|>system\n"))
|
||||
Expect(s.lastPrompt).To(ContainSubstring("<|im_start|>user\nhello world\n<|im_end|>"))
|
||||
Expect(strings.HasSuffix(s.lastPrompt, "<|im_start|>assistant\n")).To(BeTrue(),
|
||||
"chatMLRenderer fallback must end at assistant-open marker. got: %q", s.lastPrompt)
|
||||
})
|
||||
|
||||
It("uses the classifier model's first stopword as the candidate suffix", func() {
|
||||
writeChatMLClassifierModel(modelDir, "arch-router")
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
|
||||
s := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -0.05,
|
||||
"casual-chat": -3.0,
|
||||
"math-reasoning": -4.0,
|
||||
}}
|
||||
_, err := runRouterWithDeps(loader, appConfig, store, routerCfg,
|
||||
openAIChat("hi"),
|
||||
ClassifierDeps{
|
||||
Scorer: stubScorerFactory(s),
|
||||
ModelLookup: loaderLookup(loader, appConfig),
|
||||
Evaluator: eval,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
// arch-router YAML lists <|im_end|> first.
|
||||
for _, c := range s.lastCandidates {
|
||||
Expect(c).To(HaveSuffix("<|im_end|>"),
|
||||
"candidate must end with the classifier model's turn-end token. got: %q", c)
|
||||
}
|
||||
})
|
||||
|
||||
It("picks the actual turn-end token when the stopwords list is misordered (Llama-3 style)", func() {
|
||||
// gallery/llama3-instruct.yaml et al. defensively list
|
||||
// <|im_end|> first even though the actual Llama-3 assistant
|
||||
// turn-end is <|eot_id|>. The naive "stopwords[0]" pick would
|
||||
// suffix candidates with <|im_end|> — a token Llama-3 never
|
||||
// emits at turn end. pickAssistantTurnEnd should scan the
|
||||
// chat_message template and recognise <|eot_id|> as the real
|
||||
// turn-end.
|
||||
writeLlama3StyleClassifierModel(modelDir, "arch-router")
|
||||
routerCfg := newScoreRouterModel(modelDir, "smart-router")
|
||||
|
||||
s := &stubScorer{labelToLogProb: map[string]float64{
|
||||
"code-generation": -0.05,
|
||||
"casual-chat": -3.0,
|
||||
"math-reasoning": -4.0,
|
||||
}}
|
||||
_, err := runRouterWithDeps(loader, appConfig, store, routerCfg,
|
||||
openAIChat("hi"),
|
||||
ClassifierDeps{
|
||||
Scorer: stubScorerFactory(s),
|
||||
ModelLookup: loaderLookup(loader, appConfig),
|
||||
Evaluator: eval,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
for _, c := range s.lastCandidates {
|
||||
Expect(c).To(HaveSuffix("<|eot_id|>"),
|
||||
"candidate must end with the Llama-3 turn-end token, not the misordered first stopword. got: %q", c)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
// stubScorer scores each candidate label according to a fixed
|
||||
// label→log-prob map; per-token length is faked at 2 tokens so length
|
||||
// normalisation is a no-op. Captures the prompt + candidate list of
|
||||
// the last Score call so regression tests can pin the rendered prompt
|
||||
// shape.
|
||||
type stubScorer struct {
|
||||
labelToLogProb map[string]float64
|
||||
lastPrompt string
|
||||
lastCandidates []string
|
||||
}
|
||||
|
||||
func (s *stubScorer) Score(_ context.Context, prompt string, candidates []string) ([]backend.CandidateScore, error) {
|
||||
s.lastPrompt = prompt
|
||||
s.lastCandidates = append([]string(nil), candidates...)
|
||||
out := make([]backend.CandidateScore, len(candidates))
|
||||
for i, c := range candidates {
|
||||
// Match against the full `{"route": "<label>"}` envelope.
|
||||
// Naively substring-matching on `"<label>"` would let a label
|
||||
// that's a substring of another collide via Go's randomised
|
||||
// map iteration order — `"code"` would also match the
|
||||
// `"code-generation"` candidate.
|
||||
var lp float64
|
||||
for label, v := range s.labelToLogProb {
|
||||
if strings.Contains(c, `{"route": "`+label+`"}`) {
|
||||
lp = v
|
||||
break
|
||||
}
|
||||
}
|
||||
out[i] = backend.CandidateScore{
|
||||
LogProb: lp * 2,
|
||||
LengthNormalizedLogProb: lp,
|
||||
NumTokens: 2,
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func stubScorerFactory(s *stubScorer) ScorerFactory {
|
||||
return func(string) backend.Scorer { return s }
|
||||
}
|
||||
|
||||
type fakeDecisionStore struct {
|
||||
records []router.DecisionRecord
|
||||
}
|
||||
|
||||
func (f *fakeDecisionStore) Record(_ context.Context, r router.DecisionRecord) error {
|
||||
f.records = append(f.records, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDecisionStore) List(_ context.Context, _ router.DecisionListQuery) ([]router.DecisionRecord, error) {
|
||||
out := append([]router.DecisionRecord(nil), f.records...)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (f *fakeDecisionStore) Close() error { return nil }
|
||||
func (f *fakeDecisionStore) Count(_ context.Context) (int, error) { return len(f.records), nil }
|
||||
|
||||
// newScoreRouterModel builds a smart-router config with 3 policies
|
||||
// and 2 candidates (small with one label, bigger with all three).
|
||||
// Admins are expected to order candidates small → large; the
|
||||
// middleware picks the first whose labels are a superset of the
|
||||
// active set.
|
||||
func newScoreRouterModel(modelDir, name string) *config.ModelConfig {
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Router: config.RouterConfig{
|
||||
Classifier: "score",
|
||||
ClassifierModel: "arch-router",
|
||||
Fallback: "qwen3-0.6b",
|
||||
Policies: []config.RouterPolicy{
|
||||
{Label: "code-generation", Description: "writing or debugging code"},
|
||||
{Label: "casual-chat", Description: "small talk"},
|
||||
{Label: "math-reasoning", Description: "arithmetic and word problems"},
|
||||
},
|
||||
Candidates: []config.RouterCandidate{
|
||||
{Model: "small-model", Labels: []string{"casual-chat"}},
|
||||
{Model: "big-model", Labels: []string{"code-generation", "casual-chat", "math-reasoning"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(toYAML(cfg)), 0o644)).To(Succeed())
|
||||
return cfg
|
||||
}
|
||||
|
||||
func writeCandidate(modelDir, name string) {
|
||||
body := "name: " + name + "\nbackend: mock-backend\n"
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(body), 0o644)).To(Succeed())
|
||||
}
|
||||
|
||||
func toYAML(cfg *config.ModelConfig) string {
|
||||
b, err := yaml.Marshal(cfg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func openAIChat(content string) *schema.OpenAIRequest {
|
||||
req := &schema.OpenAIRequest{
|
||||
Messages: []schema.Message{
|
||||
{Role: "user", Content: content},
|
||||
},
|
||||
}
|
||||
req.Model = "smart-router"
|
||||
return req
|
||||
}
|
||||
|
||||
func runRouter(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, store router.DecisionStore, routerCfg *config.ModelConfig, parsed any, scorerFactory ScorerFactory) (*httptest.ResponseRecorder, error) {
|
||||
return runRouterWithDeps(loader, appConfig, store, routerCfg, parsed, ClassifierDeps{Scorer: scorerFactory})
|
||||
}
|
||||
|
||||
// runRouterWithDeps is runRouter's general form: lets the caller pass
|
||||
// a fully-populated ClassifierDeps (ModelLookup, Evaluator, ...) so
|
||||
// tests can exercise the template-renderer + stop-token derivation
|
||||
// paths, not just the bare-scorer fast path.
|
||||
func runRouterWithDeps(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, store router.DecisionStore, routerCfg *config.ModelConfig, parsed any, deps ClassifierDeps) (*httptest.ResponseRecorder, error) {
|
||||
mw := RouteModel(loader, appConfig, store, nil, OpenAIProbe, router.SourceChat, deps)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader("{}"))
|
||||
rec := httptest.NewRecorder()
|
||||
c := echo.New().NewContext(req, rec)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, routerCfg)
|
||||
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, parsed)
|
||||
handler := mw(func(c echo.Context) error {
|
||||
served, _ := c.Get(ContextKeyServedModel).(string)
|
||||
return c.String(http.StatusOK, "served:"+served)
|
||||
})
|
||||
err := handler(c)
|
||||
return rec, err
|
||||
}
|
||||
|
||||
// loaderLookup mirrors application.ModelConfigLookup — bridges the
|
||||
// loader to the ModelConfigLookup signature ClassifierDeps wants.
|
||||
func loaderLookup(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig) ModelConfigLookup {
|
||||
return func(name string) *config.ModelConfig {
|
||||
cfg, err := loader.LoadModelConfigFileByNameDefaultOptions(name, appConfig)
|
||||
if err != nil || cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
|
||||
// writeChatMLClassifierModel writes a classifier model YAML that
|
||||
// mirrors the live arch-router-1.5b.yaml shipped at
|
||||
// volumes/models/arch-router-1.5b.yaml: ChatML chat + chat_message
|
||||
// templates, score usecase, <|im_end|> first in stopwords.
|
||||
func writeChatMLClassifierModel(modelDir, name string) {
|
||||
body := `name: ` + name + `
|
||||
backend: llama-cpp
|
||||
known_usecases:
|
||||
- score
|
||||
stopwords:
|
||||
- <|im_end|>
|
||||
- <|endoftext|>
|
||||
template:
|
||||
chat: |
|
||||
{{.Input -}}
|
||||
<|im_start|>assistant
|
||||
chat_message: |
|
||||
<|im_start|>{{ .RoleName }}
|
||||
{{- if .Content }}
|
||||
{{ .Content }}
|
||||
{{- end }}<|im_end|>
|
||||
`
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(body), 0o644)).To(Succeed())
|
||||
}
|
||||
|
||||
// writeLlama3StyleClassifierModel writes a classifier model mirroring
|
||||
// gallery/llama3-instruct.yaml — stopwords defensively list <|im_end|>
|
||||
// first even though the assistant turn-end is actually <|eot_id|>.
|
||||
// Exercises pickAssistantTurnEnd's template scan: the right token is
|
||||
// the one that appears in chat_message, not the one at position 0.
|
||||
func writeLlama3StyleClassifierModel(modelDir, name string) {
|
||||
body := `name: ` + name + `
|
||||
backend: llama-cpp
|
||||
known_usecases:
|
||||
- score
|
||||
stopwords:
|
||||
- <|im_end|>
|
||||
- <dummy32000>
|
||||
- "<|eot_id|>"
|
||||
- <|end_of_text|>
|
||||
template:
|
||||
chat: |
|
||||
{{.Input }}
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
chat_message: |
|
||||
<|start_header_id|>{{ .RoleName }}<|end_header_id|>
|
||||
|
||||
{{ .Content }}<|eot_id|>
|
||||
`
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(body), 0o644)).To(Succeed())
|
||||
}
|
||||
|
||||
// writePartialClassifierModel writes a classifier model that has the
|
||||
// outer Chat template but no ChatMessage — exercises the
|
||||
// newTemplateRenderer "refuse partial templating" branch that hands
|
||||
// off to chatMLRenderer.
|
||||
func writePartialClassifierModel(modelDir, name string) {
|
||||
body := `name: ` + name + `
|
||||
backend: llama-cpp
|
||||
known_usecases:
|
||||
- score
|
||||
stopwords:
|
||||
- <|im_end|>
|
||||
template:
|
||||
chat: |
|
||||
{{.Input -}}
|
||||
<|im_start|>assistant
|
||||
`
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(body), 0o644)).To(Succeed())
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
@@ -116,6 +118,16 @@ func truncateForTrace(body []byte, maxBytes int) ([]byte, bool) {
|
||||
return out, true
|
||||
}
|
||||
|
||||
// Hijack lets WebSocket upgraders (gorilla/websocket) reach the
|
||||
// underlying connection. Without this, gorilla's Hijacker type-assertion
|
||||
// fails on the wrapped writer and the handshake returns 500.
|
||||
func (w *bodyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
func initializeTracing(maxItems int) {
|
||||
tracingMaxItems = maxItems
|
||||
doInitializeTracing()
|
||||
|
||||
@@ -2,165 +2,19 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
usageFlushInterval = 5 * time.Second
|
||||
// usageMaxPending bounds the in-memory queue. Sized for bursty inference
|
||||
// traffic on a self-hosted instance with a slow or unavailable DB.
|
||||
usageMaxPending = 50000
|
||||
)
|
||||
|
||||
// usageBatcher accumulates usage records and flushes them to the DB periodically.
|
||||
type usageBatcher struct {
|
||||
mu sync.Mutex
|
||||
pending []*auth.UsageRecord
|
||||
db *gorm.DB
|
||||
stop chan struct{}
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// droppedRecords counts records discarded because the in-memory queue was full.
|
||||
// Used to rate-limit the warn log so a sustained outage doesn't flood it.
|
||||
var droppedRecords atomic.Uint64
|
||||
|
||||
func (b *usageBatcher) add(r *auth.UsageRecord) {
|
||||
b.mu.Lock()
|
||||
if len(b.pending) >= usageMaxPending {
|
||||
b.mu.Unlock()
|
||||
// Rate-limit: one warn per 1024 drops keeps the log readable.
|
||||
n := droppedRecords.Add(1)
|
||||
if n&1023 == 1 {
|
||||
xlog.Warn("usage batcher full, dropping record",
|
||||
"cap", usageMaxPending, "total_dropped", n)
|
||||
}
|
||||
return
|
||||
}
|
||||
b.pending = append(b.pending, r)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *usageBatcher) flush() {
|
||||
b.mu.Lock()
|
||||
batch := b.pending
|
||||
b.pending = nil
|
||||
b.mu.Unlock()
|
||||
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := b.db.Create(&batch).Error; err != nil {
|
||||
xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err)
|
||||
// Cap-aware re-queue: prepend as much of the failed batch as fits
|
||||
// alongside any records added concurrently with the failed write.
|
||||
b.mu.Lock()
|
||||
room := usageMaxPending - len(b.pending)
|
||||
if room > 0 {
|
||||
if room > len(batch) {
|
||||
room = len(batch)
|
||||
}
|
||||
b.pending = append(batch[:room], b.pending...)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *usageBatcher) run() {
|
||||
defer close(b.done)
|
||||
ticker := time.NewTicker(usageFlushInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
b.flush()
|
||||
case <-b.stop:
|
||||
b.flush() // final drain
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *usageBatcher) shutdown() {
|
||||
b.stopOnce.Do(func() {
|
||||
close(b.stop)
|
||||
<-b.done
|
||||
})
|
||||
}
|
||||
|
||||
// The package-level batcher is guarded by batcherMu so Init / Shutdown cycles
|
||||
// (the test pattern) don't race against UsageMiddleware reads.
|
||||
var (
|
||||
batcherMu sync.RWMutex
|
||||
batcher *usageBatcher
|
||||
)
|
||||
|
||||
func currentBatcher() *usageBatcher {
|
||||
batcherMu.RLock()
|
||||
defer batcherMu.RUnlock()
|
||||
return batcher
|
||||
}
|
||||
|
||||
// InitUsageRecorder starts a background goroutine that periodically flushes
|
||||
// accumulated usage records to the database. Calling it more than once
|
||||
// shuts down the previous batcher first so its goroutine doesn't leak.
|
||||
func InitUsageRecorder(db *gorm.DB) {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
|
||||
batcherMu.Lock()
|
||||
old := batcher
|
||||
batcher = nil
|
||||
batcherMu.Unlock()
|
||||
if old != nil {
|
||||
old.shutdown()
|
||||
}
|
||||
|
||||
b := &usageBatcher{
|
||||
db: db,
|
||||
stop: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
batcherMu.Lock()
|
||||
batcher = b
|
||||
batcherMu.Unlock()
|
||||
|
||||
go b.run()
|
||||
}
|
||||
|
||||
// ShutdownUsageRecorder stops the background flusher and synchronously drains
|
||||
// pending records once. Safe to call multiple times. Not yet wired into the
|
||||
// application lifecycle; intended for graceful process exit and tests.
|
||||
func ShutdownUsageRecorder() {
|
||||
batcherMu.Lock()
|
||||
b := batcher
|
||||
batcher = nil
|
||||
batcherMu.Unlock()
|
||||
if b != nil {
|
||||
b.shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
// FlushNow synchronously flushes any pending usage records. Intended for tests
|
||||
// that need deterministic behaviour without waiting for the ticker.
|
||||
func FlushNow() {
|
||||
if b := currentBatcher(); b != nil {
|
||||
b.flush()
|
||||
}
|
||||
}
|
||||
|
||||
// usageResponseBody is the minimal structure we need from the response JSON.
|
||||
// usageResponseBody is the minimal structure we need from an OpenAI-shaped
|
||||
// JSON response. Anthropic responses are decoded separately because their
|
||||
// usage block uses different field names (input_tokens / output_tokens).
|
||||
type usageResponseBody struct {
|
||||
Model string `json:"model"`
|
||||
Usage *struct {
|
||||
@@ -170,19 +24,47 @@ type usageResponseBody struct {
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// UsageMiddleware extracts token usage from OpenAI-compatible response JSON
|
||||
// and records it per-user.
|
||||
func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
||||
// anthropicResponseBody covers /v1/messages JSON responses.
|
||||
type anthropicResponseBody struct {
|
||||
Model string `json:"model"`
|
||||
Usage *struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// UsageMiddleware records token usage for inference requests via the
|
||||
// billing.Recorder. Two paths produce a record:
|
||||
//
|
||||
// 1. Handler-stamped (preferred): the request handler called
|
||||
// middleware.StampUsage with the canonical token counts before
|
||||
// returning. This is the only reliable path for streaming responses
|
||||
// — clients rarely set OpenAI's stream_options.include_usage, and
|
||||
// Anthropic's usage lives in a separate message_delta event.
|
||||
// 2. Body-parsed (fallback): the response is parsed for an OpenAI- or
|
||||
// Anthropic-shaped usage block. Used by passthrough proxies and
|
||||
// foreign endpoints.
|
||||
//
|
||||
// Recorder being nil (e.g., --disable-stats) makes the middleware a
|
||||
// transparent pass-through. fallbackUser is used when auth.GetUser(c)
|
||||
// returns nil; without it, an unauthenticated request would be dropped.
|
||||
//
|
||||
// Every request that fails to produce a record ticks
|
||||
// localai_usage_unrecorded_total so silent billing misses are observable.
|
||||
func UsageMiddleware(recorder *billing.Recorder, fallbackUser *auth.User) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
b := currentBatcher()
|
||||
if db == nil || b == nil {
|
||||
if recorder == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Wrap response writer to capture body
|
||||
// Wrap response writer to capture body for the fallback parser.
|
||||
// When the handler stamps the context we never read this buffer,
|
||||
// so the cost is the per-chunk Write going through one extra
|
||||
// indirection — accepted overhead in exchange for one billing
|
||||
// path that works for both stamping and body-parse callers.
|
||||
resBody := new(bytes.Buffer)
|
||||
origWriter := c.Response().Writer
|
||||
mw := &bodyWriter{
|
||||
@@ -193,53 +75,34 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
||||
|
||||
handlerErr := next(c)
|
||||
|
||||
// Restore original writer
|
||||
c.Response().Writer = origWriter
|
||||
|
||||
// Only record on successful responses
|
||||
endpoint := c.Request().URL.Path
|
||||
|
||||
if c.Response().Status < 200 || c.Response().Status >= 300 {
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
// Get authenticated user
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
user = fallbackUser
|
||||
}
|
||||
if user == nil || user.ID == "" {
|
||||
billing.CountUnrecorded(context.Background(), endpoint, "no_user")
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
// Try to parse usage from response
|
||||
responseBytes := resBody.Bytes()
|
||||
if len(responseBytes) == 0 {
|
||||
model, prompt, completion, total, ok := tokensFromContext(c)
|
||||
if !ok {
|
||||
model, prompt, completion, total, ok = tokensFromBody(resBody.Bytes(), c.Response().Header().Get("Content-Type"))
|
||||
}
|
||||
if !ok {
|
||||
billing.CountUnrecorded(context.Background(), endpoint, "no_usage")
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
// Check content type
|
||||
ct := c.Response().Header().Get("Content-Type")
|
||||
isJSON := ct == "" || ct == "application/json" || bytes.HasPrefix([]byte(ct), []byte("application/json"))
|
||||
isSSE := bytes.HasPrefix([]byte(ct), []byte("text/event-stream"))
|
||||
|
||||
if !isJSON && !isSSE {
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
var resp usageResponseBody
|
||||
if isSSE {
|
||||
last, ok := lastSSEData(responseBytes)
|
||||
if !ok {
|
||||
return handlerErr
|
||||
}
|
||||
if err := json.Unmarshal(last, &resp); err != nil {
|
||||
return handlerErr
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(responseBytes, &resp); err != nil {
|
||||
return handlerErr
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Usage == nil {
|
||||
return handlerErr
|
||||
}
|
||||
requested, served := modelsFromContext(c, model)
|
||||
pre, post := promptTokensFromContext(c, prompt)
|
||||
|
||||
source := auth.GetSource(c)
|
||||
if source == "" {
|
||||
@@ -249,16 +112,21 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
record := &auth.UsageRecord{
|
||||
UserID: user.ID,
|
||||
UserName: user.Name,
|
||||
Source: source,
|
||||
Model: resp.Model,
|
||||
Endpoint: c.Request().URL.Path,
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
CreatedAt: startTime,
|
||||
UserID: user.ID,
|
||||
UserName: user.Name,
|
||||
Source: source,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
PromptTokens: prompt,
|
||||
CompletionTokens: completion,
|
||||
TotalTokens: total,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
CreatedAt: startTime,
|
||||
RequestedModel: requested,
|
||||
ServedModel: served,
|
||||
PreFilterPromptTokens: pre,
|
||||
PostFilterPromptTokens: post,
|
||||
CorrelationID: correlationIDFromContext(c),
|
||||
}
|
||||
|
||||
if key := auth.GetAPIKey(c); key != nil {
|
||||
@@ -267,13 +135,145 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
||||
record.APIKeyName = key.Name
|
||||
}
|
||||
|
||||
b.add(record)
|
||||
if err := recorder.Record(context.Background(), record); err != nil {
|
||||
xlog.Error("usage middleware: recorder.Record failed", "error", err, "user", user.ID, "model", model)
|
||||
billing.CountUnrecorded(context.Background(), endpoint, "record_failed")
|
||||
}
|
||||
|
||||
return handlerErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tokensFromContext returns canonical token counts stamped by a handler
|
||||
// via middleware.StampUsage. Returns ok=false when no stamp is present
|
||||
// — the caller then tries the body-parse fallback.
|
||||
//
|
||||
// A model name without token counts is not considered "stamped" because a
|
||||
// record with zero tokens looks the same as a never-recorded request to
|
||||
// later analytics; the second condition is what gates ok.
|
||||
func tokensFromContext(c echo.Context) (model string, prompt, completion, total int64, ok bool) {
|
||||
if v, found := c.Get(ContextKeyResponseModel).(string); found {
|
||||
model = v
|
||||
}
|
||||
pPresent := false
|
||||
cPresent := false
|
||||
if v, found := c.Get(ContextKeyPromptTokens).(int64); found {
|
||||
prompt = v
|
||||
pPresent = true
|
||||
}
|
||||
if v, found := c.Get(ContextKeyCompletionTokens).(int64); found {
|
||||
completion = v
|
||||
cPresent = true
|
||||
}
|
||||
if v, found := c.Get(ContextKeyTotalTokens).(int64); found {
|
||||
total = v
|
||||
} else {
|
||||
total = prompt + completion
|
||||
}
|
||||
ok = pPresent || cPresent
|
||||
return
|
||||
}
|
||||
|
||||
// tokensFromBody covers the passthrough-proxy / foreign-endpoint case
|
||||
// where no handler stamps the context. Returns ok=false on any parse
|
||||
// failure or missing-usage; the caller increments the unrecorded counter.
|
||||
func tokensFromBody(responseBytes []byte, contentType string) (model string, prompt, completion, total int64, ok bool) {
|
||||
if len(responseBytes) == 0 {
|
||||
return
|
||||
}
|
||||
isJSON := contentType == "" || contentType == "application/json" || bytes.HasPrefix([]byte(contentType), []byte("application/json"))
|
||||
isSSE := bytes.HasPrefix([]byte(contentType), []byte("text/event-stream"))
|
||||
if !isJSON && !isSSE {
|
||||
return
|
||||
}
|
||||
|
||||
payload := responseBytes
|
||||
if isSSE {
|
||||
// For SSE, the canonical usage chunk is the *last* non-[DONE] data
|
||||
// line. OpenAI clients only emit one if stream_options.include_usage
|
||||
// is set; Anthropic emits a final message_delta with usage. Both
|
||||
// fit the "last data: line" rule.
|
||||
last, lastOk := lastSSEData(responseBytes)
|
||||
if !lastOk {
|
||||
return
|
||||
}
|
||||
payload = last
|
||||
}
|
||||
|
||||
// Try OpenAI shape first (handles /v1/chat/completions, /v1/completions,
|
||||
// /v1/embeddings, /v1/edits, and any proxy that translates to OpenAI).
|
||||
// A usage block whose token fields all decoded to zero is ambiguous —
|
||||
// it could be an Anthropic body that happens to have a `usage` key —
|
||||
// so fall through to the Anthropic parser instead of recording zeros.
|
||||
var openAI usageResponseBody
|
||||
if err := json.Unmarshal(payload, &openAI); err == nil && openAI.Usage != nil {
|
||||
if openAI.Usage.PromptTokens != 0 || openAI.Usage.CompletionTokens != 0 || openAI.Usage.TotalTokens != 0 {
|
||||
model = openAI.Model
|
||||
prompt = openAI.Usage.PromptTokens
|
||||
completion = openAI.Usage.CompletionTokens
|
||||
total = openAI.Usage.TotalTokens
|
||||
if total == 0 {
|
||||
total = prompt + completion
|
||||
}
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Fall through to Anthropic shape (proxy passthrough territory).
|
||||
var ant anthropicResponseBody
|
||||
if err := json.Unmarshal(payload, &ant); err == nil && ant.Usage != nil {
|
||||
if ant.Usage.InputTokens != 0 || ant.Usage.OutputTokens != 0 {
|
||||
model = ant.Model
|
||||
prompt = ant.Usage.InputTokens
|
||||
completion = ant.Usage.OutputTokens
|
||||
total = prompt + completion
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// modelsFromContext returns (requested, served) using context-set values
|
||||
// when present, falling back to the response-reported model for both.
|
||||
// The router middleware (subsystem 2 of the routing plan) populates
|
||||
// these; until it lands they are equal.
|
||||
func modelsFromContext(c echo.Context, fallback string) (string, string) {
|
||||
requested := fallback
|
||||
served := fallback
|
||||
if v, ok := c.Get(ContextKeyRequestedModel).(string); ok && v != "" {
|
||||
requested = v
|
||||
}
|
||||
if v, ok := c.Get(ContextKeyServedModel).(string); ok && v != "" {
|
||||
served = v
|
||||
}
|
||||
return requested, served
|
||||
}
|
||||
|
||||
func promptTokensFromContext(c echo.Context, fallback int64) (int64, int64) {
|
||||
pre := fallback
|
||||
post := fallback
|
||||
if v, ok := c.Get(ContextKeyPreFilterPromptTokens).(int64); ok && v > 0 {
|
||||
pre = v
|
||||
}
|
||||
if v, ok := c.Get(ContextKeyPostFilterPromptTokens).(int64); ok && v > 0 {
|
||||
post = v
|
||||
}
|
||||
return pre, post
|
||||
}
|
||||
|
||||
func correlationIDFromContext(c echo.Context) string {
|
||||
if v, ok := c.Get(ContextKeyCorrelationID).(string); ok {
|
||||
return v
|
||||
}
|
||||
// X-Correlation-ID header is set by trace.go middleware; read it as a
|
||||
// fallback if the echo-context binding hasn't been populated yet.
|
||||
return c.Response().Header().Get("X-Correlation-ID")
|
||||
}
|
||||
|
||||
// lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]".
|
||||
func lastSSEData(b []byte) ([]byte, bool) {
|
||||
prefix := []byte("data: ")
|
||||
|
||||
33
core/http/middleware/usage_stamp.go
Normal file
33
core/http/middleware/usage_stamp.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/labstack/echo/v4"
|
||||
|
||||
// StampUsage records the canonical token counts on the echo context so
|
||||
// UsageMiddleware can attribute the request without parsing the response
|
||||
// body. Handlers must call this for every successful response — the
|
||||
// body-parse fallback is reserved for foreign endpoints (e.g., the cloud
|
||||
// passthrough proxy).
|
||||
//
|
||||
// model is the name written into the response payload; passing it here
|
||||
// is what lets the middleware fill the UsageRecord even when the handler
|
||||
// abbreviates or rewrites the user-supplied model. Empty values are
|
||||
// ignored so partial information is still useful (e.g., embeddings calls
|
||||
// where completion is always 0).
|
||||
//
|
||||
// prompt and completion accept int because that's the native width of
|
||||
// LocalAI's TokenUsage / OpenAIUsage structs (token counts never come
|
||||
// close to overflow). Conversion to int64 happens once, here, so call
|
||||
// sites stay free of casts.
|
||||
func StampUsage(c echo.Context, model string, prompt, completion int) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if model != "" {
|
||||
c.Set(ContextKeyResponseModel, model)
|
||||
}
|
||||
p := int64(prompt)
|
||||
cp := int64(completion)
|
||||
c.Set(ContextKeyPromptTokens, p)
|
||||
c.Set(ContextKeyCompletionTokens, cp)
|
||||
c.Set(ContextKeyTotalTokens, p+cp)
|
||||
}
|
||||
@@ -1,140 +1,323 @@
|
||||
//go:build auth
|
||||
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// testAuthDB returns a fresh in-memory SQLite auth DB.
|
||||
func testAuthDB() *gorm.DB {
|
||||
db, err := auth.InitDB(":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return db
|
||||
// captureBackend collects records the recorder forwards. We assert on
|
||||
// it directly rather than going through StatsBackend.Aggregate because
|
||||
// these tests verify the middleware -> recorder hop, not aggregation
|
||||
// (which has its own tests in routing/billing).
|
||||
type captureBackend struct {
|
||||
records []*auth.UsageRecord
|
||||
}
|
||||
|
||||
func (c *captureBackend) Record(_ context.Context, r *auth.UsageRecord) error {
|
||||
c.records = append(c.records, r)
|
||||
return nil
|
||||
}
|
||||
func (c *captureBackend) Aggregate(_ context.Context, _ billing.AggregateQuery) ([]auth.UsageBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *captureBackend) Close() error { return nil }
|
||||
|
||||
var _ = Describe("UsageMiddleware", func() {
|
||||
var (
|
||||
e *echo.Echo
|
||||
db *gorm.DB
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
db = testAuthDB()
|
||||
e = echo.New()
|
||||
middleware.InitUsageRecorder(db)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
middleware.ShutdownUsageRecorder()
|
||||
})
|
||||
|
||||
okHandler := func(c echo.Context) error {
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "gpt-4",
|
||||
"usage": map[string]int{
|
||||
"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15,
|
||||
},
|
||||
})
|
||||
c.Response().Header().Set("Content-Type", "application/json")
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
_, _ = c.Response().Write(body)
|
||||
return nil
|
||||
mockChat := func(usage string) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Response().Header().Set("Content-Type", "application/json")
|
||||
body := fmt.Sprintf(`{"model":"qwen-7b","usage":%s}`, usage)
|
||||
return c.String(http.StatusOK, body)
|
||||
}
|
||||
}
|
||||
|
||||
// FlushNow drains pending records synchronously, replacing the 6s sleep
|
||||
// that was previously needed to wait for the batcher's ticker.
|
||||
flush := middleware.FlushNow
|
||||
It("records under the synthetic local user when auth is off", func() {
|
||||
cap := &captureBackend{}
|
||||
rec := billing.NewRecorder(cap)
|
||||
fallback := &auth.User{ID: "local-uuid", Name: "local", Provider: auth.ProviderLocal}
|
||||
|
||||
It("records source=web when auth_source is web", func() {
|
||||
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
e := echo.New()
|
||||
e.POST("/v1/chat/completions",
|
||||
mockChat(`{"prompt_tokens":12,"completion_tokens":8,"total_tokens":20}`),
|
||||
httpMiddleware.UsageMiddleware(rec, fallback),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(cap.records).To(HaveLen(1))
|
||||
r := cap.records[0]
|
||||
Expect(r.UserID).To(Equal("local-uuid"))
|
||||
Expect(r.UserName).To(Equal("local"))
|
||||
Expect(r.Model).To(Equal("qwen-7b"))
|
||||
Expect(r.PromptTokens).To(Equal(int64(12)))
|
||||
Expect(r.CompletionTokens).To(Equal(int64(8)))
|
||||
Expect(r.TotalTokens).To(Equal(int64(20)))
|
||||
})
|
||||
|
||||
It("does nothing when recorder is nil (--disable-stats)", func() {
|
||||
fallback := &auth.User{ID: "local-uuid", Name: "local"}
|
||||
e := echo.New()
|
||||
e.POST("/v1/chat/completions",
|
||||
mockChat(`{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}`),
|
||||
httpMiddleware.UsageMiddleware(nil, fallback),
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
// no panic, no record — recorder=nil is the disable-stats path
|
||||
})
|
||||
|
||||
It("skips when neither auth nor fallback user is available", func() {
|
||||
cap := &captureBackend{}
|
||||
rec := billing.NewRecorder(cap)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/v1/chat/completions",
|
||||
mockChat(`{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5}`),
|
||||
httpMiddleware.UsageMiddleware(rec, nil),
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(cap.records).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("ignores 5xx responses (no usage to attribute)", func() {
|
||||
cap := &captureBackend{}
|
||||
rec := billing.NewRecorder(cap)
|
||||
fallback := &auth.User{ID: "local-uuid", Name: "local"}
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/v1/chat/completions",
|
||||
func(c echo.Context) error {
|
||||
return c.String(http.StatusInternalServerError, `{"error":"boom"}`)
|
||||
},
|
||||
httpMiddleware.UsageMiddleware(rec, fallback),
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
Expect(w.Code).To(Equal(http.StatusInternalServerError))
|
||||
Expect(cap.records).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("records via context-stamped tokens when handler called StampUsage (streaming-safe path)", func() {
|
||||
cap := &captureBackend{}
|
||||
rec := billing.NewRecorder(cap)
|
||||
fallback := &auth.User{ID: "local-uuid", Name: "local"}
|
||||
|
||||
// Simulate a streaming chat handler that emits SSE chunks WITHOUT a
|
||||
// terminal usage block (the common case — clients rarely set
|
||||
// stream_options.include_usage). The handler stamps the canonical
|
||||
// counts on the context just before returning. UsageMiddleware
|
||||
// must record from the stamp, not from body parsing.
|
||||
streamingHandler := func(c echo.Context) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprint(c.Response().Writer, "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n")
|
||||
_, _ = fmt.Fprint(c.Response().Writer, "data: [DONE]\n\n")
|
||||
httpMiddleware.StampUsage(c, "qwen-7b", 9, 5)
|
||||
return nil
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/v1/chat/completions",
|
||||
streamingHandler,
|
||||
httpMiddleware.UsageMiddleware(rec, fallback),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(cap.records).To(HaveLen(1))
|
||||
Expect(cap.records[0].PromptTokens).To(Equal(int64(9)))
|
||||
Expect(cap.records[0].CompletionTokens).To(Equal(int64(5)))
|
||||
Expect(cap.records[0].TotalTokens).To(Equal(int64(14)))
|
||||
Expect(cap.records[0].Model).To(Equal("qwen-7b"))
|
||||
})
|
||||
|
||||
It("falls back to Anthropic body shape when no stamp is present", func() {
|
||||
cap := &captureBackend{}
|
||||
rec := billing.NewRecorder(cap)
|
||||
fallback := &auth.User{ID: "local-uuid", Name: "local"}
|
||||
|
||||
// Simulates a passthrough proxy / foreign endpoint: no handler stamp,
|
||||
// so the middleware must parse the response body. Anthropic's shape
|
||||
// uses input_tokens / output_tokens, not the OpenAI names.
|
||||
anthropicHandler := func(c echo.Context) error {
|
||||
c.Response().Header().Set("Content-Type", "application/json")
|
||||
body := `{"model":"claude-sonnet","usage":{"input_tokens":15,"output_tokens":7}}`
|
||||
return c.String(http.StatusOK, body)
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/v1/messages",
|
||||
anthropicHandler,
|
||||
httpMiddleware.UsageMiddleware(rec, fallback),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(cap.records).To(HaveLen(1))
|
||||
Expect(cap.records[0].PromptTokens).To(Equal(int64(15)))
|
||||
Expect(cap.records[0].CompletionTokens).To(Equal(int64(7)))
|
||||
Expect(cap.records[0].TotalTokens).To(Equal(int64(22)))
|
||||
Expect(cap.records[0].Model).To(Equal("claude-sonnet"))
|
||||
})
|
||||
|
||||
It("populates RequestedModel/ServedModel from echo context when set", func() {
|
||||
cap := &captureBackend{}
|
||||
rec := billing.NewRecorder(cap)
|
||||
fallback := &auth.User{ID: "local-uuid", Name: "local"}
|
||||
|
||||
// A pre-handler stand-in for the future router middleware: it
|
||||
// rewrites Served and remembers the original Requested. Once the
|
||||
// real router lands, this is exactly the contract it must keep.
|
||||
setRouterContext := func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"})
|
||||
c.Set("auth_source", auth.UsageSourceWeb)
|
||||
c.Set(httpMiddleware.ContextKeyRequestedModel, "auto")
|
||||
c.Set(httpMiddleware.ContextKeyServedModel, "qwen-7b")
|
||||
return next(c)
|
||||
}
|
||||
}, middleware.UsageMiddleware(db))
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||
flush()
|
||||
e := echo.New()
|
||||
e.POST("/v1/chat/completions",
|
||||
mockChat(`{"prompt_tokens":4,"completion_tokens":3,"total_tokens":7}`),
|
||||
httpMiddleware.UsageMiddleware(rec, fallback),
|
||||
setRouterContext,
|
||||
)
|
||||
|
||||
var rec auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed())
|
||||
Expect(rec.Source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(rec.APIKeyID).To(BeNil())
|
||||
Expect(rec.APIKeyName).To(BeEmpty())
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(cap.records).To(HaveLen(1))
|
||||
Expect(cap.records[0].RequestedModel).To(Equal("auto"))
|
||||
Expect(cap.records[0].ServedModel).To(Equal("qwen-7b"))
|
||||
})
|
||||
|
||||
// stampAuth is a stand-in for the auth middleware: it sets the
|
||||
// echo-context keys UsageMiddleware reads. Pass source=="" to
|
||||
// simulate the unauthenticated/legacy path; pass key=nil to skip
|
||||
// the API-key snapshot.
|
||||
stampAuth := func(user *auth.User, source string, key *auth.UserAPIKey) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if user != nil {
|
||||
c.Set("auth_user", user)
|
||||
}
|
||||
if source != "" {
|
||||
c.Set("auth_source", source)
|
||||
}
|
||||
if key != nil {
|
||||
c.Set("auth_apikey", key)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
It("records source=web when auth_source is web and snapshots no API key", func() {
|
||||
cap := &captureBackend{}
|
||||
rec := billing.NewRecorder(cap)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/v1/chat/completions",
|
||||
mockChat(`{"prompt_tokens":2,"completion_tokens":1,"total_tokens":3}`),
|
||||
httpMiddleware.UsageMiddleware(rec, nil),
|
||||
stampAuth(&auth.User{ID: "alice", Name: "Alice"}, auth.UsageSourceWeb, nil),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(cap.records).To(HaveLen(1))
|
||||
r := cap.records[0]
|
||||
Expect(r.UserID).To(Equal("alice"))
|
||||
Expect(r.Source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(r.APIKeyID).To(BeNil())
|
||||
Expect(r.APIKeyName).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("records source=apikey with snapshotted name when auth_apikey is set", func() {
|
||||
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"})
|
||||
c.Set("auth_source", auth.UsageSourceAPIKey)
|
||||
c.Set("auth_apikey", &auth.UserAPIKey{ID: "key-1", Name: "ci-runner"})
|
||||
return next(c)
|
||||
}
|
||||
}, middleware.UsageMiddleware(db))
|
||||
cap := &captureBackend{}
|
||||
rec := billing.NewRecorder(cap)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||
flush()
|
||||
e := echo.New()
|
||||
e.POST("/v1/chat/completions",
|
||||
mockChat(`{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}`),
|
||||
httpMiddleware.UsageMiddleware(rec, nil),
|
||||
stampAuth(
|
||||
&auth.User{ID: "alice", Name: "Alice"},
|
||||
auth.UsageSourceAPIKey,
|
||||
&auth.UserAPIKey{ID: "key-1", Name: "ci-runner"},
|
||||
),
|
||||
)
|
||||
|
||||
var rec auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed())
|
||||
Expect(rec.Source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(rec.APIKeyID).ToNot(BeNil())
|
||||
Expect(*rec.APIKeyID).To(Equal("key-1"))
|
||||
Expect(rec.APIKeyName).To(Equal("ci-runner"))
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(cap.records).To(HaveLen(1))
|
||||
r := cap.records[0]
|
||||
Expect(r.Source).To(Equal(auth.UsageSourceAPIKey))
|
||||
Expect(r.APIKeyID).ToNot(BeNil())
|
||||
Expect(*r.APIKeyID).To(Equal("key-1"))
|
||||
Expect(r.APIKeyName).To(Equal("ci-runner"))
|
||||
})
|
||||
|
||||
It("FlushNow drains pending records synchronously", func() {
|
||||
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", &auth.User{ID: "carol", Name: "Carol"})
|
||||
c.Set("auth_source", auth.UsageSourceWeb)
|
||||
return next(c)
|
||||
}
|
||||
}, middleware.UsageMiddleware(db))
|
||||
It("defaults source=web when auth_source is empty", func() {
|
||||
cap := &captureBackend{}
|
||||
rec := billing.NewRecorder(cap)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||
// Only user set, no source — the middleware must classify the
|
||||
// row as web rather than dropping it from per-source aggregates.
|
||||
e := echo.New()
|
||||
e.POST("/v1/chat/completions",
|
||||
mockChat(`{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}`),
|
||||
httpMiddleware.UsageMiddleware(rec, nil),
|
||||
stampAuth(&auth.User{ID: "alice", Name: "Alice"}, "", nil),
|
||||
)
|
||||
|
||||
// No sleep: FlushNow should drain immediately.
|
||||
middleware.FlushNow()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
var rec auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "carol").First(&rec).Error).To(Succeed())
|
||||
Expect(rec.Source).To(Equal(auth.UsageSourceWeb))
|
||||
})
|
||||
|
||||
It("falls back to source=web when auth_source is empty", func() {
|
||||
e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"})
|
||||
// no auth_source set
|
||||
return next(c)
|
||||
}
|
||||
}, middleware.UsageMiddleware(db))
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
e.ServeHTTP(httptest.NewRecorder(), req)
|
||||
flush()
|
||||
|
||||
var rec auth.UsageRecord
|
||||
Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed())
|
||||
Expect(rec.Source).To(Equal(auth.UsageSourceWeb))
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(cap.records).To(HaveLen(1))
|
||||
Expect(cap.records[0].Source).To(Equal(auth.UsageSourceWeb))
|
||||
})
|
||||
})
|
||||
|
||||
308
core/http/react-ui/e2e/middleware-page.spec.js
Normal file
308
core/http/react-ui/e2e/middleware-page.spec.js
Normal file
@@ -0,0 +1,308 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
|
||||
// Mocked fixture covering the three things the page renders:
|
||||
// - PII pattern catalogue (action badges, action-change buttons)
|
||||
// - Per-model resolved PII state (one with default off, one with proxy default on, one with explicit YAML)
|
||||
// - Recent events feed (the page must NEVER show the redacted content)
|
||||
const MOCK_STATUS = {
|
||||
pii: {
|
||||
enabled_globally: true,
|
||||
default_enabled_for_backends: ['cloud-proxy'],
|
||||
patterns: [
|
||||
{ id: 'email', description: 'Email addresses', action: 'mask', max_match_length: 254 },
|
||||
{ id: 'ssn', description: 'US Social Security Numbers', action: 'mask', max_match_length: 11 },
|
||||
{ id: 'api_key_prefix', description: 'API key prefixes', action: 'block', max_match_length: 200 },
|
||||
],
|
||||
models: [
|
||||
{ name: 'qwen-7b', backend: 'llama-cpp', enabled: false, explicit: false, default_for_backend: false, overrides: null },
|
||||
{ name: 'claude-sonnet', backend: 'cloud-proxy', enabled: true, explicit: false, default_for_backend: true, overrides: null },
|
||||
{ name: 'claude-strict', backend: 'cloud-proxy', enabled: true, explicit: true, default_for_backend: true, overrides: { ssn: 'block' } },
|
||||
],
|
||||
recent_event_count: 2,
|
||||
},
|
||||
router: {
|
||||
configured: true,
|
||||
models: [
|
||||
{
|
||||
name: 'smart-router',
|
||||
classifier: 'score',
|
||||
fallback: 'qwen-7b',
|
||||
policies: [
|
||||
{ label: 'casual-chat', description: 'small talk' },
|
||||
{ label: 'code-generation', description: 'writing or debugging code' },
|
||||
],
|
||||
candidates: [
|
||||
{ model: 'qwen-3b', labels: ['casual-chat'] },
|
||||
{ model: 'qwen-coder', labels: ['code-generation', 'casual-chat'] },
|
||||
],
|
||||
embedding_cache: {
|
||||
embedding_model: 'nomic-embed-text-v1.5',
|
||||
similarity_threshold: 0.80,
|
||||
confidence_threshold: 0.60,
|
||||
store_name: '',
|
||||
stats: {
|
||||
hits: 31,
|
||||
misses: 1,
|
||||
near_misses: 56,
|
||||
low_confidence: 29,
|
||||
embedder_errors: 0,
|
||||
store_errors: 0,
|
||||
// peak [0.4, 0.6) for paraphrases, secondary in [0.8, 1.0) for near-exact matches
|
||||
similarity_buckets: [0, 0, 0, 1, 22, 16, 3, 7, 19, 19],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
recent_decision_count: 1,
|
||||
available_classifiers: ['score'],
|
||||
},
|
||||
}
|
||||
|
||||
const MOCK_DECISIONS = {
|
||||
decisions: [
|
||||
{
|
||||
id: 'rd_a1', correlation_id: 'corr-1', user_id: 'local',
|
||||
router_model: 'smart-router', requested_model: 'smart-router', served_model: 'qwen-3b',
|
||||
classifier: 'score', label: 'casual-chat', score: 0.91, latency_ms: 15,
|
||||
cached: true, cache_similarity: 0.92,
|
||||
created_at: '2026-05-06T11:00:00Z',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const MOCK_EVENTS = {
|
||||
events: [
|
||||
{
|
||||
id: 'pii_aaa', kind: 'pii', correlation_id: 'corr-1', user_id: 'local',
|
||||
direction: 'in', pattern_id: 'email', byte_offset: 12, length: 17,
|
||||
hash_prefix: 'ff8d9819', action: 'mask',
|
||||
created_at: '2026-05-06T10:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'proxy_connect_1', kind: 'proxy_connect',
|
||||
host: 'api.openai.com', intercepted: true,
|
||||
created_at: '2026-05-06T10:01:00Z',
|
||||
},
|
||||
{
|
||||
id: 'proxy_connect_2', kind: 'proxy_connect',
|
||||
host: 'github.com', intercepted: false,
|
||||
created_at: '2026-05-06T10:02:00Z',
|
||||
},
|
||||
{
|
||||
id: 'proxy_traffic_1', kind: 'proxy_traffic', correlation_id: 'corr-2',
|
||||
host: 'api.openai.com',
|
||||
bytes_sent: 412, bytes_received: 1228, status_code: 200, duration_ms: 240,
|
||||
created_at: '2026-05-06T10:03:00Z',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
test.describe('Middleware page — admin in no-auth mode', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.route('**/api/auth/status', (route) =>
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({ authEnabled: false, staticApiKeyRequired: false, providers: [] }),
|
||||
})
|
||||
)
|
||||
await page.route('**/api/middleware/status', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_STATUS) })
|
||||
)
|
||||
await page.route('**/api/pii/events?**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_EVENTS) })
|
||||
)
|
||||
await page.route('**/api/router/decisions?**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_DECISIONS) })
|
||||
)
|
||||
})
|
||||
|
||||
test('Filtering tab renders pattern catalogue and per-model state', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
|
||||
// Pattern table — at least one pattern id visible.
|
||||
await expect(page.getByText('email').first()).toBeVisible()
|
||||
await expect(page.getByText('api_key_prefix').first()).toBeVisible()
|
||||
|
||||
// Per-model state — each model's name is visible.
|
||||
await expect(page.getByText('qwen-7b').first()).toBeVisible()
|
||||
await expect(page.getByText('claude-strict').first()).toBeVisible()
|
||||
|
||||
// Default-policy banner names the backends with PII on by default.
|
||||
await expect(page.getByText(/cloud-proxy/).first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('Routing tab renders configured routers and recent decisions', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Routing/i }).click()
|
||||
// Active router model name visible.
|
||||
await expect(page.getByText('smart-router').first()).toBeVisible()
|
||||
// Candidate model names visible.
|
||||
await expect(page.getByText('qwen-coder').first()).toBeVisible()
|
||||
await expect(page.getByText('qwen-3b').first()).toBeVisible()
|
||||
// Decision row visible — label and served model.
|
||||
await expect(page.getByText('casual-chat').first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('Routing tab renders embedding-cache stats and similarity histogram', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Routing/i }).click()
|
||||
|
||||
// Embedding model name surfaces in the cache column.
|
||||
await expect(page.getByText('nomic-embed-text-v1.5').first()).toBeVisible()
|
||||
|
||||
// Hit-rate badge: 31 hits / (31 + 56 + 1) = 35% rounded.
|
||||
await expect(page.getByText(/35% hit/i).first()).toBeVisible()
|
||||
|
||||
// h/n/m counter row visible.
|
||||
await expect(page.getByText(/31h\/56n\/1m/).first()).toBeVisible()
|
||||
|
||||
// Skipped (low-confidence) counter visible.
|
||||
await expect(page.getByText(/29 skipped/).first()).toBeVisible()
|
||||
|
||||
// Threshold marker text matches the configured 0.80.
|
||||
await expect(page.getByText(/sim ≥ 0\.8/).first()).toBeVisible()
|
||||
|
||||
// Histogram bars rendered with hover titles that include the
|
||||
// bucket range and count. Bucket 4 (peak) has count 22; the
|
||||
// <div> with that exact title is the structural assertion.
|
||||
await expect(
|
||||
page.locator('div[title="[0.4, 0.5): 22"]')
|
||||
).toBeVisible()
|
||||
// Bucket 8 (just at threshold) has count 19.
|
||||
await expect(
|
||||
page.locator('div[title="[0.8, 0.9): 19"]')
|
||||
).toBeVisible()
|
||||
})
|
||||
|
||||
test('Routing tab shows a cached decision with cache_similarity', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Routing/i }).click()
|
||||
|
||||
// The decision row exposes the cached flag and the cosine that
|
||||
// produced the hit so admins can correlate with the histogram.
|
||||
await expect(page.getByText('corr-1')).toBeVisible()
|
||||
})
|
||||
|
||||
test('Events tab renders rows but never the redacted content', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Events/i }).click()
|
||||
// Hash prefix is visible — that's how admins audit recurring leaks.
|
||||
await expect(page.getByText('ff8d9819')).toBeVisible()
|
||||
// The page only ever shows fields the EventStore stores. The matched
|
||||
// value (e.g. "alice@example.com") would never appear because it's
|
||||
// not in the payload — explicit asserting absence here is the
|
||||
// contract the design relies on.
|
||||
await expect(page.getByText(/@example\.com/)).toHaveCount(0)
|
||||
})
|
||||
|
||||
test('Events tab renders proxy_connect rows with intercept decision', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Events/i }).click()
|
||||
|
||||
// Both intercept and tunnel decisions visible.
|
||||
const interceptRow = page.locator('tr').filter({ hasText: 'api.openai.com' }).first()
|
||||
await expect(interceptRow).toContainText(/intercepted/i)
|
||||
const tunnelRow = page.locator('tr').filter({ hasText: 'github.com' }).first()
|
||||
await expect(tunnelRow).toContainText(/tunneled/i)
|
||||
})
|
||||
|
||||
test('Events tab renders proxy_traffic byte counts and status', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Events/i }).click()
|
||||
|
||||
// The traffic row formats as "HTTP 200 · ↑412B ↓1.2KB · 240ms".
|
||||
// We assert on the durable parts: status code, byte values, duration unit.
|
||||
const trafficRow = page.locator('tr').filter({ hasText: 'corr-2' }).first()
|
||||
await expect(trafficRow).toContainText('HTTP 200')
|
||||
await expect(trafficRow).toContainText('412B')
|
||||
await expect(trafficRow).toContainText(/1\.2\s*KB/i)
|
||||
await expect(trafficRow).toContainText('240ms')
|
||||
})
|
||||
|
||||
test('Events kind filter narrows the table to the chosen kind', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Events/i }).click()
|
||||
|
||||
// Default = All: pii row + 2 connect rows + 1 traffic row visible.
|
||||
await expect(page.getByText('ff8d9819')).toBeVisible()
|
||||
await expect(page.getByText('github.com')).toBeVisible()
|
||||
|
||||
// Click "PII" filter — proxy rows must disappear.
|
||||
await page.getByRole('button', { name: /^PII$/ }).click()
|
||||
await expect(page.getByText('ff8d9819')).toBeVisible()
|
||||
await expect(page.getByText('github.com')).toHaveCount(0)
|
||||
await expect(page.getByText('HTTP 200')).toHaveCount(0)
|
||||
|
||||
// Click "Proxy traffic" — only the traffic row remains.
|
||||
await page.getByRole('button', { name: /Proxy traffic/i }).click()
|
||||
await expect(page.getByText('HTTP 200')).toBeVisible()
|
||||
await expect(page.getByText('ff8d9819')).toHaveCount(0)
|
||||
await expect(page.getByText('github.com')).toHaveCount(0)
|
||||
|
||||
// Click "Proxy connect" — both connect rows visible, no PII or traffic.
|
||||
await page.getByRole('button', { name: /Proxy connect/i }).click()
|
||||
await expect(page.locator('tr').filter({ hasText: 'github.com' })).toHaveCount(1)
|
||||
await expect(page.locator('tr').filter({ hasText: 'api.openai.com' }).filter({ hasText: 'intercepted' })).toHaveCount(1)
|
||||
await expect(page.getByText('HTTP 200')).toHaveCount(0)
|
||||
await expect(page.getByText('ff8d9819')).toHaveCount(0)
|
||||
|
||||
// Click "All" — everything back.
|
||||
await page.getByRole('button', { name: /^All$/ }).click()
|
||||
await expect(page.getByText('ff8d9819')).toBeVisible()
|
||||
await expect(page.getByText('HTTP 200')).toBeVisible()
|
||||
})
|
||||
|
||||
test('Events tab shows the kind badge for each row', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Events/i }).click()
|
||||
|
||||
// The Kind column header is present.
|
||||
await expect(page.locator('th').filter({ hasText: /^Kind$/ })).toBeVisible()
|
||||
// At least one cell renders each of the three kinds. Scope to
|
||||
// <span> elements so the "PII" filter button doesn't match.
|
||||
await expect(page.locator('span').getByText(/^pii$/i).first()).toBeVisible()
|
||||
await expect(page.getByText(/^proxy connect$/i).first()).toBeVisible()
|
||||
await expect(page.getByText(/^proxy traffic$/i).first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('PUT /api/pii/patterns/:id fires when an action button is clicked', async ({ page }) => {
|
||||
let putHit = null
|
||||
await page.route('**/api/pii/patterns/email', (route) => {
|
||||
if (route.request().method() === 'PUT') {
|
||||
putHit = JSON.parse(route.request().postData() || '{}')
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ id: 'email', action: putHit.action, persisted: false }) })
|
||||
} else {
|
||||
route.continue()
|
||||
}
|
||||
})
|
||||
|
||||
await page.goto('/app/middleware')
|
||||
// Click the email row's "block" button (currently mask, so block is
|
||||
// enabled). Use a precise locator that matches the inner button.
|
||||
const emailRow = page.locator('tr').filter({ hasText: 'email' }).first()
|
||||
await emailRow.getByRole('button', { name: 'block' }).click()
|
||||
|
||||
await expect.poll(() => putHit).toEqual({ action: 'block' })
|
||||
})
|
||||
})
|
||||
|
||||
test.describe('Middleware page — non-admin under auth-on', () => {
|
||||
test('redirects to /app when the user is not admin', async ({ page }) => {
|
||||
await page.route('**/api/auth/status', (route) =>
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({
|
||||
authEnabled: true,
|
||||
staticApiKeyRequired: false,
|
||||
providers: ['local'],
|
||||
user: { id: 'bob', name: 'Bob', role: 'user', provider: 'local' },
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
await page.goto('/app/middleware')
|
||||
// RequireAdmin redirects non-admin viewers; the URL must not stay on /middleware.
|
||||
await page.waitForURL(/\/app(?!\/middleware)/, { timeout: 5000 })
|
||||
expect(page.url()).not.toMatch(/\/middleware/)
|
||||
})
|
||||
})
|
||||
219
core/http/react-ui/e2e/router-template.spec.js
Normal file
219
core/http/react-ui/e2e/router-template.spec.js
Normal file
@@ -0,0 +1,219 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
|
||||
// Router template + structured editor regression tests.
|
||||
//
|
||||
// The historical regression was: the "Create routing model" button
|
||||
// loaded the model editor with an array-shaped `router.candidates`
|
||||
// value, which crashed when a code-editor field received it instead
|
||||
// of a string ("(intermediate value).split is not a function").
|
||||
//
|
||||
// The current schema is also covered:
|
||||
// - classifier=score is the only shipped classifier
|
||||
// - router.policies surfaces in its own structured editor (label +
|
||||
// description rows with duplicate detection)
|
||||
// - router.candidates is the structured {model, labels[]} editor;
|
||||
// labels are chips populated from router.policies via FormContext
|
||||
// - router.embedding_cache.* surface as labelled fields with the
|
||||
// correct components (model-select / slider)
|
||||
// - router.activation_threshold and the two embedding_cache slider
|
||||
// fields render with slider min/max/step from the registry
|
||||
|
||||
const ROUTER_METADATA = {
|
||||
sections: [
|
||||
{ id: 'general', label: 'General', icon: 'settings', order: 0 },
|
||||
{ id: 'other', label: 'Other', icon: 'more-horizontal', order: 100 },
|
||||
],
|
||||
fields: [
|
||||
{ path: 'name', yaml_key: 'name', go_type: 'string', ui_type: 'string',
|
||||
section: 'general', label: 'Model Name', component: 'input', order: 0 },
|
||||
{
|
||||
path: 'router.classifier', yaml_key: 'classifier', go_type: 'string', ui_type: 'string',
|
||||
section: 'other', label: 'Classifier', component: 'select',
|
||||
options: [{ value: 'score', label: 'Score (Arch-Router-style)' }],
|
||||
description: 'Picks a candidate by scoring every policy label against the prompt. Only "score" is shipped today.',
|
||||
order: 230,
|
||||
},
|
||||
{
|
||||
path: 'router.classifier_model', yaml_key: 'classifier_model', go_type: 'string', ui_type: 'string',
|
||||
section: 'other', label: 'Classifier Model', component: 'model-select', autocomplete_provider: 'models:chat',
|
||||
description: 'Loaded LocalAI model the score classifier asks to rank each policy label.',
|
||||
order: 231,
|
||||
},
|
||||
{
|
||||
path: 'router.fallback', yaml_key: 'fallback', go_type: 'string', ui_type: 'string',
|
||||
section: 'other', label: 'Fallback Model', component: 'model-select', autocomplete_provider: 'models:chat',
|
||||
description: 'Model used when no candidate covers the active label set.',
|
||||
order: 232,
|
||||
},
|
||||
{
|
||||
path: 'router.activation_threshold', yaml_key: 'activation_threshold', go_type: 'float64', ui_type: 'float',
|
||||
section: 'other', label: 'Activation Threshold', component: 'slider',
|
||||
min: 0, max: 1, step: 0.05,
|
||||
description: 'Softmax-probability floor a policy must clear to join the active label set.',
|
||||
order: 233,
|
||||
},
|
||||
{
|
||||
path: 'router.policies', yaml_key: 'policies', go_type: '[]RouterPolicy', ui_type: 'object',
|
||||
section: 'other', label: 'Policies', component: 'router-policies',
|
||||
description: 'Label vocabulary the classifier scores over.',
|
||||
order: 235,
|
||||
},
|
||||
{
|
||||
path: 'router.candidates', yaml_key: 'candidates', go_type: '[]RouterCandidate', ui_type: 'object',
|
||||
section: 'other', label: 'Candidates', component: 'router-candidates',
|
||||
description: 'Routing table: each entry binds a downstream model to a set of policy labels.',
|
||||
order: 236,
|
||||
},
|
||||
{
|
||||
path: 'router.embedding_cache.embedding_model', yaml_key: 'embedding_model', go_type: 'string', ui_type: 'string',
|
||||
section: 'other', label: 'L2 Cache: Embedding Model', component: 'model-select', autocomplete_provider: 'models',
|
||||
description: 'Embedding model used by the L2 decision cache.',
|
||||
order: 237,
|
||||
},
|
||||
{
|
||||
path: 'router.embedding_cache.similarity_threshold', yaml_key: 'similarity_threshold', go_type: 'float64', ui_type: 'float',
|
||||
section: 'other', label: 'L2 Cache: Similarity Threshold', component: 'slider',
|
||||
min: 0, max: 1, step: 0.01,
|
||||
description: 'Cosine-similarity floor a cache candidate must clear to count as a hit.',
|
||||
order: 238,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const MIDDLEWARE_STATUS = {
|
||||
pii: { enabled_globally: false, patterns: [], models: [], recent_event_count: 0 },
|
||||
router: { configured: false, models: [], recent_decision_count: 0, available_classifiers: ['score'] },
|
||||
mitm: { running: false, listen_addr: '', configured_addr: '', host_owners: {}, host_conflicts: {}, models: [], ca_available: false, ca_cert_url: '' },
|
||||
}
|
||||
|
||||
test.describe('Router template — create flow', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.route('**/api/auth/status', (route) =>
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({ authEnabled: false, staticApiKeyRequired: false, providers: [] }),
|
||||
})
|
||||
)
|
||||
await page.route('**/api/middleware/status', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify(MIDDLEWARE_STATUS) })
|
||||
)
|
||||
await page.route('**/api/router/decisions?**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ decisions: [] }) })
|
||||
)
|
||||
await page.route('**/api/pii/events?**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ events: [] }) })
|
||||
)
|
||||
await page.route('**/api/models/config-metadata*', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify(ROUTER_METADATA) })
|
||||
)
|
||||
await page.route('**/api/models/config-metadata/autocomplete/**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ values: [] }) })
|
||||
)
|
||||
|
||||
// Surface any uncaught render-time error so the assertion fails
|
||||
// with a useful message rather than the test silently passing.
|
||||
page.on('pageerror', (err) => {
|
||||
throw new Error(`uncaught page error: ${err.message}`)
|
||||
})
|
||||
})
|
||||
|
||||
test('Routing tab links to the model editor with the router template loaded', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Routing/i }).click()
|
||||
|
||||
// Empty-state button is the primary CTA.
|
||||
await page.getByRole('button', { name: /Create routing model/i }).click()
|
||||
|
||||
// Editor loads on a /app/model-editor URL with template=router.
|
||||
await expect(page).toHaveURL(/\/app\/model-editor.*template=router/)
|
||||
})
|
||||
|
||||
test('Router template renders without crashing on structured candidates/policies', async ({ page }) => {
|
||||
// Navigate straight to the create-with-template URL. This was the
|
||||
// regression that crashed with "(intermediate value).split is not
|
||||
// a function" when the template's array-shaped router.candidates
|
||||
// fell into a code-editor wrapper.
|
||||
await page.goto('/app/model-editor?template=router')
|
||||
|
||||
// The react-router error overlay must not appear.
|
||||
await expect(page.getByText(/Unexpected Application Error/i)).toHaveCount(0)
|
||||
|
||||
// Editor surface visible. Template URL is "create mode", so the
|
||||
// heading reads "Add Model" rather than "Model Editor".
|
||||
await expect(page.locator('h1.page-title')).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
// Top-level field labels seeded by the template are visible.
|
||||
// embedding_cache.* fields are surfaced via "Add Field" search
|
||||
// rather than active by default — separate spec covers them.
|
||||
await expect(page.getByText('Classifier').first()).toBeVisible()
|
||||
await expect(page.getByText('Policies').first()).toBeVisible()
|
||||
await expect(page.getByText('Candidates').first()).toBeVisible()
|
||||
await expect(page.getByText('Activation Threshold').first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('Classifier select offers only the score option', async ({ page }) => {
|
||||
await page.goto('/app/model-editor?template=router')
|
||||
|
||||
// SearchableSelect renders the current option's *label* inside the
|
||||
// trigger button. After the schema cleanup the only option is
|
||||
// "Score (Arch-Router-style)", pre-selected by the template.
|
||||
await expect(page.getByText('Score (Arch-Router-style)').first()).toBeVisible({ timeout: 10_000 })
|
||||
})
|
||||
|
||||
test('Policies editor renders structured rows with label + description fields', async ({ page }) => {
|
||||
await page.goto('/app/model-editor?template=router')
|
||||
|
||||
// The template seeds three example policies. Their labels are
|
||||
// pre-populated in input fields with monospace styling — the
|
||||
// editor signature is "Add policy" button + label/description
|
||||
// input pairs.
|
||||
await expect(page.getByRole('button', { name: /Add policy/i }).first()).toBeVisible()
|
||||
|
||||
// Pre-seeded labels visible as input values. RouterPoliciesEditor
|
||||
// renders each label in an input with a recognisable placeholder;
|
||||
// assert on their values by position.
|
||||
const labelInputs = page.locator('input[placeholder^="label ("]')
|
||||
await expect(labelInputs.nth(0)).toHaveValue('code-generation')
|
||||
await expect(labelInputs.nth(1)).toHaveValue('casual-chat')
|
||||
await expect(labelInputs.nth(2)).toHaveValue('math-reasoning')
|
||||
})
|
||||
|
||||
test('Candidates editor renders {model, labels} rows with policy-aware label chips', async ({ page }) => {
|
||||
await page.goto('/app/model-editor?template=router')
|
||||
|
||||
// "Add candidate" is the signature of the new RouterCandidatesEditor.
|
||||
await expect(page.getByRole('button', { name: /Add candidate/i }).first()).toBeVisible()
|
||||
|
||||
// Each candidate row should expose move-up/move-down controls,
|
||||
// a model picker, and label chips. The chip for a known policy
|
||||
// label appears as a button with the policy's label text.
|
||||
// Pre-seeded template: candidate[0] has labels=['casual-chat'];
|
||||
// candidate[1] has labels=['code-generation', 'casual-chat', 'math-reasoning'].
|
||||
//
|
||||
// The chips appear inside a flex row of buttons. Using getByRole
|
||||
// with the exact name catches typos/regressions cleanly.
|
||||
await expect(page.getByRole('button', { name: 'casual-chat' }).first()).toBeVisible()
|
||||
await expect(page.getByRole('button', { name: 'code-generation' }).first()).toBeVisible()
|
||||
await expect(page.getByRole('button', { name: 'math-reasoning' }).first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('Adding a duplicate policy label flags the duplicate row', async ({ page }) => {
|
||||
await page.goto('/app/model-editor?template=router')
|
||||
|
||||
// Add a new empty policy row, then type a duplicate of the
|
||||
// existing 'casual-chat'. The duplicate detection in
|
||||
// RouterPoliciesEditor sets a warning border via inline style.
|
||||
await page.getByRole('button', { name: /Add policy/i }).first().click()
|
||||
|
||||
// Find the newly-added empty label input (placeholder catches it).
|
||||
const newLabel = page.locator('input[placeholder*="label (e.g. code-generation)"]').last()
|
||||
await newLabel.fill('casual-chat')
|
||||
|
||||
// Both rows now hold the same label. The duplicate-detection
|
||||
// logic flags the row visually; we assert on the title attribute
|
||||
// RouterPoliciesEditor sets on the input when duplicate=true.
|
||||
await expect(
|
||||
page.locator('input[title="Duplicate label — candidates won\'t be able to distinguish them"]').first()
|
||||
).toBeVisible()
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user