From 6a80e2373392f00cf03b2a61cbca1b8ac0b09541 Mon Sep 17 00:00:00 2001 From: Richard Palethorpe Date: Mon, 25 May 2026 08:28:27 +0100 Subject: [PATCH] feat(middleware): Model routing, PII filtering, Cloud model proxies (#9802) Add a routing middleware stack and a cloud-proxy backend. * cloud-proxy: a Go gRPC backend that forwards OpenAI- and Anthropic-shaped chat requests to upstream providers, with an optional translate mode (OpenAI request -> Anthropic /v1/messages -> OpenAI response) and full tool-calling support. * routing: admission control, content-aware model routing (embedding cache + classifier + rerank + Arch-Router score), PII detection/redaction (regex + NER) with streaming filter and OpenAI/Anthropic adapters, and a per-user/per-key billing recorder backed by GORM or in-memory storage. * middleware: UsageMiddleware records usage via the billing recorder, plus admission, route-model, usage-stamp and trace middlewares. * observability: BackendTrace ring buffer stores full request bodies (capped), MITM proxy emits structured trace events, and router classifier decisions surface at /api/router/decide. * gallery: Arch-Router-1.5B (Q4_K_M and Q8_0). * UI: cloud-proxy model-editor fields, classifier system-prompt and score-normalization config, and a Traces page rendering request bodies. Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash] Signed-off-by: Richard Palethorpe --- .dockerignore | 9 + .gitignore | 4 + Makefile | 15 +- backend/backend.proto | 151 +++ backend/cpp/llama-cpp/grpc-server.cpp | 248 ++++ backend/go/cloud-proxy/Makefile | 12 + .../go/cloud-proxy/cloud_proxy_suite_test.go | 16 + backend/go/cloud-proxy/main.go | 39 + backend/go/cloud-proxy/package.sh | 13 + .../go/cloud-proxy/passthrough_edge_test.go | 270 ++++ backend/go/cloud-proxy/provider_anthropic.go | 508 ++++++++ .../go/cloud-proxy/provider_anthropic_test.go | 334 +++++ backend/go/cloud-proxy/provider_edge_test.go | 119 ++ backend/go/cloud-proxy/provider_openai.go | 320 +++++ .../go/cloud-proxy/provider_openai_test.go | 170 +++ backend/go/cloud-proxy/proxy.go | 429 +++++++ backend/go/cloud-proxy/proxy_test.go | 206 +++ backend/go/cloud-proxy/run.sh | 6 + backend/go/cloud-proxy/toolcalls_test.go | 232 ++++ backend/go/local-store/debug.go | 2 +- backend/go/local-store/store.go | 686 ++++------ backend/go/local-store/store_suite_test.go | 13 + backend/go/local-store/store_test.go | 284 +++++ backend/python/transformers/backend.py | 50 +- backend/python/vllm/backend.py | 127 ++ core/application/application.go | 129 ++ core/application/mitm.go | 146 +++ core/application/router_factories.go | 63 + .../runtime_settings_branding_test.go | 22 + core/application/startup.go | 150 +++ core/backend/embeddings.go | 27 + core/backend/options.go | 12 + core/backend/rerank.go | 45 + core/backend/score.go | 159 +++ core/backend/score_test.go | 63 + core/backend/stores.go | 62 + core/cli/run.go | 6 + core/config/application_config.go | 94 ++ core/config/meta/constants.go | 13 +- core/config/meta/registry.go | 213 ++++ core/config/meta/registry_coverage_test.go | 258 ++++ core/config/mitm_host_owners_test.go | 133 ++ core/config/model_config.go | 497 +++++++- core/config/model_config_loader.go | 43 + core/config/model_config_test.go | 158 +++ core/config/runtime_settings.go | 22 + core/config/runtime_settings_persist_test.go | 19 + core/explorer/empty_db.json.lock | 0 core/explorer/test_db.json.lock | 0 core/http/app.go | 44 +- core/http/auth/usage.go | 28 + .../anthropic/anthropic_suite_test.go | 13 + core/http/endpoints/anthropic/messages.go | 141 ++- .../endpoints/anthropic/messages_pii_test.go | 114 ++ .../endpoints/localai/api_instructions.go | 24 + .../localai/api_instructions_test.go | 6 +- core/http/endpoints/localai/import_model.go | 10 +- core/http/endpoints/localai/mcp.go | 6 +- core/http/endpoints/localai/pii_decide.go | 85 ++ .../http/endpoints/localai/pii_decide_test.go | 107 ++ core/http/endpoints/localai/router_decide.go | 109 ++ .../endpoints/localai/router_decide_test.go | 248 ++++ core/http/endpoints/localai/score.go | 90 ++ core/http/endpoints/localai/settings.go | 10 + .../endpoints/mcp/localai_assistant_test.go | 28 + core/http/endpoints/openai/chat.go | 120 +- core/http/endpoints/openai/completion.go | 74 +- core/http/endpoints/openai/edit.go | 2 + core/http/endpoints/openai/embeddings.go | 9 + core/http/endpoints/openai/realtime.go | 6 +- core/http/endpoints/openai/realtime_model.go | 166 ++- core/http/middleware/admission.go | 81 ++ core/http/middleware/admission_test.go | 118 ++ core/http/middleware/context_keys.go | 50 + core/http/middleware/request.go | 11 + core/http/middleware/request_test.go | 50 +- core/http/middleware/route_model.go | 603 +++++++++ core/http/middleware/route_model_test.go | 551 ++++++++ core/http/middleware/trace.go | 12 + core/http/middleware/usage.go | 402 +++--- core/http/middleware/usage_stamp.go | 33 + core/http/middleware/usage_test.go | 393 ++++-- .../http/react-ui/e2e/middleware-page.spec.js | 308 +++++ .../http/react-ui/e2e/router-template.spec.js | 219 ++++ .../http/react-ui/e2e/usage-dashboard.spec.js | 148 +++ .../react-ui/e2e/users-tab-gating.spec.js | 74 ++ core/http/react-ui/public/locales/en/nav.json | 1 + core/http/react-ui/src/App.css | 18 + .../src/components/ConfigFieldRenderer.jsx | 70 +- .../src/components/PIIPatternListEditor.jsx | 120 ++ .../src/components/RequireAuthEnabled.jsx | 16 + .../src/components/RouterCandidatesEditor.jsx | 185 +++ .../src/components/RouterPoliciesEditor.jsx | 109 ++ core/http/react-ui/src/components/Sidebar.jsx | 3 +- .../src/components/StructuredCodeEditor.jsx | 80 ++ .../react-ui/src/contexts/FormContext.jsx | 26 + core/http/react-ui/src/hooks/useChat.js | 6 +- core/http/react-ui/src/pages/Manage.jsx | 15 +- core/http/react-ui/src/pages/Middleware.jsx | 1108 +++++++++++++++++ core/http/react-ui/src/pages/ModelEditor.jsx | 3 + core/http/react-ui/src/pages/Traces.jsx | 17 + core/http/react-ui/src/pages/Usage.jsx | 40 +- core/http/react-ui/src/router.jsx | 5 +- core/http/react-ui/src/utils/capabilities.js | 1 + .../http/react-ui/src/utils/modelTemplates.js | 90 ++ core/http/routes/anthropic.go | 29 +- core/http/routes/localai.go | 5 + core/http/routes/middleware.go | 362 ++++++ core/http/routes/ollama.go | 2 +- core/http/routes/openai.go | 43 +- core/http/routes/openresponses.go | 6 +- core/http/routes/pii.go | 260 ++++ core/http/routes/usage.go | 157 +++ core/http/routes/usage_test.go | 135 ++ core/schema/localai.go | 95 ++ core/schema/message.go | 12 +- core/schema/message_test.go | 53 + core/schema/openai.go | 50 +- core/schema/prediction.go | 8 +- core/services/cloudproxy/backend_forward.go | 237 ++++ .../cloudproxy/backend_forward_test.go | 160 +++ core/services/cloudproxy/build_filter_test.go | 72 ++ core/services/cloudproxy/mitm/ca.go | 177 +++ core/services/cloudproxy/mitm/ca_test.go | 79 ++ core/services/cloudproxy/mitm/handler.go | 442 +++++++ core/services/cloudproxy/mitm/handler_test.go | 329 +++++ core/services/cloudproxy/mitm/http2_test.go | 165 +++ core/services/cloudproxy/mitm/leaf.go | 102 ++ core/services/cloudproxy/mitm/leaf_test.go | 103 ++ .../cloudproxy/mitm/mitm_suite_test.go | 13 + core/services/cloudproxy/mitm/proxy.go | 306 +++++ core/services/cloudproxy/mitm/proxy_test.go | 278 +++++ core/services/cloudproxy/mitm/response.go | 105 ++ core/services/cloudproxy/mitm/restart_test.go | 98 ++ core/services/cloudproxy/proxy.go | 125 ++ core/services/cloudproxy/proxy_suite_test.go | 13 + core/services/cloudproxy/proxy_test.go | 38 + core/services/cloudproxy/ssewire/ssewire.go | 218 ++++ .../cloudproxy/ssewire/ssewire_suite_test.go | 13 + .../cloudproxy/ssewire/ssewire_test.go | 114 ++ core/services/monitoring/metrics.go | 10 + core/services/nodes/health_mock_test.go | 9 + core/services/nodes/inflight_test.go | 12 + core/services/routing/admission/admission.go | 105 ++ .../routing/admission/admission_suite_test.go | 13 + .../routing/admission/admission_test.go | 103 ++ core/services/routing/billing/backend.go | 52 + .../routing/billing/billing_suite_test.go | 13 + core/services/routing/billing/disabled.go | 20 + core/services/routing/billing/gorm.go | 111 ++ core/services/routing/billing/inmem.go | 157 +++ core/services/routing/billing/inmem_test.go | 140 +++ core/services/routing/billing/local_user.go | 84 ++ .../routing/billing/local_user_test.go | 70 ++ core/services/routing/billing/prom.go | 215 ++++ .../services/routing/billing/recorder_test.go | 82 ++ core/services/routing/contract/contract.go | 55 + core/services/routing/contract/strict_off.go | 5 + core/services/routing/contract/strict_on.go | 9 + core/services/routing/pii/config.go | 71 ++ core/services/routing/pii/config_test.go | 56 + core/services/routing/pii/middleware.go | 260 ++++ core/services/routing/pii/middleware_test.go | 309 +++++ core/services/routing/pii/ner.go | 97 ++ core/services/routing/pii/ner_test.go | 174 +++ core/services/routing/pii/patterns.go | 188 +++ core/services/routing/pii/pii_suite_test.go | 13 + core/services/routing/pii/redactor.go | 342 +++++ .../routing/pii/redactor_race_test.go | 66 + core/services/routing/pii/redactor_test.go | 184 +++ core/services/routing/pii/store.go | 130 ++ core/services/routing/pii/stream.go | 197 +++ core/services/routing/pii/stream_test.go | 184 +++ core/services/routing/pii/types.go | 170 +++ core/services/routing/piiadapter/anthropic.go | 81 ++ .../routing/piiadapter/anthropic_test.go | 69 + core/services/routing/piiadapter/openai.go | 112 ++ .../routing/piiadapter/openai_test.go | 93 ++ .../piiadapter/piiadapter_suite_test.go | 13 + core/services/routing/router/cache.go | 96 ++ core/services/routing/router/decisions.go | 166 +++ .../routing/router/embedding_cache.go | 227 ++++ .../routing/router/embedding_cache_test.go | 311 +++++ core/services/routing/router/registry.go | 76 ++ core/services/routing/router/rerank.go | 104 ++ core/services/routing/router/rerank_test.go | 121 ++ core/services/routing/router/resolve.go | 203 +++ core/services/routing/router/resolve_test.go | 130 ++ .../routing/router/router_suite_test.go | 13 + core/services/routing/router/score.go | 423 +++++++ core/services/routing/router/score_test.go | 337 +++++ core/services/routing/router/types.go | 133 ++ core/trace/backend_trace.go | 26 +- docs/content/features/cloud-proxy.md | 232 ++++ docs/content/features/middleware.md | 509 ++++++++ docs/content/features/mitm-proxy.md | 159 +++ gallery/bge-m3-colbert.yaml | 11 + gallery/index.yaml | 69 + go.mod | 2 +- pkg/grpc/backend.go | 11 + pkg/grpc/base/base.go | 5 + pkg/grpc/client.go | 111 ++ pkg/grpc/embed.go | 121 ++ pkg/grpc/forward_test.go | 94 ++ pkg/grpc/integration_toolcalls_test.go | 147 +++ pkg/grpc/interface.go | 25 + pkg/grpc/rich_test.go | 129 ++ pkg/grpc/server.go | 90 +- pkg/mcp/localaitools/client.go | 39 + pkg/mcp/localaitools/coverage_test.go | 28 +- pkg/mcp/localaitools/dto.go | 173 +++ pkg/mcp/localaitools/fakes_test.go | 81 +- pkg/mcp/localaitools/httpapi/client.go | 187 ++- pkg/mcp/localaitools/httpapi/routes.go | 38 +- pkg/mcp/localaitools/inproc/client.go | 320 ++++- pkg/mcp/localaitools/server.go | 3 + pkg/mcp/localaitools/server_test.go | 14 + pkg/mcp/localaitools/tools.go | 8 + pkg/mcp/localaitools/tools_middleware.go | 78 ++ pkg/mcp/localaitools/tools_pii.go | 45 + pkg/mcp/localaitools/tools_usage.go | 22 + pkg/model/connection_evicting_client.go | 12 + pkg/store/client.go | 82 +- pkg/store/proto.go | 46 + tests/e2e-ui/main.go | 39 +- tests/e2e/cloud_proxy_helpers_test.go | 206 +++ tests/e2e/e2e_cloud_proxy_test.go | 268 ++++ tests/e2e/e2e_suite_test.go | 105 ++ tests/e2e/mock-backend/main.go | 13 +- 229 files changed, 26339 insertions(+), 1030 deletions(-) create mode 100644 backend/go/cloud-proxy/Makefile create mode 100644 backend/go/cloud-proxy/cloud_proxy_suite_test.go create mode 100644 backend/go/cloud-proxy/main.go create mode 100755 backend/go/cloud-proxy/package.sh create mode 100644 backend/go/cloud-proxy/passthrough_edge_test.go create mode 100644 backend/go/cloud-proxy/provider_anthropic.go create mode 100644 backend/go/cloud-proxy/provider_anthropic_test.go create mode 100644 backend/go/cloud-proxy/provider_edge_test.go create mode 100644 backend/go/cloud-proxy/provider_openai.go create mode 100644 backend/go/cloud-proxy/provider_openai_test.go create mode 100644 backend/go/cloud-proxy/proxy.go create mode 100644 backend/go/cloud-proxy/proxy_test.go create mode 100755 backend/go/cloud-proxy/run.sh create mode 100644 backend/go/cloud-proxy/toolcalls_test.go create mode 100644 backend/go/local-store/store_suite_test.go create mode 100644 backend/go/local-store/store_test.go create mode 100644 core/application/mitm.go create mode 100644 core/application/router_factories.go create mode 100644 core/backend/score.go create mode 100644 core/backend/score_test.go create mode 100644 core/config/meta/registry_coverage_test.go create mode 100644 core/config/mitm_host_owners_test.go create mode 100644 core/explorer/empty_db.json.lock create mode 100644 core/explorer/test_db.json.lock create mode 100644 core/http/endpoints/anthropic/anthropic_suite_test.go create mode 100644 core/http/endpoints/anthropic/messages_pii_test.go create mode 100644 core/http/endpoints/localai/pii_decide.go create mode 100644 core/http/endpoints/localai/pii_decide_test.go create mode 100644 core/http/endpoints/localai/router_decide.go create mode 100644 core/http/endpoints/localai/router_decide_test.go create mode 100644 core/http/endpoints/localai/score.go create mode 100644 core/http/middleware/admission.go create mode 100644 core/http/middleware/admission_test.go create mode 100644 core/http/middleware/context_keys.go create mode 100644 core/http/middleware/route_model.go create mode 100644 core/http/middleware/route_model_test.go create mode 100644 core/http/middleware/usage_stamp.go create mode 100644 core/http/react-ui/e2e/middleware-page.spec.js create mode 100644 core/http/react-ui/e2e/router-template.spec.js create mode 100644 core/http/react-ui/e2e/usage-dashboard.spec.js create mode 100644 core/http/react-ui/e2e/users-tab-gating.spec.js create mode 100644 core/http/react-ui/src/components/PIIPatternListEditor.jsx create mode 100644 core/http/react-ui/src/components/RequireAuthEnabled.jsx create mode 100644 core/http/react-ui/src/components/RouterCandidatesEditor.jsx create mode 100644 core/http/react-ui/src/components/RouterPoliciesEditor.jsx create mode 100644 core/http/react-ui/src/components/StructuredCodeEditor.jsx create mode 100644 core/http/react-ui/src/contexts/FormContext.jsx create mode 100644 core/http/react-ui/src/pages/Middleware.jsx create mode 100644 core/http/routes/middleware.go create mode 100644 core/http/routes/pii.go create mode 100644 core/http/routes/usage.go create mode 100644 core/http/routes/usage_test.go create mode 100644 core/services/cloudproxy/backend_forward.go create mode 100644 core/services/cloudproxy/backend_forward_test.go create mode 100644 core/services/cloudproxy/build_filter_test.go create mode 100644 core/services/cloudproxy/mitm/ca.go create mode 100644 core/services/cloudproxy/mitm/ca_test.go create mode 100644 core/services/cloudproxy/mitm/handler.go create mode 100644 core/services/cloudproxy/mitm/handler_test.go create mode 100644 core/services/cloudproxy/mitm/http2_test.go create mode 100644 core/services/cloudproxy/mitm/leaf.go create mode 100644 core/services/cloudproxy/mitm/leaf_test.go create mode 100644 core/services/cloudproxy/mitm/mitm_suite_test.go create mode 100644 core/services/cloudproxy/mitm/proxy.go create mode 100644 core/services/cloudproxy/mitm/proxy_test.go create mode 100644 core/services/cloudproxy/mitm/response.go create mode 100644 core/services/cloudproxy/mitm/restart_test.go create mode 100644 core/services/cloudproxy/proxy.go create mode 100644 core/services/cloudproxy/proxy_suite_test.go create mode 100644 core/services/cloudproxy/proxy_test.go create mode 100644 core/services/cloudproxy/ssewire/ssewire.go create mode 100644 core/services/cloudproxy/ssewire/ssewire_suite_test.go create mode 100644 core/services/cloudproxy/ssewire/ssewire_test.go create mode 100644 core/services/routing/admission/admission.go create mode 100644 core/services/routing/admission/admission_suite_test.go create mode 100644 core/services/routing/admission/admission_test.go create mode 100644 core/services/routing/billing/backend.go create mode 100644 core/services/routing/billing/billing_suite_test.go create mode 100644 core/services/routing/billing/disabled.go create mode 100644 core/services/routing/billing/gorm.go create mode 100644 core/services/routing/billing/inmem.go create mode 100644 core/services/routing/billing/inmem_test.go create mode 100644 core/services/routing/billing/local_user.go create mode 100644 core/services/routing/billing/local_user_test.go create mode 100644 core/services/routing/billing/prom.go create mode 100644 core/services/routing/billing/recorder_test.go create mode 100644 core/services/routing/contract/contract.go create mode 100644 core/services/routing/contract/strict_off.go create mode 100644 core/services/routing/contract/strict_on.go create mode 100644 core/services/routing/pii/config.go create mode 100644 core/services/routing/pii/config_test.go create mode 100644 core/services/routing/pii/middleware.go create mode 100644 core/services/routing/pii/middleware_test.go create mode 100644 core/services/routing/pii/ner.go create mode 100644 core/services/routing/pii/ner_test.go create mode 100644 core/services/routing/pii/patterns.go create mode 100644 core/services/routing/pii/pii_suite_test.go create mode 100644 core/services/routing/pii/redactor.go create mode 100644 core/services/routing/pii/redactor_race_test.go create mode 100644 core/services/routing/pii/redactor_test.go create mode 100644 core/services/routing/pii/store.go create mode 100644 core/services/routing/pii/stream.go create mode 100644 core/services/routing/pii/stream_test.go create mode 100644 core/services/routing/pii/types.go create mode 100644 core/services/routing/piiadapter/anthropic.go create mode 100644 core/services/routing/piiadapter/anthropic_test.go create mode 100644 core/services/routing/piiadapter/openai.go create mode 100644 core/services/routing/piiadapter/openai_test.go create mode 100644 core/services/routing/piiadapter/piiadapter_suite_test.go create mode 100644 core/services/routing/router/cache.go create mode 100644 core/services/routing/router/decisions.go create mode 100644 core/services/routing/router/embedding_cache.go create mode 100644 core/services/routing/router/embedding_cache_test.go create mode 100644 core/services/routing/router/registry.go create mode 100644 core/services/routing/router/rerank.go create mode 100644 core/services/routing/router/rerank_test.go create mode 100644 core/services/routing/router/resolve.go create mode 100644 core/services/routing/router/resolve_test.go create mode 100644 core/services/routing/router/router_suite_test.go create mode 100644 core/services/routing/router/score.go create mode 100644 core/services/routing/router/score_test.go create mode 100644 core/services/routing/router/types.go create mode 100644 docs/content/features/cloud-proxy.md create mode 100644 docs/content/features/middleware.md create mode 100644 docs/content/features/mitm-proxy.md create mode 100644 gallery/bge-m3-colbert.yaml create mode 100644 pkg/grpc/forward_test.go create mode 100644 pkg/grpc/integration_toolcalls_test.go create mode 100644 pkg/grpc/rich_test.go create mode 100644 pkg/mcp/localaitools/tools_middleware.go create mode 100644 pkg/mcp/localaitools/tools_pii.go create mode 100644 pkg/mcp/localaitools/tools_usage.go create mode 100644 pkg/store/proto.go create mode 100644 tests/e2e/cloud_proxy_helpers_test.go create mode 100644 tests/e2e/e2e_cloud_proxy_test.go diff --git a/.dockerignore b/.dockerignore index 5b62e5f31..e8904d3be 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,6 +4,7 @@ .devcontainer models backends +volumes examples/chatbot-ui/models backend/go/image/stablediffusion-ggml/build/ backend/go/*/build @@ -21,3 +22,11 @@ __pycache__ # backend virtual environments **/venv backend/python/**/source + +# In-place llama.cpp clone + per-variant build copies. The Makefile +# clones llama.cpp itself at the pinned LLAMA_VERSION; if a stale +# local checkout is COPY'd into the image, the `llama.cpp:` target +# sees the directory and skips re-cloning, so grpc-server.cpp ends +# up compiled against whatever (likely older) commit the host had. +backend/cpp/llama-cpp/llama.cpp +backend/cpp/llama-cpp-*-build diff --git a/.gitignore b/.gitignore index 08873e8b2..ec105ac85 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,10 @@ go-bert LocalAI /local-ai /local-ai-launcher +# Root-level build artifacts when running `go build ./...` against +# Go backend packages whose main lives under backend/go/. +/cloud-proxy +/local-store # prevent above rules from omitting the helm chart !charts/* # prevent above rules from omitting the api/localai folder diff --git a/Makefile b/Makefile index ebeef4c41..3eebc1871 100644 --- a/Makefile +++ b/Makefile @@ -69,7 +69,7 @@ else GORELEASER=$(shell which goreleaser) endif -TEST_PATHS?=./api/... ./pkg/... ./core/... +TEST_PATHS?=./api/... ./pkg/... ./core/... ./backend/go/cloud-proxy/... ./backend/go/local-store/... .PHONY: all test build vendor lint lint-all @@ -268,12 +268,13 @@ prepare-e2e: run-e2e-image: docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests -test-e2e: build-mock-backend prepare-e2e run-e2e-image +test-e2e: build-mock-backend build-cloud-proxy-backend prepare-e2e run-e2e-image @echo 'Running e2e tests' BUILD_TYPE=$(BUILD_TYPE) \ LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e $(MAKE) clean-mock-backend + $(MAKE) clean-cloud-proxy-backend $(MAKE) teardown-e2e docker rmi localai-tests @@ -1064,6 +1065,7 @@ BACKEND_DS4 = ds4|ds4|.|false|false # Golang backends BACKEND_PIPER = piper|golang|.|false|true BACKEND_LOCAL_STORE = local-store|golang|.|false|true +BACKEND_CLOUD_PROXY = cloud-proxy|golang|.|false|true BACKEND_HUGGINGFACE = huggingface|golang|.|false|true BACKEND_SILERO_VAD = silero-vad|golang|.|false|true BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true @@ -1149,6 +1151,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT))) $(eval $(call generate-docker-build-target,$(BACKEND_DS4))) $(eval $(call generate-docker-build-target,$(BACKEND_PIPER))) $(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE))) +$(eval $(call generate-docker-build-target,$(BACKEND_CLOUD_PROXY))) $(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE))) $(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD))) $(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML))) @@ -1201,7 +1204,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX))) docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx +docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy ######################################################## ### Mock Backend for E2E Tests @@ -1213,6 +1216,12 @@ build-mock-backend: protogen-go clean-mock-backend: rm -f tests/e2e/mock-backend/mock-backend +build-cloud-proxy-backend: protogen-go + $(GOCMD) build -o tests/e2e/mock-backend/cloud-proxy ./backend/go/cloud-proxy + +clean-cloud-proxy-backend: + rm -f tests/e2e/mock-backend/cloud-proxy + ######################################################## ### UI E2E Test Server ######################################################## diff --git a/backend/backend.proto b/backend/backend.proto index bf07f3bd4..8a0c8e696 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -37,6 +37,22 @@ service Backend { rpc Rerank(RerankRequest) returns (RerankResult) {} + // TokenClassify runs a token-classification (NER) model on the + // supplied text and returns each detected entity span. Used by the + // PII redactor's optional NER tier — the regex tier still handles + // formatted hits cheaply, while this catches names, locations, and + // other unformatted PII that regex misses. + rpc TokenClassify(TokenClassifyRequest) returns (TokenClassifyResponse) {} + + // Score evaluates the model's joint log-probability of each + // supplied candidate continuation given a shared prompt. The + // prompt's KV cache is computed once and reused across candidates. + // Used for routing-policy multi-label classification, reranking, + // calibrated confidence, and reward-model scoring — any task where + // the consumer wants the model's confidence in a pre-specified + // continuation rather than a generated one. + rpc Score(ScoreRequest) returns (ScoreResponse) {} + rpc GetMetrics(MetricsRequest) returns (MetricsResponse); rpc VAD(VADRequest) returns (VADResponse) {} @@ -68,6 +84,23 @@ service Backend { rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {} rpc StopQuantization(QuantizationStopRequest) returns (Result) {} + // Forward proxies a raw HTTP request to an upstream provider. The + // cloud-proxy backend implements this for passthrough-mode model + // configs: the client wire format is preserved end-to-end (no + // translation through internal proto), which means new provider + // fields work the day they ship. Translation-mode proxies use the + // standard Predict/PredictStream RPCs instead. Backends that don't + // support this return UNIMPLEMENTED. + // + // The request is bidirectionally streamed so large bodies can flow + // without buffering. In practice the first ForwardRequest carries + // path, method, headers, and the initial body chunk; subsequent + // messages append body chunks. The first ForwardReply carries the + // upstream status and response headers; subsequent messages stream + // body chunks (SSE frames or chunked transfer). Cancellation of the + // gRPC context closes the upstream connection. + rpc Forward(stream ForwardRequest) returns (stream ForwardReply) {} + } // Define the empty request @@ -81,6 +114,76 @@ message MetricsResponse { int32 prompt_tokens_processed = 5; } +// TokenClassifyRequest carries the text to classify plus an optional +// score threshold. The transformers backend interprets threshold as +// the minimum confidence to include in the response; 0 = include all. +message TokenClassifyRequest { + string text = 1; + float threshold = 2; +} + +// TokenClassifyEntity is one detected entity span. Byte offsets are +// into the original UTF-8 text — start..end is a half-open range that +// addresses the substring corresponding to entity_group. +// +// entity_group follows HuggingFace's aggregated-tag convention (e.g. +// "PER", "LOC", "ORG", or a PII-specific label like "EMAIL" / +// "SSN" depending on the model). The redactor's per-pattern action +// map keys off this string. +message TokenClassifyEntity { + string entity_group = 1; + int32 start = 2; + int32 end = 3; + float score = 4; + string text = 5; +} + +message TokenClassifyResponse { + repeated TokenClassifyEntity entities = 1; +} + +// ScoreRequest carries one shared prompt and one or more continuations +// to score against it. The backend tokenises the prompt once and reuses +// the resulting KV cache across all candidates in this request. +message ScoreRequest { + string prompt = 1; + repeated string candidates = 2; + // Return per-token logprobs for each candidate when true. Default + // false to keep the wire response small; the joint log_prob field + // covers the common ranking case. + bool include_token_logprobs = 3; + // When true, the response also populates length_normalized_log_prob + // (joint log-prob divided by candidate token count). Useful when + // candidates differ in length and the consumer wants a per-token + // measure comparable across them (PMI-style scoring). + bool length_normalize = 4; +} + +// CandidateScore is one row in the ScoreResponse, matching by index +// the candidate in ScoreRequest.candidates. +message CandidateScore { + // Sum of log P(token_i | prompt, candidate_token_ #include #include +#include #include #include #include @@ -121,6 +122,40 @@ static std::string base64_encode_bytes(const unsigned char* data, size_t len) { bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model +// Score bypasses the slot loop (see the comment on Score below) so it +// must not run concurrently with any slot-loop RPC. These counters +// are a defence-in-depth tripwire — ModelConfig.Validate already +// rejects llama-cpp configs that mix score with chat/completion/ +// embeddings, so a healthy deployment never trips them. seq_cst is +// load-bearing for the increment-then-check pattern below. +static std::atomic slot_loop_inflight{0}; +static std::atomic score_inflight{0}; + +// Increment-then-check, not check-then-increment: two simultaneous +// racers both observe the other's increment and both abort cleanly. +// Reversed, both could see zero and proceed. +struct conflict_guard { + std::atomic& self; + conflict_guard(const char* rpc, std::atomic& self_, std::atomic& other, const char* other_name) + : self(self_) { + self.fetch_add(1, std::memory_order_seq_cst); + int o = other.load(std::memory_order_seq_cst); + if (o > 0) { + fprintf(stderr, + "FATAL: %s called with %s=%d. The llama-cpp backend cannot " + "service Score and slot-loop RPCs concurrently — Score " + "bypasses the slot loop and races the llama_context. Bind " + "Score-using features to a model dedicated to scoring " + "(known_usecases: [score] with no chat/completion/embeddings).\n", + rpc, other_name, o); + std::abort(); + } + } + ~conflict_guard() { + self.fetch_sub(1, std::memory_order_seq_cst); + } +}; + static std::function shutdown_handler; static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; @@ -1446,6 +1481,7 @@ public: if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } + conflict_guard guard("PredictStream", slot_loop_inflight, score_inflight, "score_inflight"); json data = parse_options(true, request, params_base, ctx_server.get_llama_context()); @@ -2205,6 +2241,7 @@ public: if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } + conflict_guard guard("Predict", slot_loop_inflight, score_inflight, "score_inflight"); json data = parse_options(true, request, params_base, ctx_server.get_llama_context()); data["stream"] = false; @@ -2963,6 +3000,7 @@ public: if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } + conflict_guard guard("Embedding", slot_loop_inflight, score_inflight, "score_inflight"); json body = parse_options(false, request, params_base, ctx_server.get_llama_context()); body["stream"] = false; @@ -3070,6 +3108,8 @@ public: return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array"); } + conflict_guard guard("Rerank", slot_loop_inflight, score_inflight, "score_inflight"); + // Create and queue the task auto rd = ctx_server.get_response_reader(); { @@ -3142,12 +3182,218 @@ public: return grpc::Status::OK; } + // Score returns the model's joint log-probability of each candidate + // continuation given a shared prompt. + // + // WHY bypass the slot/task queue: upstream server_context exposes + // get_llama_context as "main thread only" and the slot loop's + // update_slots() owns the context whenever a task is in flight. + // No public synchronization primitive is available — so Score is + // unsafe to call concurrently with active generation through this + // backend. In practice routing-classifier calls happen before the + // request is routed to a generation backend, so the model used + // for Score is typically idle. Concurrent Score calls are + // serialised by a local mutex; KV-cache state is isolated behind + // a dedicated sequence ID cleared between candidates. + // + // A patch to server-context.cpp that adds SERVER_TASK_TYPE_SCORE + // and routes scoring through the slot loop would be the correct + // long-term fix; tracked as a follow-up. + // + // Perf TODO (measured: ~450 ms warm for 3 candidates on Arch- + // Router-1.5B Q4_K_M + Intel SYCL): the current loop re-decodes + // `prompt + candidate` from scratch for every candidate, throwing + // away the prompt's KV cache between iterations. A smarter + // version would: + // 1. Decode just the prompt once into score_seq_id. + // 2. Snapshot/cp that sequence (llama_memory_seq_cp) into a + // per-candidate sequence id. + // 3. For each candidate, decode only its tokens onto the copy + // (continuing from the saved prompt state), read logits. + // 4. llama_memory_seq_rm the copy. + // Estimated speedup: 3-candidate calls 450 ms -> ~150-200 ms, + // 6-candidate calls 630 ms -> ~220 ms. Single source-file change, + // no proto / Go-side changes needed. Worth doing once routing is + // wired into the middleware and Score is on the hot path of every + // chat request. + grpc::Status Score(ServerContext* context, const backend::ScoreRequest* request, backend::ScoreResponse* response) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; + if (params_base.model.path.empty()) { + return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); + } + if (request->candidates_size() == 0) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "candidates must be non-empty"); + } + + // Tripwire against the slot loop. Acquired before score_mutex + // so it fires even when this Score is queued behind another. + conflict_guard guard("Score", score_inflight, slot_loop_inflight, "slot_loop_inflight"); + + // Serialise concurrent Score calls. The slot loop is still + // free to race with us — see the class comment above. + static std::mutex score_mutex; + std::lock_guard score_lock(score_mutex); + + llama_context * lctx = ctx_server.get_llama_context(); + if (lctx == nullptr) { + return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "llama context unavailable (sleeping?)"); + } + const llama_vocab * vocab = ctx_server.impl->vocab; + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + const int32_t n_ctx = llama_n_ctx(lctx); + llama_memory_t mem = llama_get_memory(lctx); + + // The KV-cache is sized to seq_to_stream.size() at load + // (typically equal to n_slots, often 1). Sequence IDs must + // be in [0, n_seq_max), so we can't pick a high-value + // "private" ID — we have to share with the slot. We clear + // the cache before AND after each candidate to keep + // scoring isolated from whatever state the slot held, and + // the static mutex above guarantees no other Score call is + // racing in the meantime. The slot loop is still free to + // race (see comment on this method) — Score must not run + // concurrently with generation through this backend. + const llama_seq_id score_seq_id = 0; + llama_memory_seq_rm(mem, score_seq_id, -1, -1); + + // Tokenize the shared prompt once with add_special=true so + // BOS is prepended when the model requires it. parse_special + // keeps chat-template markers in the prompt intact. + const std::string prompt = request->prompt(); + std::vector prompt_tokens = common_tokenize(vocab, prompt, /*add_special=*/true, /*parse_special=*/true); + const int32_t prompt_len = (int32_t) prompt_tokens.size(); + + for (int ci = 0; ci < request->candidates_size(); ci++) { + const std::string & candidate_text = request->candidates(ci); + + // Re-tokenize prompt + candidate as a single string. BPE + // merges across the boundary can shift the tokenization + // versus tokenize(prompt) ++ tokenize(candidate), so we + // find the divergence point against prompt_tokens. + std::vector full_tokens = common_tokenize(vocab, prompt + candidate_text, /*add_special=*/true, /*parse_special=*/true); + int32_t divergence = prompt_len; + const int32_t min_len = std::min(prompt_len, (int32_t) full_tokens.size()); + for (int32_t i = 0; i < min_len; i++) { + if (prompt_tokens[i] != full_tokens[i]) { + divergence = i; + break; + } + } + const int32_t cand_len = (int32_t) full_tokens.size() - divergence; + backend::CandidateScore * cs = response->add_candidates(); + cs->set_num_tokens(cand_len); + if (cand_len <= 0) { + cs->set_log_prob(0.0); + if (request->length_normalize()) { + cs->set_length_normalized_log_prob(0.0); + } + continue; + } + if (divergence < 1) { + // Need at least one prior token (typically BOS) to + // predict the first candidate token's logit. Tokeniser + // models without BOS + an empty prompt fall in here. + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Score: prompt produced no leading tokens; need at least one (e.g. BOS) to predict candidate"); + } + if ((int32_t) full_tokens.size() > n_ctx) { + return grpc::Status(grpc::StatusCode::OUT_OF_RANGE, + "Score: prompt+candidate exceeds context size (got " + + std::to_string(full_tokens.size()) + ", n_ctx=" + std::to_string(n_ctx) + ")"); + } + + // Build a batch covering the entire prompt+candidate. We + // need logits at (divergence-1) onward — those are the + // predictions for each candidate token. + llama_batch batch = llama_batch_init((int32_t) full_tokens.size(), 0, 1); + for (int32_t i = 0; i < (int32_t) full_tokens.size(); i++) { + batch.token[i] = full_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = score_seq_id; + // logits[i] is "do we want the prediction *for the + // next token*, computed from this position?" + // We want predictions for candidate tokens at + // positions divergence .. full_tokens.size()-1, which + // come from logits at positions (divergence-1) .. + // (full_tokens.size()-2). + bool need_logit = (i >= divergence - 1) && (i < (int32_t) full_tokens.size() - 1); + batch.logits[i] = need_logit ? 1 : 0; + } + batch.n_tokens = (int32_t) full_tokens.size(); + + // Decode the batch. If decode fails (e.g. KV slot + // exhaustion), surface as INTERNAL — the caller will + // typically fall back to a sampling-based classifier. + int decode_err = llama_decode(lctx, batch); + if (decode_err != 0) { + llama_batch_free(batch); + llama_memory_seq_rm(mem, score_seq_id, -1, -1); + return grpc::Status(grpc::StatusCode::INTERNAL, + "llama_decode failed during Score: " + std::to_string(decode_err)); + } + + // Sum log-probabilities of the actual candidate tokens. + double total_log_prob = 0.0; + for (int32_t k = 0; k < cand_len; k++) { + // The k-th candidate token sits at full_tokens index + // (divergence + k). Its predicting logit is at batch + // position (divergence + k - 1). + int32_t logit_pos = divergence + k - 1; + const float * logits = llama_get_logits_ith(lctx, logit_pos); + if (logits == nullptr) { + llama_batch_free(batch); + llama_memory_seq_rm(mem, score_seq_id, -1, -1); + return grpc::Status(grpc::StatusCode::INTERNAL, + "llama_get_logits_ith returned null at position " + std::to_string(logit_pos)); + } + llama_token target_token = full_tokens[divergence + k]; + + // Compute log_softmax(logits)[target_token] with the + // max-subtraction stability trick. + float max_logit = logits[0]; + for (int32_t v = 1; v < n_vocab; v++) { + if (logits[v] > max_logit) max_logit = logits[v]; + } + double sum_exp = 0.0; + for (int32_t v = 0; v < n_vocab; v++) { + sum_exp += std::exp((double)(logits[v] - max_logit)); + } + double token_log_prob = (double)(logits[target_token] - max_logit) - std::log(sum_exp); + total_log_prob += token_log_prob; + + if (request->include_token_logprobs()) { + backend::TokenLogProb * tlp = cs->add_tokens(); + std::string piece = common_token_to_piece(lctx, target_token); + tlp->set_token(piece); + tlp->set_log_prob(token_log_prob); + } + } + + cs->set_log_prob(total_log_prob); + if (request->length_normalize() && cand_len > 0) { + cs->set_length_normalized_log_prob(total_log_prob / (double) cand_len); + } + + llama_batch_free(batch); + // Drop this candidate's KV-cache contribution so the next + // candidate starts from a clean state. Without this, the + // next decode would conflict at positions 0..N-1 for our + // sequence ID. + llama_memory_seq_rm(mem, score_seq_id, -1, -1); + } + + return grpc::Status::OK; + } + grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override { auto auth = checkAuth(context); if (!auth.ok()) return auth; if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } + conflict_guard guard("TokenizeString", slot_loop_inflight, score_inflight, "score_inflight"); json body = parse_options(false, request, params_base, ctx_server.get_llama_context()); body["stream"] = false; @@ -3169,6 +3415,8 @@ public: grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override { + conflict_guard guard("GetMetrics", slot_loop_inflight, score_inflight, "score_inflight"); + // request slots data using task queue auto rd = ctx_server.get_response_reader(); int task_id = rd.queue_tasks.get_new_id(); diff --git a/backend/go/cloud-proxy/Makefile b/backend/go/cloud-proxy/Makefile new file mode 100644 index 000000000..7900905cd --- /dev/null +++ b/backend/go/cloud-proxy/Makefile @@ -0,0 +1,12 @@ +GOCMD=go + +cloud-proxy: + CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o cloud-proxy ./ + +package: + bash package.sh + +build: cloud-proxy package + +clean: + rm -f cloud-proxy diff --git a/backend/go/cloud-proxy/cloud_proxy_suite_test.go b/backend/go/cloud-proxy/cloud_proxy_suite_test.go new file mode 100644 index 000000000..e6c8bf322 --- /dev/null +++ b/backend/go/cloud-proxy/cloud_proxy_suite_test.go @@ -0,0 +1,16 @@ +package main + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// Ginkgo bootstrap. The other Test* functions in this package use +// raw testing.T and run independently; they coexist with Ginkgo +// specs registered via Describe / Context. +func TestCloudProxySpecs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "cloud-proxy specs") +} diff --git a/backend/go/cloud-proxy/main.go b/backend/go/cloud-proxy/main.go new file mode 100644 index 000000000..7f75efb2a --- /dev/null +++ b/backend/go/cloud-proxy/main.go @@ -0,0 +1,39 @@ +package main + +// cloud-proxy is a LocalAI backend that forwards request traffic to an +// external HTTP provider (OpenAI, Anthropic, etc.). Two modes: +// +// - passthrough: serves the Forward RPC; the client wire format is +// preserved end-to-end, no translation. +// - translate: serves Predict/PredictStream; the backend converts +// internal proto to the provider's wire format. (Phases 5–6.) +// +// LoadModel reads UpstreamURL/Mode/Provider/key references from +// ProxyOptions and resolves the API key once at load time. + +import ( + "flag" + "os" + + grpc "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/xlog" + "golang.org/x/term" +) + +var addr = flag.String("addr", "localhost:50051", "the address to listen on") + +func main() { + // xlog's default handler emits ANSI color codes; that's fine for an + // interactive shell but unreadable when the backend's stdout is + // captured by LocalAI and tee'd to a log file. Force plain text when + // LOCALAI_LOG_FORMAT is unset and stdout isn't a terminal. + format := os.Getenv("LOCALAI_LOG_FORMAT") + if format == "" && !term.IsTerminal(int(os.Stdout.Fd())) { + format = xlog.TextFormat + } + xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), format)) + flag.Parse() + if err := grpc.StartServer(*addr, NewCloudProxy()); err != nil { + panic(err) + } +} diff --git a/backend/go/cloud-proxy/package.sh b/backend/go/cloud-proxy/package.sh new file mode 100755 index 000000000..da86cd003 --- /dev/null +++ b/backend/go/cloud-proxy/package.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Script to copy the cloud-proxy binary into the package dir for the +# final Dockerfile stage. Mirrors backend/go/local-store/package.sh — +# no extra runtime libs needed since the backend is pure Go. + +set -e + +CURDIR=$(dirname "$(realpath $0)") + +mkdir -p $CURDIR/package +cp -avf $CURDIR/cloud-proxy $CURDIR/package/ +cp -rfv $CURDIR/run.sh $CURDIR/package/ diff --git a/backend/go/cloud-proxy/passthrough_edge_test.go b/backend/go/cloud-proxy/passthrough_edge_test.go new file mode 100644 index 000000000..c727026f8 --- /dev/null +++ b/backend/go/cloud-proxy/passthrough_edge_test.go @@ -0,0 +1,270 @@ +package main + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strconv" + "sync" + + grpc "github.com/mudler/LocalAI/pkg/grpc" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("composeURL", func() { + // Upstream URL convention: gallery configs put the canonical path + // in upstream_url, so per-request Path is ignored. A bare-host + // upstream_url accepts the per-request path. + DescribeTable("path resolution", + func(upstream, reqPath, want string) { + got, err := composeURL(upstream, reqPath) + Expect(err).NotTo(HaveOccurred()) + Expect(got).To(Equal(want)) + }, + Entry("full path wins", "https://api.openai.com/v1/chat/completions", "/v1/something-else", "https://api.openai.com/v1/chat/completions"), + Entry("bare host accepts path", "https://api.openai.com", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"), + Entry("root slash treated as bare", "https://api.openai.com/", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"), + Entry("bare host + empty path", "https://api.openai.com", "", "https://api.openai.com"), + ) + + It("returns an error on invalid upstream URL", func() { + _, err := composeURL("://garbage", "") + Expect(err).To(HaveOccurred()) + }) +}) + +var _ = Describe("applyAuthHeader", func() { + It("sets x-api-key and anthropic-version for Anthropic, no Authorization", func() { + req, _ := http.NewRequest("POST", "https://example.com", nil) + applyAuthHeader(req, providerAnthropic, "ant-key") + Expect(req.Header.Get("x-api-key")).To(Equal("ant-key")) + Expect(req.Header.Get("anthropic-version")).NotTo(BeEmpty()) + Expect(req.Header.Get("Authorization")).To(BeEmpty(), "Authorization must not leak on Anthropic backend") + }) + + It("sets Bearer Authorization for OpenAI, no x-api-key", func() { + req, _ := http.NewRequest("POST", "https://example.com", nil) + applyAuthHeader(req, providerOpenAI, "sk-key") + Expect(req.Header.Get("Authorization")).To(Equal("Bearer sk-key")) + Expect(req.Header.Get("x-api-key")).To(BeEmpty(), "x-api-key must not leak on OpenAI backend") + }) + + It("defaults to Bearer when provider is empty", func() { + // Passthrough mode often has provider == "" because the operator + // doesn't claim a specific upstream wire format. Most providers + // (including OpenAI-compatible ones) accept Bearer, so default to it. + req, _ := http.NewRequest("POST", "https://example.com", nil) + applyAuthHeader(req, "", "some-key") + Expect(req.Header.Get("Authorization")).To(Equal("Bearer some-key")) + }) + + It("preserves an existing anthropic-version header", func() { + // If the client supplied anthropic-version (rare but legitimate + // for an upstream pinned to a specific date), the proxy must not + // clobber it. + req, _ := http.NewRequest("POST", "https://example.com", nil) + req.Header.Set("anthropic-version", "2024-10-01") + applyAuthHeader(req, providerAnthropic, "k") + Expect(req.Header.Get("anthropic-version")).To(Equal("2024-10-01")) + }) +}) + +var _ = Describe("isHopByHopHeader", func() { + DescribeTable("hop-by-hop classification", + func(header string, want bool) { + Expect(isHopByHopHeader(header)).To(Equal(want)) + }, + Entry("Connection is hop-by-hop", "Connection", true), + Entry("Keep-Alive is hop-by-hop", "Keep-Alive", true), + Entry("Proxy-Connection is hop-by-hop", "Proxy-Connection", true), + Entry("Transfer-Encoding is hop-by-hop", "Transfer-Encoding", true), + Entry("TE is hop-by-hop", "TE", true), + Entry("Trailer is hop-by-hop", "Trailer", true), + Entry("Upgrade is hop-by-hop", "Upgrade", true), + Entry("Host is hop-by-hop", "Host", true), + Entry("Content-Length is hop-by-hop", "Content-Length", true), + // Case-insensitive — RFC 7230 doesn't constrain header case. + Entry("lowercase connection is hop-by-hop", "connection", true), + Entry("uppercase HOST is hop-by-hop", "HOST", true), + // Non hop-by-hop — must NOT be stripped. + Entry("Authorization is end-to-end", "Authorization", false), + Entry("Content-Type is end-to-end", "Content-Type", false), + Entry("Accept is end-to-end", "Accept", false), + Entry("X-Custom is end-to-end", "X-Custom", false), + ) +}) + +var _ = Describe("Forward", func() { + It("strips hop-by-hop and Connection headers before upstream, preserves custom headers", func() { + gotConnection := make(chan string, 1) + gotXCustom := make(chan string, 1) + gotHost := make(chan string, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotConnection <- r.Header.Get("Connection") + gotXCustom <- r.Header.Get("X-Custom") + gotHost <- r.Header.Get("Host") + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + cp := NewCloudProxy() + Expect(cp.Load(&pb.ModelOptions{ + Proxy: &pb.ProxyOptions{ + UpstreamUrl: upstream.URL, + Mode: modePassthrough, + }, + })).To(Succeed()) + + addr := "test://forward-hopbyhop" + grpc.Provide(addr, cp) + c := grpc.NewClient(addr, true, nil, false) + stream, err := c.Forward(context.Background()) + Expect(err).NotTo(HaveOccurred()) + Expect(stream.Send(&pb.ForwardRequest{ + Path: "/v1/chat/completions", + Method: "POST", + Headers: []*pb.ForwardHeader{ + {Name: "Connection", Value: "keep-alive"}, + {Name: "Host", Value: "spoofed.example.com"}, + {Name: "X-Custom", Value: "preserved"}, + }, + })).To(Succeed()) + Expect(stream.CloseSend()).To(Succeed()) + _, _ = stream.Recv() + for { + if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil { + break + } + } + + Expect(<-gotConnection).To(BeEmpty(), "Connection must not leak to upstream") + Expect(<-gotHost).NotTo(Equal("spoofed.example.com"), "Host header must not be spoofed through") + Expect(<-gotXCustom).To(Equal("preserved"), "X-Custom header must survive") + }) + + It("replaces caller-supplied Authorization with the configured key", func() { + // The proxy must overwrite a client-supplied Authorization header + // so a downstream caller can't smuggle stale or wrong credentials. + gotAuth := make(chan string, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth <- r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + GinkgoT().Setenv("CLOUD_PROXY_AUTH_REPLACE_KEY", "sk-real") + + cp := NewCloudProxy() + Expect(cp.Load(&pb.ModelOptions{ + Proxy: &pb.ProxyOptions{ + UpstreamUrl: upstream.URL, + Mode: modePassthrough, + ApiKeyEnv: "CLOUD_PROXY_AUTH_REPLACE_KEY", + }, + })).To(Succeed()) + + addr := "test://forward-replaces-auth" + grpc.Provide(addr, cp) + c := grpc.NewClient(addr, true, nil, false) + stream, err := c.Forward(context.Background()) + Expect(err).NotTo(HaveOccurred()) + Expect(stream.Send(&pb.ForwardRequest{ + Path: "/v1/chat/completions", + Method: "POST", + Headers: []*pb.ForwardHeader{ + // Client-supplied Authorization with the wrong scheme / key. + {Name: "Authorization", Value: "Basic Zm9vOmJhcg=="}, + }, + })).To(Succeed()) + Expect(stream.CloseSend()).To(Succeed()) + _, _ = stream.Recv() + for { + if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil { + break + } + } + + Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced") + }) + + It("handles concurrent calls without interference", func() { + // CloudProxy explicitly omits base.SingleThread — independent + // Forward streams must not block each other or leak state. + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(body) + })) + defer upstream.Close() + + cp := NewCloudProxy() + Expect(cp.Load(&pb.ModelOptions{ + Proxy: &pb.ProxyOptions{ + UpstreamUrl: upstream.URL, + Mode: modePassthrough, + }, + })).To(Succeed()) + addr := "test://forward-concurrent" + grpc.Provide(addr, cp) + c := grpc.NewClient(addr, true, nil, false) + + const N = 8 + var wg sync.WaitGroup + errs := make(chan error, N) + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + stream, err := c.Forward(context.Background()) + if err != nil { + errs <- err + return + } + payload := "request-" + string(rune('A'+idx)) + if err := stream.Send(&pb.ForwardRequest{ + Path: "/v1/chat/completions", + Method: "POST", + BodyChunk: []byte(payload), + }); err != nil { + errs <- err + return + } + _ = stream.CloseSend() + _, _ = stream.Recv() + var body []byte + for { + r, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + errs <- err + return + } + body = append(body, r.GetBodyChunk()...) + } + if string(body) != payload { + errs <- &echoMismatch{want: payload, got: string(body)} + } + }(i) + } + wg.Wait() + close(errs) + var collected []error + for err := range errs { + collected = append(collected, err) + } + Expect(collected).To(BeEmpty(), "no concurrent Forward call should fail") + }) +}) + +type echoMismatch struct{ want, got string } + +func (e *echoMismatch) Error() string { + return "echo mismatch: want " + strconv.Quote(e.want) + " got " + strconv.Quote(e.got) +} diff --git a/backend/go/cloud-proxy/provider_anthropic.go b/backend/go/cloud-proxy/provider_anthropic.go new file mode 100644 index 000000000..d8382d454 --- /dev/null +++ b/backend/go/cloud-proxy/provider_anthropic.go @@ -0,0 +1,508 @@ +package main + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/xlog" +) + +// Anthropic Messages API wire-format types. Narrowed to what translate +// mode preserves through the Reply proto: text + tool_use blocks + +// usage tokens. Image blocks, prompt caching, metadata, and stop +// sequence metadata are not modelled — passthrough mode covers those. +// +// Notable differences from OpenAI: +// - max_tokens is REQUIRED. Anthropic 400s without it. +// - Roles are user/assistant only — system messages move to a +// top-level `system` string field. +// - Streaming SSE uses event: lines alongside data: lines. The +// events we care about: content_block_start (carries tool_use +// init: id + name), content_block_delta (text_delta with text; +// input_json_delta with partial_json for tool arguments), and +// message_stop (terminates the stream). Others are ignored. + +type anthropicRequest struct { + Model string `json:"model"` + MaxTokens int32 `json:"max_tokens"` + System string `json:"system,omitempty"` + Messages []anthropicMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Tools []anthropicTool `json:"tools,omitempty"` + ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"` +} + +// Content is `any` because Anthropic accepts a bare string OR a +// list of content blocks. Use the string form for plain user/ +// assistant turns; switch to []anthropicContentBlock when the +// turn needs tool_use (assistant) or tool_result (user) blocks. +type anthropicMessage struct { + Role string `json:"role"` + Content any `json:"content"` +} + +type anthropicTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema"` +} + +// anthropicToolChoice mirrors the four shapes Anthropic accepts: +// {"type":"auto"} | {"type":"any"} | {"type":"tool","name":"X"} | +// {"type":"none"} (newer models). OpenAI's "auto"/"none"/ +// "required"/{"function":{"name":"X"}} all map here. +type anthropicToolChoice struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` +} + +// anthropicContentBlock is the union shape used both for response +// blocks (text/tool_use we read off the wire) and outbound request +// blocks (tool_use/tool_result we emit in the conversation history). +// Anthropic encodes tool calls inline rather than as a separate field, +// so we walk Content[] looking for type=="tool_use" on responses and +// produce equivalent blocks when serialising prior-turn tool calls. +type anthropicContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` + // Tool-result block fields. tool_result uses `content` (not + // `text`) and pairs with `tool_use_id`; modelling them as + // distinct fields avoids ambiguity at marshal time. + ToolUseID string `json:"tool_use_id,omitempty"` + ResultContent string `json:"content,omitempty"` +} + +type anthropicResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []anthropicContentBlock `json:"content"` + Model string `json:"model"` + Usage *anthropicUsage `json:"usage,omitempty"` +} + +type anthropicUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// anthropicStreamEvent is the union shape used for every event type we +// process. Type discriminates; only the matching fields are populated. +// content_block_start carries ContentBlock (with id/name for tool_use); +// content_block_delta carries Delta (text or partial_json). +type anthropicStreamEvent struct { + Type string `json:"type"` + Index int `json:"index,omitempty"` + ContentBlock *anthropicContentBlock `json:"content_block,omitempty"` + Delta *anthropicStreamDelta `json:"delta,omitempty"` + Message *anthropicResponse `json:"message,omitempty"` + Usage *anthropicUsage `json:"usage,omitempty"` +} + +type anthropicStreamDelta struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` + PartialJSON string `json:"partial_json,omitempty"` +} + +// Anthropic requires max_tokens. If the caller didn't set it, use a +// generous-but-bounded default so the request doesn't 400. +const anthropicDefaultMaxTokens int32 = 4096 + +const anthropicToolChoiceNone = "none" + +// Reused JSON-Schema defaults for malformed inputs. Anthropic requires +// input_schema to be a JSON object and tool_use.input to be a JSON +// object; clients that omit them must not 400 the entire request. +var ( + emptyJSONObject = json.RawMessage(`{}`) + emptyObjectSchema = json.RawMessage(`{"type":"object","properties":{}}`) +) + +func buildAnthropicRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) { + req := anthropicRequest{ + Model: modelName(cfg, opts), + MaxTokens: opts.GetTokens(), + Stream: stream, + StopSequences: opts.GetStopPrompts(), + } + if req.MaxTokens <= 0 { + req.MaxTokens = anthropicDefaultMaxTokens + } + // Newer Anthropic models 400 when both temperature and top_p are + // set ("`temperature` and `top_p` cannot both be specified for + // this model. Please use only one.") even though their docs only + // "recommend" picking one. The OpenAI-compatible chat UI almost + // always sends both with default values, so prefer temperature + // and drop top_p when both are present. + if t := opts.GetTemperature(); t != 0 { + v := float64(t) + req.Temperature = &v + } else if t := opts.GetTopP(); t != 0 { + v := float64(t) + req.TopP = &v + } + + req.Tools = convertOpenAITools(opts.GetTools()) + req.ToolChoice = convertOpenAIToolChoice(opts.GetToolChoice()) + // Anthropic rejects tool_choice without tools and older models + // don't accept {"type":"none"} — collapse to a no-tools request. + if req.ToolChoice != nil && req.ToolChoice.Type == anthropicToolChoiceNone { + req.Tools, req.ToolChoice = nil, nil + } + + var systemParts []string + for _, m := range opts.GetMessages() { + role := m.GetRole() + if role == "system" { + if c := m.GetContent(); c != "" { + systemParts = append(systemParts, c) + } + continue + } + switch role { + case "user": + req.Messages = append(req.Messages, anthropicMessage{ + Role: "user", + Content: m.GetContent(), + }) + case "assistant": + if blocks := assistantBlocks(m); blocks != nil { + req.Messages = append(req.Messages, anthropicMessage{Role: "assistant", Content: blocks}) + continue + } + req.Messages = append(req.Messages, anthropicMessage{ + Role: "assistant", + Content: m.GetContent(), + }) + case "tool", "function": + req.Messages = appendToolResult(req.Messages, anthropicContentBlock{ + Type: "tool_result", + ToolUseID: m.GetToolCallId(), + ResultContent: m.GetContent(), + }) + } + } + req.System = strings.Join(systemParts, "\n\n") + + if len(req.Messages) == 0 && opts.GetPrompt() != "" { + req.Messages = []anthropicMessage{{Role: "user", Content: opts.GetPrompt()}} + } + + return json.Marshal(req) +} + +// appendToolResult appends a tool_result block as a user message, +// merging into a preceding user message that already carries blocks. +// Anthropic concatenates consecutive same-role messages on its end, +// but explicit merging keeps the body smaller and the conversation +// strictly alternating — which some upstream filters require. +func appendToolResult(msgs []anthropicMessage, block anthropicContentBlock) []anthropicMessage { + if n := len(msgs); n > 0 && msgs[n-1].Role == "user" { + if existing, ok := msgs[n-1].Content.([]anthropicContentBlock); ok { + msgs[n-1].Content = append(existing, block) + return msgs + } + } + return append(msgs, anthropicMessage{ + Role: "user", + Content: []anthropicContentBlock{block}, + }) +} + +func convertOpenAITools(toolsJSON string) []anthropicTool { + if toolsJSON == "" { + return nil + } + var raw []openAITool + if err := json.Unmarshal([]byte(toolsJSON), &raw); err != nil { + xlog.Warn("cloud-proxy: anthropic translate: unparseable tools JSON, dropping", "error", err) + return nil + } + tools := make([]anthropicTool, 0, len(raw)) + for _, t := range raw { + if t.Function.Name == "" { + continue + } + schema := t.Function.Parameters + if len(schema) == 0 { + schema = emptyObjectSchema + } + tools = append(tools, anthropicTool{ + Name: t.Function.Name, + Description: t.Function.Description, + InputSchema: schema, + }) + } + return tools +} + +// convertOpenAIToolChoice accepts the spec form +// ({type:function, function:{name:X}}) and the flat legacy form +// ({type:function, name:X}) some clients send. Unknown object shapes +// are warned and dropped rather than silently treated as auto. +func convertOpenAIToolChoice(toolChoiceJSON string) *anthropicToolChoice { + if toolChoiceJSON == "" { + return nil + } + var asString string + if err := json.Unmarshal([]byte(toolChoiceJSON), &asString); err == nil { + switch asString { + case "auto": + return &anthropicToolChoice{Type: "auto"} + case "none": + return &anthropicToolChoice{Type: anthropicToolChoiceNone} + case "required": + return &anthropicToolChoice{Type: "any"} + } + return nil + } + var asObj struct { + Type string `json:"type"` + Name string `json:"name"` + Function struct { + Name string `json:"name"` + } `json:"function"` + } + if err := json.Unmarshal([]byte(toolChoiceJSON), &asObj); err != nil { + xlog.Warn("cloud-proxy: anthropic translate: unparseable tool_choice, dropping", "error", err) + return nil + } + if name := asObj.Function.Name; name != "" { + return &anthropicToolChoice{Type: "tool", Name: name} + } + if asObj.Name != "" { + return &anthropicToolChoice{Type: "tool", Name: asObj.Name} + } + xlog.Warn("cloud-proxy: anthropic translate: unrecognised tool_choice shape, dropping", "shape", toolChoiceJSON) + return nil +} + +// openAITool mirrors pkg/functions.Tool but keeps Parameters as +// json.RawMessage so the input_schema passes through verbatim — no +// re-marshal cost, no fidelity loss on exotic schemas. +type openAITool struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters json.RawMessage `json:"parameters"` + } `json:"function"` +} + +func assistantBlocks(m *pb.Message) []anthropicContentBlock { + toolCallsJSON := m.GetToolCalls() + if toolCallsJSON == "" { + return nil + } + var toolCalls []openAIToolCall + if err := json.Unmarshal([]byte(toolCallsJSON), &toolCalls); err != nil || len(toolCalls) == 0 { + return nil + } + blocks := make([]anthropicContentBlock, 0, len(toolCalls)+1) + if text := m.GetContent(); text != "" { + blocks = append(blocks, anthropicContentBlock{Type: "text", Text: text}) + } + for _, tc := range toolCalls { + // OpenAI's arguments are a JSON-encoded string; pass through + // as RawMessage so a non-JSON string from a poorly-formed + // local model doesn't crash the marshaller downstream. + args := json.RawMessage(tc.Function.Arguments) + if len(args) == 0 { + args = emptyJSONObject + } + blocks = append(blocks, anthropicContentBlock{ + Type: "tool_use", + ID: tc.ID, + Name: tc.Function.Name, + Input: args, + }) + } + return blocks +} + +// doAnthropicRequest is the Anthropic counterpart of doOpenAIRequest. +// applyAuthHeader sets x-api-key and anthropic-version when provider +// is anthropic, so this method doesn't need to duplicate that. +func (c *CloudProxy) doAnthropicRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("cloud-proxy: build request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + if cfg.apiKey != "" { + applyAuthHeader(req, cfg.provider, cfg.apiKey) + } + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err) + } + return resp, nil +} + +// predictAnthropicRich returns the full Reply: joined text from all +// text blocks, tool_use blocks mapped to ToolCallDelta, and usage +// tokens. +func (c *CloudProxy) predictAnthropicRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) { + body, err := buildAnthropicRequest(opts, cfg, false) + if err != nil { + return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err) + } + resp, err := c.doAnthropicRequest(ctx, cfg, body) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody)) + } + + var parsed anthropicResponse + if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { + return nil, fmt.Errorf("cloud-proxy: decode response: %w", err) + } + + reply := &pb.Reply{} + if parsed.Usage != nil { + reply.PromptTokens = int32(parsed.Usage.InputTokens) + reply.Tokens = int32(parsed.Usage.OutputTokens) + } + + var content strings.Builder + var toolCalls []*pb.ToolCallDelta + toolIdx := 0 + for _, b := range parsed.Content { + switch b.Type { + case "text": + content.WriteString(b.Text) + case "tool_use": + // Input is a structured JSON object; we serialise to a + // string so it fits the OpenAI-shaped arguments field + // downstream consumers expect. + args := "" + if len(b.Input) > 0 { + args = string(b.Input) + } + toolCalls = append(toolCalls, newToolCallDelta(toolIdx, b.ID, b.Name, args)) + toolIdx++ + } + } + reply.Message = []byte(content.String()) + if len(toolCalls) > 0 { + reply.ChatDeltas = []*pb.ChatDelta{{ToolCalls: toolCalls}} + } + return reply, nil +} + +// predictAnthropicStreamRich streams Reply chunks from Anthropic's SSE. +// Three event types matter: content_block_start (initialises tool_use +// id+name), content_block_delta (carries text or input_json_delta), +// message_stop (terminates). The block index from the wire feeds +// straight into ToolCallDelta.Index so downstream consumers can +// reassemble multiple parallel tool calls. +func (c *CloudProxy) predictAnthropicStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error { + body, err := buildAnthropicRequest(opts, cfg, true) + if err != nil { + return fmt.Errorf("cloud-proxy: marshal request: %w", err) + } + resp, err := c.doAnthropicRequest(ctx, cfg, body) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1<<20) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" { + continue + } + var ev anthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &ev); err != nil { + xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err) + continue + } + switch ev.Type { + case "content_block_start": + // tool_use blocks announce id + name here; arguments arrive + // in subsequent input_json_delta events. Emit a Reply with + // just the tool_call init fields so consumers can allocate + // a slot at this index. + if ev.ContentBlock != nil && ev.ContentBlock.Type == "tool_use" { + if !sendReply(ctx, results, &pb.Reply{ + ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{ + newToolCallDelta(ev.Index, ev.ContentBlock.ID, ev.ContentBlock.Name, ""), + }}}, + }) { + return ctx.Err() + } + } + case "content_block_delta": + if ev.Delta == nil { + continue + } + switch ev.Delta.Type { + case "text_delta": + if ev.Delta.Text == "" { + continue + } + if !sendReply(ctx, results, &pb.Reply{ + Message: []byte(ev.Delta.Text), + ChatDeltas: []*pb.ChatDelta{{Content: ev.Delta.Text}}, + }) { + return ctx.Err() + } + case "input_json_delta": + if ev.Delta.PartialJSON == "" { + continue + } + if !sendReply(ctx, results, &pb.Reply{ + ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{ + newToolCallDelta(ev.Index, "", "", ev.Delta.PartialJSON), + }}}, + }) { + return ctx.Err() + } + } + case "message_delta": + // Anthropic sends final usage in message_delta.usage. Emit + // a usage-only Reply so the consumer can record totals. + if ev.Usage != nil { + if !sendReply(ctx, results, &pb.Reply{ + Tokens: int32(ev.Usage.OutputTokens), + }) { + return ctx.Err() + } + } + case "message_stop": + return nil + } + } + return scanner.Err() +} diff --git a/backend/go/cloud-proxy/provider_anthropic_test.go b/backend/go/cloud-proxy/provider_anthropic_test.go new file mode 100644 index 000000000..d46db7b12 --- /dev/null +++ b/backend/go/cloud-proxy/provider_anthropic_test.go @@ -0,0 +1,334 @@ +package main + +import ( + "encoding/json" + "io" + "math" + "net/http" + "net/http/httptest" + "strings" + "testing" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/gomega" +) + +// fakeAnthropicUpstream mirrors fakeOpenAIUpstream but decodes the +// request body as an anthropicRequest so tests can assert on the +// translated wire shape (system field, max_tokens, etc.). +func fakeAnthropicUpstream(t *testing.T, handler func(req anthropicRequest) (status int, body string, contentType string)) (*httptest.Server, *anthropicRequest) { + t.Helper() + var captured anthropicRequest + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &captured) + status, body, ct := handler(captured) + w.Header().Set("Content-Type", ct) + w.WriteHeader(status) + _, _ = io.WriteString(w, body) + })) + return srv, &captured +} + +func newAnthropicTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy { + t.Helper() + g := NewWithT(t) + t.Setenv("CLOUD_PROXY_ANTHROPIC_FAKE", "sk-ant-fake") + cp := NewCloudProxy() + err := cp.Load(&pb.ModelOptions{ + Model: "claude-local", + Proxy: &pb.ProxyOptions{ + UpstreamUrl: upstreamURL, + Mode: modeTranslate, + Provider: providerAnthropic, + ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_FAKE", + UpstreamModel: "claude-3-5-sonnet-20241022", + }, + }) + g.Expect(err).NotTo(HaveOccurred()) + return cp +} + +func TestPredict_Anthropic_BasicMessages(t *testing.T) { + g := NewWithT(t) + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hi there"}],"model":"claude-3-5-sonnet-20241022","usage":{"input_tokens":5,"output_tokens":2}}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + got, err := cp.Predict(&pb.PredictOptions{ + Messages: []*pb.Message{ + {Role: "system", Content: "be brief"}, + {Role: "user", Content: "hello"}, + }, + Temperature: 0.5, + TopP: 0.9, + Tokens: 32, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(got).To(Equal("hi there")) + + g.Expect(captured.Model).To(Equal("claude-3-5-sonnet-20241022")) + // System message must be hoisted out of Messages into top-level field. + g.Expect(captured.System).To(Equal("be brief")) + g.Expect(captured.Messages).To(HaveLen(1)) + g.Expect(captured.Messages[0].Role).To(Equal("user")) + g.Expect(captured.MaxTokens).To(Equal(int32(32))) + g.Expect(captured.Temperature).NotTo(BeNil()) + g.Expect(*captured.Temperature).To(Equal(0.5)) + // Anthropic 400s when both temperature and top_p are set; the + // translator must prefer temperature and drop top_p. + g.Expect(captured.TopP).To(BeNil()) + g.Expect(captured.Stream).To(BeFalse()) +} + +// When only top_p is set, it should be forwarded. +func TestPredict_Anthropic_TopPOnly(t *testing.T) { + g := NewWithT(t) + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + _, err := cp.Predict(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "hello"}}, + TopP: 0.9, + Tokens: 16, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(captured.Temperature).To(BeNil()) + // PredictOptions.TopP is float32 on the wire; the translator widens + // to float64 so 0.9 round-trips as 0.8999999761581421… — compare + // with a small tolerance rather than exact equality. + g.Expect(captured.TopP).NotTo(BeNil()) + g.Expect(math.Abs(*captured.TopP - 0.9)).To(BeNumerically("<=", 1e-6)) +} + +func TestPredict_Anthropic_DefaultsMaxTokens(t *testing.T) { + g := NewWithT(t) + // Anthropic 400s without max_tokens. The translator must default + // it when the caller doesn't supply Tokens. + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + _, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(captured.MaxTokens).To(Equal(anthropicDefaultMaxTokens)) +} + +func TestPredict_Anthropic_PromptFallback(t *testing.T) { + g := NewWithT(t) + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + _, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?", Tokens: 16}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(captured.Messages).To(HaveLen(1)) + g.Expect(captured.Messages[0].Role).To(Equal("user")) + g.Expect(captured.Messages[0].Content).To(Equal("what time is it?")) +} + +func TestPredict_Anthropic_ConcatenatesContentBlocks(t *testing.T) { + g := NewWithT(t) + // Anthropic may return multiple text blocks; the translator joins + // them so the Predict() string return is the full assistant message. + srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"content":[{"type":"text","text":"hello "},{"type":"text","text":"world"}]}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(got).To(Equal("hello world")) +} + +func TestPredict_Anthropic_UpstreamError(t *testing.T) { + g := NewWithT(t) + srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 401, `{"error":{"type":"authentication_error","message":"bad key"}}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + _, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16}) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("401")) +} + +func TestPredictStream_Anthropic_StreamsTextDeltas(t *testing.T) { + g := NewWithT(t) + // Real Anthropic SSE has event: lines + data: lines. The translator + // only needs the data: payload; only content_block_delta with + // delta.type=text_delta carries content. message_stop ends. + frames := []string{ + "event: message_start\ndata: {\"type\":\"message_start\"}\n\n", + "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0}\n\n", + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n", + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" \"}}\n\n", + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"world\"}}\n\n", + "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", + } + body := strings.Join(frames, "") + + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, body, "text/event-stream" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + results := make(chan string, 8) + done := make(chan error, 1) + go func() { + done <- cp.PredictStream(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "hi"}}, + Tokens: 16, + }, results) + }() + + var got []string + for s := range results { + got = append(got, s) + } + err := <-done + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(strings.Join(got, "")).To(Equal("hello world")) + g.Expect(captured.Stream).To(BeTrue()) +} + +func TestBuildAnthropic_TranslatesOpenAITools(t *testing.T) { + g := NewWithT(t) + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + tools := `[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}]` + _, err := cp.Predict(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "weather in Paris?"}}, + Tools: tools, + ToolChoice: `"auto"`, + Tokens: 32, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(captured.Tools).To(HaveLen(1)) + g.Expect(captured.Tools[0].Name).To(Equal("get_weather")) + g.Expect(captured.Tools[0].Description).To(Equal("Get weather")) + // input_schema must be the parameters object verbatim. + g.Expect(string(captured.Tools[0].InputSchema)).To(ContainSubstring(`"city"`)) + g.Expect(captured.ToolChoice).NotTo(BeNil()) + g.Expect(captured.ToolChoice.Type).To(Equal("auto")) +} + +func TestBuildAnthropic_ToolChoice_RequiredMapsToAny(t *testing.T) { + g := NewWithT(t) + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"content":[]}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + _, err := cp.Predict(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "x"}}, + Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`, + ToolChoice: `"required"`, + Tokens: 16, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(captured.ToolChoice).NotTo(BeNil()) + g.Expect(captured.ToolChoice.Type).To(Equal("any")) +} + +func TestBuildAnthropic_ToolChoice_NoneDropsTools(t *testing.T) { + g := NewWithT(t) + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"content":[]}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + _, err := cp.Predict(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "x"}}, + Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`, + ToolChoice: `"none"`, + Tokens: 16, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(captured.Tools).To(BeNil()) + g.Expect(captured.ToolChoice).To(BeNil()) +} + +func TestBuildAnthropic_ToolChoice_NamedFunction(t *testing.T) { + g := NewWithT(t) + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"content":[]}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + _, err := cp.Predict(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "x"}}, + Tools: `[{"type":"function","function":{"name":"weather","parameters":{"type":"object"}}}]`, + ToolChoice: `{"type":"function","function":{"name":"weather"}}`, + Tokens: 16, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(captured.ToolChoice).NotTo(BeNil()) + g.Expect(captured.ToolChoice.Type).To(Equal("tool")) + g.Expect(captured.ToolChoice.Name).To(Equal("weather")) +} + +func TestBuildAnthropic_RoundTripsAssistantToolCalls(t *testing.T) { + g := NewWithT(t) + // LocalAI Assistant's second turn: the LLM previously emitted a + // tool_use, the server executed it, and the conversation now + // includes the assistant turn (with tool_calls) plus a tool-role + // result message. Both must convert to Anthropic block form. + srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + tools := `[{"type":"function","function":{"name":"list_models","parameters":{"type":"object"}}}]` + toolCallsJSON := `[{"id":"call_abc","type":"function","function":{"name":"list_models","arguments":"{}"}}]` + _, err := cp.Predict(&pb.PredictOptions{ + Tools: tools, + Messages: []*pb.Message{ + {Role: "user", Content: "what models are installed?"}, + {Role: "assistant", Content: "", ToolCalls: toolCallsJSON}, + {Role: "tool", Content: `{"models":["a","b"]}`, ToolCallId: "call_abc"}, + }, + Tokens: 64, + }) + g.Expect(err).NotTo(HaveOccurred()) + + g.Expect(captured.Messages).To(HaveLen(3)) + // 1. user text — bare string + s, ok := captured.Messages[0].Content.(string) + g.Expect(ok).To(BeTrue()) + g.Expect(s).To(Equal("what models are installed?")) + // 2. assistant — must be a content-block list with one tool_use + // json.Unmarshal of `any` produces []any not []anthropicContentBlock. + blocks, ok := captured.Messages[1].Content.([]any) + g.Expect(ok).To(BeTrue()) + g.Expect(blocks).To(HaveLen(1)) + b0, _ := blocks[0].(map[string]any) + g.Expect(b0["type"]).To(Equal("tool_use")) + g.Expect(b0["id"]).To(Equal("call_abc")) + g.Expect(b0["name"]).To(Equal("list_models")) + // 3. tool → user with tool_result block + g.Expect(captured.Messages[2].Role).To(Equal("user")) + resBlocks, _ := captured.Messages[2].Content.([]any) + r0, _ := resBlocks[0].(map[string]any) + g.Expect(r0["type"]).To(Equal("tool_result")) + g.Expect(r0["tool_use_id"]).To(Equal("call_abc")) + g.Expect(r0["content"]).To(Equal(`{"models":["a","b"]}`)) +} diff --git a/backend/go/cloud-proxy/provider_edge_test.go b/backend/go/cloud-proxy/provider_edge_test.go new file mode 100644 index 000000000..5c6deb829 --- /dev/null +++ b/backend/go/cloud-proxy/provider_edge_test.go @@ -0,0 +1,119 @@ +package main + +import ( + "encoding/json" + "strings" + "testing" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/gomega" +) + +// Verify buildOpenAIRequest preserves caller-supplied tools and +// tool_choice as opaque JSON. PredictOptions carries them as strings; +// they must land in the outbound request body unchanged so the +// upstream sees the caller's intent verbatim. A regression here would +// silently disable function calling for translate-mode clients. +func TestBuildOpenAIRequest_ToolsAndToolChoicePassthrough(t *testing.T) { + g := NewWithT(t) + cfg := &proxyConfig{upstreamModel: "gpt-4o"} + toolsJSON := `[{"type":"function","function":{"name":"search","parameters":{"type":"object"}}}]` + choiceJSON := `{"type":"function","function":{"name":"search"}}` + + body, err := buildOpenAIRequest(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "find x"}}, + Tools: toolsJSON, + ToolChoice: choiceJSON, + }, cfg, false) + g.Expect(err).NotTo(HaveOccurred()) + + var decoded openAIRequest + err = json.Unmarshal(body, &decoded) + g.Expect(err).NotTo(HaveOccurred()) + // Compare the JSON-canonical form so whitespace differences are ignored. + gotTools, _ := json.Marshal(json.RawMessage(decoded.Tools)) + wantTools, _ := json.Marshal(json.RawMessage(toolsJSON)) + g.Expect(string(gotTools)).To(Equal(string(wantTools))) + gotChoice, _ := json.Marshal(json.RawMessage(decoded.ToolChoice)) + wantChoice, _ := json.Marshal(json.RawMessage(choiceJSON)) + g.Expect(string(gotChoice)).To(Equal(string(wantChoice))) +} + +// Garbage JSON in tools / tool_choice is silently dropped (omitted) +// rather than blowing up the request. Documents the parseRawJSON +// behaviour — operators shouldn't see hard failures from an upstream +// caller's mis-formatted tools field. +func TestBuildOpenAIRequest_InvalidToolsJSONDropped(t *testing.T) { + g := NewWithT(t) + cfg := &proxyConfig{upstreamModel: "gpt-4o"} + body, err := buildOpenAIRequest(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "x"}}, + Tools: "this is not json", + ToolChoice: "{also bad", + }, cfg, false) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(string(body)).NotTo(ContainSubstring("this is not json")) + g.Expect(string(body)).NotTo(ContainSubstring("{also bad")) +} + +// Anthropic empty content array yields an empty Reply (not an error). +// Mirrors how an upstream tool_use-only response might arrive — the +// content array can legitimately be empty in some edge cases. +func TestPredictRich_Anthropic_EmptyContent(t *testing.T) { + g := NewWithT(t) + srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{"id":"m1","type":"message","role":"assistant","content":[],"usage":{"input_tokens":3,"output_tokens":0}}`, "application/json" + }) + defer srv.Close() + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + reply, err := cp.PredictRich(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "x"}}, + Tokens: 16, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(string(reply.GetMessage())).To(Equal("")) + g.Expect(reply.GetChatDeltas()).To(HaveLen(0)) + g.Expect(reply.GetPromptTokens()).To(Equal(int32(3))) +} + +// A truncated / malformed SSE payload mid-stream should be tolerated: +// the malformed chunk gets skipped (xlog.Debug logged), valid chunks +// before AND after it still reach the channel. +func TestPredictStreamRich_OpenAI_TolerantOfBadChunks(t *testing.T) { + g := NewWithT(t) + body := strings.Join([]string{ + `data: {"choices":[{"index":0,"delta":{"content":"hello"}}]}`, + ``, + `data: this-is-not-json{{`, + ``, + `data: {"choices":[{"index":0,"delta":{"content":" world"}}]}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + + srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) { + return 200, body, "text/event-stream" + }) + defer srv.Close() + cp := newTranslateCloudProxy(t, srv.URL) + + results := make(chan *pb.Reply, 8) + done := make(chan error, 1) + go func() { + done <- cp.PredictStreamRich(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "hi"}}, + }, results) + close(results) + }() + + var assembled strings.Builder + for reply := range results { + assembled.Write(reply.GetMessage()) + } + err := <-done + g.Expect(err).NotTo(HaveOccurred()) + // The good chunks before and after the malformed one both made it through. + g.Expect(assembled.String()).To(Equal("hello world")) +} diff --git a/backend/go/cloud-proxy/provider_openai.go b/backend/go/cloud-proxy/provider_openai.go new file mode 100644 index 000000000..d4911b6b7 --- /dev/null +++ b/backend/go/cloud-proxy/provider_openai.go @@ -0,0 +1,320 @@ +package main + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/xlog" +) + +// OpenAI Chat Completions wire-format types. Narrowed to the fields +// translate mode needs to preserve through the Reply proto: content, +// role, tool_calls (typed so we can map them to pb.ToolCallDelta), +// and sampling params copied verbatim from PredictOptions. +// +// Provider-specific extensions (logit_bias, function calling beyond +// tool_calls, etc.) are not modelled — passthrough mode covers callers +// that need full upstream fidelity. + +type openAIRequest struct { + Model string `json:"model"` + Messages []openAIMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + MaxTokens *int32 `json:"max_tokens,omitempty"` + Stop []string `json:"stop,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` +} + +type openAIMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolCalls []openAIToolCall `json:"tool_calls,omitempty"` +} + +// openAIToolCall covers both the non-streaming response shape (full +// id+function+arguments) and the streaming-delta shape (sparse fields, +// index assignment). The proto's ToolCallDelta absorbs both — name is +// set on first appearance, arguments arrive incrementally in streaming. +type openAIToolCall struct { + Index int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function openAIFunctionCall `json:"function,omitempty"` +} + +type openAIFunctionCall struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` +} + +type openAIChoice struct { + Index int `json:"index"` + Message openAIMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type openAIResponse struct { + ID string `json:"id"` + Choices []openAIChoice `json:"choices"` + Usage *openAIUsage `json:"usage,omitempty"` +} + +type openAIStreamChoice struct { + Index int `json:"index"` + Delta struct { + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + ToolCalls []openAIToolCall `json:"tool_calls,omitempty"` + } `json:"delta"` + FinishReason string `json:"finish_reason,omitempty"` +} + +type openAIStreamChunk struct { + Choices []openAIStreamChoice `json:"choices"` + Usage *openAIUsage `json:"usage,omitempty"` +} + +type openAIUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// buildOpenAIRequest converts pb.PredictOptions into the OpenAI Chat +// Completions request body. Prefers Messages when non-empty; falls +// back to wrapping Prompt as a single user message so plain +// /completions-style calls still work in translate mode. +func buildOpenAIRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) { + req := openAIRequest{ + Model: modelName(cfg, opts), + Stream: stream, + Stop: opts.GetStopPrompts(), + Tools: parseRawJSON(opts.GetTools()), + ToolChoice: parseRawJSON(opts.GetToolChoice()), + } + if t := opts.GetTemperature(); t != 0 { + v := float64(t) + req.Temperature = &v + } + if t := opts.GetTopP(); t != 0 { + v := float64(t) + req.TopP = &v + } + if n := opts.GetTokens(); n > 0 { + req.MaxTokens = &n + } + if p := opts.GetFrequencyPenalty(); p != 0 { + v := float64(p) + req.FrequencyPenalty = &v + } + if p := opts.GetPresencePenalty(); p != 0 { + v := float64(p) + req.PresencePenalty = &v + } + + for _, m := range opts.GetMessages() { + msg := openAIMessage{ + Role: m.GetRole(), + Content: m.GetContent(), + Name: m.GetName(), + ToolCallID: m.GetToolCallId(), + } + // Pre-existing tool_calls arrive as a JSON string from the + // upstream caller's previous assistant turn; pass-through as-is. + if tc := m.GetToolCalls(); tc != "" { + _ = json.Unmarshal([]byte(tc), &msg.ToolCalls) + } + req.Messages = append(req.Messages, msg) + } + // Fallback for plain Prompt requests (no Messages array). LocalAI + // templating may have produced a flat prompt; rewrap as a single + // user message so the upstream chat endpoint accepts it. + if len(req.Messages) == 0 && opts.GetPrompt() != "" { + req.Messages = []openAIMessage{{Role: "user", Content: opts.GetPrompt()}} + } + + return json.Marshal(req) +} + +// modelName picks the upstream model: upstream_model from the proxy +// config wins (operator override), else the local model name captured +// at LoadModel time. Operator sets upstream_model to map LocalAI's +// alias (e.g. "claude-strict") to the upstream's canonical name +// (e.g. "claude-3-5-sonnet-20241022"). +func modelName(cfg *proxyConfig, _ *pb.PredictOptions) string { + if cfg.upstreamModel != "" { + return cfg.upstreamModel + } + return cfg.localModel +} + +// parseRawJSON parses a JSON string into a RawMessage so it round-trips +// into the upstream body. Returns nil for empty/invalid input so the +// field is omitted (omitempty). +func parseRawJSON(s string) json.RawMessage { + if s == "" { + return nil + } + var probe json.RawMessage + if err := json.Unmarshal([]byte(s), &probe); err != nil { + return nil + } + return probe +} + +// doOpenAIRequest builds + sends the upstream request. Returns the +// raw response on success; caller handles status / body. +func (c *CloudProxy) doOpenAIRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("cloud-proxy: build request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + if cfg.apiKey != "" { + applyAuthHeader(req, cfg.provider, cfg.apiKey) + } + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err) + } + return resp, nil +} + +// predictOpenAIRich is the non-streaming translate path. Returns a +// fully-populated *pb.Reply with assistant content, tool calls, and +// token usage. The gRPC server forwards the Reply verbatim. +func (c *CloudProxy) predictOpenAIRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) { + body, err := buildOpenAIRequest(opts, cfg, false) + if err != nil { + return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err) + } + resp, err := c.doOpenAIRequest(ctx, cfg, body) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody)) + } + + var parsed openAIResponse + if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { + return nil, fmt.Errorf("cloud-proxy: decode response: %w", err) + } + if len(parsed.Choices) == 0 { + return nil, errors.New("cloud-proxy: upstream returned no choices") + } + + choice := parsed.Choices[0] + reply := &pb.Reply{ + Message: []byte(choice.Message.Content), + } + if parsed.Usage != nil { + reply.PromptTokens = int32(parsed.Usage.PromptTokens) + reply.Tokens = int32(parsed.Usage.CompletionTokens) + } + if len(choice.Message.ToolCalls) > 0 { + // Non-streaming: a single ChatDelta carries the full tool-call + // set. Index/Name/Arguments are populated together; downstream + // consumers don't need to assemble streaming deltas. + delta := &pb.ChatDelta{} + for _, tc := range choice.Message.ToolCalls { + delta.ToolCalls = append(delta.ToolCalls, + newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments)) + } + reply.ChatDeltas = []*pb.ChatDelta{delta} + } + return reply, nil +} + +// predictOpenAIStreamRich streams *pb.Reply chunks. Each chunk carries +// either a content delta (Message + ChatDeltas[].Content) or tool-call +// deltas (ChatDeltas[].ToolCalls). The final Reply carries usage tokens +// when the upstream sends them (stream_options.include_usage). +func (c *CloudProxy) predictOpenAIStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error { + body, err := buildOpenAIRequest(opts, cfg, true) + if err != nil { + return fmt.Errorf("cloud-proxy: marshal request: %w", err) + } + resp, err := c.doOpenAIRequest(ctx, cfg, body) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1<<20) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" || payload == "[DONE]" { + return nil + } + var chunk openAIStreamChunk + if err := json.Unmarshal([]byte(payload), &chunk); err != nil { + xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err) + continue + } + // Usage frames may arrive separately from content frames when + // stream_options.include_usage is set; emit a usage-only Reply + // in that case so the consumer sees the totals. + if chunk.Usage != nil && len(chunk.Choices) == 0 { + if !sendReply(ctx, results, &pb.Reply{ + PromptTokens: int32(chunk.Usage.PromptTokens), + Tokens: int32(chunk.Usage.CompletionTokens), + }) { + return ctx.Err() + } + continue + } + for _, ch := range chunk.Choices { + reply := &pb.Reply{} + if ch.Delta.Content != "" { + reply.Message = []byte(ch.Delta.Content) + reply.ChatDeltas = []*pb.ChatDelta{{Content: ch.Delta.Content}} + } + if len(ch.Delta.ToolCalls) > 0 { + if len(reply.ChatDeltas) == 0 { + reply.ChatDeltas = []*pb.ChatDelta{{}} + } + for _, tc := range ch.Delta.ToolCalls { + reply.ChatDeltas[0].ToolCalls = append(reply.ChatDeltas[0].ToolCalls, + newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments)) + } + } + if reply.Message == nil && len(reply.ChatDeltas) == 0 { + continue + } + if !sendReply(ctx, results, reply) { + return ctx.Err() + } + } + } + return scanner.Err() +} diff --git a/backend/go/cloud-proxy/provider_openai_test.go b/backend/go/cloud-proxy/provider_openai_test.go new file mode 100644 index 000000000..9ce4334db --- /dev/null +++ b/backend/go/cloud-proxy/provider_openai_test.go @@ -0,0 +1,170 @@ +package main + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + . "github.com/onsi/gomega" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// fakeOpenAIUpstream returns an httptest.Server that decodes the +// inbound request as an openAIRequest, calls handler with it, and +// writes the handler's reply as the response. +func fakeOpenAIUpstream(t *testing.T, handler func(req openAIRequest) (status int, body string, contentType string)) (*httptest.Server, *openAIRequest) { + t.Helper() + var captured openAIRequest + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &captured) + status, body, ct := handler(captured) + w.Header().Set("Content-Type", ct) + w.WriteHeader(status) + _, _ = io.WriteString(w, body) + })) + return srv, &captured +} + +func newTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy { + t.Helper() + g := NewWithT(t) + t.Setenv("CLOUD_PROXY_OPENAI_FAKE", "sk-fake-openai") + cp := NewCloudProxy() + err := cp.Load(&pb.ModelOptions{ + Model: "gpt-4o-local", + Proxy: &pb.ProxyOptions{ + UpstreamUrl: upstreamURL, + Mode: modeTranslate, + Provider: providerOpenAI, + ApiKeyEnv: "CLOUD_PROXY_OPENAI_FAKE", + UpstreamModel: "gpt-4o", + }, + }) + g.Expect(err).NotTo(HaveOccurred()) + return cp +} + +func TestPredict_OpenAI_BasicChat(t *testing.T) { + g := NewWithT(t) + srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) { + return 200, `{"id":"resp-1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, "application/json" + }) + defer srv.Close() + cp := newTranslateCloudProxy(t, srv.URL) + + got, err := cp.Predict(&pb.PredictOptions{ + Messages: []*pb.Message{ + {Role: "system", Content: "be brief"}, + {Role: "user", Content: "hello"}, + }, + Temperature: 0.5, + TopP: 0.9, + Tokens: 32, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(got).To(Equal("hi there")) + + // Verify the upstream saw a properly-translated request. + g.Expect(captured.Model).To(Equal("gpt-4o")) + g.Expect(captured.Messages).To(HaveLen(2)) + g.Expect(captured.Messages[0].Role).To(Equal("system")) + g.Expect(captured.Messages[1].Role).To(Equal("user")) + g.Expect(captured.Temperature).NotTo(BeNil()) + g.Expect(*captured.Temperature).To(Equal(0.5)) + g.Expect(captured.MaxTokens).NotTo(BeNil()) + g.Expect(*captured.MaxTokens).To(Equal(int32(32))) + g.Expect(captured.Stream).To(BeFalse()) +} + +func TestPredict_OpenAI_PromptFallback(t *testing.T) { + g := NewWithT(t) + // No Messages array — backend should synth a single user message + // from Prompt so non-chat clients still route through translate. + srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) { + return 200, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`, "application/json" + }) + defer srv.Close() + cp := newTranslateCloudProxy(t, srv.URL) + + _, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?"}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(captured.Messages).To(HaveLen(1)) + g.Expect(captured.Messages[0].Role).To(Equal("user")) + g.Expect(captured.Messages[0].Content).To(Equal("what time is it?")) +} + +func TestPredict_OpenAI_UpstreamError(t *testing.T) { + g := NewWithT(t) + srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) { + return 401, `{"error":{"message":"bad key"}}`, "application/json" + }) + defer srv.Close() + cp := newTranslateCloudProxy(t, srv.URL) + + _, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}}) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("401")) +} + +func TestPredictStream_OpenAI_StreamsContent(t *testing.T) { + g := NewWithT(t) + // Stream three content deltas then [DONE]. Verify the channel + // receives them in order with no missing pieces. + chunks := []string{ + `{"choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + `{"choices":[{"index":0,"delta":{"content":"hello"}}]}`, + `{"choices":[{"index":0,"delta":{"content":" "}}]}`, + `{"choices":[{"index":0,"delta":{"content":"world"}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + } + body := "" + for _, c := range chunks { + body += "data: " + c + "\n\n" + } + body += "data: [DONE]\n\n" + + srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) { + return 200, body, "text/event-stream" + }) + defer srv.Close() + cp := newTranslateCloudProxy(t, srv.URL) + + results := make(chan string, 8) + done := make(chan error, 1) + go func() { + done <- cp.PredictStream(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "hi"}}, + }, results) + }() + + var got []string + for s := range results { + got = append(got, s) + } + err := <-done + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(strings.Join(got, "")).To(Equal("hello world")) + g.Expect(captured.Stream).To(BeTrue()) +} + +func TestPredict_RejectedInPassthroughMode(t *testing.T) { + g := NewWithT(t) + t.Setenv("CLOUD_PROXY_FAKE", "k") + cp := NewCloudProxy() + err := cp.Load(&pb.ModelOptions{ + Proxy: &pb.ProxyOptions{ + UpstreamUrl: "https://example.com", + Mode: modePassthrough, + ApiKeyEnv: "CLOUD_PROXY_FAKE", + }, + }) + g.Expect(err).NotTo(HaveOccurred()) + _, err = cp.Predict(&pb.PredictOptions{}) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("only valid in translate")) +} diff --git a/backend/go/cloud-proxy/proxy.go b/backend/go/cloud-proxy/proxy.go new file mode 100644 index 000000000..9015ffc4d --- /dev/null +++ b/backend/go/cloud-proxy/proxy.go @@ -0,0 +1,429 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync/atomic" + + "github.com/mudler/LocalAI/pkg/grpc/base" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/xlog" +) + +// Mirror of core/config.Proxy{Mode,Provider}* — backends don't +// import core to keep the boundary clean. +const ( + modePassthrough = "passthrough" + modeTranslate = "translate" + + providerOpenAI = "openai" + providerAnthropic = "anthropic" +) + +// CloudProxy is the LocalAI backend that proxies model traffic to a +// configured upstream HTTP provider. Concurrency: base.SingleThread is +// NOT embedded — forward calls are independent and HTTP transport is +// goroutine-safe, so multiple Forward streams can run in parallel. +// Locking would serialise requests to a chat provider for no benefit. +type CloudProxy struct { + base.Base + + cfg atomic.Pointer[proxyConfig] + client *http.Client +} + +type proxyConfig struct { + upstreamURL string + mode string + provider string + upstreamModel string + localModel string // ModelOptions.Model — fallback when upstream_model is unset + apiKey string // resolved at Load time +} + +func NewCloudProxy() *CloudProxy { + // No Client-level Timeout — that would bound streaming SSE + // responses too, which can legitimately last minutes. Per-request + // deadlines come from the gRPC stream context. + return &CloudProxy{client: &http.Client{}} +} + +func (c *CloudProxy) Load(opts *pb.ModelOptions) error { + po := opts.GetProxy() + if po == nil { + return errors.New("cloud-proxy: Load requires ProxyOptions to be set") + } + if po.GetUpstreamUrl() == "" { + return errors.New("cloud-proxy: upstream_url is required") + } + if _, err := url.ParseRequestURI(po.GetUpstreamUrl()); err != nil { + return fmt.Errorf("cloud-proxy: upstream_url %q invalid: %w", po.GetUpstreamUrl(), err) + } + + mode := po.GetMode() + if mode == "" { + mode = modePassthrough + } + switch mode { + case modePassthrough: + case modeTranslate: + switch po.GetProvider() { + case providerOpenAI: + // implemented in provider_openai.go + case providerAnthropic: + // implemented in provider_anthropic.go + default: + return fmt.Errorf("cloud-proxy: translate mode requires provider in {%s, %s}, got %q", + providerOpenAI, providerAnthropic, po.GetProvider()) + } + default: + return fmt.Errorf("cloud-proxy: unknown mode %q", mode) + } + + key, err := resolveAPIKey(po.GetApiKeyEnv(), po.GetApiKeyFile()) + if err != nil { + return err + } + + c.cfg.Store(&proxyConfig{ + upstreamURL: po.GetUpstreamUrl(), + mode: mode, + provider: po.GetProvider(), + upstreamModel: po.GetUpstreamModel(), + localModel: opts.GetModel(), + apiKey: key, + }) + xlog.Info("cloud-proxy: ready", + "upstream", po.GetUpstreamUrl(), + "mode", mode, + "provider", po.GetProvider(), + "has_key", key != "") + return nil +} + +// resolveAPIKey mirrors config.ProxyConfig.ResolveAPIKey. Duplicated +// (a few lines) rather than importing core/config from a backend +// binary — keeps backends independent of core's package layout. +// Mutual-exclusion is enforced upstream in core/config.Validate. +func resolveAPIKey(envName, filePath string) (string, error) { + if envName != "" { + v := os.Getenv(envName) + if v == "" { + return "", fmt.Errorf("cloud-proxy: api_key_env %q is unset", envName) + } + return v, nil + } + if filePath != "" { + b, err := os.ReadFile(filePath) + if err != nil { + return "", fmt.Errorf("cloud-proxy: read api_key_file %q: %w", filePath, err) + } + return strings.TrimSpace(string(b)), nil + } + return "", nil +} + +// PredictRich is the non-streaming translate path. Returns a fully- +// populated *pb.Reply: content, tool-call deltas (ChatDeltas), and +// usage tokens. Implements the optional grpc.AIModelRich interface; +// the gRPC server prefers this path over Predict when present so +// tool calls survive the round-trip. Passthrough mode rejects +// PredictRich — callers must use Forward. +func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) { + cfg := c.cfg.Load() + if cfg == nil { + return nil, errors.New("cloud-proxy: model not loaded") + } + if cfg.mode != modeTranslate { + return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode) + } + xlog.Info("cloud-proxy: predict", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel) + defer func() { + if err != nil { + xlog.Warn("cloud-proxy: predict failed", "provider", cfg.provider, "error", err) + } + }() + ctx := context.Background() + switch cfg.provider { + case providerOpenAI: + return c.predictOpenAIRich(ctx, cfg, opts) + case providerAnthropic: + return c.predictAnthropicRich(ctx, cfg, opts) + default: + return nil, fmt.Errorf("cloud-proxy: predict not implemented for provider %q", cfg.provider) + } +} + +// PredictStreamRich is the rich streaming counterpart of PredictRich. +// Each emitted Reply carries either a content delta, tool-call deltas, +// or usage tokens (the final upstream frame). base.Base.PredictStream +// is bypassed when AIModelRich is implemented, so the channel is +// closed by the gRPC server pump. +func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) { + cfg := c.cfg.Load() + if cfg == nil { + return errors.New("cloud-proxy: model not loaded") + } + if cfg.mode != modeTranslate { + return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode) + } + xlog.Info("cloud-proxy: predict-stream", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel) + defer func() { + if err != nil { + xlog.Warn("cloud-proxy: predict-stream failed", "provider", cfg.provider, "error", err) + } + }() + ctx := context.Background() + switch cfg.provider { + case providerOpenAI: + return c.predictOpenAIStreamRich(ctx, cfg, opts, results) + case providerAnthropic: + return c.predictAnthropicStreamRich(ctx, cfg, opts, results) + default: + return fmt.Errorf("cloud-proxy: predictStream not implemented for provider %q", cfg.provider) + } +} + +// Predict is the legacy (string, error) AIModel signature. Used only +// if a caller goes through the non-rich path (it shouldn't, since +// server.go prefers PredictRich). Provided so the AIModel interface +// is satisfied for backends that haven't opted into the rich variant. +func (c *CloudProxy) Predict(opts *pb.PredictOptions) (string, error) { + reply, err := c.PredictRich(opts) + if err != nil { + return "", err + } + return string(reply.GetMessage()), nil +} + +// PredictStream is the legacy chan-string streaming path. Adapts the +// rich stream by extracting only content text — tool-call-only chunks +// (no Message bytes) and usage-only chunks are silently dropped, since +// the legacy chan-string contract cannot represent them. Consumers +// that need tool calls must call PredictStreamRich directly. +func (c *CloudProxy) PredictStream(opts *pb.PredictOptions, results chan string) error { + defer close(results) + richCh := make(chan *pb.Reply) + errCh := make(chan error, 1) + go func() { + errCh <- c.PredictStreamRich(opts, richCh) + close(richCh) + }() + for reply := range richCh { + if msg := reply.GetMessage(); len(msg) > 0 { + results <- string(msg) + } + } + return <-errCh +} + +// sendReply pushes one Reply onto a stream channel honouring ctx +// cancellation. Returns false on cancel so the caller can exit with +// ctx.Err(). Used by both translate-mode providers. +func sendReply(ctx context.Context, results chan<- *pb.Reply, reply *pb.Reply) bool { + select { + case results <- reply: + return true + case <-ctx.Done(): + return false + } +} + +// newToolCallDelta is a small constructor for the cross-provider +// tool-call delta shape. Centralised so the int32 cast and the four +// fields stay consistent across the OpenAI / Anthropic translators. +// Empty name/args are valid — Anthropic streaming announces the call +// with id+name then sends arguments incrementally; OpenAI's reverse +// pattern (args without name) also lands here. +func newToolCallDelta(index int, id, name, args string) *pb.ToolCallDelta { + return &pb.ToolCallDelta{ + Index: int32(index), + Id: id, + Name: name, + Arguments: args, + } +} + +// Forward shovels bytes between a Forward gRPC stream and an upstream +// HTTP request. First request message carries path/method/headers and +// the initial body chunk; subsequent messages append body chunks. The +// first reply carries upstream status + response headers; subsequent +// replies stream body chunks until the upstream connection closes. +// Cancellation of ctx (the gRPC stream context) closes the upstream +// connection. +func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest, out chan<- *pb.ForwardReply) error { + defer close(out) + + cfg := c.cfg.Load() + if cfg == nil { + return errors.New("cloud-proxy: model not loaded") + } + if cfg.mode != modePassthrough { + return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode) + } + + first, ok := <-in + if !ok { + return errors.New("cloud-proxy: Forward stream closed before first request") + } + + // Honour the per-request path only when the configured upstream_url + // has no path of its own — gallery convention is to put the + // canonical path in upstream_url. + fullURL, err := composeURL(cfg.upstreamURL, first.GetPath()) + if err != nil { + return err + } + + method := first.GetMethod() + if method == "" { + method = http.MethodPost + } + + // Pipe the body in from the gRPC stream so the HTTP request can + // start before the client finishes sending. The pipe-reader is + // closed via CloseWithError on the error paths so the writer + // goroutine doesn't block forever. + pr, pw := io.Pipe() + + go func() { + var writeErr error + defer func() { _ = pw.CloseWithError(writeErr) }() + if len(first.GetBodyChunk()) > 0 { + if _, writeErr = pw.Write(first.GetBodyChunk()); writeErr != nil { + return + } + } + for req := range in { + if len(req.GetBodyChunk()) == 0 { + continue + } + if _, writeErr = pw.Write(req.GetBodyChunk()); writeErr != nil { + return + } + } + }() + + req, err := http.NewRequestWithContext(ctx, method, fullURL, pr) + if err != nil { + _ = pr.CloseWithError(err) // unblocks the body-pump's pw.Write + return fmt.Errorf("cloud-proxy: build request: %w", err) + } + + // Apply caller-supplied headers, then override with the + // authorization header derived from the resolved key. Caller- + // supplied Authorization is always replaced — operators may not + // know the backend's auth scheme, and silently leaking through a + // client Authorization header to a different upstream would + // confuse the upstream and could leak credentials. + for _, h := range first.GetHeaders() { + if h == nil || h.GetName() == "" { + continue + } + // Strip hop-by-hop headers that aren't meaningful to the + // upstream (Host is set by the http client from the URL; + // Content-Length is computed from the body). + if isHopByHopHeader(h.GetName()) { + continue + } + req.Header.Add(h.GetName(), h.GetValue()) + } + if cfg.apiKey != "" { + applyAuthHeader(req, cfg.provider, cfg.apiKey) + } + + xlog.Info("cloud-proxy: forward", "method", method, "url", fullURL, "provider", cfg.provider) + resp, err := c.client.Do(req) + if err != nil { + xlog.Warn("cloud-proxy: forward upstream failed", "url", fullURL, "error", err) + return fmt.Errorf("cloud-proxy: upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + logFn := xlog.Info + if resp.StatusCode >= 400 { + logFn = xlog.Warn + } + logFn("cloud-proxy: forward response", "url", fullURL, "status", resp.StatusCode) + + // First reply: status + response headers, no body. + headers := make([]*pb.ForwardHeader, 0, len(resp.Header)) + for k, vs := range resp.Header { + for _, v := range vs { + headers = append(headers, &pb.ForwardHeader{Name: k, Value: v}) + } + } + out <- &pb.ForwardReply{Status: int32(resp.StatusCode), Headers: headers} + + // Subsequent replies: body chunks. Use a fixed 8KB buffer — small + // enough that SSE token frames flush promptly, large enough that + // long chunked-transfer bodies aren't death by a thousand reads. + buf := make([]byte, 8*1024) + for { + n, rerr := resp.Body.Read(buf) + if n > 0 { + chunk := make([]byte, n) + copy(chunk, buf[:n]) + out <- &pb.ForwardReply{BodyChunk: chunk} + } + if rerr != nil { + if errors.Is(rerr, io.EOF) { + return nil + } + return fmt.Errorf("cloud-proxy: upstream body read: %w", rerr) + } + } +} + +// composeURL combines the configured upstream URL with the per-request +// path. The upstream URL typically already includes the canonical path +// (e.g. https://api.openai.com/v1/chat/completions) so the per-request +// path is ignored in that case. When upstream_url is a bare host +// (https://api.openai.com), the request path is appended. +func composeURL(upstream, reqPath string) (string, error) { + u, err := url.Parse(upstream) + if err != nil { + return "", fmt.Errorf("cloud-proxy: parse upstream_url %q: %w", upstream, err) + } + if u.Path == "" || u.Path == "/" { + u.Path = reqPath + } + return u.String(), nil +} + +// applyAuthHeader writes the appropriate authorization header for the +// provider. OpenAI/Anthropic/most providers use Bearer; Anthropic +// historically uses x-api-key + anthropic-version, but accepts Bearer +// too via the OpenAI-compatible path. Default to Bearer when provider +// is empty (passthrough mode where the operator doesn't claim a +// provider). +func applyAuthHeader(req *http.Request, provider, key string) { + switch provider { + case providerAnthropic: + req.Header.Set("x-api-key", key) + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + default: + req.Header.Set("Authorization", "Bearer "+key) + } +} + +// isHopByHopHeader returns true for headers that should not be +// forwarded from the client request to the upstream (RFC 7230 §6.1 +// hop-by-hop list, plus a few that the http.Client sets itself). +func isHopByHopHeader(name string) bool { + switch strings.ToLower(name) { + case "connection", "proxy-connection", "keep-alive", "transfer-encoding", + "te", "trailer", "upgrade", "host", "content-length": + return true + } + return false +} + diff --git a/backend/go/cloud-proxy/proxy_test.go b/backend/go/cloud-proxy/proxy_test.go new file mode 100644 index 000000000..3881346d4 --- /dev/null +++ b/backend/go/cloud-proxy/proxy_test.go @@ -0,0 +1,206 @@ +package main + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + grpc "github.com/mudler/LocalAI/pkg/grpc" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + + . "github.com/onsi/gomega" +) + +// helper: run a CloudProxy in-process via grpc.Provide so tests can +// call Forward through the public Backend interface without listening +// on a real socket. +func newInProcClient(t *testing.T, proxy *CloudProxy) grpc.Backend { + t.Helper() + addr := "test://" + t.Name() + grpc.Provide(addr, proxy) + return grpc.NewClient(addr, true, nil, false) +} + +func TestForward_PassthroughEcho(t *testing.T) { + g := NewWithT(t) + // Fake upstream: echoes the request body back, prefixed with a + // canary so the test can assert both that the body reached the + // upstream and the response made it back to the client. + gotBody := make(chan string, 1) + gotAuth := make(chan string, 1) + gotPath := make(chan string, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + gotBody <- string(body) + gotAuth <- r.Header.Get("Authorization") + gotPath <- r.URL.Path + w.Header().Set("X-Echo", "true") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("echo: " + string(body))) + })) + defer upstream.Close() + + t.Setenv("CLOUD_PROXY_FAKE_KEY", "sk-fake") + + cp := NewCloudProxy() + err := cp.Load(&pb.ModelOptions{ + Proxy: &pb.ProxyOptions{ + UpstreamUrl: upstream.URL, + Mode: modePassthrough, + ApiKeyEnv: "CLOUD_PROXY_FAKE_KEY", + }, + }) + g.Expect(err).NotTo(HaveOccurred()) + + c := newInProcClient(t, cp) + stream, err := c.Forward(context.Background()) + g.Expect(err).NotTo(HaveOccurred()) + + err = stream.Send(&pb.ForwardRequest{ + Path: "/v1/chat/completions", + Method: "POST", + Headers: []*pb.ForwardHeader{{Name: "Content-Type", Value: "application/json"}}, + BodyChunk: []byte(`{"prompt":`), + }) + g.Expect(err).NotTo(HaveOccurred()) + err = stream.Send(&pb.ForwardRequest{BodyChunk: []byte(`"hi"}`)}) + g.Expect(err).NotTo(HaveOccurred()) + err = stream.CloseSend() + g.Expect(err).NotTo(HaveOccurred()) + + // First reply: status + headers. + first, err := stream.Recv() + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(first.Status).To(Equal(int32(http.StatusOK))) + g.Expect(hasHeader(first.Headers, "X-Echo", "true")).To(BeTrue()) + + // Subsequent replies: body. + var body []byte + for { + r, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + g.Expect(err).NotTo(HaveOccurred()) + body = append(body, r.BodyChunk...) + } + g.Expect(string(body)).To(Equal(`echo: {"prompt":"hi"}`)) + + // Upstream observations. + var gotBodyVal, gotAuthVal, gotPathVal string + g.Eventually(gotBody).Should(Receive(&gotBodyVal), "upstream never saw body") + g.Expect(gotBodyVal).To(Equal(`{"prompt":"hi"}`)) + g.Eventually(gotAuth).Should(Receive(&gotAuthVal), "upstream never saw auth header") + g.Expect(gotAuthVal).To(Equal("Bearer sk-fake")) + g.Eventually(gotPath).Should(Receive(&gotPathVal), "upstream never saw path") + g.Expect(gotPathVal).To(Equal("/v1/chat/completions")) +} + +func TestForward_AnthropicAuthHeader(t *testing.T) { + g := NewWithT(t) + gotXAPIKey := make(chan string, 1) + gotVersion := make(chan string, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotXAPIKey <- r.Header.Get("x-api-key") + gotVersion <- r.Header.Get("anthropic-version") + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + t.Setenv("CLOUD_PROXY_ANTHROPIC_KEY", "sk-ant-fake") + + cp := NewCloudProxy() + err := cp.Load(&pb.ModelOptions{ + Proxy: &pb.ProxyOptions{ + UpstreamUrl: upstream.URL, + Mode: modePassthrough, + Provider: providerAnthropic, + ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_KEY", + }, + }) + g.Expect(err).NotTo(HaveOccurred()) + + c := newInProcClient(t, cp) + stream, err := c.Forward(context.Background()) + g.Expect(err).NotTo(HaveOccurred()) + err = stream.Send(&pb.ForwardRequest{Path: "/v1/messages", Method: "POST"}) + g.Expect(err).NotTo(HaveOccurred()) + _ = stream.CloseSend() + _, _ = stream.Recv() // drain status + for { + if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil { + break + } + } + + g.Expect(<-gotXAPIKey).To(Equal("sk-ant-fake")) + g.Expect(<-gotVersion).NotTo(BeEmpty()) +} + +func TestLoad_ValidatesConfig(t *testing.T) { + g := NewWithT(t) + cp := NewCloudProxy() + + err := cp.Load(&pb.ModelOptions{}) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("ProxyOptions")) + + err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{}}) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("upstream_url")) + + err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{ + UpstreamUrl: "https://example.com", + Mode: "rewrite", + }}) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("unknown mode")) + + // translate + openai should load successfully (Phase 5). + err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{ + UpstreamUrl: "https://example.com/v1/chat/completions", + Mode: modeTranslate, + Provider: providerOpenAI, + }}) + g.Expect(err).NotTo(HaveOccurred()) + + // translate + anthropic should load successfully (Phase 6). + err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{ + UpstreamUrl: "https://example.com/v1/messages", + Mode: modeTranslate, + Provider: providerAnthropic, + }}) + g.Expect(err).NotTo(HaveOccurred()) + + err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{ + UpstreamUrl: "https://example.com", + ApiKeyEnv: "DEFINITELY_UNSET_ENV_VAR_XYZ", + }}) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("unset")) +} + +func TestForward_RejectsWithoutLoad(t *testing.T) { + g := NewWithT(t) + cp := NewCloudProxy() + c := newInProcClient(t, cp) + stream, err := c.Forward(context.Background()) + g.Expect(err).NotTo(HaveOccurred()) + _ = stream.CloseSend() + _, err = stream.Recv() + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("not loaded")) +} + +func hasHeader(hs []*pb.ForwardHeader, name, value string) bool { + for _, h := range hs { + if strings.EqualFold(h.GetName(), name) && h.GetValue() == value { + return true + } + } + return false +} diff --git a/backend/go/cloud-proxy/run.sh b/backend/go/cloud-proxy/run.sh new file mode 100755 index 000000000..c533c093a --- /dev/null +++ b/backend/go/cloud-proxy/run.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -ex + +CURDIR=$(dirname "$(realpath $0)") + +exec $CURDIR/cloud-proxy "$@" diff --git a/backend/go/cloud-proxy/toolcalls_test.go b/backend/go/cloud-proxy/toolcalls_test.go new file mode 100644 index 000000000..e05de8023 --- /dev/null +++ b/backend/go/cloud-proxy/toolcalls_test.go @@ -0,0 +1,232 @@ +package main + +import ( + "strings" + "testing" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/gomega" +) + +// OpenAI: non-streaming tool call response. Verify the response is +// mapped to Reply.ChatDeltas[].ToolCalls with id/name/arguments intact, +// and usage tokens land on Reply.PromptTokens / Reply.Tokens. +func TestPredictRich_OpenAI_ToolCalls(t *testing.T) { + srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) { + return 200, `{ + "id":"resp-1", + "choices":[{ + "index":0, + "message":{ + "role":"assistant", + "content":"", + "tool_calls":[ + {"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"SF\"}"}}, + {"id":"call_def","type":"function","function":{"name":"get_time","arguments":"{\"tz\":\"PT\"}"}} + ] + }, + "finish_reason":"tool_calls" + }], + "usage":{"prompt_tokens":42,"completion_tokens":18,"total_tokens":60} + }`, "application/json" + }) + defer srv.Close() + g := NewWithT(t) + cp := newTranslateCloudProxy(t, srv.URL) + + reply, err := cp.PredictRich(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}}, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(string(reply.GetMessage())).To(Equal("")) + g.Expect(reply.GetPromptTokens()).To(Equal(int32(42))) + g.Expect(reply.GetTokens()).To(Equal(int32(18))) + g.Expect(reply.GetChatDeltas()).To(HaveLen(1)) + tcs := reply.GetChatDeltas()[0].GetToolCalls() + g.Expect(tcs).To(HaveLen(2)) + g.Expect(tcs[0].GetId()).To(Equal("call_abc")) + g.Expect(tcs[0].GetName()).To(Equal("get_weather")) + g.Expect(tcs[0].GetArguments()).To(ContainSubstring(`"location":"SF"`)) + g.Expect(tcs[1].GetId()).To(Equal("call_def")) + g.Expect(tcs[1].GetName()).To(Equal("get_time")) +} + +// OpenAI: streaming tool call. Arguments arrive as a sequence of +// delta chunks; the consumer is expected to concatenate by tool index. +// Verify each chunk reaches the channel and the assembled arguments +// match the input. +func TestPredictStreamRich_OpenAI_ToolCallDeltas(t *testing.T) { + chunks := []string{ + // Frame 0: announce the tool call (id + name, no args yet). + `{"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_xyz","type":"function","function":{"name":"search"}}]}}]}`, + // Frames 1-3: arguments arrive in fragments. + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"q\":"}}]}}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"clo"}}]}}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"uds\"}"}}]}}]}`, + // Stop frame. + `{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + } + body := "" + for _, c := range chunks { + body += "data: " + c + "\n\n" + } + body += "data: [DONE]\n\n" + + srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) { + return 200, body, "text/event-stream" + }) + defer srv.Close() + g := NewWithT(t) + cp := newTranslateCloudProxy(t, srv.URL) + + results := make(chan *pb.Reply, 16) + done := make(chan error, 1) + go func() { + done <- cp.PredictStreamRich(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "find something"}}, + }, results) + close(results) + }() + + var ( + toolName string + toolID string + toolIndex int32 = -1 + argsBuf strings.Builder + ) + for reply := range results { + for _, cd := range reply.GetChatDeltas() { + for _, tc := range cd.GetToolCalls() { + if tc.GetName() != "" { + toolName = tc.GetName() + } + if tc.GetId() != "" { + toolID = tc.GetId() + } + if toolIndex == -1 { + toolIndex = tc.GetIndex() + } + argsBuf.WriteString(tc.GetArguments()) + } + } + } + err := <-done + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(toolID).To(Equal("call_xyz")) + g.Expect(toolName).To(Equal("search")) + g.Expect(toolIndex).To(Equal(int32(0))) + g.Expect(argsBuf.String()).To(Equal(`{"q":"clouds"}`)) +} + +// Anthropic: non-streaming tool_use block. The block appears in +// Content[] alongside text blocks; the input field is a structured +// JSON object. Map to ToolCallDelta with arguments as serialised JSON +// so downstream OpenAI-shaped consumers see a familiar format. +func TestPredictRich_Anthropic_ToolUse(t *testing.T) { + srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, `{ + "id":"msg_1","type":"message","role":"assistant", + "content":[ + {"type":"text","text":"Let me check that."}, + {"type":"tool_use","id":"toolu_01","name":"weather","input":{"location":"SF"}} + ], + "model":"claude","usage":{"input_tokens":12,"output_tokens":34} + }`, "application/json" + }) + defer srv.Close() + g := NewWithT(t) + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + reply, err := cp.PredictRich(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}}, + Tokens: 64, + }) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(string(reply.GetMessage())).To(Equal("Let me check that.")) + g.Expect(reply.GetPromptTokens()).To(Equal(int32(12))) + g.Expect(reply.GetTokens()).To(Equal(int32(34))) + g.Expect(reply.GetChatDeltas()).To(HaveLen(1)) + g.Expect(reply.GetChatDeltas()[0].GetToolCalls()).To(HaveLen(1)) + tc := reply.GetChatDeltas()[0].GetToolCalls()[0] + g.Expect(tc.GetId()).To(Equal("toolu_01")) + g.Expect(tc.GetName()).To(Equal("weather")) + g.Expect(tc.GetArguments()).To(ContainSubstring(`"location":"SF"`)) +} + +// Anthropic: streaming tool_use. content_block_start announces the +// tool's id + name; input_json_delta events carry argument fragments +// which the consumer accumulates. message_delta carries final usage. +func TestPredictStreamRich_Anthropic_InputJSONDelta(t *testing.T) { + frames := []string{ + "event: message_start\ndata: {\"type\":\"message_start\"}\n\n", + // Block 0 is a tool_use; consumer should allocate a slot. + "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_42\",\"name\":\"lookup\"}}\n\n", + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"q\\\":\"}}\n\n", + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"\\\"rain\\\"}\"}}\n\n", + "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "event: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n", + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", + } + body := strings.Join(frames, "") + + srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) { + return 200, body, "text/event-stream" + }) + defer srv.Close() + g := NewWithT(t) + cp := newAnthropicTranslateCloudProxy(t, srv.URL) + + results := make(chan *pb.Reply, 16) + done := make(chan error, 1) + go func() { + done <- cp.PredictStreamRich(&pb.PredictOptions{ + Messages: []*pb.Message{{Role: "user", Content: "rain?"}}, + Tokens: 64, + }, results) + close(results) + }() + + var ( + toolID, toolName string + argsBuf strings.Builder + finalTokens int32 + ) + for reply := range results { + if reply.GetTokens() > 0 && len(reply.GetChatDeltas()) == 0 { + finalTokens = reply.GetTokens() + continue + } + for _, cd := range reply.GetChatDeltas() { + for _, tc := range cd.GetToolCalls() { + if tc.GetId() != "" { + toolID = tc.GetId() + } + if tc.GetName() != "" { + toolName = tc.GetName() + } + argsBuf.WriteString(tc.GetArguments()) + } + } + } + err := <-done + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(toolID).To(Equal("toolu_42")) + g.Expect(toolName).To(Equal("lookup")) + g.Expect(argsBuf.String()).To(Equal(`{"q":"rain"}`)) + g.Expect(finalTokens).To(Equal(int32(7))) +} + +// Sanity: the legacy Predict() (string, error) signature still works +// — it delegates to PredictRich and extracts Message. +func TestPredict_LegacyWrapper_OpenAI(t *testing.T) { + srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) { + return 200, `{"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, "application/json" + }) + defer srv.Close() + g := NewWithT(t) + cp := newTranslateCloudProxy(t, srv.URL) + + got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "hi"}}}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(got).To(Equal("hello")) +} diff --git a/backend/go/local-store/debug.go b/backend/go/local-store/debug.go index 2c3d77cab..503b4ece2 100644 --- a/backend/go/local-store/debug.go +++ b/backend/go/local-store/debug.go @@ -8,6 +8,6 @@ import ( func assert(cond bool, msg string) { if !cond { - xlog.Fatal().Stack().Msg(msg) + xlog.Fatal(msg) } } diff --git a/backend/go/local-store/store.go b/backend/go/local-store/store.go index e2ad54098..2085f74a9 100644 --- a/backend/go/local-store/store.go +++ b/backend/go/local-store/store.go @@ -1,7 +1,22 @@ package main -// This is a wrapper to statisfy the GRPC service interface -// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +// LocalAI's in-process vector store, exposed as a gRPC backend. Keep +// the implementation here — NOT in a pkg/ library imported by the main +// LocalAI process. The whole point of the gRPC surface is that vector +// storage is a backend like any other (local-store, qdrant, pinecone, +// ...) and can be swapped without changing the routing/recognition +// code that consumes it. +// +// Storage is a sorted parallel-slice (keys [][]float32, values +// [][]byte). Set/Delete preserve the sort so Get can binary-search. +// Find scans linearly and uses a heap to keep the top-K — fine for +// the tens-to-thousands range. The "normalized fast path" (Find when +// every stored key has unit magnitude AND the query is normalized) +// skips the per-item magnitude calculation. +// +// Concurrency: base.SingleThread serialises gRPC calls so the +// non-thread-safe slice/heap manipulation here is sound. + import ( "container/heap" "fmt" @@ -10,30 +25,27 @@ import ( "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" - - "github.com/mudler/xlog" + "github.com/mudler/LocalAI/pkg/store" ) type Store struct { base.SingleThread - // The sorted keys - keys [][]float32 - // The sorted values + keys [][]float32 values [][]byte - // If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions - // TODO: Should we normalize incoming keys if they are not instead? + // keysAreNormalized stays true until any non-unit-magnitude key + // is added; once false, the magnitude-aware fallback path is + // used by Find. Re-evaluated only at Set time, never again on + // its own — a deletion of the offending key does NOT flip it + // back to true (the bookkeeping cost would dominate the gain). keysAreNormalized bool - // The first key decides the length of the keys - keyLen int -} -// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because -// that's theoretically best for memory layout and cache locality, but this isn't optimized yet. -type Pair struct { - Key []float32 - Value []byte + // keyLen is the dimension of every stored key. -1 means "no + // keys yet, dimension is open". Dimension mismatch on Set is + // rejected so cosine similarity (which requires equal-length + // vectors) doesn't silently mis-match. + keyLen int } func NewStore() *Store { @@ -45,334 +57,278 @@ func NewStore() *Store { } } -func compareSlices(k1, k2 []float32) int { - assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) - - return slices.Compare(k1, k2) -} - -func hasKey(unsortedSlice [][]float32, target []float32) bool { - return slices.ContainsFunc(unsortedSlice, func(k []float32) bool { - return compareSlices(k, target) == 0 - }) -} - -func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) { - return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int { - return compareSlices(k, t) - }) -} - -func isSortedPairs(kvs []Pair) bool { - for i := 1; i < len(kvs); i++ { - if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 { - return false - } - } - - return true -} - -func isSortedKeys(keys [][]float32) bool { - for i := 1; i < len(keys); i++ { - if compareSlices(keys[i-1], keys[i]) > 0 { - return false - } - } - - return true -} - -func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 { - ks := make([][]float32, len(keys)) - - for i, k := range keys { - ks[i] = k.Floats - } - - slices.SortFunc(ks, compareSlices) - - assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys))) - assert(isSortedKeys(ks), "keys are not sorted") - - return ks -} - +// Load is a no-op — local-store has no on-disk artefact. opts.Model is +// just a namespace identifier; isolation is already handled upstream +// (ModelLoader spawns a fresh local-store process per (backend, +// model) tuple, so each namespace is its own Store{} instance). func (s *Store) Load(opts *pb.ModelOptions) error { - // local-store is an in-memory vector store with no on-disk artefact to - // load — opts.Model is just a namespace identifier. The old `!= ""` guard - // rejected any non-empty model name with "not implemented", which broke - // callers that pass a namespace to isolate embedding spaces (face vs. - // voice biometrics both go through local-store but need distinct stores - // so ArcFace 512-D and ECAPA-TDNN 192-D don't collide). Namespace - // isolation is already handled upstream: ModelLoader spawns a fresh - // local-store process per (backend, model) tuple, so each namespace is - // its own Store{} instance. Nothing to do here beyond accepting the load. _ = opts return nil } -// Sort the incoming kvs and merge them with the existing sorted kvs func (s *Store) StoresSet(opts *pb.StoresSetOptions) error { - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to add") + keys := store.UnwrapKeys(opts.Keys) + values := store.UnwrapValues(opts.Values) + if len(keys) == 0 { + return fmt.Errorf("local-store: Set: no keys to add") } - - if len(opts.Keys) != len(opts.Values) { - return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values)) + if len(keys) != len(values) { + return fmt.Errorf("local-store: Set: len(keys) = %d, len(values) = %d", len(keys), len(values)) } if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) - } + s.keyLen = len(keys[0]) + } else if len(keys[0]) != s.keyLen { + return fmt.Errorf("local-store: Set: key length %d does not match existing %d", len(keys[0]), s.keyLen) } - kvs := make([]Pair, len(opts.Keys)) - - for i, k := range opts.Keys { - if s.keysAreNormalized && !isNormalized(k.Floats) { + kvs := make([]incomingPair, len(keys)) + for i, k := range keys { + if len(k) != s.keyLen { + return fmt.Errorf("local-store: Set: key %d length %d does not match existing %d", i, len(k), s.keyLen) + } + if s.keysAreNormalized && !isNormalized(k) { s.keysAreNormalized = false - var sample []float32 - if len(s.keys) > 5 { - sample = k.Floats[:5] - } else { - sample = k.Floats - } - xlog.Debug("Key is not normalized", "sample", sample) - } - - kvs[i] = Pair{ - Key: k.Floats, - Value: opts.Values[i].Bytes, } + kvs[i] = incomingPair{key: k, value: values[i]} } - slices.SortFunc(kvs, func(a, b Pair) int { - return compareSlices(a.Key, b.Key) - }) - - assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys))) - assert(isSortedPairs(kvs), "keys are not sorted") - - l := len(kvs) + len(s.keys) - merge_ks := make([][]float32, 0, l) - merge_vs := make([][]byte, 0, l) - - i, j := 0, 0 - for { - if i+j >= l { - break - } - - if i >= len(kvs) { - merge_ks = append(merge_ks, s.keys[j]) - merge_vs = append(merge_vs, s.values[j]) - j++ - continue - } - - if j >= len(s.keys) { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - continue - } - - c := compareSlices(kvs[i].Key, s.keys[j]) - if c < 0 { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - } else if c > 0 { - merge_ks = append(merge_ks, s.keys[j]) - merge_vs = append(merge_vs, s.values[j]) - j++ - } else { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - j++ - } - } - - assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l)) - assert(isSortedKeys(merge_ks), "merge keys are not sorted") - - s.keys = merge_ks - s.values = merge_vs + slices.SortFunc(kvs, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) }) + merged := mergeSortedPairs(s.keys, s.values, kvs) + s.keys = merged.keys + s.values = merged.values + assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Set: s.keys not sorted post-merge") + assert(len(s.keys) == len(s.values), "Set: keys/values length skew") return nil } func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error { - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to delete") + keys := store.UnwrapKeys(opts.Keys) + if len(keys) == 0 { + return fmt.Errorf("local-store: Delete: no keys to delete") } - - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to add") - } - - if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) - } - } - - ks := sortIntoKeySlicese(opts.Keys) - - l := len(s.keys) - len(ks) - merge_ks := make([][]float32, 0, l) - merge_vs := make([][]byte, 0, l) - - tail_ks := s.keys - tail_vs := s.values - for _, k := range ks { - j, found := findInSortedSlice(tail_ks, k) - - if found { - merge_ks = append(merge_ks, tail_ks[:j]...) - merge_vs = append(merge_vs, tail_vs[:j]...) - tail_ks = tail_ks[j+1:] - tail_vs = tail_vs[j+1:] - } else { - assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k)) - } - - xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs)) - } - - merge_ks = append(merge_ks, tail_ks...) - merge_vs = append(merge_vs, tail_vs...) - - assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys))) - - s.keys = merge_ks - s.values = merge_vs - - assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l)) - assert(isSortedKeys(s.keys), "keys are not sorted") - assert(func() bool { - for _, k := range ks { - if _, found := findInSortedSlice(s.keys, k); found { - return false + if s.keyLen != -1 { + for i, k := range keys { + if len(k) != s.keyLen { + return fmt.Errorf("local-store: Delete: key %d length %d does not match existing %d", i, len(k), s.keyLen) } } - return true - }(), "Keys to delete still present") - - if len(s.keys) != l { - xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l) } + sortedKeys := append([][]float32(nil), keys...) + slices.SortFunc(sortedKeys, slices.Compare[[]float32]) + mergedK := make([][]float32, 0, len(s.keys)) + mergedV := make([][]byte, 0, len(s.keys)) + tailK := s.keys + tailV := s.values + for _, k := range sortedKeys { + j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32]) + if ok { + mergedK = append(mergedK, tailK[:j]...) + mergedV = append(mergedV, tailV[:j]...) + tailK = tailK[j+1:] + tailV = tailV[j+1:] + } + } + mergedK = append(mergedK, tailK...) + mergedV = append(mergedV, tailV...) + s.keys = mergedK + s.values = mergedV + assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Delete: s.keys not sorted post-merge") + assert(len(s.keys) == len(s.values), "Delete: keys/values length skew") return nil } +// StoresGet fetches values for the given keys. Missing keys are +// omitted from the result rather than reported as an error — callers +// compare returned-key length against requested-key length to detect +// them. Returned slices are aligned. func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) { - pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys)) - pbValues := make([]*pb.StoresValue, 0, len(opts.Keys)) - ks := sortIntoKeySlicese(opts.Keys) - + keys := store.UnwrapKeys(opts.Keys) if len(s.keys) == 0 { - xlog.Debug("Get: No keys in store") + return pb.StoresGetResult{}, nil } - - if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) + if s.keyLen != -1 { + for i, k := range keys { + if len(k) != s.keyLen { + return pb.StoresGetResult{}, fmt.Errorf("local-store: Get: key %d length %d does not match existing %d", i, len(k), s.keyLen) + } } } + sortedKeys := append([][]float32(nil), keys...) + slices.SortFunc(sortedKeys, slices.Compare[[]float32]) - tail_k := s.keys - tail_v := s.values - for i, k := range ks { - j, found := findInSortedSlice(tail_k, k) - - if found { - pbKeys = append(pbKeys, &pb.StoresKey{ - Floats: k, - }) - pbValues = append(pbValues, &pb.StoresValue{ - Bytes: tail_v[j], - }) - - tail_k = tail_k[j+1:] - tail_v = tail_v[j+1:] - } else { - assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k)) + var foundKeys [][]float32 + var foundValues [][]byte + tailK := s.keys + tailV := s.values + for _, k := range sortedKeys { + j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32]) + if !ok { + continue } + foundKeys = append(foundKeys, tailK[j]) + foundValues = append(foundValues, tailV[j]) + tailK = tailK[j+1:] + tailV = tailV[j+1:] } - - if len(pbKeys) != len(opts.Keys) { - xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys)) - } - return pb.StoresGetResult{ - Keys: pbKeys, - Values: pbValues, + Keys: store.WrapKeys(foundKeys), + Values: store.WrapValues(foundValues), }, nil } +// StoresFind returns the topK nearest stored entries by cosine +// similarity, ordered most-similar first. An empty store returns +// empty slices and no error. +func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { + query := opts.Key.Floats + topK := int(opts.TopK) + if topK < 1 { + return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: topK = %d, must be >= 1", topK) + } + if len(s.keys) == 0 { + return pb.StoresFindResult{}, nil + } + if len(query) != s.keyLen { + return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: query length %d does not match existing %d", len(query), s.keyLen) + } + + var keys [][]float32 + var values [][]byte + var sims []float32 + if s.keysAreNormalized && isNormalized(query) { + keys, values, sims = s.findNormalized(query, topK) + } else { + keys, values, sims = s.findFallback(query, topK) + } + return pb.StoresFindResult{ + Keys: store.WrapKeys(keys), + Values: store.WrapValues(values), + Similarities: sims, + }, nil +} + +func (s *Store) findNormalized(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) { + assert(s.keysAreNormalized, "findNormalized: s.keysAreNormalized is false") + assert(isNormalized(query), "findNormalized: query is not unit-length") + pq := make(priorityQueue, 0, topK) + heap.Init(&pq) + for i, k := range s.keys { + var dot float32 + for j := range k { + dot += query[j] * k[j] + } + assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("findNormalized: dot %f out of [-1, 1] — keysAreNormalized invariant violated", dot)) + heap.Push(&pq, &priorityItem{similarity: dot, key: k, value: s.values[i]}) + if pq.Len() > topK { + heap.Pop(&pq) + } + } + return drainPQ(&pq) +} + +func (s *Store) findFallback(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) { + var qmag float64 + for _, v := range query { + qmag += float64(v) * float64(v) + } + qmag = math.Sqrt(qmag) + pq := make(priorityQueue, 0, topK) + heap.Init(&pq) + for i, k := range s.keys { + var dot, kmag float64 + for j := range k { + dot += float64(query[j]) * float64(k[j]) + kmag += float64(k[j]) * float64(k[j]) + } + denom := qmag * math.Sqrt(kmag) + var sim float32 + if denom > 0 { + sim = float32(dot / denom) + } + heap.Push(&pq, &priorityItem{similarity: sim, key: k, value: s.values[i]}) + if pq.Len() > topK { + heap.Pop(&pq) + } + } + return drainPQ(&pq) +} + func isNormalized(k []float32) bool { var sum float64 - for _, v := range k { - v64 := float64(v) - sum += v64 * v64 + sum += float64(v) * float64(v) } - - s := math.Sqrt(sum) - - return s >= 0.99 && s <= 1.01 + mag := math.Sqrt(sum) + return mag >= 0.99 && mag <= 1.01 } -// TODO: This we could replace with handwritten SIMD code -func normalizedCosineSimilarity(k1, k2 []float32) float32 { - assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) +type incomingPair struct { + key []float32 + value []byte +} - var dot float32 - for i := range len(k1) { - dot += k1[i] * k2[i] +type pairs struct { + keys [][]float32 + values [][]byte +} + +// mergeSortedPairs merges (existing, incoming) into a fresh sorted +// slice. Equal keys take the incoming value — Set is upsert. +func mergeSortedPairs(existingK [][]float32, existingV [][]byte, incoming []incomingPair) pairs { + assert(slices.IsSortedFunc(existingK, slices.Compare[[]float32]), "mergeSortedPairs: existing not sorted") + assert(slices.IsSortedFunc(incoming, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) }), "mergeSortedPairs: incoming not sorted") + l := len(existingK) + len(incoming) + mk := make([][]float32, 0, l) + mv := make([][]byte, 0, l) + i, j := 0, 0 + for i < len(incoming) || j < len(existingK) { + switch { + case j >= len(existingK): + mk = append(mk, incoming[i].key) + mv = append(mv, incoming[i].value) + i++ + case i >= len(incoming): + mk = append(mk, existingK[j]) + mv = append(mv, existingV[j]) + j++ + default: + c := slices.Compare(incoming[i].key, existingK[j]) + switch { + case c < 0: + mk = append(mk, incoming[i].key) + mv = append(mv, incoming[i].value) + i++ + case c > 0: + mk = append(mk, existingK[j]) + mv = append(mv, existingV[j]) + j++ + default: + mk = append(mk, incoming[i].key) + mv = append(mv, incoming[i].value) + i++ + j++ + } + } } - - assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot)) - - // 2.0 * (1.0 - dot) would be the Euclidean distance - return dot + return pairs{keys: mk, values: mv} } -type PriorityItem struct { - Similarity float32 - Key []float32 - Value []byte +type priorityItem struct { + similarity float32 + key []float32 + value []byte } -type PriorityQueue []*PriorityItem +type priorityQueue []*priorityItem -func (pq PriorityQueue) Len() int { return len(pq) } - -func (pq PriorityQueue) Less(i, j int) bool { - // Inverted because the most similar should be at the top - return pq[i].Similarity < pq[j].Similarity -} - -func (pq PriorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] -} - -func (pq *PriorityQueue) Push(x any) { - item := x.(*PriorityItem) - *pq = append(*pq, item) -} - -func (pq *PriorityQueue) Pop() any { +func (pq priorityQueue) Len() int { return len(pq) } +func (pq priorityQueue) Less(i, j int) bool { return pq[i].similarity < pq[j].similarity } +func (pq priorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] } +func (pq *priorityQueue) Push(x any) { *pq = append(*pq, x.(*priorityItem)) } +func (pq *priorityQueue) Pop() any { old := *pq n := len(old) item := old[n-1] @@ -380,142 +336,16 @@ func (pq *PriorityQueue) Pop() any { return item } -func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { - tk := opts.Key.Floats - top_ks := make(PriorityQueue, 0, int(opts.TopK)) - heap.Init(&top_ks) - - for i, k := range s.keys { - sim := normalizedCosineSimilarity(tk, k) - heap.Push(&top_ks, &PriorityItem{ - Similarity: sim, - Key: k, - Value: s.values[i], - }) - - if top_ks.Len() > int(opts.TopK) { - heap.Pop(&top_ks) - } - } - - similarities := make([]float32, top_ks.Len()) - pbKeys := make([]*pb.StoresKey, top_ks.Len()) - pbValues := make([]*pb.StoresValue, top_ks.Len()) - - for i := top_ks.Len() - 1; i >= 0; i-- { - item := heap.Pop(&top_ks).(*PriorityItem) - - similarities[i] = item.Similarity - pbKeys[i] = &pb.StoresKey{ - Floats: item.Key, - } - pbValues[i] = &pb.StoresValue{ - Bytes: item.Value, - } - } - - return pb.StoresFindResult{ - Keys: pbKeys, - Values: pbValues, - Similarities: similarities, - }, nil -} - -func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 { - assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) - - var dot, mag2 float64 - for i := range len(k1) { - dot += float64(k1[i] * k2[i]) - mag2 += float64(k2[i] * k2[i]) - } - - sim := float32(dot / (mag1 * math.Sqrt(mag2))) - - assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim)) - - return sim -} - -func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { - tk := opts.Key.Floats - top_ks := make(PriorityQueue, 0, int(opts.TopK)) - heap.Init(&top_ks) - - var mag1 float64 - for _, v := range tk { - mag1 += float64(v * v) - } - mag1 = math.Sqrt(mag1) - - for i, k := range s.keys { - dist := cosineSimilarity(tk, k, mag1) - heap.Push(&top_ks, &PriorityItem{ - Similarity: dist, - Key: k, - Value: s.values[i], - }) - - if top_ks.Len() > int(opts.TopK) { - heap.Pop(&top_ks) - } - } - - similarities := make([]float32, top_ks.Len()) - pbKeys := make([]*pb.StoresKey, top_ks.Len()) - pbValues := make([]*pb.StoresValue, top_ks.Len()) - - for i := top_ks.Len() - 1; i >= 0; i-- { - item := heap.Pop(&top_ks).(*PriorityItem) - - similarities[i] = item.Similarity - pbKeys[i] = &pb.StoresKey{ - Floats: item.Key, - } - pbValues[i] = &pb.StoresValue{ - Bytes: item.Value, - } - } - - return pb.StoresFindResult{ - Keys: pbKeys, - Values: pbValues, - Similarities: similarities, - }, nil -} - -func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { - tk := opts.Key.Floats - - if len(tk) != s.keyLen { - return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen) - } - - if opts.TopK < 1 { - return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK) - } - - if s.keyLen == -1 { - s.keyLen = len(opts.Key.Floats) - } else { - if len(opts.Key.Floats) != s.keyLen { - return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen) - } - } - - if s.keysAreNormalized && isNormalized(tk) { - return s.StoresFindNormalized(opts) - } else { - if s.keysAreNormalized { - var sample []float32 - if len(s.keys) > 5 { - sample = tk[:5] - } else { - sample = tk - } - xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample) - } - - return s.StoresFindFallback(opts) +func drainPQ(pq *priorityQueue) (keys [][]float32, values [][]byte, similarities []float32) { + n := pq.Len() + keys = make([][]float32, n) + values = make([][]byte, n) + similarities = make([]float32, n) + for i := n - 1; i >= 0; i-- { + item := heap.Pop(pq).(*priorityItem) + keys[i] = item.key + values[i] = item.value + similarities[i] = item.similarity } + return keys, values, similarities } diff --git a/backend/go/local-store/store_suite_test.go b/backend/go/local-store/store_suite_test.go new file mode 100644 index 000000000..63affb46b --- /dev/null +++ b/backend/go/local-store/store_suite_test.go @@ -0,0 +1,13 @@ +package main + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestLocalStore(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "local-store test suite") +} diff --git a/backend/go/local-store/store_test.go b/backend/go/local-store/store_test.go new file mode 100644 index 000000000..2043647c0 --- /dev/null +++ b/backend/go/local-store/store_test.go @@ -0,0 +1,284 @@ +package main + +// Regression suite for the local-store gRPC backend. Exercises the +// Stores{Set,Get,Find,Delete} surface — the only public contract. +// Callers (face/voice recognition, the routing KNN classifier) reach +// this code via grpc.Backend, so testing at the wire-shaped boundary +// matches the production import shape. + +import ( + "math" + "math/rand/v2" + "testing" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("StoresSet", func() { + It("rejects empty input", func() { + Expect(NewStore().StoresSet(&pb.StoresSetOptions{})).NotTo(Succeed(), "Set with no keys should fail") + }) + + It("rejects key/value length mismatch", func() { + err := NewStore().StoresSet(&pb.StoresSetOptions{ + Keys: wrapKeys([][]float32{{1, 0, 0}}), + Values: wrapValues([][]byte{[]byte("a"), []byte("b")}), + }) + Expect(err).To(HaveOccurred(), "len(keys) != len(values) should fail") + }) + + It("rejects dimension mismatch on later add", func() { + s := NewStore() + mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("3d")}) + err := s.StoresSet(&pb.StoresSetOptions{ + Keys: wrapKeys([][]float32{{1, 0}}), + Values: wrapValues([][]byte{[]byte("2d")}), + }) + Expect(err).To(HaveOccurred(), "dimension mismatch on later Set should fail") + }) + + It("rejects dimension mismatch within batch", func() { + err := NewStore().StoresSet(&pb.StoresSetOptions{ + Keys: wrapKeys([][]float32{{1, 0, 0}, {1, 0}}), + Values: wrapValues([][]byte{[]byte("3d"), []byte("2d")}), + }) + Expect(err).To(HaveOccurred(), "mixed-dimension within one batch should fail") + }) + + It("merges sorted and updates existing key", func() { + s := NewStore() + mustSet(s, [][]float32{{0.3, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("c"), []byte("a")}) + mustSet(s, [][]float32{{0.2, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("b"), []byte("a-updated")}) + Expect(s.keys).To(HaveLen(3)) + got := singleGet(s, []float32{0.1, 0, 0}) + Expect(string(got)).To(Equal("a-updated")) + }) +}) + +var _ = Describe("StoresGet", func() { + It("round-trips multi-key", func() { + s := NewStore() + mustSet(s, + [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}}, + [][]byte{[]byte("a"), []byte("b"), []byte("c")}, + ) + res, err := s.StoresGet(&pb.StoresGetOptions{ + Keys: wrapKeys([][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}}), + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Keys).To(HaveLen(2)) + }) + + It("omits missing keys rather than erroring", func() { + s := NewStore() + mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")}) + res, err := s.StoresGet(&pb.StoresGetOptions{ + Keys: wrapKeys([][]float32{{0.1, 0, 0}, {0.9, 0, 0}}), + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Keys).To(HaveLen(1)) + }) +}) + +var _ = Describe("StoresDelete", func() { + It("removes and preserves sort", func() { + s := NewStore() + mustSet(s, + [][]float32{{0.1, 0, 0}, {0.2, 0, 0}, {0.3, 0, 0}, {0.4, 0, 0}}, + [][]byte{[]byte("a"), []byte("b"), []byte("c"), []byte("d")}, + ) + Expect(s.StoresDelete(&pb.StoresDeleteOptions{ + Keys: wrapKeys([][]float32{{0.2, 0, 0}, {0.4, 0, 0}}), + })).To(Succeed()) + Expect(s.keys).To(HaveLen(2)) + }) + + It("tolerates missing keys", func() { + s := NewStore() + mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")}) + Expect(s.StoresDelete(&pb.StoresDeleteOptions{ + Keys: wrapKeys([][]float32{{0.9, 0, 0}}), + })).To(Succeed(), "delete of missing key should succeed") + Expect(s.keys).To(HaveLen(1)) + }) +}) + +var _ = Describe("StoresFind", func() { + It("returns normalized top-K", func() { + s := NewStore() + mustSet(s, + [][]float32{ + normalizeVec([]float32{1, 0, 0}), + normalizeVec([]float32{0, 1, 0}), + normalizeVec([]float32{0, 0, 1}), + }, + [][]byte{[]byte("x"), []byte("y"), []byte("z")}, + ) + res, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: normalizeVec([]float32{0.9, 0.1, 0})}, + TopK: 2, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Keys).To(HaveLen(2)) + Expect(res.Similarities[0]).To(BeNumerically(">=", res.Similarities[1]), "results not sorted desc by similarity") + Expect(string(res.Values[0].Bytes)).To(Equal("x")) + }) + + It("falls back for non-normalized keys", func() { + s := NewStore() + mustSet(s, [][]float32{{2, 0, 0}, {0, 3, 0}}, [][]byte{[]byte("x"), []byte("y")}) + Expect(s.keysAreNormalized).To(BeFalse(), "store should report non-normalized after Set with magnitude > 1") + res, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: []float32{4, 0, 0}}, + TopK: 1, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(string(res.Values[0].Bytes)).To(Equal("x")) + Expect(res.Similarities[0]).To(BeNumerically(">=", float32(0.99))) + Expect(res.Similarities[0]).To(BeNumerically("<=", float32(1.01))) + }) + + It("rejects zero topK", func() { + s := NewStore() + mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")}) + _, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: []float32{1, 0, 0}}, + TopK: 0, + }) + Expect(err).To(HaveOccurred(), "Find with topK=0 should fail") + }) + + It("rejects dimension mismatch", func() { + s := NewStore() + mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")}) + _, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: []float32{1, 0}}, + TopK: 1, + }) + Expect(err).To(HaveOccurred(), "Find with mismatched dimension should fail") + }) + + It("returns empty result on empty store", func() { + res, err := NewStore().StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: []float32{1, 0, 0}}, + TopK: 5, + }) + Expect(err).NotTo(HaveOccurred(), "Find on empty store should succeed") + Expect(res.Keys).To(BeEmpty()) + }) + + It("handles topK larger than store", func() { + s := NewStore() + mustSet(s, + [][]float32{normalizeVec([]float32{1, 0, 0}), normalizeVec([]float32{0, 1, 0})}, + [][]byte{[]byte("x"), []byte("y")}, + ) + res, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: normalizeVec([]float32{1, 0, 0})}, + TopK: 10, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Keys).To(HaveLen(2)) + }) +}) + +var _ = Describe("StoresLoad", func() { + It("is a no-op", func() { + Expect(NewStore().Load(&pb.ModelOptions{Model: "any-namespace"})).To(Succeed()) + }) +}) + +func BenchmarkStoresFindNormalized(b *testing.B) { + const dim = 768 + for _, n := range []int{8, 32, 128, 512} { + b.Run(fmtN(n), func(b *testing.B) { + s := buildStore(b, n, dim) + query := normalizeVec(randVec(dim, 42)) + req := &pb.StoresFindOptions{Key: &pb.StoresKey{Floats: query}, TopK: 1} + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := s.StoresFind(req); err != nil { + b.Fatal(err) + } + } + }) + } +} + +// --- test helpers --- + +func mustSet(s *Store, keys [][]float32, values [][]byte) { + ExpectWithOffset(1, s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)})).To(Succeed()) +} + +func singleGet(s *Store, key []float32) []byte { + res, err := s.StoresGet(&pb.StoresGetOptions{Keys: wrapKeys([][]float32{key})}) + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + if len(res.Values) == 0 { + return nil + } + return res.Values[0].Bytes +} + +func wrapKeys(in [][]float32) []*pb.StoresKey { + out := make([]*pb.StoresKey, len(in)) + for i, k := range in { + out[i] = &pb.StoresKey{Floats: k} + } + return out +} + +func wrapValues(in [][]byte) []*pb.StoresValue { + out := make([]*pb.StoresValue, len(in)) + for i, v := range in { + out[i] = &pb.StoresValue{Bytes: v} + } + return out +} + +func buildStore(tb testing.TB, n, dim int) *Store { + tb.Helper() + s := NewStore() + keys := make([][]float32, n) + values := make([][]byte, n) + for i := 0; i < n; i++ { + keys[i] = normalizeVec(randVec(dim, int64(i)+1)) + values[i] = []byte{byte(i)} + } + if err := s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)}); err != nil { + tb.Fatal(err) + } + return s +} + +func randVec(dim int, seed int64) []float32 { + r := rand.New(rand.NewPCG(uint64(seed), 0xabcdef)) + v := make([]float32, dim) + for i := range v { + v[i] = float32(r.NormFloat64()) + } + return v +} + +func normalizeVec(v []float32) []float32 { + var sum float64 + for _, x := range v { + sum += float64(x) * float64(x) + } + mag := math.Sqrt(sum) + if mag == 0 { + return v + } + out := make([]float32, len(v)) + for i, x := range v { + out[i] = float32(float64(x) / mag) + } + return out +} + +func fmtN(n int) string { + return map[int]string{8: "n=8", 32: "n=32", 128: "n=128", 512: "n=512"}[n] +} diff --git a/backend/python/transformers/backend.py b/backend/python/transformers/backend.py index f2f70acb3..a8c1840b3 100644 --- a/backend/python/transformers/backend.py +++ b/backend/python/transformers/backend.py @@ -26,7 +26,7 @@ import torch.cuda XPU=os.environ.get("XPU", "0") == "1" import transformers as transformers_module -from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria +from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, pipeline from scipy.io import wavfile from sentence_transformers import SentenceTransformer @@ -200,6 +200,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): autoTokenizer = False self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode) self.SentenceTransformer = True + elif request.Type == "TokenClassification": + # NER / PII tagging via HuggingFace's token-classification + # pipeline. aggregation_strategy="simple" merges B-/I- tags + # into single spans and gives byte offsets back. The + # tokenizer is bundled inside the pipeline, so we skip the + # AutoTokenizer load below. + autoTokenizer = False + self.tokenClassifier = pipeline( + "token-classification", + model=model_name, + aggregation_strategy="simple", + device=0 if self.CUDA else -1, + trust_remote_code=request.TrustRemoteCode, + ) + self.TokenClassification = True else: # Generic: dynamically resolve model class from transformers model_type = TYPE_ALIASES.get(request.Type, request.Type) @@ -253,6 +268,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Model loaded successfully", success=True) + def TokenClassify(self, request, context): + # Runs HuggingFace's token-classification pipeline and returns + # the aggregated entity spans. The pipeline gives us byte + # offsets via aggregation_strategy="simple" (set at load + # time), so the caller can slice the original text without + # re-tokenising on the Go side. + if not getattr(self, "TokenClassification", False): + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("model was not loaded as Type=TokenClassification") + return backend_pb2.TokenClassifyResponse() + try: + results = self.tokenClassifier(request.text) + except Exception as err: + print("TokenClassify error:", err, file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"token-classification failed: {err}") + return backend_pb2.TokenClassifyResponse() + + threshold = request.threshold if request.threshold > 0 else 0.0 + entities = [] + for r in results: + score = float(r.get("score", 0.0)) + if score < threshold: + continue + entities.append(backend_pb2.TokenClassifyEntity( + entity_group=str(r.get("entity_group") or r.get("entity") or ""), + start=int(r.get("start", 0)), + end=int(r.get("end", 0)), + score=score, + text=str(r.get("word", "")), + )) + return backend_pb2.TokenClassifyResponse(entities=entities) + def Embedding(self, request, context): set_seed(request.Seed) # Tokenize input diff --git a/backend/python/vllm/backend.py b/backend/python/vllm/backend.py index 967c4420c..74598660b 100644 --- a/backend/python/vllm/backend.py +++ b/backend/python/vllm/backend.py @@ -356,6 +356,133 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): except Exception as e: return backend_pb2.Result(success=False, message=str(e)) + async def Score(self, request, context): + """ + Joint log-probability of each candidate continuation given the + shared prompt. Used by routing-policy multi-label classification + (read the distribution rather than asking the model to emit a + single argmax label), reranking, and reward-model scoring. + + Implementation uses vLLM's `prompt_logprobs` to recover the + per-token log P(token_i | tokens_= len(prompt_logprobs) or prompt_logprobs[position] is None: + continue + entry = prompt_logprobs[position] + lp_obj = entry.get(tok_id) + if lp_obj is not None: + lp = lp_obj.logprob + else: + # Token not in top-K; vLLM's top-1 may miss it. + # Fall back to the lowest available logprob in the + # entry — a conservative lower-bound on the true + # log P, biased against this candidate. + lp = min(v.logprob for v in entry.values()) + total += lp + if request.include_token_logprobs: + tokens_proto.append(backend_pb2.TokenLogProb( + token=self.tokenizer.decode([tok_id]), + log_prob=lp, + )) + + cs = backend_pb2.CandidateScore( + log_prob=total, + num_tokens=num_candidate_tokens, + ) + if request.length_normalize and num_candidate_tokens > 0: + cs.length_normalized_log_prob = total / num_candidate_tokens + if tokens_proto: + cs.tokens.extend(tokens_proto) + results.append(cs) + + return backend_pb2.ScoreResponse(candidates=results) + except Exception as e: + print(f"Score error: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + return backend_pb2.ScoreResponse() + async def _predict(self, request, context, streaming=False): # Build the sampling parameters # NOTE: this must stay in sync with the vllm backend diff --git a/core/application/application.go b/core/application/application.go index 852324e74..7a34279c9 100644 --- a/core/application/application.go +++ b/core/application/application.go @@ -9,11 +9,18 @@ import ( corebackend "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/services/agentpool" "github.com/mudler/LocalAI/core/services/facerecognition" "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/services/monitoring" "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/routing/admission" + "github.com/mudler/LocalAI/core/services/routing/billing" + "github.com/mudler/LocalAI/core/services/cloudproxy/mitm" + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/LocalAI/core/services/routing/router" "github.com/mudler/LocalAI/core/services/voicerecognition" "github.com/mudler/LocalAI/core/templates" pkggrpc "github.com/mudler/LocalAI/pkg/grpc" @@ -51,6 +58,22 @@ type Application struct { faceRegistry facerecognition.Registry voiceRegistry voicerecognition.Registry authDB *gorm.DB + metricsService *monitoring.LocalAIMetricsService + statsRecorder *billing.Recorder + fallbackUser *auth.User + piiRedactor *pii.Redactor + piiEvents pii.EventStore + mitmCA atomic.Pointer[mitm.CA] + mitmServer atomic.Pointer[mitm.Server] + mitmMutex sync.Mutex // serializes Stop+Start; readers use atomic loads + // mitmHostConflicts records duplicate-host claims across model configs. + // Non-empty disables the MITM listener until resolved — the strict + // 1-to-1 host↔model invariant the dispatcher relies on. Read by + // /api/middleware/status so the admin UI can surface the cause. + mitmHostConflicts atomic.Pointer[map[string][]string] + routerDecisions router.DecisionStore + routerRegistry *router.Registry + admissionLimiter *admission.Limiter watchdogMutex sync.Mutex watchdogStop chan bool p2pMutex sync.Mutex @@ -185,6 +208,103 @@ func (a *Application) AuthDB() *gorm.DB { return a.authDB } +// MetricsService returns the OTel + Prometheus metric service. nil when +// --disable-metrics is set or initialisation failed at startup. +// +// The service is created in startup.go before any counter is registered +// so that otel.SetMeterProvider runs early enough for the billing +// recorder's counters to bind to the Prom-backed provider rather than +// the no-op global. core/http/app.go reuses this instance instead of +// constructing its own — two providers would orphan one set of counters +// behind whichever provider lost the SetMeterProvider race. +func (a *Application) MetricsService() *monitoring.LocalAIMetricsService { + return a.metricsService +} + +// StatsRecorder returns the billing recorder used by the usage +// middleware. It is non-nil whenever stats are not explicitly disabled +// — i.e., the no-auth single-user path still gets a working recorder +// (in-memory by default). Routes register UsageMiddleware against this +// recorder regardless of auth state. +func (a *Application) StatsRecorder() *billing.Recorder { + return a.statsRecorder +} + +// FallbackUser is the synthetic "local" user that UsageMiddleware uses +// to attribute requests when no authenticated user is on the context +// (i.e., --auth is off). nil when auth is on, since real users are +// always available there. +func (a *Application) FallbackUser() *auth.User { + return a.fallbackUser +} + +// PIIRedactor returns the regex-tier PII redactor or nil if PII +// filtering is disabled. The chat-route middleware uses this to apply +// redaction before dispatch. +func (a *Application) PIIRedactor() *pii.Redactor { + return a.piiRedactor +} + +// PIIEvents returns the PII event store. Same nil-when-disabled +// semantics as PIIRedactor; admin REST and MCP read tools call List +// against it. +func (a *Application) PIIEvents() pii.EventStore { + return a.piiEvents +} + +// MITMCA returns the cloudproxy MITM proxy's CA, or nil when the +// MITM listener is disabled. +func (a *Application) MITMCA() *mitm.CA { return a.mitmCA.Load() } + +// MITMServer returns the running MITM proxy or nil. +func (a *Application) MITMServer() *mitm.Server { return a.mitmServer.Load() } + +// MITMHostConflicts returns a snapshot of host→[]model-name pairs that +// are claimed by 2+ model configs. Empty when the 1-to-1 invariant +// holds. Non-empty disables the MITM listener — read by the admin +// status endpoint to explain why. +func (a *Application) MITMHostConflicts() map[string][]string { + p := a.mitmHostConflicts.Load() + if p == nil { + return nil + } + return *p +} + +// MITMHostOwners returns the host→model-name map, useful for the +// admin status endpoint. The lookup is recomputed on each call to +// stay current with model-config edits without needing a +// MITMRestart. +func (a *Application) MITMHostOwners() map[string]string { + if a.backendLoader == nil { + return nil + } + return a.backendLoader.MITMHostOwners().Owners +} + +// RouterDecisions returns the routing decision store. nil when stats +// are disabled (--disable-stats); the RouteModel middleware skips the +// log write in that case but still rewrites requests. +func (a *Application) RouterDecisions() router.DecisionStore { + return a.routerDecisions +} + +// RouterClassifierRegistry returns the process-wide classifier cache. +// Shared between the OpenAI and Anthropic route middlewares so the +// admin stats endpoint sees every live classifier — and so a +// classifier built on the OpenAI route is reused on Anthropic. +func (a *Application) RouterClassifierRegistry() *router.Registry { + return a.routerRegistry +} + +// AdmissionLimiter returns the per-model admission limiter. The +// admission middleware uses it to gate concurrent requests; the +// admin status surface reads InFlight/Capacity from it for live +// load visibility. +func (a *Application) AdmissionLimiter() *admission.Limiter { + return a.admissionLimiter +} + // StartupConfig returns the original startup configuration (from env vars, before file loading) func (a *Application) StartupConfig() *config.ApplicationConfig { return a.startupConfig @@ -255,6 +375,15 @@ func (a *Application) start() error { a.modelLoader, a.galleryService, ) + // Wire usage tracking so the assistant's get_usage_stats tool + // returns real data; nil values keep the tool returning a clear + // "unavailable" error if startup ran with --disable-stats. + assistantClient.StatsRecorder = a.statsRecorder + assistantClient.FallbackUser = a.fallbackUser + // PII filter — same nil-or-real wiring. + assistantClient.PIIRedactor = a.piiRedactor + assistantClient.PIIEvents = a.piiEvents + assistantClient.RouterDecisions = a.routerDecisions if err := holder.Initialize(a.applicationConfig.Context, assistantClient, localaitools.Options{}); err != nil { // Why log+continue instead of fail: the assistant is an optional // feature; a failure here must not take down the whole server. diff --git a/core/application/mitm.go b/core/application/mitm.go new file mode 100644 index 000000000..293b3d449 --- /dev/null +++ b/core/application/mitm.go @@ -0,0 +1,146 @@ +package application + +import ( + "errors" + "fmt" + "path/filepath" + "sort" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/cloudproxy/mitm" + "github.com/mudler/xlog" +) + +func startMITMProxy(app *Application, options *config.ApplicationConfig) error { + app.mitmMutex.Lock() + defer app.mitmMutex.Unlock() + return startMITMLocked(app, options) +} + +func startMITMLocked(app *Application, options *config.ApplicationConfig) error { + // Validate the host↔model-config 1-to-1 invariant before binding + // the listener. Two configs claiming the same host means the + // dispatcher would have ambiguous PII settings; refuse to start + // rather than silently picking one. The conflict map is published + // for /api/middleware/status to surface in the UI. + ownership := app.backendLoader.MITMHostOwners() + if len(ownership.Conflicts) > 0 { + conflicts := ownership.Conflicts + app.mitmHostConflicts.Store(&conflicts) + hosts := make([]string, 0, len(conflicts)) + for h := range conflicts { + hosts = append(hosts, h) + } + sort.Strings(hosts) + xlog.Error("mitm: refusing to start — duplicate host claims across model configs", + "hosts", hosts, + "conflicts", conflicts, + ) + return errors.New("mitm: configuration error: duplicate host claims (see /api/middleware/status)") + } + app.mitmHostConflicts.Store(nil) + + caDir := options.MITMCADir + if caDir == "" { + base := options.DataPath + if base == "" { + base = "." + } + caDir = filepath.Join(base, "mitm-ca") + } + + if app.mitmCA.Load() == nil { + ca, err := mitm.LoadOrCreateCA(caDir) + if err != nil { + return fmt.Errorf("ca: %w", err) + } + app.mitmCA.Store(ca) + } + + // Allowlist is exactly the set of hosts claimed by model configs. + // No global list — admins add hosts by creating an MITM model + // config (template available in the Add Model UI). When no config + // claims any host, the listener still starts but every CONNECT + // tunnels through unmodified. + effectiveHosts := make([]string, 0, len(ownership.Owners)) + for h := range ownership.Owners { + effectiveHosts = append(effectiveHosts, h) + } + sort.Strings(effectiveHosts) + + // Per-host PII gate inherits from the owning model's pii.enabled. + // A non-cloud-proxy backend with no explicit pii.enabled resolves + // to false → host is intercepted but the regex pass is skipped + // (audit events still record). + var piiDisabled []string + for host, modelName := range ownership.Owners { + cfg, exists := app.backendLoader.GetModelConfig(modelName) + if !exists { + continue + } + if !cfg.PIIIsEnabled() { + piiDisabled = append(piiDisabled, host) + } + } + + handler := mitm.NewPIIHandler(mitm.PIIHandlerOptions{ + Redactor: app.piiRedactor, + EventStore: app.piiEvents, + HostsWithPIIDisabled: piiDisabled, + }) + + srv, err := mitm.NewServer(mitm.Config{ + Addr: options.MITMListen, + CA: app.mitmCA.Load(), + InterceptHosts: effectiveHosts, + Handler: handler, + EventStore: app.piiEvents, + }) + if err != nil { + return fmt.Errorf("server: %w", err) + } + if err := srv.Start(); err != nil { + return fmt.Errorf("listen: %w", err) + } + app.mitmServer.Store(srv) + + xlog.Info("mitm: cloudproxy listener started", + "addr", srv.Addr(), + "ca_dir", caDir, + "intercept_hosts", effectiveHosts, + "model_owned_hosts", len(ownership.Owners), + "pii_disabled_hosts", len(piiDisabled), + ) + return nil +} + +// StopMITM is idempotent. +func (a *Application) StopMITM() error { + a.mitmMutex.Lock() + defer a.mitmMutex.Unlock() + stopMITMLocked(a) + return nil +} + +// RestartMITM reuses the existing CA so trusted clients keep +// working across listener flips. +func (a *Application) RestartMITM() error { + a.mitmMutex.Lock() + defer a.mitmMutex.Unlock() + stopMITMLocked(a) + if a.applicationConfig.MITMListen == "" { + xlog.Info("mitm: cloudproxy listener stays disabled (no listen address)") + return nil + } + return startMITMLocked(a, a.applicationConfig) +} + +func stopMITMLocked(a *Application) { + srv := a.mitmServer.Load() + if srv == nil { + return + } + srv.Stop() + a.mitmServer.Store(nil) + xlog.Info("mitm: cloudproxy listener stopped") +} diff --git a/core/application/router_factories.go b/core/application/router_factories.go new file mode 100644 index 000000000..d37cfb9d8 --- /dev/null +++ b/core/application/router_factories.go @@ -0,0 +1,63 @@ +package application + +import ( + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" +) + +// adapterConfig resolves a model name to its runtime ModelConfig, or +// nil when the name is unknown. Shared by the router-facing factories +// below and by ModelConfigLookup. +func (a *Application) adapterConfig(modelName string) *config.ModelConfig { + cfg, err := a.backendLoader.LoadModelConfigFileByNameDefaultOptions(modelName, a.applicationConfig) + if err != nil || cfg == nil { + return nil + } + return cfg +} + +// ModelConfigLookup is the lookup function the router middleware's +// classifier validator uses to confirm classifier_model declares +// FLAG_SCORE before binding it. +func (a *Application) ModelConfigLookup() func(modelName string) *config.ModelConfig { + return a.adapterConfig +} + +// Scorer returns a backend.Scorer bound to the named model, or nil +// when the model is unknown. Used as a method value (app.Scorer) by +// router.ClassifierDeps — no factory-of-factory wrapper needed. +func (a *Application) Scorer(modelName string) backend.Scorer { + cfg := a.adapterConfig(modelName) + if cfg == nil { + return nil + } + return backend.NewScorer(a.modelLoader, *cfg, a.applicationConfig) +} + +// Reranker returns a backend.Reranker bound to the named model, or +// nil when unknown. The reranker model's `type:` (e.g. "colbert") +// selects the scoring head inside the rerankers backend. +func (a *Application) Reranker(modelName string) backend.Reranker { + cfg := a.adapterConfig(modelName) + if cfg == nil { + return nil + } + return backend.NewReranker(a.modelLoader, *cfg, a.applicationConfig) +} + +// Embedder returns a backend.Embedder bound to the named model, or +// nil when unknown. Used by the router's L2 embedding cache. +func (a *Application) Embedder(modelName string) backend.Embedder { + cfg := a.adapterConfig(modelName) + if cfg == nil { + return nil + } + return backend.NewEmbedder(a.modelLoader, *cfg, a.applicationConfig) +} + +// VectorStore returns a backend.VectorStore for the named collection, +// or nil when the name is empty. Each router model gets its own +// backend process via the model loader's cache keyed by storeName. +func (a *Application) VectorStore(storeName string) backend.VectorStore { + return backend.NewVectorStore(a.modelLoader, a.applicationConfig, storeName) +} diff --git a/core/application/runtime_settings_branding_test.go b/core/application/runtime_settings_branding_test.go index 9f173864e..6300f4456 100644 --- a/core/application/runtime_settings_branding_test.go +++ b/core/application/runtime_settings_branding_test.go @@ -87,6 +87,28 @@ var _ = Describe("loadRuntimeSettingsFromFile", func() { }) }) + // MITM listener address. The file is the only source — no env var + // exists — so a regression here means an admin who configured the + // listener via /api/settings loses it after a reboot, even though + // the value is still on disk in the volume. (Intercept hosts now + // live in model YAML mitm.hosts: blocks, not runtime_settings.json.) + Describe("MITM fields", func() { + It("loads mitm_listen", func() { + cfg := &config.ApplicationConfig{DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`)} + loadRuntimeSettingsFromFile(cfg) + Expect(cfg.MITMListen).To(Equal(":8443")) + }) + + It("does not override an explicit CLI flag", func() { + cfg := &config.ApplicationConfig{ + DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`), + MITMListen: ":9999", // simulate WithMITMListen(":9999") + } + loadRuntimeSettingsFromFile(cfg) + Expect(cfg.MITMListen).To(Equal(":9999"), "CLI flag must win over the persisted file value") + }) + }) + // The Agent Pool block has a mix of zero and non-zero defaults // (Enabled=true, EmbeddingModel="granite-...", MaxChunkingSize=400, // VectorEngine="chromem", AgentHubURL="https://agenthub.localai.io"). diff --git a/core/application/startup.go b/core/application/startup.go index 183713f86..8268112bc 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -15,8 +15,14 @@ import ( "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/core/services/jobs" + "github.com/mudler/LocalAI/core/services/monitoring" "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/routing/admission" + "github.com/mudler/LocalAI/core/services/routing/billing" + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/LocalAI/core/services/routing/router" "github.com/mudler/LocalAI/core/services/storage" + "github.com/mudler/LocalAI/pkg/signals" coreStartup "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/vram" @@ -128,6 +134,117 @@ func New(opts ...config.AppOption) (*Application, error) { }() } + // Initialize the OTel + Prometheus metric pipeline before any + // counter is created. monitoring.NewLocalAIMetricsService calls + // otel.SetMeterProvider, so any subsequent otel.Meter() call — + // including billing.NewRecorder below — sees the real provider + // rather than the no-op global. Initialising metrics later (in + // core/http/app.go) leaves billing's counters bound to a no-op + // meter and never reaches /metrics. We deliberately ignore + // DisableMetrics here for ordering purposes; the HTTP middleware + // that records api_call histograms is still gated. + if !options.DisableMetrics { + ms, err := monitoring.NewLocalAIMetricsService() + if err != nil { + xlog.Error("failed to initialize metrics provider", "error", err) + } else { + application.metricsService = ms + // Bind the billing package's counters to the same meter the + // metrics service exports. Without this, billing's counters + // resolve via the OTel global and never reach /metrics. + billing.SetMeter(ms.Meter) + } + } + + // Wire the routing-module billing recorder. The recorder runs in + // every mode (auth on/off, distributed/single-node) so that token + // tracking is not gated on auth — a no-auth single-user box still + // gets dashboards and `/api/usage` populated. + // + // fallbackUser is wired *unconditionally* when stats are enabled. + // UsageMiddleware uses it as the attribution source whenever + // auth.GetUser(c) is nil — that covers (a) no-auth deployments and + // (b) internal callers under auth-on (cron flushers, distributed + // worker callbacks) that hit a recordable endpoint without a user + // in context. The billing.user_id_present invariant still rejects + // empty IDs; LocalUser() returns a stable UUID per data path. + if !options.DisableStats { + var statsBackend billing.StatsBackend + switch { + case application.authDB != nil: + statsBackend = billing.NewGormBackend(application.authDB, 0, 0) + xlog.Info("stats: using auth DB for usage records") + default: + statsBackend = billing.NewMemoryBackend(0) + xlog.Info("stats: using in-memory ring buffer (no-auth single-user mode)") + } + application.fallbackUser = billing.LocalUser(options.DataPath) + application.statsRecorder = billing.NewRecorder(statsBackend) + // Drain pending records on SIGTERM. The GORM backend buffers up + // to maxPending (5k) records across a 5s flush tick, so without + // this the last few seconds of usage disappear on graceful exit. + signals.RegisterGracefulTerminationHandler(func() { + _ = application.statsRecorder.Close() + }) + xlog.Info("stats: fallback user wired", "local_user_id", application.fallbackUser.ID) + } else { + xlog.Info("stats: disabled by --disable-stats") + } + + // Wire the regex PII filter. Default-on: a single-user box gets + // the built-in pattern set the first time it starts, with email/ + // phone/SSN/credit-card on mask and api_key_prefix on block. If + // the operator wants different actions, --pii-config points at a + // YAML file that overrides per-id; --disable-pii turns it off + // entirely. + if !options.DisablePII { + patterns, err := pii.LoadConfig(options.PIIConfigPath) + if err != nil { + return nil, fmt.Errorf("pii config: %w", err) + } + application.piiRedactor = pii.NewRedactor(patterns) + application.piiEvents = pii.NewMemoryEventStore(0) + // Apply persisted per-pattern overrides — admins toggling + // action/disabled via the UI and clicking "Save to disk" land + // here on the next start. Bad ids are warned and ignored so a + // stale entry doesn't block startup. + for id, ov := range options.PIIPatternOverrides { + if ov.Action != nil { + if err := application.piiRedactor.SetAction(id, pii.Action(*ov.Action)); err != nil { + xlog.Warn("pii: persisted override skipped", "pattern", id, "error", err) + continue + } + } + if ov.Disabled != nil { + if err := application.piiRedactor.SetDisabled(id, *ov.Disabled); err != nil { + xlog.Warn("pii: persisted disable skipped", "pattern", id, "error", err) + } + } + } + xlog.Info("pii: filter enabled", + "patterns", len(patterns), + "config_path", options.PIIConfigPath, + "persisted_overrides", len(options.PIIPatternOverrides), + ) + } else { + xlog.Info("pii: disabled by --disable-pii") + } + + // Wire the routing decision log. Always-on when stats are enabled — + // the per-router admin page reads this as the live activity feed + // and as input to drift checks for subsystem 5. + if !options.DisableStats { + application.routerDecisions = router.NewMemoryDecisionStore(0) + } + // Process-wide classifier cache shared across all route middlewares so + // the embedding-cache stats endpoint sees a single source of truth. + application.routerRegistry = router.NewRegistry() + + // Subsystem 5: admission control. Limiter is always wired so a + // model that gains a limits: block via gallery install or YAML + // edit takes effect on the next restart without conditional plumbing. + application.admissionLimiter = admission.New() + // Wire JobStore for DB-backed task/job persistence whenever auth DB is available. // This ensures tasks and jobs survive restarts in both single-node and distributed modes. if application.authDB != nil && application.agentJobService != nil { @@ -291,6 +408,20 @@ func New(opts ...config.AppOption) (*Application, error) { loadRuntimeSettingsFromFile(options) } + // Wire the cloudproxy MITM listener. Opt-in: empty MITMListen + // means "no MITM" — operators must explicitly choose to start + // it because clients have to install the generated CA cert. + // The handler reuses the global redactor + event store so an + // admin who's already configured PII filtering for direct API + // traffic doesn't need a parallel config for MITM traffic. + // Runs after loadRuntimeSettingsFromFile so a listener configured + // via /api/settings is brought back up across restarts. + if options.MITMListen != "" { + if err := startMITMProxy(application, options); err != nil { + return nil, fmt.Errorf("mitm: startup: %w", err) + } + } + application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging) // turn off any process that was started by GRPC if the context is canceled @@ -580,6 +711,25 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) { options.Branding.FaviconFile = *settings.FaviconFile } + // MITM listener address. The CLI flag WithMITMListen populates + // options at startup; if the user configured MITM via /api/settings + // after the fact, only the file holds the value. Apply when the + // CLI flag did not already set it. (Intercept hosts now live in + // model YAML mitm.hosts: rather than runtime_settings.json.) + if settings.MITMListen != nil && options.MITMListen == "" { + options.MITMListen = *settings.MITMListen + } + + // PII pattern overrides — file is the only source; CLI flags don't + // reach into this map. Apply unconditionally when present; the + // redactor wiring below sees the result on first construction. + if settings.PIIPatternOverrides != nil { + options.PIIPatternOverrides = make(map[string]config.PIIPatternRuntimeOverride, len(*settings.PIIPatternOverrides)) + for id, ov := range *settings.PIIPatternOverrides { + options.PIIPatternOverrides[id] = ov + } + } + // Backend upgrade flags if settings.AutoUpgradeBackends != nil { if !options.AutoUpgradeBackends { diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index f7944827d..dd9b9cffb 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -1,6 +1,7 @@ package backend import ( + "context" "fmt" "time" @@ -11,6 +12,32 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) +// Embedder produces a fixed-dimension vector from a prompt. The +// router's L2 embedding cache uses it to look up semantically-similar +// past decisions. +type Embedder interface { + Embed(ctx context.Context, text string) ([]float32, error) +} + +// NewEmbedder binds (loader, modelConfig, appConfig) into an Embedder. +func NewEmbedder(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Embedder { + return &modelEmbedder{loader: loader, modelConfig: modelConfig, appConfig: appConfig} +} + +type modelEmbedder struct { + loader *model.ModelLoader + modelConfig config.ModelConfig + appConfig *config.ApplicationConfig +} + +func (e *modelEmbedder) Embed(_ context.Context, text string) ([]float32, error) { + fn, err := ModelEmbedding(text, nil, e.loader, e.modelConfig, e.appConfig) + if err != nil { + return nil, err + } + return fn() +} + func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { opts := ModelOptions(modelConfig, appConfig) diff --git a/core/backend/options.go b/core/backend/options.go index a7d332344..0215bf37a 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -242,6 +242,18 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions { Tokenizer: c.Tokenizer, } + if c.Backend == "cloud-proxy" { + opts.Proxy = &pb.ProxyOptions{ + UpstreamUrl: c.Proxy.UpstreamURL, + Mode: c.Proxy.Mode, + Provider: c.Proxy.Provider, + ApiKeyEnv: c.Proxy.APIKeyEnv, + ApiKeyFile: c.Proxy.APIKeyFile, + UpstreamModel: c.Proxy.UpstreamModel, + RequestTimeoutSeconds: int32(c.Proxy.RequestTimeoutSeconds), + } + } + if c.MMProj != "" { opts.MMProj = filepath.Join(modelPath, c.MMProj) } diff --git a/core/backend/rerank.go b/core/backend/rerank.go index a90c2aad1..f82208f8c 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -11,6 +11,51 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) +// RerankResult is the per-document score returned to consumers, +// narrowed from proto.RerankResult so callers don't need to depend on +// the proto package. +type RerankResult struct { + Index int + RelevanceScore float32 +} + +// Reranker scores a list of candidate documents against a query. +// Returns one RerankResult per input document (no top-N truncation — +// callers that need it can sort and slice). +type Reranker interface { + Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error) +} + +// NewReranker binds (loader, modelConfig, appConfig) into a Reranker. +func NewReranker(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Reranker { + return &modelReranker{loader: loader, modelConfig: modelConfig, appConfig: appConfig} +} + +type modelReranker struct { + loader *model.ModelLoader + modelConfig config.ModelConfig + appConfig *config.ApplicationConfig +} + +func (r *modelReranker) Rerank(ctx context.Context, query string, documents []string) ([]RerankResult, error) { + req := &proto.RerankRequest{ + Query: query, + Documents: documents, + // TopN=0 → backend returns scores for every document. Truncating + // here would silently zero out labels the reranker considered + // unlikely, which the router classifier needs. + } + res, err := Rerank(ctx, req, r.loader, r.appConfig, r.modelConfig) + if err != nil { + return nil, err + } + out := make([]RerankResult, 0, len(res.GetResults())) + for _, dr := range res.GetResults() { + out = append(out, RerankResult{Index: int(dr.GetIndex()), RelevanceScore: dr.GetRelevanceScore()}) + } + return out, nil +} + func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) { opts := ModelOptions(modelConfig, appConfig) rerankModel, err := loader.Load(opts...) diff --git a/core/backend/score.go b/core/backend/score.go new file mode 100644 index 000000000..dce06213c --- /dev/null +++ b/core/backend/score.go @@ -0,0 +1,159 @@ +package backend + +import ( + "context" + "fmt" + "time" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/trace" + "github.com/mudler/LocalAI/pkg/grpc" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" +) + +// ScoreOptions controls a single Score request. +type ScoreOptions struct { + // IncludeTokenLogprobs returns per-token log-probability detail for + // each candidate. Off by default — the joint LogProb is enough for + // ranking; callers that need calibration / entropy over the token + // stream opt in. + IncludeTokenLogprobs bool + // LengthNormalize divides the joint log-prob by the candidate's + // token count. Useful when comparing candidates of different + // lengths — without it, longer candidates score lower by default. + LengthNormalize bool +} + +// CandidateScore is the per-candidate result. Mirrors pb.CandidateScore +// but avoids leaking the proto type to consumers. +type CandidateScore struct { + LogProb float64 + LengthNormalizedLogProb float64 + NumTokens int + Tokens []TokenLogProb +} + +type TokenLogProb struct { + Token string + LogProb float64 +} + +// Scorer evaluates a model's joint log-probability of each candidate +// continuation given a shared prompt. Implemented by NewScorer over a +// model-loaded backend; the router's score classifier consumes this +// for multi-label policy selection. +type Scorer interface { + Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error) +} + +// NewScorer binds (loader, modelConfig, appConfig) into a Scorer. The +// underlying backend is resolved lazily on the first Score call. +// Returns nil only as a contract violation — callers that need to +// detect "model not loadable" should look up the config first. +func NewScorer(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) Scorer { + return &modelScorer{loader: loader, modelConfig: modelConfig, appConfig: appConfig} +} + +type modelScorer struct { + loader *model.ModelLoader + modelConfig config.ModelConfig + appConfig *config.ApplicationConfig +} + +func (m *modelScorer) Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error) { + fn, err := ModelScore(prompt, candidates, ScoreOptions{LengthNormalize: true}, m.loader, m.modelConfig, m.appConfig) + if err != nil { + return nil, err + } + return fn(ctx) +} + +// ModelScore loads the backend for modelConfig and returns a closure +// that scores `candidates` against `prompt`. The closure is bound to +// the loaded model so callers can keep it around for repeat scoring +// within the same request without re-resolving the backend. +func ModelScore(prompt string, candidates []string, opts ScoreOptions, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func(ctx context.Context) ([]CandidateScore, error), error) { + modelOpts := ModelOptions(modelConfig, appConfig) + inferenceModel, err := loader.Load(modelOpts...) + if err != nil { + recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) + return nil, err + } + b, ok := inferenceModel.(grpc.Backend) + if !ok { + return nil, fmt.Errorf("scoring not supported by backend %q", modelConfig.Backend) + } + if len(candidates) == 0 { + return nil, fmt.Errorf("Score: candidates must be non-empty") + } + return func(ctx context.Context) ([]CandidateScore, error) { + // Surface score calls in the Traces UI alongside the LLM calls + // they typically gate (router classifier, eval scoring). Without + // this, a router-classified request shows only the downstream LLM + // trace with no record of the classification that picked it. + var startTime time.Time + if appConfig.EnableTracing { + trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes) + startTime = time.Now() + } + resp, err := b.Score(ctx, &pb.ScoreRequest{ + Prompt: prompt, + Candidates: candidates, + IncludeTokenLogprobs: opts.IncludeTokenLogprobs, + LengthNormalize: opts.LengthNormalize, + }) + results := scoreResponseToCandidates(resp, opts.IncludeTokenLogprobs) + if appConfig.EnableTracing { + errStr := "" + if err != nil { + errStr = err.Error() + } + trace.RecordBackendTrace(trace.BackendTrace{ + Timestamp: startTime, + Duration: time.Since(startTime), + Type: trace.BackendTraceScore, + ModelName: modelConfig.Name, + Backend: modelConfig.Backend, + Summary: trace.TruncateString(prompt, 200), + Error: errStr, + Data: map[string]any{ + // Copy candidates so the trace buffer doesn't pin a + // caller-owned slice for the lifetime of the ring. + "candidates": append([]string(nil), candidates...), + "results": results, + }, + }) + } + if err != nil { + return nil, err + } + return results, nil + }, nil +} + +// scoreResponseToCandidates converts the wire-format pb response into +// the value type consumed by callers. Extracted to keep ModelScore's +// closure trivial and so the conversion can be unit-tested without a +// real backend. +func scoreResponseToCandidates(resp *pb.ScoreResponse, includeTokens bool) []CandidateScore { + if resp == nil { + return nil + } + out := make([]CandidateScore, len(resp.Candidates)) + for i, c := range resp.Candidates { + cs := CandidateScore{ + LogProb: c.LogProb, + LengthNormalizedLogProb: c.LengthNormalizedLogProb, + NumTokens: int(c.NumTokens), + } + if includeTokens && len(c.Tokens) > 0 { + cs.Tokens = make([]TokenLogProb, len(c.Tokens)) + for j, t := range c.Tokens { + cs.Tokens[j] = TokenLogProb{Token: t.Token, LogProb: t.LogProb} + } + } + out[i] = cs + } + return out +} diff --git a/core/backend/score_test.go b/core/backend/score_test.go new file mode 100644 index 000000000..48193efab --- /dev/null +++ b/core/backend/score_test.go @@ -0,0 +1,63 @@ +package backend + +import ( + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("scoreResponseToCandidates", func() { + It("returns nil for a nil response", func() { + Expect(scoreResponseToCandidates(nil, false)).To(BeNil()) + }) + + It("returns an empty slice when the response has no candidates", func() { + Expect(scoreResponseToCandidates(&pb.ScoreResponse{}, false)).To(BeEmpty()) + }) + + It("copies LogProb / LengthNormalizedLogProb / NumTokens for every candidate", func() { + resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{ + {LogProb: -2.0, LengthNormalizedLogProb: -1.0, NumTokens: 2}, + {LogProb: -7.5, LengthNormalizedLogProb: -1.5, NumTokens: 5}, + }} + got := scoreResponseToCandidates(resp, false) + Expect(got).To(HaveLen(2)) + Expect(got[0].LogProb).To(Equal(-2.0)) + Expect(got[0].LengthNormalizedLogProb).To(Equal(-1.0)) + Expect(got[0].NumTokens).To(Equal(2)) + Expect(got[1].LogProb).To(Equal(-7.5)) + Expect(got[1].NumTokens).To(Equal(5)) + }) + + It("omits per-token detail when includeTokens=false even if the wire response carries it", func() { + // Defensive: if the backend over-reports we still respect the + // caller's opt-in so consumers don't pay marshaling for data + // they didn't ask for. + resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{ + LogProb: -1.0, + Tokens: []*pb.TokenLogProb{{Token: "hi", LogProb: -1.0}}, + }}} + got := scoreResponseToCandidates(resp, false) + Expect(got).To(HaveLen(1)) + Expect(got[0].Tokens).To(BeNil()) + }) + + It("populates per-token detail when includeTokens=true", func() { + resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{ + LogProb: -3.0, + NumTokens: 2, + Tokens: []*pb.TokenLogProb{ + {Token: "Hello", LogProb: -1.0}, + {Token: " world", LogProb: -2.0}, + }, + }}} + got := scoreResponseToCandidates(resp, true) + Expect(got).To(HaveLen(1)) + Expect(got[0].Tokens).To(HaveLen(2)) + Expect(got[0].Tokens[0].Token).To(Equal("Hello")) + Expect(got[0].Tokens[0].LogProb).To(Equal(-1.0)) + Expect(got[0].Tokens[1].Token).To(Equal(" world")) + Expect(got[0].Tokens[1].LogProb).To(Equal(-2.0)) + }) +}) diff --git a/core/backend/stores.go b/core/backend/stores.go index 2fd4cc148..4884765f2 100644 --- a/core/backend/stores.go +++ b/core/backend/stores.go @@ -1,12 +1,74 @@ package backend import ( + "context" + "fmt" + "strings" + "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/store" ) +// VectorStore is the narrowed KNN store used by the router's embedding +// cache. Search returns the top-1 match (cosine similarity in [-1, 1]) +// and the serialised payload, or ok=false on a clean miss. +type VectorStore interface { + Search(ctx context.Context, vec []float32) (similarity float64, payload []byte, ok bool, err error) + Insert(ctx context.Context, vec []float32, payload []byte) error +} + +// NewVectorStore returns a VectorStore backed by the local-store +// gRPC backend, namespaced by storeName so two routers don't collide. +func NewVectorStore(loader *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) VectorStore { + if storeName == "" { + return nil + } + return &localVectorStore{loader: loader, appConfig: appConfig, storeName: storeName} +} + +type localVectorStore struct { + loader *model.ModelLoader + appConfig *config.ApplicationConfig + storeName string +} + +func (s *localVectorStore) backend(_ context.Context) (grpc.Backend, error) { + return StoreBackend(s.loader, s.appConfig, s.storeName, "") +} + +func (s *localVectorStore) Search(ctx context.Context, vec []float32) (float64, []byte, bool, error) { + be, err := s.backend(ctx) + if err != nil { + return 0, nil, false, fmt.Errorf("vector store load: %w", err) + } + _, values, similarities, err := store.Find(ctx, be, vec, 1) + if err != nil { + // local-store's Find returns "existing length is -1" before + // any keys are inserted. Surface that as a clean miss so the + // cache layer treats it as an empty store and proceeds to + // Insert rather than skipping. + if strings.Contains(err.Error(), "existing length is -1") { + return 0, nil, false, nil + } + return 0, nil, false, fmt.Errorf("vector store find: %w", err) + } + if len(values) == 0 || len(similarities) == 0 { + return 0, nil, false, nil + } + return float64(similarities[0]), values[0], true, nil +} + +func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload []byte) error { + be, err := s.backend(ctx) + if err != nil { + return fmt.Errorf("vector store load: %w", err) + } + return store.SetSingle(ctx, be, vec, payload) +} + func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string, backend string) (grpc.Backend, error) { if backend == "" { backend = model.LocalStoreBackend diff --git a/core/cli/run.go b/core/cli/run.go index 78e8b69c1..0f01e2ab4 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -159,6 +159,10 @@ type RunCMD struct { BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"` Version bool + + // Cloud-proxy MITM listener (off by default). + MITMListen string `env:"LOCALAI_MITM_LISTEN" help:"Address (host:port) for the cloudproxy MITM listener. Empty = disabled. Clients set HTTPS_PROXY=http://:. Intercept hosts are declared per-model via the model YAML mitm.hosts: block; create one from the Add Model UI." group:"middleware"` + MITMCADir string `env:"LOCALAI_MITM_CA_DIR" type:"path" help:"Directory holding the MITM proxy CA cert + key. Defaults to /mitm-ca." group:"middleware"` } func (r *RunCMD) Run(ctx *cliContext.Context) error { @@ -217,6 +221,8 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithLoadToMemory(r.LoadToMemory), config.WithMachineTag(r.MachineTag), config.WithAPIAddress(r.Address), + config.WithMITMListen(r.MITMListen), + config.WithMITMCADir(r.MITMCADir), config.WithAgentJobRetentionDays(r.AgentJobRetentionDays), config.WithLlamaCPPTunnelCallback(func(tunnels []string) { tunnelEnvVar := strings.Join(tunnels, ",") diff --git a/core/config/application_config.go b/core/config/application_config.go index a4119206d..24e7a82fc 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -40,6 +40,54 @@ type ApplicationConfig struct { P2PNetworkID string Federated bool + // DisableStats turns off per-request token tracking. By default the + // routing module's billing recorder runs in every mode (including + // no-auth single-user) so dashboards and `/api/usage` are immediately + // useful; set this to opt out of that, e.g., for ephemeral CI runs + // or privacy-strict deployments where no token-count history should + // touch disk or memory. + DisableStats bool + + // PIIConfigPath points to an optional YAML file describing the PII + // pattern set. When empty, the routing/pii module's DefaultPatterns() + // (email, phone, SSN, credit card, IPv4, API key prefixes) are + // loaded with their default actions. Each entry overrides the + // matching default by ID: + // + // patterns: + // - id: email + // action: route_local # downgrade default mask -> route_local + // - id: ssn + // action: block # upgrade default mask -> block + // + // Unknown ids are rejected with a clear error at startup. + PIIConfigPath string + + // DisablePII turns the regex PII filter off entirely. Default + // (false) enables it on the OpenAI chat completions route. + DisablePII bool + + // MITMListen is the address (host:port) the cloudproxy MITM + // listener binds on. Empty disables the MITM proxy entirely. + // Use case: redacting PII from Claude Code / Codex CLI traffic + // without LocalAI holding the upstream API key. Clients set + // HTTPS_PROXY=http://localai:port and trust the CA cert + // LocalAI exposes at /api/middleware/proxy-ca.crt. + MITMListen string + + // MITMCADir holds the persisted MITM proxy CA cert and private + // key. The CA is generated on first start; subsequent starts + // reload it so clients keep trusting the same root. The key + // file is mode 0600. + MITMCADir string + + + // PIIPatternOverrides applies persisted per-id deltas (action, + // disabled) to the live redactor at startup. Loaded from + // runtime_settings.json and applied right after pii.NewRedactor. + // nil/empty leaves the YAML defaults in place. + PIIPatternOverrides map[string]PIIPatternRuntimeOverride + DisableWebUI bool OllamaAPIRootEndpoint bool EnforcePredownloadScans bool @@ -604,6 +652,45 @@ func WithDataPath(dataPath string) AppOption { } } +// WithDisableStats turns off the billing recorder. CLI: --disable-stats. +func WithDisableStats(disable bool) AppOption { + return func(o *ApplicationConfig) { + o.DisableStats = disable + } +} + +// WithPIIConfigPath points the routing PII filter at a YAML config +// file. CLI: --pii-config. +func WithPIIConfigPath(path string) AppOption { + return func(o *ApplicationConfig) { + o.PIIConfigPath = path + } +} + +// WithDisablePII turns the regex PII filter off. CLI: --disable-pii. +func WithDisablePII(disable bool) AppOption { + return func(o *ApplicationConfig) { + o.DisablePII = disable + } +} + +// WithMITMListen sets the address the cloudproxy MITM listener +// binds on. Empty = disabled. CLI: --mitm-listen. +func WithMITMListen(addr string) AppOption { + return func(o *ApplicationConfig) { + o.MITMListen = addr + } +} + +// WithMITMCADir sets the directory used to persist the MITM proxy +// CA cert + key. CLI: --mitm-ca-dir. +func WithMITMCADir(dir string) AppOption { + return func(o *ApplicationConfig) { + o.MITMCADir = dir + } +} + + func WithDynamicConfigDir(dynamicConfigsDir string) AppOption { return func(o *ApplicationConfig) { o.DynamicConfigsDir = dynamicConfigsDir @@ -998,6 +1085,8 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { logoHorizontalFile := o.Branding.LogoHorizontalFile faviconFile := o.Branding.FaviconFile + mitmListen := o.MITMListen + return RuntimeSettings{ WatchdogEnabled: &watchdogEnabled, WatchdogIdleEnabled: &watchdogIdle, @@ -1051,6 +1140,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { LogoFile: &logoFile, LogoHorizontalFile: &logoHorizontalFile, FaviconFile: &faviconFile, + MITMListen: &mitmListen, } } @@ -1276,6 +1366,10 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req o.Branding.FaviconFile = *settings.FaviconFile } + if settings.MITMListen != nil { + o.MITMListen = *settings.MITMListen + } + // Note: ApiKeys requires special handling (merging with startup keys) - handled in caller return requireRestart diff --git a/core/config/meta/constants.go b/core/config/meta/constants.go index b0633c22d..b15eb53d0 100644 --- a/core/config/meta/constants.go +++ b/core/config/meta/constants.go @@ -49,20 +49,31 @@ var DiffusersPipelineOptions = []FieldOption{ {Value: "StableVideoDiffusionPipeline", Label: "StableVideoDiffusionPipeline"}, } +// UsecaseOptions must stay in sync with GetAllModelConfigUsecases in +// core/config/model_config.go — a value missing here is silently +// inaccessible from the model editor, which is how `score` (the router +// classifier usecase) hid for an entire release. var UsecaseOptions = []FieldOption{ {Value: "chat", Label: "Chat"}, {Value: "completion", Label: "Completion"}, {Value: "edit", Label: "Edit"}, {Value: "embeddings", Label: "Embeddings"}, {Value: "rerank", Label: "Rerank"}, + {Value: "score", Label: "Score (Router Classifier)"}, {Value: "image", Label: "Image"}, + {Value: "vision", Label: "Vision"}, + {Value: "detection", Label: "Detection"}, + {Value: "face_recognition", Label: "Face Recognition"}, {Value: "transcript", Label: "Transcript"}, + {Value: "diarization", Label: "Diarization"}, + {Value: "speaker_recognition", Label: "Speaker Recognition"}, {Value: "tts", Label: "TTS"}, {Value: "sound_generation", Label: "Sound Generation"}, + {Value: "audio_transform", Label: "Audio Transform"}, + {Value: "realtime_audio", Label: "Realtime Audio"}, {Value: "tokenize", Label: "Tokenize"}, {Value: "vad", Label: "VAD"}, {Value: "video", Label: "Video"}, - {Value: "detection", Label: "Detection"}, } var DiffusersSchedulerOptions = []FieldOption{ diff --git a/core/config/meta/registry.go b/core/config/meta/registry.go index 99f9e0298..54d891106 100644 --- a/core/config/meta/registry.go +++ b/core/config/meta/registry.go @@ -232,6 +232,17 @@ func DefaultRegistry() map[string]FieldMetaOverride { Description: "Use the chat template from the model's tokenizer config", Order: 43, }, + // Router section template — kept in the templates UI section + // (rather than the router section under "other") so operators + // editing prompt shapes find all template-typed fields in one + // place, mirroring how chat / chat_message are grouped. + "router.classifier_system_template": { + Section: "templates", + Label: "Router Classifier System Prompt", + Description: "Go text/template (with sprig functions) for the routing system prompt the score classifier feeds to its classifier_model. Executed with `.Policies` ([]{Label, Description}). Empty falls back to the built-in Arch-Router-shaped prompt (route-listing block + JSON output schema). Override when the classifier model was trained on a different schema or you need the routing instructions in a different language. The candidate format scored against the model is fixed at `{\"route\": \"