From 59108fbe32463d50d4f4dbf485b1d40104e78c55 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 30 Mar 2026 00:47:27 +0200 Subject: [PATCH] feat: add distributed mode (#9124) * feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto * refactorings Signed-off-by: Ettore Di Giacinto * fixups Signed-off-by: Ettore Di Giacinto * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto * use ginkgo Signed-off-by: Ettore Di Giacinto * refactoring and consolidation Signed-off-by: Ettore Di Giacinto * refactoring and consolidation Signed-off-by: Ettore Di Giacinto * refactoring and consolidation Signed-off-by: Ettore Di Giacinto * refactoring and consolidation Signed-off-by: Ettore Di Giacinto * refactoring and consolidation Signed-off-by: Ettore Di Giacinto * refactoring and consolidation Signed-off-by: Ettore Di Giacinto * refactoring and consolidation Signed-off-by: Ettore Di Giacinto * refactoring and consolidation Signed-off-by: Ettore Di Giacinto * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto * enhancements, refactorings Signed-off-by: Ettore Di Giacinto * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto * tests fixups Signed-off-by: Ettore Di Giacinto * refactoring and consolidation Signed-off-by: Ettore Di Giacinto * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto * enhancements, refactorings Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- .github/gallery-agent/agent.go | 2 +- .github/gallery-agent/testing.go | 40 +- .github/workflows/test.yml | 4 +- Dockerfile | 3 +- README.md | 1 + backend/backend.proto | 2 + backend/cpp/llama-cpp/grpc-server.cpp | 58 +- backend/go/acestep-cpp/acestepcpp_test.go | 6 +- backend/go/acestep-cpp/goacestepcpp.go | 14 +- backend/go/llm/llama/llama.go | 1 - backend/go/local-store/debug.go | 1 - backend/go/local-store/production.go | 1 - backend/go/local-store/store.go | 4 +- backend/go/opus/opus_test.go | 24 +- backend/go/voxtral/voxtral_test.go | 2 +- backend/python/ace-step/backend.py | 6 + backend/python/chatterbox/backend.py | 8 +- backend/python/common/grpc_auth.py | 78 + backend/python/coqui/backend.py | 8 +- backend/python/diffusers/backend.py | 8 +- backend/python/faster-qwen3-tts/backend.py | 6 + backend/python/faster-whisper/backend.py | 8 +- backend/python/fish-speech/backend.py | 6 + backend/python/kitten-tts/backend.py | 8 +- backend/python/kokoro/backend.py | 8 +- .../python/llama-cpp-quantization/backend.py | 8 +- backend/python/mlx-audio/backend.py | 8 +- backend/python/mlx-distributed/backend.py | 6 + backend/python/mlx-vlm/backend.py | 8 +- backend/python/mlx/backend.py | 8 +- backend/python/moonshine/backend.py | 8 +- backend/python/nemo/backend.py | 8 +- backend/python/neutts/backend.py | 8 +- backend/python/outetts/backend.py | 8 +- backend/python/pocket-tts/backend.py | 8 +- backend/python/qwen-asr/backend.py | 8 +- backend/python/qwen-tts/backend.py | 6 + backend/python/rerankers/backend.py | 8 +- backend/python/rfdetr/backend.py | 8 +- backend/python/transformers/backend.py | 8 +- backend/python/trl/backend.py | 6 + backend/python/vibevoice/backend.py | 8 +- backend/python/vllm-omni/backend.py | 8 +- backend/python/vllm/backend.py | 8 +- backend/python/voxcpm/backend.py | 8 +- backend/python/whisperx/backend.py | 8 +- core/application/agent_jobs.go | 14 +- core/application/application.go | 114 +- core/application/distributed.go | 267 + core/application/p2p.go | 20 +- core/application/startup.go | 97 +- core/backend/llm.go | 26 +- core/backend/options.go | 4 +- core/backend/tokenize.go | 2 +- core/backend/vad.go | 80 +- core/cli/agent.go | 8 +- core/cli/agent_worker.go | 463 + core/cli/backends.go | 4 +- core/cli/cli.go | 4 +- core/cli/completion.go | 20 +- core/cli/completion_test.go | 4 +- core/cli/models.go | 4 +- core/cli/run.go | 82 +- core/cli/transcript.go | 3 +- core/cli/worker.go | 897 + core/cli/workerregistry/client.go | 272 + core/clients/store.go | 4 +- core/config/application_config.go | 149 +- core/config/application_config_test.go | 4 +- core/config/distributed_config.go | 188 + core/config/model_config.go | 32 +- core/config/model_config_filter.go | 70 +- core/config/model_config_loader.go | 7 +- core/config/runtime_settings.go | 26 +- core/explorer/database.go | 8 +- core/gallery/backend_resolve.go | 18 +- core/gallery/backends.go | 45 +- core/gallery/gallery.go | 66 +- core/gallery/gallery_test.go | 26 +- core/gallery/importers/local_test.go | 2 +- core/gallery/models.go | 4 +- core/gallery/models_test.go | 14 +- core/gallery/models_types.go | 4 +- core/gallery/worker.go | 66 + core/gallery/worker_test.go | 99 + core/http/app.go | 46 +- core/http/app_test.go | 50 +- core/http/auth/middleware.go | 19 + core/http/auth/models.go | 45 +- core/http/auth/permissions.go | 30 +- core/http/auth/quota.go | 22 +- core/http/auth/session.go | 2 +- core/http/auth/usage.go | 20 +- core/http/auth/usage_test.go | 2 +- core/http/endpoints/anthropic/messages.go | 163 +- core/http/endpoints/explorer/dashboard.go | 25 +- core/http/endpoints/localai/agent_jobs.go | 15 +- .../http/endpoints/localai/agent_responses.go | 130 +- core/http/endpoints/localai/agent_skills.go | 189 +- core/http/endpoints/localai/agents.go | 95 +- core/http/endpoints/localai/backend.go | 18 +- .../http/endpoints/localai/backend_monitor.go | 6 +- core/http/endpoints/localai/cors_proxy.go | 69 +- core/http/endpoints/localai/edit_model.go | 16 +- .../http/endpoints/localai/edit_model_test.go | 2 +- core/http/endpoints/localai/finetune.go | 22 +- core/http/endpoints/localai/gallery.go | 14 +- core/http/endpoints/localai/import_model.go | 14 +- core/http/endpoints/localai/mcp.go | 13 +- core/http/endpoints/localai/mcp_tools.go | 36 +- core/http/endpoints/localai/metrics.go | 6 +- core/http/endpoints/localai/nodes.go | 634 + core/http/endpoints/localai/nodes_test.go | 229 + core/http/endpoints/localai/quantization.go | 18 +- core/http/endpoints/localai/types.go | 12 +- core/http/endpoints/localai/welcome.go | 8 +- core/http/endpoints/mcp/executor.go | 132 + core/http/endpoints/mcp/tools.go | 244 +- core/http/endpoints/openai/chat.go | 1086 +- core/http/endpoints/openai/image.go | 2 +- core/http/endpoints/openai/inference.go | 2 +- core/http/endpoints/openai/list.go | 10 +- core/http/endpoints/openai/realtime.go | 12 +- core/http/endpoints/openai/realtime_model.go | 12 +- .../openai/realtime_transport_webrtc.go | 24 +- core/http/endpoints/openai/transcription.go | 3 +- .../http/endpoints/openresponses/responses.go | 1269 +- core/http/endpoints/openresponses/store.go | 12 +- .../endpoints/openresponses/store_test.go | 4 +- .../http/endpoints/openresponses/websocket.go | 6 +- core/http/middleware/auth.go | 2 +- core/http/middleware/request.go | 22 +- core/http/middleware/trace.go | 48 +- core/http/openresponses_test.go | 152 +- .../react-ui/src/components/ImageSelector.jsx | 81 + core/http/react-ui/src/components/Modal.jsx | 6 +- core/http/react-ui/src/components/Sidebar.jsx | 1 + core/http/react-ui/src/hooks/useAgentChat.js | 43 +- core/http/react-ui/src/pages/AgentChat.jsx | 104 +- core/http/react-ui/src/pages/AgentCreate.jsx | 219 +- core/http/react-ui/src/pages/AgentJobs.jsx | 17 +- .../react-ui/src/pages/AgentTaskDetails.jsx | 2 + core/http/react-ui/src/pages/Manage.jsx | 23 +- core/http/react-ui/src/pages/Models.jsx | 15 +- .../react-ui/src/pages/NodeBackendLogs.jsx | 268 + core/http/react-ui/src/pages/Nodes.jsx | 715 + core/http/react-ui/src/pages/P2P.jsx | 29 +- core/http/react-ui/src/router.jsx | 4 + core/http/react-ui/src/utils/api.js | 17 +- core/http/react-ui/src/utils/config.js | 13 + core/http/render.go | 6 +- core/http/routes/anthropic.go | 20 +- core/http/routes/auth.go | 80 +- core/http/routes/auth_test.go | 78 +- core/http/routes/finetuning.go | 4 +- core/http/routes/localai.go | 23 +- core/http/routes/nodes.go | 117 + core/http/routes/openai.go | 9 +- core/http/routes/openresponses.go | 8 + core/http/routes/quantization.go | 4 +- core/http/routes/ui.go | 183 +- core/http/routes/ui_api.go | 266 +- core/http/routes/ui_api_backends_test.go | 22 +- core/http/routes/ui_backend_gallery.go | 4 +- core/http/routes/ui_gallery.go | 4 +- core/p2p/node.go | 6 +- core/p2p/p2p.go | 4 +- core/schema/agent_jobs.go | 10 +- core/schema/anthropic.go | 67 +- core/schema/anthropic_test.go | 26 +- core/schema/elevenlabs.go | 80 +- core/schema/finetune.go | 22 +- core/schema/localai.go | 20 +- core/schema/message.go | 6 +- core/schema/message_test.go | 12 +- core/schema/openai.go | 24 +- core/schema/openresponses.go | 44 +- core/schema/quantization.go | 34 +- core/schema/request.go | 46 +- .../schema/transcription_format.go | 18 +- core/services/advisorylock/advisorylock.go | 73 + .../advisorylock/advisorylock_suite_test.go | 13 + .../advisorylock/advisorylock_test.go | 202 + core/services/advisorylock/keys.go | 13 + core/services/advisorylock/leader_loop.go | 41 + .../services/advisorylock/leader_loop_test.go | 115 + core/services/agent_pool.go | 2121 - .../agentpool/agent_config_backend.go | 60 + .../agentpool/agent_config_distributed.go | 178 + core/services/agentpool/agent_config_local.go | 191 + core/services/{ => agentpool}/agent_jobs.go | 431 +- .../{ => agentpool}/agent_jobs_test.go | 287 +- core/services/agentpool/agent_pool.go | 1091 + .../{ => agentpool}/agent_pool_sse.go | 6 +- core/services/agentpool/errors.go | 13 + core/services/agentpool/job_persister.go | 29 + core/services/agentpool/job_persister_db.go | 85 + core/services/agentpool/job_persister_file.go | 153 + core/services/agentpool/job_persister_test.go | 262 + .../{ => agentpool}/services_suite_test.go | 2 +- .../services/{ => agentpool}/user_services.go | 42 +- core/services/{ => agentpool}/user_storage.go | 6 +- .../services/agents/agents_suite_test.go | 6 +- core/services/agents/config.go | 217 + core/services/agents/configmeta.go | 155 + core/services/agents/dispatcher.go | 546 + core/services/agents/events.go | 319 + core/services/agents/executor.go | 399 + core/services/agents/executor_test.go | 203 + core/services/agents/knowledge.go | 262 + core/services/agents/mcp.go | 62 + core/services/agents/scheduler.go | 151 + core/services/agents/scheduler_test.go | 387 + core/services/agents/skills.go | 137 + core/services/agents/store.go | 218 + core/services/agents/store_test.go | 101 + core/services/dbutil/json.go | 29 + core/services/distributed/finetune.go | 105 + core/services/distributed/gallery.go | 136 + core/services/distributed/init.go | 40 + core/services/distributed/skills.go | 96 + .../{finetune.go => finetune/service.go} | 63 +- core/services/gallery.go | 166 - core/services/{ => galleryop}/backends.go | 40 +- .../services/{ => galleryop}/backends_test.go | 26 +- .../galleryop/galleryop_suite_test.go | 13 + core/services/{ => galleryop}/list_models.go | 126 +- core/services/galleryop/managers.go | 23 + core/services/galleryop/managers_local.go | 102 + core/services/{ => galleryop}/models.go | 86 +- core/services/{ => galleryop}/operation.go | 6 +- core/services/galleryop/service.go | 272 + core/services/jobs/conversions.go | 125 + core/services/jobs/conversions_test.go | 162 + core/services/jobs/dispatcher.go | 496 + core/services/jobs/dispatcher_test.go | 415 + core/services/jobs/jobs_suite_test.go | 13 + core/services/jobs/publish.go | 36 + core/services/jobs/sse.go | 97 + core/services/jobs/store.go | 305 + core/services/jobs/store_test.go | 367 + core/services/mcp/remote.go | 56 + core/services/messaging/cancel_registry.go | 34 + .../messaging/cancel_registry_test.go | 67 + core/services/messaging/client.go | 189 + core/services/messaging/interfaces.go | 27 + .../messaging/messaging_suite_test.go | 13 + core/services/messaging/subjects.go | 264 + .../{ => monitoring}/backend_monitor.go | 2 +- core/services/{ => monitoring}/metrics.go | 2 +- core/services/nodes/distributed_store.go | 88 + core/services/nodes/distributed_store_test.go | 133 + core/services/nodes/file_stager.go | 35 + core/services/nodes/file_stager_http.go | 238 + core/services/nodes/file_stager_s3.go | 176 + core/services/nodes/file_staging_client.go | 433 + core/services/nodes/file_transfer_server.go | 515 + .../nodes/file_transfer_server_test.go | 199 + core/services/nodes/health.go | 172 + core/services/nodes/health_mock_test.go | 309 + core/services/nodes/health_test.go | 279 + core/services/nodes/inflight.go | 100 + core/services/nodes/inflight_test.go | 241 + core/services/nodes/interfaces.go | 89 + core/services/nodes/managers_distributed.go | 156 + core/services/nodes/model_router.go | 94 + core/services/nodes/model_router_test.go | 154 + core/services/nodes/nodes_suite_test.go | 13 + core/services/nodes/registry.go | 580 + core/services/nodes/registry_test.go | 341 + core/services/nodes/router.go | 653 + core/services/nodes/router_test.go | 601 + core/services/nodes/staging_keys.go | 42 + core/services/nodes/staging_keys_test.go | 139 + core/services/nodes/unloader.go | 161 + core/services/nodes/unloader_test.go | 256 + .../service.go} | 9 +- core/services/skills/distributed.go | 183 + core/services/skills/filesystem.go | 522 + core/services/skills/manager.go | 50 + core/services/storage/cleanup.go | 101 + core/services/storage/filemanager.go | 255 + core/services/storage/filemanager_test.go | 312 + core/services/storage/filesystem.go | 146 + core/services/storage/objectstore.go | 37 + core/services/storage/s3.go | 164 + core/services/storage/storage_suite_test.go | 13 + core/services/testutil/testdb.go | 42 + core/startup/model_preload.go | 10 +- core/startup/model_preload_test.go | 6 +- core/templates/cache.go | 2 +- core/templates/evaluator.go | 2 +- core/templates/evaluator_test.go | 4 +- core/templates/multimodal.go | 6 +- core/trace/audio_snippet.go | 2 +- core/trace/backend_trace.go | 28 +- docker-compose.distributed.yaml | 191 + docs/content/features/_index.en.md | 1 + docs/content/features/distributed-mode.md | 307 + .../features/distributed_inferencing.md | 3 + docs/content/getting-started/quickstart.md | 112 +- docs/content/overview.md | 63 +- go.mod | 46 +- go.sum | 103 +- pkg/audio/audio.go | 120 +- pkg/concurrency/jobresult.go | 69 - pkg/concurrency/jobresult_test.go | 80 - pkg/concurrency/safego.go | 20 + pkg/downloader/uri.go | 51 +- pkg/functions/function_structure.go | 10 +- pkg/functions/functions.go | 18 +- pkg/functions/functions_test.go | 27 +- pkg/functions/grammars/bnf_rules.go | 2 +- pkg/functions/grammars/grammars_suite_test.go | 4 +- pkg/functions/grammars/json_schema.go | 56 +- pkg/functions/grammars/json_schema_test.go | 8 +- pkg/functions/grammars/llama31_schema.go | 46 +- pkg/functions/iterative_parser.go | 4 +- pkg/functions/parse.go | 13 +- pkg/functions/peg/arena.go | 2 +- pkg/functions/peg/builder.go | 2 +- pkg/functions/peg/chat.go | 36 +- pkg/functions/peg/parser.go | 9 +- pkg/functions/peg/parser_test.go | 4 +- pkg/grpc/auth_test.go | 114 + pkg/grpc/backend.go | 14 +- pkg/grpc/base/singlethread.go | 1 - pkg/grpc/client.go | 253 +- pkg/grpc/grpc_suite_test.go | 13 + pkg/grpc/server.go | 88 +- pkg/huggingface-api/client.go | 30 +- pkg/huggingface-api/hfapi_suite_test.go | 2 - pkg/langchain/langchain.go | 57 - pkg/model/initializers.go | 13 +- pkg/model/loader.go | 68 +- pkg/model/loader_test.go | 33 +- pkg/model/model.go | 39 +- pkg/model/process.go | 40 +- pkg/model/store.go | 51 + pkg/model/store_test.go | 149 + pkg/model/watchdog.go | 12 +- pkg/model/watchdog_test.go | 2 +- pkg/sanitize/url.go | 16 + pkg/signals/handler.go | 13 +- pkg/sound/int16.go | 4 +- pkg/sound/testutil_test.go | 2 +- pkg/utils/base64_test.go | 76 +- pkg/utils/strings.go | 9 +- pkg/vram/cache.go | 12 +- pkg/vram/estimate.go | 56 + pkg/vram/gguf_reader.go | 10 +- pkg/vram/types.go | 16 +- pkg/xsysinfo/cpu.go | 7 +- pkg/xsysinfo/gpu.go | 25 +- tests/e2e-aio/e2e_test.go | 2 +- tests/e2e-aio/sample_data_test.go | 480018 +++++++-------- .../e2e/distributed/agent_distributed_test.go | 195 + .../distributed/agent_native_executor_test.go | 1334 + tests/e2e/distributed/backend_logs_test.go | 537 + .../distributed/distributed_full_flow_test.go | 984 + .../e2e/distributed/distributed_store_test.go | 143 + .../e2e/distributed/distributed_suite_test.go | 13 + tests/e2e/distributed/file_staging_test.go | 136 + .../distributed/finetune_distributed_test.go | 141 + tests/e2e/distributed/foundation_test.go | 275 + .../distributed/gallery_distributed_test.go | 171 + tests/e2e/distributed/job_dispatch_test.go | 188 + .../e2e/distributed/job_distribution_test.go | 590 + tests/e2e/distributed/managers_test.go | 312 + .../e2e/distributed/mcp_ci_job_helper_test.go | 212 + tests/e2e/distributed/mcp_ci_job_test.go | 644 + tests/e2e/distributed/mcp_nats_test.go | 177 + tests/e2e/distributed/model_routing_test.go | 123 + tests/e2e/distributed/node_lifecycle_test.go | 188 + .../e2e/distributed/node_registration_test.go | 236 + tests/e2e/distributed/object_storage_test.go | 209 + tests/e2e/distributed/phase4_test.go | 222 + tests/e2e/distributed/registry_extra_test.go | 209 + tests/e2e/distributed/router_tracking_test.go | 207 + .../distributed/skills_distributed_test.go | 158 + tests/e2e/distributed/sse_routes_test.go | 106 + tests/e2e/distributed/testhelpers_test.go | 113 + tests/e2e/e2e_anthropic_test.go | 24 +- tests/e2e/e2e_mcp_test.go | 42 +- tests/e2e/e2e_suite_test.go | 28 +- tests/e2e/e2e_websocket_responses_test.go | 38 +- tests/e2e/mock-backend/main.go | 42 +- tests/integration/integration_suite_test.go | 2 +- tests/integration/stores_test.go | 6 +- 389 files changed, 276305 insertions(+), 246521 deletions(-) create mode 100644 backend/python/common/grpc_auth.py create mode 100644 core/application/distributed.go create mode 100644 core/cli/agent_worker.go create mode 100644 core/cli/worker.go create mode 100644 core/cli/workerregistry/client.go create mode 100644 core/config/distributed_config.go create mode 100644 core/gallery/worker.go create mode 100644 core/gallery/worker_test.go create mode 100644 core/http/endpoints/localai/nodes.go create mode 100644 core/http/endpoints/localai/nodes_test.go create mode 100644 core/http/endpoints/mcp/executor.go create mode 100644 core/http/react-ui/src/components/ImageSelector.jsx create mode 100644 core/http/react-ui/src/pages/NodeBackendLogs.jsx create mode 100644 core/http/react-ui/src/pages/Nodes.jsx create mode 100644 core/http/routes/nodes.go rename pkg/format/transcription.go => core/schema/transcription_format.go (64%) create mode 100644 core/services/advisorylock/advisorylock.go create mode 100644 core/services/advisorylock/advisorylock_suite_test.go create mode 100644 core/services/advisorylock/advisorylock_test.go create mode 100644 core/services/advisorylock/keys.go create mode 100644 core/services/advisorylock/leader_loop.go create mode 100644 core/services/advisorylock/leader_loop_test.go delete mode 100644 core/services/agent_pool.go create mode 100644 core/services/agentpool/agent_config_backend.go create mode 100644 core/services/agentpool/agent_config_distributed.go create mode 100644 core/services/agentpool/agent_config_local.go rename core/services/{ => agentpool}/agent_jobs.go (77%) rename core/services/{ => agentpool}/agent_jobs_test.go (67%) create mode 100644 core/services/agentpool/agent_pool.go rename core/services/{ => agentpool}/agent_pool_sse.go (93%) create mode 100644 core/services/agentpool/errors.go create mode 100644 core/services/agentpool/job_persister.go create mode 100644 core/services/agentpool/job_persister_db.go create mode 100644 core/services/agentpool/job_persister_file.go create mode 100644 core/services/agentpool/job_persister_test.go rename core/services/{ => agentpool}/services_suite_test.go (88%) rename core/services/{ => agentpool}/user_services.go (79%) rename core/services/{ => agentpool}/user_storage.go (98%) rename pkg/concurrency/concurrency_suite_test.go => core/services/agents/agents_suite_test.go (54%) create mode 100644 core/services/agents/config.go create mode 100644 core/services/agents/configmeta.go create mode 100644 core/services/agents/dispatcher.go create mode 100644 core/services/agents/events.go create mode 100644 core/services/agents/executor.go create mode 100644 core/services/agents/executor_test.go create mode 100644 core/services/agents/knowledge.go create mode 100644 core/services/agents/mcp.go create mode 100644 core/services/agents/scheduler.go create mode 100644 core/services/agents/scheduler_test.go create mode 100644 core/services/agents/skills.go create mode 100644 core/services/agents/store.go create mode 100644 core/services/agents/store_test.go create mode 100644 core/services/dbutil/json.go create mode 100644 core/services/distributed/finetune.go create mode 100644 core/services/distributed/gallery.go create mode 100644 core/services/distributed/init.go create mode 100644 core/services/distributed/skills.go rename core/services/{finetune.go => finetune/service.go} (91%) delete mode 100644 core/services/gallery.go rename core/services/{ => galleryop}/backends.go (76%) rename core/services/{ => galleryop}/backends_test.go (88%) create mode 100644 core/services/galleryop/galleryop_suite_test.go rename core/services/{ => galleryop}/list_models.go (90%) create mode 100644 core/services/galleryop/managers.go create mode 100644 core/services/galleryop/managers_local.go rename core/services/{ => galleryop}/models.go (64%) rename core/services/{ => galleryop}/operation.go (97%) create mode 100644 core/services/galleryop/service.go create mode 100644 core/services/jobs/conversions.go create mode 100644 core/services/jobs/conversions_test.go create mode 100644 core/services/jobs/dispatcher.go create mode 100644 core/services/jobs/dispatcher_test.go create mode 100644 core/services/jobs/jobs_suite_test.go create mode 100644 core/services/jobs/publish.go create mode 100644 core/services/jobs/sse.go create mode 100644 core/services/jobs/store.go create mode 100644 core/services/jobs/store_test.go create mode 100644 core/services/mcp/remote.go create mode 100644 core/services/messaging/cancel_registry.go create mode 100644 core/services/messaging/cancel_registry_test.go create mode 100644 core/services/messaging/client.go create mode 100644 core/services/messaging/interfaces.go create mode 100644 core/services/messaging/messaging_suite_test.go create mode 100644 core/services/messaging/subjects.go rename core/services/{ => monitoring}/backend_monitor.go (99%) rename core/services/{ => monitoring}/metrics.go (98%) create mode 100644 core/services/nodes/distributed_store.go create mode 100644 core/services/nodes/distributed_store_test.go create mode 100644 core/services/nodes/file_stager.go create mode 100644 core/services/nodes/file_stager_http.go create mode 100644 core/services/nodes/file_stager_s3.go create mode 100644 core/services/nodes/file_staging_client.go create mode 100644 core/services/nodes/file_transfer_server.go create mode 100644 core/services/nodes/file_transfer_server_test.go create mode 100644 core/services/nodes/health.go create mode 100644 core/services/nodes/health_mock_test.go create mode 100644 core/services/nodes/health_test.go create mode 100644 core/services/nodes/inflight.go create mode 100644 core/services/nodes/inflight_test.go create mode 100644 core/services/nodes/interfaces.go create mode 100644 core/services/nodes/managers_distributed.go create mode 100644 core/services/nodes/model_router.go create mode 100644 core/services/nodes/model_router_test.go create mode 100644 core/services/nodes/nodes_suite_test.go create mode 100644 core/services/nodes/registry.go create mode 100644 core/services/nodes/registry_test.go create mode 100644 core/services/nodes/router.go create mode 100644 core/services/nodes/router_test.go create mode 100644 core/services/nodes/staging_keys.go create mode 100644 core/services/nodes/staging_keys_test.go create mode 100644 core/services/nodes/unloader.go create mode 100644 core/services/nodes/unloader_test.go rename core/services/{quantization.go => quantization/service.go} (99%) create mode 100644 core/services/skills/distributed.go create mode 100644 core/services/skills/filesystem.go create mode 100644 core/services/skills/manager.go create mode 100644 core/services/storage/cleanup.go create mode 100644 core/services/storage/filemanager.go create mode 100644 core/services/storage/filemanager_test.go create mode 100644 core/services/storage/filesystem.go create mode 100644 core/services/storage/objectstore.go create mode 100644 core/services/storage/s3.go create mode 100644 core/services/storage/storage_suite_test.go create mode 100644 core/services/testutil/testdb.go create mode 100644 docker-compose.distributed.yaml create mode 100644 docs/content/features/distributed-mode.md delete mode 100644 pkg/concurrency/jobresult.go delete mode 100644 pkg/concurrency/jobresult_test.go create mode 100644 pkg/concurrency/safego.go create mode 100644 pkg/grpc/auth_test.go create mode 100644 pkg/grpc/grpc_suite_test.go delete mode 100644 pkg/langchain/langchain.go create mode 100644 pkg/model/store.go create mode 100644 pkg/model/store_test.go create mode 100644 pkg/sanitize/url.go create mode 100644 tests/e2e/distributed/agent_distributed_test.go create mode 100644 tests/e2e/distributed/agent_native_executor_test.go create mode 100644 tests/e2e/distributed/backend_logs_test.go create mode 100644 tests/e2e/distributed/distributed_full_flow_test.go create mode 100644 tests/e2e/distributed/distributed_store_test.go create mode 100644 tests/e2e/distributed/distributed_suite_test.go create mode 100644 tests/e2e/distributed/file_staging_test.go create mode 100644 tests/e2e/distributed/finetune_distributed_test.go create mode 100644 tests/e2e/distributed/foundation_test.go create mode 100644 tests/e2e/distributed/gallery_distributed_test.go create mode 100644 tests/e2e/distributed/job_dispatch_test.go create mode 100644 tests/e2e/distributed/job_distribution_test.go create mode 100644 tests/e2e/distributed/managers_test.go create mode 100644 tests/e2e/distributed/mcp_ci_job_helper_test.go create mode 100644 tests/e2e/distributed/mcp_ci_job_test.go create mode 100644 tests/e2e/distributed/mcp_nats_test.go create mode 100644 tests/e2e/distributed/model_routing_test.go create mode 100644 tests/e2e/distributed/node_lifecycle_test.go create mode 100644 tests/e2e/distributed/node_registration_test.go create mode 100644 tests/e2e/distributed/object_storage_test.go create mode 100644 tests/e2e/distributed/phase4_test.go create mode 100644 tests/e2e/distributed/registry_extra_test.go create mode 100644 tests/e2e/distributed/router_tracking_test.go create mode 100644 tests/e2e/distributed/skills_distributed_test.go create mode 100644 tests/e2e/distributed/sse_routes_test.go create mode 100644 tests/e2e/distributed/testhelpers_test.go diff --git a/.github/gallery-agent/agent.go b/.github/gallery-agent/agent.go index 87eee4f7e..4de05b51e 100644 --- a/.github/gallery-agent/agent.go +++ b/.github/gallery-agent/agent.go @@ -406,7 +406,7 @@ func getHuggingFaceAvatarURL(author string) string { } // Parse the response to get avatar URL - var userInfo map[string]interface{} + var userInfo map[string]any body, err := io.ReadAll(resp.Body) if err != nil { return "" diff --git a/.github/gallery-agent/testing.go b/.github/gallery-agent/testing.go index c7960a9f2..10170af01 100644 --- a/.github/gallery-agent/testing.go +++ b/.github/gallery-agent/testing.go @@ -3,7 +3,7 @@ package main import ( "context" "fmt" - "math/rand" + "math/rand/v2" "strings" "time" ) @@ -13,11 +13,11 @@ func runSyntheticMode() error { generator := NewSyntheticDataGenerator() // Generate a random number of synthetic models (1-3) - numModels := generator.rand.Intn(3) + 1 + numModels := generator.rand.IntN(3) + 1 fmt.Printf("Generating %d synthetic models for testing...\n", numModels) var models []ProcessedModel - for i := 0; i < numModels; i++ { + for i := range numModels { model := generator.GenerateProcessedModel() models = append(models, model) fmt.Printf("Generated synthetic model: %s\n", model.ModelID) @@ -42,14 +42,14 @@ type SyntheticDataGenerator struct { // NewSyntheticDataGenerator creates a new synthetic data generator func NewSyntheticDataGenerator() *SyntheticDataGenerator { return &SyntheticDataGenerator{ - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + rand: rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)), } } // GenerateProcessedModelFile creates a synthetic ProcessedModelFile func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile { fileTypes := []string{"model", "readme", "other"} - fileType := fileTypes[g.rand.Intn(len(fileTypes))] + fileType := fileTypes[g.rand.IntN(len(fileTypes))] var path string var isReadme bool @@ -68,7 +68,7 @@ func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile return ProcessedModelFile{ Path: path, - Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB + Size: int64(g.rand.IntN(1000000000) + 1000000), // 1MB to 1GB SHA256: g.randomSHA256(), IsReadme: isReadme, FileType: fileType, @@ -80,19 +80,19 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel { authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"} modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"} - author := authors[g.rand.Intn(len(authors))] - modelName := modelNames[g.rand.Intn(len(modelNames))] + author := authors[g.rand.IntN(len(authors))] + modelName := modelNames[g.rand.IntN(len(modelNames))] modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6)) // Generate files - numFiles := g.rand.Intn(5) + 2 // 2-6 files + numFiles := g.rand.IntN(5) + 2 // 2-6 files files := make([]ProcessedModelFile, numFiles) // Ensure at least one model file and one readme hasModelFile := false hasReadme := false - for i := 0; i < numFiles; i++ { + for i := range numFiles { files[i] = g.GenerateProcessedModelFile() if files[i].FileType == "model" { hasModelFile = true @@ -140,27 +140,27 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel { // Generate sample metadata licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""} - license := licenses[g.rand.Intn(len(licenses))] + license := licenses[g.rand.IntN(len(licenses))] sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"} - numTags := g.rand.Intn(4) + 3 // 3-6 tags + numTags := g.rand.IntN(4) + 3 // 3-6 tags tags := make([]string, numTags) - for i := 0; i < numTags; i++ { - tags[i] = sampleTags[g.rand.Intn(len(sampleTags))] + for i := range numTags { + tags[i] = sampleTags[g.rand.IntN(len(sampleTags))] } // Remove duplicates tags = g.removeDuplicates(tags) // Optionally include icon (50% chance) icon := "" - if g.rand.Intn(2) == 0 { + if g.rand.IntN(2) == 0 { icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24)) } return ProcessedModel{ ModelID: modelID, Author: author, - Downloads: g.rand.Intn(1000000) + 1000, + Downloads: g.rand.IntN(1000000) + 1000, LastModified: g.randomDate(), Files: files, PreferredModelFile: preferredModelFile, @@ -180,7 +180,7 @@ func (g *SyntheticDataGenerator) randomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, length) for i := range b { - b[i] = charset[g.rand.Intn(len(charset))] + b[i] = charset[g.rand.IntN(len(charset))] } return string(b) } @@ -189,14 +189,14 @@ func (g *SyntheticDataGenerator) randomSHA256() string { const charset = "0123456789abcdef" b := make([]byte, 64) for i := range b { - b[i] = charset[g.rand.Intn(len(charset))] + b[i] = charset[g.rand.IntN(len(charset))] } return string(b) } func (g *SyntheticDataGenerator) randomDate() string { now := time.Now() - daysAgo := g.rand.Intn(365) // Random date within last year + daysAgo := g.rand.IntN(365) // Random date within last year pastDate := now.AddDate(0, 0, -daysAgo) return pastDate.Format("2006-01-02T15:04:05.000Z") } @@ -220,5 +220,5 @@ func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string) fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author), } - return templates[g.rand.Intn(len(templates))] + return templates[g.rand.IntN(len(templates))] } diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d58e3b077..51bae1cb1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,7 +21,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go-version: ['1.25.x'] + go-version: ['1.26.x'] steps: - name: Free Disk Space (Ubuntu) uses: jlumbroso/free-disk-space@main @@ -179,7 +179,7 @@ jobs: runs-on: macos-latest strategy: matrix: - go-version: ['1.25.x'] + go-version: ['1.26.x'] steps: - name: Clone uses: actions/checkout@v6 diff --git a/Dockerfile b/Dockerfile index 431839819..1567ef6f7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -176,7 +176,7 @@ ENV PATH=/opt/rocm/bin:${PATH} # The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it. FROM requirements-drivers AS build-requirements -ARG GO_VERSION=1.25.4 +ARG GO_VERSION=1.26.0 ARG CMAKE_VERSION=3.31.10 ARG CMAKE_FROM_SOURCE=false ARG TARGETARCH @@ -319,7 +319,6 @@ COPY ./.git ./.git # Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here COPY ./pkg/grpc ./pkg/grpc COPY ./pkg/utils ./pkg/utils -COPY ./pkg/langchain ./pkg/langchain RUN ls -l ./ RUN make protogen-go diff --git a/README.md b/README.md index 054f84c2d..dfe7ca8f1 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,7 @@ For older news and full release notes, see [GitHub Releases](https://github.com/ - [Object Detection](https://localai.io/features/object-detection/) - [Reranker API](https://localai.io/features/reranker/) - [P2P Inferencing](https://localai.io/features/distribute/) +- [Distributed Mode](https://localai.io/features/distributed-mode/) — Horizontal scaling with PostgreSQL + NATS - [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io) - [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images diff --git a/backend/backend.proto b/backend/backend.proto index 3f01efbe1..9a5eea630 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -51,6 +51,7 @@ service Backend { rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {} rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {} rpc StopQuantization(QuantizationStopRequest) returns (Result) {} + } // Define the empty request @@ -676,3 +677,4 @@ message QuantizationProgressUpdate { message QuantizationStopRequest { string job_id = 1; } + diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index 89f03bf7d..d9d5a5ca4 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -22,8 +22,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -37,6 +39,47 @@ using grpc::Server; using grpc::ServerBuilder; using grpc::ServerContext; using grpc::Status; + +// gRPC bearer token auth via AuthMetadataProcessor for distributed mode. +// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects +// requests without a matching "authorization: Bearer " metadata header. +class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor { +public: + explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {} + + bool IsBlocking() const override { return false; } + + grpc::Status Process(const InputMetadata& auth_metadata, + grpc::AuthContext* /*context*/, + OutputMetadata* /*consumed_auth_metadata*/, + OutputMetadata* /*response_metadata*/) override { + auto it = auth_metadata.find("authorization"); + if (it != auth_metadata.end()) { + std::string expected = "Bearer " + token_; + std::string got(it->second.data(), it->second.size()); + // Constant-time comparison + if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) { + return grpc::Status::OK; + } + } + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token"); + } + +private: + std::string token_; + + // Minimal constant-time comparison (avoids OpenSSL dependency) + static int ct_memcmp(const void* a, const void* b, size_t n) { + const unsigned char* pa = static_cast(a); + const unsigned char* pb = static_cast(b); + unsigned char result = 0; + for (size_t i = 0; i < n; i++) { + result |= pa[i] ^ pb[i]; + } + return result; + } +}; + // END LocalAI @@ -2760,11 +2803,24 @@ int main(int argc, char** argv) { BackendServiceImpl service(ctx_server); ServerBuilder builder; - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set + const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN"); + std::shared_ptr creds; + if (auth_token != nullptr && auth_token[0] != '\0') { + creds = grpc::InsecureServerCredentials(); + creds->SetAuthMetadataProcessor( + std::make_shared(auth_token)); + std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl; + } else { + creds = grpc::InsecureServerCredentials(); + } + + builder.AddListeningPort(server_address, creds); builder.RegisterService(&service); builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB + std::unique_ptr server(builder.BuildAndStart()); // run the HTTP server in a thread - see comment below std::thread t([&]() diff --git a/backend/go/acestep-cpp/acestepcpp_test.go b/backend/go/acestep-cpp/acestepcpp_test.go index ad154dfb3..b47857663 100644 --- a/backend/go/acestep-cpp/acestepcpp_test.go +++ b/backend/go/acestep-cpp/acestepcpp_test.go @@ -106,10 +106,10 @@ func TestLoadModel(t *testing.T) { defer conn.Close() client := pb.NewBackendClient(conn) - + // Get base directory from main model file for relative paths mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf") - + resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{ ModelFile: mainModelPath, ModelPath: modelDir, @@ -134,7 +134,7 @@ func TestSoundGeneration(t *testing.T) { if err != nil { t.Fatal(err) } - defer os.RemoveAll(tmpDir) + t.Cleanup(func() { os.RemoveAll(tmpDir) }) outputFile := filepath.Join(tmpDir, "output.wav") diff --git a/backend/go/acestep-cpp/goacestepcpp.go b/backend/go/acestep-cpp/goacestepcpp.go index 276c317d8..ec94626f1 100644 --- a/backend/go/acestep-cpp/goacestepcpp.go +++ b/backend/go/acestep-cpp/goacestepcpp.go @@ -11,7 +11,7 @@ import ( ) var ( - CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int + CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int ) @@ -29,18 +29,18 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error { var textEncoderModel, ditModel, vaeModel string for _, oo := range opts.Options { - parts := strings.SplitN(oo, ":", 2) - if len(parts) != 2 { + key, value, found := strings.Cut(oo, ":") + if !found { fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo) continue } - switch parts[0] { + switch key { case "text_encoder_model": - textEncoderModel = parts[1] + textEncoderModel = value case "dit_model": - ditModel = parts[1] + ditModel = value case "vae_model": - vaeModel = parts[1] + vaeModel = value default: fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo) } diff --git a/backend/go/llm/llama/llama.go b/backend/go/llm/llama/llama.go index ceca1fa5e..a1b06b980 100644 --- a/backend/go/llm/llama/llama.go +++ b/backend/go/llm/llama/llama.go @@ -18,7 +18,6 @@ type LLM struct { draftModel *llama.LLama } - // Free releases GPU resources and frees the llama model // This should be called when the model is being unloaded to properly release VRAM func (llm *LLM) Free() error { diff --git a/backend/go/local-store/debug.go b/backend/go/local-store/debug.go index 0654d2952..2c3d77cab 100644 --- a/backend/go/local-store/debug.go +++ b/backend/go/local-store/debug.go @@ -1,5 +1,4 @@ //go:build debug -// +build debug package main diff --git a/backend/go/local-store/production.go b/backend/go/local-store/production.go index 418b63972..ef9610cb2 100644 --- a/backend/go/local-store/production.go +++ b/backend/go/local-store/production.go @@ -1,5 +1,4 @@ //go:build !debug -// +build !debug package main diff --git a/backend/go/local-store/store.go b/backend/go/local-store/store.go index 2082684bc..b48c2e919 100644 --- a/backend/go/local-store/store.go +++ b/backend/go/local-store/store.go @@ -332,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 { assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) var dot float32 - for i := 0; i < len(k1); i++ { + for i := range len(k1) { dot += k1[i] * k2[i] } @@ -419,7 +419,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 { assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) var dot, mag2 float64 - for i := 0; i < len(k1); i++ { + for i := range len(k1) { dot += float64(k1[i] * k2[i]) mag2 += float64(k2[i] * k2[i]) } diff --git a/backend/go/opus/opus_test.go b/backend/go/opus/opus_test.go index b3daf7148..0a1a5fb97 100644 --- a/backend/go/opus/opus_test.go +++ b/backend/go/opus/opus_test.go @@ -701,7 +701,7 @@ var _ = Describe("Opus", func() { // to one-shot (only difference is resampler batch boundaries). var maxDiff float64 var sumDiffSq float64 - for i := 0; i < minLen; i++ { + for i := range minLen { diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i])) if diff > maxDiff { maxDiff = diff @@ -774,7 +774,7 @@ var _ = Describe("Opus", func() { minLen := min(len(refTail), min(len(persistentTail), len(freshTail))) var persistentMaxDiff, freshMaxDiff float64 - for i := 0; i < minLen; i++ { + for i := range minLen { pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i])) fd := math.Abs(float64(refTail[i]) - float64(freshTail[i])) if pd > persistentMaxDiff { @@ -932,7 +932,7 @@ var _ = Describe("Opus", func() { GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n", mean, stddev, stddev/mean, 16000.0/440.0/2.0) - Expect(stddev / mean).To(BeNumerically("<", 0.15), + Expect(stddev/mean).To(BeNumerically("<", 0.15), fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean)) // Also check frequency is correct @@ -978,7 +978,7 @@ var _ = Describe("Opus", func() { // Every sample must be identical — the resampler is deterministic var maxDiff float64 - for i := 0; i < len(oneShot); i++ { + for i := range len(oneShot) { diff := math.Abs(float64(oneShot[i]) - float64(batched[i])) if diff > maxDiff { maxDiff = diff @@ -1037,13 +1037,13 @@ var _ = Describe("Opus", func() { binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen)) copy(hdr[8:12], "WAVE") copy(hdr[12:16], "fmt ") - binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size - binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM - binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono - binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate - binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate - binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align - binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample + binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size + binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM + binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono + binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate + binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate + binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align + binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample copy(hdr[36:40], "data") binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen)) @@ -1126,7 +1126,7 @@ var _ = Describe("Opus", func() { ) pcm := make([]byte, toneNumSamples*2) - for i := 0; i < toneNumSamples; i++ { + for i := range toneNumSamples { sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate))) binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) } diff --git a/backend/go/voxtral/voxtral_test.go b/backend/go/voxtral/voxtral_test.go index 018b332a3..6f2dee699 100644 --- a/backend/go/voxtral/voxtral_test.go +++ b/backend/go/voxtral/voxtral_test.go @@ -138,7 +138,7 @@ func TestAudioTranscription(t *testing.T) { if err != nil { t.Fatal(err) } - defer os.RemoveAll(tmpDir) + t.Cleanup(func() { os.RemoveAll(tmpDir) }) // Download sample audio — JFK "ask not what your country can do for you" clip audioFile := filepath.Join(tmpDir, "sample.wav") diff --git a/backend/python/ace-step/backend.py b/backend/python/ace-step/backend.py index 56ce1314a..ae7584ae3 100644 --- a/backend/python/ace-step/backend.py +++ b/backend/python/ace-step/backend.py @@ -19,6 +19,10 @@ import tempfile import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + from acestep.inference import ( GenerationParams, GenerationConfig, @@ -444,6 +448,8 @@ def serve(address): ("grpc.max_send_message_length", 50 * 1024 * 1024), ("grpc.max_receive_message_length", 50 * 1024 * 1024), ], + + interceptors=get_auth_interceptors(), ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) diff --git a/backend/python/chatterbox/backend.py b/backend/python/chatterbox/backend.py index 45fd177e2..4dffeb95e 100644 --- a/backend/python/chatterbox/backend.py +++ b/backend/python/chatterbox/backend.py @@ -16,6 +16,10 @@ import torchaudio as ta from chatterbox.tts import ChatterboxTTS from chatterbox.mtl_tts import ChatterboxMultilingualTTS import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + import tempfile def is_float(s): @@ -225,7 +229,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/common/grpc_auth.py b/backend/python/common/grpc_auth.py new file mode 100644 index 000000000..eda138ab4 --- /dev/null +++ b/backend/python/common/grpc_auth.py @@ -0,0 +1,78 @@ +"""Shared gRPC bearer token authentication interceptor for LocalAI Python backends. + +When the environment variable LOCALAI_GRPC_AUTH_TOKEN is set, requests without +a valid Bearer token in the 'authorization' metadata header are rejected with +UNAUTHENTICATED. When the variable is empty or unset, no authentication is +performed (backward compatible). +""" + +import hmac +import os + +import grpc + + +class _AbortHandler(grpc.RpcMethodHandler): + """A method handler that immediately aborts with UNAUTHENTICATED.""" + + def __init__(self): + self.request_streaming = False + self.response_streaming = False + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = self._abort + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + + @staticmethod + def _abort(request, context): + context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid token") + + +class TokenAuthInterceptor(grpc.ServerInterceptor): + """Sync gRPC server interceptor that validates a bearer token.""" + + def __init__(self, token: str): + self._token = token + self._abort_handler = _AbortHandler() + + def intercept_service(self, continuation, handler_call_details): + metadata = dict(handler_call_details.invocation_metadata) + auth = metadata.get("authorization", "") + expected = "Bearer " + self._token + if not hmac.compare_digest(auth, expected): + return self._abort_handler + return continuation(handler_call_details) + + +class AsyncTokenAuthInterceptor(grpc.aio.ServerInterceptor): + """Async gRPC server interceptor that validates a bearer token.""" + + def __init__(self, token: str): + self._token = token + + async def intercept_service(self, continuation, handler_call_details): + metadata = dict(handler_call_details.invocation_metadata) + auth = metadata.get("authorization", "") + expected = "Bearer " + self._token + if not hmac.compare_digest(auth, expected): + return _AbortHandler() + return await continuation(handler_call_details) + + +def get_auth_interceptors(*, aio: bool = False): + """Return a list of gRPC interceptors for bearer token auth. + + Args: + aio: If True, return async-compatible interceptors for grpc.aio.server(). + If False (default), return sync interceptors for grpc.server(). + + Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set. + """ + token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "") + if not token: + return [] + if aio: + return [AsyncTokenAuthInterceptor(token)] + return [TokenAuthInterceptor(token)] diff --git a/backend/python/coqui/backend.py b/backend/python/coqui/backend.py index 65b37e063..d0fafc1eb 100644 --- a/backend/python/coqui/backend.py +++ b/backend/python/coqui/backend.py @@ -15,6 +15,10 @@ import torch from TTS.api import TTS import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -93,7 +97,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index 19daeff20..c9ad3b0bd 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -22,6 +22,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + # Import dynamic loader for pipeline discovery from diffusers_dynamic_loader import ( @@ -1042,7 +1046,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/faster-qwen3-tts/backend.py b/backend/python/faster-qwen3-tts/backend.py index d3bec3247..31a179423 100644 --- a/backend/python/faster-qwen3-tts/backend.py +++ b/backend/python/faster-qwen3-tts/backend.py @@ -15,6 +15,10 @@ import torch import soundfile as sf import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + def is_float(s): @@ -165,6 +169,8 @@ def serve(address): ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ] + , + interceptors=get_auth_interceptors(), ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) diff --git a/backend/python/faster-whisper/backend.py b/backend/python/faster-whisper/backend.py index c94665b2b..9cfb9dfd1 100755 --- a/backend/python/faster-whisper/backend.py +++ b/backend/python/faster-whisper/backend.py @@ -14,6 +14,10 @@ import torch from faster_whisper import WhisperModel import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -70,7 +74,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/fish-speech/backend.py b/backend/python/fish-speech/backend.py index 921b71efc..a061d49e0 100644 --- a/backend/python/fish-speech/backend.py +++ b/backend/python/fish-speech/backend.py @@ -19,6 +19,10 @@ import numpy as np import json import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + def is_float(s): @@ -424,6 +428,8 @@ def serve(address): ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB ], + + interceptors=get_auth_interceptors(), ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) diff --git a/backend/python/kitten-tts/backend.py b/backend/python/kitten-tts/backend.py index b31023c8c..33abb3289 100644 --- a/backend/python/kitten-tts/backend.py +++ b/backend/python/kitten-tts/backend.py @@ -16,6 +16,10 @@ from kittentts import KittenTTS import soundfile as sf import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -77,7 +81,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/kokoro/backend.py b/backend/python/kokoro/backend.py index 43d22238f..32013b2ff 100644 --- a/backend/python/kokoro/backend.py +++ b/backend/python/kokoro/backend.py @@ -16,6 +16,10 @@ from kokoro import KPipeline import soundfile as sf import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -84,7 +88,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/llama-cpp-quantization/backend.py b/backend/python/llama-cpp-quantization/backend.py index 359133d37..b91343daa 100644 --- a/backend/python/llama-cpp-quantization/backend.py +++ b/backend/python/llama-cpp-quantization/backend.py @@ -17,6 +17,10 @@ import time from concurrent import futures import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + import backend_pb2 import backend_pb2_grpc @@ -398,7 +402,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): def serve(address): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/mlx-audio/backend.py b/backend/python/mlx-audio/backend.py index da37d2c37..0dc197060 100644 --- a/backend/python/mlx-audio/backend.py +++ b/backend/python/mlx-audio/backend.py @@ -15,6 +15,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + from mlx_audio.tts.utils import load_model import soundfile as sf import numpy as np @@ -436,7 +440,9 @@ async def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(aio=True), + ) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address diff --git a/backend/python/mlx-distributed/backend.py b/backend/python/mlx-distributed/backend.py index b21a98070..90d74eba8 100644 --- a/backend/python/mlx-distributed/backend.py +++ b/backend/python/mlx-distributed/backend.py @@ -23,6 +23,10 @@ import tempfile from typing import List import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + import backend_pb2 import backend_pb2_grpc @@ -468,6 +472,8 @@ async def serve(address): ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ], + + interceptors=get_auth_interceptors(aio=True), ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) diff --git a/backend/python/mlx-vlm/backend.py b/backend/python/mlx-vlm/backend.py index 6c5f8b189..578a5e563 100644 --- a/backend/python/mlx-vlm/backend.py +++ b/backend/python/mlx-vlm/backend.py @@ -12,6 +12,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + from mlx_vlm import load, generate, stream_generate from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.utils import load_config, load_image @@ -446,7 +450,9 @@ async def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(aio=True), + ) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py index aaa0d6f34..1a41020f5 100644 --- a/backend/python/mlx/backend.py +++ b/backend/python/mlx/backend.py @@ -12,6 +12,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + from mlx_lm import load, generate, stream_generate from mlx_lm.sample_utils import make_sampler from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache @@ -421,7 +425,9 @@ async def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(aio=True), + ) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address diff --git a/backend/python/moonshine/backend.py b/backend/python/moonshine/backend.py index 2e50d3844..988dae548 100644 --- a/backend/python/moonshine/backend.py +++ b/backend/python/moonshine/backend.py @@ -17,6 +17,10 @@ from moonshine_voice import ( ) import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -128,7 +132,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/nemo/backend.py b/backend/python/nemo/backend.py index fd2218f69..270e8fb3a 100644 --- a/backend/python/nemo/backend.py +++ b/backend/python/nemo/backend.py @@ -14,6 +14,10 @@ import torch import nemo.collections.asr as nemo_asr import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + def is_float(s): @@ -119,7 +123,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/neutts/backend.py b/backend/python/neutts/backend.py index e765436d1..139ed824b 100644 --- a/backend/python/neutts/backend.py +++ b/backend/python/neutts/backend.py @@ -15,6 +15,10 @@ from neuttsair.neutts import NeuTTSAir import soundfile as sf import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + def is_float(s): """Check if a string can be converted to float.""" @@ -130,7 +134,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/outetts/backend.py b/backend/python/outetts/backend.py index d98cc59e8..c6f5e8aa2 100644 --- a/backend/python/outetts/backend.py +++ b/backend/python/outetts/backend.py @@ -14,6 +14,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + import outetts _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -116,7 +120,9 @@ async def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), - ]) + ], + interceptors=get_auth_interceptors(aio=True), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) diff --git a/backend/python/pocket-tts/backend.py b/backend/python/pocket-tts/backend.py index 7c734e54a..d7678636b 100644 --- a/backend/python/pocket-tts/backend.py +++ b/backend/python/pocket-tts/backend.py @@ -16,6 +16,10 @@ import torch from pocket_tts import TTSModel import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + def is_float(s): """Check if a string can be converted to float.""" @@ -225,7 +229,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/qwen-asr/backend.py b/backend/python/qwen-asr/backend.py index 53660c82e..556f0e97a 100644 --- a/backend/python/qwen-asr/backend.py +++ b/backend/python/qwen-asr/backend.py @@ -14,6 +14,10 @@ import torch from qwen_asr import Qwen3ASRModel import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + def is_float(s): @@ -184,7 +188,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/qwen-tts/backend.py b/backend/python/qwen-tts/backend.py index f4aa71f6f..f24533966 100644 --- a/backend/python/qwen-tts/backend.py +++ b/backend/python/qwen-tts/backend.py @@ -23,6 +23,10 @@ import hashlib import pickle import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + def is_float(s): @@ -900,6 +904,8 @@ def serve(address): ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB ], + + interceptors=get_auth_interceptors(), ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) diff --git a/backend/python/rerankers/backend.py b/backend/python/rerankers/backend.py index 8ce2636d7..d70b2b60f 100755 --- a/backend/python/rerankers/backend.py +++ b/backend/python/rerankers/backend.py @@ -14,6 +14,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + from rerankers import Reranker @@ -97,7 +101,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/rfdetr/backend.py b/backend/python/rfdetr/backend.py index 57f68647f..47985dfc8 100755 --- a/backend/python/rfdetr/backend.py +++ b/backend/python/rfdetr/backend.py @@ -13,6 +13,10 @@ import base64 import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + import requests @@ -139,7 +143,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/transformers/backend.py b/backend/python/transformers/backend.py index 54fd9193d..f2f70acb3 100644 --- a/backend/python/transformers/backend.py +++ b/backend/python/transformers/backend.py @@ -16,6 +16,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + import torch import torch.cuda @@ -532,7 +536,9 @@ async def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(aio=True), + ) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address diff --git a/backend/python/trl/backend.py b/backend/python/trl/backend.py index c414e6fb6..3ea4de975 100644 --- a/backend/python/trl/backend.py +++ b/backend/python/trl/backend.py @@ -17,6 +17,10 @@ import uuid from concurrent import futures import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + import backend_pb2 import backend_pb2_grpc @@ -832,6 +836,8 @@ def serve(address): ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ], + + interceptors=get_auth_interceptors(), ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) diff --git a/backend/python/vibevoice/backend.py b/backend/python/vibevoice/backend.py index 2344188d2..4353f8a29 100644 --- a/backend/python/vibevoice/backend.py +++ b/backend/python/vibevoice/backend.py @@ -20,6 +20,10 @@ from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalG from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + def is_float(s): """Check if a string can be converted to float.""" @@ -724,7 +728,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/vllm-omni/backend.py b/backend/python/vllm-omni/backend.py index c21aeb0ad..96eb8a111 100644 --- a/backend/python/vllm-omni/backend.py +++ b/backend/python/vllm-omni/backend.py @@ -27,6 +27,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + from vllm_omni.entrypoints.omni import Omni from vllm_omni.outputs import OmniRequestOutput @@ -650,7 +654,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/vllm/backend.py b/backend/python/vllm/backend.py index 56698a54e..07323c424 100644 --- a/backend/python/vllm/backend.py +++ b/backend/python/vllm/backend.py @@ -12,6 +12,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams @@ -338,7 +342,9 @@ async def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(aio=True), + ) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address diff --git a/backend/python/voxcpm/backend.py b/backend/python/voxcpm/backend.py index 0c1970648..9ee3a6e12 100644 --- a/backend/python/voxcpm/backend.py +++ b/backend/python/voxcpm/backend.py @@ -18,6 +18,10 @@ import backend_pb2_grpc import torch import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + def is_float(s): """Check if a string can be converted to float.""" @@ -297,7 +301,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/backend/python/whisperx/backend.py b/backend/python/whisperx/backend.py index 7fd5cfb42..096f9ffdf 100644 --- a/backend/python/whisperx/backend.py +++ b/backend/python/whisperx/backend.py @@ -13,6 +13,10 @@ import backend_pb2 import backend_pb2_grpc import grpc +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -137,7 +141,9 @@ def serve(address): ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB - ]) + ], + interceptors=get_auth_interceptors(), + ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/core/application/agent_jobs.go b/core/application/agent_jobs.go index 0ed5d9283..b7cfb20a3 100644 --- a/core/application/agent_jobs.go +++ b/core/application/agent_jobs.go @@ -3,7 +3,7 @@ package application import ( "time" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/agentpool" "github.com/mudler/xlog" ) @@ -22,13 +22,23 @@ func (a *Application) RestartAgentJobService() error { } // Create new service instance - agentJobService := services.NewAgentJobService( + agentJobService := agentpool.NewAgentJobService( a.ApplicationConfig(), a.ModelLoader(), a.ModelConfigLoader(), a.TemplatesEvaluator(), ) + // Re-apply distributed wiring if available (matches startup.go logic) + if d := a.Distributed(); d != nil { + if d.Dispatcher != nil { + agentJobService.SetDistributedBackends(d.Dispatcher) + } + if d.JobStore != nil { + agentJobService.SetDistributedJobStore(d.JobStore) + } + } + // Start the service err := agentJobService.Start(a.ApplicationConfig().Context) if err != nil { diff --git a/core/application/application.go b/core/application/application.go index c636be38f..accba0330 100644 --- a/core/application/application.go +++ b/core/application/application.go @@ -2,12 +2,16 @@ package application import ( "context" + "math/rand/v2" "sync" "sync/atomic" + "time" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/agentpool" + "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/services/nodes" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" @@ -20,9 +24,9 @@ type Application struct { applicationConfig *config.ApplicationConfig startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading) templatesEvaluator *templates.Evaluator - galleryService *services.GalleryService - agentJobService *services.AgentJobService - agentPoolService atomic.Pointer[services.AgentPoolService] + galleryService *galleryop.GalleryService + agentJobService *agentpool.AgentJobService + agentPoolService atomic.Pointer[agentpool.AgentPoolService] authDB *gorm.DB watchdogMutex sync.Mutex watchdogStop chan bool @@ -30,6 +34,9 @@ type Application struct { p2pCtx context.Context p2pCancel context.CancelFunc agentJobMutex sync.Mutex + + // Distributed mode services (nil when not in distributed mode) + distributed *DistributedServices } func newApplication(appConfig *config.ApplicationConfig) *Application { @@ -64,15 +71,15 @@ func (a *Application) TemplatesEvaluator() *templates.Evaluator { return a.templatesEvaluator } -func (a *Application) GalleryService() *services.GalleryService { +func (a *Application) GalleryService() *galleryop.GalleryService { return a.galleryService } -func (a *Application) AgentJobService() *services.AgentJobService { +func (a *Application) AgentJobService() *agentpool.AgentJobService { return a.agentJobService } -func (a *Application) AgentPoolService() *services.AgentPoolService { +func (a *Application) AgentPoolService() *agentpool.AgentPoolService { return a.agentPoolService.Load() } @@ -86,8 +93,53 @@ func (a *Application) StartupConfig() *config.ApplicationConfig { return a.startupConfig } +// Distributed returns the distributed services, or nil if not in distributed mode. +func (a *Application) Distributed() *DistributedServices { + return a.distributed +} + +// IsDistributed returns true if the application is running in distributed mode. +func (a *Application) IsDistributed() bool { + return a.distributed != nil +} + +// waitForHealthyWorker blocks until at least one healthy backend worker is registered. +// This prevents the agent pool from failing during startup when workers haven't connected yet. +func (a *Application) waitForHealthyWorker() { + maxWait := a.applicationConfig.Distributed.WorkerWaitTimeoutOrDefault() + const basePoll = 2 * time.Second + + xlog.Info("Waiting for at least one healthy backend worker before starting agent pool") + deadline := time.Now().Add(maxWait) + + for time.Now().Before(deadline) { + registered, err := a.distributed.Registry.List(context.Background()) + if err == nil { + for _, n := range registered { + if n.NodeType == nodes.NodeTypeBackend && n.Status == nodes.StatusHealthy { + xlog.Info("Healthy backend worker found", "node", n.Name) + return + } + } + } + // Add 0-1s jitter to prevent thundering-herd on the node registry + jitter := time.Duration(rand.Int64N(int64(time.Second))) + select { + case <-a.applicationConfig.Context.Done(): + return + case <-time.After(basePoll + jitter): + } + } + xlog.Warn("No healthy backend worker found after waiting, proceeding anyway") +} + +// InstanceID returns the unique identifier for this frontend instance. +func (a *Application) InstanceID() string { + return a.applicationConfig.Distributed.InstanceID +} + func (a *Application) start() error { - galleryService := services.NewGalleryService(a.ApplicationConfig(), a.ModelLoader()) + galleryService := galleryop.NewGalleryService(a.ApplicationConfig(), a.ModelLoader()) err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState) if err != nil { return err @@ -95,19 +147,14 @@ func (a *Application) start() error { a.galleryService = galleryService - // Initialize agent job service - agentJobService := services.NewAgentJobService( + // Initialize agent job service (Start() is deferred to after distributed wiring) + agentJobService := agentpool.NewAgentJobService( a.ApplicationConfig(), a.ModelLoader(), a.ModelConfigLoader(), a.TemplatesEvaluator(), ) - err = agentJobService.Start(a.ApplicationConfig().Context) - if err != nil { - return err - } - a.agentJobService = agentJobService return nil @@ -120,27 +167,56 @@ func (a *Application) StartAgentPool() { if !a.applicationConfig.AgentPool.Enabled { return } - aps, err := services.NewAgentPoolService(a.applicationConfig) + // Build options struct from available dependencies + opts := agentpool.AgentPoolOptions{ + AuthDB: a.authDB, + } + if d := a.Distributed(); d != nil { + if d.DistStores != nil && d.DistStores.Skills != nil { + opts.SkillStore = d.DistStores.Skills + } + opts.NATSClient = d.Nats + opts.EventBridge = d.AgentBridge + opts.AgentStore = d.AgentStore + } + + aps, err := agentpool.NewAgentPoolService(a.applicationConfig, opts) if err != nil { xlog.Error("Failed to create agent pool service", "error", err) return } - if a.authDB != nil { - aps.SetAuthDB(a.authDB) + + // Wire distributed mode components + if d := a.Distributed(); d != nil { + // Wait for at least one healthy backend worker before starting the agent pool. + // Collections initialization calls embeddings which require a worker. + if d.Registry != nil { + a.waitForHealthyWorker() + } } + if err := aps.Start(a.applicationConfig.Context); err != nil { xlog.Error("Failed to start agent pool", "error", err) return } // Wire per-user scoped services so collections, skills, and jobs are isolated per user - usm := services.NewUserServicesManager( + usm := agentpool.NewUserServicesManager( aps.UserStorage(), a.applicationConfig, a.modelLoader, a.backendLoader, a.templatesEvaluator, ) + // Wire distributed backends to per-user job services + if a.agentJobService != nil { + if d := a.agentJobService.Dispatcher(); d != nil { + usm.SetJobDispatcher(d) + } + if s := a.agentJobService.DBStore(); s != nil { + usm.SetJobDBStore(s) + } + } aps.SetUserServicesManager(usm) a.agentPoolService.Store(aps) diff --git a/core/application/distributed.go b/core/application/distributed.go new file mode 100644 index 000000000..257fecdd4 --- /dev/null +++ b/core/application/distributed.go @@ -0,0 +1,267 @@ +package application + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strings" + "sync" + + "github.com/google/uuid" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/agents" + "github.com/mudler/LocalAI/core/services/distributed" + "github.com/mudler/LocalAI/core/services/jobs" + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/storage" + "github.com/mudler/LocalAI/pkg/sanitize" + "github.com/mudler/xlog" + "gorm.io/gorm" +) + +// DistributedServices holds all services initialized for distributed mode. +type DistributedServices struct { + Nats *messaging.Client + Store storage.ObjectStore + Registry *nodes.NodeRegistry + Router *nodes.SmartRouter + Health *nodes.HealthMonitor + JobStore *jobs.JobStore + Dispatcher *jobs.Dispatcher + AgentStore *agents.AgentStore + AgentBridge *agents.EventBridge + DistStores *distributed.Stores + FileMgr *storage.FileManager + FileStager nodes.FileStager + ModelAdapter *nodes.ModelRouterAdapter + Unloader *nodes.RemoteUnloaderAdapter + + shutdownOnce sync.Once +} + +// Shutdown stops all distributed services in reverse initialization order. +// It is safe to call on a nil receiver and is idempotent (uses sync.Once). +func (ds *DistributedServices) Shutdown() { + if ds == nil { + return + } + ds.shutdownOnce.Do(func() { + if ds.Health != nil { + ds.Health.Stop() + } + if ds.Dispatcher != nil { + ds.Dispatcher.Stop() + } + if closer, ok := ds.Store.(io.Closer); ok { + closer.Close() + } + // AgentBridge has no Close method — its NATS subscriptions are cleaned up + // when the NATS client is closed below. + if ds.Nats != nil { + ds.Nats.Close() + } + xlog.Info("Distributed services shut down") + }) +} + +// initDistributed validates distributed mode prerequisites and initializes +// NATS, object storage, node registry, and instance identity. +// Returns nil if distributed mode is not enabled. +func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) { + if !cfg.Distributed.Enabled { + return nil, nil + } + + xlog.Info("Distributed mode enabled — validating prerequisites") + + // Validate distributed config (NATS URL, S3 credential pairing, durations, etc.) + if err := cfg.Distributed.Validate(); err != nil { + return nil, err + } + + // Validate PostgreSQL is configured (auth DB must be PostgreSQL for distributed mode) + if !cfg.Auth.Enabled { + return nil, fmt.Errorf("distributed mode requires authentication to be enabled (--auth / LOCALAI_AUTH=true)") + } + if !isPostgresURL(cfg.Auth.DatabaseURL) { + return nil, fmt.Errorf("distributed mode requires PostgreSQL for auth database (got %q)", sanitize.URL(cfg.Auth.DatabaseURL)) + } + + // Generate instance ID if not set + if cfg.Distributed.InstanceID == "" { + cfg.Distributed.InstanceID = uuid.New().String() + } + xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID) + + // Connect to NATS + natsClient, err := messaging.New(cfg.Distributed.NatsURL) + if err != nil { + return nil, fmt.Errorf("connecting to NATS: %w", err) + } + xlog.Info("Connected to NATS", "url", sanitize.URL(cfg.Distributed.NatsURL)) + + // Ensure NATS is closed if any subsequent initialization step fails. + success := false + defer func() { + if !success { + natsClient.Close() + } + }() + + // Initialize object storage + var store storage.ObjectStore + if cfg.Distributed.StorageURL != "" { + if cfg.Distributed.StorageBucket == "" { + return nil, fmt.Errorf("distributed storage bucket must be set when storage URL is configured") + } + s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{ + Endpoint: cfg.Distributed.StorageURL, + Region: cfg.Distributed.StorageRegion, + Bucket: cfg.Distributed.StorageBucket, + AccessKeyID: cfg.Distributed.StorageAccessKey, + SecretAccessKey: cfg.Distributed.StorageSecretKey, + ForcePathStyle: true, // required for MinIO + }) + if err != nil { + return nil, fmt.Errorf("initializing S3 storage: %w", err) + } + xlog.Info("Object storage initialized (S3)", "endpoint", cfg.Distributed.StorageURL, "bucket", cfg.Distributed.StorageBucket) + store = s3Store + } else { + // Fallback to filesystem storage in distributed mode (useful for single-node testing) + fsStore, err := storage.NewFilesystemStore(cfg.DataPath + "/objectstore") + if err != nil { + return nil, fmt.Errorf("initializing filesystem storage: %w", err) + } + xlog.Info("Object storage initialized (filesystem fallback)", "path", cfg.DataPath+"/objectstore") + store = fsStore + } + + // Initialize node registry (requires the auth DB which is PostgreSQL) + if authDB == nil { + return nil, fmt.Errorf("distributed mode requires auth database to be initialized first") + } + + registry, err := nodes.NewNodeRegistry(authDB) + if err != nil { + return nil, fmt.Errorf("initializing node registry: %w", err) + } + xlog.Info("Node registry initialized") + + // Collect SmartRouter option values; the router itself is created after all + // dependencies (including FileStager and Unloader) are ready. + var routerAuthToken string + if cfg.Distributed.RegistrationToken != "" { + routerAuthToken = cfg.Distributed.RegistrationToken + } + var routerGalleriesJSON string + if galleriesJSON, err := json.Marshal(cfg.BackendGalleries); err == nil { + routerGalleriesJSON = string(galleriesJSON) + } + + healthMon := nodes.NewHealthMonitor(registry, authDB, + cfg.Distributed.HealthCheckIntervalOrDefault(), + cfg.Distributed.StaleNodeThresholdOrDefault(), + routerAuthToken, + cfg.Distributed.PerModelHealthCheck, + ) + + // Initialize job store + jobStore, err := jobs.NewJobStore(authDB) + if err != nil { + return nil, fmt.Errorf("initializing job store: %w", err) + } + xlog.Info("Distributed job store initialized") + + // Initialize job dispatcher + dispatcher := jobs.NewDispatcher(jobStore, natsClient, authDB, cfg.Distributed.InstanceID, cfg.Distributed.JobWorkerConcurrency) + + // Initialize agent store + agentStore, err := agents.NewAgentStore(authDB) + if err != nil { + return nil, fmt.Errorf("initializing agent store: %w", err) + } + xlog.Info("Distributed agent store initialized") + + // Initialize agent event bridge + agentBridge := agents.NewEventBridge(natsClient, agentStore, cfg.Distributed.InstanceID) + + // Start observable persister — captures observable_update events from workers + // (which have no DB access) and persists them to PostgreSQL. + if err := agentBridge.StartObservablePersister(); err != nil { + xlog.Warn("Failed to start observable persister", "error", err) + } else { + xlog.Info("Observable persister started") + } + + // Initialize Phase 4 stores (MCP, Gallery, FineTune, Skills) + distStores, err := distributed.InitStores(authDB) + if err != nil { + return nil, fmt.Errorf("initializing distributed stores: %w", err) + } + + // Initialize file manager with local cache + cacheDir := cfg.DataPath + "/cache" + fileMgr, err := storage.NewFileManager(store, cacheDir) + if err != nil { + return nil, fmt.Errorf("initializing file manager: %w", err) + } + xlog.Info("File manager initialized", "cacheDir", cacheDir) + + // Create FileStager for distributed file transfer + var fileStager nodes.FileStager + if cfg.Distributed.StorageURL != "" { + fileStager = nodes.NewS3NATSFileStager(fileMgr, natsClient) + xlog.Info("File stager initialized (S3+NATS)") + } else { + fileStager = nodes.NewHTTPFileStager(func(nodeID string) (string, error) { + node, err := registry.Get(context.Background(), nodeID) + if err != nil { + return "", err + } + if node.HTTPAddress == "" { + return "", fmt.Errorf("node %s has no HTTP address for file transfer", nodeID) + } + return node.HTTPAddress, nil + }, cfg.Distributed.RegistrationToken) + xlog.Info("File stager initialized (HTTP direct transfer)") + } + // Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go + remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient) + + // All dependencies ready — build SmartRouter with all options at once + router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{ + Unloader: remoteUnloader, + FileStager: fileStager, + GalleriesJSON: routerGalleriesJSON, + AuthToken: routerAuthToken, + DB: authDB, + }) + + // Create ModelRouterAdapter to wire into ModelLoader + modelAdapter := nodes.NewModelRouterAdapter(router) + + success = true + return &DistributedServices{ + Nats: natsClient, + Store: store, + Registry: registry, + Router: router, + Health: healthMon, + JobStore: jobStore, + Dispatcher: dispatcher, + AgentStore: agentStore, + AgentBridge: agentBridge, + DistStores: distStores, + FileMgr: fileMgr, + FileStager: fileStager, + ModelAdapter: modelAdapter, + Unloader: remoteUnloader, + }, nil +} + +func isPostgresURL(url string) bool { + return strings.HasPrefix(url, "postgres://") || strings.HasPrefix(url, "postgresql://") +} diff --git a/core/application/p2p.go b/core/application/p2p.go index 8522a121d..451e38121 100644 --- a/core/application/p2p.go +++ b/core/application/p2p.go @@ -11,7 +11,7 @@ import ( "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/edgevpn/pkg/node" "github.com/mudler/xlog" @@ -146,22 +146,14 @@ func (a *Application) RestartP2P() error { return fmt.Errorf("P2P token is not set") } - // Create new context for P2P - ctx, cancel := context.WithCancel(appConfig.Context) - a.p2pCtx = ctx - a.p2pCancel = cancel - - // Get API address from config - address := appConfig.APIAddress - if address == "" { - address = "127.0.0.1:8080" // default - } - // Start P2P stack in a goroutine + // Note: StartP2P creates its own context and assigns a.p2pCtx/a.p2pCancel go func() { if err := a.StartP2P(); err != nil { xlog.Error("Failed to start P2P stack", "error", err) - cancel() // Cancel context on error + if a.p2pCancel != nil { + a.p2pCancel() + } } }() xlog.Info("P2P stack restarted with new settings") @@ -228,7 +220,7 @@ func syncState(ctx context.Context, n *node.Node, app *Application) error { continue } - app.GalleryService().ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ + app.GalleryService().ModelGalleryChannel <- galleryop.ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{ ID: uuid.String(), GalleryElementName: model, Galleries: app.ApplicationConfig().Galleries, diff --git a/core/application/startup.go b/core/application/startup.go index 0c3bde5d5..026e55226 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -13,11 +13,15 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/auth" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/services/jobs" + "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/storage" coreStartup "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/sanitize" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/xlog" ) @@ -101,7 +105,7 @@ func New(opts ...config.AppOption) (*Application, error) { return nil, fmt.Errorf("failed to initialize auth database: %w", err) } application.authDB = authDB - xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL) + xlog.Info("Auth enabled", "database", sanitize.URL(options.Auth.DatabaseURL)) // Start session and expired API key cleanup goroutine go func() { @@ -123,12 +127,92 @@ func New(opts ...config.AppOption) (*Application, error) { }() } + // Wire JobStore for DB-backed task/job persistence whenever auth DB is available. + // This ensures tasks and jobs survive restarts in both single-node and distributed modes. + if application.authDB != nil && application.agentJobService != nil { + dbJobStore, err := jobs.NewJobStore(application.authDB) + if err != nil { + xlog.Error("Failed to create job store for auth DB", "error", err) + } else { + application.agentJobService.SetDistributedJobStore(dbJobStore) + } + } + + // Initialize distributed mode services (NATS, object storage, node registry) + distSvc, err := initDistributed(options, application.authDB) + if err != nil { + return nil, fmt.Errorf("distributed mode initialization failed: %w", err) + } + if distSvc != nil { + application.distributed = distSvc + // Wire remote model unloader so ShutdownModel works for remote nodes + // Uses NATS to tell serve-backend nodes to Free + kill their backend process + application.modelLoader.SetRemoteUnloader(distSvc.Unloader) + // Wire ModelRouter so grpcModel() delegates to SmartRouter in distributed mode + application.modelLoader.SetModelRouter(distSvc.ModelAdapter.AsModelRouter()) + // Wire DistributedModelStore so shutdown/list/watchdog can find remote models + distStore := nodes.NewDistributedModelStore( + model.NewInMemoryModelStore(), + distSvc.Registry, + ) + application.modelLoader.SetModelStore(distStore) + // Start health monitor + distSvc.Health.Start(options.Context) + // In distributed mode, MCP CI jobs are executed by agent workers (not the frontend) + // because the frontend can't create MCP sessions (e.g., stdio servers using docker). + // The dispatcher still subscribes to jobs.new for persistence (result/progress subs) + // but does NOT set a workerFn — agent workers consume jobs from the same NATS queue. + + // Wire model config loader so job events include model config for agent workers + distSvc.Dispatcher.SetModelConfigLoader(application.backendLoader) + + // Start job dispatcher — abort startup if it fails, as jobs would be accepted but never dispatched + if err := distSvc.Dispatcher.Start(options.Context); err != nil { + return nil, fmt.Errorf("starting job dispatcher: %w", err) + } + // Start ephemeral file cleanup + storage.StartEphemeralCleanup(options.Context, distSvc.FileMgr, 0, 0) + // Wire distributed backends into AgentJobService (before Start) + if application.agentJobService != nil { + application.agentJobService.SetDistributedBackends(distSvc.Dispatcher) + application.agentJobService.SetDistributedJobStore(distSvc.JobStore) + } + // Wire skill store into AgentPoolService (wired at pool start time via closure) + // The actual wiring happens in StartAgentPool since the pool doesn't exist yet. + + // Wire NATS and gallery store into GalleryService for cross-instance progress/cancel + if application.galleryService != nil { + application.galleryService.SetNATSClient(distSvc.Nats) + if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil { + // Clean up stale in-progress operations from previous crashed instances + if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil { + xlog.Warn("Failed to clean stale gallery operations", "error", err) + } + application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery) + } + // Wire distributed model/backend managers so delete propagates to workers + application.galleryService.SetModelManager( + nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader), + ) + application.galleryService.SetBackendManager( + nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry), + ) + } + } + + // Start AgentJobService (after distributed wiring so it knows whether to use local or NATS) + if application.agentJobService != nil { + if err := application.agentJobService.Start(options.Context); err != nil { + return nil, fmt.Errorf("starting agent job service: %w", err) + } + } + if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil { xlog.Error("error installing models", "error", err) } for _, backend := range options.ExternalBackends { - if err := services.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil { + if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil { xlog.Error("error installing external backend", "error", err) } } @@ -154,13 +238,13 @@ func New(opts ...config.AppOption) (*Application, error) { } if options.PreloadJSONModels != "" { - if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil { + if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil { return nil, err } } if options.PreloadModelsFromPath != "" { - if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil { + if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil { return nil, err } } @@ -184,6 +268,7 @@ func New(opts ...config.AppOption) (*Application, error) { go func() { <-options.Context.Done() xlog.Debug("Context canceled, shutting down") + application.distributed.Shutdown() err := application.ModelLoader().StopAllGRPC() if err != nil { xlog.Error("error while stopping all grpc backends", "error", err) @@ -207,7 +292,7 @@ func New(opts ...config.AppOption) (*Application, error) { var backendErr error _, backendErr = application.ModelLoader().Load(o...) if backendErr != nil { - return nil, err + return nil, backendErr } } } diff --git a/core/backend/llm.go b/core/backend/llm.go index b82533a9f..5b416a44d 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -13,9 +13,9 @@ import ( "github.com/mudler/xlog" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/grpc/proto" @@ -27,7 +27,7 @@ type LLMResponse struct { Response string // should this be []byte? Usage TokenUsage AudioOutput string - Logprobs *schema.Logprobs // Logprobs from the backend response + Logprobs *schema.Logprobs // Logprobs from the backend response ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser } @@ -47,14 +47,18 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima // Check if the modelFile exists, if it doesn't try to load it from the gallery if o.AutoloadGalleries { // experimental - modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS) + modelNames, err := galleryop.ListModels(cl, loader, nil, galleryop.SKIP_ALWAYS) if err != nil { return nil, err } - if !slices.Contains(modelNames, c.Name) { + modelName := c.Name + if modelName == "" { + modelName = c.Model + } + if !slices.Contains(modelNames, modelName) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) + err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) if err != nil { xlog.Error("failed to install model from gallery", "error", err, "model", modelFile) //return nil, err @@ -252,12 +256,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima trace.InitBackendTracingIfEnabled(o.TracingMaxItems) traceData := map[string]any{ - "chat_template": c.TemplateConfig.Chat, + "chat_template": c.TemplateConfig.Chat, "function_template": c.TemplateConfig.Functions, - "streaming": tokenCallback != nil, - "images_count": len(images), - "videos_count": len(videos), - "audios_count": len(audios), + "streaming": tokenCallback != nil, + "images_count": len(images), + "videos_count": len(videos), + "audios_count": len(audios), } if len(messages) > 0 { diff --git a/core/backend/options.go b/core/backend/options.go index 71b9d682a..2d410c90a 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -1,7 +1,7 @@ package backend import ( - "math/rand" + "math/rand/v2" "os" "path/filepath" "strings" @@ -86,7 +86,7 @@ func getSeed(c config.ModelConfig) int32 { } if seed == config.RAND_SEED { - seed = rand.Int31() + seed = rand.Int32() } return seed diff --git a/core/backend/tokenize.go b/core/backend/tokenize.go index f70ca14e2..761329973 100644 --- a/core/backend/tokenize.go +++ b/core/backend/tokenize.go @@ -4,8 +4,8 @@ import ( "time" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/model" ) diff --git a/core/backend/vad.go b/core/backend/vad.go index bcf6f5976..c3ecb66c9 100644 --- a/core/backend/vad.go +++ b/core/backend/vad.go @@ -1,40 +1,40 @@ -package backend - -import ( - "context" - - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/pkg/grpc/proto" - "github.com/mudler/LocalAI/pkg/model" -) - -func VAD(request *schema.VADRequest, - ctx context.Context, - ml *model.ModelLoader, - appConfig *config.ApplicationConfig, - modelConfig config.ModelConfig) (*schema.VADResponse, error) { - opts := ModelOptions(modelConfig, appConfig) - vadModel, err := ml.Load(opts...) - if err != nil { - recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) - return nil, err - } - - req := proto.VADRequest{ - Audio: request.Audio, - } - resp, err := vadModel.VAD(ctx, &req) - if err != nil { - return nil, err - } - - segments := []schema.VADSegment{} - for _, s := range resp.Segments { - segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End}) - } - - return &schema.VADResponse{ - Segments: segments, - }, nil -} +package backend + +import ( + "context" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/model" +) + +func VAD(request *schema.VADRequest, + ctx context.Context, + ml *model.ModelLoader, + appConfig *config.ApplicationConfig, + modelConfig config.ModelConfig) (*schema.VADResponse, error) { + opts := ModelOptions(modelConfig, appConfig) + vadModel, err := ml.Load(opts...) + if err != nil { + recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) + return nil, err + } + + req := proto.VADRequest{ + Audio: request.Audio, + } + resp, err := vadModel.VAD(ctx, &req) + if err != nil { + return nil, err + } + + segments := []schema.VADSegment{} + for _, s := range resp.Segments { + segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End}) + } + + return &schema.VADResponse{ + Segments: segments, + }, nil +} diff --git a/core/cli/agent.go b/core/cli/agent.go index acd240056..8193d27d8 100644 --- a/core/cli/agent.go +++ b/core/cli/agent.go @@ -8,11 +8,11 @@ import ( "os/signal" "syscall" - cliContext "github.com/mudler/LocalAI/core/cli/context" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAGI/core/state" coreTypes "github.com/mudler/LocalAGI/core/types" + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/agentpool" "github.com/mudler/xlog" ) @@ -59,7 +59,7 @@ func (r *AgentRunCMD) Run(ctx *cliContext.Context) error { appConfig := r.buildAppConfig() - poolService, err := services.NewAgentPoolService(appConfig) + poolService, err := agentpool.NewAgentPoolService(appConfig) if err != nil { return fmt.Errorf("failed to create agent pool service: %w", err) } diff --git a/core/cli/agent_worker.go b/core/cli/agent_worker.go new file mode 100644 index 000000000..2fdf7dd0c --- /dev/null +++ b/core/cli/agent_worker.go @@ -0,0 +1,463 @@ +package cli + +import ( + "cmp" + "context" + "encoding/json" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + "time" + + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/core/cli/workerregistry" + "github.com/mudler/LocalAI/core/config" + mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" + "github.com/mudler/LocalAI/core/services/agents" + "github.com/mudler/LocalAI/core/services/jobs" + mcpRemote "github.com/mudler/LocalAI/core/services/mcp" + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/pkg/sanitize" + "github.com/mudler/cogito" + "github.com/mudler/cogito/clients" + "github.com/mudler/xlog" +) + +// AgentWorkerCMD starts a dedicated agent worker process for distributed mode. +// It registers with the frontend, subscribes to the NATS agent execution queue, +// and executes agent chats using cogito. The worker is a pure executor — it +// receives the full agent config and skills in the NATS job payload, so it +// does not need direct database access. +// +// Usage: +// +// localai agent-worker --nats-url nats://... --register-to http://localai:8080 +type AgentWorkerCMD struct { + // NATS (required) + NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"` + + // Registration (required) + RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"` + NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"` + RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"` + HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"` + + // API access + APIURL string `env:"LOCALAI_API_URL" help:"LocalAI API URL for inference (auto-derived from RegisterTo if not set)" group:"api"` + APIToken string `env:"LOCALAI_API_TOKEN" help:"API token for LocalAI inference (auto-provisioned during registration if not set)" group:"api"` + + // NATS subjects + Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"` + Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"` + + // Timeouts + MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"` +} + +func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error { + xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo) + + // Resolve API URL + apiURL := cmp.Or(cmd.APIURL, strings.TrimRight(cmd.RegisterTo, "/")) + + // Register with frontend + regClient := &workerregistry.RegistrationClient{ + FrontendURL: cmd.RegisterTo, + RegistrationToken: cmd.RegistrationToken, + } + + nodeName := cmd.NodeName + if nodeName == "" { + hostname, _ := os.Hostname() + nodeName = "agent-" + hostname + } + registrationBody := map[string]any{ + "name": nodeName, + "node_type": "agent", + } + if cmd.RegistrationToken != "" { + registrationBody["token"] = cmd.RegistrationToken + } + + nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10) + if err != nil { + return fmt.Errorf("registration failed: %w", err) + } + xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo) + + // Use provisioned API token if none was set + if cmd.APIToken == "" { + cmd.APIToken = apiToken + } + + // Start heartbeat + heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval) + if err != nil && cmd.HeartbeatInterval != "" { + xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err) + } + heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second) + // Context cancelled on shutdown — used by heartbeat and other background goroutines + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + defer shutdownCancel() + + go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} }) + + // Connect to NATS + natsClient, err := messaging.New(cmd.NatsURL) + if err != nil { + return fmt.Errorf("connecting to NATS: %w", err) + } + defer natsClient.Close() + + // Create event bridge for publishing results back via NATS + eventBridge := agents.NewEventBridge(natsClient, nil, "agent-worker-"+nodeID) + + // Start cancel listener + cancelSub, err := eventBridge.StartCancelListener() + if err != nil { + xlog.Warn("Failed to start cancel listener", "error", err) + } else { + defer cancelSub.Unsubscribe() + } + + // Create and start the NATS dispatcher. + // No ConfigProvider or SkillStore needed — config and skills arrive in the job payload. + dispatcher := agents.NewNATSDispatcher( + natsClient, + eventBridge, + nil, // no ConfigProvider: config comes in the enriched NATS payload + apiURL, cmd.APIToken, + cmd.Subject, cmd.Queue, + 0, // no concurrency limit (CLI worker) + ) + + if err := dispatcher.Start(shutdownCtx); err != nil { + return fmt.Errorf("starting dispatcher: %w", err) + } + + // Subscribe to MCP tool execution requests (load-balanced across workers). + // The frontend routes model-level MCP tool calls here via NATS request-reply. + if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) { + handleMCPToolRequest(data, reply) + }); err != nil { + return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPToolExecute, err) + } + + // Subscribe to MCP discovery requests (load-balanced across workers). + if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) { + handleMCPDiscoveryRequest(data, reply) + }); err != nil { + return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPDiscovery, err) + } + + // Subscribe to MCP CI job execution (load-balanced across agent workers). + // In distributed mode, MCP CI jobs are routed here because the frontend + // cannot create MCP sessions (e.g., stdio servers using docker). + mcpCIJobTimeout, err := time.ParseDuration(cmd.MCPCIJobTimeout) + if err != nil && cmd.MCPCIJobTimeout != "" { + xlog.Warn("invalid MCP CI job timeout, using default 10m", "input", cmd.MCPCIJobTimeout, "error", err) + } + mcpCIJobTimeout = cmp.Or(mcpCIJobTimeout, config.DefaultMCPCIJobTimeout) + + if _, err := natsClient.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func(data []byte) { + handleMCPCIJob(shutdownCtx, data, apiURL, cmd.APIToken, natsClient, mcpCIJobTimeout) + }); err != nil { + return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPCIJobsNew, err) + } + + // Subscribe to backend stop events to clean up cached MCP sessions. + // In the main application this is done via ml.OnModelUnload, but the agent + // worker has no model loader — we listen for the NATS stop event instead. + if _, err := natsClient.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func(data []byte) { + var req struct { + Backend string `json:"backend"` + } + if json.Unmarshal(data, &req) == nil && req.Backend != "" { + mcpTools.CloseMCPSessions(req.Backend) + } + }); err != nil { + return fmt.Errorf("subscribing to %s: %w", messaging.SubjectNodeBackendStop(nodeID), err) + } + + xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue) + + // Wait for shutdown + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + <-sigCh + + xlog.Info("Shutting down agent worker") + shutdownCancel() // stop heartbeat loop immediately + dispatcher.Stop() + mcpTools.CloseAllMCPSessions() + regClient.GracefulDeregister(nodeID) + return nil +} + +// handleMCPToolRequest handles a NATS request-reply for MCP tool execution. +// The worker creates/caches MCP sessions from the serialized config and executes the tool. +func handleMCPToolRequest(data []byte, reply func([]byte)) { + var req mcpRemote.MCPToolRequest + if err := json.Unmarshal(data, &req); err != nil { + sendMCPToolReply(reply, "", fmt.Sprintf("unmarshal error: %v", err)) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPToolTimeout) + defer cancel() + + // Create/cache named MCP sessions from the provided config + namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil) + if err != nil { + sendMCPToolReply(reply, "", fmt.Sprintf("session error: %v", err)) + return + } + + // Discover tools to find the right session + tools, err := mcpTools.DiscoverMCPTools(ctx, namedSessions) + if err != nil { + sendMCPToolReply(reply, "", fmt.Sprintf("discovery error: %v", err)) + return + } + + // Execute the tool + argsJSON, _ := json.Marshal(req.Arguments) + result, err := mcpTools.ExecuteMCPToolCall(ctx, tools, req.ToolName, string(argsJSON)) + if err != nil { + sendMCPToolReply(reply, "", err.Error()) + return + } + + sendMCPToolReply(reply, result, "") +} + +func sendMCPToolReply(reply func([]byte), result, errMsg string) { + resp := mcpRemote.MCPToolResponse{Result: result, Error: errMsg} + data, _ := json.Marshal(resp) + reply(data) +} + +// handleMCPDiscoveryRequest handles a NATS request-reply for MCP tool/prompt/resource discovery. +func handleMCPDiscoveryRequest(data []byte, reply func([]byte)) { + var req mcpRemote.MCPDiscoveryRequest + if err := json.Unmarshal(data, &req); err != nil { + sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("unmarshal error: %v", err)) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPDiscoveryTimeout) + defer cancel() + + // Create/cache named MCP sessions + namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil) + if err != nil { + sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("session error: %v", err)) + return + } + + // List servers with their tools/prompts/resources + serverInfos, err := mcpTools.ListMCPServers(ctx, namedSessions) + if err != nil { + sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("list error: %v", err)) + return + } + + // Also get tool function schemas for the frontend + tools, _ := mcpTools.DiscoverMCPTools(ctx, namedSessions) + var toolDefs []mcpRemote.MCPToolDef + for _, t := range tools { + toolDefs = append(toolDefs, mcpRemote.MCPToolDef{ + ServerName: t.ServerName, + ToolName: t.ToolName, + Function: t.Function, + }) + } + + // Convert server infos + var servers []mcpRemote.MCPServerInfo + for _, s := range serverInfos { + servers = append(servers, mcpRemote.MCPServerInfo{ + Name: s.Name, + Type: s.Type, + Tools: s.Tools, + Prompts: s.Prompts, + Resources: s.Resources, + }) + } + + sendMCPDiscoveryReply(reply, servers, toolDefs, "") +} + +func sendMCPDiscoveryReply(reply func([]byte), servers []mcpRemote.MCPServerInfo, tools []mcpRemote.MCPToolDef, errMsg string) { + resp := mcpRemote.MCPDiscoveryResponse{Servers: servers, Tools: tools, Error: errMsg} + data, _ := json.Marshal(resp) + reply(data) +} + +// handleMCPCIJob processes an MCP CI job on the agent worker. +// The agent worker can create MCP sessions (has docker) and call the LocalAI API for inference. +func handleMCPCIJob(shutdownCtx context.Context, data []byte, apiURL, apiToken string, natsClient messaging.MessagingClient, jobTimeout time.Duration) { + var evt jobs.JobEvent + if err := json.Unmarshal(data, &evt); err != nil { + xlog.Error("Failed to unmarshal job event", "error", err) + return + } + + job := evt.Job + task := evt.Task + if job == nil || task == nil { + xlog.Error("MCP CI job missing enriched data", "jobID", evt.JobID) + publishJobResult(natsClient, evt.JobID, "failed", "", "job or task data missing from NATS event") + return + } + + modelCfg := evt.ModelConfig + if modelCfg == nil { + publishJobResult(natsClient, evt.JobID, "failed", "", "model config missing from job event") + return + } + + xlog.Info("Processing MCP CI job", "jobID", evt.JobID, "taskID", evt.TaskID, "model", task.Model) + + // Publish running status + natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{ + JobID: evt.JobID, Status: "running", Message: "Job started on agent worker", + }) + + // Parse MCP config + if modelCfg.MCP.Servers == "" && modelCfg.MCP.Stdio == "" { + publishJobResult(natsClient, evt.JobID, "failed", "", "no MCP servers configured for model") + return + } + + remote, stdio, err := modelCfg.MCP.MCPConfigFromYAML() + if err != nil { + publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("failed to parse MCP config: %v", err)) + return + } + + // Create MCP sessions locally (agent worker has docker) + sessions, err := mcpTools.SessionsFromMCPConfig(modelCfg.Name, remote, stdio) + if err != nil || len(sessions) == 0 { + errMsg := "no working MCP servers found" + if err != nil { + errMsg = fmt.Sprintf("failed to create MCP sessions: %v", err) + } + publishJobResult(natsClient, evt.JobID, "failed", "", errMsg) + return + } + + // Build prompt from template + prompt := task.Prompt + if task.CronParametersJSON != "" { + var params map[string]string + if err := json.Unmarshal([]byte(task.CronParametersJSON), ¶ms); err != nil { + xlog.Warn("Failed to unmarshal parameters", "error", err) + } + for k, v := range params { + prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v) + } + } + if job.ParametersJSON != "" { + var params map[string]string + if err := json.Unmarshal([]byte(job.ParametersJSON), ¶ms); err != nil { + xlog.Warn("Failed to unmarshal parameters", "error", err) + } + for k, v := range params { + prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v) + } + } + + // Create LLM client pointing back to the frontend API + llm := clients.NewLocalAILLM(task.Model, apiToken, apiURL) + + // Build cogito options + ctx, cancel := context.WithTimeout(shutdownCtx, jobTimeout) + defer cancel() + + // Update job status to running in DB + publishJobStatus(natsClient, evt.JobID, "running", "") + + // Buffer stream tokens and flush as complete blocks + var reasoningBuf, contentBuf strings.Builder + var lastStreamType cogito.StreamEventType + + flushStreamBuf := func() { + if reasoningBuf.Len() > 0 { + natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{ + JobID: evt.JobID, TraceType: "reasoning", TraceContent: reasoningBuf.String(), + }) + reasoningBuf.Reset() + } + if contentBuf.Len() > 0 { + natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{ + JobID: evt.JobID, TraceType: "content", TraceContent: contentBuf.String(), + }) + contentBuf.Reset() + } + } + + cogitoOpts := modelCfg.BuildCogitoOptions() + cogitoOpts = append(cogitoOpts, + cogito.WithContext(ctx), + cogito.WithMCPs(sessions...), + cogito.WithStatusCallback(func(status string) { + flushStreamBuf() + natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{ + JobID: evt.JobID, TraceType: "status", TraceContent: status, + }) + }), + cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) { + flushStreamBuf() + natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{ + JobID: evt.JobID, TraceType: "tool_result", TraceContent: fmt.Sprintf("%s: %s", t.Name, t.Result), + }) + }), + cogito.WithStreamCallback(func(ev cogito.StreamEvent) { + // Flush if stream type changed (e.g., reasoning → content) + if ev.Type != lastStreamType { + flushStreamBuf() + lastStreamType = ev.Type + } + switch ev.Type { + case cogito.StreamEventReasoning: + reasoningBuf.WriteString(ev.Content) + case cogito.StreamEventContent: + contentBuf.WriteString(ev.Content) + case cogito.StreamEventToolCall: + natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{ + JobID: evt.JobID, TraceType: "tool_call", TraceContent: fmt.Sprintf("%s(%s)", ev.ToolName, ev.ToolArgs), + }) + } + }), + ) + + // Execute via cogito + fragment := cogito.NewEmptyFragment() + fragment = fragment.AddMessage("user", prompt) + + f, err := cogito.ExecuteTools(llm, fragment, cogitoOpts...) + flushStreamBuf() // flush any remaining buffered tokens + + if err != nil { + publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("cogito execution failed: %v", err)) + return + } + + result := "" + if msg := f.LastMessage(); msg != nil { + result = msg.Content + } + publishJobResult(natsClient, evt.JobID, "completed", result, "") + xlog.Info("MCP CI job completed", "jobID", evt.JobID, "resultLen", len(result)) +} + +func publishJobStatus(nc messaging.MessagingClient, jobID, status, message string) { + jobs.PublishJobProgress(nc, jobID, status, message) +} + +func publishJobResult(nc messaging.MessagingClient, jobID, status, result, errMsg string) { + jobs.PublishJobResult(nc, jobID, status, result, errMsg) +} diff --git a/core/cli/backends.go b/core/cli/backends.go index 9877d746a..23f6b3ff1 100644 --- a/core/cli/backends.go +++ b/core/cli/backends.go @@ -8,7 +8,7 @@ import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" @@ -103,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error { } modelLoader := model.NewModelLoader(systemState) - err = services.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias) + err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias) if err != nil { return err } diff --git a/core/cli/cli.go b/core/cli/cli.go index 9d448f88e..b87c81511 100644 --- a/core/cli/cli.go +++ b/core/cli/cli.go @@ -15,7 +15,9 @@ var CLI struct { TTS TTSCMD `cmd:"" help:"Convert text to speech"` SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"` Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"` - Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"` + P2PWorker worker.Worker `cmd:"" name:"p2p-worker" help:"Run workers to distribute workload via p2p (llama.cpp-only)"` + Worker WorkerCMD `cmd:"" help:"Start a worker for distributed mode (generic, backend-agnostic)"` + AgentWorker AgentWorkerCMD `cmd:"" name:"agent-worker" help:"Start an agent worker for distributed mode (executes agent chats via NATS)"` Util UtilCMD `cmd:"" help:"Utility commands"` Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"` Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"` diff --git a/core/cli/completion.go b/core/cli/completion.go index 809599520..04bf4b30a 100644 --- a/core/cli/completion.go +++ b/core/cli/completion.go @@ -186,9 +186,9 @@ _local_ai_completions() } subcmds := []string{} for _, sub := range cmds { - parts := strings.SplitN(sub.fullName, " ", 2) - if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") { - subcmds = append(subcmds, parts[1]) + parent, child, found := strings.Cut(sub.fullName, " ") + if found && parent == cmd.name && !strings.Contains(child, " ") { + subcmds = append(subcmds, child) } } if len(subcmds) > 0 { @@ -279,8 +279,8 @@ _local_ai() { // Check for subcommands subcmds := []commandInfo{} for _, sub := range cmds { - parts := strings.SplitN(sub.fullName, " ", 2) - if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") { + parent, child, found := strings.Cut(sub.fullName, " ") + if found && parent == cmd.name && !strings.Contains(child, " ") { subcmds = append(subcmds, sub) } } @@ -289,11 +289,11 @@ _local_ai() { sb.WriteString(" local -a subcmds\n") sb.WriteString(" subcmds=(\n") for _, sub := range subcmds { - parts := strings.SplitN(sub.fullName, " ", 2) + _, child, _ := strings.Cut(sub.fullName, " ") help := strings.ReplaceAll(sub.help, "'", "'\\''") help = strings.ReplaceAll(help, "[", "\\[") help = strings.ReplaceAll(help, "]", "\\]") - sb.WriteString(fmt.Sprintf(" '%s:%s'\n", parts[1], help)) + sb.WriteString(fmt.Sprintf(" '%s:%s'\n", child, help)) } sb.WriteString(" )\n") sb.WriteString(" _describe -t commands 'subcommands' subcmds\n") @@ -372,10 +372,10 @@ func generateFishCompletion(app *kong.Application) string { // Subcommands for _, sub := range cmds { - parts := strings.SplitN(sub.fullName, " ", 2) - if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") { + parent, child, found := strings.Cut(sub.fullName, " ") + if found && parent == cmd.name && !strings.Contains(child, " ") { help := strings.ReplaceAll(sub.help, "'", "\\'") - sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, parts[1], help)) + sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, child, help)) } } diff --git a/core/cli/completion_test.go b/core/cli/completion_test.go index 5da1d49db..be625d051 100644 --- a/core/cli/completion_test.go +++ b/core/cli/completion_test.go @@ -9,8 +9,8 @@ import ( func getTestApp() *kong.Application { var testCLI struct { - Run struct{} `cmd:"" help:"Run the server"` - Models struct { + Run struct{} `cmd:"" help:"Run the server"` + Models struct { List struct{} `cmd:"" help:"List models"` Install struct{} `cmd:"" help:"Install a model"` } `cmd:"" help:"Manage models"` diff --git a/core/cli/models.go b/core/cli/models.go index 3006922c8..a947c18e2 100644 --- a/core/cli/models.go +++ b/core/cli/models.go @@ -8,7 +8,7 @@ import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/startup" @@ -80,7 +80,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { return err } - galleryService := services.NewGalleryService(&config.ApplicationConfig{ + galleryService := galleryop.NewGalleryService(&config.ApplicationConfig{ SystemState: systemState, }, model.NewModelLoader(systemState)) err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState) diff --git a/core/cli/run.go b/core/cli/run.go index c614e123d..d3f1ac103 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -44,9 +44,9 @@ type RunCMD struct { Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"` AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"` - BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"` - BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"` - BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"` + BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"` + BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"` + BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"` PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"` Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"` PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"` @@ -100,7 +100,7 @@ type RunCMD struct { OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"` // Agent Pool (LocalAGI) - DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"` + DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"` AgentPoolAPIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"Default API URL for agents (defaults to self-referencing LocalAI)" group:"agents"` AgentPoolAPIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"Default API key for agents (defaults to first LocalAI API key)" group:"agents"` AgentPoolDefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for agents" group:"agents"` @@ -109,17 +109,17 @@ type RunCMD struct { AgentPoolTranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Default transcription language for agents" group:"agents"` AgentPoolTTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"Default TTS model for agents" group:"agents"` AgentPoolStateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" help:"State directory for agent pool" group:"agents"` - AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"` - AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"` - AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"` - AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"` - AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"` - AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"` - AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"` - AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"` - AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"` - AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"` - AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"` + AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"` + AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"` + AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"` + AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"` + AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"` + AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"` + AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"` + AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"` + AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"` + AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"` + AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"` // Authentication AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"` @@ -136,6 +136,18 @@ type RunCMD struct { AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"` DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"` + // Distributed / Horizontal Scaling + Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"` + InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"` + NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"` + StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"` + StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"` + StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"` + StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"` + StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"` + RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"` + AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"` + Version bool } @@ -210,6 +222,38 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { }), } + // Distributed mode + if r.Distributed { + opts = append(opts, config.EnableDistributed) + } + if r.InstanceID != "" { + opts = append(opts, config.WithDistributedInstanceID(r.InstanceID)) + } + if r.NatsURL != "" { + opts = append(opts, config.WithNatsURL(r.NatsURL)) + } + if r.StorageURL != "" { + opts = append(opts, config.WithStorageURL(r.StorageURL)) + } + if r.StorageBucket != "" { + opts = append(opts, config.WithStorageBucket(r.StorageBucket)) + } + if r.StorageRegion != "" { + opts = append(opts, config.WithStorageRegion(r.StorageRegion)) + } + if r.StorageAccessKey != "" { + opts = append(opts, config.WithStorageAccessKey(r.StorageAccessKey)) + } + if r.StorageSecretKey != "" { + opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey)) + } + if r.RegistrationToken != "" { + opts = append(opts, config.WithRegistrationToken(r.RegistrationToken)) + } + if r.AutoApproveNodes { + opts = append(opts, config.EnableAutoApproveNodes) + } + if r.DisableMetricsEndpoint { opts = append(opts, config.DisableMetricsEndpoint) } @@ -218,10 +262,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { opts = append(opts, config.DisableRuntimeSettings) } - if r.EnableTracing { - opts = append(opts, config.EnableTracing) - } - if r.EnableTracing { opts = append(opts, config.EnableTracing) } @@ -479,6 +519,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { if err := app.ModelLoader().StopAllGRPC(); err != nil { xlog.Error("error while stopping all grpc backends", "error", err) } + // Clean up distributed services (idempotent — safe if already called) + if d := app.Distributed(); d != nil { + d.Shutdown() + } }) // Start the agent pool after the HTTP server is listening, because diff --git a/core/cli/transcript.go b/core/cli/transcript.go index 8da3892a0..47a8e61a5 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -12,7 +12,6 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/pkg/format" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" @@ -80,7 +79,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { switch t.ResponseFormat { case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText: - fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat)) + fmt.Println(schema.TranscriptionResponse(tr, t.ResponseFormat)) case schema.TranscriptionResponseFormatJson: tr.Segments = nil fallthrough diff --git a/core/cli/worker.go b/core/cli/worker.go new file mode 100644 index 000000000..4042ae352 --- /dev/null +++ b/core/cli/worker.go @@ -0,0 +1,897 @@ +package cli + +import ( + "cmp" + "context" + "encoding/json" + "fmt" + "maps" + "net" + "os" + "os/signal" + "path/filepath" + "slices" + "strconv" + "strings" + "sync" + "syscall" + "time" + + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/core/cli/workerregistry" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/storage" + grpc "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/sanitize" + "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/LocalAI/pkg/xsysinfo" + process "github.com/mudler/go-processmanager" + "github.com/mudler/xlog" +) + +// isPathAllowed checks if path is within one of the allowed directories. +func isPathAllowed(path string, allowedDirs []string) bool { + absPath, err := filepath.Abs(path) + if err != nil { + return false + } + resolved, err := filepath.EvalSymlinks(absPath) + if err != nil { + // Path may not exist yet; use the absolute path + resolved = absPath + } + for _, dir := range allowedDirs { + absDir, err := filepath.Abs(dir) + if err != nil { + continue + } + if strings.HasPrefix(resolved, absDir+string(filepath.Separator)) || resolved == absDir { + return true + } + } + return false +} + +// WorkerCMD starts a generic worker process for distributed mode. +// Workers are backend-agnostic — they wait for backend.install NATS events +// from the SmartRouter to install and start the required backend. +// +// NATS is required. The worker acts as a process supervisor: +// - Receives backend.install → installs backend from gallery, starts gRPC process, replies success +// - Receives backend.stop → stops the gRPC process +// - Receives stop → full shutdown (deregister + exit) +// +// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that. +type WorkerCMD struct { + Addr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"Address to bind the gRPC server to" group:"server"` + BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"` + BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"` + BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"` + ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"` + + // HTTP file transfer + HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server"` + AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server"` + + // Registration (required) + AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration"` + RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"` + NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"` + RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"` + HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"` + + // NATS (required) + NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"` + + // S3 storage for distributed file transfer + StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"` + StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" help:"S3 bucket name" group:"distributed"` + StorageRegion string `env:"LOCALAI_STORAGE_REGION" help:"S3 region" group:"distributed"` + StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key" group:"distributed"` + StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret key" group:"distributed"` +} + +func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error { + xlog.Info("Starting worker", "addr", cmd.Addr) + + systemState, err := system.GetSystemState( + system.WithModelPath(cmd.ModelsPath), + system.WithBackendPath(cmd.BackendsPath), + system.WithBackendSystemPath(cmd.BackendsSystemPath), + ) + if err != nil { + return fmt.Errorf("getting system state: %w", err) + } + + ml := model.NewModelLoader(systemState) + ml.SetBackendLoggingEnabled(true) + + // Register already-installed backends + gallery.RegisterBackends(systemState, ml) + + // Parse galleries config + var galleries []config.Gallery + if err := json.Unmarshal([]byte(cmd.BackendGalleries), &galleries); err != nil { + xlog.Warn("Failed to parse backend galleries", "error", err) + } + + // Self-registration with frontend (with retry) + regClient := &workerregistry.RegistrationClient{ + FrontendURL: cmd.RegisterTo, + RegistrationToken: cmd.RegistrationToken, + } + + registrationBody := cmd.registrationBody() + nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10) + if err != nil { + return fmt.Errorf("failed to register with frontend: %w", err) + } + + xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo) + heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval) + if err != nil && cmd.HeartbeatInterval != "" { + xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err) + } + heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second) + // Context cancelled on shutdown — used by heartbeat and other background goroutines + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + defer shutdownCancel() + + // Start HTTP file transfer server + httpAddr := cmd.resolveHTTPAddr() + stagingDir := filepath.Join(cmd.ModelsPath, "..", "staging") + dataDir := filepath.Join(cmd.ModelsPath, "..", "data") + httpServer, err := nodes.StartFileTransferServer(httpAddr, stagingDir, cmd.ModelsPath, dataDir, cmd.RegistrationToken, config.DefaultMaxUploadSize, ml.BackendLogs()) + if err != nil { + return fmt.Errorf("starting HTTP file transfer server: %w", err) + } + + // Connect to NATS + xlog.Info("Connecting to NATS", "url", sanitize.URL(cmd.NatsURL)) + natsClient, err := messaging.New(cmd.NatsURL) + if err != nil { + nodes.ShutdownFileTransferServer(httpServer) + return fmt.Errorf("connecting to NATS: %w", err) + } + defer natsClient.Close() + + // Start heartbeat goroutine (after NATS is connected so IsConnected check works) + go func() { + ticker := time.NewTicker(heartbeatInterval) + defer ticker.Stop() + for { + select { + case <-shutdownCtx.Done(): + return + case <-ticker.C: + if !natsClient.IsConnected() { + xlog.Warn("Skipping heartbeat: NATS disconnected") + continue + } + body := cmd.heartbeatBody() + if err := regClient.Heartbeat(shutdownCtx, nodeID, body); err != nil { + xlog.Warn("Heartbeat failed", "error", err) + } + } + } + }() + + // Process supervisor — manages multiple backend gRPC processes on different ports + basePort := 50051 + if cmd.Addr != "" { + // Extract port from addr (e.g., "0.0.0.0:50051" → 50051) + if _, portStr, err := net.SplitHostPort(cmd.Addr); err == nil { + if p, err := strconv.Atoi(portStr); err == nil { + basePort = p + } + } + } + // Buffered so NATS stop handler can send without blocking + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + // Set the registration token once before any backends are started + if cmd.RegistrationToken != "" { + os.Setenv(grpc.AuthTokenEnvVar, cmd.RegistrationToken) + } + + supervisor := &backendSupervisor{ + cmd: cmd, + ml: ml, + systemState: systemState, + galleries: galleries, + nodeID: nodeID, + nats: natsClient, + sigCh: sigCh, + processes: make(map[string]*backendProcess), + nextPort: basePort, + } + supervisor.subscribeLifecycleEvents() + + // Subscribe to file staging NATS subjects if S3 is configured + if cmd.StorageURL != "" { + if err := cmd.subscribeFileStaging(natsClient, nodeID); err != nil { + xlog.Error("Failed to subscribe to file staging subjects", "error", err) + } + } + + xlog.Info("Worker ready, waiting for backend.install events") + <-sigCh + + xlog.Info("Shutting down worker") + shutdownCancel() // stop heartbeat loop immediately + regClient.GracefulDeregister(nodeID) + supervisor.stopAllBackends() + nodes.ShutdownFileTransferServer(httpServer) + return nil +} + +// subscribeFileStaging subscribes to NATS file staging subjects for this node. +func (cmd *WorkerCMD) subscribeFileStaging(natsClient messaging.MessagingClient, nodeID string) error { + // Create FileManager with same S3 config as the frontend + // TODO: propagate a caller-provided context once WorkerCMD carries one + s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{ + Endpoint: cmd.StorageURL, + Region: cmd.StorageRegion, + Bucket: cmd.StorageBucket, + AccessKeyID: cmd.StorageAccessKey, + SecretAccessKey: cmd.StorageSecretKey, + ForcePathStyle: true, + }) + if err != nil { + return fmt.Errorf("initializing S3 store: %w", err) + } + + cacheDir := filepath.Join(cmd.ModelsPath, "..", "cache") + fm, err := storage.NewFileManager(s3Store, cacheDir) + if err != nil { + return fmt.Errorf("initializing file manager: %w", err) + } + + // Subscribe: files.ensure — download S3 key to local, reply with local path + natsClient.SubscribeReply(messaging.SubjectNodeFilesEnsure(nodeID), func(data []byte, reply func([]byte)) { + var req struct { + Key string `json:"key"` + } + if err := json.Unmarshal(data, &req); err != nil { + replyJSON(reply, map[string]string{"error": "invalid request"}) + return + } + + localPath, err := fm.Download(context.Background(), req.Key) + if err != nil { + xlog.Error("File ensure failed", "key", req.Key, "error", err) + replyJSON(reply, map[string]string{"error": err.Error()}) + return + } + + xlog.Debug("File ensured locally", "key", req.Key, "path", localPath) + replyJSON(reply, map[string]string{"local_path": localPath}) + }) + + // Subscribe: files.stage — upload local path to S3, reply with key + natsClient.SubscribeReply(messaging.SubjectNodeFilesStage(nodeID), func(data []byte, reply func([]byte)) { + var req struct { + LocalPath string `json:"local_path"` + Key string `json:"key"` + } + if err := json.Unmarshal(data, &req); err != nil { + replyJSON(reply, map[string]string{"error": "invalid request"}) + return + } + + allowedDirs := []string{cacheDir} + if cmd.ModelsPath != "" { + allowedDirs = append(allowedDirs, cmd.ModelsPath) + } + if !isPathAllowed(req.LocalPath, allowedDirs) { + replyJSON(reply, map[string]string{"error": "path outside allowed directories"}) + return + } + + if err := fm.Upload(context.Background(), req.Key, req.LocalPath); err != nil { + xlog.Error("File stage failed", "path", req.LocalPath, "key", req.Key, "error", err) + replyJSON(reply, map[string]string{"error": err.Error()}) + return + } + + xlog.Debug("File staged to S3", "path", req.LocalPath, "key", req.Key) + replyJSON(reply, map[string]string{"key": req.Key}) + }) + + // Subscribe: files.temp — allocate temp file, reply with local path + natsClient.SubscribeReply(messaging.SubjectNodeFilesTemp(nodeID), func(data []byte, reply func([]byte)) { + tmpDir := filepath.Join(cacheDir, "staging-tmp") + if err := os.MkdirAll(tmpDir, 0750); err != nil { + replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp dir: %v", err)}) + return + } + + f, err := os.CreateTemp(tmpDir, "localai-staging-*.tmp") + if err != nil { + replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp file: %v", err)}) + return + } + localPath := f.Name() + f.Close() + + xlog.Debug("Allocated temp file", "path", localPath) + replyJSON(reply, map[string]string{"local_path": localPath}) + }) + + // Subscribe: files.listdir — list files in a local directory, reply with relative paths + natsClient.SubscribeReply(messaging.SubjectNodeFilesListDir(nodeID), func(data []byte, reply func([]byte)) { + var req struct { + KeyPrefix string `json:"key_prefix"` + } + if err := json.Unmarshal(data, &req); err != nil { + replyJSON(reply, map[string]any{"error": "invalid request"}) + return + } + + // Resolve key prefix to local directory + dirPath := filepath.Join(cacheDir, req.KeyPrefix) + if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.ModelKeyPrefix); ok && cmd.ModelsPath != "" { + dirPath = filepath.Join(cmd.ModelsPath, rel) + } else if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.DataKeyPrefix); ok { + dirPath = filepath.Join(cacheDir, "..", "data", rel) + } + + // Sanitize to prevent directory traversal via crafted key_prefix + dirPath = filepath.Clean(dirPath) + cleanCache := filepath.Clean(cacheDir) + cleanModels := filepath.Clean(cmd.ModelsPath) + cleanData := filepath.Clean(filepath.Join(cacheDir, "..", "data")) + if !(strings.HasPrefix(dirPath, cleanCache+string(filepath.Separator)) || + dirPath == cleanCache || + (cleanModels != "." && strings.HasPrefix(dirPath, cleanModels+string(filepath.Separator))) || + dirPath == cleanModels || + strings.HasPrefix(dirPath, cleanData+string(filepath.Separator)) || + dirPath == cleanData) { + replyJSON(reply, map[string]any{"error": "invalid key prefix"}) + return + } + + var files []string + filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if !d.IsDir() { + rel, err := filepath.Rel(dirPath, path) + if err == nil { + files = append(files, rel) + } + } + return nil + }) + + xlog.Debug("Listed remote dir", "keyPrefix", req.KeyPrefix, "dirPath", dirPath, "fileCount", len(files)) + replyJSON(reply, map[string]any{"files": files}) + }) + + xlog.Info("Subscribed to file staging NATS subjects", "nodeID", nodeID) + return nil +} + +// replyJSON marshals v to JSON and calls the reply function. +func replyJSON(reply func([]byte), v any) { + data, err := json.Marshal(v) + if err != nil { + xlog.Error("Failed to marshal NATS reply", "error", err) + data = []byte(`{"error":"internal marshal error"}`) + } + reply(data) +} + +// backendProcess represents a single gRPC backend process. +type backendProcess struct { + proc *process.Process + backend string + addr string // gRPC address (host:port) +} + +// backendSupervisor manages multiple backend gRPC processes on different ports. +// Each backend type (e.g., llama-cpp, bert-embeddings) gets its own process and port. +type backendSupervisor struct { + cmd *WorkerCMD + ml *model.ModelLoader + systemState *system.SystemState + galleries []config.Gallery + nodeID string + nats messaging.MessagingClient + sigCh chan<- os.Signal // send shutdown signal instead of os.Exit + + mu sync.Mutex + processes map[string]*backendProcess // key: backend name + nextPort int // next available port for new backends + freePorts []int // ports freed by stopBackend, reused before nextPort +} + +// startBackend starts a gRPC backend process on a dynamically allocated port. +// Returns the gRPC address. +func (s *backendSupervisor) startBackend(backend, backendPath string) (string, error) { + s.mu.Lock() + + // Already running? + if bp, ok := s.processes[backend]; ok { + if bp.proc != nil && bp.proc.IsAlive() { + s.mu.Unlock() + return bp.addr, nil + } + // Process died — clean up and restart + xlog.Warn("Backend process died unexpectedly, restarting", "backend", backend) + delete(s.processes, backend) + } + + // Allocate port — recycle freed ports first, then grow upward from basePort + var port int + if len(s.freePorts) > 0 { + port = s.freePorts[len(s.freePorts)-1] + s.freePorts = s.freePorts[:len(s.freePorts)-1] + } else { + port = s.nextPort + s.nextPort++ + } + bindAddr := fmt.Sprintf("0.0.0.0:%d", port) + clientAddr := fmt.Sprintf("127.0.0.1:%d", port) + + proc, err := s.ml.StartProcess(backendPath, backend, bindAddr) + if err != nil { + s.mu.Unlock() + return "", fmt.Errorf("starting backend process: %w", err) + } + + s.processes[backend] = &backendProcess{ + proc: proc, + backend: backend, + addr: clientAddr, + } + xlog.Info("Backend process started", "backend", backend, "addr", clientAddr) + + // Capture reference before unlocking for race-safe health check. + // Another goroutine could stopBackend and recycle the port while we poll. + bp := s.processes[backend] + s.mu.Unlock() + + // Wait for the gRPC server to be ready + client := grpc.NewClientWithToken(clientAddr, false, nil, false, s.cmd.RegistrationToken) + for range 20 { + time.Sleep(200 * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + if ok, _ := client.HealthCheck(ctx); ok { + cancel() + // Verify the process wasn't stopped/replaced while health-checking + s.mu.Lock() + currentBP, exists := s.processes[backend] + s.mu.Unlock() + if !exists || currentBP != bp { + return "", fmt.Errorf("backend %s was stopped during startup", backend) + } + xlog.Debug("Backend gRPC server is ready", "backend", backend, "addr", clientAddr) + return clientAddr, nil + } + cancel() + } + + xlog.Warn("Backend gRPC server not ready after waiting, proceeding anyway", "backend", backend, "addr", clientAddr) + return clientAddr, nil +} + +// stopBackend stops a specific backend's gRPC process. +func (s *backendSupervisor) stopBackend(backend string) { + s.mu.Lock() + bp, ok := s.processes[backend] + if !ok || bp.proc == nil { + s.mu.Unlock() + return + } + // Clean up map and recycle port while holding lock + delete(s.processes, backend) + if _, portStr, err := net.SplitHostPort(bp.addr); err == nil { + if p, err := strconv.Atoi(portStr); err == nil { + s.freePorts = append(s.freePorts, p) + } + } + s.mu.Unlock() + + // Network I/O outside the lock + client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken) + if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok { + xlog.Debug("Calling Free() before stopping backend", "backend", backend) + if err := freeFunc.Free(context.Background()); err != nil { + xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err) + } + } + + xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr) + if err := bp.proc.Stop(); err != nil { + xlog.Error("Error stopping backend process", "backend", backend, "error", err) + } +} + +// stopAllBackends stops all running backend processes. +func (s *backendSupervisor) stopAllBackends() { + s.mu.Lock() + backends := slices.Collect(maps.Keys(s.processes)) + s.mu.Unlock() + + for _, b := range backends { + s.stopBackend(b) + } +} + +// isRunning returns whether a specific backend process is currently running. +func (s *backendSupervisor) isRunning(backend string) bool { + s.mu.Lock() + defer s.mu.Unlock() + bp, ok := s.processes[backend] + return ok && bp.proc != nil && bp.proc.IsAlive() +} + +// getAddr returns the gRPC address for a running backend, or empty string. +func (s *backendSupervisor) getAddr(backend string) string { + s.mu.Lock() + defer s.mu.Unlock() + if bp, ok := s.processes[backend]; ok { + return bp.addr + } + return "" +} + +// installBackend handles the backend.install flow: +// 1. If already running for this model, return existing address +// 2. Install backend from gallery (if not already installed) +// 3. Find backend binary +// 4. Start gRPC process on a new port +// Returns the gRPC address of the backend process. +func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest) (string, error) { + // Process key: use ModelID if provided (per-model process), else backend name + processKey := req.ModelID + if processKey == "" { + processKey = req.Backend + } + + // If already running for this model, return its address + if addr := s.getAddr(processKey); addr != "" { + xlog.Info("Backend already running for model", "backend", req.Backend, "model", req.ModelID, "addr", addr) + return addr, nil + } + + // Parse galleries from request (override local config if provided) + galleries := s.galleries + if req.BackendGalleries != "" { + var reqGalleries []config.Gallery + if err := json.Unmarshal([]byte(req.BackendGalleries), &reqGalleries); err == nil { + galleries = reqGalleries + } + } + + // Try to find the backend binary + backendPath := s.findBackend(req.Backend) + if backendPath == "" { + // Backend not found locally — try auto-installing from gallery + xlog.Info("Backend not found locally, attempting gallery install", "backend", req.Backend) + if err := gallery.InstallBackendFromGallery( + context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, false, + ); err != nil { + return "", fmt.Errorf("installing backend from gallery: %w", err) + } + // Re-register after install and retry + gallery.RegisterBackends(s.systemState, s.ml) + backendPath = s.findBackend(req.Backend) + } + + if backendPath == "" { + return "", fmt.Errorf("backend %q not found after install attempt", req.Backend) + } + + xlog.Info("Found backend binary", "path", backendPath, "processKey", processKey) + + // Start the gRPC process on a new port (keyed by model, not just backend) + return s.startBackend(processKey, backendPath) +} + +// findBackend looks for the backend binary in the backends path and system path. +func (s *backendSupervisor) findBackend(backend string) string { + candidates := []string{ + filepath.Join(s.cmd.BackendsPath, backend), + filepath.Join(s.cmd.BackendsPath, backend, backend), + filepath.Join(s.cmd.BackendsSystemPath, backend), + filepath.Join(s.cmd.BackendsSystemPath, backend, backend), + } + if uri := s.ml.GetExternalBackend(backend); uri != "" { + if fi, err := os.Stat(uri); err == nil && !fi.IsDir() { + return uri + } + } + for _, path := range candidates { + fi, err := os.Stat(path) + if err == nil && !fi.IsDir() { + return path + } + } + return "" +} + +// subscribeLifecycleEvents subscribes to NATS backend lifecycle events. +func (s *backendSupervisor) subscribeLifecycleEvents() { + // backend.install — install backend + start gRPC process (request-reply) + s.nats.SubscribeReply(messaging.SubjectNodeBackendInstall(s.nodeID), func(data []byte, reply func([]byte)) { + xlog.Info("Received NATS backend.install event") + var req messaging.BackendInstallRequest + if err := json.Unmarshal(data, &req); err != nil { + resp := messaging.BackendInstallReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} + replyJSON(reply, resp) + return + } + + addr, err := s.installBackend(req) + if err != nil { + xlog.Error("Failed to install backend via NATS", "error", err) + resp := messaging.BackendInstallReply{Success: false, Error: err.Error()} + replyJSON(reply, resp) + return + } + + // Return the gRPC address so the router knows which port to use + advertiseAddr := addr + if s.cmd.AdvertiseAddr != "" { + // Replace 0.0.0.0 with the advertised host but keep the dynamic port + _, port, _ := net.SplitHostPort(addr) + advertiseHost, _, _ := net.SplitHostPort(s.cmd.AdvertiseAddr) + advertiseAddr = net.JoinHostPort(advertiseHost, port) + } + resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr} + replyJSON(reply, resp) + }) + + // backend.stop — stop a specific backend process + s.nats.Subscribe(messaging.SubjectNodeBackendStop(s.nodeID), func(data []byte) { + // Try to parse backend name from payload; if empty, stop all + var req struct { + Backend string `json:"backend"` + } + if json.Unmarshal(data, &req) == nil && req.Backend != "" { + xlog.Info("Received NATS backend.stop event", "backend", req.Backend) + s.stopBackend(req.Backend) + } else { + xlog.Info("Received NATS backend.stop event (all)") + s.stopAllBackends() + } + }) + + // backend.delete — stop backend + delete files (request-reply) + s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) { + xlog.Info("Received NATS backend.delete event") + var req messaging.BackendDeleteRequest + if err := json.Unmarshal(data, &req); err != nil { + resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} + replyJSON(reply, resp) + return + } + + // Stop if running this backend + if s.isRunning(req.Backend) { + s.stopBackend(req.Backend) + } + + // Delete the backend files + if err := gallery.DeleteBackendFromSystem(s.systemState, req.Backend); err != nil { + xlog.Warn("Failed to delete backend files", "backend", req.Backend, "error", err) + resp := messaging.BackendDeleteReply{Success: false, Error: err.Error()} + replyJSON(reply, resp) + return + } + + // Re-register backends after deletion + gallery.RegisterBackends(s.systemState, s.ml) + + resp := messaging.BackendDeleteReply{Success: true} + replyJSON(reply, resp) + }) + + // backend.list — list installed backends (request-reply) + s.nats.SubscribeReply(messaging.SubjectNodeBackendList(s.nodeID), func(data []byte, reply func([]byte)) { + xlog.Info("Received NATS backend.list event") + backends, err := gallery.ListSystemBackends(s.systemState) + if err != nil { + resp := messaging.BackendListReply{Error: err.Error()} + replyJSON(reply, resp) + return + } + + var infos []messaging.NodeBackendInfo + for name, b := range backends { + info := messaging.NodeBackendInfo{ + Name: name, + IsSystem: b.IsSystem, + IsMeta: b.IsMeta, + } + if b.Metadata != nil { + info.InstalledAt = b.Metadata.InstalledAt + info.GalleryURL = b.Metadata.GalleryURL + } + infos = append(infos, info) + } + + resp := messaging.BackendListReply{Backends: infos} + replyJSON(reply, resp) + }) + + // model.unload — call gRPC Free() to release GPU memory (request-reply) + s.nats.SubscribeReply(messaging.SubjectNodeModelUnload(s.nodeID), func(data []byte, reply func([]byte)) { + xlog.Info("Received NATS model.unload event") + var req messaging.ModelUnloadRequest + if err := json.Unmarshal(data, &req); err != nil { + resp := messaging.ModelUnloadReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} + replyJSON(reply, resp) + return + } + + // Find the backend address for this model's backend type + // The request includes an Address field if the router knows which process to target + targetAddr := req.Address + if targetAddr == "" { + // Fallback: try all running backends + s.mu.Lock() + for _, bp := range s.processes { + targetAddr = bp.addr + break + } + s.mu.Unlock() + } + + if targetAddr != "" { + // Best-effort gRPC Free() + client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken) + if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok { + if err := freeFunc.Free(context.Background()); err != nil { + xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr) + } + } + } + + resp := messaging.ModelUnloadReply{Success: true} + replyJSON(reply, resp) + }) + + // model.delete — remove model files from disk (request-reply) + s.nats.SubscribeReply(messaging.SubjectNodeModelDelete(s.nodeID), func(data []byte, reply func([]byte)) { + xlog.Info("Received NATS model.delete event") + var req messaging.ModelDeleteRequest + if err := json.Unmarshal(data, &req); err != nil { + replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: "invalid request"}) + return + } + + if err := gallery.DeleteStagedModelFiles(s.cmd.ModelsPath, req.ModelName); err != nil { + xlog.Warn("Failed to delete model files", "model", req.ModelName, "error", err) + replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: err.Error()}) + return + } + + replyJSON(reply, messaging.ModelDeleteReply{Success: true}) + }) + + // stop — trigger the normal shutdown path via sigCh so deferred cleanup runs + s.nats.Subscribe(messaging.SubjectNodeStop(s.nodeID), func(data []byte) { + xlog.Info("Received NATS stop event — signaling shutdown") + select { + case s.sigCh <- syscall.SIGTERM: + default: + xlog.Debug("Shutdown already signaled, ignoring duplicate stop") + } + }) +} + +// advertiseAddr returns the address the frontend should use to reach this node. +func (cmd *WorkerCMD) advertiseAddr() string { + if cmd.AdvertiseAddr != "" { + return cmd.AdvertiseAddr + } + host, port, ok := strings.Cut(cmd.Addr, ":") + if ok && (host == "0.0.0.0" || host == "") { + if hostname, err := os.Hostname(); err == nil { + return hostname + ":" + port + } + } + return cmd.Addr +} + +// resolveHTTPAddr returns the address to bind the HTTP file transfer server to. +// Uses basePort-1 so it doesn't conflict with dynamically allocated gRPC ports +// which grow upward from basePort. +func (cmd *WorkerCMD) resolveHTTPAddr() string { + if cmd.HTTPAddr != "" { + return cmd.HTTPAddr + } + host, port, ok := strings.Cut(cmd.Addr, ":") + if !ok { + return "0.0.0.0:50050" + } + portNum, _ := strconv.Atoi(port) + return fmt.Sprintf("%s:%d", host, portNum-1) +} + +// advertiseHTTPAddr returns the HTTP address the frontend should use to reach +// this node for file transfer. +func (cmd *WorkerCMD) advertiseHTTPAddr() string { + if cmd.AdvertiseHTTPAddr != "" { + return cmd.AdvertiseHTTPAddr + } + httpAddr := cmd.resolveHTTPAddr() + host, port, ok := strings.Cut(httpAddr, ":") + if ok && (host == "0.0.0.0" || host == "") { + if hostname, err := os.Hostname(); err == nil { + return hostname + ":" + port + } + } + return httpAddr +} + +// registrationBody builds the JSON body for node registration. +func (cmd *WorkerCMD) registrationBody() map[string]any { + nodeName := cmd.NodeName + if nodeName == "" { + hostname, err := os.Hostname() + if err != nil { + nodeName = fmt.Sprintf("node-%d", os.Getpid()) + } else { + nodeName = hostname + } + } + + // Detect GPU info for VRAM-aware scheduling + totalVRAM, _ := xsysinfo.TotalAvailableVRAM() + gpuVendor, _ := xsysinfo.DetectGPUVendor() + + body := map[string]any{ + "name": nodeName, + "address": cmd.advertiseAddr(), + "http_address": cmd.advertiseHTTPAddr(), + "total_vram": totalVRAM, + "available_vram": totalVRAM, // initially all VRAM is available + "gpu_vendor": gpuVendor, + } + + // If no GPU detected, report system RAM so the scheduler/UI has capacity info + if totalVRAM == 0 { + if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil { + body["total_ram"] = ramInfo.Total + body["available_ram"] = ramInfo.Available + } + } + if cmd.RegistrationToken != "" { + body["token"] = cmd.RegistrationToken + } + return body +} + +// heartbeatBody returns the current VRAM/RAM stats for heartbeat payloads. +func (cmd *WorkerCMD) heartbeatBody() map[string]any { + var availVRAM uint64 + aggregate := xsysinfo.GetGPUAggregateInfo() + if aggregate.TotalVRAM > 0 { + availVRAM = aggregate.FreeVRAM + } else { + // Fallback: report total as available (no usage tracking possible) + availVRAM, _ = xsysinfo.TotalAvailableVRAM() + } + + body := map[string]any{ + "available_vram": availVRAM, + } + + // If no GPU, report system RAM usage instead + if aggregate.TotalVRAM == 0 { + if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil { + body["available_ram"] = ramInfo.Available + } + } + return body +} diff --git a/core/cli/workerregistry/client.go b/core/cli/workerregistry/client.go new file mode 100644 index 000000000..b7236e6bf --- /dev/null +++ b/core/cli/workerregistry/client.go @@ -0,0 +1,272 @@ +// Package workerregistry provides a shared HTTP client for worker node +// registration, heartbeating, draining, and deregistration against a +// LocalAI frontend. Both the backend worker (WorkerCMD) and the agent +// worker (AgentWorkerCMD) use this instead of duplicating the logic. +package workerregistry + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/mudler/xlog" +) + +// RegistrationClient talks to the frontend's /api/node/* endpoints. +type RegistrationClient struct { + FrontendURL string + RegistrationToken string + HTTPTimeout time.Duration // used for registration calls; defaults to 10s + client *http.Client + clientOnce sync.Once +} + +// httpTimeout returns the configured timeout or a sensible default. +func (c *RegistrationClient) httpTimeout() time.Duration { + if c.HTTPTimeout > 0 { + return c.HTTPTimeout + } + return 10 * time.Second +} + +// httpClient returns the shared HTTP client, initializing it on first use. +func (c *RegistrationClient) httpClient() *http.Client { + c.clientOnce.Do(func() { + c.client = &http.Client{Timeout: c.httpTimeout()} + }) + return c.client +} + +// baseURL returns FrontendURL with any trailing slash stripped. +func (c *RegistrationClient) baseURL() string { + return strings.TrimRight(c.FrontendURL, "/") +} + +// setAuth adds an Authorization header when a token is configured. +func (c *RegistrationClient) setAuth(req *http.Request) { + if c.RegistrationToken != "" { + req.Header.Set("Authorization", "Bearer "+c.RegistrationToken) + } +} + +// RegisterResponse is the JSON body returned by /api/node/register. +type RegisterResponse struct { + ID string `json:"id"` + APIToken string `json:"api_token,omitempty"` +} + +// Register sends a single registration request and returns the node ID and +// (optionally) an auto-provisioned API token. +func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) { + jsonBody, _ := json.Marshal(body) + url := c.baseURL() + "/api/node/register" + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) + if err != nil { + return "", "", fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + c.setAuth(req) + + resp, err := c.httpClient().Do(req) + if err != nil { + return "", "", fmt.Errorf("posting to %s: %w", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode) + } + + var result RegisterResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", "", fmt.Errorf("decoding response: %w", err) + } + return result.ID, result.APIToken, nil +} + +// RegisterWithRetry retries registration with exponential backoff. +func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) { + backoff := 2 * time.Second + maxBackoff := 30 * time.Second + + var nodeID, apiToken string + var err error + + for attempt := 1; attempt <= maxRetries; attempt++ { + nodeID, apiToken, err = c.Register(ctx, body) + if err == nil { + return nodeID, apiToken, nil + } + if attempt == maxRetries { + return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err) + } + xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err) + select { + case <-ctx.Done(): + return "", "", ctx.Err() + case <-time.After(backoff): + } + backoff = min(backoff*2, maxBackoff) + } + return nodeID, apiToken, err +} + +// Heartbeat sends a single heartbeat POST with the given body. +func (c *RegistrationClient) Heartbeat(ctx context.Context, nodeID string, body map[string]any) error { + jsonBody, _ := json.Marshal(body) + url := c.baseURL() + "/api/node/" + nodeID + "/heartbeat" + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) + if err != nil { + return fmt.Errorf("creating heartbeat request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + c.setAuth(req) + + resp, err := c.httpClient().Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + return nil +} + +// HeartbeatLoop runs heartbeats at the given interval until ctx is cancelled. +// bodyFn is called each tick to build the heartbeat payload (e.g. VRAM stats). +func (c *RegistrationClient) HeartbeatLoop(ctx context.Context, nodeID string, interval time.Duration, bodyFn func() map[string]any) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + body := bodyFn() + if err := c.Heartbeat(ctx, nodeID, body); err != nil { + xlog.Warn("Heartbeat failed", "error", err) + } + } + } +} + +// Drain sets the node to draining status via POST /api/node/:id/drain. +func (c *RegistrationClient) Drain(ctx context.Context, nodeID string) error { + url := c.baseURL() + "/api/node/" + nodeID + "/drain" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + return fmt.Errorf("creating drain request: %w", err) + } + c.setAuth(req) + + resp, err := c.httpClient().Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("drain failed with status %d", resp.StatusCode) + } + return nil +} + +// WaitForDrain polls GET /api/node/:id/models until all models report 0 +// in-flight requests, or until timeout elapses. +func (c *RegistrationClient) WaitForDrain(ctx context.Context, nodeID string, timeout time.Duration) { + url := c.baseURL() + "/api/node/" + nodeID + "/models" + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + xlog.Warn("Failed to create drain poll request", "error", err) + return + } + c.setAuth(req) + + resp, err := c.httpClient().Do(req) + if err != nil { + xlog.Warn("Drain poll failed, will retry", "error", err) + select { + case <-ctx.Done(): + xlog.Warn("Drain wait cancelled") + return + case <-time.After(1 * time.Second): + } + continue + } + var models []struct { + InFlight int `json:"in_flight"` + } + json.NewDecoder(resp.Body).Decode(&models) + resp.Body.Close() + + total := 0 + for _, m := range models { + total += m.InFlight + } + if total == 0 { + xlog.Info("All in-flight requests drained") + return + } + xlog.Info("Waiting for in-flight requests", "count", total) + select { + case <-ctx.Done(): + xlog.Warn("Drain wait cancelled") + return + case <-time.After(1 * time.Second): + } + } + xlog.Warn("Drain timeout reached, proceeding with shutdown") +} + +// Deregister marks the node as offline via POST /api/node/:id/deregister. +// The node row is preserved in the database so re-registration restores +// approval status. +func (c *RegistrationClient) Deregister(ctx context.Context, nodeID string) error { + url := c.baseURL() + "/api/node/" + nodeID + "/deregister" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + return fmt.Errorf("creating deregister request: %w", err) + } + c.setAuth(req) + + resp, err := c.httpClient().Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("deregistration failed with status %d", resp.StatusCode) + } + return nil +} + +// GracefulDeregister performs drain -> wait -> deregister in sequence. +// This is the standard shutdown sequence for backend workers. +func (c *RegistrationClient) GracefulDeregister(nodeID string) { + if c.FrontendURL == "" || nodeID == "" { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + if err := c.Drain(ctx, nodeID); err != nil { + xlog.Warn("Failed to set drain status", "error", err) + } else { + c.WaitForDrain(ctx, nodeID, 30*time.Second) + } + + if err := c.Deregister(ctx, nodeID); err != nil { + xlog.Error("Failed to deregister", "error", err) + } else { + xlog.Info("Deregistered from frontend") + } +} diff --git a/core/clients/store.go b/core/clients/store.go index f737ee421..4b0b9c2c1 100644 --- a/core/clients/store.go +++ b/core/clients/store.go @@ -94,7 +94,7 @@ func (c *StoreClient) Find(req FindRequest) (*FindResponse, error) { } // Helper function to perform a request without expecting a response body -func (c *StoreClient) doRequest(path string, data interface{}) error { +func (c *StoreClient) doRequest(path string, data any) error { jsonData, err := json.Marshal(data) if err != nil { return err @@ -120,7 +120,7 @@ func (c *StoreClient) doRequest(path string, data interface{}) error { } // Helper function to perform a request and parse the response body -func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]byte, error) { +func (c *StoreClient) doRequestWithResponse(path string, data any) ([]byte, error) { jsonData, err := json.Marshal(data) if err != nil { return nil, err diff --git a/core/config/application_config.go b/core/config/application_config.go index 9c1be82d9..2bfb552d7 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -83,8 +83,8 @@ type ApplicationConfig struct { APIAddress string - LlamaCPPTunnelCallback func(tunnels []string) - MLXTunnelCallback func(tunnels []string) + LlamaCPPTunnelCallback func(tunnels []string) + MLXTunnelCallback func(tunnels []string) DisableRuntimeSettings bool @@ -99,47 +99,50 @@ type ApplicationConfig struct { // Authentication & Authorization Auth AuthConfig + + // Distributed / Horizontal Scaling + Distributed DistributedConfig } // AuthConfig holds configuration for user authentication and authorization. type AuthConfig struct { - Enabled bool - DatabaseURL string // "postgres://..." or file path for SQLite - GitHubClientID string - GitHubClientSecret string - OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com) - OIDCClientID string - OIDCClientSecret string - BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080") - AdminEmail string // auto-promote to admin on login - RegistrationMode string // "open", "approval" (default when empty), "invite" - DisableLocalAuth bool // disable local email/password registration and login - APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty + Enabled bool + DatabaseURL string // "postgres://..." or file path for SQLite + GitHubClientID string + GitHubClientSecret string + OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com) + OIDCClientID string + OIDCClientSecret string + BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080") + AdminEmail string // auto-promote to admin on login + RegistrationMode string // "open", "approval" (default when empty), "invite" + DisableLocalAuth bool // disable local email/password registration and login + APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry } // AgentPoolConfig holds configuration for the LocalAGI agent pool integration. type AgentPoolConfig struct { - Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true) - StateDir string // default: DynamicConfigsDir (LocalAI configuration folder) - APIURL string // default: self-referencing LocalAI (http://127.0.0.1:) - APIKey string // default: first API key from LocalAI config - DefaultModel string - MultimodalModel string - TranscriptionModel string - TranscriptionLanguage string - TTSModel string - Timeout string // default: "5m" - EnableSkills bool - EnableLogs bool - CustomActionsDir string - CollectionDBPath string - VectorEngine string // default: "chromem" - EmbeddingModel string // default: "granite-embedding-107m-multilingual" - MaxChunkingSize int // default: 400 - ChunkOverlap int // default: 0 - DatabaseURL string - AgentHubURL string // default: "https://agenthub.localai.io" + Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true) + StateDir string // default: DynamicConfigsDir (LocalAI configuration folder) + APIURL string // default: self-referencing LocalAI (http://127.0.0.1:) + APIKey string // default: first API key from LocalAI config + DefaultModel string + MultimodalModel string + TranscriptionModel string + TranscriptionLanguage string + TTSModel string + Timeout string // default: "5m" + EnableSkills bool + EnableLogs bool + CustomActionsDir string + CollectionDBPath string + VectorEngine string // default: "chromem" + EmbeddingModel string // default: "granite-embedding-107m-multilingual" + MaxChunkingSize int // default: 400 + ChunkOverlap int // default: 0 + DatabaseURL string + AgentHubURL string // default: "https://agenthub.localai.io" } type AppOption func(*ApplicationConfig) @@ -155,12 +158,12 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig { WatchDogInterval: 500 * time.Millisecond, // Default: 500ms TracingMaxItems: 1024, AgentPool: AgentPoolConfig{ - Enabled: true, - Timeout: "5m", - VectorEngine: "chromem", - EmbeddingModel: "granite-embedding-107m-multilingual", + Enabled: true, + Timeout: "5m", + VectorEngine: "chromem", + EmbeddingModel: "granite-embedding-107m-multilingual", MaxChunkingSize: 400, - AgentHubURL: "https://agenthub.localai.io", + AgentHubURL: "https://agenthub.localai.io", }, PathWithoutAuth: []string{ "/static/", @@ -904,40 +907,40 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { agentPoolCollectionDBPath := o.AgentPool.CollectionDBPath return RuntimeSettings{ - WatchdogEnabled: &watchdogEnabled, - WatchdogIdleEnabled: &watchdogIdle, - WatchdogBusyEnabled: &watchdogBusy, - WatchdogIdleTimeout: &idleTimeout, - WatchdogBusyTimeout: &busyTimeout, - WatchdogInterval: &watchdogInterval, - SingleBackend: &singleBackend, - MaxActiveBackends: &maxActiveBackends, - ParallelBackendRequests: ¶llelBackendRequests, - MemoryReclaimerEnabled: &memoryReclaimerEnabled, - MemoryReclaimerThreshold: &memoryReclaimerThreshold, - ForceEvictionWhenBusy: &forceEvictionWhenBusy, - LRUEvictionMaxRetries: &lruEvictionMaxRetries, - LRUEvictionRetryInterval: &lruEvictionRetryInterval, - Threads: &threads, - ContextSize: &contextSize, - F16: &f16, - Debug: &debug, - TracingMaxItems: &tracingMaxItems, - EnableTracing: &enableTracing, - EnableBackendLogging: &enableBackendLogging, - CORS: &cors, - CSRF: &csrf, - CORSAllowOrigins: &corsAllowOrigins, - P2PToken: &p2pToken, - P2PNetworkID: &p2pNetworkID, - Federated: &federated, - Galleries: &galleries, - BackendGalleries: &backendGalleries, - AutoloadGalleries: &autoloadGalleries, - AutoloadBackendGalleries: &autoloadBackendGalleries, - ApiKeys: &apiKeys, - AgentJobRetentionDays: &agentJobRetentionDays, - OpenResponsesStoreTTL: &openResponsesStoreTTL, + WatchdogEnabled: &watchdogEnabled, + WatchdogIdleEnabled: &watchdogIdle, + WatchdogBusyEnabled: &watchdogBusy, + WatchdogIdleTimeout: &idleTimeout, + WatchdogBusyTimeout: &busyTimeout, + WatchdogInterval: &watchdogInterval, + SingleBackend: &singleBackend, + MaxActiveBackends: &maxActiveBackends, + ParallelBackendRequests: ¶llelBackendRequests, + MemoryReclaimerEnabled: &memoryReclaimerEnabled, + MemoryReclaimerThreshold: &memoryReclaimerThreshold, + ForceEvictionWhenBusy: &forceEvictionWhenBusy, + LRUEvictionMaxRetries: &lruEvictionMaxRetries, + LRUEvictionRetryInterval: &lruEvictionRetryInterval, + Threads: &threads, + ContextSize: &contextSize, + F16: &f16, + Debug: &debug, + TracingMaxItems: &tracingMaxItems, + EnableTracing: &enableTracing, + EnableBackendLogging: &enableBackendLogging, + CORS: &cors, + CSRF: &csrf, + CORSAllowOrigins: &corsAllowOrigins, + P2PToken: &p2pToken, + P2PNetworkID: &p2pNetworkID, + Federated: &federated, + Galleries: &galleries, + BackendGalleries: &backendGalleries, + AutoloadGalleries: &autoloadGalleries, + AutoloadBackendGalleries: &autoloadBackendGalleries, + ApiKeys: &apiKeys, + AgentJobRetentionDays: &agentJobRetentionDays, + OpenResponsesStoreTTL: &openResponsesStoreTTL, AgentPoolEnabled: &agentPoolEnabled, AgentPoolDefaultModel: &agentPoolDefaultModel, AgentPoolEmbeddingModel: &agentPoolEmbeddingModel, diff --git a/core/config/application_config_test.go b/core/config/application_config_test.go index c6c6c15b9..4ea89e8d6 100644 --- a/core/config/application_config_test.go +++ b/core/config/application_config_test.go @@ -26,7 +26,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { F16: true, Debug: true, CORS: true, - DisableCSRF: true, + DisableCSRF: true, CORSAllowOrigins: "https://example.com", P2PToken: "test-token", P2PNetworkID: "test-network", @@ -463,7 +463,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { F16: true, Debug: false, CORS: true, - DisableCSRF: false, + DisableCSRF: false, CORSAllowOrigins: "https://test.com", P2PToken: "round-trip-token", P2PNetworkID: "round-trip-network", diff --git a/core/config/distributed_config.go b/core/config/distributed_config.go new file mode 100644 index 000000000..8fc7f6518 --- /dev/null +++ b/core/config/distributed_config.go @@ -0,0 +1,188 @@ +package config + +import ( + "cmp" + "fmt" + "time" + + "github.com/mudler/xlog" +) + +// DistributedConfig holds configuration for horizontal scaling mode. +// When Enabled is true, PostgreSQL and NATS are required. +type DistributedConfig struct { + Enabled bool // --distributed / LOCALAI_DISTRIBUTED + InstanceID string // --instance-id / LOCALAI_INSTANCE_ID (auto-generated UUID if empty) + NatsURL string // --nats-url / LOCALAI_NATS_URL + StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint) + RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration) + AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers) + + // S3 configuration (used when StorageURL is set) + StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET + StorageRegion string // --storage-region / LOCALAI_STORAGE_REGION + StorageAccessKey string // --storage-access-key / LOCALAI_STORAGE_ACCESS_KEY + StorageSecretKey string // --storage-secret-key / LOCALAI_STORAGE_SECRET_KEY + + // Timeout configuration (all have sensible defaults — zero means use default) + MCPToolTimeout time.Duration // MCP tool execution timeout (default 360s) + MCPDiscoveryTimeout time.Duration // MCP discovery timeout (default 60s) + WorkerWaitTimeout time.Duration // Max wait for healthy worker at startup (default 5m) + DrainTimeout time.Duration // Time to wait for in-flight requests during drain (default 30s) + HealthCheckInterval time.Duration // Health monitor check interval (default 15s) + StaleNodeThreshold time.Duration // Time before a node is considered stale (default 60s) + PerModelHealthCheck bool // Enable per-model backend health checking (default false) + MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m) + + MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB) + + AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"` + JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"` +} + +// Validate checks that the distributed configuration is internally consistent. +// It returns nil if distributed mode is disabled. +func (c DistributedConfig) Validate() error { + if !c.Enabled { + return nil + } + if c.NatsURL == "" { + return fmt.Errorf("distributed mode requires --nats-url / LOCALAI_NATS_URL") + } + // S3 credentials must be paired + if (c.StorageAccessKey != "" && c.StorageSecretKey == "") || + (c.StorageAccessKey == "" && c.StorageSecretKey != "") { + return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty") + } + // Warn about missing registration token (not an error) + if c.RegistrationToken == "" { + xlog.Warn("distributed mode running without registration token — node endpoints are unprotected") + } + // Check for negative durations + for name, d := range map[string]time.Duration{ + "mcp-tool-timeout": c.MCPToolTimeout, + "mcp-discovery-timeout": c.MCPDiscoveryTimeout, + "worker-wait-timeout": c.WorkerWaitTimeout, + "drain-timeout": c.DrainTimeout, + "health-check-interval": c.HealthCheckInterval, + "stale-node-threshold": c.StaleNodeThreshold, + "mcp-ci-job-timeout": c.MCPCIJobTimeout, + } { + if d < 0 { + return fmt.Errorf("%s must not be negative", name) + } + } + return nil +} + +// Distributed config options + +var EnableDistributed = func(o *ApplicationConfig) { + o.Distributed.Enabled = true +} + +func WithDistributedInstanceID(id string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.InstanceID = id + } +} + +func WithNatsURL(url string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.NatsURL = url + } +} + +func WithRegistrationToken(token string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.RegistrationToken = token + } +} + +func WithStorageURL(url string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.StorageURL = url + } +} + +func WithStorageBucket(bucket string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.StorageBucket = bucket + } +} + +func WithStorageRegion(region string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.StorageRegion = region + } +} + +func WithStorageAccessKey(key string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.StorageAccessKey = key + } +} + +func WithStorageSecretKey(key string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.StorageSecretKey = key + } +} + +var EnableAutoApproveNodes = func(o *ApplicationConfig) { + o.Distributed.AutoApproveNodes = true +} + +// Defaults for distributed timeouts. +const ( + DefaultMCPToolTimeout = 360 * time.Second + DefaultMCPDiscoveryTimeout = 60 * time.Second + DefaultWorkerWaitTimeout = 5 * time.Minute + DefaultDrainTimeout = 30 * time.Second + DefaultHealthCheckInterval = 15 * time.Second + DefaultStaleNodeThreshold = 60 * time.Second + DefaultMCPCIJobTimeout = 10 * time.Minute +) + +// DefaultMaxUploadSize is the default maximum upload body size (50 GB). +const DefaultMaxUploadSize int64 = 50 << 30 + +// MCPToolTimeoutOrDefault returns the configured timeout or the default. +func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration { + return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout) +} + +// MCPDiscoveryTimeoutOrDefault returns the configured timeout or the default. +func (c DistributedConfig) MCPDiscoveryTimeoutOrDefault() time.Duration { + return cmp.Or(c.MCPDiscoveryTimeout, DefaultMCPDiscoveryTimeout) +} + +// WorkerWaitTimeoutOrDefault returns the configured timeout or the default. +func (c DistributedConfig) WorkerWaitTimeoutOrDefault() time.Duration { + return cmp.Or(c.WorkerWaitTimeout, DefaultWorkerWaitTimeout) +} + +// DrainTimeoutOrDefault returns the configured timeout or the default. +func (c DistributedConfig) DrainTimeoutOrDefault() time.Duration { + return cmp.Or(c.DrainTimeout, DefaultDrainTimeout) +} + +// HealthCheckIntervalOrDefault returns the configured interval or the default. +func (c DistributedConfig) HealthCheckIntervalOrDefault() time.Duration { + return cmp.Or(c.HealthCheckInterval, DefaultHealthCheckInterval) +} + +// StaleNodeThresholdOrDefault returns the configured threshold or the default. +func (c DistributedConfig) StaleNodeThresholdOrDefault() time.Duration { + return cmp.Or(c.StaleNodeThreshold, DefaultStaleNodeThreshold) +} + +// MCPCIJobTimeoutOrDefault returns the configured MCP CI job timeout or the default. +func (c DistributedConfig) MCPCIJobTimeoutOrDefault() time.Duration { + return cmp.Or(c.MCPCIJobTimeout, DefaultMCPCIJobTimeout) +} + +// MaxUploadSizeOrDefault returns the configured max upload size or the default. +func (c DistributedConfig) MaxUploadSizeOrDefault() int64 { + return cmp.Or(c.MaxUploadSize, DefaultMaxUploadSize) +} diff --git a/core/config/model_config.go b/core/config/model_config.go index 0d148eac1..a4815c766 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -46,11 +46,11 @@ type ModelConfig struct { KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"` Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"` - PromptStrings, InputStrings []string `yaml:"-" json:"-"` - InputToken [][]int `yaml:"-" json:"-"` - functionCallString, functionCallNameString string `yaml:"-" json:"-"` - ResponseFormat string `yaml:"-" json:"-"` - ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"` + PromptStrings, InputStrings []string `yaml:"-" json:"-"` + InputToken [][]int `yaml:"-" json:"-"` + functionCallString, functionCallNameString string `yaml:"-" json:"-"` + ResponseFormat string `yaml:"-" json:"-"` + ResponseFormatMap map[string]any `yaml:"-" json:"-"` FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"` ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"` @@ -105,6 +105,11 @@ type AgentConfig struct { ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"` } +// HasMCPServers returns true if any MCP servers (remote or stdio) are configured. +func (c MCPConfig) HasMCPServers() bool { + return c.Servers != "" || c.Stdio != "" +} + func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) { var remote MCPGenericConfig[MCPRemoteServers] var stdio MCPGenericConfig[MCPSTDIOServers] @@ -619,15 +624,32 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool { // In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half. // This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently. func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool { + // Backends that are clearly not text-generation + nonTextGenBackends := []string{ + "whisper", "piper", "kokoro", + "diffusers", "stablediffusion", "stablediffusion-ggml", + "rerankers", "silero-vad", "rfdetr", + "transformers-musicgen", "ace-step", "acestep-cpp", + } + if (u & FLAG_CHAT) == FLAG_CHAT { if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate { return false } + if slices.Contains(nonTextGenBackends, c.Backend) { + return false + } + if c.Embeddings != nil && *c.Embeddings { + return false + } } if (u & FLAG_COMPLETION) == FLAG_COMPLETION { if c.TemplateConfig.Completion == "" { return false } + if slices.Contains(nonTextGenBackends, c.Backend) { + return false + } } if (u & FLAG_EDIT) == FLAG_EDIT { if c.TemplateConfig.Edit == "" { diff --git a/core/config/model_config_filter.go b/core/config/model_config_filter.go index cb7cc0bfd..f32e6c12a 100644 --- a/core/config/model_config_filter.go +++ b/core/config/model_config_filter.go @@ -1,35 +1,35 @@ -package config - -import "regexp" - -type ModelConfigFilterFn func(string, *ModelConfig) bool - -func NoFilterFn(_ string, _ *ModelConfig) bool { return true } - -func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) { - if filter == "" { - return NoFilterFn, nil - } - rxp, err := regexp.Compile(filter) - if err != nil { - return nil, err - } - return func(name string, config *ModelConfig) bool { - if config != nil { - return rxp.MatchString(config.Name) - } - return rxp.MatchString(name) - }, nil -} - -func BuildUsecaseFilterFn(usecases ModelConfigUsecase) ModelConfigFilterFn { - if usecases == FLAG_ANY { - return NoFilterFn - } - return func(name string, config *ModelConfig) bool { - if config == nil { - return false // TODO: Potentially make this a param, for now, no known usecase to include - } - return config.HasUsecases(usecases) - } -} +package config + +import "regexp" + +type ModelConfigFilterFn func(string, *ModelConfig) bool + +func NoFilterFn(_ string, _ *ModelConfig) bool { return true } + +func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) { + if filter == "" { + return NoFilterFn, nil + } + rxp, err := regexp.Compile(filter) + if err != nil { + return nil, err + } + return func(name string, config *ModelConfig) bool { + if config != nil { + return rxp.MatchString(config.Name) + } + return rxp.MatchString(name) + }, nil +} + +func BuildUsecaseFilterFn(usecases ModelConfigUsecase) ModelConfigFilterFn { + if usecases == FLAG_ANY { + return NoFilterFn + } + return func(name string, config *ModelConfig) bool { + if config == nil { + return false // TODO: Potentially make this a param, for now, no known usecase to include + } + return config.HasUsecases(usecases) + } +} diff --git a/core/config/model_config_loader.go b/core/config/model_config_loader.go index 68647a086..53f28f2cd 100644 --- a/core/config/model_config_loader.go +++ b/core/config/model_config_loader.go @@ -1,12 +1,13 @@ package config import ( + "cmp" "errors" "fmt" "io/fs" "os" "path/filepath" - "sort" + "slices" "strings" "sync" @@ -215,8 +216,8 @@ func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig { res = append(res, v) } - sort.SliceStable(res, func(i, j int) bool { - return res[i].Name < res[j].Name + slices.SortStableFunc(res, func(a, b ModelConfig) int { + return cmp.Compare(a.Name, b.Name) }) return res diff --git a/core/config/runtime_settings.go b/core/config/runtime_settings.go index 7637a0d94..611759110 100644 --- a/core/config/runtime_settings.go +++ b/core/config/runtime_settings.go @@ -27,15 +27,15 @@ type RuntimeSettings struct { MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%) // Eviction settings - ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety) - LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30) - LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s) + ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety) + LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30) + LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s) // Performance settings - Threads *int `json:"threads,omitempty"` - ContextSize *int `json:"context_size,omitempty"` - F16 *bool `json:"f16,omitempty"` - Debug *bool `json:"debug,omitempty"` + Threads *int `json:"threads,omitempty"` + ContextSize *int `json:"context_size,omitempty"` + F16 *bool `json:"f16,omitempty"` + Debug *bool `json:"debug,omitempty"` EnableTracing *bool `json:"enable_tracing,omitempty"` TracingMaxItems *int `json:"tracing_max_items,omitempty"` EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"` @@ -66,11 +66,11 @@ type RuntimeSettings struct { OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration) // Agent Pool settings - AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"` - AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"` - AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"` - AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"` - AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"` - AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"` + AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"` + AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"` + AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"` + AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"` + AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"` + AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"` AgentPoolCollectionDBPath *string `json:"agent_pool_collection_db_path,omitempty"` } diff --git a/core/explorer/database.go b/core/explorer/database.go index e24de0aad..6c7365356 100644 --- a/core/explorer/database.go +++ b/core/explorer/database.go @@ -3,9 +3,10 @@ package explorer // A simple JSON database for storing and retrieving p2p network tokens and a name and description. import ( + "cmp" "encoding/json" "os" - "sort" + "slices" "sync" "github.com/gofrs/flock" @@ -89,9 +90,8 @@ func (db *Database) TokenList() []string { tokens = append(tokens, k) } - sort.Slice(tokens, func(i, j int) bool { - // sort by token - return tokens[i] < tokens[j] + slices.SortFunc(tokens, func(a, b string) int { + return cmp.Compare(a, b) }) return tokens diff --git a/core/gallery/backend_resolve.go b/core/gallery/backend_resolve.go index 64a89c504..c38c9b2c0 100644 --- a/core/gallery/backend_resolve.go +++ b/core/gallery/backend_resolve.go @@ -15,7 +15,7 @@ import ( // modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config. type modelConfigCacheEntry struct { - configMap map[string]interface{} + configMap map[string]any lastUpdated time.Time } @@ -57,7 +57,7 @@ func resolveBackend(m *GalleryModel, basePath string) string { // fetchModelConfigMap fetches a model config URL, parses the config_file YAML string // inside it, and returns the result as a map. Results are cached for 1 hour. // Local file:// URLs skip the cache so edits are picked up immediately. -func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} { +func fetchModelConfigMap(modelURL, basePath string) map[string]any { // Check cache (skip for file:// URLs so local edits are picked up immediately) isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix) if !isLocal && modelConfigCache.Exists(modelURL) { @@ -75,15 +75,15 @@ func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} { // Cache the failure for remote URLs to avoid repeated fetch attempts if !isLocal { modelConfigCache.Set(modelURL, modelConfigCacheEntry{ - configMap: map[string]interface{}{}, + configMap: map[string]any{}, lastUpdated: time.Now(), }) } - return map[string]interface{}{} + return map[string]any{} } // Parse the config_file YAML string into a map - configMap := make(map[string]interface{}) + configMap := make(map[string]any) if modelConfig.ConfigFile != "" { if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil { xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err) @@ -108,13 +108,11 @@ func prefetchModelConfigs(urls []string, basePath string) { sem := make(chan struct{}, maxConcurrency) var wg sync.WaitGroup for _, url := range urls { - wg.Add(1) - go func(u string) { - defer wg.Done() + wg.Go(func() { sem <- struct{}{} defer func() { <-sem }() - fetchModelConfigMap(u, basePath) - }(url) + fetchModelConfigMap(url, basePath) + }) } wg.Wait() } diff --git a/core/gallery/backends.go b/core/gallery/backends.go index 83d1cd40c..e06179074 100644 --- a/core/gallery/backends.go +++ b/core/gallery/backends.go @@ -4,10 +4,10 @@ package gallery import ( "context" - "os" "encoding/json" "errors" "fmt" + "os" "path/filepath" "strings" "time" @@ -20,6 +20,9 @@ import ( cp "github.com/otiai10/copy" ) +// ErrBackendNotFound is returned when a backend is not found in the system. +var ErrBackendNotFound = errors.New("backend not found") + const ( metadataFile = "metadata.json" runFile = "run.sh" @@ -198,9 +201,16 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL } else { xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath) if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil { - // Don't remove backendPath here — fallback OCI extractions need the directory to exist xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err) + // resetBackendPath cleans up partial state from a failed OCI extraction + // so the next download attempt starts fresh. The directory is re-created + // because OCI image extractors need it to exist for writing files into. + resetBackendPath := func() { + os.RemoveAll(backendPath) + os.MkdirAll(backendPath, 0750) + } + success := false // Try to download from mirrors for _, mirror := range config.Mirrors { @@ -210,6 +220,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL return ctx.Err() default: } + resetBackendPath() if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { success = true xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath) @@ -221,28 +232,22 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL // Try fallback: replace latestTag + "-" with masterTag + "-" in the URI fallbackURI := strings.Replace(string(config.URI), latestTag+"-", masterTag+"-", 1) if fallbackURI != string(config.URI) { - xlog.Debug("Trying fallback URI", "original", config.URI, "fallback", fallbackURI) + resetBackendPath() + xlog.Info("Trying fallback URI", "original", config.URI, "fallback", fallbackURI) if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath) success = true } else { - // Try another fallback: add "-" + devSuffix suffix to the backend name - // For example: master-gpu-nvidia-cuda-13-ace-step -> master-gpu-nvidia-cuda-13-ace-step-development + xlog.Info("Fallback URI failed", "fallback", fallbackURI, "error", err) if !strings.Contains(fallbackURI, "-"+devSuffix) { - // Extract backend name from URI and add -development - parts := strings.Split(fallbackURI, "-") - if len(parts) >= 2 { - // Find where the backend name ends (usually the last part before the tag) - // Pattern: quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-ace-step - lastDash := strings.LastIndex(fallbackURI, "-") - if lastDash > 0 { - devFallbackURI := fallbackURI[:lastDash] + "-" + devSuffix - xlog.Debug("Trying development fallback URI", "fallback", devFallbackURI) - if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { - xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath) - success = true - } - } + resetBackendPath() + devFallbackURI := fallbackURI + "-" + devSuffix + xlog.Info("Trying development fallback URI", "fallback", devFallbackURI) + if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { + xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath) + success = true + } else { + xlog.Info("Development fallback URI failed", "fallback", devFallbackURI, "error", err) } } } @@ -295,7 +300,7 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error backend, ok := backends.Get(name) if !ok { - return fmt.Errorf("backend %q not found", name) + return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound) } if backend.IsSystem { diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index 607ad37d1..0b0791afe 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -5,7 +5,7 @@ import ( "fmt" "os" "path/filepath" - "sort" + "slices" "strings" "time" @@ -106,64 +106,64 @@ func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] { } func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] { - sort.Slice(gm, func(i, j int) bool { - if sortOrder == "asc" { - return strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName()) - } else { - return strings.ToLower(gm[i].GetName()) > strings.ToLower(gm[j].GetName()) + slices.SortFunc(gm, func(a, b T) int { + r := strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName())) + if sortOrder == "desc" { + return -r } + return r }) return gm } func (gm GalleryElements[T]) SortByRepository(sortOrder string) GalleryElements[T] { - sort.Slice(gm, func(i, j int) bool { - if sortOrder == "asc" { - return strings.ToLower(gm[i].GetGallery().Name) < strings.ToLower(gm[j].GetGallery().Name) - } else { - return strings.ToLower(gm[i].GetGallery().Name) > strings.ToLower(gm[j].GetGallery().Name) + slices.SortFunc(gm, func(a, b T) int { + r := strings.Compare(strings.ToLower(a.GetGallery().Name), strings.ToLower(b.GetGallery().Name)) + if sortOrder == "desc" { + return -r } + return r }) return gm } func (gm GalleryElements[T]) SortByLicense(sortOrder string) GalleryElements[T] { - sort.Slice(gm, func(i, j int) bool { - licenseI := gm[i].GetLicense() - licenseJ := gm[j].GetLicense() - var result bool - if licenseI == "" && licenseJ != "" { - return sortOrder == "desc" - } else if licenseI != "" && licenseJ == "" { - return sortOrder == "asc" - } else if licenseI == "" && licenseJ == "" { - return false + slices.SortFunc(gm, func(a, b T) int { + licenseA := a.GetLicense() + licenseB := b.GetLicense() + var r int + if licenseA == "" && licenseB != "" { + r = 1 + } else if licenseA != "" && licenseB == "" { + r = -1 } else { - result = strings.ToLower(licenseI) < strings.ToLower(licenseJ) + r = strings.Compare(strings.ToLower(licenseA), strings.ToLower(licenseB)) } if sortOrder == "desc" { - return !result - } else { - return result + return -r } + return r }) return gm } func (gm GalleryElements[T]) SortByInstalled(sortOrder string) GalleryElements[T] { - sort.Slice(gm, func(i, j int) bool { - var result bool + slices.SortFunc(gm, func(a, b T) int { + var r int // Sort by installed status: installed items first (true > false) - if gm[i].GetInstalled() != gm[j].GetInstalled() { - result = gm[i].GetInstalled() + if a.GetInstalled() != b.GetInstalled() { + if a.GetInstalled() { + r = -1 + } else { + r = 1 + } } else { - result = strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName()) + r = strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName())) } if sortOrder == "desc" { - return !result - } else { - return result + return -r } + return r }) return gm } diff --git a/core/gallery/gallery_test.go b/core/gallery/gallery_test.go index ef09c076d..2d6512622 100644 --- a/core/gallery/gallery_test.go +++ b/core/gallery/gallery_test.go @@ -27,7 +27,7 @@ var _ = Describe("Gallery", func() { Describe("ReadConfigFile", func() { It("should read and unmarshal a valid YAML file", func() { - testConfig := map[string]interface{}{ + testConfig := map[string]any{ "name": "test-model", "description": "A test model", "license": "MIT", @@ -39,8 +39,8 @@ var _ = Describe("Gallery", func() { err = os.WriteFile(filePath, yamlData, 0644) Expect(err).NotTo(HaveOccurred()) - var result map[string]interface{} - config, err := ReadConfigFile[map[string]interface{}](filePath) + var result map[string]any + config, err := ReadConfigFile[map[string]any](filePath) Expect(err).NotTo(HaveOccurred()) Expect(config).NotTo(BeNil()) result = *config @@ -50,7 +50,7 @@ var _ = Describe("Gallery", func() { }) It("should return error when file does not exist", func() { - _, err := ReadConfigFile[map[string]interface{}]("nonexistent.yaml") + _, err := ReadConfigFile[map[string]any]("nonexistent.yaml") Expect(err).To(HaveOccurred()) }) @@ -59,7 +59,7 @@ var _ = Describe("Gallery", func() { err := os.WriteFile(filePath, []byte("invalid: yaml: content: [unclosed"), 0644) Expect(err).NotTo(HaveOccurred()) - _, err = ReadConfigFile[map[string]interface{}](filePath) + _, err = ReadConfigFile[map[string]any](filePath) Expect(err).To(HaveOccurred()) }) }) @@ -552,32 +552,32 @@ var _ = Describe("Gallery", func() { // Verify first model Expect(models[0].Name).To(Equal("nanbeige4.1-3b-q8")) Expect(models[0].Overrides).NotTo(BeNil()) - Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{})) - params := models[0].Overrides["parameters"].(map[string]interface{}) + Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{})) + params := models[0].Overrides["parameters"].(map[string]any) Expect(params["model"]).To(Equal("nanbeige4.1-3b-q8_0.gguf")) // Verify second model (merged) Expect(models[1].Name).To(Equal("nanbeige4.1-3b-q4")) Expect(models[1].Overrides).NotTo(BeNil()) - Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{})) - params = models[1].Overrides["parameters"].(map[string]interface{}) + Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{})) + params = models[1].Overrides["parameters"].(map[string]any) Expect(params["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf")) // Simulate the mergo.Merge call that was failing in models.go:251 // This should not panic with yaml.v3 - configMap := make(map[string]interface{}) + configMap := make(map[string]any) configMap["name"] = "test" configMap["backend"] = "llama-cpp" - configMap["parameters"] = map[string]interface{}{ + configMap["parameters"] = map[string]any{ "model": "original.gguf", } err = mergo.Merge(&configMap, models[1].Overrides, mergo.WithOverride) Expect(err).NotTo(HaveOccurred()) Expect(configMap["parameters"]).NotTo(BeNil()) - + // Verify the merge worked correctly - mergedParams := configMap["parameters"].(map[string]interface{}) + mergedParams := configMap["parameters"].(map[string]any) Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf")) }) }) diff --git a/core/gallery/importers/local_test.go b/core/gallery/importers/local_test.go index 0de679462..9d6a6519a 100644 --- a/core/gallery/importers/local_test.go +++ b/core/gallery/importers/local_test.go @@ -59,7 +59,7 @@ var _ = Describe("ImportLocalPath", func() { adapterConfig := map[string]any{ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf", - "peft_type": "LORA", + "peft_type": "LORA", } data, _ := json.Marshal(adapterConfig) Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed()) diff --git a/core/gallery/models.go b/core/gallery/models.go index 3aa5e4db8..c2277be90 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -158,7 +158,7 @@ func InstallModelFromGallery( return applyModel(model) } -func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) { +func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]any, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) { basePath := systemState.Model.ModelsPath // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0750) @@ -239,7 +239,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver configFilePath := filepath.Join(basePath, name+".yaml") // Read and update config file as map[string]interface{} - configMap := make(map[string]interface{}) + configMap := make(map[string]any) err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap) if err != nil { return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err) diff --git a/core/gallery/models_test.go b/core/gallery/models_test.go index c67243599..52071ab65 100644 --- a/core/gallery/models_test.go +++ b/core/gallery/models_test.go @@ -35,7 +35,7 @@ var _ = Describe("Model test", func() { system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) - _, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) + _, err = InstallModel(context.TODO(), systemState, "", c, map[string]any{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { @@ -43,7 +43,7 @@ var _ = Describe("Model test", func() { Expect(err).ToNot(HaveOccurred()) } - content := map[string]interface{}{} + content := map[string]any{} dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml")) Expect(err).ToNot(HaveOccurred()) @@ -95,7 +95,7 @@ var _ = Describe("Model test", func() { dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) - content := map[string]interface{}{} + content := map[string]any{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) @@ -130,7 +130,7 @@ var _ = Describe("Model test", func() { system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) - _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) + _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -150,7 +150,7 @@ var _ = Describe("Model test", func() { system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) - _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true) + _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{"backend": "foo"}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -158,7 +158,7 @@ var _ = Describe("Model test", func() { Expect(err).ToNot(HaveOccurred()) } - content := map[string]interface{}{} + content := map[string]any{} dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml")) Expect(err).ToNot(HaveOccurred()) @@ -180,7 +180,7 @@ var _ = Describe("Model test", func() { system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) - _, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) + _, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]any{}, func(string, string, string, float64) {}, true) Expect(err).To(HaveOccurred()) }) diff --git a/core/gallery/models_types.go b/core/gallery/models_types.go index 000aa2b26..f70a5b222 100644 --- a/core/gallery/models_types.go +++ b/core/gallery/models_types.go @@ -12,9 +12,9 @@ import ( type GalleryModel struct { Metadata `json:",inline" yaml:",inline"` // config_file is read in the situation where URL is blank - and therefore this is a base config. - ConfigFile map[string]interface{} `json:"config_file,omitempty" yaml:"config_file,omitempty"` + ConfigFile map[string]any `json:"config_file,omitempty" yaml:"config_file,omitempty"` // Overrides are used to override the configuration of the model located at URL - Overrides map[string]interface{} `json:"overrides,omitempty" yaml:"overrides,omitempty"` + Overrides map[string]any `json:"overrides,omitempty" yaml:"overrides,omitempty"` } func (m *GalleryModel) GetInstalled() bool { diff --git a/core/gallery/worker.go b/core/gallery/worker.go new file mode 100644 index 000000000..292151f27 --- /dev/null +++ b/core/gallery/worker.go @@ -0,0 +1,66 @@ +package gallery + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/mudler/xlog" +) + +// DeleteStagedModelFiles removes all staged files for a model from a worker's +// models directory. Files are expected to be in a subdirectory named after the +// model's tracking key (created by stageModelFiles in the router). +// +// Workers receive model files via S3/HTTP file staging, not gallery install, +// so they lack the YAML configs that DeleteModelFromSystem requires. +// +// Falls back to glob-based cleanup for single-file models or legacy layouts. +func DeleteStagedModelFiles(modelsPath, modelName string) error { + if modelName == "" { + return fmt.Errorf("empty model name") + } + + // Clean and validate: resolved path must stay within modelsPath + modelPath := filepath.Clean(filepath.Join(modelsPath, modelName)) + absModels := filepath.Clean(modelsPath) + if !strings.HasPrefix(modelPath, absModels+string(filepath.Separator)) { + return fmt.Errorf("model name %q escapes models directory", modelName) + } + + // Primary: remove the model's subdirectory (contains all staged files) + if info, err := os.Stat(modelPath); err == nil && info.IsDir() { + return os.RemoveAll(modelPath) + } + + // Fallback for single-file models or legacy layouts: + // remove exact file match + glob siblings + removed := false + if _, err := os.Stat(modelPath); err == nil { + if err := os.Remove(modelPath); err != nil { + xlog.Warn("Failed to remove model file", "path", modelPath, "error", err) + } else { + removed = true + } + } + + // Remove sibling files (e.g., model.gguf.mmproj alongside model.gguf) + matches, _ := filepath.Glob(modelPath + ".*") + for _, m := range matches { + clean := filepath.Clean(m) + if !strings.HasPrefix(clean, absModels+string(filepath.Separator)) { + continue // skip any glob result that escapes + } + if err := os.Remove(clean); err != nil { + xlog.Warn("Failed to remove model-related file", "path", clean, "error", err) + } else { + removed = true + } + } + + if !removed { + xlog.Debug("No files found to delete for model", "model", modelName, "path", modelPath) + } + return nil +} diff --git a/core/gallery/worker_test.go b/core/gallery/worker_test.go new file mode 100644 index 000000000..5f7dae6c4 --- /dev/null +++ b/core/gallery/worker_test.go @@ -0,0 +1,99 @@ +package gallery_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/mudler/LocalAI/core/gallery" +) + +func TestDeleteStagedModelFiles(t *testing.T) { + t.Run("rejects empty model name", func(t *testing.T) { + dir := t.TempDir() + err := gallery.DeleteStagedModelFiles(dir, "") + if err == nil { + t.Fatal("expected error for empty model name") + } + }) + + t.Run("rejects path traversal via ..", func(t *testing.T) { + dir := t.TempDir() + err := gallery.DeleteStagedModelFiles(dir, "../../etc/passwd") + if err == nil { + t.Fatal("expected error for path traversal attempt") + } + }) + + t.Run("rejects path traversal via ../foo", func(t *testing.T) { + dir := t.TempDir() + err := gallery.DeleteStagedModelFiles(dir, "../foo") + if err == nil { + t.Fatal("expected error for path traversal attempt") + } + }) + + t.Run("removes model subdirectory with all files", func(t *testing.T) { + dir := t.TempDir() + modelDir := filepath.Join(dir, "my-model", "sd-cpp", "models") + if err := os.MkdirAll(modelDir, 0o755); err != nil { + t.Fatal(err) + } + // Create model files in subdirectory + os.WriteFile(filepath.Join(modelDir, "flux.gguf"), []byte("model"), 0o644) + os.WriteFile(filepath.Join(modelDir, "flux.gguf.mmproj"), []byte("mmproj"), 0o644) + + err := gallery.DeleteStagedModelFiles(dir, "my-model") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Entire my-model directory should be gone + if _, err := os.Stat(filepath.Join(dir, "my-model")); !os.IsNotExist(err) { + t.Fatal("expected model directory to be removed") + } + }) + + t.Run("removes single file model", func(t *testing.T) { + dir := t.TempDir() + modelFile := filepath.Join(dir, "model.gguf") + os.WriteFile(modelFile, []byte("model"), 0o644) + + err := gallery.DeleteStagedModelFiles(dir, "model.gguf") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if _, err := os.Stat(modelFile); !os.IsNotExist(err) { + t.Fatal("expected model file to be removed") + } + }) + + t.Run("removes sibling files via glob", func(t *testing.T) { + dir := t.TempDir() + modelFile := filepath.Join(dir, "model.gguf") + siblingFile := filepath.Join(dir, "model.gguf.mmproj") + os.WriteFile(modelFile, []byte("model"), 0o644) + os.WriteFile(siblingFile, []byte("mmproj"), 0o644) + + err := gallery.DeleteStagedModelFiles(dir, "model.gguf") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if _, err := os.Stat(modelFile); !os.IsNotExist(err) { + t.Fatal("expected model file to be removed") + } + if _, err := os.Stat(siblingFile); !os.IsNotExist(err) { + t.Fatal("expected sibling file to be removed") + } + }) + + t.Run("no error when model does not exist", func(t *testing.T) { + dir := t.TempDir() + err := gallery.DeleteStagedModelFiles(dir, "nonexistent") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) +} diff --git a/core/http/app.go b/core/http/app.go index 94f36c89d..763ccb8b6 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -16,12 +16,17 @@ import ( "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/http/endpoints/localai" + httpMiddleware "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/finetune" + "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/services/monitoring" + "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/quantization" "github.com/mudler/xlog" ) @@ -155,7 +160,7 @@ func API(application *application.Application) (*echo.Echo, error) { // Metrics middleware if !application.ApplicationConfig().DisableMetrics { - metricsService, err := services.NewLocalAIMetricsService() + metricsService, err := monitoring.NewLocalAIMetricsService() if err != nil { return nil, err } @@ -295,9 +300,9 @@ func API(application *application.Application) (*echo.Echo, error) { routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) // Create opcache for tracking UI operations (used by both UI and LocalAI routes) - var opcache *services.OpCache + var opcache *galleryop.OpCache if !application.ApplicationConfig().DisableWebUI { - opcache = services.NewOpCache(application.GalleryService()) + opcache = galleryop.NewOpCache(application.GalleryService()) } mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP) @@ -305,22 +310,51 @@ func API(application *application.Application) (*echo.Echo, error) { routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw) // Fine-tuning routes fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning) - ftService := services.NewFineTuneService( + ftService := finetune.NewFineTuneService( application.ApplicationConfig(), application.ModelLoader(), application.ModelConfigLoader(), ) + if d := application.Distributed(); d != nil { + ftService.SetNATSClient(d.Nats) + if d.DistStores != nil && d.DistStores.FineTune != nil { + ftService.SetFineTuneStore(d.DistStores.FineTune) + } + } routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw) // Quantization routes quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization) - qService := services.NewQuantizationService( + qService := quantization.NewQuantizationService( application.ApplicationConfig(), application.ModelLoader(), application.ModelConfigLoader(), ) routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw) + // Node management routes (distributed mode) + distCfg := application.ApplicationConfig().Distributed + var registry *nodes.NodeRegistry + var remoteUnloader nodes.NodeCommandSender + if d := application.Distributed(); d != nil { + registry = d.Registry + if d.Router != nil { + remoteUnloader = d.Router.Unloader() + } + } + routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret) + routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken) + + // Distributed SSE routes (job progress + agent events via NATS) + if d := application.Distributed(); d != nil { + if d.Dispatcher != nil { + e.GET("/api/agent/jobs/:id/progress", d.Dispatcher.SSEHandler(), mcpJobsMw) + } + if d.AgentBridge != nil { + e.GET("/api/agents/:name/sse/distributed", d.AgentBridge.SSEHandler(), agentsMw) + } + } + routes.RegisterOpenAIRoutes(e, requestExtractor, application) routes.RegisterAnthropicRoutes(e, requestExtractor, application) routes.RegisterOpenResponsesRoutes(e, requestExtractor, application) diff --git a/core/http/app_test.go b/core/http/app_test.go index 903aae17b..1640451e1 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -44,14 +44,14 @@ Say hello. ### Response:` type modelApplyRequest struct { - ID string `json:"id"` - URL string `json:"url"` - ConfigURL string `json:"config_url"` - Name string `json:"name"` - Overrides map[string]interface{} `json:"overrides"` + ID string `json:"id"` + URL string `json:"url"` + ConfigURL string `json:"config_url"` + Name string `json:"name"` + Overrides map[string]any `json:"overrides"` } -func getModelStatus(url string) (response map[string]interface{}) { +func getModelStatus(url string) (response map[string]any) { // Create the HTTP request req, err := http.NewRequest("GET", url, nil) req.Header.Set("Content-Type", "application/json") @@ -94,7 +94,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) { return response, err } -func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { +func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]any) { //url := "http://localhost:AI/models/apply" @@ -336,7 +336,7 @@ var _ = Describe("API test", func() { Name: "bert", URL: bertEmbeddingsURL, }, - Overrides: map[string]interface{}{"backend": "llama-cpp"}, + Overrides: map[string]any{"backend": "llama-cpp"}, }, { Metadata: gallery.Metadata{ @@ -344,7 +344,7 @@ var _ = Describe("API test", func() { URL: bertEmbeddingsURL, AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}}, }, - Overrides: map[string]interface{}{"foo": "bar"}, + Overrides: map[string]any{"foo": "bar"}, }, } out, err := yaml.Marshal(g) @@ -464,7 +464,7 @@ var _ = Describe("API test", func() { Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) - resp := map[string]interface{}{} + resp := map[string]any{} Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) fmt.Println(response) @@ -479,7 +479,7 @@ var _ = Describe("API test", func() { _, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml")) Expect(err).ToNot(HaveOccurred()) - content := map[string]interface{}{} + content := map[string]any{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) @@ -503,7 +503,7 @@ var _ = Describe("API test", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: bertEmbeddingsURL, Name: "bert", - Overrides: map[string]interface{}{ + Overrides: map[string]any{ "backend": "llama", }, }) @@ -520,7 +520,7 @@ var _ = Describe("API test", func() { dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) - content := map[string]interface{}{} + content := map[string]any{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("llama")) @@ -529,7 +529,7 @@ var _ = Describe("API test", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: bertEmbeddingsURL, Name: "bert", - Overrides: map[string]interface{}{}, + Overrides: map[string]any{}, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) @@ -544,7 +544,7 @@ var _ = Describe("API test", func() { dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) - content := map[string]interface{}{} + content := map[string]any{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) @@ -586,7 +586,7 @@ parameters: Expect(response.ID).ToNot(BeEmpty()) uuid := response.ID - resp := map[string]interface{}{} + resp := map[string]any{} Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) resp = response @@ -601,7 +601,7 @@ parameters: dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml")) Expect(err).ToNot(HaveOccurred()) - content := map[string]interface{}{} + content := map[string]any{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["name"]).To(Equal("test-import-model")) @@ -657,7 +657,7 @@ parameters: Expect(response.ID).ToNot(BeEmpty()) uuid := response.ID - resp := map[string]interface{}{} + resp := map[string]any{} Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) resp = response @@ -1248,7 +1248,7 @@ parameters: Context("Agent Jobs", Label("agent-jobs"), func() { It("creates and manages tasks", func() { // Create a task - taskBody := map[string]interface{}{ + taskBody := map[string]any{ "name": "Test Task", "description": "Test Description", "model": "testmodel.ggml", @@ -1256,7 +1256,7 @@ parameters: "enabled": true, } - var createResp map[string]interface{} + var createResp map[string]any err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) Expect(err).ToNot(HaveOccurred()) Expect(createResp["id"]).ToNot(BeEmpty()) @@ -1302,20 +1302,20 @@ parameters: It("executes and monitors jobs", func() { // Create a task first - taskBody := map[string]interface{}{ + taskBody := map[string]any{ "name": "Job Test Task", "model": "testmodel.ggml", "prompt": "Say hello", "enabled": true, } - var createResp map[string]interface{} + var createResp map[string]any err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) Expect(err).ToNot(HaveOccurred()) taskID := createResp["id"].(string) // Execute a job - jobBody := map[string]interface{}{ + jobBody := map[string]any{ "task_id": taskID, "parameters": map[string]string{}, } @@ -1357,14 +1357,14 @@ parameters: It("executes task by name", func() { // Create a task with a specific name - taskBody := map[string]interface{}{ + taskBody := map[string]any{ "name": "Named Task", "model": "testmodel.ggml", "prompt": "Hello", "enabled": true, } - var createResp map[string]interface{} + var createResp map[string]any err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) Expect(err).ToNot(HaveOccurred()) diff --git a/core/http/auth/middleware.go b/core/http/auth/middleware.go index 4799d51fa..01ec33a68 100644 --- a/core/http/auth/middleware.go +++ b/core/http/auth/middleware.go @@ -516,6 +516,17 @@ func isExemptPath(path string, appConfig *config.ApplicationConfig) bool { return true } + // Node self-service endpoints — authenticated via registration token, not global auth. + // Only exempt the specific known endpoints, not the entire prefix. + if strings.HasPrefix(path, "/api/node/") { + if path == "/api/node/register" || + strings.HasSuffix(path, "/heartbeat") || + strings.HasSuffix(path, "/drain") || + strings.HasSuffix(path, "/deregister") { + return true + } + } + // Check configured exempt paths for _, p := range appConfig.PathWithoutAuth { if strings.HasPrefix(path, p) { @@ -540,6 +551,14 @@ func isAPIPath(path string) bool { strings.HasPrefix(path, "/system") || strings.HasPrefix(path, "/ws/") || strings.HasPrefix(path, "/generated-") || + strings.HasPrefix(path, "/chat/") || + strings.HasPrefix(path, "/completions") || + strings.HasPrefix(path, "/edits") || + strings.HasPrefix(path, "/embeddings") || + strings.HasPrefix(path, "/audio/") || + strings.HasPrefix(path, "/images/") || + strings.HasPrefix(path, "/messages") || + strings.HasPrefix(path, "/responses") || path == "/metrics" } diff --git a/core/http/auth/models.go b/core/http/auth/models.go index 598c0342c..854d02e6c 100644 --- a/core/http/auth/models.go +++ b/core/http/auth/models.go @@ -9,24 +9,25 @@ import ( // Auth provider constants. const ( - ProviderLocal = "local" - ProviderGitHub = "github" - ProviderOIDC = "oidc" + ProviderLocal = "local" + ProviderGitHub = "github" + ProviderOIDC = "oidc" + ProviderAgentWorker = "agent-worker" ) // User represents an authenticated user. type User struct { - ID string `gorm:"primaryKey;size:36"` - Email string `gorm:"size:255;index"` - Name string `gorm:"size:255"` - AvatarURL string `gorm:"size:512"` - Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC - Subject string `gorm:"size:255"` // provider-specific user ID - PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users + ID string `gorm:"primaryKey;size:36"` + Email string `gorm:"size:255;index"` + Name string `gorm:"size:255"` + AvatarURL string `gorm:"size:512"` + Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC + Subject string `gorm:"size:255"` // provider-specific user ID + PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users Role string `gorm:"size:20;default:user"` Status string `gorm:"size:20;default:active"` // "active", "pending" - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt time.Time + UpdatedAt time.Time } // Session represents a user login session. @@ -90,16 +91,16 @@ func (p *PermissionMap) Scan(value any) error { // InviteCode represents an admin-generated invitation for user registration. type InviteCode struct { - ID string `gorm:"primaryKey;size:36"` - Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code - CodePrefix string `gorm:"size:12"` // first 8 chars for admin display - CreatedBy string `gorm:"size:36;not null"` - UsedBy *string `gorm:"size:36"` - UsedAt *time.Time - ExpiresAt time.Time `gorm:"not null;index"` - CreatedAt time.Time - Creator User `gorm:"foreignKey:CreatedBy"` - Consumer *User `gorm:"foreignKey:UsedBy"` + ID string `gorm:"primaryKey;size:36"` + Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code + CodePrefix string `gorm:"size:12"` // first 8 chars for admin display + CreatedBy string `gorm:"size:36;not null"` + UsedBy *string `gorm:"size:36"` + UsedAt *time.Time + ExpiresAt time.Time `gorm:"not null;index"` + CreatedAt time.Time + Creator User `gorm:"foreignKey:CreatedBy"` + Consumer *User `gorm:"foreignKey:UsedBy"` } // ModelAllowlist controls which models a user can access. diff --git a/core/http/auth/permissions.go b/core/http/auth/permissions.go index 63bce7d21..4361904dc 100644 --- a/core/http/auth/permissions.go +++ b/core/http/auth/permissions.go @@ -33,24 +33,24 @@ const ( FeatureMCPJobs = "mcp_jobs" // General features (default OFF for new users) - FeatureFineTuning = "fine_tuning" - FeatureQuantization = "quantization" + FeatureFineTuning = "fine_tuning" + FeatureQuantization = "quantization" // API features (default ON for new users) - FeatureChat = "chat" - FeatureImages = "images" - FeatureAudioSpeech = "audio_speech" + FeatureChat = "chat" + FeatureImages = "images" + FeatureAudioSpeech = "audio_speech" FeatureAudioTranscription = "audio_transcription" - FeatureVAD = "vad" - FeatureDetection = "detection" - FeatureVideo = "video" - FeatureEmbeddings = "embeddings" - FeatureSound = "sound" - FeatureRealtime = "realtime" - FeatureRerank = "rerank" - FeatureTokenize = "tokenize" - FeatureMCP = "mcp" - FeatureStores = "stores" + FeatureVAD = "vad" + FeatureDetection = "detection" + FeatureVideo = "video" + FeatureEmbeddings = "embeddings" + FeatureSound = "sound" + FeatureRealtime = "realtime" + FeatureRerank = "rerank" + FeatureTokenize = "tokenize" + FeatureMCP = "mcp" + FeatureStores = "stores" ) // AgentFeatures lists agent-related features (default OFF). diff --git a/core/http/auth/quota.go b/core/http/auth/quota.go index a79e1861b..8e26e787b 100644 --- a/core/http/auth/quota.go +++ b/core/http/auth/quota.go @@ -24,14 +24,14 @@ type QuotaRule struct { // QuotaStatus is returned to clients with current usage included. type QuotaStatus struct { - ID string `json:"id"` - Model string `json:"model"` - MaxRequests *int64 `json:"max_requests"` - MaxTotalTokens *int64 `json:"max_total_tokens"` - Window string `json:"window"` - CurrentRequests int64 `json:"current_requests"` - CurrentTokens int64 `json:"current_total_tokens"` - ResetsAt string `json:"resets_at,omitempty"` + ID string `json:"id"` + Model string `json:"model"` + MaxRequests *int64 `json:"max_requests"` + MaxTotalTokens *int64 `json:"max_total_tokens"` + Window string `json:"window"` + CurrentRequests int64 `json:"current_requests"` + CurrentTokens int64 `json:"current_total_tokens"` + ResetsAt string `json:"resets_at,omitempty"` } // ── CRUD ── @@ -209,9 +209,9 @@ func QuotaExceeded(db *gorm.DB, userID, model string) (bool, int64, string) { var quotaCache = newQuotaCacheStore() type quotaCacheStore struct { - mu sync.RWMutex - rules map[string]cachedRules // userID -> rules - usage map[string]cachedUsage // "userID|model|windowStart" -> counts + mu sync.RWMutex + rules map[string]cachedRules // userID -> rules + usage map[string]cachedUsage // "userID|model|windowStart" -> counts } type cachedRules struct { diff --git a/core/http/auth/session.go b/core/http/auth/session.go index 7c8bf68b6..028e8db19 100644 --- a/core/http/auth/session.go +++ b/core/http/auth/session.go @@ -13,7 +13,7 @@ import ( const ( sessionDuration = 30 * 24 * time.Hour // 30 days - sessionIDBytes = 32 // 32 bytes = 64 hex chars + sessionIDBytes = 32 // 32 bytes = 64 hex chars sessionCookie = "session" sessionRotationInterval = 1 * time.Hour ) diff --git a/core/http/auth/usage.go b/core/http/auth/usage.go index 08841a442..31c3202b2 100644 --- a/core/http/auth/usage.go +++ b/core/http/auth/usage.go @@ -10,15 +10,15 @@ import ( // UsageRecord represents a single API request's token usage. type UsageRecord struct { - ID uint `gorm:"primaryKey;autoIncrement"` - UserID string `gorm:"size:36;index:idx_usage_user_time"` - UserName string `gorm:"size:255"` - Model string `gorm:"size:255;index"` - Endpoint string `gorm:"size:255"` + ID uint `gorm:"primaryKey;autoIncrement"` + UserID string `gorm:"size:36;index:idx_usage_user_time"` + UserName string `gorm:"size:255"` + Model string `gorm:"size:255;index"` + Endpoint string `gorm:"size:255"` PromptTokens int64 CompletionTokens int64 TotalTokens int64 - Duration int64 // milliseconds + Duration int64 // milliseconds CreatedAt time.Time `gorm:"index:idx_usage_user_time"` } @@ -127,10 +127,10 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) { bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) query := db.Model(&UsageRecord{}). - Select(bucketExpr+", model, user_id, user_name, "+ - "SUM(prompt_tokens) as prompt_tokens, "+ - "SUM(completion_tokens) as completion_tokens, "+ - "SUM(total_tokens) as total_tokens, "+ + Select(bucketExpr + ", model, user_id, user_name, " + + "SUM(prompt_tokens) as prompt_tokens, " + + "SUM(completion_tokens) as completion_tokens, " + + "SUM(total_tokens) as total_tokens, " + "COUNT(*) as request_count"). Group("bucket, model, user_id, user_name"). Order("bucket ASC") diff --git a/core/http/auth/usage_test.go b/core/http/auth/usage_test.go index 0c3fa5df5..8782ac095 100644 --- a/core/http/auth/usage_test.go +++ b/core/http/auth/usage_test.go @@ -36,7 +36,7 @@ var _ = Describe("Usage", func() { db := testDB() // Insert records for two users - for i := 0; i < 3; i++ { + for range 3 { err := auth.RecordUsage(db, &auth.UsageRecord{ UserID: "user-a", UserName: "Alice", diff --git a/core/http/endpoints/anthropic/messages.go b/core/http/endpoints/anthropic/messages.go index adb4b989f..7736e70b2 100644 --- a/core/http/endpoints/anthropic/messages.go +++ b/core/http/endpoints/anthropic/messages.go @@ -3,7 +3,6 @@ package anthropic import ( "encoding/json" "fmt" - "strings" "github.com/google/uuid" "github.com/labstack/echo/v4" @@ -25,7 +24,7 @@ import ( // @Param request body schema.AnthropicRequest true "query params" // @Success 200 {object} schema.AnthropicResponse "Response" // @Router /v1/messages [post] -func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { +func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc { return func(c echo.Context) error { id := uuid.New().String() @@ -52,7 +51,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu funcs, shouldUseFn := convertAnthropicTools(input, cfg) // MCP injection: prompts, resources, and tools - var mcpToolInfos []mcpTools.MCPToolInfo + var mcpExecutor mcpTools.ToolExecutor mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata) mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata) mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata) @@ -60,76 +59,29 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (cfg.MCP.Servers != "" || cfg.MCP.Stdio != "") { remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML() if mcpErr == nil { + mcpExecutor = mcpTools.NewToolExecutor(c.Request().Context(), natsClient, cfg.Name, remote, stdio, mcpServers) + + // Prompt and resource injection (pre-processing step — resolves locally regardless of distributed mode) namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers) if sessErr == nil && len(namedSessions) > 0 { - // Prompt injection - if mcpPromptName != "" { - prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions) - if discErr == nil { - promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs) - if getErr == nil { - var injected []schema.Message - for _, pm := range promptMsgs { - injected = append(injected, schema.Message{ - Role: string(pm.Role), - Content: mcpTools.PromptMessageToText(pm), - }) - } - openAIMessages = append(injected, openAIMessages...) - xlog.Debug("Anthropic MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected)) - } else { - xlog.Error("Failed to get MCP prompt", "error", getErr) - } - } + mcpCtx, _ := mcpTools.InjectMCPContext(c.Request().Context(), namedSessions, mcpPromptName, mcpPromptArgs, mcpResourceURIs) + if mcpCtx != nil { + openAIMessages = append(mcpCtx.PromptMessages, openAIMessages...) + mcpTools.AppendResourceSuffix(openAIMessages, mcpCtx.ResourceSuffix) } + } - // Resource injection - if len(mcpResourceURIs) > 0 { - resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions) - if discErr == nil { - var resourceTexts []string - for _, uri := range mcpResourceURIs { - content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri) - if readErr != nil { - xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri) - continue - } - name := uri - for _, r := range resources { - if r.URI == uri { - name = r.Name - break - } - } - resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content)) - } - if len(resourceTexts) > 0 && len(openAIMessages) > 0 { - lastIdx := len(openAIMessages) - 1 - suffix := "\n\n" + strings.Join(resourceTexts, "\n\n") - switch ct := openAIMessages[lastIdx].Content.(type) { - case string: - openAIMessages[lastIdx].Content = ct + suffix - default: - openAIMessages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix) - } - xlog.Debug("Anthropic MCP resources injected", "count", len(resourceTexts)) - } - } - } - - // Tool injection - if len(mcpServers) > 0 { - discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions) - if discErr == nil { - mcpToolInfos = discovered - for _, ti := range mcpToolInfos { - funcs = append(funcs, ti.Function) - } - shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions() - xlog.Debug("Anthropic MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs)) - } else { - xlog.Error("Failed to discover MCP tools", "error", discErr) + // Tool injection via executor + if mcpExecutor.HasTools() { + mcpFuncs, discErr := mcpExecutor.DiscoverTools(c.Request().Context()) + if discErr == nil { + for _, fn := range mcpFuncs { + funcs = append(funcs, fn) } + shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions() + xlog.Debug("Anthropic MCP tools injected", "count", len(mcpFuncs), "total_funcs", len(funcs)) + } else { + xlog.Error("Failed to discover MCP tools", "error", discErr) } } } else { @@ -177,19 +129,19 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput) if input.Stream { - return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) + return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator) } - return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) + return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator) } } -func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error { +func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error { mcpMaxIterations := 10 if cfg.Agent.MaxIterations > 0 { mcpMaxIterations = cfg.Agent.MaxIterations } - hasMCPTools := len(mcpToolInfos) > 0 + hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools() for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { // Re-template on each MCP iteration since messages may have changed @@ -227,7 +179,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic if hasMCPTools && shouldUseFn && len(toolCalls) > 0 { var hasMCPCalls bool for _, tc := range toolCalls { - if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { + if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) { hasMCPCalls = true break } @@ -257,13 +209,12 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic // Execute each MCP tool call and append results for _, tc := range assistantMsg.ToolCalls { - if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool (Anthropic)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration) - toolResult, toolErr := mcpTools.ExecuteMCPToolCall( - c.Request().Context(), mcpToolInfos, - tc.FunctionCall.Name, tc.FunctionCall.Arguments, + toolResult, toolErr := mcpExecutor.ExecuteTool( + c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments, ) if toolErr != nil { xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) @@ -290,10 +241,10 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic if shouldUseFn && len(toolCalls) > 0 { stopReason = "tool_use" for _, tc := range toolCalls { - var inputArgs map[string]interface{} + var inputArgs map[string]any if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil { xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments) - inputArgs = map[string]interface{}{"raw": tc.Arguments} + inputArgs = map[string]any{"raw": tc.Arguments} } contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{ Type: "tool_use", @@ -316,9 +267,9 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{Type: "text", Text: stripped}) } for i, fc := range parsed { - var inputArgs map[string]interface{} + var inputArgs map[string]any if err := json.Unmarshal([]byte(fc.Arguments), &inputArgs); err != nil { - inputArgs = map[string]interface{}{"raw": fc.Arguments} + inputArgs = map[string]any{"raw": fc.Arguments} } toolCallID := fc.ID if toolCallID == "" { @@ -365,7 +316,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached") } -func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error { +func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") @@ -388,7 +339,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq if cfg.Agent.MaxIterations > 0 { mcpMaxIterations = cfg.Agent.MaxIterations } - hasMCPTools := len(mcpToolInfos) > 0 + hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools() for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { // Re-template on MCP iterations @@ -483,7 +434,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq _, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback) if err != nil { xlog.Error("Anthropic stream model inference failed", "error", err) - return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) + sendAnthropicSSE(c, schema.AnthropicStreamEvent{ + Type: "error", + Error: &schema.AnthropicError{ + Type: "api_error", + Message: fmt.Sprintf("model inference failed: %v", err), + }, + }) + return nil } // Also check chat deltas for tool calls @@ -495,7 +453,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq if hasMCPTools && len(collectedToolCalls) > 0 { var hasMCPCalls bool for _, tc := range collectedToolCalls { - if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { + if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) { hasMCPCalls = true break } @@ -525,13 +483,12 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq // Execute MCP tool calls for _, tc := range assistantMsg.ToolCalls { - if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool (Anthropic stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration) - toolResult, toolErr := mcpTools.ExecuteMCPToolCall( - c.Request().Context(), mcpToolInfos, - tc.FunctionCall.Name, tc.FunctionCall.Arguments, + toolResult, toolErr := mcpExecutor.ExecuteTool( + c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments, ) if toolErr != nil { xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) @@ -686,7 +643,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M case string: openAIMsg.StringContent = content openAIMsg.Content = content - case []interface{}: + case []any: // Handle array of content blocks var textContent string var stringImages []string @@ -694,7 +651,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M toolCallIndex := 0 for _, block := range content { - if blockMap, ok := block.(map[string]interface{}); ok { + if blockMap, ok := block.(map[string]any); ok { blockType, _ := blockMap["type"].(string) switch blockType { case "text": @@ -703,7 +660,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M } case "image": // Handle image content - if source, ok := blockMap["source"].(map[string]interface{}); ok { + if source, ok := blockMap["source"].(map[string]any); ok { if sourceType, ok := source["type"].(string); ok && sourceType == "base64" { if data, ok := source["data"].(string); ok { mediaType, _ := source["media_type"].(string) @@ -718,14 +675,14 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M toolID, _ := blockMap["id"].(string) toolName, _ := blockMap["name"].(string) toolInput := blockMap["input"] - + // Serialize input to JSON string inputJSON, err := json.Marshal(toolInput) if err != nil { xlog.Warn("Failed to marshal tool input", "error", err) inputJSON = []byte("{}") } - + toolCalls = append(toolCalls, schema.ToolCall{ Index: toolCallIndex, ID: toolID, @@ -745,16 +702,16 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M if isErrorPtr, ok := blockMap["is_error"].(*bool); ok && isErrorPtr != nil { isError = *isErrorPtr } - + var resultText string if resultContent, ok := blockMap["content"]; ok { switch rc := resultContent.(type) { case string: resultText = rc - case []interface{}: + case []any: // Array of content blocks for _, cb := range rc { - if cbMap, ok := cb.(map[string]interface{}); ok { + if cbMap, ok := cb.(map[string]any); ok { if cbMap["type"] == "text" { if text, ok := cbMap["text"].(string); ok { resultText += text @@ -764,7 +721,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M } } } - + // Add tool result as a tool role message // We need to handle this differently - create a new message if msg.Role == "user" { @@ -781,7 +738,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M openAIMsg.StringContent = textContent openAIMsg.Content = textContent openAIMsg.StringImages = stringImages - + // Add tool calls if present if len(toolCalls) > 0 { openAIMsg.ToolCalls = toolCalls @@ -799,7 +756,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf if len(input.Tools) == 0 { return nil, false } - + var funcs functions.Functions for _, tool := range input.Tools { f := functions.Function{ @@ -809,7 +766,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf } funcs = append(funcs, f) } - + // Handle tool_choice if input.ToolChoice != nil { switch tc := input.ToolChoice.(type) { @@ -823,7 +780,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf return nil, false } // "auto" is the default - let model decide - case map[string]interface{}: + case map[string]any: // Specific tool selection: {"type": "tool", "name": "tool_name"} if tcType, ok := tc["type"].(string); ok && tcType == "tool" { if name, ok := tc["name"].(string); ok { @@ -833,6 +790,6 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf } } } - + return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions() } diff --git a/core/http/endpoints/explorer/dashboard.go b/core/http/endpoints/explorer/dashboard.go index 3c1e0ae91..7c759ec5a 100644 --- a/core/http/endpoints/explorer/dashboard.go +++ b/core/http/endpoints/explorer/dashboard.go @@ -1,9 +1,10 @@ package explorer import ( + "cmp" "encoding/base64" "net/http" - "sort" + "slices" "strings" "github.com/labstack/echo/v4" @@ -14,7 +15,7 @@ import ( func Dashboard() echo.HandlerFunc { return func(c echo.Context) error { - summary := map[string]interface{}{ + summary := map[string]any{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), "BaseURL": middleware.BaseURL(c), @@ -61,8 +62,8 @@ func ShowNetworks(db *explorer.Database) echo.HandlerFunc { } // order by number of clusters - sort.Slice(results, func(i, j int) bool { - return len(results[i].Clusters) > len(results[j].Clusters) + slices.SortFunc(results, func(a, b Network) int { + return cmp.Compare(len(b.Clusters), len(a.Clusters)) }) return c.JSON(http.StatusOK, results) @@ -73,36 +74,36 @@ func AddNetwork(db *explorer.Database) echo.HandlerFunc { return func(c echo.Context) error { request := new(AddNetworkRequest) if err := c.Bind(request); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"}) + return c.JSON(http.StatusBadRequest, map[string]any{"error": "Cannot parse JSON"}) } if request.Token == "" { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"}) + return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token is required"}) } if request.Name == "" { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"}) + return c.JSON(http.StatusBadRequest, map[string]any{"error": "Name is required"}) } if request.Description == "" { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"}) + return c.JSON(http.StatusBadRequest, map[string]any{"error": "Description is required"}) } // TODO: check if token is valid, otherwise reject // try to decode the token from base64 _, err := base64.StdEncoding.DecodeString(request.Token) if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"}) + return c.JSON(http.StatusBadRequest, map[string]any{"error": "Invalid token"}) } if _, exists := db.Get(request.Token); exists { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"}) + return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token already exists"}) } err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description}) if err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"}) + return c.JSON(http.StatusInternalServerError, map[string]any{"error": "Cannot add token"}) } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"}) + return c.JSON(http.StatusOK, map[string]any{"message": "Token added"}) } } diff --git a/core/http/endpoints/localai/agent_jobs.go b/core/http/endpoints/localai/agent_jobs.go index 8ed20d7df..b55fab65c 100644 --- a/core/http/endpoints/localai/agent_jobs.go +++ b/core/http/endpoints/localai/agent_jobs.go @@ -1,6 +1,7 @@ package localai import ( + "errors" "fmt" "net/http" "strconv" @@ -8,12 +9,12 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/agentpool" ) // getJobService returns the job service for the current user. // Falls back to the global service when no user is authenticated. -func getJobService(app *application.Application, c echo.Context) *services.AgentJobService { +func getJobService(app *application.Application, c echo.Context) *agentpool.AgentJobService { userID := getUserID(c) if userID == "" { return app.AgentJobService() @@ -54,7 +55,7 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc { } if err := getJobService(app, c).UpdateTask(id, task); err != nil { - if err.Error() == "task not found: "+id { + if errors.Is(err, agentpool.ErrTaskNotFound) { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) @@ -68,7 +69,7 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") if err := getJobService(app, c).DeleteTask(id); err != nil { - if err.Error() == "task not found: "+id { + if errors.Is(err, agentpool.ErrTaskNotFound) { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) @@ -244,7 +245,7 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") if err := getJobService(app, c).CancelJob(id); err != nil { - if err.Error() == "job not found: "+id { + if errors.Is(err, agentpool.ErrJobNotFound) { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) @@ -258,7 +259,7 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") if err := getJobService(app, c).DeleteJob(id); err != nil { - if err.Error() == "job not found: "+id { + if errors.Is(err, agentpool.ErrJobNotFound) { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) @@ -275,7 +276,7 @@ func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc { if c.Request().ContentLength > 0 { if err := c.Bind(¶ms); err != nil { - body := make(map[string]interface{}) + body := make(map[string]any) if err := c.Bind(&body); err == nil { params = make(map[string]string) for k, v := range body { diff --git a/core/http/endpoints/localai/agent_responses.go b/core/http/endpoints/localai/agent_responses.go index 391926223..d118b672b 100644 --- a/core/http/endpoints/localai/agent_responses.go +++ b/core/http/endpoints/localai/agent_responses.go @@ -2,6 +2,7 @@ package localai import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -10,8 +11,9 @@ import ( "github.com/google/uuid" "github.com/labstack/echo/v4" - "github.com/mudler/LocalAI/core/application" coreTypes "github.com/mudler/LocalAGI/core/types" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/services/agents" "github.com/mudler/xlog" "github.com/sashabaranov/go-openai" ) @@ -50,55 +52,105 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc return next(c) } - // Check if this model name is an agent - ag := svc.GetAgent(req.Model) - if ag == nil { - return next(c) - } - - // This is an agent — handle the request directly + // Check if this model name is an agent — try in-process agent first, + // fall back to config lookup (covers distributed mode where agents + // don't run in-process). messages := parseInputToMessages(req.Input) - if len(messages) == 0 { - return c.JSON(http.StatusBadRequest, map[string]any{ - "error": map[string]string{ - "type": "invalid_request_error", - "message": "no input messages provided", - }, - }) + userID := effectiveUserID(c) + ag := svc.GetAgent(req.Model) + if ag == nil && svc.GetAgentConfigForUser(userID, req.Model) == nil { + return next(c) // not an agent } - jobOptions := []coreTypes.JobOption{ - coreTypes.WithConversationHistory(messages), + // Extract the last user message for the executor + var userMessage string + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + userMessage = messages[i].Content + break + } } - res := ag.Ask(jobOptions...) + var responseText string - if res == nil { - return c.JSON(http.StatusInternalServerError, map[string]any{ - "error": map[string]string{ - "type": "server_error", - "message": "agent request failed or was cancelled", - }, - }) - } - if res.Error != nil { - xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error) - return c.JSON(http.StatusInternalServerError, map[string]any{ - "error": map[string]string{ - "type": "server_error", - "message": res.Error.Error(), - }, + if ag != nil { + // Local mode: use LocalAGI agent directly + jobOptions := []coreTypes.JobOption{ + coreTypes.WithConversationHistory(messages), + } + + res := ag.Ask(jobOptions...) + if res == nil { + return c.JSON(http.StatusInternalServerError, map[string]any{ + "error": map[string]string{ + "type": "server_error", + "message": "agent request failed or was cancelled", + }, + }) + } + if res.Error != nil { + xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error) + return c.JSON(http.StatusInternalServerError, map[string]any{ + "error": map[string]string{ + "type": "server_error", + "message": res.Error.Error(), + }, + }) + } + responseText = res.Response + } else { + // Distributed mode: dispatch via NATS + wait for response synchronously + var bridge *agents.EventBridge + if d := app.Distributed(); d != nil { + bridge = d.AgentBridge + } + if bridge == nil { + return next(c) + } + + // Subscribe BEFORE dispatching so we never miss a fast response + ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Minute) + defer cancel() + + responseCh := make(chan string, 1) + sub, err := bridge.SubscribeEvents(req.Model, userID, func(evt agents.AgentEvent) { + if evt.EventType == "json_message" && evt.Sender == "agent" { + responseCh <- evt.Content + } }) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]any{ + "error": map[string]string{"type": "server_error", "message": "failed to subscribe to agent events"}, + }) + } + defer sub.Unsubscribe() + + // Now dispatch via ChatForUser (publishes to NATS) + _, err = svc.ChatForUser(userID, req.Model, userMessage) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]any{ + "error": map[string]string{"type": "server_error", "message": err.Error()}, + }) + } + + select { + case responseText = <-responseCh: + // Got the response + case <-ctx.Done(): + return c.JSON(http.StatusGatewayTimeout, map[string]any{ + "error": map[string]string{"type": "server_error", "message": "agent response timeout"}, + }) + } } id := fmt.Sprintf("resp_%s", uuid.New().String()) return c.JSON(http.StatusOK, map[string]any{ - "id": id, - "object": "response", - "created_at": time.Now().Unix(), - "status": "completed", - "model": req.Model, + "id": id, + "object": "response", + "created_at": time.Now().Unix(), + "status": "completed", + "model": req.Model, "previous_response_id": nil, "output": []any{ map[string]any{ @@ -109,7 +161,7 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc "content": []map[string]any{ { "type": "output_text", - "text": res.Response, + "text": responseText, "annotations": []any{}, }, }, diff --git a/core/http/endpoints/localai/agent_skills.go b/core/http/endpoints/localai/agent_skills.go index 2256db2bb..6e8538f2e 100644 --- a/core/http/endpoints/localai/agent_skills.go +++ b/core/http/endpoints/localai/agent_skills.go @@ -7,6 +7,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" + skillsManager "github.com/mudler/LocalAI/core/services/skills" skilldomain "github.com/mudler/skillserver/pkg/domain" ) @@ -41,27 +42,48 @@ func skillsToResponses(skills []skilldomain.Skill) []skillResponse { return out } +// getSkillManager returns a SkillManager for the request's user. +func getSkillManager(c echo.Context, app *application.Application) (skillsManager.Manager, error) { + svc := app.AgentPoolService() + userID := getUserID(c) + return svc.SkillManagerForUser(userID) +} + +func getSkillManagerEffective(c echo.Context, app *application.Application) (skillsManager.Manager, error) { + svc := app.AgentPoolService() + userID := effectiveUserID(c) + return svc.SkillManagerForUser(userID) +} + func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) - skills, err := svc.ListSkillsForUser(userID) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + skills, err := mgr.List() if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } // Admin cross-user aggregation if wantsAllUsers(c) { + svc := app.AgentPoolService() usm := svc.UserServicesManager() if usm != nil { userIDs, _ := usm.ListAllUserIDs() userGroups := map[string]any{} + userID := getUserID(c) for _, uid := range userIDs { if uid == userID { continue } - userSkills, err := svc.ListSkillsForUser(uid) - if err != nil || len(userSkills) == 0 { + uidMgr, mgrErr := svc.SkillManagerForUser(uid) + if mgrErr != nil { + continue + } + userSkills, listErr := uidMgr.List() + if listErr != nil || len(userSkills) == 0 { continue } userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)} @@ -76,25 +98,28 @@ func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc { } } - return c.JSON(http.StatusOK, skillsToResponses(skills)) + return c.JSON(http.StatusOK, map[string]any{"skills": skillsToResponses(skills)}) } } func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) - cfg := svc.GetSkillsConfigForUser(userID) - return c.JSON(http.StatusOK, cfg) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusOK, map[string]string{}) + } + return c.JSON(http.StatusOK, mgr.GetConfig()) } } func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } query := c.QueryParam("q") - skills, err := svc.SearchSkillsForUser(userID, query) + skills, err := mgr.Search(query) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } @@ -104,8 +129,10 @@ func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc { func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } var payload struct { Name string `json:"name"` Description string `json:"description"` @@ -118,7 +145,7 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) + skill, err := mgr.Create(payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) if err != nil { if strings.Contains(err.Error(), "already exists") { return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()}) @@ -131,9 +158,11 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := effectiveUserID(c) - skill, err := svc.GetSkillForUser(userID, c.Param("name")) + mgr, err := getSkillManagerEffective(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + skill, err := mgr.Get(c.Param("name")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -143,8 +172,10 @@ func GetSkillEndpoint(app *application.Application) echo.HandlerFunc { func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := effectiveUserID(c) + mgr, err := getSkillManagerEffective(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } var payload struct { Description string `json:"description"` Content string `json:"content"` @@ -156,7 +187,7 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) + skill, err := mgr.Update(c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -169,9 +200,11 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := effectiveUserID(c) - if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil { + mgr, err := getSkillManagerEffective(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + if err := mgr.Delete(c.Param("name")); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -180,10 +213,12 @@ func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc { func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := effectiveUserID(c) + mgr, err := getSkillManagerEffective(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } name := c.Param("*") - data, err := svc.ExportSkillForUser(userID, name) + data, err := mgr.Export(name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -195,8 +230,10 @@ func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc { func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"}) @@ -210,7 +247,7 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - skill, err := svc.ImportSkillForUser(userID, data) + skill, err := mgr.Import(data) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -222,9 +259,11 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := effectiveUserID(c) - resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name")) + mgr, err := getSkillManagerEffective(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + resources, skill, err := mgr.ListResources(c.Param("name")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -260,9 +299,11 @@ func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := effectiveUserID(c) - content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*")) + mgr, err := getSkillManagerEffective(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + content, info, err := mgr.GetResource(c.Param("name"), c.Param("*")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -281,10 +322,12 @@ func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) - file, err := c.FormFile("file") + mgr, err := getSkillManager(c, app) if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + file, fileErr := c.FormFile("file") + if fileErr != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"}) } path := c.FormValue("path") @@ -300,7 +343,7 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } - if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil { + if err := mgr.CreateResource(c.Param("name"), path, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"path": path}) @@ -309,15 +352,17 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } var payload struct { Content string `json:"content"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil { + if err := mgr.UpdateResource(c.Param("name"), c.Param("*"), payload.Content); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -326,9 +371,11 @@ func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) - if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil { + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + if err := mgr.DeleteResource(c.Param("name"), c.Param("*")); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -339,9 +386,11 @@ func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) - repos, err := svc.ListGitReposForUser(userID) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + repos, err := mgr.ListGitRepos() if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } @@ -351,15 +400,17 @@ func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc { func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } var payload struct { URL string `json:"url"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - repo, err := svc.AddGitRepoForUser(userID, payload.URL) + repo, err := mgr.AddGitRepo(payload.URL) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -369,8 +420,10 @@ func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } var payload struct { URL string `json:"url"` Enabled *bool `json:"enabled"` @@ -378,7 +431,7 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled) + repo, err := mgr.UpdateGitRepo(c.Param("id"), payload.URL, payload.Enabled) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -391,9 +444,11 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) - if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil { + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + if err := mgr.DeleteGitRepo(c.Param("id")); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -405,9 +460,11 @@ func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) - if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil { + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + if err := mgr.SyncGitRepo(c.Param("id")); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"}) @@ -416,9 +473,11 @@ func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - svc := app.AgentPoolService() - userID := getUserID(c) - repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id")) + mgr, err := getSkillManager(c, app) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + repo, err := mgr.ToggleGitRepo(c.Param("id")) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } diff --git a/core/http/endpoints/localai/agents.go b/core/http/endpoints/localai/agents.go index d2bc25c48..2bf2b3263 100644 --- a/core/http/endpoints/localai/agents.go +++ b/core/http/endpoints/localai/agents.go @@ -4,20 +4,23 @@ import ( "encoding/json" "fmt" "io" + "maps" "net/http" "os" "path/filepath" - "sort" + "slices" "strings" "github.com/labstack/echo/v4" - "github.com/mudler/LocalAI/core/application" - "github.com/mudler/LocalAI/core/http/auth" - "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAGI/core/state" coreTypes "github.com/mudler/LocalAGI/core/types" agiServices "github.com/mudler/LocalAGI/services" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/services/agentpool" + "github.com/mudler/LocalAI/core/services/agents" + "github.com/mudler/LocalAI/pkg/utils" + "github.com/mudler/xlog" ) // getUserID extracts the scoped user ID from the request context. @@ -42,25 +45,39 @@ func wantsAllUsers(c echo.Context) bool { } // effectiveUserID returns the user ID to scope operations to. -// SECURITY: Only admins may supply ?user_id= to operate on another user's -// resources. Non-admin callers always get their own ID regardless of query params. +// SECURITY: Only admins and agent-worker service accounts may supply +// ?user_id= to operate on another user's resources. Agent-worker users are +// created exclusively server-side during node registration and need to access +// collections on behalf of the user whose agent they are executing. +// Regular callers always get their own ID regardless of query params. func effectiveUserID(c echo.Context) string { - if targetUID := c.QueryParam("user_id"); targetUID != "" && isAdminUser(c) { + if targetUID := c.QueryParam("user_id"); targetUID != "" && canImpersonateUser(c) { + if callerID := getUserID(c); callerID != targetUID { + xlog.Info("User impersonation", "caller", callerID, "target", targetUID, "path", c.Path()) + } return targetUID } return getUserID(c) } +// canImpersonateUser returns true if the caller is allowed to use ?user_id= to +// scope operations to another user. Allowed for admins and agent-worker service +// accounts (ProviderAgentWorker is set server-side during node registration and +// cannot be self-assigned). +func canImpersonateUser(c echo.Context) bool { + user := auth.GetUser(c) + if user == nil { + return false + } + return user.Role == auth.RoleAdmin || user.Provider == auth.ProviderAgentWorker +} + func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) statuses := svc.ListAgentsForUser(userID) - agents := make([]string, 0, len(statuses)) - for name := range statuses { - agents = append(agents, name) - } - sort.Strings(agents) + agents := slices.Sorted(maps.Keys(statuses)) resp := map[string]any{ "agents": agents, "agentCount": len(agents), @@ -111,13 +128,13 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") - ag := svc.GetAgentForUser(userID, name) - if ag == nil { + + statuses := svc.ListAgentsForUser(userID) + active, exists := statuses[name] + if !exists { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } - return c.JSON(http.StatusOK, map[string]any{ - "active": !ag.Paused(), - }) + return c.JSON(http.StatusOK, map[string]any{"active": active}) } } @@ -192,9 +209,13 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") + history := svc.GetAgentStatusForUser(userID, name) if history == nil { - history = &state.Status{ActionResults: []coreTypes.ActionState{}} + return c.JSON(http.StatusOK, map[string]any{ + "Name": name, + "History": []string{}, + }) } entries := []string{} for i := len(history.Results()) - 1; i >= 0; i-- { @@ -221,10 +242,14 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") + history, err := svc.GetAgentObservablesForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } + if history == nil { + history = []json.RawMessage{} + } return c.JSON(http.StatusOK, map[string]any{ "Name": name, "History": history, @@ -278,26 +303,30 @@ func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") - manager := svc.GetSSEManagerForUser(userID, name) - if manager == nil { - return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) - } - return services.HandleSSE(c, manager) - } -} -type agentConfigMetaResponse struct { - state.AgentConfigMeta - OutputsDir string `json:"OutputsDir"` + // Try local SSE manager first + manager := svc.GetSSEManagerForUser(userID, name) + if manager != nil { + return agentpool.HandleSSE(c, manager) + } + + // Fall back to distributed EventBridge SSE + var bridge *agents.EventBridge + if d := app.Distributed(); d != nil { + bridge = d.AgentBridge + } + if bridge != nil { + return bridge.HandleSSE(c, name, userID) + } + + return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) + } } func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - return c.JSON(http.StatusOK, agentConfigMetaResponse{ - AgentConfigMeta: svc.GetConfigMeta(), - OutputsDir: svc.OutputsDir(), - }) + return c.JSON(http.StatusOK, svc.GetConfigMetaResult()) } } diff --git a/core/http/endpoints/localai/backend.go b/core/http/endpoints/localai/backend.go index f804f1b35..2a10258d2 100644 --- a/core/http/endpoints/localai/backend.go +++ b/core/http/endpoints/localai/backend.go @@ -10,7 +10,7 @@ import ( "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) @@ -19,14 +19,14 @@ type BackendEndpointService struct { galleries []config.Gallery backendPath string backendSystemPath string - backendApplier *services.GalleryService + backendApplier *galleryop.GalleryService } type GalleryBackend struct { ID string `json:"id"` } -func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService { +func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService) BackendEndpointService { return BackendEndpointService{ galleries: galleries, backendPath: systemState.Backend.BackendsPath, @@ -37,7 +37,7 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste // GetOpStatusEndpoint returns the job status // @Summary Returns the job status -// @Success 200 {object} services.GalleryOpStatus "Response" +// @Success 200 {object} galleryop.OpStatus "Response" // @Router /backends/jobs/{uuid} [get] func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { return func(c echo.Context) error { @@ -51,7 +51,7 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { // GetAllStatusEndpoint returns all the jobs status progress // @Summary Returns all the jobs status progress -// @Success 200 {object} map[string]services.GalleryOpStatus "Response" +// @Success 200 {object} map[string]galleryop.OpStatus "Response" // @Router /backends/jobs [get] func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc { return func(c echo.Context) error { @@ -76,7 +76,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc { if err != nil { return err } - mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{ + mgs.backendApplier.BackendGalleryChannel <- galleryop.ManagementOp[gallery.GalleryBackend, any]{ ID: uuid.String(), GalleryElementName: input.ID, Galleries: mgs.galleries, @@ -95,7 +95,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc { return func(c echo.Context) error { backendName := c.Param("name") - mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{ + mgs.backendApplier.BackendGalleryChannel <- galleryop.ManagementOp[gallery.GalleryBackend, any]{ Delete: true, GalleryElementName: backendName, Galleries: mgs.galleries, @@ -114,9 +114,9 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc { // @Summary List all Backends // @Success 200 {object} []gallery.GalleryBackend "Response" // @Router /backends [get] -func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc { +func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc { return func(c echo.Context) error { - backends, err := gallery.ListSystemBackends(systemState) + backends, err := mgs.backendApplier.ListBackends() if err != nil { return err } diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go index 18016c579..43cff0b14 100644 --- a/core/http/endpoints/localai/backend_monitor.go +++ b/core/http/endpoints/localai/backend_monitor.go @@ -3,7 +3,7 @@ package localai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/monitoring" ) // BackendMonitorEndpoint returns the status of the specified backend @@ -11,7 +11,7 @@ import ( // @Param request body schema.BackendMonitorRequest true "Backend statistics request" // @Success 200 {object} proto.StatusResponse "Response" // @Router /backend/monitor [get] -func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { +func BackendMonitorEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.BackendMonitorRequest) @@ -32,7 +32,7 @@ func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc // @Summary Backend monitor endpoint // @Param request body schema.BackendMonitorRequest true "Backend statistics request" // @Router /backend/shutdown [post] -func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { +func BackendShutdownEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.BackendMonitorRequest) // Get input data from the request body diff --git a/core/http/endpoints/localai/cors_proxy.go b/core/http/endpoints/localai/cors_proxy.go index d776aa3b3..3b32d410b 100644 --- a/core/http/endpoints/localai/cors_proxy.go +++ b/core/http/endpoints/localai/cors_proxy.go @@ -1,8 +1,10 @@ package localai import ( + "context" "fmt" "io" + "net" "net/http" "net/url" "strings" @@ -13,8 +15,31 @@ import ( "github.com/mudler/xlog" ) -var corsProxyClient = &http.Client{ - Timeout: 10 * time.Minute, +var privateNetworks []*net.IPNet + +func init() { + for _, cidr := range []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "127.0.0.0/8", + "169.254.0.0/16", + "::1/128", + "fc00::/7", + "fe80::/10", + } { + _, network, _ := net.ParseCIDR(cidr) + privateNetworks = append(privateNetworks, network) + } +} + +func isPrivateIP(ip net.IP) bool { + for _, network := range privateNetworks { + if network.Contains(ip) { + return true + } + } + return false } // CORSProxyEndpoint proxies HTTP requests to external MCP servers, @@ -36,6 +61,35 @@ func CORSProxyEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc { return c.JSON(http.StatusBadRequest, map[string]string{"error": "only http and https schemes are supported"}) } + ips, err := net.LookupIP(parsed.Hostname()) + if err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "cannot resolve hostname"}) + } + for _, ip := range ips { + if isPrivateIP(ip) { + return c.JSON(http.StatusForbidden, map[string]string{"error": "requests to private networks are not allowed"}) + } + } + + // Pin the connection to the validated IP to prevent DNS rebinding (TOCTOU) + validIP := ips[0] + port := parsed.Port() + if port == "" { + if parsed.Scheme == "https" { + port = "443" + } else { + port = "80" + } + } + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return (&net.Dialer{Timeout: 10 * time.Second}).DialContext( + ctx, network, net.JoinHostPort(validIP.String(), port), + ) + }, + } + client := &http.Client{Transport: transport, Timeout: 10 * time.Minute} + xlog.Debug("CORS proxy request", "method", c.Request().Method, "target", targetURL) proxyReq, err := http.NewRequestWithContext( @@ -52,7 +106,9 @@ func CORSProxyEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc { skipHeaders := map[string]bool{ "Host": true, "Connection": true, "Keep-Alive": true, "Transfer-Encoding": true, "Upgrade": true, "Origin": true, - "Referer": true, + "Referer": true, + "Authorization": true, "Cookie": true, + "X-Api-Key": true, "Proxy-Authorization": true, } for key, values := range c.Request().Header { if skipHeaders[key] { @@ -63,7 +119,7 @@ func CORSProxyEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc { } } - resp, err := corsProxyClient.Do(proxyReq) + resp, err := client.Do(proxyReq) if err != nil { xlog.Error("CORS proxy request failed", "error", err, "target", targetURL) return c.JSON(http.StatusBadGateway, map[string]string{"error": "proxy request failed: " + err.Error()}) @@ -90,8 +146,9 @@ func CORSProxyEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc { c.Response().WriteHeader(resp.StatusCode) - // Stream the response body - _, err = io.Copy(c.Response().Writer, resp.Body) + // Stream the response body with a size limit + const maxProxyResponseSize = 100 << 20 // 100 MB + _, err = io.Copy(c.Response().Writer, io.LimitReader(resp.Body, maxProxyResponseSize)) return err } } diff --git a/core/http/endpoints/localai/edit_model.go b/core/http/endpoints/localai/edit_model.go index 50d80b92e..38fdfd1d4 100644 --- a/core/http/endpoints/localai/edit_model.go +++ b/core/http/endpoints/localai/edit_model.go @@ -60,14 +60,14 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio // Render the edit page with the current configuration templateData := struct { - Title string - ModelName string - Config *config.ModelConfig - ConfigJSON string - ConfigYAML string - BaseURL string - Version string - DisableRuntimeSettings bool + Title string + ModelName string + Config *config.ModelConfig + ConfigJSON string + ConfigYAML string + BaseURL string + Version string + DisableRuntimeSettings bool }{ Title: "LocalAI - Edit Model " + modelName, ModelName: modelName, diff --git a/core/http/endpoints/localai/edit_model_test.go b/core/http/endpoints/localai/edit_model_test.go index b354dbc2b..f944c28bf 100644 --- a/core/http/endpoints/localai/edit_model_test.go +++ b/core/http/endpoints/localai/edit_model_test.go @@ -20,7 +20,7 @@ import ( // testRenderer is a simple renderer for tests that returns JSON type testRenderer struct{} -func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error { +func (t *testRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { // For tests, just return the data as JSON return json.NewEncoder(w).Encode(data) } diff --git a/core/http/endpoints/localai/finetune.go b/core/http/endpoints/localai/finetune.go index fe735acb2..23b0683c7 100644 --- a/core/http/endpoints/localai/finetune.go +++ b/core/http/endpoints/localai/finetune.go @@ -15,11 +15,11 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/finetune" ) // StartFineTuneJobEndpoint starts a new fine-tuning job. -func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func StartFineTuneJobEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) @@ -53,7 +53,7 @@ func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerF } // ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user. -func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func ListFineTuneJobsEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobs := ftService.ListJobs(userID) @@ -65,7 +65,7 @@ func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerF } // GetFineTuneJobEndpoint gets a specific fine-tuning job. -func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func GetFineTuneJobEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -82,7 +82,7 @@ func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFun } // StopFineTuneJobEndpoint stops a running fine-tuning job. -func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func StopFineTuneJobEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -105,7 +105,7 @@ func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFu } // DeleteFineTuneJobEndpoint deletes a fine-tuning job and its data. -func DeleteFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func DeleteFineTuneJobEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -131,7 +131,7 @@ func DeleteFineTuneJobEndpoint(ftService *services.FineTuneService) echo.Handler } // FineTuneProgressEndpoint streams progress updates via SSE. -func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func FineTuneProgressEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -161,7 +161,7 @@ func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerF } // ListCheckpointsEndpoint lists checkpoints for a job. -func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func ListCheckpointsEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -180,7 +180,7 @@ func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFu } // ExportModelEndpoint exports a model from a checkpoint. -func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func ExportModelEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -208,7 +208,7 @@ func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { } // DownloadExportedModelEndpoint streams the exported model directory as a tar.gz archive. -func DownloadExportedModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func DownloadExportedModelEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -324,7 +324,7 @@ func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.Hand } // UploadDatasetEndpoint handles dataset file upload. -func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { +func UploadDatasetEndpoint(ftService *finetune.FineTuneService) echo.HandlerFunc { return func(c echo.Context) error { file, err := c.FormFile("file") if err != nil { diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go index 5c87e6d05..e22ba3679 100644 --- a/core/http/endpoints/localai/gallery.go +++ b/core/http/endpoints/localai/gallery.go @@ -10,7 +10,7 @@ import ( "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) @@ -19,7 +19,7 @@ type ModelGalleryEndpointService struct { galleries []config.Gallery backendGalleries []config.Gallery modelPath string - galleryApplier *services.GalleryService + galleryApplier *galleryop.GalleryService configLoader *config.ModelConfigLoader } @@ -28,7 +28,7 @@ type GalleryModel struct { gallery.GalleryModel } -func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService, configLoader *config.ModelConfigLoader) ModelGalleryEndpointService { +func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *galleryop.GalleryService, configLoader *config.ModelConfigLoader) ModelGalleryEndpointService { return ModelGalleryEndpointService{ galleries: galleries, backendGalleries: backendGalleries, @@ -40,7 +40,7 @@ func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGaller // GetOpStatusEndpoint returns the job status // @Summary Returns the job status -// @Success 200 {object} services.GalleryOpStatus "Response" +// @Success 200 {object} galleryop.OpStatus "Response" // @Router /models/jobs/{uuid} [get] func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { return func(c echo.Context) error { @@ -54,7 +54,7 @@ func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { // GetAllStatusEndpoint returns all the jobs status progress // @Summary Returns all the jobs status progress -// @Success 200 {object} map[string]services.GalleryOpStatus "Response" +// @Success 200 {object} map[string]galleryop.OpStatus "Response" // @Router /models/jobs [get] func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc { return func(c echo.Context) error { @@ -79,7 +79,7 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.Handler if err != nil { return err } - mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ + mgs.galleryApplier.ModelGalleryChannel <- galleryop.ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{ Req: input.GalleryModel, ID: uuid.String(), GalleryElementName: input.ID, @@ -100,7 +100,7 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.Handle return func(c echo.Context) error { modelName := c.Param("name") - mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ + mgs.galleryApplier.ModelGalleryChannel <- galleryop.ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{ Delete: true, GalleryElementName: modelName, } diff --git a/core/http/endpoints/localai/import_model.go b/core/http/endpoints/localai/import_model.go index 7fb03c617..a1931bae9 100644 --- a/core/http/endpoints/localai/import_model.go +++ b/core/http/endpoints/localai/import_model.go @@ -18,7 +18,7 @@ import ( "github.com/mudler/LocalAI/core/gallery/importers" httpUtils "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/vram" @@ -26,7 +26,7 @@ import ( ) // ImportModelURIEndpoint handles creating new model configurations from a URI -func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc { +func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.ImportModelRequest) @@ -51,8 +51,10 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl } estCtx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second) defer cancel() - opts := vram.EstimateOptions{ContextLength: 8192} - result, err := vram.Estimate(estCtx, files, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader()) + result, err := vram.EstimateModel(estCtx, vram.ModelEstimateInput{ + Files: files, + Options: vram.EstimateOptions{ContextLength: 8192}, + }) if err == nil { if result.SizeBytes > 0 { resp.EstimatedSizeBytes = result.SizeBytes @@ -81,9 +83,9 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl opcache.Set(galleryID, uuid.String()) } - galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ + galleryService.ModelGalleryChannel <- galleryop.ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{ Req: gallery.GalleryModel{ - Overrides: map[string]interface{}{}, + Overrides: map[string]any{}, }, ID: uuid.String(), GalleryElementName: galleryID, diff --git a/core/http/endpoints/localai/mcp.go b/core/http/endpoints/localai/mcp.go index 0ff75f4a9..0db18a9c5 100644 --- a/core/http/endpoints/localai/mcp.go +++ b/core/http/endpoints/localai/mcp.go @@ -6,6 +6,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" + mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" @@ -20,10 +21,10 @@ type MCPReasoningEvent struct { } type MCPToolCallEvent struct { - Type string `json:"type"` - Name string `json:"name"` - Arguments map[string]interface{} `json:"arguments"` - Reasoning string `json:"reasoning"` + Type string `json:"type"` + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + Reasoning string `json:"reasoning"` } type MCPToolResultEvent struct { @@ -55,8 +56,8 @@ type MCPErrorEvent struct { // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/mcp/chat/completions [post] -func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { - chatHandler := openai.ChatEndpoint(cl, ml, evaluator, appConfig) +func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc { + chatHandler := openai.ChatEndpoint(cl, ml, evaluator, appConfig, natsClient) return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) diff --git a/core/http/endpoints/localai/mcp_tools.go b/core/http/endpoints/localai/mcp_tools.go index 0ec43529b..6cc239463 100644 --- a/core/http/endpoints/localai/mcp_tools.go +++ b/core/http/endpoints/localai/mcp_tools.go @@ -2,6 +2,7 @@ package localai import ( "fmt" + "net/http" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" @@ -11,7 +12,7 @@ import ( // MCPServersEndpoint returns the list of MCP servers and their tools for a given model. // GET /v1/mcp/servers/:model -func MCPServersEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { +func MCPServersEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc { return func(c echo.Context) error { modelName := c.Param("model") if modelName == "" { @@ -20,7 +21,11 @@ func MCPServersEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicat cfg, exists := cl.GetModelConfig(modelName) if !exists { - return fmt.Errorf("model %q not found", modelName) + return c.JSON(http.StatusNotFound, map[string]any{ + "model": modelName, + "servers": []any{}, + "error": fmt.Sprintf("model %q not found", modelName), + }) } if cfg.MCP.Servers == "" && cfg.MCP.Stdio == "" { @@ -35,6 +40,19 @@ func MCPServersEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicat return fmt.Errorf("failed to parse MCP config: %w", err) } + // In distributed mode, route discovery through NATS to an agent worker + // that can actually connect to the MCP servers. + if natsClient != nil { + resp, err := mcpTools.DiscoverMCPToolsRemote(c.Request().Context(), natsClient, cfg.Name, remote, stdio) + if err != nil { + return fmt.Errorf("remote MCP discovery failed: %w", err) + } + return c.JSON(200, map[string]any{ + "model": modelName, + "servers": resp.Servers, + }) + } + namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) if err != nil { return fmt.Errorf("failed to get MCP sessions: %w", err) @@ -54,7 +72,7 @@ func MCPServersEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicat // MCPServersEndpointFromMiddleware is a version that uses the middleware-resolved model config. // This allows it to use the same middleware chain as other endpoints. -func MCPServersEndpointFromMiddleware() echo.HandlerFunc { +func MCPServersEndpointFromMiddleware(natsClient mcpTools.MCPNATSClient) echo.HandlerFunc { return func(c echo.Context) error { cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { @@ -73,6 +91,18 @@ func MCPServersEndpointFromMiddleware() echo.HandlerFunc { return fmt.Errorf("failed to parse MCP config: %w", err) } + // In distributed mode, route discovery through NATS to an agent worker. + if natsClient != nil { + resp, err := mcpTools.DiscoverMCPToolsRemote(c.Request().Context(), natsClient, cfg.Name, remote, stdio) + if err != nil { + return fmt.Errorf("remote MCP discovery failed: %w", err) + } + return c.JSON(200, map[string]any{ + "model": cfg.Name, + "servers": resp.Servers, + }) + } + namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) if err != nil { return fmt.Errorf("failed to get MCP sessions: %w", err) diff --git a/core/http/endpoints/localai/metrics.go b/core/http/endpoints/localai/metrics.go index a5f08a7f6..7fbd043a1 100644 --- a/core/http/endpoints/localai/metrics.go +++ b/core/http/endpoints/localai/metrics.go @@ -4,7 +4,7 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/monitoring" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -18,10 +18,10 @@ func LocalAIMetricsEndpoint() echo.HandlerFunc { type apiMiddlewareConfig struct { Filter func(c echo.Context) bool - metricsService *services.LocalAIMetricsService + metricsService *monitoring.LocalAIMetricsService } -func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc { +func LocalAIMetricsAPIMiddleware(metrics *monitoring.LocalAIMetricsService) echo.MiddlewareFunc { cfg := apiMiddlewareConfig{ metricsService: metrics, Filter: func(c echo.Context) bool { diff --git a/core/http/endpoints/localai/nodes.go b/core/http/endpoints/localai/nodes.go new file mode 100644 index 000000000..bbdb33471 --- /dev/null +++ b/core/http/endpoints/localai/nodes.go @@ -0,0 +1,634 @@ +package localai + +import ( + "context" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "fmt" + "io" + "net/http" + "net/url" + "sync" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/xlog" + "gorm.io/gorm" +) + +// nodeError builds a schema.ErrorResponse for node endpoints. +func nodeError(code int, message string) schema.ErrorResponse { + return schema.ErrorResponse{ + Error: &schema.APIError{ + Code: code, + Message: message, + Type: "node_error", + }, + } +} + +// ListNodesEndpoint returns all registered backend nodes. +func ListNodesEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + nodeList, err := registry.List(ctx) + if err != nil { + xlog.Error("Failed to list nodes", "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list nodes")) + } + return c.JSON(http.StatusOK, nodeList) + } +} + +// GetNodeEndpoint returns a single node by ID. +func GetNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + id := c.Param("id") + node, err := registry.Get(ctx, id) + if err != nil { + return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found")) + } + return c.JSON(http.StatusOK, node) + } +} + +// RegisterNodeRequest is the request body for registering a new worker node. +type RegisterNodeRequest struct { + Name string `json:"name"` + NodeType string `json:"node_type,omitempty"` // "backend" (default) or "agent" + Address string `json:"address"` + HTTPAddress string `json:"http_address,omitempty"` + Token string `json:"token,omitempty"` + TotalVRAM uint64 `json:"total_vram,omitempty"` + AvailableVRAM uint64 `json:"available_vram,omitempty"` + TotalRAM uint64 `json:"total_ram,omitempty"` + AvailableRAM uint64 `json:"available_ram,omitempty"` + GPUVendor string `json:"gpu_vendor,omitempty"` +} + +// RegisterNodeEndpoint registers a new backend node. +// expectedToken is the registration token configured on the frontend (may be empty to disable auth). +// autoApprove controls whether new nodes go directly to "healthy" or require admin approval. +func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc { + return func(c echo.Context) error { + var req RegisterNodeRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body")) + } + + // Validate registration token if one is configured on the frontend + if expectedToken != "" { + if req.Token == "" { + return c.JSON(http.StatusUnauthorized, nodeError(http.StatusUnauthorized, "registration token required")) + } + expectedHash := sha256.Sum256([]byte(expectedToken)) + providedHash := sha256.Sum256([]byte(req.Token)) + if subtle.ConstantTimeCompare(expectedHash[:], providedHash[:]) != 1 { + return c.JSON(http.StatusUnauthorized, nodeError(http.StatusUnauthorized, "invalid registration token")) + } + } + + // Determine node type + nodeType := req.NodeType + if nodeType == "" { + nodeType = nodes.NodeTypeBackend + } + if nodeType != nodes.NodeTypeBackend && nodeType != nodes.NodeTypeAgent { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, + fmt.Sprintf("invalid node_type %q; must be %q or %q", nodeType, nodes.NodeTypeBackend, nodes.NodeTypeAgent))) + } + + // Backend workers require address; agent workers don't serve gRPC + if req.Name == "" { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "name is required")) + } + if nodeType == nodes.NodeTypeBackend && req.Address == "" { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "address is required for backend workers")) + } + if len(req.Name) > 255 { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "name exceeds 255 characters")) + } + if len(req.Address) > 512 { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "address exceeds 512 characters")) + } + if len(req.HTTPAddress) > 512 { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "http_address exceeds 512 characters")) + } + + // Hash the token for storage (if provided) + var tokenHash string + if req.Token != "" { + h := sha256.Sum256([]byte(req.Token)) + tokenHash = hex.EncodeToString(h[:]) + } + + node := &nodes.BackendNode{ + Name: req.Name, + NodeType: nodeType, + Address: req.Address, + HTTPAddress: req.HTTPAddress, + TokenHash: tokenHash, + TotalVRAM: req.TotalVRAM, + AvailableVRAM: req.AvailableVRAM, + TotalRAM: req.TotalRAM, + AvailableRAM: req.AvailableRAM, + GPUVendor: req.GPUVendor, + } + + ctx := c.Request().Context() + if err := registry.Register(ctx, node, autoApprove); err != nil { + xlog.Error("Failed to register node", "name", req.Name, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to register node")) + } + + response := map[string]any{ + "id": node.ID, + "name": node.Name, + "node_type": node.NodeType, + "status": node.Status, + "created_at": node.CreatedAt, + } + + // Provision API key for agent workers that are approved (not pending). + // On re-registration of a previously approved node, revoke old + provision new. + if nodeType == nodes.NodeTypeAgent && authDB != nil && node.Status != nodes.StatusPending { + // Use a transaction so that if provisioning fails after revoking old creds, + // the old credentials are not lost. + txErr := authDB.Transaction(func(tx *gorm.DB) error { + if node.AuthUserID != "" { + if err := tx.Exec("DELETE FROM users WHERE id = ?", node.AuthUserID).Error; err != nil { + return fmt.Errorf("revoking old credentials: %w", err) + } + node.AuthUserID = "" + node.APIKeyID = "" + } + plaintext, err := provisionAgentWorkerKey(ctx, tx, registry, node, hmacSecret) + if err != nil { + return err + } + response["api_token"] = plaintext + return nil + }) + if txErr != nil { + xlog.Warn("Failed to auto-provision API key for agent worker", "node", node.Name, "error", txErr) + } + } + + return c.JSON(http.StatusCreated, response) + } +} + +// ApproveNodeEndpoint approves a pending node, setting its status to healthy. +// For agent workers, it also provisions an API key so they can call the inference API. +func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + id := c.Param("id") + if err := registry.ApproveNode(ctx, id); err != nil { + xlog.Error("Failed to approve node", "id", id, "error", err) + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "failed to approve node")) + } + node, err := registry.Get(ctx, id) + if err != nil { + return c.JSON(http.StatusOK, map[string]string{"message": "node approved"}) + } + + response := map[string]any{ + "id": node.ID, + "name": node.Name, + "node_type": node.NodeType, + "status": node.Status, + "message": "node approved", + } + + // Provision API key for newly approved agent workers + if node.NodeType == nodes.NodeTypeAgent && authDB != nil && node.AuthUserID == "" { + if plaintext, err := provisionAgentWorkerKey(ctx, authDB, registry, node, hmacSecret); err != nil { + xlog.Warn("Failed to provision API key on approval", "node", node.Name, "error", err) + } else { + response["api_token"] = plaintext + } + } + + return c.JSON(http.StatusOK, response) + } +} + +// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node. +// Returns the plaintext API key on success. +func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) { + workerUser := &auth.User{ + ID: uuid.New().String(), + Name: "agent-worker:" + node.Name, + Provider: auth.ProviderAgentWorker, + Subject: node.ID, + Role: "user", + Status: "active", + CreatedAt: time.Now(), + } + if err := authDB.Create(workerUser).Error; err != nil { + return "", fmt.Errorf("creating agent worker user: %w", err) + } + + plaintext, apiKey, err := auth.CreateAPIKey(authDB, workerUser.ID, "agent-worker:"+node.Name, "user", hmacSecret, nil) + if err != nil { + return "", fmt.Errorf("creating API key: %w", err) + } + + node.AuthUserID = workerUser.ID + node.APIKeyID = apiKey.ID + if err := registry.UpdateAuthRefs(ctx, node.ID, workerUser.ID, apiKey.ID); err != nil { + xlog.Warn("Failed to update auth refs on node", "node", node.Name, "error", err) + } + + // Grant collections feature so the worker can store/retrieve KB data on behalf of users. + perm := &auth.UserPermission{ + ID: uuid.New().String(), + UserID: workerUser.ID, + Permissions: auth.PermissionMap{auth.FeatureCollections: true}, + } + if err := authDB.Create(perm).Error; err != nil { + xlog.Warn("Failed to grant collections permission to agent worker", "node", node.Name, "error", err) + } + + xlog.Info("Provisioned API key for agent worker", "node", node.Name, "user", workerUser.ID) + return plaintext, nil +} + +// DeregisterNodeEndpoint removes a backend node permanently (admin use). +func DeregisterNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + id := c.Param("id") + if err := registry.Deregister(ctx, id); err != nil { + xlog.Error("Failed to deregister node", "id", id, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to deregister node")) + } + return c.JSON(http.StatusOK, map[string]string{"message": "node deregistered"}) + } +} + +// DeactivateNodeEndpoint marks a node as offline without deleting it. +// Used by workers on graceful shutdown to preserve approval status across restarts. +func DeactivateNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + id := c.Param("id") + if err := registry.MarkOffline(ctx, id); err != nil { + xlog.Error("Failed to deactivate node", "id", id, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to deactivate node")) + } + return c.JSON(http.StatusOK, map[string]string{"message": "node set to offline"}) + } +} + +// HeartbeatEndpoint updates the heartbeat for a node. +func HeartbeatEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + id := c.Param("id") + + // Parse optional VRAM update from body + var update nodes.HeartbeatUpdate + _ = c.Bind(&update) // best-effort — empty body is fine + + var updatePtr *nodes.HeartbeatUpdate + if update.AvailableVRAM != nil || update.TotalVRAM != nil || update.AvailableRAM != nil || update.GPUVendor != "" { + updatePtr = &update + } + + ctx := c.Request().Context() + if err := registry.Heartbeat(ctx, id, updatePtr); err != nil { + xlog.Warn("Heartbeat failed for node", "id", id, "error", err) + return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found")) + } + return c.JSON(http.StatusOK, map[string]string{"message": "heartbeat received"}) + } +} + +// GetNodeModelsEndpoint returns the models loaded on a node. +func GetNodeModelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + id := c.Param("id") + models, err := registry.GetNodeModels(ctx, id) + if err != nil { + xlog.Error("Failed to get node models", "id", id, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get node models")) + } + return c.JSON(http.StatusOK, models) + } +} + +// DrainNodeEndpoint sets a node to draining status (no new requests). +func DrainNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + id := c.Param("id") + if err := registry.MarkDraining(ctx, id); err != nil { + xlog.Error("Failed to drain node", "id", id, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to drain node")) + } + return c.JSON(http.StatusOK, map[string]string{"message": "node set to draining"}) + } +} + +// InstallBackendOnNodeEndpoint triggers backend installation on a worker node via NATS. +func InstallBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc { + return func(c echo.Context) error { + if unloader == nil { + return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured")) + } + nodeID := c.Param("id") + var req struct { + Backend string `json:"backend"` + BackendGalleries string `json:"backend_galleries,omitempty"` + } + if err := c.Bind(&req); err != nil || req.Backend == "" { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name required")) + } + reply, err := unloader.InstallBackend(nodeID, req.Backend, "", req.BackendGalleries) + if err != nil { + xlog.Error("Failed to install backend on node", "node", nodeID, "backend", req.Backend, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to install backend on node")) + } + if !reply.Success { + xlog.Error("Backend install failed on node", "node", nodeID, "backend", req.Backend, "error", reply.Error) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "backend installation failed")) + } + return c.JSON(http.StatusOK, map[string]string{"message": "backend installed"}) + } +} + +// DeleteBackendOnNodeEndpoint deletes a backend from a worker node via NATS. +func DeleteBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc { + return func(c echo.Context) error { + if unloader == nil { + return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured")) + } + nodeID := c.Param("id") + var req struct { + Backend string `json:"backend"` + } + if err := c.Bind(&req); err != nil || req.Backend == "" { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name required")) + } + reply, err := unloader.DeleteBackend(nodeID, req.Backend) + if err != nil { + xlog.Error("Failed to delete backend on node", "node", nodeID, "backend", req.Backend, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to delete backend on node")) + } + if !reply.Success { + xlog.Error("Backend delete failed on node", "node", nodeID, "backend", req.Backend, "error", reply.Error) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "backend deletion failed")) + } + return c.JSON(http.StatusOK, map[string]string{"message": "backend deleted"}) + } +} + +// ListBackendsOnNodeEndpoint lists installed backends on a worker node via NATS. +func ListBackendsOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc { + return func(c echo.Context) error { + if unloader == nil { + return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured")) + } + nodeID := c.Param("id") + reply, err := unloader.ListBackends(nodeID) + if err != nil { + xlog.Error("Failed to list backends on node", "node", nodeID, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list backends on node")) + } + if reply.Error != "" { + xlog.Error("List backends failed on node", "node", nodeID, "error", reply.Error) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list backends on node")) + } + return c.JSON(http.StatusOK, reply.Backends) + } +} + +// UnloadModelOnNodeEndpoint unloads a model from a worker node (gRPC Free) via NATS. +func UnloadModelOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + if unloader == nil { + return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured")) + } + nodeID := c.Param("id") + var req struct { + ModelName string `json:"model_name"` + } + if err := c.Bind(&req); err != nil || req.ModelName == "" { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name required")) + } + if err := unloader.UnloadModelOnNode(nodeID, req.ModelName); err != nil { + xlog.Error("Failed to unload model on node", "node", nodeID, "model", req.ModelName, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to unload model on node")) + } + // Also stop the backend process + if err := unloader.StopBackend(nodeID, req.ModelName); err != nil { + xlog.Error("Failed to stop backend after model unload", "node", nodeID, "model", req.ModelName, "error", err) + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "model unloaded but backend stop failed")) + } + // Remove from registry + registry.RemoveNodeModel(c.Request().Context(), nodeID, req.ModelName) + return c.JSON(http.StatusOK, map[string]string{"message": "model unloaded"}) + } +} + +// DeleteModelOnNodeEndpoint deletes model files from a worker node via NATS. +func DeleteModelOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + if unloader == nil { + return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured")) + } + nodeID := c.Param("id") + var req struct { + ModelName string `json:"model_name"` + } + if err := c.Bind(&req); err != nil || req.ModelName == "" { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name required")) + } + // Unload model first if loaded + if err := unloader.UnloadModelOnNode(nodeID, req.ModelName); err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to unload model before deletion")) + } + if err := unloader.StopBackend(nodeID, req.ModelName); err != nil { + // Non-fatal — backend process may not be running + xlog.Warn("StopBackend failed during model deletion (non-fatal)", "node", nodeID, "model", req.ModelName, "error", err) + } + registry.RemoveNodeModel(c.Request().Context(), nodeID, req.ModelName) + return c.JSON(http.StatusOK, map[string]string{"message": "model deleted from node"}) + } +} + +// NodeBackendLogsListEndpoint proxies a request to a worker node's /v1/backend-logs +// endpoint to list model IDs that have backend logs. +func NodeBackendLogsListEndpoint(registry *nodes.NodeRegistry, registrationToken string) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + nodeID := c.Param("id") + node, err := registry.Get(ctx, nodeID) + if err != nil { + return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found")) + } + + if node.HTTPAddress == "" { + return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, "node has no HTTP address")) + } + + resp, err := proxyHTTPToWorker(node.HTTPAddress, "/v1/backend-logs", registrationToken) + if err != nil { + return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, fmt.Sprintf("failed to reach worker: %v", err))) + } + defer resp.Body.Close() + + c.Response().Header().Set("Content-Type", "application/json") + c.Response().WriteHeader(resp.StatusCode) + io.Copy(c.Response(), resp.Body) + return nil + } +} + +// NodeBackendLogsLinesEndpoint proxies a request to a worker node's +// /v1/backend-logs/{modelId} endpoint to get buffered log lines. +func NodeBackendLogsLinesEndpoint(registry *nodes.NodeRegistry, registrationToken string) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + nodeID := c.Param("id") + modelID := c.Param("modelId") + + node, err := registry.Get(ctx, nodeID) + if err != nil { + return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found")) + } + + if node.HTTPAddress == "" { + return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, "node has no HTTP address")) + } + + path := "/v1/backend-logs/" + url.PathEscape(modelID) + resp, err := proxyHTTPToWorker(node.HTTPAddress, path, registrationToken) + if err != nil { + return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, fmt.Sprintf("failed to reach worker: %v", err))) + } + defer resp.Body.Close() + + c.Response().Header().Set("Content-Type", "application/json") + c.Response().WriteHeader(resp.StatusCode) + io.Copy(c.Response(), resp.Body) + return nil + } +} + +// NodeBackendLogsWSEndpoint proxies a WebSocket connection to a worker node's +// /v1/backend-logs/{modelId}/ws endpoint for real-time log streaming. +func NodeBackendLogsWSEndpoint(registry *nodes.NodeRegistry, registrationToken string) echo.HandlerFunc { + browserUpgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true // no origin header = same-origin or non-browser + } + // Parse origin URL and compare host with request host + u, err := url.Parse(origin) + if err != nil { + return false + } + return u.Host == r.Host + }, + } + + return func(c echo.Context) error { + ctx := c.Request().Context() + nodeID := c.Param("id") + modelID := c.Param("modelId") + + node, err := registry.Get(ctx, nodeID) + if err != nil { + return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found")) + } + + // Upgrade browser connection + browserWS, err := browserUpgrader.Upgrade(c.Response(), c.Request(), nil) + if err != nil { + return err + } + + // Dial the worker WebSocket + workerURL := fmt.Sprintf("ws://%s/v1/backend-logs/%s/ws", node.HTTPAddress, url.PathEscape(modelID)) + workerHeaders := http.Header{} + if registrationToken != "" { + workerHeaders.Set("Authorization", "Bearer "+registrationToken) + } + + workerDialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second} + workerWS, _, err := workerDialer.Dial(workerURL, workerHeaders) + if err != nil { + browserWS.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "failed to connect to worker")) + browserWS.Close() + return nil + } + + // Use sync.OnceFunc wrappers to avoid double-close and ensure each + // goroutine can safely close the *other* connection to unblock + // its peer's ReadMessage call. + done := make(chan struct{}) + closeWorker := sync.OnceFunc(func() { workerWS.Close() }) + closeBrowser := sync.OnceFunc(func() { browserWS.Close() }) + + // Worker → Browser + go func() { + defer close(done) + defer closeBrowser() // unblock Browser→Worker goroutine + for { + msgType, msg, err := workerWS.ReadMessage() + if err != nil { + return + } + if err := browserWS.WriteMessage(msgType, msg); err != nil { + return + } + } + }() + + // Browser → Worker (mainly for close detection) + go func() { + defer closeWorker() // unblock Worker→Browser goroutine + for { + msgType, msg, err := browserWS.ReadMessage() + if err != nil { + return + } + if err := workerWS.WriteMessage(msgType, msg); err != nil { + return + } + } + }() + + <-done + closeWorker() + closeBrowser() + return nil + } +} + +// proxyHTTPToWorker makes a GET request to a worker's HTTP server with bearer token auth. +func proxyHTTPToWorker(httpAddress, path, token string) (*http.Response, error) { + reqURL := fmt.Sprintf("http://%s%s", httpAddress, path) + req, err := http.NewRequest("GET", reqURL, nil) + if err != nil { + return nil, err + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + client := &http.Client{Timeout: 15 * time.Second} + return client.Do(req) +} diff --git a/core/http/endpoints/localai/nodes_test.go b/core/http/endpoints/localai/nodes_test.go new file mode 100644 index 000000000..526056cda --- /dev/null +++ b/core/http/endpoints/localai/nodes_test.go @@ -0,0 +1,229 @@ +package localai + +import ( + "context" + "crypto/sha256" + "crypto/subtle" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/testutil" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = DescribeTable("token validation", + func(expectedToken, providedToken string, wantMatch bool) { + if expectedToken == "" { + // No auth required — always matches + Expect(wantMatch).To(BeTrue(), "no-auth should always pass") + return + } + + if providedToken == "" { + Expect(wantMatch).To(BeFalse(), "empty token should be rejected") + return + } + + expectedHash := sha256.Sum256([]byte(expectedToken)) + providedHash := sha256.Sum256([]byte(providedToken)) + match := subtle.ConstantTimeCompare(expectedHash[:], providedHash[:]) == 1 + + Expect(match).To(Equal(wantMatch)) + }, + Entry("matching tokens", "my-secret-token", "my-secret-token", true), + Entry("mismatched tokens", "my-secret-token", "wrong-token", false), + Entry("empty expected (no auth)", "", "any-token", true), + Entry("empty provided when expected set", "my-secret-token", "", false), +) + +var _ = Describe("Node HTTP handlers", func() { + var ( + registry *nodes.NodeRegistry + ) + + BeforeEach(func() { + db := testutil.SetupTestDB() + var err error + registry, err = nodes.NewNodeRegistry(db) + Expect(err).ToNot(HaveOccurred()) + }) + + Describe("RegisterNodeEndpoint", func() { + It("registers a backend node and returns 201", func() { + e := echo.New() + body := `{"name":"worker-1","address":"10.0.0.1:50051"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := RegisterNodeEndpoint(registry, "", true, nil, "") + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusCreated)) + + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + Expect(resp["name"]).To(Equal("worker-1")) + Expect(resp["id"]).ToNot(BeEmpty()) + Expect(resp["status"]).To(Equal(nodes.StatusHealthy)) + }) + + It("returns 400 when name is missing", func() { + e := echo.New() + body := `{"address":"10.0.0.1:50051"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := RegisterNodeEndpoint(registry, "", true, nil, "") + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + errObj, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue()) + Expect(errObj["message"]).To(ContainSubstring("name is required")) + }) + + It("returns 400 when name exceeds 255 characters", func() { + e := echo.New() + longName := strings.Repeat("x", 256) + body := `{"name":"` + longName + `","address":"10.0.0.1:50051"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := RegisterNodeEndpoint(registry, "", true, nil, "") + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + errObj, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue()) + Expect(errObj["message"]).To(ContainSubstring("exceeds 255 characters")) + }) + + It("returns 400 when address is missing for backend node type", func() { + e := echo.New() + body := `{"name":"worker-no-addr"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := RegisterNodeEndpoint(registry, "", true, nil, "") + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + errObj, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue()) + Expect(errObj["message"]).To(ContainSubstring("address is required")) + }) + + It("returns 400 when node_type is invalid", func() { + e := echo.New() + body := `{"name":"bad-type","address":"10.0.0.1:50051","node_type":"invalid"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := RegisterNodeEndpoint(registry, "", true, nil, "") + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + errObj, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue()) + Expect(errObj["message"]).To(ContainSubstring("invalid node_type")) + }) + + It("returns 401 when registration token is wrong", func() { + e := echo.New() + body := `{"name":"worker-1","address":"10.0.0.1:50051","token":"wrong-token"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "") + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("sets status to pending when autoApprove is false", func() { + e := echo.New() + body := `{"name":"pending-worker","address":"10.0.0.1:50051"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := RegisterNodeEndpoint(registry, "", false, nil, "") + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusCreated)) + + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + Expect(resp["status"]).To(Equal(nodes.StatusPending)) + }) + }) + + Describe("ListNodesEndpoint", func() { + It("returns an empty list when no nodes are registered", func() { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := ListNodesEndpoint(registry) + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusOK)) + + var list []nodes.BackendNode + Expect(json.Unmarshal(rec.Body.Bytes(), &list)).To(Succeed()) + Expect(list).To(BeEmpty()) + }) + + It("returns registered nodes", func() { + // Register two nodes directly via the registry + ctx := context.Background() + Expect(registry.Register(ctx, &nodes.BackendNode{ + Name: "alpha", + Address: "10.0.0.1:50051", + }, true)).To(Succeed()) + Expect(registry.Register(ctx, &nodes.BackendNode{ + Name: "beta", + Address: "10.0.0.2:50051", + }, true)).To(Succeed()) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := ListNodesEndpoint(registry) + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusOK)) + + var list []nodes.BackendNode + Expect(json.Unmarshal(rec.Body.Bytes(), &list)).To(Succeed()) + Expect(list).To(HaveLen(2)) + names := []string{list[0].Name, list[1].Name} + Expect(names).To(ConsistOf("alpha", "beta")) + }) + }) +}) diff --git a/core/http/endpoints/localai/quantization.go b/core/http/endpoints/localai/quantization.go index e3bf011da..1c0af5ef8 100644 --- a/core/http/endpoints/localai/quantization.go +++ b/core/http/endpoints/localai/quantization.go @@ -10,11 +10,11 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/quantization" ) // StartQuantizationJobEndpoint starts a new quantization job. -func StartQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc { +func StartQuantizationJobEndpoint(qService *quantization.QuantizationService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) @@ -43,7 +43,7 @@ func StartQuantizationJobEndpoint(qService *services.QuantizationService) echo.H } // ListQuantizationJobsEndpoint lists quantization jobs for the current user. -func ListQuantizationJobsEndpoint(qService *services.QuantizationService) echo.HandlerFunc { +func ListQuantizationJobsEndpoint(qService *quantization.QuantizationService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobs := qService.ListJobs(userID) @@ -55,7 +55,7 @@ func ListQuantizationJobsEndpoint(qService *services.QuantizationService) echo.H } // GetQuantizationJobEndpoint gets a specific quantization job. -func GetQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc { +func GetQuantizationJobEndpoint(qService *quantization.QuantizationService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -72,7 +72,7 @@ func GetQuantizationJobEndpoint(qService *services.QuantizationService) echo.Han } // StopQuantizationJobEndpoint stops a running quantization job. -func StopQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc { +func StopQuantizationJobEndpoint(qService *quantization.QuantizationService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -92,7 +92,7 @@ func StopQuantizationJobEndpoint(qService *services.QuantizationService) echo.Ha } // DeleteQuantizationJobEndpoint deletes a quantization job and its data. -func DeleteQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc { +func DeleteQuantizationJobEndpoint(qService *quantization.QuantizationService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -118,7 +118,7 @@ func DeleteQuantizationJobEndpoint(qService *services.QuantizationService) echo. } // QuantizationProgressEndpoint streams progress updates via SSE. -func QuantizationProgressEndpoint(qService *services.QuantizationService) echo.HandlerFunc { +func QuantizationProgressEndpoint(qService *quantization.QuantizationService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -148,7 +148,7 @@ func QuantizationProgressEndpoint(qService *services.QuantizationService) echo.H } // ImportQuantizedModelEndpoint imports a quantized model into LocalAI. -func ImportQuantizedModelEndpoint(qService *services.QuantizationService) echo.HandlerFunc { +func ImportQuantizedModelEndpoint(qService *quantization.QuantizationService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") @@ -176,7 +176,7 @@ func ImportQuantizedModelEndpoint(qService *services.QuantizationService) echo.H } // DownloadQuantizedModelEndpoint streams the quantized model file. -func DownloadQuantizedModelEndpoint(qService *services.QuantizationService) echo.HandlerFunc { +func DownloadQuantizedModelEndpoint(qService *quantization.QuantizationService) echo.HandlerFunc { return func(c echo.Context) error { userID := getUserID(c) jobID := c.Param("id") diff --git a/core/http/endpoints/localai/types.go b/core/http/endpoints/localai/types.go index 32a549089..f1c507472 100644 --- a/core/http/endpoints/localai/types.go +++ b/core/http/endpoints/localai/types.go @@ -2,10 +2,10 @@ package localai // ModelResponse represents the common response structure for model operations type ModelResponse struct { - Success bool `json:"success"` - Message string `json:"message"` - Filename string `json:"filename,omitempty"` - Config interface{} `json:"config,omitempty"` - Error string `json:"error,omitempty"` - Details []string `json:"details,omitempty"` + Success bool `json:"success"` + Message string `json:"message"` + Filename string `json:"filename,omitempty"` + Config any `json:"config,omitempty"` + Error string `json:"error,omitempty"` + Details []string `json:"details,omitempty"` } diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go index ce197ba05..c4bdb8c33 100644 --- a/core/http/endpoints/localai/welcome.go +++ b/core/http/endpoints/localai/welcome.go @@ -7,13 +7,13 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/middleware" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" ) func WelcomeEndpoint(appConfig *config.ApplicationConfig, - cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc { + cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *galleryop.OpCache) echo.HandlerFunc { return func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() galleryConfigs := map[string]*gallery.ModelConfig{} @@ -37,12 +37,12 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig, loadedModelsMap[m.ID] = true } - modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) + modelsWithoutConfig, _ := galleryop.ListModels(cl, ml, config.NoFilterFn, galleryop.LOOSE_ONLY) // Get model statuses to display in the UI the operation in progress processingModels, taskTypes := opcache.GetStatus() - summary := map[string]interface{}{ + summary := map[string]any{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), "BaseURL": middleware.BaseURL(c), diff --git a/core/http/endpoints/mcp/executor.go b/core/http/endpoints/mcp/executor.go new file mode 100644 index 000000000..9f9b279d6 --- /dev/null +++ b/core/http/endpoints/mcp/executor.go @@ -0,0 +1,132 @@ +package mcp + +import ( + "context" + "slices" + + "github.com/mudler/LocalAI/core/config" + mcpRemote "github.com/mudler/LocalAI/core/services/mcp" + "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/xlog" +) + +// ToolExecutor abstracts MCP tool discovery and execution. +// Implementations handle local (in-process sessions) vs distributed (NATS) modes. +type ToolExecutor interface { + // DiscoverTools returns the tool function schemas available from MCP servers. + DiscoverTools(ctx context.Context) ([]functions.Function, error) + // IsTool returns true if the given function name is an MCP tool. + IsTool(name string) bool + // ExecuteTool executes an MCP tool by name with the given JSON arguments. + ExecuteTool(ctx context.Context, toolName, arguments string) (string, error) + // HasTools returns true if any MCP tools are available. + HasTools() bool +} + +// LocalToolExecutor uses in-process MCP sessions for tool operations. +type LocalToolExecutor struct { + tools []MCPToolInfo +} + +// NewLocalToolExecutor creates a ToolExecutor from local named sessions. +// It discovers tools immediately and caches the result. +func NewLocalToolExecutor(ctx context.Context, sessions []NamedSession) *LocalToolExecutor { + tools, err := DiscoverMCPTools(ctx, sessions) + if err != nil { + xlog.Error("Failed to discover MCP tools (local)", "error", err) + } + return &LocalToolExecutor{tools: tools} +} + +func (e *LocalToolExecutor) DiscoverTools(_ context.Context) ([]functions.Function, error) { + var fns []functions.Function + for _, t := range e.tools { + fns = append(fns, t.Function) + } + return fns, nil +} + +func (e *LocalToolExecutor) IsTool(name string) bool { + return IsMCPTool(e.tools, name) +} + +func (e *LocalToolExecutor) ExecuteTool(ctx context.Context, toolName, arguments string) (string, error) { + return ExecuteMCPToolCall(ctx, e.tools, toolName, arguments) +} + +func (e *LocalToolExecutor) HasTools() bool { + return len(e.tools) > 0 +} + +// DistributedToolExecutor routes tool operations through NATS to agent workers. +type DistributedToolExecutor struct { + natsClient MCPNATSClient + modelName string + remote config.MCPGenericConfig[config.MCPRemoteServers] + stdio config.MCPGenericConfig[config.MCPSTDIOServers] + toolDefs []mcpRemote.MCPToolDef +} + +// NewDistributedToolExecutor creates a ToolExecutor that routes through NATS. +// It discovers tools immediately via a NATS request-reply to an agent worker. +func NewDistributedToolExecutor(ctx context.Context, natsClient MCPNATSClient, modelName string, + remote config.MCPGenericConfig[config.MCPRemoteServers], + stdio config.MCPGenericConfig[config.MCPSTDIOServers], +) *DistributedToolExecutor { + e := &DistributedToolExecutor{ + natsClient: natsClient, + modelName: modelName, + remote: remote, + stdio: stdio, + } + resp, err := DiscoverMCPToolsRemote(ctx, natsClient, modelName, remote, stdio) + if err != nil { + xlog.Error("Failed to discover MCP tools (distributed)", "error", err) + } else if resp != nil { + e.toolDefs = resp.Tools + } + return e +} + +func (e *DistributedToolExecutor) DiscoverTools(_ context.Context) ([]functions.Function, error) { + var fns []functions.Function + for _, td := range e.toolDefs { + fns = append(fns, td.Function) + } + return fns, nil +} + +func (e *DistributedToolExecutor) IsTool(name string) bool { + return slices.ContainsFunc(e.toolDefs, func(td mcpRemote.MCPToolDef) bool { + return td.ToolName == name + }) +} + +func (e *DistributedToolExecutor) ExecuteTool(ctx context.Context, toolName, arguments string) (string, error) { + return ExecuteMCPToolCallRemote(ctx, e.natsClient, e.modelName, e.remote, e.stdio, toolName, arguments) +} + +func (e *DistributedToolExecutor) HasTools() bool { + return len(e.toolDefs) > 0 +} + +// NewToolExecutor creates the appropriate ToolExecutor based on the current mode. +// When natsClient is non-nil, returns a DistributedToolExecutor that routes through NATS. +// When natsClient is nil, creates local sessions and returns a LocalToolExecutor. +func NewToolExecutor(ctx context.Context, natsClient MCPNATSClient, modelName string, + remote config.MCPGenericConfig[config.MCPRemoteServers], + stdio config.MCPGenericConfig[config.MCPSTDIOServers], + enabledServers []string, +) ToolExecutor { + if natsClient != nil { + return NewDistributedToolExecutor(ctx, natsClient, modelName, remote, stdio) + } + sessions, err := NamedSessionsFromMCPConfig(modelName, remote, stdio, enabledServers) + if err != nil || len(sessions) == 0 { + if err != nil { + xlog.Error("Failed to create MCP sessions", "error", err) + } + return &LocalToolExecutor{} // empty, HasTools() returns false + } + return NewLocalToolExecutor(ctx, sessions) +} diff --git a/core/http/endpoints/mcp/tools.go b/core/http/endpoints/mcp/tools.go index fde990f01..4c94ae186 100644 --- a/core/http/endpoints/mcp/tools.go +++ b/core/http/endpoints/mcp/tools.go @@ -12,6 +12,10 @@ import ( "time" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + mcpRemote "github.com/mudler/LocalAI/core/services/mcp" + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/signals" @@ -89,6 +93,11 @@ var ( client = mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil) ) +// MCPNATSClient is the interface for NATS request-reply operations needed by MCP routing. +type MCPNATSClient interface { + Request(subject string, data []byte, timeout time.Duration) ([]byte, error) +} + // MCPServersFromMetadata extracts the MCP server list from the metadata map // and returns the list. The "mcp_servers" key is consumed (deleted from the map) // so it doesn't leak to the backend. @@ -114,6 +123,29 @@ func SessionsFromMCPConfig( defer cache.mu.Unlock() sessions, exists := cache.cache[name] + + // Verify cached sessions are still alive. + if exists { + pingCtx, pingCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pingCancel() + alive := true + for _, s := range sessions { + if err := s.Ping(pingCtx, nil); err != nil { + xlog.Warn("MCP session dead, evicting cache", "name", name, "error", err) + alive = false + break + } + } + if !alive { + if cancel, ok := cache.cancels[name]; ok { + cancel() + } + delete(cache.cache, name) + delete(cache.cancels, name) + exists = false + } + } + if exists { return sessions, nil } @@ -127,7 +159,7 @@ func SessionsFromMCPConfig( xlog.Debug("[MCP remote server] Configuration", "server", server) // Create HTTP client with custom roundtripper for bearer token injection httpClient := &http.Client{ - Timeout: 360 * time.Second, + Timeout: config.DefaultMCPToolTimeout, Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport), } @@ -177,13 +209,39 @@ func NamedSessionsFromMCPConfig( defer namedCache.mu.Unlock() allSessions, exists := namedCache.cache[name] + + // If cached, verify sessions are still alive via Ping. + // Dead sessions (e.g. exited stdio containers) are evicted so they get recreated. + if exists { + pingCtx, pingCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pingCancel() + alive := true + for _, ns := range allSessions { + if err := ns.Session.Ping(pingCtx, nil); err != nil { + xlog.Warn("MCP session dead, evicting cache", "server", ns.Name, "error", err) + alive = false + break + } + } + if !alive { + // Close dead sessions and recreate + if cancel, ok := namedCache.cancels[name]; ok { + cancel() + } + delete(namedCache.cache, name) + delete(namedCache.cancels, name) + exists = false + allSessions = nil + } + } + if !exists { ctx, cancel := context.WithCancel(context.Background()) for serverName, server := range remote.Servers { xlog.Debug("[MCP remote server] Configuration", "name", serverName, "server", server) httpClient := &http.Client{ - Timeout: 360 * time.Second, + Timeout: config.DefaultMCPToolTimeout, Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport), } @@ -266,20 +324,22 @@ func DiscoverMCPTools(ctx context.Context, sessions []NamedSession) ([]MCPToolIn Description: tool.Description, } - // Convert InputSchema to map[string]interface{} for functions.Function + // Convert InputSchema to map[string]any for functions.Function if tool.InputSchema != nil { schemaBytes, err := json.Marshal(tool.InputSchema) if err == nil { - var params map[string]interface{} - if json.Unmarshal(schemaBytes, ¶ms) == nil { + var params map[string]any + if err := json.Unmarshal(schemaBytes, ¶ms); err == nil { f.Parameters = params + } else { + xlog.Warn("Failed to unmarshal MCP tool input schema", "tool", tool.Name, "error", err) } } } if f.Parameters == nil { - f.Parameters = map[string]interface{}{ + f.Parameters = map[string]any{ "type": "object", - "properties": map[string]interface{}{}, + "properties": map[string]any{}, } } @@ -341,6 +401,86 @@ func ExecuteMCPToolCall(ctx context.Context, tools []MCPToolInfo, toolName strin return string(combined), nil } +// ExecuteMCPToolCallRemote routes an MCP tool execution request to an agent worker via NATS. +// Used in distributed mode when the frontend doesn't hold MCP sessions locally. +func ExecuteMCPToolCallRemote( + ctx context.Context, + natsClient MCPNATSClient, + modelName string, + remote config.MCPGenericConfig[config.MCPRemoteServers], + stdio config.MCPGenericConfig[config.MCPSTDIOServers], + toolName, arguments string, +) (string, error) { + if natsClient == nil { + return "", fmt.Errorf("NATS client not configured for distributed MCP") + } + + var args map[string]any + if arguments != "" { + if err := json.Unmarshal([]byte(arguments), &args); err != nil { + return "", fmt.Errorf("invalid tool arguments JSON: %w", err) + } + } + + req := mcpRemote.MCPToolRequest{ + ModelName: modelName, + ToolName: toolName, + Arguments: args, + RemoteServers: remote, + StdioServers: stdio, + } + reqData, _ := json.Marshal(req) + + replyData, err := natsClient.Request(messaging.SubjectMCPToolExecute, reqData, config.DefaultMCPToolTimeout) + if err != nil { + return "", fmt.Errorf("NATS MCP tool request failed: %w", err) + } + + var resp mcpRemote.MCPToolResponse + if err := json.Unmarshal(replyData, &resp); err != nil { + return "", fmt.Errorf("unmarshal MCP reply: %w", err) + } + if resp.Error != "" { + return "", fmt.Errorf("remote MCP tool error: %s", resp.Error) + } + return resp.Result, nil +} + +// DiscoverMCPToolsRemote routes an MCP discovery request to an agent worker via NATS. +// Returns server info and tool function schemas from the remote worker. +func DiscoverMCPToolsRemote( + ctx context.Context, + natsClient MCPNATSClient, + modelName string, + remote config.MCPGenericConfig[config.MCPRemoteServers], + stdio config.MCPGenericConfig[config.MCPSTDIOServers], +) (*mcpRemote.MCPDiscoveryResponse, error) { + if natsClient == nil { + return nil, fmt.Errorf("NATS client not configured for distributed MCP") + } + + req := mcpRemote.MCPDiscoveryRequest{ + ModelName: modelName, + RemoteServers: remote, + StdioServers: stdio, + } + reqData, _ := json.Marshal(req) + + replyData, err := natsClient.Request(messaging.SubjectMCPDiscovery, reqData, config.DefaultMCPDiscoveryTimeout) + if err != nil { + return nil, fmt.Errorf("NATS MCP discovery request failed: %w", err) + } + + var resp mcpRemote.MCPDiscoveryResponse + if err := json.Unmarshal(replyData, &resp); err != nil { + return nil, fmt.Errorf("unmarshal MCP discovery reply: %w", err) + } + if resp.Error != "" { + return nil, fmt.Errorf("remote MCP discovery error: %s", resp.Error) + } + return &resp, nil +} + // ListMCPServers returns server info with tool, prompt, and resource names for each session. func ListMCPServers(ctx context.Context, sessions []NamedSession) ([]MCPServerInfo, error) { var result []MCPServerInfo @@ -640,3 +780,93 @@ func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.Round base: base, } } + +// MCPContextResult holds the results of MCP prompt and resource discovery +// so callers can inject them into their message slices. +type MCPContextResult struct { + // PromptMessages are schema.Message values converted from MCP prompts, + // intended to be prepended to the conversation. + PromptMessages []schema.Message + + // ResourceSuffix is the formatted text of all discovered MCP resources, + // intended to be appended to the last user message's content. + // Empty string when no resources were requested or found. + ResourceSuffix string +} + +// InjectMCPContext discovers MCP prompts and resources from the given named sessions +// and returns them in a form ready for injection into any endpoint's message list. +func InjectMCPContext( + ctx context.Context, + namedSessions []NamedSession, + mcpPromptName string, + mcpPromptArgs map[string]string, + mcpResourceURIs []string, +) (*MCPContextResult, error) { + result := &MCPContextResult{} + + if mcpPromptName != "" { + prompts, discErr := DiscoverMCPPrompts(ctx, namedSessions) + if discErr != nil { + xlog.Error("Failed to discover MCP prompts", "error", discErr) + } else { + promptMsgs, getErr := GetMCPPrompt(ctx, prompts, mcpPromptName, mcpPromptArgs) + if getErr != nil { + xlog.Error("Failed to get MCP prompt", "error", getErr) + } else { + for _, pm := range promptMsgs { + result.PromptMessages = append(result.PromptMessages, schema.Message{ + Role: string(pm.Role), + Content: PromptMessageToText(pm), + }) + } + xlog.Debug("MCP prompt discovered", "prompt", mcpPromptName, "messages", len(result.PromptMessages)) + } + } + } + + if len(mcpResourceURIs) > 0 { + resources, discErr := DiscoverMCPResources(ctx, namedSessions) + if discErr != nil { + xlog.Error("Failed to discover MCP resources", "error", discErr) + } else { + var resourceTexts []string + for _, uri := range mcpResourceURIs { + content, readErr := ReadMCPResource(ctx, resources, uri) + if readErr != nil { + xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri) + continue + } + name := uri + for _, r := range resources { + if r.URI == uri { + name = r.Name + break + } + } + resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content)) + } + if len(resourceTexts) > 0 { + result.ResourceSuffix = "\n\n" + strings.Join(resourceTexts, "\n\n") + xlog.Debug("MCP resources discovered", "count", len(resourceTexts)) + } + } + } + + return result, nil +} + +// AppendResourceSuffix appends the resource suffix from an MCPContextResult +// to the last message's content in the given message slice. +func AppendResourceSuffix(messages []schema.Message, suffix string) { + if suffix == "" || len(messages) == 0 { + return + } + lastIdx := len(messages) - 1 + switch ct := messages[lastIdx].Content.(type) { + case string: + messages[lastIdx].Content = ct + suffix + default: + messages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix) + } +} diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 871084054..eb3a92a77 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -3,7 +3,6 @@ package openai import ( "encoding/json" "fmt" - "strings" "time" "github.com/google/uuid" @@ -59,11 +58,8 @@ func mergeToolCallDeltas(existing []schema.ToolCall, deltas []schema.ToolCall) [ // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] -func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc { - var id, textContentToReturn string - var created int - - process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { +func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc { + process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, id string, created int) error { initialMessage := schema.OpenAIResponse{ ID: id, Created: created, @@ -120,7 +116,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator close(responses) return err } - processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { + processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, id string, created int, textContentToReturn *string) error { // Detect if thinking token is already in prompt or template var template string if config.TemplateConfig.UseTokenizerTemplate { @@ -309,18 +305,18 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls)) functionResults = deltaToolCalls // Use content/reasoning from deltas too - textContentToReturn = functions.ContentFromChatDeltas(chatDeltas) + *textContentToReturn = functions.ContentFromChatDeltas(chatDeltas) reasoning = functions.ReasoningFromChatDeltas(chatDeltas) } else { // Fallback: parse tool calls from raw text (no chat deltas from backend) xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing") reasoning = extractor.Reasoning() cleanedResult := extractor.CleanedContent() - textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig) + *textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig) cleanedResult = functions.CleanupLLMResult(cleanedResult, config.FunctionsConfig) functionResults = functions.ParseFunctionCall(cleanedResult, config.FunctionsConfig) } - xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", textContentToReturn) + xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", *textContentToReturn) noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0 switch { @@ -413,7 +409,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator Choices: []schema.Choice{{ Delta: &schema.Message{ Role: "assistant", - Content: &textContentToReturn, + Content: textContentToReturn, ToolCalls: []schema.ToolCall{ { Index: i, @@ -438,9 +434,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } return func(c echo.Context) error { - textContentToReturn = "" - id = uuid.New().String() - created = int(time.Now().Unix()) + var textContentToReturn string + id := uuid.New().String() + created := int(time.Now().Unix()) input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { @@ -461,7 +457,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator strictMode := false // MCP tool injection: when mcp_servers is set in metadata and model has MCP config - var mcpToolInfos []mcpTools.MCPToolInfo + var mcpExecutor mcpTools.ToolExecutor mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata) // MCP prompt and resource injection (extracted before tool injection) @@ -471,82 +467,30 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (config.MCP.Servers != "" || config.MCP.Stdio != "") { remote, stdio, mcpErr := config.MCP.MCPConfigFromYAML() if mcpErr == nil { + mcpExecutor = mcpTools.NewToolExecutor(c.Request().Context(), natsClient, config.Name, remote, stdio, mcpServers) + + // Prompt and resource injection (pre-processing step — resolves locally regardless of distributed mode) namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(config.Name, remote, stdio, mcpServers) if sessErr == nil && len(namedSessions) > 0 { - // Prompt injection: prepend prompt messages to the conversation - if mcpPromptName != "" { - prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions) - if discErr == nil { - promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs) - if getErr == nil { - var injected []schema.Message - for _, pm := range promptMsgs { - injected = append(injected, schema.Message{ - Role: string(pm.Role), - Content: mcpTools.PromptMessageToText(pm), - }) - } - input.Messages = append(injected, input.Messages...) - xlog.Debug("MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected)) - } else { - xlog.Error("Failed to get MCP prompt", "error", getErr) - } - } else { - xlog.Error("Failed to discover MCP prompts", "error", discErr) - } + mcpCtx, _ := mcpTools.InjectMCPContext(c.Request().Context(), namedSessions, mcpPromptName, mcpPromptArgs, mcpResourceURIs) + if mcpCtx != nil { + input.Messages = append(mcpCtx.PromptMessages, input.Messages...) + mcpTools.AppendResourceSuffix(input.Messages, mcpCtx.ResourceSuffix) } + } - // Resource injection: append resource content to the last user message - if len(mcpResourceURIs) > 0 { - resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions) - if discErr == nil { - var resourceTexts []string - for _, uri := range mcpResourceURIs { - content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri) - if readErr != nil { - xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri) - continue - } - // Find resource name - name := uri - for _, r := range resources { - if r.URI == uri { - name = r.Name - break - } - } - resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content)) - } - if len(resourceTexts) > 0 && len(input.Messages) > 0 { - lastIdx := len(input.Messages) - 1 - suffix := "\n\n" + strings.Join(resourceTexts, "\n\n") - switch ct := input.Messages[lastIdx].Content.(type) { - case string: - input.Messages[lastIdx].Content = ct + suffix - default: - input.Messages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix) - } - xlog.Debug("MCP resources injected", "count", len(resourceTexts)) - } - } else { - xlog.Error("Failed to discover MCP resources", "error", discErr) - } - } - - // Tool injection - if len(mcpServers) > 0 { - discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions) - if discErr == nil { - mcpToolInfos = discovered - for _, ti := range mcpToolInfos { - funcs = append(funcs, ti.Function) - input.Tools = append(input.Tools, functions.Tool{Type: "function", Function: ti.Function}) - } - shouldUseFn = len(funcs) > 0 && config.ShouldUseFunctions() - xlog.Debug("MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs)) - } else { - xlog.Error("Failed to discover MCP tools", "error", discErr) + // Tool injection via executor + if mcpExecutor.HasTools() { + mcpFuncs, discErr := mcpExecutor.DiscoverTools(c.Request().Context()) + if discErr == nil { + for _, fn := range mcpFuncs { + funcs = append(funcs, fn) + input.Tools = append(input.Tools, functions.Tool{Type: "function", Function: fn}) } + shouldUseFn = len(funcs) > 0 && config.ShouldUseFunctions() + xlog.Debug("MCP tools injected", "count", len(mcpFuncs), "total_funcs", len(funcs)) + } else { + xlog.Error("Failed to discover MCP tools", "error", discErr) } } } else { @@ -630,9 +574,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ + Parameters: map[string]any{ + "properties": map[string]any{ + "message": map[string]any{ "type": "string", "description": "The message to reply the user with", }}, @@ -705,223 +649,228 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator if config.Agent.MaxIterations > 0 { mcpStreamMaxIterations = config.Agent.MaxIterations } - hasMCPToolsStream := len(mcpToolInfos) > 0 + hasMCPToolsStream := mcpExecutor != nil && mcpExecutor.HasTools() for mcpStreamIter := 0; mcpStreamIter <= mcpStreamMaxIterations; mcpStreamIter++ { - // Re-template on MCP iterations - if mcpStreamIter > 0 && !config.TemplateConfig.UseTokenizerTemplate { - predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) - xlog.Debug("MCP stream re-templating", "iteration", mcpStreamIter) - } - - responses := make(chan schema.OpenAIResponse) - ended := make(chan error, 1) - - go func() { - if !shouldUseFn { - ended <- process(predInput, input, config, ml, responses, extraUsage) - } else { - ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage) + // Re-template on MCP iterations + if mcpStreamIter > 0 && !config.TemplateConfig.UseTokenizerTemplate { + predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) + xlog.Debug("MCP stream re-templating", "iteration", mcpStreamIter) } - }() - usage := &schema.OpenAIUsage{} - toolsCalled := false - var collectedToolCalls []schema.ToolCall - var collectedContent string + responses := make(chan schema.OpenAIResponse) + ended := make(chan error, 1) - LOOP: - for { - select { - case <-input.Context.Done(): - // Context was cancelled (client disconnected or request cancelled) - xlog.Debug("Request context cancelled, stopping stream") - input.Cancel() - break LOOP - case ev := <-responses: - if len(ev.Choices) == 0 { - xlog.Debug("No choices in the response, skipping") - continue - } - usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it - if len(ev.Choices[0].Delta.ToolCalls) > 0 { - toolsCalled = true - // Collect and merge tool call deltas for MCP execution - if hasMCPToolsStream { - collectedToolCalls = mergeToolCallDeltas(collectedToolCalls, ev.Choices[0].Delta.ToolCalls) - } - } - // Collect content for MCP conversation history and automatic tool parsing fallback - if (hasMCPToolsStream || config.FunctionsConfig.AutomaticToolParsingFallback) && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil { - if s, ok := ev.Choices[0].Delta.Content.(string); ok { - collectedContent += s - } else if sp, ok := ev.Choices[0].Delta.Content.(*string); ok && sp != nil { - collectedContent += *sp - } - } - respData, err := json.Marshal(ev) - if err != nil { - xlog.Debug("Failed to marshal response", "error", err) - input.Cancel() - continue - } - xlog.Debug("Sending chunk", "chunk", string(respData)) - _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) - if err != nil { - xlog.Debug("Sending chunk failed", "error", err) - input.Cancel() - return err - } - c.Response().Flush() - case err := <-ended: - if err == nil { - break LOOP - } - xlog.Error("Stream ended with error", "error", err) - - errorResp := schema.ErrorResponse{ - Error: &schema.APIError{ - Message: err.Error(), - Type: "server_error", - Code: "server_error", - }, - } - respData, marshalErr := json.Marshal(errorResp) - if marshalErr != nil { - xlog.Error("Failed to marshal error response", "error", marshalErr) - fmt.Fprintf(c.Response().Writer, "data: {\"error\":{\"message\":\"Internal error\",\"type\":\"server_error\"}}\n\n") + go func() { + if !shouldUseFn { + ended <- process(predInput, input, config, ml, responses, extraUsage, id, created) } else { - fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) + ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage, id, created, &textContentToReturn) } - fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") - c.Response().Flush() + }() - return nil - } - } + usage := &schema.OpenAIUsage{} + toolsCalled := false + var collectedToolCalls []schema.ToolCall + var collectedContent string - // MCP streaming tool execution: if we collected MCP tool calls, execute and loop - if hasMCPToolsStream && toolsCalled && len(collectedToolCalls) > 0 { - var hasMCPCalls bool - for _, tc := range collectedToolCalls { - if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { - hasMCPCalls = true - break - } - } - if hasMCPCalls { - // Append assistant message with tool_calls - assistantMsg := schema.Message{ - Role: "assistant", - Content: collectedContent, - ToolCalls: collectedToolCalls, - } - input.Messages = append(input.Messages, assistantMsg) - - // Execute MCP tool calls and stream results as tool_result events - for _, tc := range collectedToolCalls { - if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + LOOP: + for { + select { + case <-input.Context.Done(): + // Context was cancelled (client disconnected or request cancelled) + xlog.Debug("Request context cancelled, stopping stream") + input.Cancel() + break LOOP + case ev := <-responses: + if len(ev.Choices) == 0 { + xlog.Debug("No choices in the response, skipping") continue } - xlog.Debug("Executing MCP tool (stream)", "tool", tc.FunctionCall.Name, "iteration", mcpStreamIter) - toolResult, toolErr := mcpTools.ExecuteMCPToolCall( - c.Request().Context(), mcpToolInfos, - tc.FunctionCall.Name, tc.FunctionCall.Arguments, - ) - if toolErr != nil { - xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) - toolResult = fmt.Sprintf("Error: %v", toolErr) + usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + if len(ev.Choices[0].Delta.ToolCalls) > 0 { + toolsCalled = true + // Collect and merge tool call deltas for MCP execution + if hasMCPToolsStream { + collectedToolCalls = mergeToolCallDeltas(collectedToolCalls, ev.Choices[0].Delta.ToolCalls) + } } - input.Messages = append(input.Messages, schema.Message{ - Role: "tool", - Content: toolResult, - StringContent: toolResult, - ToolCallID: tc.ID, - Name: tc.FunctionCall.Name, - }) - - // Stream tool result event to client - mcpEvent := map[string]any{ - "type": "mcp_tool_result", - "name": tc.FunctionCall.Name, - "result": toolResult, + // Collect content for MCP conversation history and automatic tool parsing fallback + if (hasMCPToolsStream || config.FunctionsConfig.AutomaticToolParsingFallback) && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil { + if s, ok := ev.Choices[0].Delta.Content.(string); ok { + collectedContent += s + } else if sp, ok := ev.Choices[0].Delta.Content.(*string); ok && sp != nil { + collectedContent += *sp + } } - if mcpEventData, err := json.Marshal(mcpEvent); err == nil { - fmt.Fprintf(c.Response().Writer, "data: %s\n\n", mcpEventData) - c.Response().Flush() + respData, err := json.Marshal(ev) + if err != nil { + xlog.Debug("Failed to marshal response", "error", err) + input.Cancel() + continue } - } + xlog.Debug("Sending chunk", "chunk", string(respData)) + _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) + if err != nil { + xlog.Debug("Sending chunk failed", "error", err) + input.Cancel() + return err + } + c.Response().Flush() + case err := <-ended: + if err == nil { + break LOOP + } + xlog.Error("Stream ended with error", "error", err) - xlog.Debug("MCP streaming tools executed, re-running inference", "iteration", mcpStreamIter) - continue // next MCP stream iteration - } - } - - // Automatic tool parsing fallback for streaming: when no tools were - // requested but the model emitted tool call markup, parse and emit them. - if !shouldUseFn && config.FunctionsConfig.AutomaticToolParsingFallback && collectedContent != "" && !toolsCalled { - parsed := functions.ParseFunctionCall(collectedContent, config.FunctionsConfig) - for i, fc := range parsed { - toolCallID := fc.ID - if toolCallID == "" { - toolCallID = id - } - toolCallMsg := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, - Choices: []schema.Choice{{ - Delta: &schema.Message{ - Role: "assistant", - ToolCalls: []schema.ToolCall{{ - Index: i, - ID: toolCallID, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: fc.Name, - Arguments: fc.Arguments, - }, - }}, + errorResp := schema.ErrorResponse{ + Error: &schema.APIError{ + Message: err.Error(), + Type: "server_error", + Code: "server_error", }, - Index: 0, - }}, - Object: "chat.completion.chunk", + } + respData, marshalErr := json.Marshal(errorResp) + if marshalErr != nil { + xlog.Error("Failed to marshal error response", "error", marshalErr) + fmt.Fprintf(c.Response().Writer, "data: {\"error\":{\"message\":\"Internal error\",\"type\":\"server_error\"}}\n\n") + } else { + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) + } + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + + return nil } - respData, _ := json.Marshal(toolCallMsg) - fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) - c.Response().Flush() - toolsCalled = true } - } - // No MCP tools to execute, send final stop message - finishReason := FinishReasonStop - if toolsCalled && len(input.Tools) > 0 { - finishReason = FinishReasonToolCalls - } else if toolsCalled { - finishReason = FinishReasonFunctionCall - } + // Drain responses channel to unblock the background goroutine if it's + // still trying to send (e.g., after client disconnect). The goroutine + // calls close(responses) when done, which terminates the drain. + if input.Context.Err() != nil { + go func() { for range responses {} }() + <-ended + } - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - FinishReason: &finishReason, - Index: 0, - Delta: &schema.Message{}, - }}, - Object: "chat.completion.chunk", - Usage: *usage, - } - respData, _ := json.Marshal(resp) + // MCP streaming tool execution: if we collected MCP tool calls, execute and loop + if hasMCPToolsStream && toolsCalled && len(collectedToolCalls) > 0 { + var hasMCPCalls bool + for _, tc := range collectedToolCalls { + if mcpExecutor != nil && mcpExecutor.IsTool(tc.FunctionCall.Name) { + hasMCPCalls = true + break + } + } + if hasMCPCalls { + // Append assistant message with tool_calls + assistantMsg := schema.Message{ + Role: "assistant", + Content: collectedContent, + ToolCalls: collectedToolCalls, + } + input.Messages = append(input.Messages, assistantMsg) - fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) - fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") - c.Response().Flush() - xlog.Debug("Stream ended") - return nil + // Execute MCP tool calls and stream results as tool_result events + for _, tc := range collectedToolCalls { + if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) { + continue + } + xlog.Debug("Executing MCP tool (stream)", "tool", tc.FunctionCall.Name, "iteration", mcpStreamIter) + toolResult, toolErr := mcpExecutor.ExecuteTool(c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments) + if toolErr != nil { + xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) + toolResult = fmt.Sprintf("Error: %v", toolErr) + } + input.Messages = append(input.Messages, schema.Message{ + Role: "tool", + Content: toolResult, + StringContent: toolResult, + ToolCallID: tc.ID, + Name: tc.FunctionCall.Name, + }) + + // Stream tool result event to client + mcpEvent := map[string]any{ + "type": "mcp_tool_result", + "name": tc.FunctionCall.Name, + "result": toolResult, + } + if mcpEventData, err := json.Marshal(mcpEvent); err == nil { + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", mcpEventData) + c.Response().Flush() + } + } + + xlog.Debug("MCP streaming tools executed, re-running inference", "iteration", mcpStreamIter) + continue // next MCP stream iteration + } + } + + // Automatic tool parsing fallback for streaming: when no tools were + // requested but the model emitted tool call markup, parse and emit them. + if !shouldUseFn && config.FunctionsConfig.AutomaticToolParsingFallback && collectedContent != "" && !toolsCalled { + parsed := functions.ParseFunctionCall(collectedContent, config.FunctionsConfig) + for i, fc := range parsed { + toolCallID := fc.ID + if toolCallID == "" { + toolCallID = id + } + toolCallMsg := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{{ + Index: i, + ID: toolCallID, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: fc.Name, + Arguments: fc.Arguments, + }, + }}, + }, + Index: 0, + }}, + Object: "chat.completion.chunk", + } + respData, _ := json.Marshal(toolCallMsg) + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) + c.Response().Flush() + toolsCalled = true + } + } + + // No MCP tools to execute, send final stop message + finishReason := FinishReasonStop + if toolsCalled && len(input.Tools) > 0 { + finishReason = FinishReasonToolCalls + } else if toolsCalled { + finishReason = FinishReasonFunctionCall + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + FinishReason: &finishReason, + Index: 0, + Delta: &schema.Message{}, + }}, + Object: "chat.completion.chunk", + Usage: *usage, + } + respData, _ := json.Marshal(resp) + + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + xlog.Debug("Stream ended") + return nil } // end MCP stream iteration loop // Safety fallback @@ -935,322 +884,319 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator if config.Agent.MaxIterations > 0 { mcpMaxIterations = config.Agent.MaxIterations } - hasMCPTools := len(mcpToolInfos) > 0 + hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools() for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { - // Re-template on each MCP iteration since messages may have changed - if mcpIteration > 0 && !config.TemplateConfig.UseTokenizerTemplate { - predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) - xlog.Debug("MCP re-templating", "iteration", mcpIteration, "prompt_len", len(predInput)) - } - - // Detect if thinking token is already in prompt or template - var template string - if config.TemplateConfig.UseTokenizerTemplate { - template = config.GetModelTemplate() // TODO: this should be the parsed jinja template. But for now this is the best we can do. - } else { - template = predInput - } - thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig) - - xlog.Debug("Thinking start token", "thinkingStartToken", thinkingStartToken, "template", template) - - // When shouldUseFn, the callback just stores the raw text — tool parsing - // is deferred to after ComputeChoices so we can check chat deltas first - // and avoid redundant Go-side parsing. - var cbRawResult, cbReasoning string - - tokenCallback := func(s string, c *[]schema.Choice) { - reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig) - - if !shouldUseFn { - stopReason := FinishReasonStop - message := &schema.Message{Role: "assistant", Content: &s} - if reasoning != "" { - message.Reasoning = &reasoning - } - *c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: message}) - return + // Re-template on each MCP iteration since messages may have changed + if mcpIteration > 0 && !config.TemplateConfig.UseTokenizerTemplate { + predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) + xlog.Debug("MCP re-templating", "iteration", mcpIteration, "prompt_len", len(predInput)) } - // Store raw text for deferred tool parsing - cbRawResult = s - cbReasoning = reasoning - } - - var result []schema.Choice - var tokenUsage backend.TokenUsage - var err error - - var chatDeltas []*pb.ChatDelta - result, tokenUsage, chatDeltas, err = ComputeChoices( - input, - predInput, - config, - cl, - startupOptions, - ml, - tokenCallback, - nil, - func(attempt int) bool { - if !shouldUseFn { - return false - } - // Retry when backend produced only reasoning and no content/tool calls. - // Full tool parsing is deferred until after ComputeChoices returns - // (when chat deltas are available), but we can detect the empty case here. - if cbRawResult == "" && textContentToReturn == "" { - xlog.Warn("Backend produced reasoning without actionable content, retrying", - "reasoning_len", len(cbReasoning), "attempt", attempt+1) - cbRawResult = "" - cbReasoning = "" - textContentToReturn = "" - return true - } - return false - }, - ) - if err != nil { - return err - } - - // Tool parsing is deferred here (only when shouldUseFn) so chat deltas are available - if shouldUseFn { - var funcResults []functions.FuncCallResults - - // Try pre-parsed tool calls from C++ autoparser first - if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { - xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls)) - funcResults = deltaToolCalls - textContentToReturn = functions.ContentFromChatDeltas(chatDeltas) - cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas) + // Detect if thinking token is already in prompt or template + var template string + if config.TemplateConfig.UseTokenizerTemplate { + template = config.GetModelTemplate() // TODO: this should be the parsed jinja template. But for now this is the best we can do. } else { - // Fallback: parse tool calls from raw text - xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing") - textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig) - cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig) - funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig) + template = predInput } + thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig) - // Content-based tool call fallback: if no tool calls were found, - // try parsing the raw result — ParseFunctionCall handles detection internally. - if len(funcResults) == 0 { - contentFuncResults := functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig) - if len(contentFuncResults) > 0 { - funcResults = contentFuncResults - textContentToReturn = functions.StripToolCallMarkup(cbRawResult) - } - } + xlog.Debug("Thinking start token", "thinkingStartToken", thinkingStartToken, "template", template) - noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0 + // When shouldUseFn, the callback just stores the raw text — tool parsing + // is deferred to after ComputeChoices so we can check chat deltas first + // and avoid redundant Go-side parsing. + var cbRawResult, cbReasoning string - switch { - case noActionsToRun: - qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput) - if qErr != nil { - xlog.Error("error handling question", "error", qErr) - } + tokenCallback := func(s string, c *[]schema.Choice) { + reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig) - stopReason := FinishReasonStop - message := &schema.Message{Role: "assistant", Content: &qResult} - if cbReasoning != "" { - message.Reasoning = &cbReasoning - } - result = append(result, schema.Choice{ - FinishReason: &stopReason, - Message: message, - }) - default: - toolCallsReason := FinishReasonToolCalls - toolChoice := schema.Choice{ - FinishReason: &toolCallsReason, - Message: &schema.Message{ - Role: "assistant", - }, - } - if cbReasoning != "" { - toolChoice.Message.Reasoning = &cbReasoning - } - - for _, ss := range funcResults { - name, args := ss.Name, ss.Arguments - toolCallID := ss.ID - if toolCallID == "" { - toolCallID = id + if !shouldUseFn { + stopReason := FinishReasonStop + message := &schema.Message{Role: "assistant", Content: &s} + if reasoning != "" { + message.Reasoning = &reasoning } + *c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: message}) + return + } + + // Store raw text for deferred tool parsing + cbRawResult = s + cbReasoning = reasoning + } + + var result []schema.Choice + var tokenUsage backend.TokenUsage + var err error + + var chatDeltas []*pb.ChatDelta + result, tokenUsage, chatDeltas, err = ComputeChoices( + input, + predInput, + config, + cl, + startupOptions, + ml, + tokenCallback, + nil, + func(attempt int) bool { + if !shouldUseFn { + return false + } + // Retry when backend produced only reasoning and no content/tool calls. + // Full tool parsing is deferred until after ComputeChoices returns + // (when chat deltas are available), but we can detect the empty case here. + if cbRawResult == "" && textContentToReturn == "" { + xlog.Warn("Backend produced reasoning without actionable content, retrying", + "reasoning_len", len(cbReasoning), "attempt", attempt+1) + cbRawResult = "" + cbReasoning = "" + textContentToReturn = "" + return true + } + return false + }, + ) + if err != nil { + return err + } + + // Tool parsing is deferred here (only when shouldUseFn) so chat deltas are available + if shouldUseFn { + var funcResults []functions.FuncCallResults + + // Try pre-parsed tool calls from C++ autoparser first + if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { + xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls)) + funcResults = deltaToolCalls + textContentToReturn = functions.ContentFromChatDeltas(chatDeltas) + cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas) + } else { + // Fallback: parse tool calls from raw text + xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing") + textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig) + cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig) + funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig) + } + + // Content-based tool call fallback: if no tool calls were found, + // try parsing the raw result — ParseFunctionCall handles detection internally. + if len(funcResults) == 0 { + contentFuncResults := functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig) + if len(contentFuncResults) > 0 { + funcResults = contentFuncResults + textContentToReturn = functions.StripToolCallMarkup(cbRawResult) + } + } + + noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0 + + switch { + case noActionsToRun: + qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput) + if qErr != nil { + xlog.Error("error handling question", "error", qErr) + } + + stopReason := FinishReasonStop + message := &schema.Message{Role: "assistant", Content: &qResult} + if cbReasoning != "" { + message.Reasoning = &cbReasoning + } + result = append(result, schema.Choice{ + FinishReason: &stopReason, + Message: message, + }) + default: + toolCallsReason := FinishReasonToolCalls + toolChoice := schema.Choice{ + FinishReason: &toolCallsReason, + Message: &schema.Message{ + Role: "assistant", + }, + } + if cbReasoning != "" { + toolChoice.Message.Reasoning = &cbReasoning + } + + for _, ss := range funcResults { + name, args := ss.Name, ss.Arguments + toolCallID := ss.ID + if toolCallID == "" { + toolCallID = id + } + if len(input.Tools) > 0 { + toolChoice.Message.Content = textContentToReturn + toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, + schema.ToolCall{ + ID: toolCallID, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }, + ) + } else { + // Deprecated function_call format + functionCallReason := FinishReasonFunctionCall + message := &schema.Message{ + Role: "assistant", + Content: &textContentToReturn, + FunctionCall: map[string]any{ + "name": name, + "arguments": args, + }, + } + if cbReasoning != "" { + message.Reasoning = &cbReasoning + } + result = append(result, schema.Choice{ + FinishReason: &functionCallReason, + Message: message, + }) + } + } + if len(input.Tools) > 0 { - toolChoice.Message.Content = textContentToReturn - toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, + result = append(result, toolChoice) + } + } + } + + // Automatic tool parsing fallback: when no tools/functions were in the + // request but the model emitted tool call markup, parse and surface them. + if !shouldUseFn && config.FunctionsConfig.AutomaticToolParsingFallback && len(result) > 0 { + for i, choice := range result { + if choice.Message == nil || choice.Message.Content == nil { + continue + } + contentStr, ok := choice.Message.Content.(string) + if !ok || contentStr == "" { + continue + } + parsed := functions.ParseFunctionCall(contentStr, config.FunctionsConfig) + if len(parsed) == 0 { + continue + } + stripped := functions.StripToolCallMarkup(contentStr) + toolCallsReason := FinishReasonToolCalls + result[i].FinishReason = &toolCallsReason + if stripped != "" { + result[i].Message.Content = &stripped + } else { + result[i].Message.Content = nil + } + for _, fc := range parsed { + toolCallID := fc.ID + if toolCallID == "" { + toolCallID = id + } + result[i].Message.ToolCalls = append(result[i].Message.ToolCalls, schema.ToolCall{ ID: toolCallID, Type: "function", FunctionCall: schema.FunctionCall{ - Name: name, - Arguments: args, + Name: fc.Name, + Arguments: fc.Arguments, }, }, ) - } else { - // Deprecated function_call format - functionCallReason := FinishReasonFunctionCall - message := &schema.Message{ - Role: "assistant", - Content: &textContentToReturn, - FunctionCall: map[string]interface{}{ - "name": name, - "arguments": args, - }, - } - if cbReasoning != "" { - message.Reasoning = &cbReasoning - } - result = append(result, schema.Choice{ - FinishReason: &functionCallReason, - Message: message, - }) } } - - if len(input.Tools) > 0 { - result = append(result, toolChoice) - } } - } - // Automatic tool parsing fallback: when no tools/functions were in the - // request but the model emitted tool call markup, parse and surface them. - if !shouldUseFn && config.FunctionsConfig.AutomaticToolParsingFallback && len(result) > 0 { - for i, choice := range result { - if choice.Message == nil || choice.Message.Content == nil { - continue - } - contentStr, ok := choice.Message.Content.(string) - if !ok || contentStr == "" { - continue - } - parsed := functions.ParseFunctionCall(contentStr, config.FunctionsConfig) - if len(parsed) == 0 { - continue - } - stripped := functions.StripToolCallMarkup(contentStr) - toolCallsReason := FinishReasonToolCalls - result[i].FinishReason = &toolCallsReason - if stripped != "" { - result[i].Message.Content = &stripped - } else { - result[i].Message.Content = nil - } - for _, fc := range parsed { - toolCallID := fc.ID - if toolCallID == "" { - toolCallID = id - } - result[i].Message.ToolCalls = append(result[i].Message.ToolCalls, - schema.ToolCall{ - ID: toolCallID, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: fc.Name, - Arguments: fc.Arguments, - }, - }, - ) - } - } - } - - // MCP server-side tool execution loop: - // If we have MCP tools and the model returned tool_calls, execute MCP tools - // and re-run inference with the results appended to the conversation. - if hasMCPTools && len(result) > 0 { - var mcpCallsExecuted bool - for _, choice := range result { - if choice.Message == nil || len(choice.Message.ToolCalls) == 0 { - continue - } - // Check if any tool calls are MCP tools - var hasMCPCalls bool - for _, tc := range choice.Message.ToolCalls { - if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { - hasMCPCalls = true - break - } - } - if !hasMCPCalls { - continue - } - - // Append assistant message with tool_calls to conversation - assistantContent := "" - if choice.Message.Content != nil { - if s, ok := choice.Message.Content.(string); ok { - assistantContent = s - } else if sp, ok := choice.Message.Content.(*string); ok && sp != nil { - assistantContent = *sp - } - } - assistantMsg := schema.Message{ - Role: "assistant", - Content: assistantContent, - ToolCalls: choice.Message.ToolCalls, - } - input.Messages = append(input.Messages, assistantMsg) - - // Execute each MCP tool call and append results - for _, tc := range choice.Message.ToolCalls { - if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + // MCP server-side tool execution loop: + // If we have MCP tools and the model returned tool_calls, execute MCP tools + // and re-run inference with the results appended to the conversation. + if hasMCPTools && len(result) > 0 { + var mcpCallsExecuted bool + for _, choice := range result { + if choice.Message == nil || len(choice.Message.ToolCalls) == 0 { continue } - xlog.Debug("Executing MCP tool", "tool", tc.FunctionCall.Name, "arguments", tc.FunctionCall.Arguments, "iteration", mcpIteration) - toolResult, toolErr := mcpTools.ExecuteMCPToolCall( - c.Request().Context(), mcpToolInfos, - tc.FunctionCall.Name, tc.FunctionCall.Arguments, - ) - if toolErr != nil { - xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) - toolResult = fmt.Sprintf("Error: %v", toolErr) + // Check if any tool calls are MCP tools + var hasMCPCalls bool + for _, tc := range choice.Message.ToolCalls { + if mcpExecutor != nil && mcpExecutor.IsTool(tc.FunctionCall.Name) { + hasMCPCalls = true + break + } } - input.Messages = append(input.Messages, schema.Message{ - Role: "tool", - Content: toolResult, - StringContent: toolResult, - ToolCallID: tc.ID, - Name: tc.FunctionCall.Name, - }) - mcpCallsExecuted = true + if !hasMCPCalls { + continue + } + + // Append assistant message with tool_calls to conversation + assistantContent := "" + if choice.Message.Content != nil { + if s, ok := choice.Message.Content.(string); ok { + assistantContent = s + } else if sp, ok := choice.Message.Content.(*string); ok && sp != nil { + assistantContent = *sp + } + } + assistantMsg := schema.Message{ + Role: "assistant", + Content: assistantContent, + ToolCalls: choice.Message.ToolCalls, + } + input.Messages = append(input.Messages, assistantMsg) + + // Execute each MCP tool call and append results + for _, tc := range choice.Message.ToolCalls { + if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) { + continue + } + xlog.Debug("Executing MCP tool", "tool", tc.FunctionCall.Name, "arguments", tc.FunctionCall.Arguments, "iteration", mcpIteration) + toolResult, toolErr := mcpExecutor.ExecuteTool(c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments) + if toolErr != nil { + xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) + toolResult = fmt.Sprintf("Error: %v", toolErr) + } + input.Messages = append(input.Messages, schema.Message{ + Role: "tool", + Content: toolResult, + StringContent: toolResult, + ToolCallID: tc.ID, + Name: tc.FunctionCall.Name, + }) + mcpCallsExecuted = true + } + } + + if mcpCallsExecuted { + xlog.Debug("MCP tools executed, re-running inference", "iteration", mcpIteration, "messages", len(input.Messages)) + continue // next MCP iteration } } - if mcpCallsExecuted { - xlog.Debug("MCP tools executed, re-running inference", "iteration", mcpIteration, "messages", len(input.Messages)) - continue // next MCP iteration + // No MCP tools to execute (or no MCP tools configured), return response + usage := schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + } + if extraUsage { + usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration + usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing } - } - // No MCP tools to execute (or no MCP tools configured), return response - usage := schema.OpenAIUsage{ - PromptTokens: tokenUsage.Prompt, - CompletionTokens: tokenUsage.Completion, - TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, - } - if extraUsage { - usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration - usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing - } + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + Usage: usage, + } + respData, _ := json.Marshal(resp) + xlog.Debug("Response", "response", string(respData)) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - Usage: usage, - } - respData, _ := json.Marshal(resp) - xlog.Debug("Response", "response", string(respData)) - - // Return the prediction in the response body - return c.JSON(200, resp) + // Return the prediction in the response body + return c.JSON(200, resp) } // end MCP iteration loop // Should not reach here, but safety fallback @@ -1273,7 +1219,7 @@ func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCall arg = funcResults[0].Arguments } // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} + arguments := map[string]any{} if err := json.Unmarshal([]byte(arg), &arguments); err != nil { xlog.Debug("handleQuestion: function result did not contain a valid JSON object") } diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 32834a923..3a6036964 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -154,7 +154,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi if input.N == 0 { n = 1 } - for j := 0; j < n; j++ { + for range n { prompts := strings.Split(i, "|") positive_prompt := prompts[0] negative_prompt := "" diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go index 9189870a0..58a9faa35 100644 --- a/core/http/endpoints/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -103,7 +103,7 @@ func ComputeChoices( const maxRetries = 5 - for i := 0; i < n; i++ { + for range n { var prediction backend.LLMResponse for attempt := 0; attempt <= maxRetries; attempt++ { diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index 1f722bacf..4565773b7 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -5,7 +5,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" model "github.com/mudler/LocalAI/pkg/model" "gorm.io/gorm" ) @@ -24,12 +24,12 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap filter := c.QueryParam("filter") // By default, exclude any loose files that are already referenced by a configuration file. - var policy services.LooseFilePolicy + var policy galleryop.LooseFilePolicy excludeConfigured := c.QueryParam("excludeConfigured") if excludeConfigured == "" || excludeConfigured == "true" { - policy = services.SKIP_IF_CONFIGURED + policy = galleryop.SKIP_IF_CONFIGURED } else { - policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user? + policy = galleryop.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user? } filterFn, err := config.BuildNameFilterFn(filter) @@ -37,7 +37,7 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap return err } - modelNames, err := services.ListModels(bcl, ml, filterFn, policy) + modelNames, err := galleryop.ListModels(bcl, ml, filterFn, policy) if err != nil { return err } diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 32923a4ac..e5610d343 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -75,7 +75,7 @@ type Session struct { DefaultConversationID string ModelInterface Model // The pipeline model config or the config for an any-to-any model - ModelConfig *config.ModelConfig + ModelConfig *config.ModelConfig InputSampleRate int OutputSampleRate int MaxOutputTokens types.IntOrInf @@ -336,12 +336,10 @@ func runRealtimeSession(application *application.Application, t Transport, model if session.TurnDetection != nil && session.TurnDetection.ServerVad != nil && !vadServerStarted { xlog.Debug("Starting VAD goroutine...") done = make(chan struct{}) - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { conversation := session.Conversations[session.DefaultConversationID] handleVAD(session, conversation, t, done) - }() + }) vadServerStarted = true } else if (session.TurnDetection == nil || session.TurnDetection.ServerVad == nil) && vadServerStarted { xlog.Debug("Stopping VAD goroutine...") @@ -684,7 +682,7 @@ func sendTestTone(t Transport) { ) pcm := make([]byte, numSamples*2) // 16-bit samples = 2 bytes each - for i := 0; i < numSamples; i++ { + for i := range numSamples { sample := int16(amplitude * math.Sin(2*math.Pi*freq*float64(i)/sampleRate)) binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) } @@ -1337,7 +1335,7 @@ func triggerResponse(ctx context.Context, session *Session, conv *Conversation, if isNoAction { arg := toolCalls[0].Arguments - arguments := map[string]interface{}{} + arguments := map[string]any{} if err := json.Unmarshal([]byte(arg), &arguments); err == nil { if m, exists := arguments["message"]; exists { if message, ok := m.(string); ok { diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index 224135b30..5c516aae3 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -130,9 +130,9 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ + Parameters: map[string]any{ + "properties": map[string]any{ + "message": map[string]any{ "type": "string", "description": "The message to reply the user with", }, @@ -199,16 +199,16 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im var chatTools []functions.Tool for _, t := range tools { if t.Function != nil { - var params map[string]interface{} + var params map[string]any switch p := t.Function.Parameters.(type) { - case map[string]interface{}: + case map[string]any: params = p case string: if err := json.Unmarshal([]byte(p), ¶ms); err != nil { xlog.Warn("Failed to parse parameters JSON string", "error", err, "function", t.Function.Name) } case nil: - params = map[string]interface{}{} + params = map[string]any{} default: // Try to marshal/unmarshal to get map b, err := json.Marshal(p) diff --git a/core/http/endpoints/openai/realtime_transport_webrtc.go b/core/http/endpoints/openai/realtime_transport_webrtc.go index af7aa046b..b687654bd 100644 --- a/core/http/endpoints/openai/realtime_transport_webrtc.go +++ b/core/http/endpoints/openai/realtime_transport_webrtc.go @@ -24,13 +24,13 @@ type WebRTCTransport struct { audioTrack *webrtc.TrackLocalStaticRTP opusBackend grpc.Backend inEvents chan []byte - outEvents chan []byte // buffered outbound event queue - closed chan struct{} - closeOnce sync.Once - flushed chan struct{} // closed when sender goroutine has drained outEvents - dcReady chan struct{} // closed when data channel is open - dcReadyOnce sync.Once - sessionCh chan *Session // delivers session from runRealtimeSession to handleIncomingAudioTrack + outEvents chan []byte // buffered outbound event queue + closed chan struct{} + closeDone func() // sync.OnceFunc that closes t.closed + flushed chan struct{} // closed when sender goroutine has drained outEvents + dcReady chan struct{} // closed when data channel is open + dcDone func() // sync.OnceFunc that closes t.dcReady + sessionCh chan *Session // delivers session from runRealtimeSession to handleIncomingAudioTrack // RTP state for outbound audio — protected by rtpMu rtpMu sync.Mutex @@ -54,6 +54,8 @@ func NewWebRTCTransport(pc *webrtc.PeerConnection, audioTrack *webrtc.TrackLocal rtpTimestamp: rand.Uint32(), rtpMarker: true, // first packet of the stream gets marker } + t.closeDone = sync.OnceFunc(func() { close(t.closed) }) + t.dcDone = sync.OnceFunc(func() { close(t.dcReady) }) // The client creates the "oai-events" data channel (so m=application is // included in the SDP offer). We receive it here via OnDataChannel. @@ -63,7 +65,7 @@ func NewWebRTCTransport(pc *webrtc.PeerConnection, audioTrack *webrtc.TrackLocal } t.dc = dc dc.OnOpen(func() { - t.dcReadyOnce.Do(func() { close(t.dcReady) }) + t.dcDone() }) dc.OnMessage(func(msg webrtc.DataChannelMessage) { select { @@ -73,7 +75,7 @@ func NewWebRTCTransport(pc *webrtc.PeerConnection, audioTrack *webrtc.TrackLocal }) // The channel may already be open by the time OnDataChannel fires if dc.ReadyState() == webrtc.DataChannelStateOpen { - t.dcReadyOnce.Do(func() { close(t.dcReady) }) + t.dcDone() } }) @@ -82,7 +84,7 @@ func NewWebRTCTransport(pc *webrtc.PeerConnection, audioTrack *webrtc.TrackLocal if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed || state == webrtc.PeerConnectionStateDisconnected { - t.closeOnce.Do(func() { close(t.closed) }) + t.closeDone() } }) @@ -244,7 +246,7 @@ func (t *WebRTCTransport) WaitForSession() *Session { func (t *WebRTCTransport) Close() error { // Signal no more events and unblock the sender if it's waiting - t.closeOnce.Do(func() { close(t.closed) }) + t.closeDone() // Wait for the sender to drain any remaining queued events <-t.flushed return t.pc.Close() diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index c52fe1914..e7dd27db2 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -13,7 +13,6 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/pkg/format" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" @@ -82,7 +81,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app switch responseFormat { case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatText, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt: - return c.String(http.StatusOK, format.TranscriptionResponse(tr, responseFormat)) + return c.String(http.StatusOK, schema.TranscriptionResponse(tr, responseFormat)) case schema.TranscriptionResponseFormatJson: tr.Segments = nil fallthrough diff --git a/core/http/endpoints/openresponses/responses.go b/core/http/endpoints/openresponses/responses.go index 37c59b568..7ab3efbe7 100644 --- a/core/http/endpoints/openresponses/responses.go +++ b/core/http/endpoints/openresponses/responses.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "strings" "time" "github.com/google/uuid" @@ -29,7 +28,7 @@ import ( // @Param request body schema.OpenResponsesRequest true "Request body" // @Success 200 {object} schema.ORResponseResource "Response" // @Router /v1/responses [post] -func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { +func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc { return func(c echo.Context) error { createdAt := time.Now().Unix() responseID := fmt.Sprintf("resp_%s", uuid.New().String()) @@ -101,7 +100,7 @@ func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eval // Handle tools var funcs functions.Functions var shouldUseFn bool - var mcpToolInfos []mcpTools.MCPToolInfo + var mcpExecutor mcpTools.ToolExecutor if len(input.Tools) > 0 { funcs, shouldUseFn = convertORToolsToFunctions(input, cfg) @@ -115,113 +114,47 @@ func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eval hasMCPRequest := len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0 hasMCPConfig := cfg.MCP.Servers != "" || cfg.MCP.Stdio != "" - if hasMCPRequest && hasMCPConfig { + if (hasMCPRequest && hasMCPConfig) || (len(input.Tools) == 0 && hasMCPConfig) { remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML() if mcpErr == nil { - namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers) - if sessErr == nil && len(namedSessions) > 0 { - // Prompt injection - if mcpPromptName != "" { - prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions) - if discErr == nil { - promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs) - if getErr == nil { - var injected []schema.Message - for _, pm := range promptMsgs { - injected = append(injected, schema.Message{ - Role: string(pm.Role), - Content: mcpTools.PromptMessageToText(pm), - }) - } - messages = append(injected, messages...) - xlog.Debug("Open Responses MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected)) - } else { - xlog.Error("Failed to get MCP prompt", "error", getErr) - } + enabledServers := mcpServers + if !hasMCPRequest { + enabledServers = nil // backward compat: auto-activate all servers + } + mcpExecutor = mcpTools.NewToolExecutor(c.Request().Context(), natsClient, cfg.Name, remote, stdio, enabledServers) + + // Prompt and resource injection (pre-processing step — resolves locally regardless of distributed mode) + if hasMCPRequest { + namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers) + if sessErr == nil && len(namedSessions) > 0 { + mcpCtx, _ := mcpTools.InjectMCPContext(c.Request().Context(), namedSessions, mcpPromptName, mcpPromptArgs, mcpResourceURIs) + if mcpCtx != nil { + messages = append(mcpCtx.PromptMessages, messages...) + mcpTools.AppendResourceSuffix(messages, mcpCtx.ResourceSuffix) } } + } - // Resource injection - if len(mcpResourceURIs) > 0 { - resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions) - if discErr == nil { - var resourceTexts []string - for _, uri := range mcpResourceURIs { - content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri) - if readErr != nil { - xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri) - continue - } - name := uri - for _, r := range resources { - if r.URI == uri { - name = r.Name - break - } - } - resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content)) - } - if len(resourceTexts) > 0 && len(messages) > 0 { - lastIdx := len(messages) - 1 - suffix := "\n\n" + strings.Join(resourceTexts, "\n\n") - switch ct := messages[lastIdx].Content.(type) { - case string: - messages[lastIdx].Content = ct + suffix - default: - messages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix) - } - xlog.Debug("Open Responses MCP resources injected", "count", len(resourceTexts)) - } - } - } - - // Tool injection - if len(mcpServers) > 0 { - discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions) - if discErr == nil { - mcpToolInfos = discovered - for _, ti := range mcpToolInfos { - funcs = append(funcs, ti.Function) - input.Tools = append(input.Tools, schema.ORFunctionTool{ - Type: "function", - Name: ti.Function.Name, - Description: ti.Function.Description, - Parameters: ti.Function.Parameters, - }) - } - shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions() - xlog.Debug("Open Responses MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs)) - } else { - xlog.Error("Failed to discover MCP tools", "error", discErr) + // Tool injection via executor + if mcpExecutor.HasTools() { + mcpFuncs, discErr := mcpExecutor.DiscoverTools(c.Request().Context()) + if discErr == nil { + for _, fn := range mcpFuncs { + funcs = append(funcs, fn) + input.Tools = append(input.Tools, schema.ORFunctionTool{ + Type: "function", + Name: fn.Name, + Description: fn.Description, + Parameters: fn.Parameters, + }) } + shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions() + xlog.Debug("Open Responses MCP tools injected", "count", len(mcpFuncs), "total_funcs", len(funcs)) } } } else { xlog.Error("Failed to parse MCP config", "error", mcpErr) } - } else if len(input.Tools) == 0 && hasMCPConfig { - // Backward compat: model has MCP config, no user tools and no mcp_servers field - remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML() - if mcpErr == nil { - namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) - if sessErr == nil && len(namedSessions) > 0 { - discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions) - if discErr == nil { - mcpToolInfos = discovered - for _, ti := range mcpToolInfos { - funcs = append(funcs, ti.Function) - input.Tools = append(input.Tools, schema.ORFunctionTool{ - Type: "function", - Name: ti.Function.Name, - Description: ti.Function.Description, - Parameters: ti.Function.Parameters, - }) - } - shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions() - xlog.Debug("Open Responses MCP tools auto-activated", "count", len(mcpToolInfos)) - } - } - } } // Create OpenAI-compatible request for internal processing @@ -259,9 +192,9 @@ func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eval noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ + Parameters: map[string]any{ + "properties": map[string]any{ + "message": map[string]any{ "type": "string", "description": "The message to reply the user with", }, @@ -327,10 +260,10 @@ func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eval if input.Stream { // Background streaming processing (buffer events) - finalResponse, bgErr = handleBackgroundStream(bgCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) + finalResponse, bgErr = handleBackgroundStream(bgCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator) } else { // Background non-streaming processing - finalResponse, bgErr = handleBackgroundNonStream(bgCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) + finalResponse, bgErr = handleBackgroundNonStream(bgCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator) } if bgErr != nil { @@ -351,25 +284,25 @@ func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eval } if input.Stream { - return handleOpenResponsesStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore, mcpToolInfos, evaluator) + return handleOpenResponsesStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore, mcpExecutor, evaluator) } - return handleOpenResponsesNonStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore, mcpToolInfos, evaluator, 0) + return handleOpenResponsesNonStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore, mcpExecutor, evaluator, 0) } } // convertORInputToMessages converts Open Responses input to internal Messages -func convertORInputToMessages(input interface{}, cfg *config.ModelConfig) ([]schema.Message, error) { +func convertORInputToMessages(input any, cfg *config.ModelConfig) ([]schema.Message, error) { var messages []schema.Message switch v := input.(type) { case string: // Simple string = user message return []schema.Message{{Role: "user", StringContent: v}}, nil - case []interface{}: + case []any: // Array of items for _, itemRaw := range v { - itemMap, ok := itemRaw.(map[string]interface{}) + itemMap, ok := itemRaw.(map[string]any) if !ok { continue } @@ -445,14 +378,14 @@ func convertORInputToMessages(input interface{}, cfg *config.ModelConfig) ([]sch } // convertORReasoningItemToMessage converts an Open Responses reasoning item to an assistant Message fragment (for merging). -func convertORReasoningItemToMessage(itemMap map[string]interface{}) (schema.Message, error) { +func convertORReasoningItemToMessage(itemMap map[string]any) (schema.Message, error) { var reasoning string if content := itemMap["content"]; content != nil { if s, ok := content.(string); ok { reasoning = s - } else if parts, ok := content.([]interface{}); ok { + } else if parts, ok := content.([]any); ok { for _, p := range parts { - if partMap, ok := p.(map[string]interface{}); ok { + if partMap, ok := p.(map[string]any); ok { if t, _ := partMap["type"].(string); (t == "output_text" || t == "input_text") && partMap["text"] != nil { if tStr, ok := partMap["text"].(string); ok { reasoning += tStr @@ -466,7 +399,7 @@ func convertORReasoningItemToMessage(itemMap map[string]interface{}) (schema.Mes } // convertORFunctionCallItemToMessage converts an Open Responses function_call item to an assistant Message fragment (for merging). -func convertORFunctionCallItemToMessage(itemMap map[string]interface{}) (schema.Message, error) { +func convertORFunctionCallItemToMessage(itemMap map[string]any) (schema.Message, error) { callID, _ := itemMap["call_id"].(string) name, _ := itemMap["name"].(string) arguments, _ := itemMap["arguments"].(string) @@ -694,7 +627,7 @@ func flushAssistantAccumulator(out *[]schema.Message, acc **schema.Message) { } // convertORMessageItem converts an Open Responses message item to internal Message -func convertORMessageItem(itemMap map[string]interface{}, cfg *config.ModelConfig) (schema.Message, error) { +func convertORMessageItem(itemMap map[string]any, cfg *config.ModelConfig) (schema.Message, error) { role, _ := itemMap["role"].(string) msg := schema.Message{Role: role} @@ -703,7 +636,7 @@ func convertORMessageItem(itemMap map[string]interface{}, cfg *config.ModelConfi case string: msg.StringContent = contentVal msg.Content = contentVal - case []interface{}: + case []any: // Array of content parts var textContent string var stringImages []string @@ -711,7 +644,7 @@ func convertORMessageItem(itemMap map[string]interface{}, cfg *config.ModelConfi var stringAudios []string for _, partRaw := range contentVal { - partMap, ok := partRaw.(map[string]interface{}) + partMap, ok := partRaw.(map[string]any) if !ok { continue } @@ -834,7 +767,7 @@ func convertORToolsToFunctions(input *schema.OpenResponsesRequest, cfg *config.M // "auto" is the default - let model decide whether to use tools // Tools are available but not forced } - case map[string]interface{}: + case map[string]any: if tcType, ok := tc["type"].(string); ok && tcType == "function" { if name, ok := tc["name"].(string); ok { cfg.SetFunctionCallString(name) @@ -847,31 +780,31 @@ func convertORToolsToFunctions(input *schema.OpenResponsesRequest, cfg *config.M } // convertTextFormatToResponseFormat converts Open Responses text_format to OpenAI response_format -func convertTextFormatToResponseFormat(textFormat interface{}) interface{} { +func convertTextFormatToResponseFormat(textFormat any) any { switch tf := textFormat.(type) { - case map[string]interface{}: + case map[string]any: if tfType, ok := tf["type"].(string); ok { if tfType == "json_schema" { - return map[string]interface{}{ + return map[string]any{ "type": "json_schema", "json_schema": tf, } } - return map[string]interface{}{"type": tfType} + return map[string]any{"type": tfType} } case string: - return map[string]interface{}{"type": tf} + return map[string]any{"type": tf} } return nil } // handleBackgroundNonStream handles background non-streaming responses -func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) { +func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) { mcpMaxIterations := 10 if cfg.Agent.MaxIterations > 0 { mcpMaxIterations = cfg.Agent.MaxIterations } - hasMCPTools := len(mcpToolInfos) > 0 + hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools() var allOutputItems []schema.ORItemField for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { @@ -933,7 +866,7 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon for i, fc := range funcCallResults { if fc.Name == noActionName { if fc.Arguments != "" { - var args map[string]interface{} + var args map[string]any if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { if msg, ok := args["message"].(string); ok && msg != "" { textContent = msg @@ -957,7 +890,7 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon if hasMCPTools && len(toolCalls) > 0 { var hasMCPCalls bool for _, tc := range toolCalls { - if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + if mcpExecutor != nil && mcpExecutor.IsTool(tc.FunctionCall.Name) { hasMCPCalls = true break } @@ -973,10 +906,10 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon Status: "completed", CallID: tc.ID, Name: tc.FunctionCall.Name, Arguments: tc.FunctionCall.Arguments, }) - if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) { continue } - toolResult, toolErr := mcpTools.ExecuteMCPToolCall(ctx, mcpToolInfos, tc.FunctionCall.Name, tc.FunctionCall.Arguments) + toolResult, toolErr := mcpExecutor.ExecuteTool(ctx, tc.FunctionCall.Name, tc.FunctionCall.Arguments) if toolErr != nil { toolResult = fmt.Sprintf("Error: %v", toolErr) } @@ -1062,7 +995,7 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon } // handleBackgroundStream handles background streaming responses with event buffering -func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) { +func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) { // Populate openAIReq fields for ComputeChoices openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools) openAIReq.ToolsChoice = input.ToolChoice @@ -1099,7 +1032,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI if cfg.Agent.MaxIterations > 0 { mcpBgStreamMaxIterations = cfg.Agent.MaxIterations } - hasMCPTools := len(mcpToolInfos) > 0 + hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools() var lastTokenUsage backend.TokenUsage var lastLogprobs *schema.Logprobs @@ -1201,14 +1134,14 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI } toolCalls = append(toolCalls, schema.ToolCall{ Index: i, ID: fmt.Sprintf("fc_%s", uuid.New().String()), - Type: "function", + Type: "function", FunctionCall: schema.FunctionCall{Name: fc.Name, Arguments: fc.Arguments}, }) } var hasMCPCalls bool for _, tc := range toolCalls { - if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + if mcpExecutor != nil && mcpExecutor.IsTool(tc.FunctionCall.Name) { hasMCPCalls = true break } @@ -1264,12 +1197,12 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI sequenceNumber++ collectedOutputItems = append(collectedOutputItems, *functionCallItem) - if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool (background stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIter) - toolResult, toolErr := mcpTools.ExecuteMCPToolCall(ctx, mcpToolInfos, tc.FunctionCall.Name, tc.FunctionCall.Arguments) + toolResult, toolErr := mcpExecutor.ExecuteTool(ctx, tc.FunctionCall.Name, tc.FunctionCall.Arguments) if toolErr != nil { toolResult = fmt.Sprintf("Error: %v", toolErr) } @@ -1368,7 +1301,7 @@ func bufferEvent(store *ResponseStore, responseID string, event *schema.ORStream } // handleOpenResponsesNonStream handles non-streaming responses -func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, shouldStore bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator, mcpIteration int) error { +func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, shouldStore bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator, mcpIteration int) error { mcpMaxIterations := 10 if cfg.Agent.MaxIterations > 0 { mcpMaxIterations = cfg.Agent.MaxIterations @@ -1458,7 +1391,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i // This is a text response, not a tool call // Try to extract the message from the arguments if fc.Arguments != "" { - var args map[string]interface{} + var args map[string]any if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { if msg, ok := args["message"].(string); ok && msg != "" { textContent = msg @@ -1479,10 +1412,10 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i } // MCP server-side tool execution: if any tool calls are MCP tools, execute and re-run - if len(mcpToolInfos) > 0 && len(toolCalls) > 0 { + if mcpExecutor != nil && mcpExecutor.HasTools() && len(toolCalls) > 0 { var hasMCPCalls bool for _, tc := range toolCalls { - if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + if mcpExecutor != nil && mcpExecutor.IsTool(tc.FunctionCall.Name) { hasMCPCalls = true break } @@ -1494,13 +1427,12 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i // Execute each MCP tool call and append results for _, tc := range toolCalls { - if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { + if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool (Open Responses)", "tool", tc.FunctionCall.Name) - toolResult, toolErr := mcpTools.ExecuteMCPToolCall( - c.Request().Context(), mcpToolInfos, - tc.FunctionCall.Name, tc.FunctionCall.Arguments, + toolResult, toolErr := mcpExecutor.ExecuteTool( + c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments, ) if toolErr != nil { xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) @@ -1523,7 +1455,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i // Re-template and re-run inference predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) - return handleOpenResponsesNonStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore, mcpToolInfos, evaluator, mcpIteration+1) + return handleOpenResponsesNonStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore, mcpExecutor, evaluator, mcpIteration+1) } } @@ -1648,7 +1580,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i } // handleOpenResponsesStream handles streaming responses -func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, shouldStore bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error { +func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, shouldStore bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") @@ -1712,7 +1644,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 if cfg.Agent.MaxIterations > 0 { mcpStreamMaxIterations = cfg.Agent.MaxIterations } - hasMCPToolsStream := len(mcpToolInfos) > 0 + hasMCPToolsStream := mcpExecutor != nil && mcpExecutor.HasTools() var result, finalReasoning, finalCleanedResult string var textContent string @@ -1722,146 +1654,70 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 var lastStreamLogprobs *schema.Logprobs for mcpStreamIter := 0; mcpStreamIter <= mcpStreamMaxIterations; mcpStreamIter++ { - if mcpStreamIter > 0 { - // Reset reasoning and tool-call state for re-inference so reasoning - // extraction runs again on subsequent iterations - inToolCallMode = false - extractor.Reset() - currentMessageID = "" - lastEmittedToolCallCount = 0 - currentReasoningID = "" + if mcpStreamIter > 0 { + // Reset reasoning and tool-call state for re-inference so reasoning + // extraction runs again on subsequent iterations + inToolCallMode = false + extractor.Reset() + currentMessageID = "" + lastEmittedToolCallCount = 0 + currentReasoningID = "" - predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) - xlog.Debug("Open Responses stream MCP re-templating", "iteration", mcpStreamIter) - } - - // For tool calls, we need to track accumulated result and parse incrementally - // We'll handle this differently - track the full result and parse tool calls - accumulatedResult := "" - tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool { - accumulatedResult += token - accumulatedText += token - - // Try to parse tool calls incrementally - cleanedResult := functions.CleanupLLMResult(accumulatedResult, cfg.FunctionsConfig) - - // Determine XML format from config - var xmlFormat *functions.XMLToolCallFormat - if cfg.FunctionsConfig.XMLFormat != nil { - xmlFormat = cfg.FunctionsConfig.XMLFormat - } else if cfg.FunctionsConfig.XMLFormatPreset != "" { - xmlFormat = functions.GetXMLFormatPreset(cfg.FunctionsConfig.XMLFormatPreset) + predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) + xlog.Debug("Open Responses stream MCP re-templating", "iteration", mcpStreamIter) } - // Try XML parsing first - partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true) - if parseErr == nil && len(partialResults) > lastEmittedToolCallCount { - // New tool calls detected - if !inToolCallMode && currentMessageID != "" { - // Close the current message content part - textPart := makeOutputTextPart(functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)) - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.content_part.done", - SequenceNumber: sequenceNumber, - ItemID: currentMessageID, - OutputIndex: &outputIndex, - ContentIndex: ¤tContentIndex, - Part: &textPart, - }) - sequenceNumber++ - inToolCallMode = true + // For tool calls, we need to track accumulated result and parse incrementally + // We'll handle this differently - track the full result and parse tool calls + accumulatedResult := "" + tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool { + accumulatedResult += token + accumulatedText += token + + // Try to parse tool calls incrementally + cleanedResult := functions.CleanupLLMResult(accumulatedResult, cfg.FunctionsConfig) + + // Determine XML format from config + var xmlFormat *functions.XMLToolCallFormat + if cfg.FunctionsConfig.XMLFormat != nil { + xmlFormat = cfg.FunctionsConfig.XMLFormat + } else if cfg.FunctionsConfig.XMLFormatPreset != "" { + xmlFormat = functions.GetXMLFormatPreset(cfg.FunctionsConfig.XMLFormatPreset) } - // Emit new tool calls - for i := lastEmittedToolCallCount; i < len(partialResults); i++ { - tc := partialResults[i] - toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) - outputIndex++ - - // Emit function_call item added - functionCallItem := &schema.ORItemField{ - Type: "function_call", - ID: toolCallID, - Status: "in_progress", - CallID: toolCallID, - Name: tc.Name, - Arguments: "", + // Try XML parsing first + partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true) + if parseErr == nil && len(partialResults) > lastEmittedToolCallCount { + // New tool calls detected + if !inToolCallMode && currentMessageID != "" { + // Close the current message content part + textPart := makeOutputTextPart(functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &textPart, + }) + sequenceNumber++ + inToolCallMode = true } - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.added", - SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, - Item: functionCallItem, - }) - sequenceNumber++ - - // Emit arguments delta - if tc.Arguments != "" { - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.function_call_arguments.delta", - SequenceNumber: sequenceNumber, - ItemID: toolCallID, - OutputIndex: &outputIndex, - Delta: strPtr(tc.Arguments), - }) - sequenceNumber++ - - // Emit arguments done - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.function_call_arguments.done", - SequenceNumber: sequenceNumber, - ItemID: toolCallID, - OutputIndex: &outputIndex, - Arguments: strPtr(tc.Arguments), - }) - sequenceNumber++ - - // Emit function_call item done - functionCallItem.Status = "completed" - functionCallItem.Arguments = tc.Arguments - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.done", - SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, - Item: functionCallItem, - }) - sequenceNumber++ - - // Collect item for storage - collectedOutputItems = append(collectedOutputItems, *functionCallItem) - } - } - lastEmittedToolCallCount = len(partialResults) - c.Response().Flush() - return true - } - - // Try JSON parsing as fallback - jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true) - if jsonErr == nil && len(jsonResults) > lastEmittedToolCallCount { - for i := lastEmittedToolCallCount; i < len(jsonResults); i++ { - jsonObj := jsonResults[i] - if name, ok := jsonObj["name"].(string); ok && name != "" { - args := "{}" - if argsVal, ok := jsonObj["arguments"]; ok { - if argsStr, ok := argsVal.(string); ok { - args = argsStr - } else { - argsBytes, _ := json.Marshal(argsVal) - args = string(argsBytes) - } - } + // Emit new tool calls + for i := lastEmittedToolCallCount; i < len(partialResults); i++ { + tc := partialResults[i] toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) outputIndex++ + // Emit function_call item added functionCallItem := &schema.ORItemField{ Type: "function_call", ID: toolCallID, - Status: "completed", + Status: "in_progress", CallID: toolCallID, - Name: name, - Arguments: args, + Name: tc.Name, + Arguments: "", } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", @@ -1871,442 +1727,517 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 }) sequenceNumber++ - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.done", - SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, - Item: functionCallItem, - }) - sequenceNumber++ + // Emit arguments delta + if tc.Arguments != "" { + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.function_call_arguments.delta", + SequenceNumber: sequenceNumber, + ItemID: toolCallID, + OutputIndex: &outputIndex, + Delta: strPtr(tc.Arguments), + }) + sequenceNumber++ + + // Emit arguments done + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.function_call_arguments.done", + SequenceNumber: sequenceNumber, + ItemID: toolCallID, + OutputIndex: &outputIndex, + Arguments: strPtr(tc.Arguments), + }) + sequenceNumber++ + + // Emit function_call item done + functionCallItem.Status = "completed" + functionCallItem.Arguments = tc.Arguments + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + + // Collect item for storage + collectedOutputItems = append(collectedOutputItems, *functionCallItem) + } } + lastEmittedToolCallCount = len(partialResults) + c.Response().Flush() + return true } - lastEmittedToolCallCount = len(jsonResults) - c.Response().Flush() - return true - } - // If no tool calls detected yet, handle reasoning and text - if !inToolCallMode { - reasoningDelta, contentDelta := extractor.ProcessToken(token) + // Try JSON parsing as fallback + jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true) + if jsonErr == nil && len(jsonResults) > lastEmittedToolCallCount { + for i := lastEmittedToolCallCount; i < len(jsonResults); i++ { + jsonObj := jsonResults[i] + if name, ok := jsonObj["name"].(string); ok && name != "" { + args := "{}" + if argsVal, ok := jsonObj["arguments"]; ok { + if argsStr, ok := argsVal.(string); ok { + args = argsStr + } else { + argsBytes, _ := json.Marshal(argsVal) + args = string(argsBytes) + } + } - // Handle reasoning item - if extractor.Reasoning() != "" { - // Check if we need to create reasoning item - if currentReasoningID == "" { - outputIndex++ - currentReasoningID = fmt.Sprintf("reasoning_%s", uuid.New().String()) - reasoningItem := &schema.ORItemField{ - Type: "reasoning", - ID: currentReasoningID, - Status: "in_progress", + toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) + outputIndex++ + + functionCallItem := &schema.ORItemField{ + Type: "function_call", + ID: toolCallID, + Status: "completed", + CallID: toolCallID, + Name: name, + Arguments: args, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ } - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.added", - SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, - Item: reasoningItem, - }) - sequenceNumber++ + } + lastEmittedToolCallCount = len(jsonResults) + c.Response().Flush() + return true + } - // Emit content_part.added for reasoning - currentReasoningContentIndex = 0 - emptyPart := makeOutputTextPart("") - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.content_part.added", - SequenceNumber: sequenceNumber, - ItemID: currentReasoningID, - OutputIndex: &outputIndex, - ContentIndex: ¤tReasoningContentIndex, - Part: &emptyPart, - }) - sequenceNumber++ + // If no tool calls detected yet, handle reasoning and text + if !inToolCallMode { + reasoningDelta, contentDelta := extractor.ProcessToken(token) + + // Handle reasoning item + if extractor.Reasoning() != "" { + // Check if we need to create reasoning item + if currentReasoningID == "" { + outputIndex++ + currentReasoningID = fmt.Sprintf("reasoning_%s", uuid.New().String()) + reasoningItem := &schema.ORItemField{ + Type: "reasoning", + ID: currentReasoningID, + Status: "in_progress", + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: reasoningItem, + }) + sequenceNumber++ + + // Emit content_part.added for reasoning + currentReasoningContentIndex = 0 + emptyPart := makeOutputTextPart("") + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.added", + SequenceNumber: sequenceNumber, + ItemID: currentReasoningID, + OutputIndex: &outputIndex, + ContentIndex: ¤tReasoningContentIndex, + Part: &emptyPart, + }) + sequenceNumber++ + } + + // Emit reasoning delta if there's new content + if reasoningDelta != "" { + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_text.delta", + SequenceNumber: sequenceNumber, + ItemID: currentReasoningID, + OutputIndex: &outputIndex, + ContentIndex: ¤tReasoningContentIndex, + Delta: strPtr(reasoningDelta), + Logprobs: emptyLogprobs(), + }) + sequenceNumber++ + c.Response().Flush() + } } - // Emit reasoning delta if there's new content - if reasoningDelta != "" { + // Only emit message content if there's actual content (not just reasoning) + if contentDelta != "" { + if currentMessageID == "" { + // Emit output_item.added for message + outputIndex++ + currentMessageID = fmt.Sprintf("msg_%s", uuid.New().String()) + messageItem := &schema.ORItemField{ + Type: "message", + ID: currentMessageID, + Status: "in_progress", + Role: "assistant", + Content: []schema.ORContentPart{}, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: messageItem, + }) + sequenceNumber++ + + // Emit content_part.added + currentContentIndex = 0 + emptyPart := makeOutputTextPart("") + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.added", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &emptyPart, + }) + sequenceNumber++ + } + + // Emit text delta sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.delta", SequenceNumber: sequenceNumber, - ItemID: currentReasoningID, + ItemID: currentMessageID, OutputIndex: &outputIndex, - ContentIndex: ¤tReasoningContentIndex, - Delta: strPtr(reasoningDelta), + ContentIndex: ¤tContentIndex, + Delta: strPtr(contentDelta), Logprobs: emptyLogprobs(), }) sequenceNumber++ c.Response().Flush() } } - - // Only emit message content if there's actual content (not just reasoning) - if contentDelta != "" { - if currentMessageID == "" { - // Emit output_item.added for message - outputIndex++ - currentMessageID = fmt.Sprintf("msg_%s", uuid.New().String()) - messageItem := &schema.ORItemField{ - Type: "message", - ID: currentMessageID, - Status: "in_progress", - Role: "assistant", - Content: []schema.ORContentPart{}, - } - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.added", - SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, - Item: messageItem, - }) - sequenceNumber++ - - // Emit content_part.added - currentContentIndex = 0 - emptyPart := makeOutputTextPart("") - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.content_part.added", - SequenceNumber: sequenceNumber, - ItemID: currentMessageID, - OutputIndex: &outputIndex, - ContentIndex: ¤tContentIndex, - Part: &emptyPart, - }) - sequenceNumber++ - } - - // Emit text delta - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_text.delta", - SequenceNumber: sequenceNumber, - ItemID: currentMessageID, - OutputIndex: &outputIndex, - ContentIndex: ¤tContentIndex, - Delta: strPtr(contentDelta), - Logprobs: emptyLogprobs(), - }) - sequenceNumber++ - c.Response().Flush() - } + return true } - return true - } - var ccResult string - ccCb := func(s string, c *[]schema.Choice) { - ccResult = s - } - choices, ccTokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, ccCb, tokenCallback) - if err != nil { - xlog.Error("Open Responses stream model inference failed", "error", err) - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "error", - SequenceNumber: sequenceNumber, - Error: &schema.ORErrorPayload{ - Type: "model_error", - Message: fmt.Sprintf("model inference failed: %v", err), - }, - }) - sequenceNumber++ - responseFailed := responseCreated - responseFailed.Status = "failed" - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.failed", - SequenceNumber: sequenceNumber, - Response: responseFailed, - }) - // Send [DONE] even on error - fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") - c.Response().Flush() - return nil - } - result = ccResult - lastStreamTokenUsage = ccTokenUsage - if len(choices) > 0 { - lastStreamLogprobs = choices[0].Logprobs - } + var ccResult string + ccCb := func(s string, c *[]schema.Choice) { + ccResult = s + } + choices, ccTokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, ccCb, tokenCallback) + if err != nil { + xlog.Error("Open Responses stream model inference failed", "error", err) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "error", + SequenceNumber: sequenceNumber, + Error: &schema.ORErrorPayload{ + Type: "model_error", + Message: fmt.Sprintf("model inference failed: %v", err), + }, + }) + sequenceNumber++ + responseFailed := responseCreated + responseFailed.Status = "failed" + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.failed", + SequenceNumber: sequenceNumber, + Response: responseFailed, + }) + // Send [DONE] even on error + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + result = ccResult + lastStreamTokenUsage = ccTokenUsage + if len(choices) > 0 { + lastStreamLogprobs = choices[0].Logprobs + } - // Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's - // streaming state, (3) final extraction from the finetuned result. - if chatDeltaReasoning := functions.ReasoningFromChatDeltas(chatDeltas); chatDeltaReasoning != "" { - finalReasoning = chatDeltaReasoning - finalCleanedResult = functions.ContentFromChatDeltas(chatDeltas) - if finalCleanedResult == "" { + // Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's + // streaming state, (3) final extraction from the finetuned result. + if chatDeltaReasoning := functions.ReasoningFromChatDeltas(chatDeltas); chatDeltaReasoning != "" { + finalReasoning = chatDeltaReasoning + finalCleanedResult = functions.ContentFromChatDeltas(chatDeltas) + if finalCleanedResult == "" { + finalCleanedResult = extractor.CleanedContent() + } + } else { + finalReasoning = extractor.Reasoning() finalCleanedResult = extractor.CleanedContent() } - } else { - finalReasoning = extractor.Reasoning() - finalCleanedResult = extractor.CleanedContent() - } - if finalReasoning == "" && finalCleanedResult == "" { - finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig) - } - - // Close reasoning item if it exists and wasn't closed yet - if currentReasoningID != "" && finalReasoning != "" { - // Emit output_text.done for reasoning - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_text.done", - SequenceNumber: sequenceNumber, - ItemID: currentReasoningID, - OutputIndex: &outputIndex, - ContentIndex: ¤tReasoningContentIndex, - Text: strPtr(finalReasoning), - Logprobs: emptyLogprobs(), - }) - sequenceNumber++ - - // Emit content_part.done for reasoning - reasoningPart := makeOutputTextPart(finalReasoning) - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.content_part.done", - SequenceNumber: sequenceNumber, - ItemID: currentReasoningID, - OutputIndex: &outputIndex, - ContentIndex: ¤tReasoningContentIndex, - Part: &reasoningPart, - }) - sequenceNumber++ - - // Emit output_item.done for reasoning - reasoningItem := &schema.ORItemField{ - Type: "reasoning", - ID: currentReasoningID, - Status: "completed", - Content: []schema.ORContentPart{reasoningPart}, + if finalReasoning == "" && finalCleanedResult == "" { + finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig) } - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.done", - SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, - Item: reasoningItem, - }) - sequenceNumber++ - // Collect reasoning item for storage - collectedOutputItems = append(collectedOutputItems, *reasoningItem) + // Close reasoning item if it exists and wasn't closed yet + if currentReasoningID != "" && finalReasoning != "" { + // Emit output_text.done for reasoning + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_text.done", + SequenceNumber: sequenceNumber, + ItemID: currentReasoningID, + OutputIndex: &outputIndex, + ContentIndex: ¤tReasoningContentIndex, + Text: strPtr(finalReasoning), + Logprobs: emptyLogprobs(), + }) + sequenceNumber++ - // Calculate reasoning tokens - reasoningTokens = len(finalReasoning) / 4 - if reasoningTokens == 0 && len(finalReasoning) > 0 { - reasoningTokens = 1 + // Emit content_part.done for reasoning + reasoningPart := makeOutputTextPart(finalReasoning) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.done", + SequenceNumber: sequenceNumber, + ItemID: currentReasoningID, + OutputIndex: &outputIndex, + ContentIndex: ¤tReasoningContentIndex, + Part: &reasoningPart, + }) + sequenceNumber++ + + // Emit output_item.done for reasoning + reasoningItem := &schema.ORItemField{ + Type: "reasoning", + ID: currentReasoningID, + Status: "completed", + Content: []schema.ORContentPart{reasoningPart}, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: reasoningItem, + }) + sequenceNumber++ + + // Collect reasoning item for storage + collectedOutputItems = append(collectedOutputItems, *reasoningItem) + + // Calculate reasoning tokens + reasoningTokens = len(finalReasoning) / 4 + if reasoningTokens == 0 && len(finalReasoning) > 0 { + reasoningTokens = 1 + } } - } - parsedToolCalls = nil - textContent = "" + parsedToolCalls = nil + textContent = "" - // Try pre-parsed tool calls from C++ autoparser first - if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { - xlog.Debug("[ChatDeltas] OpenResponses Stream: using pre-parsed tool calls", "count", len(deltaToolCalls)) - parsedToolCalls = deltaToolCalls - textContent = functions.ContentFromChatDeltas(chatDeltas) - } else { - xlog.Debug("[ChatDeltas] OpenResponses Stream: no pre-parsed tool calls, falling back to Go-side text parsing") - cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig) - parsedToolCalls = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) - textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig) - } + // Try pre-parsed tool calls from C++ autoparser first + if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { + xlog.Debug("[ChatDeltas] OpenResponses Stream: using pre-parsed tool calls", "count", len(deltaToolCalls)) + parsedToolCalls = deltaToolCalls + textContent = functions.ContentFromChatDeltas(chatDeltas) + } else { + xlog.Debug("[ChatDeltas] OpenResponses Stream: no pre-parsed tool calls, falling back to Go-side text parsing") + cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig) + parsedToolCalls = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) + textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig) + } - // Handle noAction function (model chose to respond without tool) - noActionName := "answer" - if cfg.FunctionsConfig.NoActionFunctionName != "" { - noActionName = cfg.FunctionsConfig.NoActionFunctionName - } + // Handle noAction function (model chose to respond without tool) + noActionName := "answer" + if cfg.FunctionsConfig.NoActionFunctionName != "" { + noActionName = cfg.FunctionsConfig.NoActionFunctionName + } - // Filter out noAction calls and extract the message - toolCalls = nil - for _, fc := range parsedToolCalls { - if fc.Name == noActionName { - // This is a text response, not a tool call - if fc.Arguments != "" { - var args map[string]interface{} - if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { - if msg, ok := args["message"].(string); ok && msg != "" { - textContent = msg + // Filter out noAction calls and extract the message + toolCalls = nil + for _, fc := range parsedToolCalls { + if fc.Name == noActionName { + // This is a text response, not a tool call + if fc.Arguments != "" { + var args map[string]any + if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { + if msg, ok := args["message"].(string); ok && msg != "" { + textContent = msg + } } } + continue } - continue + toolCalls = append(toolCalls, fc) } - toolCalls = append(toolCalls, fc) - } - xlog.Debug("Open Responses Stream - Parsed", "toolCalls", len(toolCalls), "textContent", textContent) + xlog.Debug("Open Responses Stream - Parsed", "toolCalls", len(toolCalls), "textContent", textContent) - // MCP streaming tool execution: check if any tool calls are MCP tools - if hasMCPToolsStream && len(toolCalls) > 0 { - var hasMCPCalls bool - for _, tc := range toolCalls { - if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { - hasMCPCalls = true - break + // MCP streaming tool execution: check if any tool calls are MCP tools + if hasMCPToolsStream && len(toolCalls) > 0 { + var hasMCPCalls bool + for _, tc := range toolCalls { + if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) { + hasMCPCalls = true + break + } + } + if hasMCPCalls { + // Build schema.ToolCall list for the assistant message + var schemaToolCalls []schema.ToolCall + for i, tc := range toolCalls { + schemaToolCalls = append(schemaToolCalls, schema.ToolCall{ + Index: i, ID: fmt.Sprintf("fc_%s", uuid.New().String()), + Type: "function", + FunctionCall: schema.FunctionCall{Name: tc.Name, Arguments: tc.Arguments}, + }) + } + assistantMsg := schema.Message{Role: "assistant", Content: result, ToolCalls: schemaToolCalls} + openAIReq.Messages = append(openAIReq.Messages, assistantMsg) + + for idx, tc := range toolCalls { + tcID := schemaToolCalls[idx].ID + + // Emit function_call item + outputIndex++ + functionCallItem := &schema.ORItemField{ + Type: "function_call", ID: tcID, Status: "completed", + CallID: tcID, Name: tc.Name, Arguments: tc.Arguments, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, Item: functionCallItem, + }) + sequenceNumber++ + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, Item: functionCallItem, + }) + sequenceNumber++ + collectedOutputItems = append(collectedOutputItems, *functionCallItem) + + if mcpExecutor == nil || !mcpExecutor.IsTool(tc.Name) { + continue + } + + // Execute MCP tool + xlog.Debug("Executing MCP tool (Open Responses stream)", "tool", tc.Name, "iteration", mcpStreamIter) + toolResult, toolErr := mcpExecutor.ExecuteTool( + input.Context, tc.Name, tc.Arguments, + ) + if toolErr != nil { + xlog.Error("MCP tool execution failed", "tool", tc.Name, "error", toolErr) + toolResult = fmt.Sprintf("Error: %v", toolErr) + } + openAIReq.Messages = append(openAIReq.Messages, schema.Message{ + Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tcID, Name: tc.Name, + }) + + // Emit function_call_output item + outputIndex++ + outputItem := &schema.ORItemField{ + Type: "function_call_output", ID: fmt.Sprintf("fco_%s", uuid.New().String()), + Status: "completed", CallID: tcID, Output: toolResult, + } + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, Item: outputItem, + }) + sequenceNumber++ + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, Item: outputItem, + }) + sequenceNumber++ + collectedOutputItems = append(collectedOutputItems, *outputItem) + } + c.Response().Flush() + xlog.Debug("MCP streaming tools executed, re-running inference", "iteration", mcpStreamIter) + continue // next MCP stream iteration } } - if hasMCPCalls { - // Build schema.ToolCall list for the assistant message - var schemaToolCalls []schema.ToolCall - for i, tc := range toolCalls { - schemaToolCalls = append(schemaToolCalls, schema.ToolCall{ - Index: i, ID: fmt.Sprintf("fc_%s", uuid.New().String()), - Type: "function", - FunctionCall: schema.FunctionCall{Name: tc.Name, Arguments: tc.Arguments}, - }) + + // Convert logprobs for streaming events + streamEventLogprobs := convertLogprobsForStreaming(lastStreamLogprobs) + + // If we have no output but the model did produce something, use the cleaned result (without reasoning tags) + if textContent == "" && len(toolCalls) == 0 && finalCleanedResult != "" { + xlog.Debug("Open Responses Stream - No parsed output, using cleaned result") + textContent = finalCleanedResult + } + + // Close message if we have text content + if currentMessageID != "" && textContent != "" && !inToolCallMode { + // Emit output_text.done + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_text.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Text: strPtr(textContent), + Logprobs: logprobsPtr(streamEventLogprobs), + }) + sequenceNumber++ + + // Emit content_part.done (with actual logprobs) + textPart := makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.content_part.done", + SequenceNumber: sequenceNumber, + ItemID: currentMessageID, + OutputIndex: &outputIndex, + ContentIndex: ¤tContentIndex, + Part: &textPart, + }) + sequenceNumber++ + + // Emit output_item.done for message (with actual logprobs) + messageItem := &schema.ORItemField{ + Type: "message", + ID: currentMessageID, + Status: "completed", + Role: "assistant", + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)}, } - assistantMsg := schema.Message{Role: "assistant", Content: result, ToolCalls: schemaToolCalls} - openAIReq.Messages = append(openAIReq.Messages, assistantMsg) + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: messageItem, + }) + sequenceNumber++ - for idx, tc := range toolCalls { - tcID := schemaToolCalls[idx].ID + // Collect message item for storage + collectedOutputItems = append(collectedOutputItems, *messageItem) + } - // Emit function_call item - outputIndex++ - functionCallItem := &schema.ORItemField{ - Type: "function_call", ID: tcID, Status: "completed", - CallID: tcID, Name: tc.Name, Arguments: tc.Arguments, - } - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.added", SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, Item: functionCallItem, - }) - sequenceNumber++ - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.done", SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, Item: functionCallItem, - }) - sequenceNumber++ - collectedOutputItems = append(collectedOutputItems, *functionCallItem) + // Emit any remaining tool calls that weren't streamed + for i := lastEmittedToolCallCount; i < len(toolCalls); i++ { + tc := toolCalls[i] + toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) + outputIndex++ - if !mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { - continue - } - - // Execute MCP tool - xlog.Debug("Executing MCP tool (Open Responses stream)", "tool", tc.Name, "iteration", mcpStreamIter) - toolResult, toolErr := mcpTools.ExecuteMCPToolCall( - input.Context, mcpToolInfos, tc.Name, tc.Arguments, - ) - if toolErr != nil { - xlog.Error("MCP tool execution failed", "tool", tc.Name, "error", toolErr) - toolResult = fmt.Sprintf("Error: %v", toolErr) - } - openAIReq.Messages = append(openAIReq.Messages, schema.Message{ - Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tcID, Name: tc.Name, - }) - - // Emit function_call_output item - outputIndex++ - outputItem := &schema.ORItemField{ - Type: "function_call_output", ID: fmt.Sprintf("fco_%s", uuid.New().String()), - Status: "completed", CallID: tcID, Output: toolResult, - } - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.added", SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, Item: outputItem, - }) - sequenceNumber++ - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.done", SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, Item: outputItem, - }) - sequenceNumber++ - collectedOutputItems = append(collectedOutputItems, *outputItem) + functionCallItem := &schema.ORItemField{ + Type: "function_call", + ID: toolCallID, + Status: "completed", + CallID: toolCallID, + Name: tc.Name, + Arguments: tc.Arguments, } - c.Response().Flush() - xlog.Debug("MCP streaming tools executed, re-running inference", "iteration", mcpStreamIter) - continue // next MCP stream iteration + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.added", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + + sendSSEEvent(c, &schema.ORStreamEvent{ + Type: "response.output_item.done", + SequenceNumber: sequenceNumber, + OutputIndex: &outputIndex, + Item: functionCallItem, + }) + sequenceNumber++ + + // Collect function call item for storage + collectedOutputItems = append(collectedOutputItems, *functionCallItem) } - } - - // Convert logprobs for streaming events - streamEventLogprobs := convertLogprobsForStreaming(lastStreamLogprobs) - - // If we have no output but the model did produce something, use the cleaned result (without reasoning tags) - if textContent == "" && len(toolCalls) == 0 && finalCleanedResult != "" { - xlog.Debug("Open Responses Stream - No parsed output, using cleaned result") - textContent = finalCleanedResult - } - - // Close message if we have text content - if currentMessageID != "" && textContent != "" && !inToolCallMode { - // Emit output_text.done - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_text.done", - SequenceNumber: sequenceNumber, - ItemID: currentMessageID, - OutputIndex: &outputIndex, - ContentIndex: ¤tContentIndex, - Text: strPtr(textContent), - Logprobs: logprobsPtr(streamEventLogprobs), - }) - sequenceNumber++ - - // Emit content_part.done (with actual logprobs) - textPart := makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs) - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.content_part.done", - SequenceNumber: sequenceNumber, - ItemID: currentMessageID, - OutputIndex: &outputIndex, - ContentIndex: ¤tContentIndex, - Part: &textPart, - }) - sequenceNumber++ - - // Emit output_item.done for message (with actual logprobs) - messageItem := &schema.ORItemField{ - Type: "message", - ID: currentMessageID, - Status: "completed", - Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)}, - } - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.done", - SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, - Item: messageItem, - }) - sequenceNumber++ - - // Collect message item for storage - collectedOutputItems = append(collectedOutputItems, *messageItem) - } - - // Emit any remaining tool calls that weren't streamed - for i := lastEmittedToolCallCount; i < len(toolCalls); i++ { - tc := toolCalls[i] - toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) - outputIndex++ - - functionCallItem := &schema.ORItemField{ - Type: "function_call", - ID: toolCallID, - Status: "completed", - CallID: toolCallID, - Name: tc.Name, - Arguments: tc.Arguments, - } - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.added", - SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, - Item: functionCallItem, - }) - sequenceNumber++ - - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.output_item.done", - SequenceNumber: sequenceNumber, - OutputIndex: &outputIndex, - Item: functionCallItem, - }) - sequenceNumber++ - - // Collect function call item for storage - collectedOutputItems = append(collectedOutputItems, *functionCallItem) - } - - break // no MCP tools to execute, exit loop + break // no MCP tools to execute, exit loop } // end MCP stream iteration loop // Build final response with all items (include reasoning first, then messages, then tool calls) @@ -2874,7 +2805,7 @@ func buildORResponse(responseID string, createdAt int64, completedAt *int64, sta } // Default tool_choice: "auto" if tools are present, "none" otherwise - var toolChoice interface{} + var toolChoice any if input.ToolChoice != nil { toolChoice = input.ToolChoice } else if len(tools) > 0 { @@ -2967,14 +2898,14 @@ func buildORResponse(responseID string, createdAt int64, completedAt *int64, sta // sendOpenResponsesError sends an error response func sendOpenResponsesError(c echo.Context, statusCode int, errorType, message, param string) error { - errorResp := map[string]interface{}{ - "error": map[string]interface{}{ + errorResp := map[string]any{ + "error": map[string]any{ "type": errorType, "message": message, }, } if param != "" { - errorResp["error"].(map[string]interface{})["param"] = param + errorResp["error"].(map[string]any)["param"] = param } return c.JSON(statusCode, errorResp) } @@ -3005,8 +2936,8 @@ func convertORToolsToOpenAIFormat(orTools []schema.ORFunctionTool) []functions.T // @Param stream query string false "Set to 'true' to resume streaming" // @Param starting_after query int false "Sequence number to resume from (for streaming)" // @Success 200 {object} schema.ORResponseResource "Response" -// @Failure 400 {object} map[string]interface{} "Bad Request" -// @Failure 404 {object} map[string]interface{} "Not Found" +// @Failure 400 {object} map[string]any "Bad Request" +// @Failure 404 {object} map[string]any "Not Found" // @Router /v1/responses/{id} [get] func GetResponseEndpoint() func(c echo.Context) error { return func(c echo.Context) error { @@ -3144,8 +3075,8 @@ func handleStreamResume(c echo.Context, store *ResponseStore, responseID string, // @Description Cancel a background response if it's still in progress // @Param id path string true "Response ID" // @Success 200 {object} schema.ORResponseResource "Response" -// @Failure 400 {object} map[string]interface{} "Bad Request" -// @Failure 404 {object} map[string]interface{} "Not Found" +// @Failure 400 {object} map[string]any "Bad Request" +// @Failure 404 {object} map[string]any "Not Found" // @Router /v1/responses/{id}/cancel [post] func CancelResponseEndpoint() func(c echo.Context) error { return func(c echo.Context) error { diff --git a/core/http/endpoints/openresponses/store.go b/core/http/endpoints/openresponses/store.go index a548254fb..bea5b7413 100644 --- a/core/http/endpoints/openresponses/store.go +++ b/core/http/endpoints/openresponses/store.go @@ -44,17 +44,13 @@ type StoredResponse struct { mu sync.RWMutex // Protect concurrent access to this response } -var ( - globalStore *ResponseStore - storeOnce sync.Once -) +var getGlobalStore = sync.OnceValue(func() *ResponseStore { + return NewResponseStore(0) // Default: no TTL, will be updated from appConfig +}) // GetGlobalStore returns the singleton response store instance func GetGlobalStore() *ResponseStore { - storeOnce.Do(func() { - globalStore = NewResponseStore(0) // Default: no TTL, will be updated from appConfig - }) - return globalStore + return getGlobalStore() } // SetTTL updates the TTL for the store diff --git a/core/http/endpoints/openresponses/store_test.go b/core/http/endpoints/openresponses/store_test.go index e0dcdba68..360e32df4 100644 --- a/core/http/endpoints/openresponses/store_test.go +++ b/core/http/endpoints/openresponses/store_test.go @@ -375,7 +375,7 @@ var _ = Describe("ResponseStore", func() { It("should handle concurrent stores and gets", func() { // This is a basic concurrency test done := make(chan bool, 10) - for i := 0; i < 10; i++ { + for i := range 10 { go func(id int) { responseID := fmt.Sprintf("resp_concurrent_%d", id) request := &schema.OpenResponsesRequest{Model: "test"} @@ -397,7 +397,7 @@ var _ = Describe("ResponseStore", func() { } // Wait for all goroutines - for i := 0; i < 10; i++ { + for range 10 { <-done } diff --git a/core/http/endpoints/openresponses/websocket.go b/core/http/endpoints/openresponses/websocket.go index 9e6ce7109..ffff7b044 100644 --- a/core/http/endpoints/openresponses/websocket.go +++ b/core/http/endpoints/openresponses/websocket.go @@ -262,9 +262,9 @@ func handleWSResponseCreate(connCtx context.Context, conn *lockedConn, input *sc noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ + Parameters: map[string]any{ + "properties": map[string]any{ + "message": map[string]any{ "type": "string", "description": "The message to reply the user with", }, diff --git a/core/http/middleware/auth.go b/core/http/middleware/auth.go index 4dde8f732..c4ab6b1f7 100644 --- a/core/http/middleware/auth.go +++ b/core/http/middleware/auth.go @@ -113,7 +113,7 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(err }) } - return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{ + return c.Render(http.StatusUnauthorized, "views/login", map[string]any{ "BaseURL": BaseURL(c), }) } diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index a853eb3d6..5f7c122cc 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -12,7 +12,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" @@ -65,7 +65,7 @@ func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { auth := c.Request().Header.Get("Authorization") bearer := strings.TrimPrefix(auth, "Bearer ") if bearer != "" && bearer != auth { - exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE) + exists, err := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, galleryop.ALWAYS_INCLUDE) if err == nil && exists { model = bearer } @@ -98,7 +98,7 @@ func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn con return next(c) } - modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED) + modelNames, err := galleryop.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, galleryop.SKIP_IF_CONFIGURED) if err != nil { xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err) return next(c) @@ -266,7 +266,7 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. switch responseFormat := input.ResponseFormat.(type) { case string: config.ResponseFormat = responseFormat - case map[string]interface{}: + case map[string]any: config.ResponseFormatMap = responseFormat } } @@ -276,7 +276,7 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. if stop != "" { config.StopWords = append(config.StopWords, stop) } - case []interface{}: + case []any: for _, pp := range stop { if s, ok := pp.(string); ok { config.StopWords = append(config.StopWords, s) @@ -296,11 +296,11 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. switch content := input.ToolsChoice.(type) { case string: _ = json.Unmarshal([]byte(content), &toolChoice) - case map[string]interface{}: + case map[string]any: dat, _ := json.Marshal(content) _ = json.Unmarshal(dat, &toolChoice) } - input.FunctionCall = map[string]interface{}{ + input.FunctionCall = map[string]any{ "name": toolChoice.Function.Name, } } @@ -315,7 +315,7 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. switch content := m.Content.(type) { case string: input.Messages[i].StringContent = content - case []interface{}: + case []any: dat, _ := json.Marshal(content) c := []schema.Content{} json.Unmarshal(dat, &c) @@ -451,7 +451,7 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. if fnc != "" { config.SetFunctionCallString(fnc) } - case map[string]interface{}: + case map[string]any: var name string n, exists := fnc["name"] if exists { @@ -466,7 +466,7 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. switch p := input.Prompt.(type) { case string: config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: + case []any: for _, pp := range p { if s, ok := pp.(string); ok { config.PromptStrings = append(config.PromptStrings, s) @@ -575,7 +575,7 @@ func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResp // Don't use tools - handled in endpoint } // "auto" is default - let model decide - case map[string]interface{}: + case map[string]any: // Specific tool: {type:"function", name:"..."} if tcType, ok := tc["type"].(string); ok && tcType == "function" { if name, ok := tc["name"].(string); ok { diff --git a/core/http/middleware/trace.go b/core/http/middleware/trace.go index 22049083d..71a12d976 100644 --- a/core/http/middleware/trace.go +++ b/core/http/middleware/trace.go @@ -4,7 +4,7 @@ import ( "bytes" "io" "net/http" - "sort" + "slices" "sync" "time" @@ -41,7 +41,27 @@ type APIExchange struct { var traceBuffer *circularbuffer.Queue[APIExchange] var mu sync.Mutex var logChan = make(chan APIExchange, 100) -var initOnce sync.Once +var tracingMaxItems int + +var doInitializeTracing = sync.OnceFunc(func() { + maxItems := tracingMaxItems + if maxItems <= 0 { + maxItems = 100 + } + mu.Lock() + traceBuffer = circularbuffer.New[APIExchange](maxItems) + mu.Unlock() + + go func() { + for exchange := range logChan { + mu.Lock() + if traceBuffer != nil { + traceBuffer.Enqueue(exchange) + } + mu.Unlock() + } + }() +}) type bodyWriter struct { http.ResponseWriter @@ -60,24 +80,8 @@ func (w *bodyWriter) Flush() { } func initializeTracing(maxItems int) { - initOnce.Do(func() { - if maxItems <= 0 { - maxItems = 100 - } - mu.Lock() - traceBuffer = circularbuffer.New[APIExchange](maxItems) - mu.Unlock() - - go func() { - for exchange := range logChan { - mu.Lock() - if traceBuffer != nil { - traceBuffer.Enqueue(exchange) - } - mu.Unlock() - } - }() - }) + tracingMaxItems = maxItems + doInitializeTracing() } // TraceMiddleware intercepts and logs JSON API requests and responses @@ -176,8 +180,8 @@ func GetTraces() []APIExchange { traces := traceBuffer.Values() mu.Unlock() - sort.Slice(traces, func(i, j int) bool { - return traces[i].Timestamp.After(traces[j].Timestamp) + slices.SortFunc(traces, func(a, b APIExchange) int { + return b.Timestamp.Compare(a.Timestamp) }) return traces diff --git a/core/http/openresponses_test.go b/core/http/openresponses_test.go index 61a448c62..fb28df380 100644 --- a/core/http/openresponses_test.go +++ b/core/http/openresponses_test.go @@ -88,7 +88,7 @@ var _ = Describe("Open Responses API", func() { Context("HTTP Protocol Compliance", func() { It("MUST accept application/json Content-Type", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -110,7 +110,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST return application/json for non-streaming responses", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", "stream": false, @@ -135,7 +135,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST return text/event-stream for streaming responses", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", "stream": true, @@ -160,7 +160,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST end streaming with [DONE] terminal event", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", "stream": true, @@ -188,7 +188,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST have event field matching type in body", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", "stream": true, @@ -219,7 +219,7 @@ var _ = Describe("Open Responses API", func() { // Next line should be data: with matching type if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { dataLine := strings.TrimPrefix(lines[i+1], "data: ") - var eventData map[string]interface{} + var eventData map[string]any if err := json.Unmarshal([]byte(dataLine), &eventData); err == nil { if typeVal, ok := eventData["type"].(string); ok { Expect(typeVal).To(Equal(eventType)) @@ -234,7 +234,7 @@ var _ = Describe("Open Responses API", func() { Context("Response Structure", func() { It("MUST return id field", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -252,7 +252,7 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) @@ -262,7 +262,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST return object field as 'response'", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -280,7 +280,7 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) @@ -290,7 +290,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST return created_at timestamp", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -308,7 +308,7 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) @@ -321,7 +321,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST return status field", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -339,7 +339,7 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) @@ -351,7 +351,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST return model field", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -369,7 +369,7 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) @@ -379,7 +379,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST return output array of items", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -397,12 +397,12 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HaveKey("output")) - output, ok := response["output"].([]interface{}) + output, ok := response["output"].([]any) Expect(ok).To(BeTrue()) Expect(output).ToNot(BeNil()) } @@ -411,7 +411,7 @@ var _ = Describe("Open Responses API", func() { Context("Items", func() { It("MUST include id field on all items", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -429,15 +429,15 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) - output, ok := response["output"].([]interface{}) + output, ok := response["output"].([]any) if ok { for _, item := range output { - itemMap, ok := item.(map[string]interface{}) + itemMap, ok := item.(map[string]any) Expect(ok).To(BeTrue()) Expect(itemMap).To(HaveKey("id")) Expect(itemMap["id"]).ToNot(BeEmpty()) @@ -447,7 +447,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST include type field on all items", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -465,15 +465,15 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) - output, ok := response["output"].([]interface{}) + output, ok := response["output"].([]any) if ok { for _, item := range output { - itemMap, ok := item.(map[string]interface{}) + itemMap, ok := item.(map[string]any) Expect(ok).To(BeTrue()) Expect(itemMap).To(HaveKey("type")) Expect(itemMap["type"]).ToNot(BeEmpty()) @@ -483,7 +483,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST include status field on all items", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -501,15 +501,15 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) - output, ok := response["output"].([]interface{}) + output, ok := response["output"].([]any) if ok { for _, item := range output { - itemMap, ok := item.(map[string]interface{}) + itemMap, ok := item.(map[string]any) Expect(ok).To(BeTrue()) Expect(itemMap).To(HaveKey("status")) status, ok := itemMap["status"].(string) @@ -521,13 +521,13 @@ var _ = Describe("Open Responses API", func() { }) It("MUST support message items with role field", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, - "input": []map[string]interface{}{ + "input": []map[string]any{ { "type": "message", "role": "user", - "content": []map[string]interface{}{ + "content": []map[string]any{ { "type": "input_text", "text": "Hello", @@ -550,14 +550,14 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) - output, ok := response["output"].([]interface{}) + output, ok := response["output"].([]any) if ok && len(output) > 0 { - itemMap, ok := output[0].(map[string]interface{}) + itemMap, ok := output[0].(map[string]any) Expect(ok).To(BeTrue()) if itemMap["type"] == "message" { Expect(itemMap).To(HaveKey("role")) @@ -572,13 +572,13 @@ var _ = Describe("Open Responses API", func() { Context("Content Types", func() { It("MUST support input_text content", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, - "input": []map[string]interface{}{ + "input": []map[string]any{ { "type": "message", "role": "user", - "content": []map[string]interface{}{ + "content": []map[string]any{ { "type": "input_text", "text": "Hello world", @@ -605,13 +605,13 @@ var _ = Describe("Open Responses API", func() { }) It("MUST support input_image content with URL", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, - "input": []map[string]interface{}{ + "input": []map[string]any{ { "type": "message", "role": "user", - "content": []map[string]interface{}{ + "content": []map[string]any{ { "type": "input_image", "image_url": "https://example.com/image.png", @@ -639,13 +639,13 @@ var _ = Describe("Open Responses API", func() { }) It("MUST support input_image content with base64", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, - "input": []map[string]interface{}{ + "input": []map[string]any{ { "type": "message", "role": "user", - "content": []map[string]interface{}{ + "content": []map[string]any{ { "type": "input_image", "image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", @@ -673,7 +673,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST support output_text content", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", } @@ -691,19 +691,19 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode == 200 { - var response map[string]interface{} + var response map[string]any body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) - output, ok := response["output"].([]interface{}) + output, ok := response["output"].([]any) if ok && len(output) > 0 { - itemMap, ok := output[0].(map[string]interface{}) + itemMap, ok := output[0].(map[string]any) Expect(ok).To(BeTrue()) if itemMap["type"] == "message" { - content, ok := itemMap["content"].([]interface{}) + content, ok := itemMap["content"].([]any) if ok && len(content) > 0 { - contentMap, ok := content[0].(map[string]interface{}) + contentMap, ok := content[0].(map[string]any) if ok { contentType, _ := contentMap["type"].(string) if contentType == "output_text" { @@ -719,7 +719,7 @@ var _ = Describe("Open Responses API", func() { Context("Streaming Events", func() { It("MUST emit response.created as first event", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", "stream": true, @@ -748,7 +748,7 @@ var _ = Describe("Open Responses API", func() { }) It("MUST include sequence_number in all events", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Hello", "stream": true, @@ -777,7 +777,7 @@ var _ = Describe("Open Responses API", func() { if strings.HasPrefix(line, "data: ") { dataLine := strings.TrimPrefix(line, "data: ") if dataLine != "[DONE]" { - var eventData map[string]interface{} + var eventData map[string]any if err := json.Unmarshal([]byte(dataLine), &eventData); err == nil { if _, hasType := eventData["type"]; hasType { Expect(eventData).To(HaveKey("sequence_number")) @@ -792,7 +792,7 @@ var _ = Describe("Open Responses API", func() { Context("Error Handling", func() { It("MUST return structured error with type and message fields", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": "nonexistent-model", "input": "Hello", } @@ -810,12 +810,12 @@ var _ = Describe("Open Responses API", func() { defer resp.Body.Close() if resp.StatusCode >= 400 { - var errorResp map[string]interface{} + var errorResp map[string]any body, _ := io.ReadAll(resp.Body) json.Unmarshal(body, &errorResp) if errorResp["error"] != nil { - errorObj, ok := errorResp["error"].(map[string]interface{}) + errorObj, ok := errorResp["error"].(map[string]any) if ok { Expect(errorObj).To(HaveKey("type")) Expect(errorObj).To(HaveKey("message")) @@ -828,7 +828,7 @@ var _ = Describe("Open Responses API", func() { Context("Previous Response ID", func() { It("should load previous response and concatenate context", func() { // First, create a response - reqBody1 := map[string]interface{}{ + reqBody1 := map[string]any{ "model": testModel, "input": "What is 2+2?", } @@ -850,7 +850,7 @@ var _ = Describe("Open Responses API", func() { Skip("First response failed, skipping previous_response_id test (backend may not be available)") } - var response1 map[string]interface{} + var response1 map[string]any body1, err := io.ReadAll(resp1.Body) Expect(err).ToNot(HaveOccurred()) err = json.Unmarshal(body1, &response1) @@ -861,7 +861,7 @@ var _ = Describe("Open Responses API", func() { Expect(responseID).ToNot(BeEmpty()) // Now create a new response with previous_response_id - reqBody2 := map[string]interface{}{ + reqBody2 := map[string]any{ "model": testModel, "input": "What about 3+3?", "previous_response_id": responseID, @@ -878,7 +878,7 @@ var _ = Describe("Open Responses API", func() { Expect(err).ToNot(HaveOccurred()) defer resp2.Body.Close() - var response2 map[string]interface{} + var response2 map[string]any body2, err := io.ReadAll(resp2.Body) Expect(err).ToNot(HaveOccurred()) err = json.Unmarshal(body2, &response2) @@ -889,7 +889,7 @@ var _ = Describe("Open Responses API", func() { }) It("should return error for invalid previous_response_id", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, "input": "Test", "previous_response_id": "nonexistent_response_id", @@ -909,12 +909,12 @@ var _ = Describe("Open Responses API", func() { Expect(resp.StatusCode).To(Equal(404)) - var errorResp map[string]interface{} + var errorResp map[string]any body, _ := io.ReadAll(resp.Body) json.Unmarshal(body, &errorResp) if errorResp["error"] != nil { - errorObj, ok := errorResp["error"].(map[string]interface{}) + errorObj, ok := errorResp["error"].(map[string]any) if ok { Expect(errorObj["type"]).To(Equal("not_found")) Expect(errorObj["param"]).To(Equal("previous_response_id")) @@ -926,7 +926,7 @@ var _ = Describe("Open Responses API", func() { Context("Item Reference", func() { It("should resolve item_reference in input", func() { // First, create a response with items - reqBody1 := map[string]interface{}{ + reqBody1 := map[string]any{ "model": testModel, "input": "Hello", } @@ -948,32 +948,32 @@ var _ = Describe("Open Responses API", func() { Skip("First response failed, skipping item_reference test (backend may not be available)") } - var response1 map[string]interface{} + var response1 map[string]any body1, err := io.ReadAll(resp1.Body) Expect(err).ToNot(HaveOccurred()) err = json.Unmarshal(body1, &response1) Expect(err).ToNot(HaveOccurred()) // Get the first output item ID - output, ok := response1["output"].([]interface{}) + output, ok := response1["output"].([]any) Expect(ok).To(BeTrue()) Expect(len(output)).To(BeNumerically(">", 0)) - firstItem, ok := output[0].(map[string]interface{}) + firstItem, ok := output[0].(map[string]any) Expect(ok).To(BeTrue()) itemID, ok := firstItem["id"].(string) Expect(ok).To(BeTrue()) Expect(itemID).ToNot(BeEmpty()) // Now create a new response with item_reference - reqBody2 := map[string]interface{}{ + reqBody2 := map[string]any{ "model": testModel, - "input": []interface{}{ - map[string]interface{}{ + "input": []any{ + map[string]any{ "type": "item_reference", "item_id": itemID, }, - map[string]interface{}{ + map[string]any{ "type": "message", "role": "user", "content": "Continue from the previous message", @@ -997,10 +997,10 @@ var _ = Describe("Open Responses API", func() { }) It("should return error for invalid item_reference", func() { - reqBody := map[string]interface{}{ + reqBody := map[string]any{ "model": testModel, - "input": []interface{}{ - map[string]interface{}{ + "input": []any{ + map[string]any{ "type": "item_reference", "item_id": "nonexistent_item_id", }, diff --git a/core/http/react-ui/src/components/ImageSelector.jsx b/core/http/react-ui/src/components/ImageSelector.jsx new file mode 100644 index 000000000..323301b61 --- /dev/null +++ b/core/http/react-ui/src/components/ImageSelector.jsx @@ -0,0 +1,81 @@ +import { useState } from 'react' + +const GPU_OPTIONS = [ + { key: 'cpu', label: 'CPU', icon: 'fa-microchip', tag: 'latest-cpu', devTag: 'master-cpu', dockerFlags: '' }, + { key: 'cuda12', label: 'CUDA 12', icon: 'fa-bolt', tag: 'latest-gpu-nvidia-cuda-12', devTag: 'master-gpu-nvidia-cuda-12', dockerFlags: '--gpus all' }, + { key: 'cuda13', label: 'CUDA 13', icon: 'fa-bolt', tag: 'latest-gpu-nvidia-cuda-13', devTag: 'master-gpu-nvidia-cuda-13', dockerFlags: '--gpus all' }, + { key: 'l4t12', label: 'L4T CUDA 12', icon: 'fa-bolt', tag: 'latest-gpu-nvidia-l4t-cuda12',devTag: 'master-gpu-nvidia-l4t-cuda12',dockerFlags: '--runtime nvidia' }, + { key: 'l4t13', label: 'L4T CUDA 13', icon: 'fa-bolt', tag: 'latest-gpu-nvidia-l4t-cuda13',devTag: 'master-gpu-nvidia-l4t-cuda13',dockerFlags: '--runtime nvidia' }, + { key: 'amd', label: 'AMD', icon: 'fa-fire', tag: 'latest-gpu-hipblas', devTag: 'master-gpu-hipblas', dockerFlags: '--device /dev/kfd --device /dev/dri' }, + { key: 'intel', label: 'Intel', icon: 'fa-atom', tag: 'latest-gpu-intel', devTag: 'master-gpu-intel', dockerFlags: '--device /dev/dri' }, + { key: 'vulkan', label: 'Vulkan', icon: 'fa-globe', tag: 'latest-gpu-vulkan', devTag: 'master-gpu-vulkan', dockerFlags: '--device /dev/dri' }, +] + +export function useImageSelector(defaultKey = 'cpu') { + const [selected, setSelected] = useState(defaultKey) + const [dev, setDev] = useState(false) + const option = GPU_OPTIONS.find(o => o.key === selected) || GPU_OPTIONS[0] + return { selected, setSelected, option, options: GPU_OPTIONS, dev, setDev } +} + +export default function ImageSelector({ selected, onSelect, dev, onDevChange }) { + return ( +
+ {GPU_OPTIONS.map(opt => { + const active = selected === opt.key + return ( + + ) + })} + {onDevChange && ( + + )} +
+ ) +} + +// Helper to build a docker image string +export function dockerImage(option, dev = false) { + return `localai/localai:${dev ? option.devTag : option.tag}` +} + +// Helper to build docker run flags (--gpus all, --device, etc.) +export function dockerFlags(option) { + return option.dockerFlags +} diff --git a/core/http/react-ui/src/components/Modal.jsx b/core/http/react-ui/src/components/Modal.jsx index 6f1e7ee09..e13824ef6 100644 --- a/core/http/react-ui/src/components/Modal.jsx +++ b/core/http/react-ui/src/components/Modal.jsx @@ -4,6 +4,8 @@ import '../pages/auth.css' export default function Modal({ onClose, children, maxWidth = '600px' }) { const dialogRef = useRef(null) const lastFocusRef = useRef(null) + const onCloseRef = useRef(onClose) + onCloseRef.current = onClose useEffect(() => { lastFocusRef.current = document.activeElement @@ -20,7 +22,7 @@ export default function Modal({ onClose, children, maxWidth = '600px' }) { const handleKeyDown = (e) => { if (e.key === 'Escape') { - onClose?.() + onCloseRef.current?.() return } if (e.key !== 'Tab') return @@ -46,7 +48,7 @@ export default function Modal({ onClose, children, maxWidth = '600px' }) { document.removeEventListener('keydown', handleKeyDown) lastFocusRef.current?.focus() } - }, [onClose]) + }, []) // Run once on mount — onClose accessed via stable ref return (
{ if (prev.length <= 1) return prev const filtered = prev.filter(c => c.id !== id) - if (id === activeId && filtered.length > 0) { - setActiveId(filtered[0].id) + const newActiveId = id === activeId && filtered.length > 0 ? filtered[0].id : activeId + if (id === activeId) { + setActiveId(newActiveId) } + saveConversations(agentName, filtered, newActiveId) return filtered }) - }, [activeId]) + }, [activeId, agentName]) const deleteAllConversations = useCallback(() => { const conv = createConversation() setConversations([conv]) setActiveId(conv.id) - }, []) + saveConversations(agentName, [conv], conv.id) + }, [agentName]) const renameConversation = useCallback((id, name) => { setConversations(prev => prev.map(c => @@ -147,11 +150,34 @@ export function useAgentChat(agentName) { })) }, [activeId]) + // Add a message to a specific conversation by ID, regardless of which is active. + // Used by SSE handlers to pin responses to the conversation that initiated the request. + const addMessageToConversation = useCallback((conversationId, msg) => { + setConversations(prev => prev.map(c => { + if (c.id !== conversationId) return c + const updated = { + ...c, + messages: [...c.messages, msg], + updatedAt: Date.now(), + } + if (c.messages.length === 0 && msg.sender === 'user') { + const text = msg.content || '' + updated.name = text.slice(0, 40) + (text.length > 40 ? '...' : '') + } + return updated + })) + }, []) + const clearMessages = useCallback(() => { - setConversations(prev => prev.map(c => - c.id === activeId ? { ...c, messages: [], updatedAt: Date.now() } : c - )) - }, [activeId]) + setConversations(prev => { + const updated = prev.map(c => + c.id === activeId ? { ...c, messages: [], updatedAt: Date.now() } : c + ) + // Save immediately so a page refresh doesn't restore the old messages + saveConversations(agentName, updated, activeId) + return updated + }) + }, [activeId, agentName]) const getMessages = useCallback(() => { return activeConversation?.messages || [] @@ -167,6 +193,7 @@ export function useAgentChat(agentName) { deleteAllConversations, renameConversation, addMessage, + addMessageToConversation, clearMessages, getMessages, } diff --git a/core/http/react-ui/src/pages/AgentChat.jsx b/core/http/react-ui/src/pages/AgentChat.jsx index e796c2422..1a79e3d81 100644 --- a/core/http/react-ui/src/pages/AgentChat.jsx +++ b/core/http/react-ui/src/pages/AgentChat.jsx @@ -93,7 +93,7 @@ export default function AgentChat() { const { conversations, activeConversation, activeId, addConversation, switchConversation, deleteConversation, - deleteAllConversations, renameConversation, addMessage, clearMessages, + deleteAllConversations, renameConversation, addMessage, addMessageToConversation, clearMessages, } = useAgentChat(name) const messages = activeConversation?.messages || [] @@ -118,8 +118,15 @@ export default function AgentChat() { const messageIdCounter = useRef(0) const addMessageRef = useRef(addMessage) addMessageRef.current = addMessage + const addMessageToConvRef = useRef(addMessageToConversation) + addMessageToConvRef.current = addMessageToConversation const activeIdRef = useRef(activeId) activeIdRef.current = activeId + // Tracks which conversation initiated the current request — SSE responses + // are pinned to this ID so switching tabs doesn't misdirect them. + const processingChatIdRef = useRef(null) + // Maps backend messageID → conversationId for robust SSE routing across navigations. + const pendingRequestsRef = useRef(new Map()) const processing = processingChatId === activeId @@ -137,16 +144,34 @@ export default function AgentChat() { es.addEventListener('json_message', (e) => { try { const data = JSON.parse(e.data) + const sender = data.sender || (data.role === 'user' ? 'user' : 'agent') + // Skip user message echoes — already added locally in handleSend + if (sender === 'user') return const msg = { id: nextId(), - sender: data.sender || (data.role === 'user' ? 'user' : 'agent'), + sender, content: data.content || data.message || '', - timestamp: data.timestamp || Date.now(), + timestamp: data.timestamp ? Math.floor(data.timestamp / 1e6) : Date.now(), } if (data.metadata && Object.keys(data.metadata).length > 0) { msg.metadata = data.metadata } - addMessageRef.current(msg) + // Route to conversation: try messageID mapping first, then processingChatIdRef, then active + const msgId = data.message_id || '' + const baseId = msgId.replace(/-agent$/, '') + const targetId = pendingRequestsRef.current.get(baseId) + || processingChatIdRef.current + || activeIdRef.current + addMessageToConvRef.current(targetId, msg) + // Clear streaming + processing state when the final agent message arrives + if (sender === 'agent') { + pendingRequestsRef.current.delete(baseId) + processingChatIdRef.current = null + setProcessingChatId(null) + setStreamContent('') + setStreamReasoning('') + setStreamToolCalls([]) + } } catch (_err) { // ignore malformed messages } @@ -156,15 +181,20 @@ export default function AgentChat() { try { const data = JSON.parse(e.data) if (data.status === 'processing') { - setProcessingChatId(activeIdRef.current) + // Track which conversation is processing so responses go to the right place. + // Only set if not already pinned by handleSend (avoids race when user switches conversations). + if (!processingChatIdRef.current) { + processingChatIdRef.current = activeIdRef.current + setProcessingChatId(activeIdRef.current) + } setStreamContent('') setStreamReasoning('') setStreamToolCalls([]) } else if (data.status === 'completed') { - setProcessingChatId(null) - setStreamContent('') - setStreamReasoning('') - setStreamToolCalls([]) + // Don't clear processingChatIdRef, processingChatId, or streaming state here — + // they'll be cleared when the agent's json_message arrives, + // so reasoning and tool calls remain visible until the response replaces them + // and late-arriving messages still route to the correct conversation. } } catch (_err) { // ignore @@ -190,6 +220,16 @@ export default function AgentChat() { updated[updated.length - 1] = { ...updated[updated.length - 1], args: updated[updated.length - 1].args + args } return updated }) + } else if (data.type === 'tool_result') { + const tname = data.tool_name || '' + setStreamToolCalls(prev => { + const updated = [...prev] + const idx = updated.findLastIndex(tc => tc.name === tname && !tc.result) + if (idx >= 0) { + updated[idx] = { ...updated[idx], result: data.tool_result || 'done' } + } + return updated + }) } else if (data.type === 'done') { // Content will be finalized by json_message event } @@ -201,7 +241,8 @@ export default function AgentChat() { es.addEventListener('status', (e) => { const text = e.data if (!text) return - addMessageRef.current({ + const targetId = processingChatIdRef.current || activeIdRef.current + addMessageToConvRef.current(targetId, { id: nextId(), sender: 'system', content: text, @@ -216,6 +257,7 @@ export default function AgentChat() { } catch (_err) { addToast('Agent error', 'error') } + processingChatIdRef.current = null setProcessingChatId(null) }) @@ -226,6 +268,8 @@ export default function AgentChat() { return () => { es.close() eventSourceRef.current = null + processingChatIdRef.current = null + pendingRequestsRef.current.clear() } }, [name, userId, addToast, nextId]) @@ -307,14 +351,22 @@ export default function AgentChat() { if (!msg || processing) return setInput('') if (textareaRef.current) textareaRef.current.style.height = 'auto' + // Add user message locally immediately (like standard chat) + addMessage({ id: nextId(), sender: 'user', content: msg, timestamp: Date.now() }) setProcessingChatId(activeId) + processingChatIdRef.current = activeId try { - await agentsApi.chat(name, msg, userId) + const resp = await agentsApi.chat(name, msg, userId) + // Map backend messageID → conversation so SSE events route correctly + if (resp && resp.message_id) { + pendingRequestsRef.current.set(resp.message_id, activeId) + } } catch (err) { addToast(`Failed to send message: ${err.message}`, 'error') + processingChatIdRef.current = null setProcessingChatId(null) } - }, [input, processing, name, activeId, addToast, userId]) + }, [input, processing, name, activeId, addToast, userId, addMessage, nextId]) const handleKeyDown = (e) => { if (e.key === 'Enter' && !e.shiftKey) { @@ -604,16 +656,28 @@ export default function AgentChat() {
)} - {streamToolCalls.length > 0 && !streamContent && ( + {streamToolCalls.length > 0 && (
{streamToolCalls.map((tc, idx) => ( -
- - - {tc.name} - - calling... -
+
+ + + {tc.name} + + {tc.result ? 'done' : 'calling...'} + + + {tc.args && ( +
+                          {(() => { try { return JSON.stringify(JSON.parse(tc.args), null, 2) } catch { return tc.args } })()}
+                        
+ )} + {tc.result && ( +
+                          {tc.result}
+                        
+ )} +
))}
)} diff --git a/core/http/react-ui/src/pages/AgentCreate.jsx b/core/http/react-ui/src/pages/AgentCreate.jsx index a0d4fccf7..1ab5f5a76 100644 --- a/core/http/react-ui/src/pages/AgentCreate.jsx +++ b/core/http/react-ui/src/pages/AgentCreate.jsx @@ -1,6 +1,6 @@ import { useState, useEffect, useMemo } from 'react' import { useParams, useNavigate, useLocation, useOutletContext, useSearchParams } from 'react-router-dom' -import { agentsApi } from '../utils/api' +import { agentsApi, skillsApi } from '../utils/api' import SearchableModelSelect from '../components/SearchableModelSelect' import { CAP_CHAT, CAP_TRANSCRIPT, CAP_TTS } from '../utils/capabilities' import Toggle from '../components/Toggle' @@ -263,6 +263,21 @@ const SECTIONS = [ // Fields handled by custom editors in the MCP section const CUSTOM_FIELDS = new Set(['mcp_stdio_servers']) +// Fields not implemented in the native executor (distributed mode). +// These are hidden from the form when meta.distributed is true. +const HIDDEN_IN_DISTRIBUTED = new Set([ + 'mcp_prepare_script', + 'multimodal_model', 'transcription_model', 'transcription_language', 'tts_model', + 'plan_reviewer_model', + 'enable_planning', 'initiate_conversations', 'can_stop_itself', + 'scheduler_poll_interval', 'scheduler_task_template', + 'enable_reasoning', 'enable_reasoning_tool', // replaced by enable_reasoning_for_instruct + 'kb_auto_search', 'kb_as_tools', // replaced by kb_mode select + 'disable_sink_state', // always disabled in native executor + 'enable_kb_compaction', 'kb_compaction_interval', 'kb_compaction_summarize', + 'parallel_jobs', 'cancel_previous_on_new_message', +]) + // --- Main component --- export default function AgentCreate() { @@ -285,7 +300,11 @@ export default function AgentCreate() { const [filters, setFilters] = useState([]) const [dynamicPrompts, setDynamicPrompts] = useState([]) const [mcpHttpServers, setMcpHttpServers] = useState([]) + const [mcpJsonMode, setMcpJsonMode] = useState(false) + const [mcpRawJson, setMcpRawJson] = useState('') const [stdioServers, setStdioServers] = useState([]) + const [availableSkills, setAvailableSkills] = useState([]) + const [selectedSkills, setSelectedSkills] = useState([]) // Group metadata Fields by tags.section const fieldsBySection = useMemo(() => { @@ -293,6 +312,7 @@ export default function AgentCreate() { const groups = {} for (const field of meta.Fields) { if (CUSTOM_FIELDS.has(field.name)) continue + if (meta?.distributed && HIDDEN_IN_DISTRIBUTED.has(field.name)) continue const section = field.tags?.section || 'BasicInfo' if (!groups[section]) groups[section] = [] groups[section].push(field) @@ -301,19 +321,26 @@ export default function AgentCreate() { }, [meta]) const visibleSections = useMemo(() => { - const items = [...SECTIONS] + let items = [...SECTIONS] + // In distributed mode, hide LocalAGI-specific sections — use MCP Servers instead + if (meta?.distributed) { + const hiddenInDistributed = new Set(['actions', 'connectors', 'filters', 'dynamic_prompts']) + items = items.filter(s => !hiddenInDistributed.has(s.id)) + } if (isEdit) items.push({ id: 'export', icon: 'fa-download', label: 'Export' }) return items - }, [isEdit]) + }, [isEdit, meta]) useEffect(() => { const init = async () => { try { - const [metaData, config] = await Promise.all([ + const [metaData, config, skillsList] = await Promise.all([ agentsApi.configMeta().catch(() => null), isEdit ? agentsApi.getConfig(name, userId).catch(() => null) : Promise.resolve(null), + skillsApi.list().catch(() => null), ]) if (metaData) setMeta(metaData) + if (skillsList?.skills) setAvailableSkills(skillsList.skills) // Build defaults from metadata const initialForm = {} @@ -343,6 +370,7 @@ export default function AgentCreate() { setDynamicPrompts(Array.isArray(sourceConfig.dynamic_prompts) ? sourceConfig.dynamic_prompts : []) setMcpHttpServers(Array.isArray(sourceConfig.mcp_servers) ? sourceConfig.mcp_servers : []) setStdioServers(parseStdioServers(sourceConfig.mcp_stdio_servers)) + if (Array.isArray(sourceConfig.selected_skills)) setSelectedSkills(sourceConfig.selected_skills) } setForm(initialForm) @@ -365,6 +393,10 @@ export default function AgentCreate() { addToast('Agent name is required', 'warning') return } + if (!form.model?.toString().trim()) { + addToast('Model is required', 'warning') + return + } setSaving(true) try { const payload = { ...form } @@ -382,9 +414,16 @@ export default function AgentCreate() { payload.dynamic_prompts = dynamicPrompts payload.mcp_servers = mcpHttpServers.filter(s => s.url) // Send STDIO servers as JSON string in expected format - if (stdioServers.length > 0) { + if (mcpJsonMode && mcpRawJson.trim()) { + // In JSON editor mode, use the raw JSON directly + payload.mcp_stdio_servers = mcpRawJson + } else if (stdioServers.length > 0) { payload.mcp_stdio_servers = buildStdioJson(stdioServers) } + // Send selected skills + if (selectedSkills.length > 0) { + payload.selected_skills = selectedSkills + } if (isEdit) { await agentsApi.update(name, payload, userId) @@ -424,15 +463,21 @@ export default function AgentCreate() { if (!fields.length) { return

No fields available for this section.

} - return fields.map(field => ( - - )) + return fields + .filter(field => { + // Hide fields whose depends_on parent is falsy + if (field.tags?.depends_on && !form[field.tags.depends_on]) return false + return true + }) + .map(field => ( + + )) } const renderSection = () => { @@ -442,13 +487,151 @@ export default function AgentCreate() { case 'MemorySettings': case 'PromptsGoals': case 'AdvancedSettings': - return renderFieldSection(activeSection) + return ( + <> + {renderFieldSection(activeSection)} + {/* Skills picker — shown only in AdvancedSettings when enable_skills is checked */} + {activeSection === 'AdvancedSettings' && form.enable_skills && availableSkills.length > 0 && ( +
+

+ + Select Skills +

+

Choose which skills this agent can use. If none selected, all available skills are included.

+
+ +
+
+ {availableSkills.map(skill => ( + + ))} +
+
+ )} + + ) case 'MCP': return ( <> - {/* Other MCP metadata fields (mcp_prepare_script, etc.) */} - {renderFieldSection('MCP')} + {/* Mode toggle + import buttons */} +
+ + + {mcpJsonMode ? 'Edit as Claude Desktop JSON format' : 'Configure MCP servers visually'} + +
+ + {mcpJsonMode ? ( +
+ +